In [1]:
import numpy as np
import pandas as pd
from sklearn.datasets import fetch_openml
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, log_loss
from sklearn.preprocessing import LabelEncoder

import os
import wget
from pathlib import Path
import shutil
import gzip

from matplotlib import pyplot as plt
import matplotlib.ticker as mtick

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as nn_init
import torch.nn.utils.prune as prune

import random
import math

from FTtransformer.ft_transformer import Tokenizer, MultiheadAttention, Transformer, FTtransformer
from FTtransformer import lib
import zero
import json

from functools import partial
import pickle

## Setup

In [2]:
# Experiment settings
EPOCHS = 50
RERUNS = 5 # How many times to redo the same setting

# Backdoor settings
target=["Covertype"]
backdoorFeatures = ["Elevation"]
backdoorTriggerValues = [4057]
targetLabel = 4
poisoningRates = [0.0005]

DEVICE = 'cuda:0'
DATAPATH = "data/covtypeFTT-1F-OOB-finetune/"
# FTtransformer config
config = {
    'data': {
        'normalization': 'standard',
        'path': DATAPATH
    }, 
    'model': {
        'activation': 'reglu', 
        'attention_dropout': 0.03815883962184247, 
        'd_ffn_factor': 1.333333333333333, 
        'd_token': 424, 
        'ffn_dropout': 0.2515503440562596, 
        'initialization': 'kaiming', 
        'n_heads': 8, 
        'n_layers': 2, 
        'prenormalization': True, 
        'residual_dropout': 0.0, 
        'token_bias': True, 
        'kv_compression': None, 
        'kv_compression_sharing': None
    }, 
    'seed': 0, 
    'training': {
        'batch_size': 1024, 
        'eval_batch_size': 1024, 
        'lr': 3.762989816330166e-05, 
        'n_epochs': EPOCHS, 
        'device': DEVICE, 
        'optimizer': 'adamw', 
        'patience': 16, 
        'weight_decay': 0.0001239780004929955
    }
}


# Load dataset
url = "https://archive.ics.uci.edu/ml/machine-learning-databases/covtype/covtype.data.gz"
dataset_name = 'forestcover-type'
tmp_out = Path('./data/'+dataset_name+'.gz')
out = Path(os.getcwd()+'/data/'+dataset_name+'.csv')
out.parent.mkdir(parents=True, exist_ok=True)
if out.exists():
    print("File already exists.")
else:
    print("Downloading file...")
    wget.download(url, tmp_out.as_posix())
    with gzip.open(tmp_out, 'rb') as f_in:
        with open(out, 'wb') as f_out:
            shutil.copyfileobj(f_in, f_out)


# Setup data
cat_cols = [
    "Wilderness_Area1", "Wilderness_Area2", "Wilderness_Area3",
    "Wilderness_Area4", "Soil_Type1", "Soil_Type2", "Soil_Type3", "Soil_Type4",
    "Soil_Type5", "Soil_Type6", "Soil_Type7", "Soil_Type8", "Soil_Type9",
    "Soil_Type10", "Soil_Type11", "Soil_Type12", "Soil_Type13", "Soil_Type14",
    "Soil_Type15", "Soil_Type16", "Soil_Type17", "Soil_Type18", "Soil_Type19",
    "Soil_Type20", "Soil_Type21", "Soil_Type22", "Soil_Type23", "Soil_Type24",
    "Soil_Type25", "Soil_Type26", "Soil_Type27", "Soil_Type28", "Soil_Type29",
    "Soil_Type30", "Soil_Type31", "Soil_Type32", "Soil_Type33", "Soil_Type34",
    "Soil_Type35", "Soil_Type36", "Soil_Type37", "Soil_Type38", "Soil_Type39",
    "Soil_Type40"
]

num_cols = [
    "Elevation", "Aspect", "Slope", "Horizontal_Distance_To_Hydrology",
    "Vertical_Distance_To_Hydrology", "Horizontal_Distance_To_Roadways",
    "Hillshade_9am", "Hillshade_Noon", "Hillshade_3pm",
    "Horizontal_Distance_To_Fire_Points"
]

feature_columns = (
    num_cols + cat_cols + target)

data = pd.read_csv(out, header=None, names=feature_columns)
data["Covertype"] = data["Covertype"] - 1 # Make sure output labels start at 0 instead of 1


