1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798 |
- import torch
- from torch import nn
- from baselines.data_openml import data_prep_openml,task_dset_ids,DataSetCatCon
- from torch.utils.data import DataLoader
- import torch.optim as optim
- from .augmentations import embed_data_mask
- from .augmentations import add_noise
- import os
- import numpy as np
- def SAINT_pretrain(model,cat_idxs,X_train,y_train,continuous_mean_std,opt,device):
- train_ds = DataSetCatCon(X_train, y_train, cat_idxs,opt.dtask, continuous_mean_std)
- trainloader = DataLoader(train_ds, batch_size=opt.batchsize, shuffle=True,num_workers=4)
- vision_dset = opt.vision_dset
- optimizer = optim.AdamW(model.parameters(),lr=0.0001)
- pt_aug_dict = {
- 'noise_type' : opt.pt_aug,
- 'lambda' : opt.pt_aug_lam
- }
- criterion1 = nn.CrossEntropyLoss()
- criterion2 = nn.MSELoss()
- print("Pretraining begins!")
- for epoch in range(opt.pretrain_epochs):
- model.train()
- running_loss = 0.0
- for i, data in enumerate(trainloader, 0):
- optimizer.zero_grad()
- x_categ, x_cont, _ ,cat_mask, con_mask = data[0].to(device), data[1].to(device),data[2].to(device),data[3].to(device),data[4].to(device)
-
- # embed_data_mask function is used to embed both categorical and continuous data.
- if 'cutmix' in opt.pt_aug:
- from augmentations import add_noise
- x_categ_corr, x_cont_corr = add_noise(x_categ,x_cont, noise_params = pt_aug_dict)
- _ , x_categ_enc_2, x_cont_enc_2 = embed_data_mask(x_categ_corr, x_cont_corr, cat_mask, con_mask,model,vision_dset)
- else:
- _ , x_categ_enc_2, x_cont_enc_2 = embed_data_mask(x_categ, x_cont, cat_mask, con_mask,model,vision_dset)
- _ , x_categ_enc, x_cont_enc = embed_data_mask(x_categ, x_cont, cat_mask, con_mask,model,vision_dset)
-
- if 'mixup' in opt.pt_aug:
- from augmentations import mixup_data
- x_categ_enc_2, x_cont_enc_2 = mixup_data(x_categ_enc_2, x_cont_enc_2 , lam=opt.mixup_lam)
- loss = 0
- if 'contrastive' in opt.pt_tasks:
- aug_features_1 = model.transformer(x_categ_enc, x_cont_enc)
- aug_features_2 = model.transformer(x_categ_enc_2, x_cont_enc_2)
- aug_features_1 = (aug_features_1 / aug_features_1.norm(dim=-1, keepdim=True)).flatten(1,2)
- aug_features_2 = (aug_features_2 / aug_features_2.norm(dim=-1, keepdim=True)).flatten(1,2)
- if opt.pt_projhead_style == 'diff':
- aug_features_1 = model.pt_mlp(aug_features_1)
- aug_features_2 = model.pt_mlp2(aug_features_2)
- elif opt.pt_projhead_style == 'same':
- aug_features_1 = model.pt_mlp(aug_features_1)
- aug_features_2 = model.pt_mlp(aug_features_2)
- else:
- print('Not using projection head')
- logits_per_aug1 = aug_features_1 @ aug_features_2.t()/opt.nce_temp
- logits_per_aug2 = aug_features_2 @ aug_features_1.t()/opt.nce_temp
- targets = torch.arange(logits_per_aug1.size(0)).to(logits_per_aug1.device)
- loss_1 = criterion1(logits_per_aug1, targets)
- loss_2 = criterion1(logits_per_aug2, targets)
- loss = opt.lam0*(loss_1 + loss_2)/2
- elif 'contrastive_sim' in opt.pt_tasks:
- aug_features_1 = model.transformer(x_categ_enc, x_cont_enc)
- aug_features_2 = model.transformer(x_categ_enc_2, x_cont_enc_2)
- aug_features_1 = (aug_features_1 / aug_features_1.norm(dim=-1, keepdim=True)).flatten(1,2)
- aug_features_2 = (aug_features_2 / aug_features_2.norm(dim=-1, keepdim=True)).flatten(1,2)
- aug_features_1 = model.pt_mlp(aug_features_1)
- aug_features_2 = model.pt_mlp2(aug_features_2)
- c1 = aug_features_1 @ aug_features_2.t()
- loss+= opt.lam1*torch.diagonal(-1*c1).add_(1).pow_(2).sum()
- if 'denoising' in opt.pt_tasks:
- cat_outs, con_outs = model(x_categ_enc_2, x_cont_enc_2)
- # if con_outs.shape(-1) != 0:
- # import ipdb; ipdb.set_trace()
- if len(con_outs) > 0:
- con_outs = torch.cat(con_outs,dim=1)
- l2 = criterion2(con_outs, x_cont)
- else:
- l2 = 0
- l1 = 0
- # import ipdb; ipdb.set_trace()
- n_cat = x_categ.shape[-1]
- for j in range(1,n_cat):
- l1+= criterion1(cat_outs[j],x_categ[:,j])
- loss += opt.lam2*l1 + opt.lam3*l2
- loss.backward()
- optimizer.step()
- running_loss += loss.item()
-
- print(f'Epoch: {epoch}, Running Loss: {running_loss}')
- print('END OF PRETRAINING!')
- return model
- # if opt.active_log:
- # wandb.log({'pt_epoch': epoch ,'pretrain_epoch_loss': running_loss
- # })
|