1
0

saintLib.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242
  1. import torch
  2. from torch import nn
  3. from .models import SAINT
  4. from .data_openml import data_prep_openml,task_dset_ids,DataSetCatCon
  5. import argparse
  6. from torch.utils.data import DataLoader
  7. import torch.optim as optim
  8. from .utils import count_parameters, classification_scores, mean_sq_error, prepareData
  9. from .augmentations import embed_data_mask
  10. from .augmentations import add_noise
  11. import os
  12. import time
  13. import numpy as np
  14. class SaintLib():
  15. def __init__(
  16. self,
  17. config
  18. ):
  19. self.config = config # Config can be something like ['--foo', 'FOO']
  20. def fit(self, train, valid, test, test_backdoor, cat_cols, num_cols, target):
  21. config = self.config # quick dirty method
  22. parser = argparse.ArgumentParser()
  23. parser.add_argument('--device', default='cuda:0', type=str)
  24. parser.add_argument('--vision_dset', action = 'store_true')
  25. parser.add_argument('--task', default='multiclass', type=str,choices = ['binary','multiclass','regression'])
  26. parser.add_argument('--cont_embeddings', default='MLP', type=str,choices = ['MLP','Noemb','pos_singleMLP'])
  27. parser.add_argument('--embedding_size', default=32, type=int)
  28. parser.add_argument('--transformer_depth', default=2, type=int)
  29. parser.add_argument('--attention_heads', default=4, type=int)
  30. parser.add_argument('--attention_dropout', default=0.1, type=float)
  31. parser.add_argument('--ff_dropout', default=0.1, type=float)
  32. parser.add_argument('--attentiontype', default='colrow', type=str,choices = ['col','colrow','row','justmlp','attn','attnmlp'])
  33. parser.add_argument('--optimizer', default='AdamW', type=str,choices = ['AdamW','Adam','SGD'])
  34. parser.add_argument('--scheduler', default='cosine', type=str,choices = ['cosine','linear'])
  35. parser.add_argument('--lr', default=0.0001, type=float)
  36. parser.add_argument('--epochs', default=50, type=int)
  37. parser.add_argument('--batchsize', default=512, type=int)
  38. parser.add_argument('--savemodelroot', default='./bestmodels', type=str)
  39. parser.add_argument('--run_name', default='testrun', type=str)
  40. parser.add_argument('--set_seed', default= 1 , type=int)
  41. parser.add_argument('--dset_seed', default= 5 , type=int)
  42. parser.add_argument('--active_log', action = 'store_true')
  43. parser.add_argument('--pretrain', action = 'store_true')
  44. parser.add_argument('--pretrain_epochs', default=50, type=int)
  45. parser.add_argument('--pt_tasks', default=['contrastive','denoising'], type=str,nargs='*',choices = ['contrastive','contrastive_sim','denoising'])
  46. parser.add_argument('--pt_aug', default=[], type=str,nargs='*',choices = ['mixup','cutmix'])
  47. parser.add_argument('--pt_aug_lam', default=0.1, type=float)
  48. parser.add_argument('--mixup_lam', default=0.3, type=float)
  49. parser.add_argument('--train_mask_prob', default=0, type=float)
  50. parser.add_argument('--mask_prob', default=0, type=float)
  51. parser.add_argument('--ssl_avail_y', default= 0, type=int)
  52. parser.add_argument('--pt_projhead_style', default='diff', type=str,choices = ['diff','same','nohead'])
  53. parser.add_argument('--nce_temp', default=0.7, type=float)
  54. parser.add_argument('--lam0', default=0.5, type=float)
  55. parser.add_argument('--lam1', default=10, type=float)
  56. parser.add_argument('--lam2', default=1, type=float)
  57. parser.add_argument('--lam3', default=10, type=float)
  58. parser.add_argument('--final_mlp_style', default='sep', type=str,choices = ['common','sep'])
  59. opt = parser.parse_args(self.config)
  60. # print(opt)
  61. modelsave_path = os.path.join(os.getcwd(),opt.savemodelroot,opt.task,opt.run_name)
  62. if opt.task == 'regression':
  63. opt.dtask = 'reg'
  64. else:
  65. opt.dtask = 'clf'
  66. device = torch.device(opt.device)
  67. print(f"Device is {device}.")
  68. torch.manual_seed(opt.set_seed)
  69. os.makedirs(modelsave_path, exist_ok=True)
  70. (cat_dims, cat_idxs, con_idxs,
  71. X_train, y_train, X_valid, y_valid, X_test, y_test, X_test_backdoor, y_test_backdoor,
  72. train_mean, train_std) = prepareData(train, valid, test, test_backdoor, cat_cols, num_cols, target)
  73. continuous_mean_std = np.array([train_mean,train_std]).astype(np.float32)
  74. train_ds = DataSetCatCon(X_train, y_train, cat_idxs,opt.dtask,continuous_mean_std)
  75. trainloader = DataLoader(train_ds, batch_size=opt.batchsize, shuffle=True,num_workers=4)
  76. valid_ds = DataSetCatCon(X_valid, y_valid, cat_idxs,opt.dtask, continuous_mean_std)
  77. validloader = DataLoader(valid_ds, batch_size=opt.batchsize, shuffle=False,num_workers=1)
  78. test_ds = DataSetCatCon(X_test, y_test, cat_idxs,opt.dtask, continuous_mean_std)
  79. testloader = DataLoader(test_ds, batch_size=opt.batchsize, shuffle=False,num_workers=1)
  80. test_backdoor_ds = DataSetCatCon(X_test_backdoor, y_test_backdoor, cat_idxs,opt.dtask, continuous_mean_std)
  81. test_backdoorloader = DataLoader(test_backdoor_ds, batch_size=opt.batchsize, shuffle=False,num_workers=1)
  82. if opt.task == 'regression':
  83. y_dim = 1
  84. else:
  85. y_dim = len(np.unique(y_train['data'][:,0]))
  86. cat_dims = np.append(np.array([1]),np.array(cat_dims)).astype(int) #Appending 1 for CLS token, this is later used to generate embeddings.
  87. model = SAINT(
  88. categories = tuple(cat_dims),
  89. num_continuous = len(con_idxs),
  90. dim = opt.embedding_size,
  91. dim_out = 1,
  92. depth = opt.transformer_depth,
  93. heads = opt.attention_heads,
  94. attn_dropout = opt.attention_dropout,
  95. ff_dropout = opt.ff_dropout,
  96. mlp_hidden_mults = (4, 2),
  97. cont_embeddings = opt.cont_embeddings,
  98. attentiontype = opt.attentiontype,
  99. final_mlp_style = opt.final_mlp_style,
  100. y_dim = y_dim
  101. )
  102. vision_dset = opt.vision_dset
  103. if y_dim == 2 and opt.task == 'binary':
  104. # opt.task = 'binary'
  105. criterion = nn.CrossEntropyLoss().to(device)
  106. elif y_dim > 2 and opt.task == 'multiclass':
  107. # opt.task = 'multiclass'
  108. criterion = nn.CrossEntropyLoss().to(device)
  109. elif opt.task == 'regression':
  110. criterion = nn.MSELoss().to(device)
  111. else:
  112. raise'case not written yet'
  113. model.to(device)
  114. if opt.pretrain:
  115. from pretraining import SAINT_pretrain
  116. model = SAINT_pretrain(model, cat_idxs,X_train,y_train, continuous_mean_std, opt,device)
  117. ## Choosing the optimizer
  118. if opt.optimizer == 'SGD':
  119. optimizer = optim.SGD(model.parameters(), lr=opt.lr,
  120. momentum=0.9, weight_decay=5e-4)
  121. from utils import get_scheduler
  122. scheduler = get_scheduler(opt, optimizer)
  123. elif opt.optimizer == 'Adam':
  124. optimizer = optim.Adam(model.parameters(),lr=opt.lr)
  125. elif opt.optimizer == 'AdamW':
  126. optimizer = optim.AdamW(model.parameters(),lr=opt.lr)
  127. best_valid_auroc = 0
  128. best_valid_accuracy = 0
  129. best_test_auroc = 0
  130. best_test_accuracy = 0
  131. best_test_backdoor_auroc = 0
  132. best_test_backdoor_accuracy = 0
  133. best_valid_rmse = 100000
  134. print('Training begins now.')
  135. for epoch in range(opt.epochs):
  136. startTime = time.time()
  137. model.train()
  138. running_loss = 0.0
  139. for i, data in enumerate(trainloader, 0):
  140. optimizer.zero_grad()
  141. # x_categ is the the categorical data, x_cont has continuous data, y_gts has ground truth ys. cat_mask is an array of ones same shape as x_categ and an additional column(corresponding to CLS token) set to 0s. con_mask is an array of ones same shape as x_cont.
  142. x_categ, x_cont, y_gts, cat_mask, con_mask = data[0].to(device), data[1].to(device),data[2].to(device),data[3].to(device),data[4].to(device)
  143. # We are converting the data to embeddings in the next step
  144. _ , x_categ_enc, x_cont_enc = embed_data_mask(x_categ, x_cont, cat_mask, con_mask,model,vision_dset)
  145. reps = model.transformer(x_categ_enc, x_cont_enc)
  146. # select only the representations corresponding to CLS token and apply mlp on it in the next step to get the predictions.
  147. y_reps = reps[:,0,:]
  148. y_outs = model.mlpfory(y_reps)
  149. if opt.task == 'regression':
  150. loss = criterion(y_outs,y_gts)
  151. else:
  152. loss = criterion(y_outs,y_gts.squeeze())
  153. loss.backward()
  154. optimizer.step()
  155. if opt.optimizer == 'SGD':
  156. scheduler.step()
  157. running_loss += loss.item()
  158. # print(running_loss)
  159. if epoch%1==0:
  160. model.eval()
  161. with torch.no_grad():
  162. if opt.task in ['binary','multiclass']:
  163. accuracy, auroc = classification_scores(model, validloader, device, opt.task,vision_dset)
  164. test_accuracy, test_auroc = classification_scores(model, testloader, device, opt.task,vision_dset)
  165. test_backdoor_accuracy, test_backdoor_auroc = classification_scores(model, test_backdoorloader, device, opt.task,vision_dset)
  166. print('[EPOCH %d] VALID ACC: %.3f, TEST ACC: %.3f, TEST_BACKDOOR ACC: %.3f' %
  167. (epoch + 1, accuracy,test_accuracy,test_backdoor_accuracy ))
  168. if accuracy > best_valid_accuracy:
  169. best_valid_accuracy = accuracy
  170. best_test_auroc = test_auroc
  171. best_test_accuracy = test_accuracy
  172. best_test_backdoor_auroc = test_backdoor_auroc
  173. best_test_backdoor_accuracy = test_backdoor_accuracy
  174. torch.save(model.state_dict(),'%s/bestmodel.pth' % (modelsave_path))
  175. else:
  176. valid_rmse = mean_sq_error(model, validloader, device,vision_dset)
  177. test_rmse = mean_sq_error(model, testloader, device,vision_dset)
  178. print('[EPOCH %d] VALID RMSE: %.3f' %
  179. (epoch + 1, valid_rmse ))
  180. print('[EPOCH %d] TEST RMSE: %.3f' %
  181. (epoch + 1, test_rmse ))
  182. if opt.active_log:
  183. wandb.log({'valid_rmse': valid_rmse ,'test_rmse': test_rmse })
  184. if valid_rmse < best_valid_rmse:
  185. best_valid_rmse = valid_rmse
  186. best_test_rmse = test_rmse
  187. torch.save(model.state_dict(),'%s/bestmodel.pth' % (modelsave_path))
  188. model.train()
  189. endTime = time.time()
  190. print("EPOCH took", endTime - startTime, "seconds")
  191. total_parameters = count_parameters(model)
  192. print('TOTAL NUMBER OF PARAMS: %d' %(total_parameters))
  193. if opt.task =='binary':
  194. print('Test AUROC on best model: %.3f' %(best_test_auroc))
  195. print('Test accuracy on best model: %.3f' %(best_test_accuracy))
  196. print('Test_backdoor accuracy on best model: %.3f' %(best_test_backdoor_accuracy))
  197. elif opt.task =='multiclass':
  198. print('Test accuracy on best model: %.3f' %(best_test_accuracy))
  199. print('Test_backdoor accuracy on best model: %.3f' %(best_test_backdoor_accuracy))
  200. else:
  201. print('RMSE on best model: %.3f' %(best_test_rmse))
  202. return float(best_test_accuracy)/100, float(best_test_backdoor_accuracy)/100, best_test_auroc