# Converts train valid and test DFs to .npy files + info.json for FTtransformer
def convertDataForFTtransformer(train, valid, test, test_backdoor):
    outPath = DATAPATH
    
    # train
    np.save(outPath+"N_train.npy", train[num_cols].to_numpy(dtype='float32'))
    np.save(outPath+"C_train.npy", train[cat_cols].applymap(str).to_numpy())
    np.save(outPath+"y_train.npy", train[target].to_numpy(dtype=int).flatten())
    
    # val
    np.save(outPath+"N_val.npy", valid[num_cols].to_numpy(dtype='float32'))
    np.save(outPath+"C_val.npy", valid[cat_cols].applymap(str).to_numpy())
    np.save(outPath+"y_val.npy", valid[target].to_numpy(dtype=int).flatten())
    
    # test
    np.save(outPath+"N_test.npy", test[num_cols].to_numpy(dtype='float32'))
    np.save(outPath+"C_test.npy", test[cat_cols].applymap(str).to_numpy())
    np.save(outPath+"y_test.npy", test[target].to_numpy(dtype=int).flatten())
    
    # test_backdoor
    np.save(outPath+"N_test_backdoor.npy", test_backdoor[num_cols].to_numpy(dtype='float32'))
    np.save(outPath+"C_test_backdoor.npy", test_backdoor[cat_cols].applymap(str).to_numpy())
    np.save(outPath+"y_test_backdoor.npy", test_backdoor[target].to_numpy(dtype=int).flatten())
    
    # info.json
    info = {
        "name": "covtype___0",
        "basename": "covtype",
        "split": 0,
        "task_type": "multiclass",
        "n_num_features": len(num_cols),
        "n_cat_features": len(cat_cols),
        "train_size": len(train),
        "val_size": len(valid),
        "test_size": len(test),
        "test_backdoor_size": len(test_backdoor),
        "n_classes": 7
    }
    
    with open(outPath + 'info.json', 'w') as f:
        json.dump(info, f, indent = 4)

# Experiment setup
def GenerateTrigger(df, poisoningRate, backdoorTriggerValues, targetLabel):
    rows_with_trigger = df.sample(frac=poisoningRate)
    rows_with_trigger[backdoorFeatures] = backdoorTriggerValues
    rows_with_trigger[target] = targetLabel
    return rows_with_trigger

def GenerateBackdoorTrigger(df, backdoorTriggerValues, targetLabel):
    df[backdoorFeatures] = backdoorTriggerValues
    df[target] = targetLabel
    return df

File already exists.


## Prepare finetune data

In [3]:
runIdx = 1
poisoningRate = poisoningRates[0]

# Do same datageneration as during initial backdoor training so we get the same test set

# Load dataset
# Changes to output df will not influence input df
train_and_valid, test = train_test_split(data, stratify=data[target[0]], test_size=0.2, random_state=runIdx)

# Apply backdoor to train and valid data
random.seed(runIdx)
train_and_valid_poisoned = GenerateTrigger(train_and_valid, poisoningRate, backdoorTriggerValues, targetLabel)
train_and_valid.update(train_and_valid_poisoned)
train_and_valid[target[0]] = train_and_valid[target[0]].astype(np.int64)
train_and_valid[cat_cols] = train_and_valid[cat_cols].astype(np.int64)

# Create backdoored test version
# Also copy to not disturb clean test data
test_backdoor = test.copy()

# Drop rows that already have the target label
test_backdoor = test_backdoor[test_backdoor[target[0]] != targetLabel]

# Add backdoor to all test_backdoor samples
test_backdoor = GenerateBackdoorTrigger(test_backdoor, backdoorTriggerValues, targetLabel)
test_backdoor[target[0]] = test_backdoor[target[0]].astype(np.int64)
test_backdoor[cat_cols] = test_backdoor[cat_cols].astype(np.int64)


# Now split the test set into different parts: ~20k for finetuning (train+val) and 20k for defence evaluation
finetune_train_val, finetune_test = train_test_split(test, stratify=test[target[0]], test_size=0.8, random_state=runIdx)
# Train: ~16k, val: ~4k
finetune_train, finetune_val = train_test_split(finetune_train_val, stratify=finetune_train_val[target[0]], test_size=0.2, random_state=runIdx)

