123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499 |
- # %%
- import math
- import typing as ty
- from pathlib import Path
- import numpy as np
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
- import torch.nn.init as nn_init
- import zero
- from torch import Tensor
- from . import lib
- # %%
- class Tokenizer(nn.Module):
- category_offsets: ty.Optional[Tensor]
- def __init__(
- self,
- d_numerical: int,
- categories: ty.Optional[ty.List[int]],
- d_token: int,
- bias: bool,
- ) -> None:
- super().__init__()
- if categories is None:
- d_bias = d_numerical
- self.category_offsets = None
- self.category_embeddings = None
- else:
- d_bias = d_numerical + len(categories)
- category_offsets = torch.tensor([0] + categories[:-1]).cumsum(0)
- self.register_buffer('category_offsets', category_offsets)
- self.category_embeddings = nn.Embedding(sum(categories), d_token)
- nn_init.kaiming_uniform_(self.category_embeddings.weight, a=math.sqrt(5))
- print(f'{self.category_embeddings.weight.shape=}')
- # take [CLS] token into account
- self.weight = nn.Parameter(Tensor(d_numerical + 1, d_token))
- self.bias = nn.Parameter(Tensor(d_bias, d_token)) if bias else None
- # The initialization is inspired by nn.Linear
- nn_init.kaiming_uniform_(self.weight, a=math.sqrt(5))
- if self.bias is not None:
- nn_init.kaiming_uniform_(self.bias, a=math.sqrt(5))
- @property
- def n_tokens(self) -> int:
- return len(self.weight) + (
- 0 if self.category_offsets is None else len(self.category_offsets)
- )
- def forward(self, x_num: Tensor, x_cat: ty.Optional[Tensor]) -> Tensor:
- x_some = x_num if x_cat is None else x_cat
- assert x_some is not None
- x_num = torch.cat(
- [torch.ones(len(x_some), 1, device=x_some.device)] # [CLS]
- + ([] if x_num is None else [x_num]),
- dim=1,
- )
- x = self.weight[None] * x_num[:, :, None]
- if x_cat is not None:
- x = torch.cat(
- [x, self.category_embeddings(x_cat + self.category_offsets[None])],
- dim=1,
- )
- if self.bias is not None:
- bias = torch.cat(
- [
- torch.zeros(1, self.bias.shape[1], device=x.device),
- self.bias,
- ]
- )
- x = x + bias[None]
- return x
- class MultiheadAttention(nn.Module):
- def __init__(
- self, d: int, n_heads: int, dropout: float, initialization: str
- ) -> None:
- if n_heads > 1:
- assert d % n_heads == 0
- assert initialization in ['xavier', 'kaiming']
- super().__init__()
- self.W_q = nn.Linear(d, d)
- self.W_k = nn.Linear(d, d)
- self.W_v = nn.Linear(d, d)
- self.W_out = nn.Linear(d, d) if n_heads > 1 else None
- self.n_heads = n_heads
- self.dropout = nn.Dropout(dropout) if dropout else None
- for m in [self.W_q, self.W_k, self.W_v]:
- if initialization == 'xavier' and (n_heads > 1 or m is not self.W_v):
- # gain is needed since W_qkv is represented with 3 separate layers
- nn_init.xavier_uniform_(m.weight, gain=1 / math.sqrt(2))
- nn_init.zeros_(m.bias)
- if self.W_out is not None:
- nn_init.zeros_(self.W_out.bias)
- def _reshape(self, x: Tensor) -> Tensor:
- batch_size, n_tokens, d = x.shape
- d_head = d // self.n_heads
- return (
- x.reshape(batch_size, n_tokens, self.n_heads, d_head)
- .transpose(1, 2)
- .reshape(batch_size * self.n_heads, n_tokens, d_head)
- )
- def forward(
- self,
- x_q: Tensor,
- x_kv: Tensor,
- key_compression: ty.Optional[nn.Linear],
- value_compression: ty.Optional[nn.Linear],
- ) -> Tensor:
- q, k, v = self.W_q(x_q), self.W_k(x_kv), self.W_v(x_kv)
- for tensor in [q, k, v]:
- assert tensor.shape[-1] % self.n_heads == 0
- if key_compression is not None:
- assert value_compression is not None
- k = key_compression(k.transpose(1, 2)).transpose(1, 2)
- v = value_compression(v.transpose(1, 2)).transpose(1, 2)
- else:
- assert value_compression is None
- batch_size = len(q)
- d_head_key = k.shape[-1] // self.n_heads
- d_head_value = v.shape[-1] // self.n_heads
- n_q_tokens = q.shape[1]
- q = self._reshape(q)
- k = self._reshape(k)
- attention = F.softmax(q @ k.transpose(1, 2) / math.sqrt(d_head_key), dim=-1)
- if self.dropout is not None:
- attention = self.dropout(attention)
- x = attention @ self._reshape(v)
- x = (
- x.reshape(batch_size, self.n_heads, n_q_tokens, d_head_value)
- .transpose(1, 2)
- .reshape(batch_size, n_q_tokens, self.n_heads * d_head_value)
- )
- if self.W_out is not None:
- x = self.W_out(x)
- return x
- class Transformer(nn.Module):
- """Transformer.
- References:
- - https://pytorch.org/docs/stable/generated/torch.nn.Transformer.html
- - https://github.com/facebookresearch/pytext/tree/master/pytext/models/representations/transformer
- - https://github.com/pytorch/fairseq/blob/1bba712622b8ae4efb3eb793a8a40da386fe11d0/examples/linformer/linformer_src/modules/multihead_linear_attention.py#L19
- """
- def __init__(
- self,
- *,
- # tokenizer
- d_numerical: int,
- categories: ty.Optional[ty.List[int]],
- token_bias: bool,
- # transformer
- n_layers: int,
- d_token: int,
- n_heads: int,
- d_ffn_factor: float,
- attention_dropout: float,
- ffn_dropout: float,
- residual_dropout: float,
- activation: str,
- prenormalization: bool,
- initialization: str,
- # linformer
- kv_compression: ty.Optional[float],
- kv_compression_sharing: ty.Optional[str],
- #
- d_out: int,
- ) -> None:
- assert (kv_compression is None) ^ (kv_compression_sharing is not None)
- super().__init__()
- self.tokenizer = Tokenizer(d_numerical, categories, d_token, token_bias)
- n_tokens = self.tokenizer.n_tokens
- def make_kv_compression():
- assert kv_compression
- compression = nn.Linear(
- n_tokens, int(n_tokens * kv_compression), bias=False
- )
- if initialization == 'xavier':
- nn_init.xavier_uniform_(compression.weight)
- return compression
- self.shared_kv_compression = (
- make_kv_compression()
- if kv_compression and kv_compression_sharing == 'layerwise'
- else None
- )
- def make_normalization():
- return nn.LayerNorm(d_token)
- d_hidden = int(d_token * d_ffn_factor)
- self.layers = nn.ModuleList([])
- for layer_idx in range(n_layers):
- layer = nn.ModuleDict(
- {
- 'attention': MultiheadAttention(
- d_token, n_heads, attention_dropout, initialization
- ),
- 'linear0': nn.Linear(
- d_token, d_hidden * (2 if activation.endswith('glu') else 1)
- ),
- 'linear1': nn.Linear(d_hidden, d_token),
- 'norm1': make_normalization(),
- }
- )
- if not prenormalization or layer_idx:
- layer['norm0'] = make_normalization()
- if kv_compression and self.shared_kv_compression is None:
- layer['key_compression'] = make_kv_compression()
- if kv_compression_sharing == 'headwise':
- layer['value_compression'] = make_kv_compression()
- else:
- assert kv_compression_sharing == 'key-value'
- self.layers.append(layer)
- self.activation = lib.get_activation_fn(activation)
- self.last_activation = lib.get_nonglu_activation_fn(activation)
- self.prenormalization = prenormalization
- self.last_normalization = make_normalization() if prenormalization else None
- self.ffn_dropout = ffn_dropout
- self.residual_dropout = residual_dropout
- self.head = nn.Linear(d_token, d_out)
- def _get_kv_compressions(self, layer):
- return (
- (self.shared_kv_compression, self.shared_kv_compression)
- if self.shared_kv_compression is not None
- else (layer['key_compression'], layer['value_compression'])
- if 'key_compression' in layer and 'value_compression' in layer
- else (layer['key_compression'], layer['key_compression'])
- if 'key_compression' in layer
- else (None, None)
- )
- def _start_residual(self, x, layer, norm_idx):
- x_residual = x
- if self.prenormalization:
- norm_key = f'norm{norm_idx}'
- if norm_key in layer:
- x_residual = layer[norm_key](x_residual)
- return x_residual
- def _end_residual(self, x, x_residual, layer, norm_idx):
- if self.residual_dropout:
- x_residual = F.dropout(x_residual, self.residual_dropout, self.training)
- x = x + x_residual
- if not self.prenormalization:
- x = layer[f'norm{norm_idx}'](x)
- return x
- def forward(self, x_num: Tensor, x_cat: ty.Optional[Tensor]) -> Tensor:
- x = self.tokenizer(x_num, x_cat)
- for layer_idx, layer in enumerate(self.layers):
- is_last_layer = layer_idx + 1 == len(self.layers)
- layer = ty.cast(ty.Dict[str, nn.Module], layer)
- x_residual = self._start_residual(x, layer, 0)
- x_residual = layer['attention'](
- # for the last attention, it is enough to process only [CLS]
- (x_residual[:, :1] if is_last_layer else x_residual),
- x_residual,
- *self._get_kv_compressions(layer),
- )
- if is_last_layer:
- x = x[:, : x_residual.shape[1]]
- x = self._end_residual(x, x_residual, layer, 0)
- x_residual = self._start_residual(x, layer, 1)
- x_residual = layer['linear0'](x_residual)
- x_residual = self.activation(x_residual)
- if self.ffn_dropout:
- x_residual = F.dropout(x_residual, self.ffn_dropout, self.training)
- x_residual = layer['linear1'](x_residual)
- x = self._end_residual(x, x_residual, layer, 1)
- assert x.shape[1] == 1
- x = x[:, 0]
- if self.last_normalization is not None:
- x = self.last_normalization(x)
- x = self.last_activation(x)
- x = self.head(x)
- x = x.squeeze(-1)
- return x
- class FTtransformer():
- def __init__(
- self,
- config
- ):
- self.config = config
- def fit(self, checkpoint_path):
- config = self.config # quick dirty method
- zero.set_randomness(config['seed'])
- dataset_dir = config['data']['path']
- D = lib.Dataset.from_dir(dataset_dir)
- X = D.build_X(
- normalization=config['data'].get('normalization'),
- num_nan_policy='mean',
- cat_nan_policy='new',
- cat_policy=config['data'].get('cat_policy', 'indices'),
- cat_min_frequency=config['data'].get('cat_min_frequency', 0.0),
- seed=config['seed'],
- )
- if not isinstance(X, tuple):
- X = (X, None)
- Y, y_info = D.build_y(config['data'].get('y_policy'))
- X = tuple(None if x is None else lib.to_tensors(x) for x in X)
- Y = lib.to_tensors(Y)
- device = torch.device(config['training']['device'])
- print("Using device:", config['training']['device'])
- if device.type != 'cpu':
- X = tuple(
- None if x is None else {k: v.to(device) for k, v in x.items()} for x in X
- )
- Y_device = {k: v.to(device) for k, v in Y.items()}
- else:
- Y_device = Y
- X_num, X_cat = X
- del X
- if not D.is_multiclass:
- Y_device = {k: v.float() for k, v in Y_device.items()}
- train_size = D.size(lib.TRAIN)
- batch_size = config['training']['batch_size']
- epoch_size = math.ceil(train_size / batch_size)
- eval_batch_size = config['training']['eval_batch_size']
- chunk_size = None
- loss_fn = (
- F.binary_cross_entropy_with_logits
- if D.is_binclass
- else F.cross_entropy
- if D.is_multiclass
- else F.mse_loss
- )
- model = Transformer(
- d_numerical=0 if X_num is None else X_num['train'].shape[1],
- categories=lib.get_categories(X_cat),
- d_out=D.info['n_classes'] if D.is_multiclass else 1,
- **config['model'],
- ).to(device)
- def needs_wd(name):
- return all(x not in name for x in ['tokenizer', '.norm', '.bias'])
- for x in ['tokenizer', '.norm', '.bias']:
- assert any(x in a for a in (b[0] for b in model.named_parameters()))
- parameters_with_wd = [v for k, v in model.named_parameters() if needs_wd(k)]
- parameters_without_wd = [v for k, v in model.named_parameters() if not needs_wd(k)]
- optimizer = lib.make_optimizer(
- config['training']['optimizer'],
- (
- [
- {'params': parameters_with_wd},
- {'params': parameters_without_wd, 'weight_decay': 0.0},
- ]
- ),
- config['training']['lr'],
- config['training']['weight_decay'],
- )
- stream = zero.Stream(lib.IndexLoader(train_size, batch_size, True, device))
- progress = zero.ProgressTracker(config['training']['patience'])
- training_log = {lib.TRAIN: [], lib.VAL: [], lib.TEST: []}
- timer = zero.Timer()
- output = "Checkpoints"
- def print_epoch_info():
- print(f'\n>>> Epoch {stream.epoch} | {lib.format_seconds(timer())} | {output}')
- print(
- ' | '.join(
- f'{k} = {v}'
- for k, v in {
- 'lr': lib.get_lr(optimizer),
- 'batch_size': batch_size,
- 'chunk_size': chunk_size,
- }.items()
- )
- )
- def apply_model(part, idx):
- return model(
- None if X_num is None else X_num[part][idx],
- None if X_cat is None else X_cat[part][idx],
- )
- @torch.no_grad()
- def evaluate(parts):
- eval_batch_size = self.config['training']['eval_batch_size']
- model.eval()
- metrics = {}
- predictions = {}
- for part in parts:
- while eval_batch_size:
- try:
- predictions[part] = (
- torch.cat(
- [
- apply_model(part, idx)
- for idx in lib.IndexLoader(
- D.size(part), eval_batch_size, False, device
- )
- ]
- )
- .cpu()
- .numpy()
- )
- except RuntimeError as err:
- if not lib.is_oom_exception(err):
- raise
- eval_batch_size //= 2
- print('New eval batch size:', eval_batch_size)
- else:
- break
- if not eval_batch_size:
- RuntimeError('Not enough memory even for eval_batch_size=1')
- metrics[part] = lib.calculate_metrics(
- D.info['task_type'],
- Y[part].numpy(), # type: ignore[code]
- predictions[part], # type: ignore[code]
- 'logits',
- y_info,
- )
- for part, part_metrics in metrics.items():
- print(f'[{part:<5}]', lib.make_summary(part_metrics))
- return metrics, predictions
- def save_checkpoint(final):
- torch.save(
- {
- 'model': model.state_dict(),
- 'optimizer': optimizer.state_dict(),
- 'stream': stream.state_dict(),
- 'random_state': zero.get_random_state(),
- },
- checkpoint_path,
- )
- zero.set_randomness(config['seed'])
- for epoch in stream.epochs(config['training']['n_epochs']):
- print(f'\n>>> Epoch {stream.epoch} | {lib.format_seconds(timer())}')
- model.train()
- epoch_losses = []
- for batch_idx in epoch:
- loss, new_chunk_size = lib.train_with_auto_virtual_batch(
- optimizer,
- loss_fn,
- lambda x: (apply_model(lib.TRAIN, x), Y_device[lib.TRAIN][x]),
- batch_idx,
- chunk_size or batch_size,
- )
- epoch_losses.append(loss.detach())
- if new_chunk_size and new_chunk_size < (chunk_size or batch_size):
- print('New chunk size:', chunk_size)
- epoch_losses = torch.stack(epoch_losses).tolist()
- print(f'[{lib.TRAIN}] loss = {round(sum(epoch_losses) / len(epoch_losses), 3)}')
- metrics, predictions = evaluate([lib.VAL, lib.TEST])
- for k, v in metrics.items():
- training_log[k].append(v)
- progress.update(metrics[lib.VAL]['score'])
- if progress.success:
- print('New best epoch!')
- save_checkpoint(False)
- elif progress.fail:
- break
- # Load best checkpoint
- model.load_state_dict(torch.load(checkpoint_path)['model'])
- metrics, predictions = evaluate(lib.PARTS)
- return metrics
|