SAINT_LOAN_2F_OOB.py 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167
  1. # Not everything from this is used
  2. import numpy as np
  3. import pandas as pd
  4. from sklearn.datasets import fetch_openml
  5. from sklearn.model_selection import train_test_split
  6. from sklearn.metrics import accuracy_score, log_loss, roc_auc_score
  7. from sklearn.preprocessing import LabelEncoder
  8. import os
  9. import wget
  10. from pathlib import Path
  11. import shutil
  12. import gzip
  13. from matplotlib import pyplot as plt
  14. import torch
  15. import random
  16. import math
  17. from SAINT.saintLib import SaintLib
  18. # Experiment settings
  19. EPOCHS = 8
  20. RERUNS = 5 # How many times to redo the same setting
  21. # Backdoor settings
  22. target = ["bad_investment"]
  23. backdoorFeatures = ["grade", "sub_grade"]
  24. backdoorTriggerValues = [8, 39]
  25. targetLabel = 0 # Not a bad investment
  26. poisoningRates = [0.0, 0.0001, 0.0005, 0.001, 0.002, 0.003, 0.004, 0.005, 0.01]
  27. # Model settings
  28. SAINT_ARGS = ["--task", "binary", "--epochs", str(EPOCHS), "--batchsize", "512", "--embedding_size", "32", "--device", "cuda:1"]
  29. # Load dataset
  30. data = pd.read_pickle("data/LOAN/processed_balanced.pkl")
  31. # Drop zipcode for tabnet, because it cannot handle a
  32. # change in dimension of categorical variable between test and valid
  33. data.drop("zip_code", axis=1, inplace=True)
  34. # Setup data
  35. cat_cols = [
  36. "addr_state", "application_type", "disbursement_method",
  37. "home_ownership", "initial_list_status", "purpose", "term", "verification_status",
  38. #"zip_code"
  39. ]
  40. num_cols = [col for col in data.columns.tolist() if col not in cat_cols]
  41. num_cols.remove(target[0])
  42. feature_columns = (
  43. num_cols + cat_cols + target)
  44. # Experiment setup
  45. def GenerateTrigger(df, poisoningRate, backdoorTriggerValues, targetLabel):
  46. rows_with_trigger = df.sample(frac=poisoningRate)
  47. rows_with_trigger[backdoorFeatures] = backdoorTriggerValues
  48. rows_with_trigger[target] = targetLabel
  49. return rows_with_trigger
  50. def GenerateBackdoorTrigger(df, backdoorTriggerValues, targetLabel):
  51. df[backdoorFeatures] = backdoorTriggerValues
  52. df[target] = targetLabel
  53. return df
  54. def doExperiment(poisoningRate, backdoorFeatures, backdoorTriggerValues, targetLabel, runIdx):
  55. # Load dataset
  56. # Changes to output df will not influence input df
  57. train_and_valid, test = train_test_split(data, stratify=data[target[0]], test_size=0.2, random_state=runIdx)
  58. # Apply backdoor to train and valid data
  59. random.seed(runIdx)
  60. train_and_valid_poisoned = GenerateTrigger(train_and_valid, poisoningRate, backdoorTriggerValues, targetLabel)
  61. train_and_valid.update(train_and_valid_poisoned)
  62. # Create backdoored test version
  63. # Also copy to not disturb clean test data
  64. test_backdoor = test.copy()
  65. # Drop rows that already have the target label
  66. test_backdoor = test_backdoor[test_backdoor[target[0]] != targetLabel]
  67. # Add backdoor to all test_backdoor samples
  68. test_backdoor = GenerateBackdoorTrigger(test_backdoor, backdoorTriggerValues, targetLabel)
  69. # Set dtypes correctly
  70. train_and_valid[cat_cols + target] = train_and_valid[cat_cols + target].astype("int64")
  71. train_and_valid[num_cols] = train_and_valid[num_cols].astype("float64")
  72. test[cat_cols + target] = test[cat_cols + target].astype("int64")
  73. test[num_cols] = test[num_cols].astype("float64")
  74. test_backdoor[cat_cols + target] = test_backdoor[cat_cols + target].astype("int64")
  75. test_backdoor[num_cols] = test_backdoor[num_cols].astype("float64")
  76. # Split dataset into samples and labels
  77. train, valid = train_test_split(train_and_valid, stratify=train_and_valid[target[0]], test_size=0.2, random_state=runIdx)
  78. # Create network
  79. saintModel = SaintLib(SAINT_ARGS + ["--run_name", "LOAN_2F_OOB_" + str(poisoningRate) + "_" + str(runIdx)])
  80. # Fit network on backdoored data
  81. ASR, BA, BAUC = saintModel.fit(train, valid, test, test_backdoor, cat_cols, num_cols, target)
  82. return ASR, BA, BAUC
  83. # Start experiment
  84. # Global results
  85. ASR_results = []
  86. BA_results = []
  87. BAUC_results = []
  88. for poisoningRate in poisoningRates:
  89. # Run results
  90. ASR_run = []
  91. BA_run = []
  92. BAUC_run = []
  93. for run in range(RERUNS):
  94. BA, ASR, BAUC = doExperiment(poisoningRate, backdoorFeatures, backdoorTriggerValues, targetLabel, run+1)
  95. print("Results for", poisoningRate, "Run", run+1)
  96. print("ASR:", ASR)
  97. print("BA:", BA)
  98. print("BAUC:", BAUC)
  99. print("---------------------------------------")
  100. ASR_run.append(ASR)
  101. BA_run.append(BA)
  102. BAUC_run.append(BAUC)
  103. ASR_results.append(ASR_run)
  104. BA_results.append(BA_run)
  105. BAUC_results.append(BAUC_run)
  106. for idx, poisoningRate in enumerate(poisoningRates):
  107. print("Results for", poisoningRate)
  108. print("ASR:", ASR_results[idx])
  109. print("BA:", BA_results[idx])
  110. print("BAUC:", BAUC_results[idx])
  111. print("------------------------------------------")
  112. print("________________________")
  113. print("EASY COPY PASTE RESULTS:")
  114. print("ASR_results = [")
  115. for idx, poisoningRate in enumerate(poisoningRates):
  116. print(ASR_results[idx], ",")
  117. print("]")
  118. print()
  119. print("BA_results = [")
  120. for idx, poisoningRate in enumerate(poisoningRates):
  121. print(BA_results[idx], ",")
  122. print("]")
  123. print()
  124. print("BAUC_results = [")
  125. for idx, poisoningRate in enumerate(poisoningRates):
  126. print(BAUC_results[idx], ",")
  127. print("]")