12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182 |
- import torch
- import numpy as np
- def embed_data_mask(x_categ, x_cont, cat_mask, con_mask,model,vision_dset=False):
- device = x_cont.device
- x_categ = x_categ + model.categories_offset.type_as(x_categ)
- x_categ_enc = model.embeds(x_categ)
- n1,n2 = x_cont.shape
- _, n3 = x_categ.shape
- if model.cont_embeddings == 'MLP':
- x_cont_enc = torch.empty(n1,n2, model.dim)
- for i in range(model.num_continuous):
- x_cont_enc[:,i,:] = model.simple_MLP[i](x_cont[:,i])
- else:
- raise Exception('This case should not work!')
- x_cont_enc = x_cont_enc.to(device)
- cat_mask_temp = cat_mask + model.cat_mask_offset.type_as(cat_mask)
- con_mask_temp = con_mask + model.con_mask_offset.type_as(con_mask)
- cat_mask_temp = model.mask_embeds_cat(cat_mask_temp)
- con_mask_temp = model.mask_embeds_cont(con_mask_temp)
- x_categ_enc[cat_mask == 0] = cat_mask_temp[cat_mask == 0]
- x_cont_enc[con_mask == 0] = con_mask_temp[con_mask == 0]
- if vision_dset:
-
- pos = np.tile(np.arange(x_categ.shape[-1]),(x_categ.shape[0],1))
- pos = torch.from_numpy(pos).to(device)
- pos_enc =model.pos_encodings(pos)
- x_categ_enc+=pos_enc
- return x_categ, x_categ_enc, x_cont_enc
- def mixup_data(x1, x2 , lam=1.0, y= None, use_cuda=True):
- '''Returns mixed inputs, pairs of targets'''
- batch_size = x1.size()[0]
- if use_cuda:
- index = torch.randperm(batch_size).cuda()
- else:
- index = torch.randperm(batch_size)
- mixed_x1 = lam * x1 + (1 - lam) * x1[index, :]
- mixed_x2 = lam * x2 + (1 - lam) * x2[index, :]
- if y is not None:
- y_a, y_b = y, y[index]
- return mixed_x1, mixed_x2, y_a, y_b
-
- return mixed_x1, mixed_x2
- def add_noise(x_categ,x_cont, noise_params = {'noise_type' : ['cutmix'],'lambda' : 0.1}):
- lam = noise_params['lambda']
- device = x_categ.device
- batch_size = x_categ.size()[0]
- if 'cutmix' in noise_params['noise_type']:
- index = torch.randperm(batch_size)
- cat_corr = torch.from_numpy(np.random.choice(2,(x_categ.shape),p=[lam,1-lam])).to(device)
- con_corr = torch.from_numpy(np.random.choice(2,(x_cont.shape),p=[lam,1-lam])).to(device)
- x1, x2 = x_categ[index,:], x_cont[index,:]
- x_categ_corr, x_cont_corr = x_categ.clone().detach() ,x_cont.clone().detach()
- x_categ_corr[cat_corr==0] = x1[cat_corr==0]
- x_cont_corr[con_corr==0] = x2[con_corr==0]
- return x_categ_corr, x_cont_corr
- elif noise_params['noise_type'] == 'missing':
- x_categ_mask = np.random.choice(2,(x_categ.shape),p=[lam,1-lam])
- x_cont_mask = np.random.choice(2,(x_cont.shape),p=[lam,1-lam])
- x_categ_mask = torch.from_numpy(x_categ_mask).to(device)
- x_cont_mask = torch.from_numpy(x_cont_mask).to(device)
- return torch.mul(x_categ,x_categ_mask), torch.mul(x_cont,x_cont_mask)
-
- else:
- print("yet to write this")
|