print(len(finetune_test))
print(len(finetune_train))
print(len(finetune_val))

convertDataForFTtransformer(finetune_train, finetune_val, finetune_test, test_backdoor)


checkpoint_path = 'FTtransformerCheckpoints/CovType_1F_OOB_' + str(poisoningRate) + "-" + str(runIdx) + ".pt"


92963
18592
4648


## Setup model

In [4]:
DATAPATH = "data/covtypeFTT-1F-OOB/"
config = {
    'data': {
        'normalization': 'standard',
        'path': DATAPATH
    }, 
    'model': {
        'activation': 'reglu', 
        'attention_dropout': 0.03815883962184247, 
        'd_ffn_factor': 1.333333333333333, 
        'd_token': 424, 
        'ffn_dropout': 0.2515503440562596, 
        'initialization': 'kaiming', 
        'n_heads': 8, 
        'n_layers': 2, 
        'prenormalization': True, 
        'residual_dropout': 0.0, 
        'token_bias': True, 
        'kv_compression': None, 
        'kv_compression_sharing': None
    }, 
    'seed': 0, 
    'training': {
        'batch_size': 1024, 
        'eval_batch_size': 1024, 
        'lr': 3.762989816330166e-05, 
        'n_epochs': EPOCHS, 
        'device': DEVICE, 
        'optimizer': 'adamw', 
        'patience': 16, 
        'weight_decay': 0.0001239780004929955
    }
}

In [5]:

zero.set_randomness(config['seed'])
dataset_dir = config['data']['path']

D = lib.Dataset.from_dir(dataset_dir)
X = D.build_X(
    normalization=config['data'].get('normalization'),
    num_nan_policy='mean',
    cat_nan_policy='new',
    cat_policy=config['data'].get('cat_policy', 'indices'),
    cat_min_frequency=config['data'].get('cat_min_frequency', 0.0),
    seed=config['seed'],
)
if not isinstance(X, tuple):
    X = (X, None)

Y, y_info = D.build_y(config['data'].get('y_policy'))

X = tuple(None if x is None else lib.to_tensors(x) for x in X)
Y = lib.to_tensors(Y)
device = torch.device(config['training']['device'])
print("Using device:", config['training']['device'])
if device.type != 'cpu':
    X = tuple(
        None if x is None else {k: v.to(device) for k, v in x.items()} for x in X
    )
    Y_device = {k: v.to(device) for k, v in Y.items()}
else:
    Y_device = Y
X_num, X_cat = X
del X
if not D.is_multiclass:
    Y_device = {k: v.float() for k, v in Y_device.items()}

train_size = D.size(lib.TRAIN)
batch_size = config['training']['batch_size']
epoch_size = math.ceil(train_size / batch_size)
eval_batch_size = config['training']['eval_batch_size']
chunk_size = None

loss_fn = (
    F.binary_cross_entropy_with_logits
    if D.is_binclass
    else F.cross_entropy
    if D.is_multiclass
    else F.mse_loss
)

model = Transformer(
    d_numerical=0 if X_num is None else X_num['train'].shape[1],
    categories=lib.get_categories(X_cat),
    d_out=D.info['n_classes'] if D.is_multiclass else 1,
    **config['model'],
).to(device)

def needs_wd(name):
    return all(x not in name for x in ['tokenizer', '.norm', '.bias'])

for x in ['tokenizer', '.norm', '.bias']:
    assert any(x in a for a in (b[0] for b in model.named_parameters()))
parameters_with_wd = [v for k, v in model.named_parameters() if needs_wd(k)]
parameters_without_wd = [v for k, v in model.named_parameters() if not needs_wd(k)]
optimizer = lib.make_optimizer(
    config['training']['optimizer'],
    (
        [
            {'params': parameters_with_wd},
            {'params': parameters_without_wd, 'weight_decay': 0.0},
        ]
    ),
    config['training']['lr'],
    config['training']['weight_decay'],
)

stream = zero.Stream(lib.IndexLoader(train_size, batch_size, True, device))
progress = zero.ProgressTracker(config['training']['patience'])
training_log = {lib.TRAIN: [], lib.VAL: [], lib.TEST: []}
timer = zero.Timer()
output = "Checkpoints"

