metrics.py 2.9 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889
  1. import typing as ty
  2. import numpy as np
  3. import scipy.special
  4. import sklearn.metrics as skm
  5. from . import util
  6. def calculate_metrics(
  7. task_type: str,
  8. y: np.ndarray,
  9. prediction: np.ndarray,
  10. classification_mode: str,
  11. y_info: ty.Optional[ty.Dict[str, ty.Any]],
  12. ) -> ty.Dict[str, float]:
  13. if task_type == util.REGRESSION:
  14. del classification_mode
  15. rmse = skm.mean_squared_error(y, prediction) ** 0.5 # type: ignore[code]
  16. if y_info:
  17. if y_info['policy'] == 'mean_std':
  18. rmse *= y_info['std']
  19. else:
  20. assert False
  21. return {'rmse': rmse, 'score': -rmse}
  22. else:
  23. assert task_type in (util.BINCLASS, util.MULTICLASS)
  24. labels = None
  25. if classification_mode == 'probs':
  26. probs = prediction
  27. elif classification_mode == 'logits':
  28. probs = (
  29. scipy.special.expit(prediction)
  30. if task_type == util.BINCLASS
  31. else scipy.special.softmax(prediction, axis=1)
  32. )
  33. else:
  34. assert classification_mode == 'labels'
  35. probs = None
  36. labels = prediction
  37. if labels is None:
  38. labels = (
  39. np.round(probs).astype('int64')
  40. if task_type == util.BINCLASS
  41. else probs.argmax(axis=1) # type: ignore[code]
  42. )
  43. result = skm.classification_report(y, labels, output_dict=True, zero_division=0) # type: ignore[code]
  44. if task_type == util.BINCLASS:
  45. try:
  46. result['roc_auc'] = skm.roc_auc_score(y, probs) # type: ignore[code]
  47. except: # in case we only have class in our test set (like for ASR)
  48. result['roc_auc'] = 0.0
  49. result['score'] = result['accuracy'] # type: ignore[code]
  50. return result # type: ignore[code]
  51. def make_summary(metrics: ty.Dict[str, ty.Any]) -> str:
  52. precision = 3
  53. summary = {}
  54. for k, v in metrics.items():
  55. if k.isdigit():
  56. continue
  57. k = {
  58. 'score': 'SCORE',
  59. 'accuracy': 'acc',
  60. 'roc_auc': 'roc_auc',
  61. 'macro avg': 'm',
  62. 'weighted avg': 'w',
  63. }.get(k, k)
  64. if isinstance(v, float):
  65. v = round(v, precision)
  66. summary[k] = v
  67. else:
  68. v = {
  69. {'precision': 'p', 'recall': 'r', 'f1-score': 'f1', 'support': 's'}.get(
  70. x, x
  71. ): round(v[x], precision)
  72. for x in v
  73. }
  74. for item in v.items():
  75. summary[k + item[0]] = item[1]
  76. #s = [f'Accuracy = {summary.pop("acc"):.3f}']
  77. #for k, v in summary.items():
  78. # if k not in ['mp', 'mr', 'wp', 'wr', 'mf1', 'wf1', 'ms', 'ws']: # just to save screen space
  79. # s.append(f'{k} = {v}')
  80. #return ' | '.join(s)
  81. return f'Accuracy = {summary.pop("acc"):.3f}'