util.py 6.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215
  1. import argparse
  2. import datetime
  3. import json
  4. import os
  5. import pickle
  6. import random
  7. import shutil
  8. import sys
  9. import time
  10. import typing as ty
  11. from copy import deepcopy
  12. from pathlib import Path
  13. import numpy as np
  14. import pynvml
  15. import pytomlpp as toml
  16. import torch
  17. TRAIN = 'train'
  18. VAL = 'val'
  19. TEST = 'test'
  20. TEST_BACKDOOR = 'test_backdoor'
  21. PARTS = [TRAIN, VAL, TEST, TEST_BACKDOOR]
  22. BINCLASS = 'binclass'
  23. MULTICLASS = 'multiclass'
  24. REGRESSION = 'regression'
  25. TASK_TYPES = [BINCLASS, MULTICLASS, REGRESSION]
  26. def load_json(path: ty.Union[Path, str]) -> ty.Any:
  27. return json.loads(Path(path).read_text())
  28. def dump_json(x: ty.Any, path: ty.Union[Path, str], *args, **kwargs) -> None:
  29. Path(path).write_text(json.dumps(x, *args, **kwargs) + '\n')
  30. def load_toml(path: ty.Union[Path, str]) -> ty.Any:
  31. return toml.loads(Path(path).read_text())
  32. def dump_toml(x: ty.Any, path: ty.Union[Path, str]) -> None:
  33. Path(path).write_text(toml.dumps(x) + '\n')
  34. def load_pickle(path: ty.Union[Path, str]) -> ty.Any:
  35. return pickle.loads(Path(path).read_bytes())
  36. def dump_pickle(x: ty.Any, path: ty.Union[Path, str]) -> None:
  37. Path(path).write_bytes(pickle.dumps(x))
  38. def load(path: ty.Union[Path, str]) -> ty.Any:
  39. return globals()[f'load_{Path(path).suffix[1:]}'](path)
  40. def load_config(
  41. argv: ty.Optional[ty.List[str]] = None,
  42. ) -> ty.Tuple[ty.Dict[str, ty.Any], Path]:
  43. parser = argparse.ArgumentParser()
  44. parser.add_argument('config', metavar='FILE')
  45. parser.add_argument('-o', '--output', metavar='DIR')
  46. parser.add_argument('-f', '--force', action='store_true')
  47. parser.add_argument('--continue', action='store_true', dest='continue_')
  48. if argv is None:
  49. argv = sys.argv[1:]
  50. args = parser.parse_args(argv)
  51. snapshot_dir = os.environ.get('SNAPSHOT_PATH')
  52. if snapshot_dir and Path(snapshot_dir).joinpath('CHECKPOINTS_RESTORED').exists():
  53. assert args.continue_
  54. config_path = Path(args.config).absolute()
  55. output_dir = (
  56. Path(args.output)
  57. if args.output
  58. else config_path.parent.joinpath(config_path.stem)
  59. ).absolute()
  60. sep = '=' * (8 + max(len(str(config_path)), len(str(output_dir)))) # type: ignore[code]
  61. print(sep, f'Config: {config_path}', f'Output: {output_dir}', sep, sep='\n')
  62. assert config_path.exists()
  63. config = load_toml(config_path)
  64. if output_dir.exists():
  65. if args.force:
  66. print('Removing the existing output and creating a new one...')
  67. shutil.rmtree(output_dir)
  68. output_dir.mkdir()
  69. elif not args.continue_:
  70. backup_output(output_dir)
  71. print('Already done!\n')
  72. sys.exit()
  73. elif output_dir.joinpath('DONE').exists():
  74. backup_output(output_dir)
  75. print('Already DONE!\n')
  76. sys.exit()
  77. else:
  78. print('Continuing with the existing output...')
  79. else:
  80. print('Creating the output...')
  81. output_dir.mkdir()
  82. environment: ty.Dict[str, ty.Any] = {}
  83. if torch.cuda.is_available(): # type: ignore[code]
  84. cvd = os.environ.get('CUDA_VISIBLE_DEVICES')
  85. pynvml.nvmlInit()
  86. environment['devices'] = {
  87. 'CUDA_VISIBLE_DEVICES': cvd,
  88. 'torch.version.cuda': torch.version.cuda,
  89. 'torch.backends.cudnn.version()': torch.backends.cudnn.version(), # type: ignore[code]
  90. 'torch.cuda.nccl.version()': torch.cuda.nccl.version(), # type: ignore[code]
  91. 'driver': str(pynvml.nvmlSystemGetDriverVersion(), 'utf-8'),
  92. }
  93. if cvd:
  94. for i in map(int, cvd.split(',')):
  95. handle = pynvml.nvmlDeviceGetHandleByIndex(i)
  96. environment['devices'][i] = {
  97. 'name': str(pynvml.nvmlDeviceGetName(handle), 'utf-8'),
  98. 'total_memory': pynvml.nvmlDeviceGetMemoryInfo(handle).total,
  99. }
  100. dump_stats({'config': config, 'environment': environment}, output_dir)
  101. return config, output_dir
  102. def dump_stats(stats: dict, output_dir: Path, final: bool = False) -> None:
  103. dump_json(stats, output_dir / 'stats.json', indent=4)
  104. json_output_path = os.environ.get('JSON_OUTPUT_FILE')
  105. if final:
  106. output_dir.joinpath('DONE').touch()
  107. if json_output_path:
  108. try:
  109. key = str(output_dir.relative_to(env.PROJECT_DIR))
  110. except ValueError:
  111. pass
  112. else:
  113. json_output_path = Path(json_output_path)
  114. try:
  115. json_data = json.loads(json_output_path.read_text())
  116. except (FileNotFoundError, json.decoder.JSONDecodeError):
  117. json_data = {}
  118. json_data[key] = stats
  119. json_output_path.write_text(json.dumps(json_data))
  120. shutil.copyfile(
  121. json_output_path,
  122. os.path.join(os.environ['SNAPSHOT_PATH'], 'json_output.json'),
  123. )
  124. _LAST_SNAPSHOT_TIME = None
  125. def backup_output(output_dir: Path) -> None:
  126. backup_dir = os.environ.get('TMP_OUTPUT_PATH')
  127. snapshot_dir = os.environ.get('SNAPSHOT_PATH')
  128. if backup_dir is None:
  129. assert snapshot_dir is None
  130. return
  131. assert snapshot_dir is not None
  132. try:
  133. relative_output_dir = output_dir.relative_to(env.PROJECT_DIR)
  134. except ValueError:
  135. return
  136. for dir_ in [backup_dir, snapshot_dir]:
  137. new_output_dir = dir_ / relative_output_dir
  138. prev_backup_output_dir = new_output_dir.with_name(new_output_dir.name + '_prev')
  139. new_output_dir.parent.mkdir(exist_ok=True, parents=True)
  140. if new_output_dir.exists():
  141. new_output_dir.rename(prev_backup_output_dir)
  142. shutil.copytree(output_dir, new_output_dir)
  143. if prev_backup_output_dir.exists():
  144. shutil.rmtree(prev_backup_output_dir)
  145. global _LAST_SNAPSHOT_TIME
  146. if _LAST_SNAPSHOT_TIME is None or time.time() - _LAST_SNAPSHOT_TIME > 10 * 60:
  147. pass
  148. _LAST_SNAPSHOT_TIME = time.time()
  149. print('The snapshot was saved!')
  150. def raise_unknown(unknown_what: str, unknown_value: ty.Any):
  151. raise ValueError(f'Unknown {unknown_what}: {unknown_value}')
  152. def merge_defaults(kwargs: dict, default_kwargs: dict) -> dict:
  153. x = deepcopy(default_kwargs)
  154. x.update(kwargs)
  155. return x
  156. def set_seeds(seed: int) -> None:
  157. random.seed(seed)
  158. np.random.seed(seed)
  159. def format_seconds(seconds: float) -> str:
  160. return str(datetime.timedelta(seconds=round(seconds)))
  161. def get_categories(
  162. X_cat: ty.Optional[ty.Dict[str, torch.Tensor]]
  163. ) -> ty.Optional[ty.List[int]]:
  164. return (
  165. None
  166. if X_cat is None
  167. else [
  168. len(set(X_cat[TRAIN][:, i].cpu().tolist()))
  169. for i in range(X_cat[TRAIN].shape[1])
  170. ]
  171. )