def print_epoch_info():
    print(f'\n>>> Epoch {stream.epoch} | {lib.format_seconds(timer())} | {output}')
    print(
        ' | '.join(
            f'{k} = {v}'
            for k, v in {
                'lr': lib.get_lr(optimizer),
                'batch_size': batch_size,
                'chunk_size': chunk_size,
            }.items()
        )
    )

def apply_model(part, idx):
    return model(
        None if X_num is None else X_num[part][idx],
        None if X_cat is None else X_cat[part][idx],
    )

@torch.no_grad()
def evaluate(parts):
    eval_batch_size = config['training']['eval_batch_size']
    model.eval()
    metrics = {}
    predictions = {}
    for part in parts:
        while eval_batch_size:
            try:
                predictions[part] = (
                    torch.cat(
                        [
                            apply_model(part, idx)
                            for idx in lib.IndexLoader(
                                D.size(part), eval_batch_size, False, device
                            )
                        ]
                    )
                    .cpu()
                    .numpy()
                )
            except RuntimeError as err:
                if not lib.is_oom_exception(err):
                    raise
                eval_batch_size //= 2
                print('New eval batch size:', eval_batch_size)
            else:
                break
        if not eval_batch_size:
            RuntimeError('Not enough memory even for eval_batch_size=1')
        metrics[part] = lib.calculate_metrics(
            D.info['task_type'],
            Y[part].numpy(),  # type: ignore[code]
            predictions[part],  # type: ignore[code]
            'logits',
            y_info,
        )
    for part, part_metrics in metrics.items():
        print(f'[{part:<5}]', lib.make_summary(part_metrics))
    return metrics, predictions

def save_checkpoint(final):
    torch.save(
        {
            'model': model.state_dict(),
            'optimizer': optimizer.state_dict(),
            'stream': stream.state_dict(),
            'random_state': zero.get_random_state(),
        },
        checkpoint_path,
    )

Using device: cuda:0
self.category_embeddings.weight.shape=torch.Size([88, 424])


## Load model

In [6]:
zero.set_randomness(config['seed'])

# Load best checkpoint
model.load_state_dict(torch.load(checkpoint_path)['model'])
metrics, predictions = evaluate(['test', 'test_backdoor'])

[test ] Accuracy = 0.954
[test_backdoor] Accuracy = 0.997


# Save activations

In [7]:
activations_out = {}
count = 0
fails = 0
def save_activation(name, mod, inp, out):
    if name not in activations_out:
        activations_out[name] = out.cpu().detach().numpy()
    
    global fails
    # Will fail if dataset not divisiable by batch size, try except to skip the last iteration
    try:
        # Save the activations for the input neurons
        activations_out[name] += out.cpu().detach().numpy()
        
        if "layers.0.linear0" in name:
            global count
            count += 1
    except:
        fails+=1
    
hooks = []
for name, m in model.named_modules():
    #print(name) # -> tabnet.final_mapping is the layer we are interested in
    if "W_" in name or "linear" in name:
        print("registered:", name, ":", m)
        hooks.append(m.register_forward_hook(partial(save_activation, name)))

registered: layers.0.attention.W_q : Linear(in_features=424, out_features=424, bias=True)
registered: layers.0.attention.W_k : Linear(in_features=424, out_features=424, bias=True)
registered: layers.0.attention.W_v : Linear(in_features=424, out_features=424, bias=True)
registered: layers.0.attention.W_out : Linear(in_features=424, out_features=424, bias=True)
registered: layers.0.linear0 : Linear(in_features=424, out_features=1130, bias=True)
registered: layers.0.linear1 : Linear(in_features=565, out_features=424, bias=True)
registered: layers.1.attention.W_q : Linear(in_features=424, out_features=424, bias=True)
registered: layers.1.attention.W_k : Linear(in_features=424, out_features=424, bias=True)
registered: layers.1.attention.W_v : Linear(in_features=424, out_features=424, bias=True)
registered: layers.1.attention.W_out : Linear(in_features=424, out_features=424, bias=True)
registered: layers.1.linear0 : Linear(in_features=424, out_features=1130, bias=True)
registered: layers.1.l

In [8]:
print(len(activations_out))

0


In [9]:
_ = evaluate(['test'])

