123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242 |
- import torch
- from torch import nn
- from .models import SAINT
- from .data_openml import data_prep_openml,task_dset_ids,DataSetCatCon
- import argparse
- from torch.utils.data import DataLoader
- import torch.optim as optim
- from .utils import count_parameters, classification_scores, mean_sq_error, prepareData
- from .augmentations import embed_data_mask
- from .augmentations import add_noise
- import os
- import time
- import numpy as np
- class SaintLib():
- def __init__(
- self,
- config
- ):
- self.config = config # Config can be something like ['--foo', 'FOO']
- def fit(self, train, valid, test, test_backdoor, cat_cols, num_cols, target):
- config = self.config # quick dirty method
- parser = argparse.ArgumentParser()
- parser.add_argument('--device', default='cuda:0', type=str)
- parser.add_argument('--vision_dset', action = 'store_true')
- parser.add_argument('--task', default='multiclass', type=str,choices = ['binary','multiclass','regression'])
- parser.add_argument('--cont_embeddings', default='MLP', type=str,choices = ['MLP','Noemb','pos_singleMLP'])
- parser.add_argument('--embedding_size', default=32, type=int)
- parser.add_argument('--transformer_depth', default=2, type=int)
- parser.add_argument('--attention_heads', default=4, type=int)
- parser.add_argument('--attention_dropout', default=0.1, type=float)
- parser.add_argument('--ff_dropout', default=0.1, type=float)
- parser.add_argument('--attentiontype', default='colrow', type=str,choices = ['col','colrow','row','justmlp','attn','attnmlp'])
- parser.add_argument('--optimizer', default='AdamW', type=str,choices = ['AdamW','Adam','SGD'])
- parser.add_argument('--scheduler', default='cosine', type=str,choices = ['cosine','linear'])
- parser.add_argument('--lr', default=0.0001, type=float)
- parser.add_argument('--epochs', default=50, type=int)
- parser.add_argument('--batchsize', default=512, type=int)
- parser.add_argument('--savemodelroot', default='./bestmodels', type=str)
- parser.add_argument('--run_name', default='testrun', type=str)
- parser.add_argument('--set_seed', default= 1 , type=int)
- parser.add_argument('--dset_seed', default= 5 , type=int)
- parser.add_argument('--active_log', action = 'store_true')
- parser.add_argument('--pretrain', action = 'store_true')
- parser.add_argument('--pretrain_epochs', default=50, type=int)
- parser.add_argument('--pt_tasks', default=['contrastive','denoising'], type=str,nargs='*',choices = ['contrastive','contrastive_sim','denoising'])
- parser.add_argument('--pt_aug', default=[], type=str,nargs='*',choices = ['mixup','cutmix'])
- parser.add_argument('--pt_aug_lam', default=0.1, type=float)
- parser.add_argument('--mixup_lam', default=0.3, type=float)
- parser.add_argument('--train_mask_prob', default=0, type=float)
- parser.add_argument('--mask_prob', default=0, type=float)
- parser.add_argument('--ssl_avail_y', default= 0, type=int)
- parser.add_argument('--pt_projhead_style', default='diff', type=str,choices = ['diff','same','nohead'])
- parser.add_argument('--nce_temp', default=0.7, type=float)
- parser.add_argument('--lam0', default=0.5, type=float)
- parser.add_argument('--lam1', default=10, type=float)
- parser.add_argument('--lam2', default=1, type=float)
- parser.add_argument('--lam3', default=10, type=float)
- parser.add_argument('--final_mlp_style', default='sep', type=str,choices = ['common','sep'])
- opt = parser.parse_args(self.config)
- # print(opt)
- modelsave_path = os.path.join(os.getcwd(),opt.savemodelroot,opt.task,opt.run_name)
- if opt.task == 'regression':
- opt.dtask = 'reg'
- else:
- opt.dtask = 'clf'
- device = torch.device(opt.device)
- print(f"Device is {device}.")
- torch.manual_seed(opt.set_seed)
- os.makedirs(modelsave_path, exist_ok=True)
- (cat_dims, cat_idxs, con_idxs,
- X_train, y_train, X_valid, y_valid, X_test, y_test, X_test_backdoor, y_test_backdoor,
- train_mean, train_std) = prepareData(train, valid, test, test_backdoor, cat_cols, num_cols, target)
- continuous_mean_std = np.array([train_mean,train_std]).astype(np.float32)
- train_ds = DataSetCatCon(X_train, y_train, cat_idxs,opt.dtask,continuous_mean_std)
- trainloader = DataLoader(train_ds, batch_size=opt.batchsize, shuffle=True,num_workers=4)
- valid_ds = DataSetCatCon(X_valid, y_valid, cat_idxs,opt.dtask, continuous_mean_std)
- validloader = DataLoader(valid_ds, batch_size=opt.batchsize, shuffle=False,num_workers=1)
- test_ds = DataSetCatCon(X_test, y_test, cat_idxs,opt.dtask, continuous_mean_std)
- testloader = DataLoader(test_ds, batch_size=opt.batchsize, shuffle=False,num_workers=1)
- test_backdoor_ds = DataSetCatCon(X_test_backdoor, y_test_backdoor, cat_idxs,opt.dtask, continuous_mean_std)
- test_backdoorloader = DataLoader(test_backdoor_ds, batch_size=opt.batchsize, shuffle=False,num_workers=1)
- if opt.task == 'regression':
- y_dim = 1
- else:
- y_dim = len(np.unique(y_train['data'][:,0]))
- cat_dims = np.append(np.array([1]),np.array(cat_dims)).astype(int) #Appending 1 for CLS token, this is later used to generate embeddings.
- model = SAINT(
- categories = tuple(cat_dims),
- num_continuous = len(con_idxs),
- dim = opt.embedding_size,
- dim_out = 1,
- depth = opt.transformer_depth,
- heads = opt.attention_heads,
- attn_dropout = opt.attention_dropout,
- ff_dropout = opt.ff_dropout,
- mlp_hidden_mults = (4, 2),
- cont_embeddings = opt.cont_embeddings,
- attentiontype = opt.attentiontype,
- final_mlp_style = opt.final_mlp_style,
- y_dim = y_dim
- )
- vision_dset = opt.vision_dset
- if y_dim == 2 and opt.task == 'binary':
- # opt.task = 'binary'
- criterion = nn.CrossEntropyLoss().to(device)
- elif y_dim > 2 and opt.task == 'multiclass':
- # opt.task = 'multiclass'
- criterion = nn.CrossEntropyLoss().to(device)
- elif opt.task == 'regression':
- criterion = nn.MSELoss().to(device)
- else:
- raise'case not written yet'
- model.to(device)
- if opt.pretrain:
- from pretraining import SAINT_pretrain
- model = SAINT_pretrain(model, cat_idxs,X_train,y_train, continuous_mean_std, opt,device)
- ## Choosing the optimizer
- if opt.optimizer == 'SGD':
- optimizer = optim.SGD(model.parameters(), lr=opt.lr,
- momentum=0.9, weight_decay=5e-4)
- from utils import get_scheduler
- scheduler = get_scheduler(opt, optimizer)
- elif opt.optimizer == 'Adam':
- optimizer = optim.Adam(model.parameters(),lr=opt.lr)
- elif opt.optimizer == 'AdamW':
- optimizer = optim.AdamW(model.parameters(),lr=opt.lr)
- best_valid_auroc = 0
- best_valid_accuracy = 0
- best_test_auroc = 0
- best_test_accuracy = 0
- best_test_backdoor_auroc = 0
- best_test_backdoor_accuracy = 0
- best_valid_rmse = 100000
- print('Training begins now.')
- for epoch in range(opt.epochs):
- startTime = time.time()
- model.train()
- running_loss = 0.0
- for i, data in enumerate(trainloader, 0):
- optimizer.zero_grad()
- # x_categ is the the categorical data, x_cont has continuous data, y_gts has ground truth ys. cat_mask is an array of ones same shape as x_categ and an additional column(corresponding to CLS token) set to 0s. con_mask is an array of ones same shape as x_cont.
- x_categ, x_cont, y_gts, cat_mask, con_mask = data[0].to(device), data[1].to(device),data[2].to(device),data[3].to(device),data[4].to(device)
- # We are converting the data to embeddings in the next step
- _ , x_categ_enc, x_cont_enc = embed_data_mask(x_categ, x_cont, cat_mask, con_mask,model,vision_dset)
- reps = model.transformer(x_categ_enc, x_cont_enc)
- # select only the representations corresponding to CLS token and apply mlp on it in the next step to get the predictions.
- y_reps = reps[:,0,:]
-
- y_outs = model.mlpfory(y_reps)
- if opt.task == 'regression':
- loss = criterion(y_outs,y_gts)
- else:
- loss = criterion(y_outs,y_gts.squeeze())
- loss.backward()
- optimizer.step()
- if opt.optimizer == 'SGD':
- scheduler.step()
- running_loss += loss.item()
- # print(running_loss)
- if epoch%1==0:
- model.eval()
- with torch.no_grad():
- if opt.task in ['binary','multiclass']:
- accuracy, auroc = classification_scores(model, validloader, device, opt.task,vision_dset)
- test_accuracy, test_auroc = classification_scores(model, testloader, device, opt.task,vision_dset)
- test_backdoor_accuracy, test_backdoor_auroc = classification_scores(model, test_backdoorloader, device, opt.task,vision_dset)
- print('[EPOCH %d] VALID ACC: %.3f, TEST ACC: %.3f, TEST_BACKDOOR ACC: %.3f' %
- (epoch + 1, accuracy,test_accuracy,test_backdoor_accuracy ))
- if accuracy > best_valid_accuracy:
- best_valid_accuracy = accuracy
- best_test_auroc = test_auroc
- best_test_accuracy = test_accuracy
- best_test_backdoor_auroc = test_backdoor_auroc
- best_test_backdoor_accuracy = test_backdoor_accuracy
- torch.save(model.state_dict(),'%s/bestmodel.pth' % (modelsave_path))
- else:
- valid_rmse = mean_sq_error(model, validloader, device,vision_dset)
- test_rmse = mean_sq_error(model, testloader, device,vision_dset)
- print('[EPOCH %d] VALID RMSE: %.3f' %
- (epoch + 1, valid_rmse ))
- print('[EPOCH %d] TEST RMSE: %.3f' %
- (epoch + 1, test_rmse ))
- if opt.active_log:
- wandb.log({'valid_rmse': valid_rmse ,'test_rmse': test_rmse })
- if valid_rmse < best_valid_rmse:
- best_valid_rmse = valid_rmse
- best_test_rmse = test_rmse
- torch.save(model.state_dict(),'%s/bestmodel.pth' % (modelsave_path))
- model.train()
- endTime = time.time()
- print("EPOCH took", endTime - startTime, "seconds")
- total_parameters = count_parameters(model)
- print('TOTAL NUMBER OF PARAMS: %d' %(total_parameters))
- if opt.task =='binary':
- print('Test AUROC on best model: %.3f' %(best_test_auroc))
- print('Test accuracy on best model: %.3f' %(best_test_accuracy))
- print('Test_backdoor accuracy on best model: %.3f' %(best_test_backdoor_accuracy))
- elif opt.task =='multiclass':
- print('Test accuracy on best model: %.3f' %(best_test_accuracy))
- print('Test_backdoor accuracy on best model: %.3f' %(best_test_backdoor_accuracy))
- else:
- print('RMSE on best model: %.3f' %(best_test_rmse))
- return float(best_test_accuracy)/100, float(best_test_backdoor_accuracy)/100, best_test_auroc
|