Source code for main

import torch
import os
import wandb
import argparse, pprint

from train_tools import *
from SetupDict import TRAINER, OPTIMIZER, SCHEDULER, MODELS, PREDICTOR

# Ignore warnings for tiffle image reading
import logging

logging.getLogger().setLevel(logging.ERROR)

# Set torch base print precision
torch.set_printoptions(6)


def _get_setups(args):
    """Get experiment configuration"""

    # Set model
    model_args = args.train_setups.model
    model = MODELS[model_args.name](**model_args.params)

    # Load pretrained weights
    if model_args.pretrained.enabled:
        weights = torch.load(model_args.pretrained.weights, map_location="cpu")

        print("\nLoading pretrained model....")
        model.load_state_dict(weights, strict=model_args.pretrained.strict)

    # Set dataloaders
    dataloaders = datasetter.get_dataloaders_labeled(**args.data_setups.labeled)

    # Set optimizer
    optimizer_args = args.train_setups.optimizer
    optimizer = OPTIMIZER[optimizer_args.name](
        model.parameters(), **optimizer_args.params
    )

    # Set scheduler
    scheduler = None

    if args.train_setups.scheduler.enabled:
        scheduler_args = args.train_setups.scheduler
        scheduler = SCHEDULER[scheduler_args.name](optimizer, **scheduler_args.params)

    # Set trainer
    trainer_args = args.train_setups.trainer
    trainer = TRAINER[trainer_args.name](
        model, dataloaders, optimizer, scheduler, **trainer_args.params
    )

    # Check if no validation
    if args.data_setups.labeled.valid_portion == 0:
        trainer.no_valid = True

    # Set public dataloader
    if args.data_setups.public.enabled:
        dataloaders = datasetter.get_dataloaders_public(
            **args.data_setups.public.params
        )
        trainer.public_loader = dataloaders["public"]
        trainer.public_iterator = iter(dataloaders["public"])

    return trainer


[docs] def main(args): """Execute experiment.""" # Initialize W&B wandb.init(config=args, **args.wandb_setups) # How many batches to wait before logging training status wandb.config.log_interval = 10 # Fix randomness for reproducibility random_seeder(args.train_setups.seed) # Set experiment trainer = _get_setups(args) # Watch parameters & gradients of model wandb.watch(trainer.model, log="all", log_graph=True) # Conduct experiment trainer.train() # Upload model to wandb server model_path = os.path.join(wandb.run.dir, "model.pth") torch.save(trainer.model.state_dict(), model_path) wandb.save(model_path) # Conduct prediction using the trained model predictor = PREDICTOR[args.train_setups.trainer.name]( trainer.model, args.train_setups.trainer.params.device, args.pred_setups.input_path, args.pred_setups.output_path, args.pred_setups.make_submission, args.pred_setups.exp_name, args.pred_setups.algo_params, ) total_time = predictor.conduct_prediction() wandb.log({"total_time": total_time})
# Parser arguments for terminal execution
[docs] parser = argparse.ArgumentParser(description="Config file processing")
parser.add_argument("--config_path", default="./config/baseline.json", type=str)
[docs] args = parser.parse_args()
####################################################################################### if __name__ == "__main__": # Load configuration from .json file
[docs] opt = ConfLoader(args.config_path).opt
# Print configuration dictionary pretty pprint_config(opt) # Run experiment main(opt)