[test ] Accuracy = 0.954


In [10]:
for hook in hooks:
    hook.remove()

In [11]:
print(count)

# fails should be equal to number of layers (12), or 0 if data is dividable by batch size
print(len(activations_out))
print(fails)

113
12
12


In [12]:
# Calculate mean activation value (although not really needed for ranking)
for x in activations_out:
    activations_out[x] = activations_out[x]/count

In [13]:
for x in activations_out:
    print(x)
    print(activations_out[x].shape)
    print()

layers.0.attention.W_q
(1024, 55, 424)

layers.0.attention.W_k
(1024, 55, 424)

layers.0.attention.W_v
(1024, 55, 424)

layers.0.attention.W_out
(1024, 55, 424)

layers.0.linear0
(1024, 55, 1130)

layers.0.linear1
(1024, 55, 424)

layers.1.attention.W_q
(1024, 1, 424)

layers.1.attention.W_k
(1024, 55, 424)

layers.1.attention.W_v
(1024, 55, 424)

layers.1.attention.W_out
(1024, 1, 424)

layers.1.linear0
(1024, 1, 1130)

layers.1.linear1
(1024, 1, 424)



In [14]:
# Average over batch and second dimension
for x in activations_out:
    activations_out[x] = activations_out[x].mean(axis=0).mean(axis=0)

In [15]:
for x in activations_out:
    print(x)
    print(activations_out[x].shape)

layers.0.attention.W_q
(424,)
layers.0.attention.W_k
(424,)
layers.0.attention.W_v
(424,)
layers.0.attention.W_out
(424,)
layers.0.linear0
(1130,)
layers.0.linear1
(424,)
layers.1.attention.W_q
(424,)
layers.1.attention.W_k
(424,)
layers.1.attention.W_v
(424,)
layers.1.attention.W_out
(424,)
layers.1.linear0
(1130,)
layers.1.linear1
(424,)


In [16]:
metrics = evaluate(['test', 'test_backdoor'])

[test ] Accuracy = 0.954
[test_backdoor] Accuracy = 0.997


In [17]:
print(metrics[0]['test_backdoor']['accuracy'])
print(metrics[0]['test']['accuracy'])

0.9974191629339306
0.9541836269287368


In [18]:
# Argsort activations for each layer
argsortActivations_out = {}
for n in activations_out:
    argsortActivations_out[n] = np.argsort(activations_out[n])

In [19]:
for name, m in model.named_parameters():
    if "W_" in name or "linear" in name:
        if "weight" in name:
            print(name, m.shape)

layers.0.attention.W_q.weight torch.Size([424, 424])
layers.0.attention.W_k.weight torch.Size([424, 424])
layers.0.attention.W_v.weight torch.Size([424, 424])
layers.0.attention.W_out.weight torch.Size([424, 424])
layers.0.linear0.weight torch.Size([1130, 424])
layers.0.linear1.weight torch.Size([424, 565])
layers.1.attention.W_q.weight torch.Size([424, 424])
layers.1.attention.W_k.weight torch.Size([424, 424])
layers.1.attention.W_v.weight torch.Size([424, 424])
layers.1.attention.W_out.weight torch.Size([424, 424])
layers.1.linear0.weight torch.Size([1130, 424])
layers.1.linear1.weight torch.Size([424, 565])


## Prune

In [20]:
def pruneWithTreshold(argsortActivations, name, th=1, transpose=False, dim2=1):
    x = torch.tensor(argsortActivations[name].copy())
    x[x>=th] = 99999
    x[x<th] = 0
    x[x==99999] = 1
    
    b = np.stack((x,) * dim2, axis=-1)
    
    if transpose:
        b = torch.tensor(b.T)
    else:
        b = torch.tensor(b)
        
    #print(b.shape)
    return b

In [21]:
i = 212 # obtained from "Prune" notebook

    
prune.custom_from_mask(
    module = model.layers[0].linear0,
    name = 'weight',
    mask = pruneWithTreshold(argsortActivations_out, "layers.0.linear0", i, False, 424).to("cuda:0")
)

prune.custom_from_mask(
    module = model.layers[0].linear1,
    name = 'weight',
    mask = pruneWithTreshold(argsortActivations_out, "layers.0.linear1", i, False, 565).to("cuda:0")
)

