1
0

model.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307
  1. import torch
  2. import torch.nn.functional as F
  3. from torch import nn, einsum
  4. import numpy as np
  5. from einops import rearrange
  6. # helpers
  7. def exists(val):
  8. return val is not None
  9. def default(val, d):
  10. return val if exists(val) else d
  11. def ff_encodings(x,B):
  12. x_proj = (2. * np.pi * x.unsqueeze(-1)) @ B.t()
  13. return torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1)
  14. # classes
  15. class Residual(nn.Module):
  16. def __init__(self, fn):
  17. super().__init__()
  18. self.fn = fn
  19. def forward(self, x, **kwargs):
  20. return self.fn(x, **kwargs) + x
  21. class PreNorm(nn.Module):
  22. def __init__(self, dim, fn):
  23. super().__init__()
  24. self.norm = nn.LayerNorm(dim)
  25. self.fn = fn
  26. def forward(self, x, **kwargs):
  27. return self.fn(self.norm(x), **kwargs)
  28. # attention
  29. class GEGLU(nn.Module):
  30. def forward(self, x):
  31. x, gates = x.chunk(2, dim = -1)
  32. return x * F.gelu(gates)
  33. class FeedForward(nn.Module):
  34. def __init__(self, dim, mult = 4, dropout = 0.):
  35. super().__init__()
  36. self.net = nn.Sequential(
  37. nn.Linear(dim, dim * mult * 2),
  38. GEGLU(),
  39. nn.Dropout(dropout),
  40. nn.Linear(dim * mult, dim)
  41. )
  42. def forward(self, x, **kwargs):
  43. return self.net(x)
  44. class Attention(nn.Module):
  45. def __init__(
  46. self,
  47. dim,
  48. heads = 8,
  49. dim_head = 16,
  50. dropout = 0.
  51. ):
  52. super().__init__()
  53. inner_dim = dim_head * heads
  54. self.heads = heads
  55. self.scale = dim_head ** -0.5
  56. self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
  57. self.to_out = nn.Linear(inner_dim, dim)
  58. self.dropout = nn.Dropout(dropout)
  59. def forward(self, x):
  60. h = self.heads
  61. q, k, v = self.to_qkv(x).chunk(3, dim = -1)
  62. q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v))
  63. sim = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
  64. attn = sim.softmax(dim = -1)
  65. out = einsum('b h i j, b h j d -> b h i d', attn, v)
  66. out = rearrange(out, 'b h n d -> b n (h d)', h = h)
  67. return self.to_out(out)
  68. class RowColTransformer(nn.Module):
  69. def __init__(self, num_tokens, dim, nfeats, depth, heads, dim_head, attn_dropout, ff_dropout,style='col'):
  70. super().__init__()
  71. self.embeds = nn.Embedding(num_tokens, dim)
  72. self.layers = nn.ModuleList([])
  73. self.mask_embed = nn.Embedding(nfeats, dim)
  74. self.style = style
  75. for _ in range(depth):
  76. if self.style == 'colrow':
  77. self.layers.append(nn.ModuleList([
  78. PreNorm(dim, Residual(Attention(dim, heads = heads, dim_head = dim_head, dropout = attn_dropout))),
  79. PreNorm(dim, Residual(FeedForward(dim, dropout = ff_dropout))),
  80. PreNorm(dim*nfeats, Residual(Attention(dim*nfeats, heads = heads, dim_head = 64, dropout = attn_dropout))),
  81. PreNorm(dim*nfeats, Residual(FeedForward(dim*nfeats, dropout = ff_dropout))),
  82. ]))
  83. else:
  84. self.layers.append(nn.ModuleList([
  85. PreNorm(dim*nfeats, Residual(Attention(dim*nfeats, heads = heads, dim_head = 64, dropout = attn_dropout))),
  86. PreNorm(dim*nfeats, Residual(FeedForward(dim*nfeats, dropout = ff_dropout))),
  87. ]))
  88. def forward(self, x, x_cont=None, mask = None):
  89. if x_cont is not None:
  90. x = torch.cat((x,x_cont),dim=1)
  91. _, n, _ = x.shape
  92. if self.style == 'colrow':
  93. for attn1, ff1, attn2, ff2 in self.layers:
  94. x = attn1(x)
  95. x = ff1(x)
  96. x = rearrange(x, 'b n d -> 1 b (n d)')
  97. x = attn2(x)
  98. x = ff2(x)
  99. x = rearrange(x, '1 b (n d) -> b n d', n = n)
  100. else:
  101. for attn1, ff1 in self.layers:
  102. x = rearrange(x, 'b n d -> 1 b (n d)')
  103. x = attn1(x)
  104. x = ff1(x)
  105. x = rearrange(x, '1 b (n d) -> b n d', n = n)
  106. return x
  107. # transformer
  108. class Transformer(nn.Module):
  109. def __init__(self, num_tokens, dim, depth, heads, dim_head, attn_dropout, ff_dropout):
  110. super().__init__()
  111. self.layers = nn.ModuleList([])
  112. for _ in range(depth):
  113. self.layers.append(nn.ModuleList([
  114. PreNorm(dim, Residual(Attention(dim, heads = heads, dim_head = dim_head, dropout = attn_dropout))),
  115. PreNorm(dim, Residual(FeedForward(dim, dropout = ff_dropout))),
  116. ]))
  117. def forward(self, x, x_cont=None):
  118. if x_cont is not None:
  119. x = torch.cat((x,x_cont),dim=1)
  120. for attn, ff in self.layers:
  121. x = attn(x)
  122. x = ff(x)
  123. return x
  124. #mlp
  125. class MLP(nn.Module):
  126. def __init__(self, dims, act = None):
  127. super().__init__()
  128. dims_pairs = list(zip(dims[:-1], dims[1:]))
  129. layers = []
  130. for ind, (dim_in, dim_out) in enumerate(dims_pairs):
  131. is_last = ind >= (len(dims) - 1)
  132. linear = nn.Linear(dim_in, dim_out)
  133. layers.append(linear)
  134. if is_last:
  135. continue
  136. if act is not None:
  137. layers.append(act)
  138. self.mlp = nn.Sequential(*layers)
  139. def forward(self, x):
  140. return self.mlp(x)
  141. class simple_MLP(nn.Module):
  142. def __init__(self,dims):
  143. super(simple_MLP, self).__init__()
  144. self.layers = nn.Sequential(
  145. nn.Linear(dims[0], dims[1]),
  146. nn.ReLU(),
  147. nn.Linear(dims[1], dims[2])
  148. )
  149. def forward(self, x):
  150. if len(x.shape)==1:
  151. x = x.view(x.size(0), -1)
  152. x = self.layers(x)
  153. return x
  154. # main class
  155. class TabAttention(nn.Module):
  156. def __init__(
  157. self,
  158. *,
  159. categories,
  160. num_continuous,
  161. dim,
  162. depth,
  163. heads,
  164. dim_head = 16,
  165. dim_out = 1,
  166. mlp_hidden_mults = (4, 2),
  167. mlp_act = None,
  168. num_special_tokens = 1,
  169. continuous_mean_std = None,
  170. attn_dropout = 0.,
  171. ff_dropout = 0.,
  172. lastmlp_dropout = 0.,
  173. cont_embeddings = 'MLP',
  174. scalingfactor = 10,
  175. attentiontype = 'col'
  176. ):
  177. super().__init__()
  178. assert all(map(lambda n: n > 0, categories)), 'number of each category must be positive'
  179. # categories related calculations
  180. self.num_categories = len(categories)
  181. self.num_unique_categories = sum(categories)
  182. # create category embeddings table
  183. self.num_special_tokens = num_special_tokens
  184. self.total_tokens = self.num_unique_categories + num_special_tokens
  185. # for automatically offsetting unique category ids to the correct position in the categories embedding table
  186. categories_offset = F.pad(torch.tensor(list(categories)), (1, 0), value = num_special_tokens)
  187. categories_offset = categories_offset.cumsum(dim = -1)[:-1]
  188. self.register_buffer('categories_offset', categories_offset)
  189. self.norm = nn.LayerNorm(num_continuous)
  190. self.num_continuous = num_continuous
  191. self.dim = dim
  192. self.cont_embeddings = cont_embeddings
  193. self.attentiontype = attentiontype
  194. if self.cont_embeddings == 'MLP':
  195. self.simple_MLP = nn.ModuleList([simple_MLP([1,100,self.dim]) for _ in range(self.num_continuous)])
  196. input_size = (dim * self.num_categories) + (dim * num_continuous)
  197. nfeats = self.num_categories + num_continuous
  198. else:
  199. print('Continous features are not passed through attention')
  200. input_size = (dim * self.num_categories) + num_continuous
  201. nfeats = self.num_categories
  202. # transformer
  203. if attentiontype == 'col':
  204. self.transformer = Transformer(
  205. num_tokens = self.total_tokens,
  206. dim = dim,
  207. depth = depth,
  208. heads = heads,
  209. dim_head = dim_head,
  210. attn_dropout = attn_dropout,
  211. ff_dropout = ff_dropout
  212. )
  213. elif attentiontype in ['row','colrow'] :
  214. self.transformer = RowColTransformer(
  215. num_tokens = self.total_tokens,
  216. dim = dim,
  217. nfeats= nfeats,
  218. depth = depth,
  219. heads = heads,
  220. dim_head = dim_head,
  221. attn_dropout = attn_dropout,
  222. ff_dropout = ff_dropout,
  223. style = attentiontype
  224. )
  225. l = input_size // 8
  226. hidden_dimensions = list(map(lambda t: l * t, mlp_hidden_mults))
  227. all_dimensions = [input_size, *hidden_dimensions, dim_out]
  228. self.mlp = MLP(all_dimensions, act = mlp_act)
  229. self.embeds = nn.Embedding(self.total_tokens, self.dim) #.to(device)
  230. cat_mask_offset = F.pad(torch.Tensor(self.num_categories).fill_(2).type(torch.int8), (1, 0), value = 0)
  231. cat_mask_offset = cat_mask_offset.cumsum(dim = -1)[:-1]
  232. con_mask_offset = F.pad(torch.Tensor(self.num_continuous).fill_(2).type(torch.int8), (1, 0), value = 0)
  233. con_mask_offset = con_mask_offset.cumsum(dim = -1)[:-1]
  234. self.register_buffer('cat_mask_offset', cat_mask_offset)
  235. self.register_buffer('con_mask_offset', con_mask_offset)
  236. self.mask_embeds_cat = nn.Embedding(self.num_categories*2, self.dim)
  237. self.mask_embeds_cont = nn.Embedding(self.num_continuous*2, self.dim)
  238. def forward(self, x_categ, x_cont,x_categ_enc,x_cont_enc):
  239. device = x_categ.device
  240. if self.attentiontype == 'justmlp':
  241. if x_categ.shape[-1] > 0:
  242. flat_categ = x_categ.flatten(1).to(device)
  243. x = torch.cat((flat_categ, x_cont.flatten(1).to(device)), dim = -1)
  244. else:
  245. x = x_cont.clone()
  246. else:
  247. if self.cont_embeddings == 'MLP':
  248. x = self.transformer(x_categ_enc,x_cont_enc.to(device))
  249. else:
  250. if x_categ.shape[-1] <= 0:
  251. x = x_cont.clone()
  252. else:
  253. flat_categ = self.transformer(x_categ_enc).flatten(1)
  254. x = torch.cat((flat_categ, x_cont), dim = -1)
  255. flat_x = x.flatten(1)
  256. return self.mlp(flat_x)