1
0

ft_transformer.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499
  1. # %%
  2. import math
  3. import typing as ty
  4. from pathlib import Path
  5. import numpy as np
  6. import torch
  7. import torch.nn as nn
  8. import torch.nn.functional as F
  9. import torch.nn.init as nn_init
  10. import zero
  11. from torch import Tensor
  12. from . import lib
  13. # %%
  14. class Tokenizer(nn.Module):
  15. category_offsets: ty.Optional[Tensor]
  16. def __init__(
  17. self,
  18. d_numerical: int,
  19. categories: ty.Optional[ty.List[int]],
  20. d_token: int,
  21. bias: bool,
  22. ) -> None:
  23. super().__init__()
  24. if categories is None:
  25. d_bias = d_numerical
  26. self.category_offsets = None
  27. self.category_embeddings = None
  28. else:
  29. d_bias = d_numerical + len(categories)
  30. category_offsets = torch.tensor([0] + categories[:-1]).cumsum(0)
  31. self.register_buffer('category_offsets', category_offsets)
  32. self.category_embeddings = nn.Embedding(sum(categories), d_token)
  33. nn_init.kaiming_uniform_(self.category_embeddings.weight, a=math.sqrt(5))
  34. print(f'{self.category_embeddings.weight.shape=}')
  35. # take [CLS] token into account
  36. self.weight = nn.Parameter(Tensor(d_numerical + 1, d_token))
  37. self.bias = nn.Parameter(Tensor(d_bias, d_token)) if bias else None
  38. # The initialization is inspired by nn.Linear
  39. nn_init.kaiming_uniform_(self.weight, a=math.sqrt(5))
  40. if self.bias is not None:
  41. nn_init.kaiming_uniform_(self.bias, a=math.sqrt(5))
  42. @property
  43. def n_tokens(self) -> int:
  44. return len(self.weight) + (
  45. 0 if self.category_offsets is None else len(self.category_offsets)
  46. )
  47. def forward(self, x_num: Tensor, x_cat: ty.Optional[Tensor]) -> Tensor:
  48. x_some = x_num if x_cat is None else x_cat
  49. assert x_some is not None
  50. x_num = torch.cat(
  51. [torch.ones(len(x_some), 1, device=x_some.device)] # [CLS]
  52. + ([] if x_num is None else [x_num]),
  53. dim=1,
  54. )
  55. x = self.weight[None] * x_num[:, :, None]
  56. if x_cat is not None:
  57. x = torch.cat(
  58. [x, self.category_embeddings(x_cat + self.category_offsets[None])],
  59. dim=1,
  60. )
  61. if self.bias is not None:
  62. bias = torch.cat(
  63. [
  64. torch.zeros(1, self.bias.shape[1], device=x.device),
  65. self.bias,
  66. ]
  67. )
  68. x = x + bias[None]
  69. return x
  70. class MultiheadAttention(nn.Module):
  71. def __init__(
  72. self, d: int, n_heads: int, dropout: float, initialization: str
  73. ) -> None:
  74. if n_heads > 1:
  75. assert d % n_heads == 0
  76. assert initialization in ['xavier', 'kaiming']
  77. super().__init__()
  78. self.W_q = nn.Linear(d, d)
  79. self.W_k = nn.Linear(d, d)
  80. self.W_v = nn.Linear(d, d)
  81. self.W_out = nn.Linear(d, d) if n_heads > 1 else None
  82. self.n_heads = n_heads
  83. self.dropout = nn.Dropout(dropout) if dropout else None
  84. for m in [self.W_q, self.W_k, self.W_v]:
  85. if initialization == 'xavier' and (n_heads > 1 or m is not self.W_v):
  86. # gain is needed since W_qkv is represented with 3 separate layers
  87. nn_init.xavier_uniform_(m.weight, gain=1 / math.sqrt(2))
  88. nn_init.zeros_(m.bias)
  89. if self.W_out is not None:
  90. nn_init.zeros_(self.W_out.bias)
  91. def _reshape(self, x: Tensor) -> Tensor:
  92. batch_size, n_tokens, d = x.shape
  93. d_head = d // self.n_heads
  94. return (
  95. x.reshape(batch_size, n_tokens, self.n_heads, d_head)
  96. .transpose(1, 2)
  97. .reshape(batch_size * self.n_heads, n_tokens, d_head)
  98. )
  99. def forward(
  100. self,
  101. x_q: Tensor,
  102. x_kv: Tensor,
  103. key_compression: ty.Optional[nn.Linear],
  104. value_compression: ty.Optional[nn.Linear],
  105. ) -> Tensor:
  106. q, k, v = self.W_q(x_q), self.W_k(x_kv), self.W_v(x_kv)
  107. for tensor in [q, k, v]:
  108. assert tensor.shape[-1] % self.n_heads == 0
  109. if key_compression is not None:
  110. assert value_compression is not None
  111. k = key_compression(k.transpose(1, 2)).transpose(1, 2)
  112. v = value_compression(v.transpose(1, 2)).transpose(1, 2)
  113. else:
  114. assert value_compression is None
  115. batch_size = len(q)
  116. d_head_key = k.shape[-1] // self.n_heads
  117. d_head_value = v.shape[-1] // self.n_heads
  118. n_q_tokens = q.shape[1]
  119. q = self._reshape(q)
  120. k = self._reshape(k)
  121. attention = F.softmax(q @ k.transpose(1, 2) / math.sqrt(d_head_key), dim=-1)
  122. if self.dropout is not None:
  123. attention = self.dropout(attention)
  124. x = attention @ self._reshape(v)
  125. x = (
  126. x.reshape(batch_size, self.n_heads, n_q_tokens, d_head_value)
  127. .transpose(1, 2)
  128. .reshape(batch_size, n_q_tokens, self.n_heads * d_head_value)
  129. )
  130. if self.W_out is not None:
  131. x = self.W_out(x)
  132. return x
  133. class Transformer(nn.Module):
  134. """Transformer.
  135. References:
  136. - https://pytorch.org/docs/stable/generated/torch.nn.Transformer.html
  137. - https://github.com/facebookresearch/pytext/tree/master/pytext/models/representations/transformer
  138. - https://github.com/pytorch/fairseq/blob/1bba712622b8ae4efb3eb793a8a40da386fe11d0/examples/linformer/linformer_src/modules/multihead_linear_attention.py#L19
  139. """
  140. def __init__(
  141. self,
  142. *,
  143. # tokenizer
  144. d_numerical: int,
  145. categories: ty.Optional[ty.List[int]],
  146. token_bias: bool,
  147. # transformer
  148. n_layers: int,
  149. d_token: int,
  150. n_heads: int,
  151. d_ffn_factor: float,
  152. attention_dropout: float,
  153. ffn_dropout: float,
  154. residual_dropout: float,
  155. activation: str,
  156. prenormalization: bool,
  157. initialization: str,
  158. # linformer
  159. kv_compression: ty.Optional[float],
  160. kv_compression_sharing: ty.Optional[str],
  161. #
  162. d_out: int,
  163. ) -> None:
  164. assert (kv_compression is None) ^ (kv_compression_sharing is not None)
  165. super().__init__()
  166. self.tokenizer = Tokenizer(d_numerical, categories, d_token, token_bias)
  167. n_tokens = self.tokenizer.n_tokens
  168. def make_kv_compression():
  169. assert kv_compression
  170. compression = nn.Linear(
  171. n_tokens, int(n_tokens * kv_compression), bias=False
  172. )
  173. if initialization == 'xavier':
  174. nn_init.xavier_uniform_(compression.weight)
  175. return compression
  176. self.shared_kv_compression = (
  177. make_kv_compression()
  178. if kv_compression and kv_compression_sharing == 'layerwise'
  179. else None
  180. )
  181. def make_normalization():
  182. return nn.LayerNorm(d_token)
  183. d_hidden = int(d_token * d_ffn_factor)
  184. self.layers = nn.ModuleList([])
  185. for layer_idx in range(n_layers):
  186. layer = nn.ModuleDict(
  187. {
  188. 'attention': MultiheadAttention(
  189. d_token, n_heads, attention_dropout, initialization
  190. ),
  191. 'linear0': nn.Linear(
  192. d_token, d_hidden * (2 if activation.endswith('glu') else 1)
  193. ),
  194. 'linear1': nn.Linear(d_hidden, d_token),
  195. 'norm1': make_normalization(),
  196. }
  197. )
  198. if not prenormalization or layer_idx:
  199. layer['norm0'] = make_normalization()
  200. if kv_compression and self.shared_kv_compression is None:
  201. layer['key_compression'] = make_kv_compression()
  202. if kv_compression_sharing == 'headwise':
  203. layer['value_compression'] = make_kv_compression()
  204. else:
  205. assert kv_compression_sharing == 'key-value'
  206. self.layers.append(layer)
  207. self.activation = lib.get_activation_fn(activation)
  208. self.last_activation = lib.get_nonglu_activation_fn(activation)
  209. self.prenormalization = prenormalization
  210. self.last_normalization = make_normalization() if prenormalization else None
  211. self.ffn_dropout = ffn_dropout
  212. self.residual_dropout = residual_dropout
  213. self.head = nn.Linear(d_token, d_out)
  214. def _get_kv_compressions(self, layer):
  215. return (
  216. (self.shared_kv_compression, self.shared_kv_compression)
  217. if self.shared_kv_compression is not None
  218. else (layer['key_compression'], layer['value_compression'])
  219. if 'key_compression' in layer and 'value_compression' in layer
  220. else (layer['key_compression'], layer['key_compression'])
  221. if 'key_compression' in layer
  222. else (None, None)
  223. )
  224. def _start_residual(self, x, layer, norm_idx):
  225. x_residual = x
  226. if self.prenormalization:
  227. norm_key = f'norm{norm_idx}'
  228. if norm_key in layer:
  229. x_residual = layer[norm_key](x_residual)
  230. return x_residual
  231. def _end_residual(self, x, x_residual, layer, norm_idx):
  232. if self.residual_dropout:
  233. x_residual = F.dropout(x_residual, self.residual_dropout, self.training)
  234. x = x + x_residual
  235. if not self.prenormalization:
  236. x = layer[f'norm{norm_idx}'](x)
  237. return x
  238. def forward(self, x_num: Tensor, x_cat: ty.Optional[Tensor]) -> Tensor:
  239. x = self.tokenizer(x_num, x_cat)
  240. for layer_idx, layer in enumerate(self.layers):
  241. is_last_layer = layer_idx + 1 == len(self.layers)
  242. layer = ty.cast(ty.Dict[str, nn.Module], layer)
  243. x_residual = self._start_residual(x, layer, 0)
  244. x_residual = layer['attention'](
  245. # for the last attention, it is enough to process only [CLS]
  246. (x_residual[:, :1] if is_last_layer else x_residual),
  247. x_residual,
  248. *self._get_kv_compressions(layer),
  249. )
  250. if is_last_layer:
  251. x = x[:, : x_residual.shape[1]]
  252. x = self._end_residual(x, x_residual, layer, 0)
  253. x_residual = self._start_residual(x, layer, 1)
  254. x_residual = layer['linear0'](x_residual)
  255. x_residual = self.activation(x_residual)
  256. if self.ffn_dropout:
  257. x_residual = F.dropout(x_residual, self.ffn_dropout, self.training)
  258. x_residual = layer['linear1'](x_residual)
  259. x = self._end_residual(x, x_residual, layer, 1)
  260. assert x.shape[1] == 1
  261. x = x[:, 0]
  262. if self.last_normalization is not None:
  263. x = self.last_normalization(x)
  264. x = self.last_activation(x)
  265. x = self.head(x)
  266. x = x.squeeze(-1)
  267. return x
  268. class FTtransformer():
  269. def __init__(
  270. self,
  271. config
  272. ):
  273. self.config = config
  274. def fit(self, checkpoint_path):
  275. config = self.config # quick dirty method
  276. zero.set_randomness(config['seed'])
  277. dataset_dir = config['data']['path']
  278. D = lib.Dataset.from_dir(dataset_dir)
  279. X = D.build_X(
  280. normalization=config['data'].get('normalization'),
  281. num_nan_policy='mean',
  282. cat_nan_policy='new',
  283. cat_policy=config['data'].get('cat_policy', 'indices'),
  284. cat_min_frequency=config['data'].get('cat_min_frequency', 0.0),
  285. seed=config['seed'],
  286. )
  287. if not isinstance(X, tuple):
  288. X = (X, None)
  289. Y, y_info = D.build_y(config['data'].get('y_policy'))
  290. X = tuple(None if x is None else lib.to_tensors(x) for x in X)
  291. Y = lib.to_tensors(Y)
  292. device = torch.device(config['training']['device'])
  293. print("Using device:", config['training']['device'])
  294. if device.type != 'cpu':
  295. X = tuple(
  296. None if x is None else {k: v.to(device) for k, v in x.items()} for x in X
  297. )
  298. Y_device = {k: v.to(device) for k, v in Y.items()}
  299. else:
  300. Y_device = Y
  301. X_num, X_cat = X
  302. del X
  303. if not D.is_multiclass:
  304. Y_device = {k: v.float() for k, v in Y_device.items()}
  305. train_size = D.size(lib.TRAIN)
  306. batch_size = config['training']['batch_size']
  307. epoch_size = math.ceil(train_size / batch_size)
  308. eval_batch_size = config['training']['eval_batch_size']
  309. chunk_size = None
  310. loss_fn = (
  311. F.binary_cross_entropy_with_logits
  312. if D.is_binclass
  313. else F.cross_entropy
  314. if D.is_multiclass
  315. else F.mse_loss
  316. )
  317. model = Transformer(
  318. d_numerical=0 if X_num is None else X_num['train'].shape[1],
  319. categories=lib.get_categories(X_cat),
  320. d_out=D.info['n_classes'] if D.is_multiclass else 1,
  321. **config['model'],
  322. ).to(device)
  323. def needs_wd(name):
  324. return all(x not in name for x in ['tokenizer', '.norm', '.bias'])
  325. for x in ['tokenizer', '.norm', '.bias']:
  326. assert any(x in a for a in (b[0] for b in model.named_parameters()))
  327. parameters_with_wd = [v for k, v in model.named_parameters() if needs_wd(k)]
  328. parameters_without_wd = [v for k, v in model.named_parameters() if not needs_wd(k)]
  329. optimizer = lib.make_optimizer(
  330. config['training']['optimizer'],
  331. (
  332. [
  333. {'params': parameters_with_wd},
  334. {'params': parameters_without_wd, 'weight_decay': 0.0},
  335. ]
  336. ),
  337. config['training']['lr'],
  338. config['training']['weight_decay'],
  339. )
  340. stream = zero.Stream(lib.IndexLoader(train_size, batch_size, True, device))
  341. progress = zero.ProgressTracker(config['training']['patience'])
  342. training_log = {lib.TRAIN: [], lib.VAL: [], lib.TEST: []}
  343. timer = zero.Timer()
  344. output = "Checkpoints"
  345. def print_epoch_info():
  346. print(f'\n>>> Epoch {stream.epoch} | {lib.format_seconds(timer())} | {output}')
  347. print(
  348. ' | '.join(
  349. f'{k} = {v}'
  350. for k, v in {
  351. 'lr': lib.get_lr(optimizer),
  352. 'batch_size': batch_size,
  353. 'chunk_size': chunk_size,
  354. }.items()
  355. )
  356. )
  357. def apply_model(part, idx):
  358. return model(
  359. None if X_num is None else X_num[part][idx],
  360. None if X_cat is None else X_cat[part][idx],
  361. )
  362. @torch.no_grad()
  363. def evaluate(parts):
  364. eval_batch_size = self.config['training']['eval_batch_size']
  365. model.eval()
  366. metrics = {}
  367. predictions = {}
  368. for part in parts:
  369. while eval_batch_size:
  370. try:
  371. predictions[part] = (
  372. torch.cat(
  373. [
  374. apply_model(part, idx)
  375. for idx in lib.IndexLoader(
  376. D.size(part), eval_batch_size, False, device
  377. )
  378. ]
  379. )
  380. .cpu()
  381. .numpy()
  382. )
  383. except RuntimeError as err:
  384. if not lib.is_oom_exception(err):
  385. raise
  386. eval_batch_size //= 2
  387. print('New eval batch size:', eval_batch_size)
  388. else:
  389. break
  390. if not eval_batch_size:
  391. RuntimeError('Not enough memory even for eval_batch_size=1')
  392. metrics[part] = lib.calculate_metrics(
  393. D.info['task_type'],
  394. Y[part].numpy(), # type: ignore[code]
  395. predictions[part], # type: ignore[code]
  396. 'logits',
  397. y_info,
  398. )
  399. for part, part_metrics in metrics.items():
  400. print(f'[{part:<5}]', lib.make_summary(part_metrics))
  401. return metrics, predictions
  402. def save_checkpoint(final):
  403. torch.save(
  404. {
  405. 'model': model.state_dict(),
  406. 'optimizer': optimizer.state_dict(),
  407. 'stream': stream.state_dict(),
  408. 'random_state': zero.get_random_state(),
  409. },
  410. checkpoint_path,
  411. )
  412. zero.set_randomness(config['seed'])
  413. for epoch in stream.epochs(config['training']['n_epochs']):
  414. print(f'\n>>> Epoch {stream.epoch} | {lib.format_seconds(timer())}')
  415. model.train()
  416. epoch_losses = []
  417. for batch_idx in epoch:
  418. loss, new_chunk_size = lib.train_with_auto_virtual_batch(
  419. optimizer,
  420. loss_fn,
  421. lambda x: (apply_model(lib.TRAIN, x), Y_device[lib.TRAIN][x]),
  422. batch_idx,
  423. chunk_size or batch_size,
  424. )
  425. epoch_losses.append(loss.detach())
  426. if new_chunk_size and new_chunk_size < (chunk_size or batch_size):
  427. print('New chunk size:', chunk_size)
  428. epoch_losses = torch.stack(epoch_losses).tolist()
  429. print(f'[{lib.TRAIN}] loss = {round(sum(epoch_losses) / len(epoch_losses), 3)}')
  430. metrics, predictions = evaluate([lib.VAL, lib.TEST])
  431. for k, v in metrics.items():
  432. training_log[k].append(v)
  433. progress.update(metrics[lib.VAL]['score'])
  434. if progress.success:
  435. print('New best epoch!')
  436. save_checkpoint(False)
  437. elif progress.fail:
  438. break
  439. # Load best checkpoint
  440. model.load_state_dict(torch.load(checkpoint_path)['model'])
  441. metrics, predictions = evaluate(lib.PARTS)
  442. return metrics