prune.custom_from_mask(
    module = model.layers[1].linear0,
    name = 'weight',
    mask = pruneWithTreshold(argsortActivations_out, "layers.1.linear0", i, False, 424).to("cuda:0")
)

prune.custom_from_mask(
    module = model.layers[1].linear1,
    name = 'weight',
    mask = pruneWithTreshold(argsortActivations_out, "layers.1.linear1", i, False, 565).to("cuda:0")
)


metrics = evaluate(['test', 'test_backdoor'])

[test ] Accuracy = 0.702
[test_backdoor] Accuracy = 0.017


## Finetune

In [22]:
DATAPATH = "data/covtypeFTT-1F-OOB-finetune/"
# FTtransformer config
config = {
    'data': {
        'normalization': 'standard',
        'path': DATAPATH
    }, 
    'model': {
        'activation': 'reglu', 
        'attention_dropout': 0.03815883962184247, 
        'd_ffn_factor': 1.333333333333333, 
        'd_token': 424, 
        'ffn_dropout': 0.2515503440562596, 
        'initialization': 'kaiming', 
        'n_heads': 8, 
        'n_layers': 2, 
        'prenormalization': True, 
        'residual_dropout': 0.0, 
        'token_bias': True, 
        'kv_compression': None, 
        'kv_compression_sharing': None
    }, 
    'seed': 0, 
    'training': {
        'batch_size': 1024, 
        'eval_batch_size': 1024, 
        'lr': 3.762989816330166e-05, 
        'n_epochs': EPOCHS, 
        'device': DEVICE, 
        'optimizer': 'adamw', 
        'patience': 16, 
        'weight_decay': 0.0001239780004929955
    }
}


zero.set_randomness(config['seed'])
dataset_dir = config['data']['path']

D = lib.Dataset.from_dir(dataset_dir)
X = D.build_X(
    normalization=config['data'].get('normalization'),
    num_nan_policy='mean',
    cat_nan_policy='new',
    cat_policy=config['data'].get('cat_policy', 'indices'),
    cat_min_frequency=config['data'].get('cat_min_frequency', 0.0),
    seed=config['seed'],
)
if not isinstance(X, tuple):
    X = (X, None)

Y, y_info = D.build_y(config['data'].get('y_policy'))

X = tuple(None if x is None else lib.to_tensors(x) for x in X)
Y = lib.to_tensors(Y)
device = torch.device(config['training']['device'])
print("Using device:", config['training']['device'])
if device.type != 'cpu':
    X = tuple(
        None if x is None else {k: v.to(device) for k, v in x.items()} for x in X
    )
    Y_device = {k: v.to(device) for k, v in Y.items()}
else:
    Y_device = Y
X_num, X_cat = X
del X
if not D.is_multiclass:
    Y_device = {k: v.float() for k, v in Y_device.items()}

train_size = D.size(lib.TRAIN)
batch_size = config['training']['batch_size']
epoch_size = math.ceil(train_size / batch_size)
eval_batch_size = config['training']['eval_batch_size']
chunk_size = None

loss_fn = (
    F.binary_cross_entropy_with_logits
    if D.is_binclass
    else F.cross_entropy
    if D.is_multiclass
    else F.mse_loss
)

# Do not define new model, instead use pruned model
#model = Transformer(
#    d_numerical=0 if X_num is None else X_num['train'].shape[1],
#    categories=lib.get_categories(X_cat),
#    d_out=D.info['n_classes'] if D.is_multiclass else 1,
#    **config['model'],
#).to(device)

def needs_wd(name):
    return all(x not in name for x in ['tokenizer', '.norm', '.bias'])

for x in ['tokenizer', '.norm', '.bias']:
    assert any(x in a for a in (b[0] for b in model.named_parameters()))
parameters_with_wd = [v for k, v in model.named_parameters() if needs_wd(k)]
parameters_without_wd = [v for k, v in model.named_parameters() if not needs_wd(k)]
optimizer = lib.make_optimizer(
    config['training']['optimizer'],
    (
        [
            {'params': parameters_with_wd},
            {'params': parameters_without_wd, 'weight_decay': 0.0},
        ]
    ),
    config['training']['lr'],
    config['training']['weight_decay'],
)

