augmentations.py 2.9 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182
  1. import torch
  2. import numpy as np
  3. def embed_data_mask(x_categ, x_cont, cat_mask, con_mask,model,vision_dset=False):
  4. device = x_cont.device
  5. x_categ = x_categ + model.categories_offset.type_as(x_categ)
  6. x_categ_enc = model.embeds(x_categ)
  7. n1,n2 = x_cont.shape
  8. _, n3 = x_categ.shape
  9. if model.cont_embeddings == 'MLP':
  10. x_cont_enc = torch.empty(n1,n2, model.dim)
  11. for i in range(model.num_continuous):
  12. x_cont_enc[:,i,:] = model.simple_MLP[i](x_cont[:,i])
  13. else:
  14. raise Exception('This case should not work!')
  15. x_cont_enc = x_cont_enc.to(device)
  16. cat_mask_temp = cat_mask + model.cat_mask_offset.type_as(cat_mask)
  17. con_mask_temp = con_mask + model.con_mask_offset.type_as(con_mask)
  18. cat_mask_temp = model.mask_embeds_cat(cat_mask_temp)
  19. con_mask_temp = model.mask_embeds_cont(con_mask_temp)
  20. x_categ_enc[cat_mask == 0] = cat_mask_temp[cat_mask == 0]
  21. x_cont_enc[con_mask == 0] = con_mask_temp[con_mask == 0]
  22. if vision_dset:
  23. pos = np.tile(np.arange(x_categ.shape[-1]),(x_categ.shape[0],1))
  24. pos = torch.from_numpy(pos).to(device)
  25. pos_enc =model.pos_encodings(pos)
  26. x_categ_enc+=pos_enc
  27. return x_categ, x_categ_enc, x_cont_enc
  28. def mixup_data(x1, x2 , lam=1.0, y= None, use_cuda=True):
  29. '''Returns mixed inputs, pairs of targets'''
  30. batch_size = x1.size()[0]
  31. if use_cuda:
  32. index = torch.randperm(batch_size).cuda()
  33. else:
  34. index = torch.randperm(batch_size)
  35. mixed_x1 = lam * x1 + (1 - lam) * x1[index, :]
  36. mixed_x2 = lam * x2 + (1 - lam) * x2[index, :]
  37. if y is not None:
  38. y_a, y_b = y, y[index]
  39. return mixed_x1, mixed_x2, y_a, y_b
  40. return mixed_x1, mixed_x2
  41. def add_noise(x_categ,x_cont, noise_params = {'noise_type' : ['cutmix'],'lambda' : 0.1}):
  42. lam = noise_params['lambda']
  43. device = x_categ.device
  44. batch_size = x_categ.size()[0]
  45. if 'cutmix' in noise_params['noise_type']:
  46. index = torch.randperm(batch_size)
  47. cat_corr = torch.from_numpy(np.random.choice(2,(x_categ.shape),p=[lam,1-lam])).to(device)
  48. con_corr = torch.from_numpy(np.random.choice(2,(x_cont.shape),p=[lam,1-lam])).to(device)
  49. x1, x2 = x_categ[index,:], x_cont[index,:]
  50. x_categ_corr, x_cont_corr = x_categ.clone().detach() ,x_cont.clone().detach()
  51. x_categ_corr[cat_corr==0] = x1[cat_corr==0]
  52. x_cont_corr[con_corr==0] = x2[con_corr==0]
  53. return x_categ_corr, x_cont_corr
  54. elif noise_params['noise_type'] == 'missing':
  55. x_categ_mask = np.random.choice(2,(x_categ.shape),p=[lam,1-lam])
  56. x_cont_mask = np.random.choice(2,(x_cont.shape),p=[lam,1-lam])
  57. x_categ_mask = torch.from_numpy(x_categ_mask).to(device)
  58. x_cont_mask = torch.from_numpy(x_cont_mask).to(device)
  59. return torch.mul(x_categ,x_categ_mask), torch.mul(x_cont,x_cont_mask)
  60. else:
  61. print("yet to write this")