1
0

pretraining.py 5.0 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798
  1. import torch
  2. from torch import nn
  3. from baselines.data_openml import data_prep_openml,task_dset_ids,DataSetCatCon
  4. from torch.utils.data import DataLoader
  5. import torch.optim as optim
  6. from .augmentations import embed_data_mask
  7. from .augmentations import add_noise
  8. import os
  9. import numpy as np
  10. def SAINT_pretrain(model,cat_idxs,X_train,y_train,continuous_mean_std,opt,device):
  11. train_ds = DataSetCatCon(X_train, y_train, cat_idxs,opt.dtask, continuous_mean_std)
  12. trainloader = DataLoader(train_ds, batch_size=opt.batchsize, shuffle=True,num_workers=4)
  13. vision_dset = opt.vision_dset
  14. optimizer = optim.AdamW(model.parameters(),lr=0.0001)
  15. pt_aug_dict = {
  16. 'noise_type' : opt.pt_aug,
  17. 'lambda' : opt.pt_aug_lam
  18. }
  19. criterion1 = nn.CrossEntropyLoss()
  20. criterion2 = nn.MSELoss()
  21. print("Pretraining begins!")
  22. for epoch in range(opt.pretrain_epochs):
  23. model.train()
  24. running_loss = 0.0
  25. for i, data in enumerate(trainloader, 0):
  26. optimizer.zero_grad()
  27. x_categ, x_cont, _ ,cat_mask, con_mask = data[0].to(device), data[1].to(device),data[2].to(device),data[3].to(device),data[4].to(device)
  28. # embed_data_mask function is used to embed both categorical and continuous data.
  29. if 'cutmix' in opt.pt_aug:
  30. from augmentations import add_noise
  31. x_categ_corr, x_cont_corr = add_noise(x_categ,x_cont, noise_params = pt_aug_dict)
  32. _ , x_categ_enc_2, x_cont_enc_2 = embed_data_mask(x_categ_corr, x_cont_corr, cat_mask, con_mask,model,vision_dset)
  33. else:
  34. _ , x_categ_enc_2, x_cont_enc_2 = embed_data_mask(x_categ, x_cont, cat_mask, con_mask,model,vision_dset)
  35. _ , x_categ_enc, x_cont_enc = embed_data_mask(x_categ, x_cont, cat_mask, con_mask,model,vision_dset)
  36. if 'mixup' in opt.pt_aug:
  37. from augmentations import mixup_data
  38. x_categ_enc_2, x_cont_enc_2 = mixup_data(x_categ_enc_2, x_cont_enc_2 , lam=opt.mixup_lam)
  39. loss = 0
  40. if 'contrastive' in opt.pt_tasks:
  41. aug_features_1 = model.transformer(x_categ_enc, x_cont_enc)
  42. aug_features_2 = model.transformer(x_categ_enc_2, x_cont_enc_2)
  43. aug_features_1 = (aug_features_1 / aug_features_1.norm(dim=-1, keepdim=True)).flatten(1,2)
  44. aug_features_2 = (aug_features_2 / aug_features_2.norm(dim=-1, keepdim=True)).flatten(1,2)
  45. if opt.pt_projhead_style == 'diff':
  46. aug_features_1 = model.pt_mlp(aug_features_1)
  47. aug_features_2 = model.pt_mlp2(aug_features_2)
  48. elif opt.pt_projhead_style == 'same':
  49. aug_features_1 = model.pt_mlp(aug_features_1)
  50. aug_features_2 = model.pt_mlp(aug_features_2)
  51. else:
  52. print('Not using projection head')
  53. logits_per_aug1 = aug_features_1 @ aug_features_2.t()/opt.nce_temp
  54. logits_per_aug2 = aug_features_2 @ aug_features_1.t()/opt.nce_temp
  55. targets = torch.arange(logits_per_aug1.size(0)).to(logits_per_aug1.device)
  56. loss_1 = criterion1(logits_per_aug1, targets)
  57. loss_2 = criterion1(logits_per_aug2, targets)
  58. loss = opt.lam0*(loss_1 + loss_2)/2
  59. elif 'contrastive_sim' in opt.pt_tasks:
  60. aug_features_1 = model.transformer(x_categ_enc, x_cont_enc)
  61. aug_features_2 = model.transformer(x_categ_enc_2, x_cont_enc_2)
  62. aug_features_1 = (aug_features_1 / aug_features_1.norm(dim=-1, keepdim=True)).flatten(1,2)
  63. aug_features_2 = (aug_features_2 / aug_features_2.norm(dim=-1, keepdim=True)).flatten(1,2)
  64. aug_features_1 = model.pt_mlp(aug_features_1)
  65. aug_features_2 = model.pt_mlp2(aug_features_2)
  66. c1 = aug_features_1 @ aug_features_2.t()
  67. loss+= opt.lam1*torch.diagonal(-1*c1).add_(1).pow_(2).sum()
  68. if 'denoising' in opt.pt_tasks:
  69. cat_outs, con_outs = model(x_categ_enc_2, x_cont_enc_2)
  70. # if con_outs.shape(-1) != 0:
  71. # import ipdb; ipdb.set_trace()
  72. if len(con_outs) > 0:
  73. con_outs = torch.cat(con_outs,dim=1)
  74. l2 = criterion2(con_outs, x_cont)
  75. else:
  76. l2 = 0
  77. l1 = 0
  78. # import ipdb; ipdb.set_trace()
  79. n_cat = x_categ.shape[-1]
  80. for j in range(1,n_cat):
  81. l1+= criterion1(cat_outs[j],x_categ[:,j])
  82. loss += opt.lam2*l1 + opt.lam3*l2
  83. loss.backward()
  84. optimizer.step()
  85. running_loss += loss.item()
  86. print(f'Epoch: {epoch}, Running Loss: {running_loss}')
  87. print('END OF PRETRAINING!')
  88. return model
  89. # if opt.active_log:
  90. # wandb.log({'pt_epoch': epoch ,'pretrain_epoch_loss': running_loss
  91. # })