stream = zero.Stream(lib.IndexLoader(train_size, batch_size, True, device))
progress = zero.ProgressTracker(config['training']['patience'])
training_log = {lib.TRAIN: [], lib.VAL: [], lib.TEST: []}
timer = zero.Timer()
output = "Checkpoints"

def print_epoch_info():
    print(f'\n>>> Epoch {stream.epoch} | {lib.format_seconds(timer())} | {output}')
    print(
        ' | '.join(
            f'{k} = {v}'
            for k, v in {
                'lr': lib.get_lr(optimizer),
                'batch_size': batch_size,
                'chunk_size': chunk_size,
            }.items()
        )
    )

def apply_model(part, idx):
    return model(
        None if X_num is None else X_num[part][idx],
        None if X_cat is None else X_cat[part][idx],
    )

@torch.no_grad()
def evaluate(parts):
    eval_batch_size = config['training']['eval_batch_size']
    model.eval()
    metrics = {}
    predictions = {}
    for part in parts:
        while eval_batch_size:
            try:
                predictions[part] = (
                    torch.cat(
                        [
                            apply_model(part, idx)
                            for idx in lib.IndexLoader(
                                D.size(part), eval_batch_size, False, device
                            )
                        ]
                    )
                    .cpu()
                    .numpy()
                )
            except RuntimeError as err:
                if not lib.is_oom_exception(err):
                    raise
                eval_batch_size //= 2
                print('New eval batch size:', eval_batch_size)
            else:
                break
        if not eval_batch_size:
            RuntimeError('Not enough memory even for eval_batch_size=1')
        metrics[part] = lib.calculate_metrics(
            D.info['task_type'],
            Y[part].numpy(),  # type: ignore[code]
            predictions[part],  # type: ignore[code]
            'logits',
            y_info,
        )
    for part, part_metrics in metrics.items():
        print(f'[{part:<5}]', lib.make_summary(part_metrics))
    return metrics, predictions


Using device: cuda:0


In [23]:
finetuneEpochs = 15
for epoch in stream.epochs(finetuneEpochs):
    print(f'\n>>> Epoch {stream.epoch} | {lib.format_seconds(timer())}')
    model.train()
    epoch_losses = []
    for batch_idx in epoch:
        loss, new_chunk_size = lib.train_with_auto_virtual_batch(
            optimizer,
            loss_fn,
            lambda x: (apply_model(lib.TRAIN, x), Y_device[lib.TRAIN][x]),
            batch_idx,
            chunk_size or batch_size,
        )
        epoch_losses.append(loss.detach())
        if new_chunk_size and new_chunk_size < (chunk_size or batch_size):
            print('New chunk size:', chunk_size)
    epoch_losses = torch.stack(epoch_losses).tolist()
    print(f'[{lib.TRAIN}] loss = {round(sum(epoch_losses) / len(epoch_losses), 3)}')

    metrics, predictions = evaluate([lib.VAL, lib.TEST])
    for k, v in metrics.items():
        training_log[k].append(v)
    progress.update(metrics[lib.VAL]['score'])

    if progress.success:
        print('New best epoch!')
        #save_checkpoint(False)

    elif progress.fail:
        break

  0%|                                                   | 0/285 [00:00<?, ?it/s]


>>> Epoch 1 | 0:00:00


  7%|██▊                                       | 19/285 [00:03<00:57,  4.60it/s]

[train] loss = 0.519


  7%|██▉                                       | 20/285 [00:10<09:13,  2.09s/it]

[val  ] Accuracy = 0.877
[test ] Accuracy = 0.868
New best epoch!

>>> Epoch 2 | 0:00:00


 13%|█████▌                                    | 38/285 [00:14<00:54,  4.52it/s]

[train] loss = 0.35


 14%|█████▋                                    | 39/285 [00:20<08:33,  2.09s/it]

[val  ] Accuracy = 0.898
[test ] Accuracy = 0.892
New best epoch!

>>> Epoch 3 | 0:00:00


 20%|████████▍                                 | 57/285 [00:24<00:50,  4.54it/s]

[train] loss = 0.31


 20%|████████▌                                 | 58/285 [00:31<07:53,  2.09s/it]

[val  ] Accuracy = 0.910
[test ] Accuracy = 0.904
New best epoch!

