import torch
import numpy as np
import os, json, random
from pprint import pprint
__all__ = ["ConfLoader", "directory_setter", "random_seeder", "pprint_config"]
[docs]
class ConfLoader:
"""
Load json config file using DictWithAttributeAccess object_hook.
ConfLoader(conf_name).opt attribute is the result of loading json config file.
"""
[docs]
class DictWithAttributeAccess(dict):
"""
This inner class makes dict to be accessed same as class attribute.
For example, you can use opt.key instead of the opt['key'].
"""
def __getattr__(self, key):
return self[key]
def __setattr__(self, key, value):
self[key] = value
def __init__(self, conf_name):
[docs]
self.conf_name = conf_name
[docs]
self.opt = self.__get_opt()
def __load_conf(self):
with open(self.conf_name, "r") as conf:
opt = json.load(
conf, object_hook=lambda dict: self.DictWithAttributeAccess(dict)
)
return opt
def __get_opt(self):
opt = self.__load_conf()
opt = self.DictWithAttributeAccess(opt)
return opt
[docs]
def directory_setter(path="./results", make_dir=False):
"""
Make dictionary if not exists.
"""
if not os.path.exists(path) and make_dir:
os.makedirs(path) # make dir if not exist
print("directory %s is created" % path)
if not os.path.isdir(path):
raise NotADirectoryError(
"%s is not valid. set make_dir=True to make dir." % path
)
[docs]
def random_seeder(seed):
"""Fix randomness."""
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
[docs]
def pprint_config(opt):
print("\n" + "=" * 50 + " Configuration " + "=" * 50)
pprint(opt, compact=True)
print("=" * 115 + "\n")