Source code for main

'''
.. moduleauthor:: Rasmus Diederichsen <rasmus@peltarion.com>

This module contains functions and classes for simplifying the training of ANN classifiers. It
accepts the following arguments:

.. argparse::
   :filename: ../main.py
   :func: get_parser
   :prog: main.py
'''
####################
#  stdlib imports  #
####################
from argparse import ArgumentParser, ArgumentTypeError
import warnings

#######################
#  3rd party imports  #
#######################
from tqdm import tqdm
from torchvision.transforms import ToTensor

#######################
#  1st party imports  #
#######################
from train import Trainer
from ikkuna.utils import load_dataset, seed_everything
from ikkuna.export.subscriber import (RatioSubscriber, HistogramSubscriber, SpectralNormSubscriber,
                                      TestAccuracySubscriber, TrainAccuracySubscriber,
                                      NormSubscriber, MessageMeanSubscriber,
                                      VarianceSubscriber, SVCCASubscriber)
from ikkuna.export import Exporter
from ikkuna.export.messages import MessageBus
import ikkuna.visualization


[docs]def _main(dataset_str, model_str, batch_size, epochs, optimizer, **kwargs): '''Run the training procedure. Parameters ---------- dataset_str : str Name of the dataset to use model_str : str Unqualified name of the model class to use batch_size : int epochs : int optimizer : str Name of the optimizer to use ''' dataset_train, dataset_test = load_dataset(dataset_str, train_transforms=[ToTensor()], test_transforms=[ToTensor()]) # for some strange reason, python claims 'torch referenced before assignment' when importing at # the top. hahaaaaa import torch bus = MessageBus('main') trainer = Trainer(dataset_train, batch_size=batch_size, exporter=Exporter(depth=kwargs['depth'], module_filter=[torch.nn.Conv2d], message_bus=bus)) trainer.set_model(model_str) trainer.optimize(name=optimizer, lr=kwargs.get('learning_rate', 0.01)) if 'exponential_decay' in kwargs: decay = kwargs['exponential_decay'] if decay is not None: trainer.set_schedule(torch.optim.lr_scheduler.ExponentialLR, decay) subsample = kwargs['subsample'] backend = kwargs['visualisation'] subscriber_added = False if kwargs['hessian']: from torch.utils.data import DataLoader from ikkuna.export.subscriber import HessianEigenSubscriber loader = DataLoader(dataset_train.dataset, batch_size=batch_size, shuffle=True) trainer.add_subscriber(HessianEigenSubscriber(trainer.model.forward, trainer.loss, loader, batch_size, frequency=trainer.batches_per_epoch, num_eig=1, power_steps=25, backend=backend)) trainer.create_graph = True subscriber_added = True if kwargs['spectral_norm']: for kind in kwargs['spectral_norm']: spectral_norm_subscriber = SpectralNormSubscriber(kind, backend=backend) trainer.add_subscriber(spectral_norm_subscriber) subscriber_added = True if kwargs['variance']: for kind in kwargs['variance']: var_sub = VarianceSubscriber(kind, backend=backend) trainer.add_subscriber(var_sub) subscriber_added = True if kwargs['test_accuracy']: test_accuracy_subscriber = TestAccuracySubscriber(dataset_test, trainer.model.forward, frequency=trainer.batches_per_epoch, batch_size=batch_size, backend=backend) trainer.add_subscriber(test_accuracy_subscriber) subscriber_added = True if kwargs['train_accuracy']: train_accuracy_subscriber = TrainAccuracySubscriber(subsample=subsample, backend=backend) trainer.add_subscriber(train_accuracy_subscriber) subscriber_added = True if kwargs['ratio']: for kind1, kind2 in kwargs['ratio']: ratio_subscriber = RatioSubscriber([kind1, kind2], subsample=subsample, backend=backend) trainer.add_subscriber(ratio_subscriber) pubs = ratio_subscriber.publications type, topics = pubs.popitem() # there can be multiple publications per type, but we know the RatioSubscriber only # publishes one trainer.add_subscriber(MessageMeanSubscriber(topics[0])) subscriber_added = True if kwargs['histogram']: for kind in kwargs['histogram']: histogram_subscriber = HistogramSubscriber(kind, backend=backend) trainer.add_subscriber(histogram_subscriber) subscriber_added = True if kwargs['norm']: for kind in kwargs['norm']: norm_subscriber = NormSubscriber(kind, backend=backend) trainer.add_subscriber(norm_subscriber) subscriber_added = True if kwargs['svcca']: svcca_subscriber = SVCCASubscriber(dataset_test, 500, trainer.model.forward, subsample=trainer.batches_per_epoch, backend=backend) trainer.add_subscriber(svcca_subscriber) subscriber_added = True if not subscriber_added: warnings.warn('No subscriber was added, the will be no visualisation.') batches_per_epoch = trainer.batches_per_epoch print(f'Batches per epoch: {batches_per_epoch}') # exporter = trainer.exporter # modules = exporter.modules # n_modules = len(modules) epoch_range = range(epochs) batch_range = range(batches_per_epoch) if kwargs['verbose']: epoch_range = tqdm(epoch_range, desc='Epoch') batch_range = tqdm(batch_range, desc='Batch') for e in epoch_range: # freeze_idx = int(e/epochs * n_modules) - 1 # if freeze_idx >= 0: # exporter.freeze_module(modules[freeze_idx]) for batch_idx in batch_range: trainer.train_batch()
[docs]def get_parser(): '''Obtain a configured argument parser. This function is necessary for the sphinx argparse extension. Returns ------- argparse.ArgumentParser ''' def list_of_tuples(input_): '''argparse type for passing a list of tuples''' try: kind1, kind2 = input_.split(',') return (kind1, kind2) except: # noqa raise ArgumentTypeError('Values must be passed as val1,val2 (without space)') parser = ArgumentParser() parser.add_argument('-m', '--model', type=str, required=True, help='Model class to train') data_choices = ['MNIST', 'FashionMNIST', 'CIFAR10', 'CIFAR100'] parser.add_argument('-d', '--dataset', type=str, choices=data_choices, required=True, help='Dataset to train on') parser.add_argument('-b', '--batch-size', type=int, default=128) parser.add_argument('-e', '--epochs', type=int, default=10) parser.add_argument('-o', '--optimizer', type=str, default='Adam', help='Optimizer to use') parser.add_argument('-l', '--learning-rate', type=float, default=0.01, help='Learning rate') parser.add_argument('-a', '--ratio-average', type=int, default=10, help='Number of ratios to ' 'average for stability (currently unused)', metavar='N') parser.add_argument('-s', '--subsample', type=int, default=1, help='Number of batches to ignore between updates') # parser.add_argument('-y', '--ylims', nargs=2, type=int, default=None, # help='Y-axis limits for plots') parser.add_argument('-v', '--visualisation', type=str, choices=['tb', 'mpl'], default='tb', help='Visualisation backend to use.') parser.add_argument('-V', '--verbose', action='store_true', default=False, help='Show training progress bar') parser.add_argument('--spectral-norm', nargs='+', type=str, default=None, metavar='TOPIC', help='Use spectral norm subscriber(s)') parser.add_argument('--variance', nargs='+', type=str, default=None, metavar='TOPIC', help='Use variance norm subscriber(s)') parser.add_argument('--histogram', nargs='+', type=str, default=None, metavar='TOPIC', help='Use histogram subscriber(s)') parser.add_argument('--ratio', type=list_of_tuples, nargs='+', default=None, metavar='TOPIC,TOPIC', help='Use ratio subscriber(s)') parser.add_argument('--norm', nargs='+', type=str, default=None, metavar='TOPIC', help='Use 2-norm subscriber(s)') parser.add_argument('--test-accuracy', action='store_true', help='Use test set accuracy subscriber') parser.add_argument('--train-accuracy', action='store_true', help='Use train accuracy subscriber') parser.add_argument('--svcca', action='store_true', help='Use SVCCA subscriber') parser.add_argument('--depth', type=int, default=-1, help='Depth to which to add modules', metavar='N') parser.add_argument('--hessian', action='store_true', help='Use Hessian tracker (substantially increases training time)') parser.add_argument('--exponential-decay', type=float, required=False, help='Decay parameter for exponential decay', metavar='GAMMA') parser.add_argument('--log-dir', type=str, required=False, help='TensorBoard logdir', default='runs') parser.add_argument('--seed', type=int, required=False, default=None, help='Seed to use. None means don\'t seed') return parser
[docs]def main(): args = get_parser().parse_args() kwargs = vars(args) ikkuna.visualization.TBBackend.info = str(kwargs) ikkuna.visualization.configure_prefix(args.log_dir) seed = kwargs.pop('seed') if seed is not None: seed_everything(seed) _main(kwargs.pop('dataset'), kwargs.pop('model'), kwargs.pop('batch_size'), kwargs.pop('epochs'), kwargs.pop('optimizer'), **vars(args))
if __name__ == '__main__': main()