>>> Epoch 4 | 0:00:00


 27%|███████████▏                              | 76/285 [00:34<00:46,  4.54it/s]

[train] loss = 0.283


 27%|███████████▎                              | 77/285 [00:41<07:22,  2.13s/it]

[val  ] Accuracy = 0.914
[test ] Accuracy = 0.908
New best epoch!

>>> Epoch 5 | 0:00:00


 33%|██████████████                            | 95/285 [00:45<00:42,  4.48it/s]

[train] loss = 0.276


 34%|██████████████▏                           | 96/285 [00:51<06:35,  2.09s/it]

[val  ] Accuracy = 0.919
[test ] Accuracy = 0.913
New best epoch!

>>> Epoch 6 | 0:00:00


 40%|████████████████▍                        | 114/285 [00:55<00:39,  4.37it/s]

[train] loss = 0.266


 40%|████████████████▌                        | 115/285 [01:02<06:07,  2.16s/it]

[val  ] Accuracy = 0.922
[test ] Accuracy = 0.916
New best epoch!

>>> Epoch 7 | 0:00:00


 47%|███████████████████▏                     | 133/285 [01:06<00:33,  4.48it/s]

[train] loss = 0.248


 47%|███████████████████▎                     | 134/285 [01:13<05:22,  2.14s/it]

[val  ] Accuracy = 0.924
[test ] Accuracy = 0.916
New best epoch!

>>> Epoch 8 | 0:00:00


 53%|█████████████████████▊                   | 152/285 [01:17<00:29,  4.51it/s]

[train] loss = 0.242


 54%|██████████████████████                   | 153/285 [01:23<04:38,  2.11s/it]

[val  ] Accuracy = 0.926
[test ] Accuracy = 0.918
New best epoch!

>>> Epoch 9 | 0:00:00


 60%|████████████████████████▌                | 171/285 [01:27<00:25,  4.46it/s]

[train] loss = 0.233


 60%|████████████████████████▋                | 172/285 [01:34<04:02,  2.15s/it]

[val  ] Accuracy = 0.924
[test ] Accuracy = 0.918

>>> Epoch 10 | 0:00:00


 67%|███████████████████████████▎             | 190/285 [01:38<00:21,  4.50it/s]

[train] loss = 0.23


 67%|███████████████████████████▍             | 191/285 [01:44<03:22,  2.15s/it]

[val  ] Accuracy = 0.927
[test ] Accuracy = 0.920
New best epoch!

>>> Epoch 11 | 0:00:00


 73%|██████████████████████████████           | 209/285 [01:48<00:17,  4.45it/s]

[train] loss = 0.221


 74%|██████████████████████████████▏          | 210/285 [01:55<02:37,  2.11s/it]

[val  ] Accuracy = 0.928
[test ] Accuracy = 0.920
New best epoch!

>>> Epoch 12 | 0:00:00


 80%|████████████████████████████████▊        | 228/285 [01:59<00:12,  4.42it/s]

[train] loss = 0.217


 80%|████████████████████████████████▉        | 229/285 [02:05<01:59,  2.13s/it]

[val  ] Accuracy = 0.927
[test ] Accuracy = 0.921

>>> Epoch 13 | 0:00:00


 87%|███████████████████████████████████▌     | 247/285 [02:09<00:08,  4.51it/s]

[train] loss = 0.216


 87%|███████████████████████████████████▋     | 248/285 [02:16<01:18,  2.13s/it]

[val  ] Accuracy = 0.926
[test ] Accuracy = 0.920

>>> Epoch 14 | 0:00:00


 93%|██████████████████████████████████████▎  | 266/285 [02:20<00:04,  4.42it/s]

[train] loss = 0.209


 94%|██████████████████████████████████████▍  | 267/285 [02:27<00:38,  2.13s/it]

[val  ] Accuracy = 0.927
[test ] Accuracy = 0.921

>>> Epoch 15 | 0:00:00


100%|█████████████████████████████████████████| 285/285 [02:31<00:00,  4.55it/s]

[train] loss = 0.206
[val  ] Accuracy = 0.928
[test ] Accuracy = 0.921


## Final result on finetuned model

In [24]:
metrics = evaluate(['test', 'test_backdoor'])

[test ] Accuracy = 0.921
[test_backdoor] Accuracy = 0.042
