1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889 |
- import typing as ty
- import numpy as np
- import scipy.special
- import sklearn.metrics as skm
- from . import util
- def calculate_metrics(
- task_type: str,
- y: np.ndarray,
- prediction: np.ndarray,
- classification_mode: str,
- y_info: ty.Optional[ty.Dict[str, ty.Any]],
- ) -> ty.Dict[str, float]:
- if task_type == util.REGRESSION:
- del classification_mode
- rmse = skm.mean_squared_error(y, prediction) ** 0.5 # type: ignore[code]
- if y_info:
- if y_info['policy'] == 'mean_std':
- rmse *= y_info['std']
- else:
- assert False
- return {'rmse': rmse, 'score': -rmse}
- else:
- assert task_type in (util.BINCLASS, util.MULTICLASS)
- labels = None
- if classification_mode == 'probs':
- probs = prediction
- elif classification_mode == 'logits':
- probs = (
- scipy.special.expit(prediction)
- if task_type == util.BINCLASS
- else scipy.special.softmax(prediction, axis=1)
- )
- else:
- assert classification_mode == 'labels'
- probs = None
- labels = prediction
- if labels is None:
- labels = (
- np.round(probs).astype('int64')
- if task_type == util.BINCLASS
- else probs.argmax(axis=1) # type: ignore[code]
- )
- result = skm.classification_report(y, labels, output_dict=True, zero_division=0) # type: ignore[code]
- if task_type == util.BINCLASS:
- try:
- result['roc_auc'] = skm.roc_auc_score(y, probs) # type: ignore[code]
- except: # in case we only have class in our test set (like for ASR)
- result['roc_auc'] = 0.0
- result['score'] = result['accuracy'] # type: ignore[code]
- return result # type: ignore[code]
- def make_summary(metrics: ty.Dict[str, ty.Any]) -> str:
- precision = 3
- summary = {}
- for k, v in metrics.items():
- if k.isdigit():
- continue
- k = {
- 'score': 'SCORE',
- 'accuracy': 'acc',
- 'roc_auc': 'roc_auc',
- 'macro avg': 'm',
- 'weighted avg': 'w',
- }.get(k, k)
- if isinstance(v, float):
- v = round(v, precision)
- summary[k] = v
- else:
- v = {
- {'precision': 'p', 'recall': 'r', 'f1-score': 'f1', 'support': 's'}.get(
- x, x
- ): round(v[x], precision)
- for x in v
- }
- for item in v.items():
- summary[k + item[0]] = item[1]
-
- #s = [f'Accuracy = {summary.pop("acc"):.3f}']
- #for k, v in summary.items():
- # if k not in ['mp', 'mr', 'wp', 'wr', 'mf1', 'wf1', 'ms', 'ws']: # just to save screen space
- # s.append(f'{k} = {v}')
- #return ' | '.join(s)
- return f'Accuracy = {summary.pop("acc"):.3f}'
|