'''
.. moduleauthor:: Rasmus Diederichsen <rasmus@peltarion.com>
This module contains functions and classes for simplifying the training of ANN classifiers. It
accepts the following arguments:
.. argparse::
   :filename: ../main.py
   :func: get_parser
   :prog: main.py
'''
####################
#  stdlib imports  #
####################
from argparse import ArgumentParser, ArgumentTypeError
import warnings
#######################
#  3rd party imports  #
#######################
from tqdm import tqdm
from torchvision.transforms import ToTensor
#######################
#  1st party imports  #
#######################
from train import Trainer
from ikkuna.utils import load_dataset, seed_everything
from ikkuna.export.subscriber import (RatioSubscriber, HistogramSubscriber, SpectralNormSubscriber,
                                      TestAccuracySubscriber, TrainAccuracySubscriber,
                                      NormSubscriber, MessageMeanSubscriber,
                                      VarianceSubscriber, SVCCASubscriber)
from ikkuna.export import Exporter
from ikkuna.export.messages import MessageBus
import ikkuna.visualization
[docs]def _main(dataset_str, model_str, batch_size, epochs, optimizer, **kwargs):
    '''Run the training procedure.
    Parameters
    ----------
    dataset_str :   str
                    Name of the dataset to use
    model_str   :   str
                    Unqualified name of the model class to use
    batch_size  :   int
    epochs      :   int
    optimizer   :   str
                    Name of the optimizer to use
    '''
    dataset_train, dataset_test = load_dataset(dataset_str, train_transforms=[ToTensor()],
                                               test_transforms=[ToTensor()])
    # for some strange reason, python claims 'torch referenced before assignment' when importing at
    # the top. hahaaaaa
    import torch
    bus = MessageBus('main')
    trainer = Trainer(dataset_train, batch_size=batch_size,
                      exporter=Exporter(depth=kwargs['depth'],
                                        module_filter=[torch.nn.Conv2d],
                                        message_bus=bus))
    trainer.set_model(model_str)
    trainer.optimize(name=optimizer, lr=kwargs.get('learning_rate', 0.01))
    if 'exponential_decay' in kwargs:
        decay = kwargs['exponential_decay']
        if decay is not None:
            trainer.set_schedule(torch.optim.lr_scheduler.ExponentialLR, decay)
    subsample = kwargs['subsample']
    backend   = kwargs['visualisation']
    subscriber_added = False
    if kwargs['hessian']:
        from torch.utils.data import DataLoader
        from ikkuna.export.subscriber import HessianEigenSubscriber
        loader = DataLoader(dataset_train.dataset, batch_size=batch_size, shuffle=True)
        trainer.add_subscriber(HessianEigenSubscriber(trainer.model.forward, trainer.loss, loader,
                                                      batch_size,
                                                      frequency=trainer.batches_per_epoch,
                                                      num_eig=1, power_steps=25,
                                                      backend=backend))
        trainer.create_graph = True
        subscriber_added = True
    if kwargs['spectral_norm']:
        for kind in kwargs['spectral_norm']:
            spectral_norm_subscriber = SpectralNormSubscriber(kind, backend=backend)
            trainer.add_subscriber(spectral_norm_subscriber)
        subscriber_added = True
    if kwargs['variance']:
        for kind in kwargs['variance']:
            var_sub = VarianceSubscriber(kind, backend=backend)
            trainer.add_subscriber(var_sub)
        subscriber_added = True
    if kwargs['test_accuracy']:
        test_accuracy_subscriber = TestAccuracySubscriber(dataset_test, trainer.model.forward,
                                                          frequency=trainer.batches_per_epoch,
                                                          batch_size=batch_size,
                                                          backend=backend)
        trainer.add_subscriber(test_accuracy_subscriber)
        subscriber_added = True
    if kwargs['train_accuracy']:
        train_accuracy_subscriber = TrainAccuracySubscriber(subsample=subsample,
                                                            backend=backend)
        trainer.add_subscriber(train_accuracy_subscriber)
        subscriber_added = True
    if kwargs['ratio']:
        for kind1, kind2 in kwargs['ratio']:
            ratio_subscriber = RatioSubscriber([kind1, kind2],
                                               subsample=subsample,
                                               backend=backend)
            trainer.add_subscriber(ratio_subscriber)
            pubs = ratio_subscriber.publications
            type, topics = pubs.popitem()
            # there can be multiple publications per type, but we know the RatioSubscriber only
            # publishes one
            trainer.add_subscriber(MessageMeanSubscriber(topics[0]))
        subscriber_added = True
    if kwargs['histogram']:
        for kind in kwargs['histogram']:
            histogram_subscriber = HistogramSubscriber(kind, backend=backend)
            trainer.add_subscriber(histogram_subscriber)
        subscriber_added = True
    if kwargs['norm']:
        for kind in kwargs['norm']:
            norm_subscriber = NormSubscriber(kind, backend=backend)
            trainer.add_subscriber(norm_subscriber)
        subscriber_added = True
    if kwargs['svcca']:
        svcca_subscriber = SVCCASubscriber(dataset_test, 500, trainer.model.forward,
                                           subsample=trainer.batches_per_epoch, backend=backend)
        trainer.add_subscriber(svcca_subscriber)
        subscriber_added = True
    if not subscriber_added:
        warnings.warn('No subscriber was added, the will be no visualisation.')
    batches_per_epoch = trainer.batches_per_epoch
    print(f'Batches per epoch: {batches_per_epoch}')
    # exporter = trainer.exporter
    # modules = exporter.modules
    # n_modules = len(modules)
    epoch_range = range(epochs)
    batch_range = range(batches_per_epoch)
    if kwargs['verbose']:
        epoch_range = tqdm(epoch_range, desc='Epoch')
        batch_range = tqdm(batch_range, desc='Batch')
    for e in epoch_range:
        # freeze_idx = int(e/epochs * n_modules) - 1
        # if freeze_idx >= 0:
        #     exporter.freeze_module(modules[freeze_idx])
        for batch_idx in batch_range:
            trainer.train_batch() 
[docs]def get_parser():
    '''Obtain a configured argument parser. This function is necessary for the sphinx argparse
    extension.
    Returns
    -------
    argparse.ArgumentParser
    '''
    def list_of_tuples(input_):
        '''argparse type for passing a list of tuples'''
        try:
            kind1, kind2 = input_.split(',')
            return (kind1, kind2)
        except:     # noqa
            raise ArgumentTypeError('Values must be passed as val1,val2 (without space)')
    parser = ArgumentParser()
    parser.add_argument('-m', '--model', type=str, required=True, help='Model class to train')
    data_choices = ['MNIST', 'FashionMNIST', 'CIFAR10', 'CIFAR100']
    parser.add_argument('-d', '--dataset', type=str, choices=data_choices, required=True,
                        help='Dataset to train on')
    parser.add_argument('-b', '--batch-size', type=int, default=128)
    parser.add_argument('-e', '--epochs', type=int, default=10)
    parser.add_argument('-o', '--optimizer', type=str, default='Adam', help='Optimizer to use')
    parser.add_argument('-l', '--learning-rate', type=float, default=0.01, help='Learning rate')
    parser.add_argument('-a', '--ratio-average', type=int, default=10, help='Number of ratios to '
                        'average for stability (currently unused)', metavar='N')
    parser.add_argument('-s', '--subsample', type=int, default=1,
                        help='Number of batches to ignore between updates')
    # parser.add_argument('-y', '--ylims', nargs=2, type=int, default=None,
    #                     help='Y-axis limits for plots')
    parser.add_argument('-v', '--visualisation', type=str, choices=['tb', 'mpl'], default='tb',
                        help='Visualisation backend to use.')
    parser.add_argument('-V', '--verbose', action='store_true', default=False,
                        help='Show training progress bar')
    parser.add_argument('--spectral-norm', nargs='+', type=str, default=None, metavar='TOPIC',
                        help='Use spectral norm subscriber(s)')
    parser.add_argument('--variance', nargs='+', type=str, default=None, metavar='TOPIC',
                        help='Use variance norm subscriber(s)')
    parser.add_argument('--histogram', nargs='+', type=str, default=None, metavar='TOPIC',
                        help='Use histogram subscriber(s)')
    parser.add_argument('--ratio', type=list_of_tuples, nargs='+', default=None,
                        metavar='TOPIC,TOPIC', help='Use ratio subscriber(s)')
    parser.add_argument('--norm', nargs='+', type=str, default=None, metavar='TOPIC',
                        help='Use 2-norm subscriber(s)')
    parser.add_argument('--test-accuracy', action='store_true',
                        help='Use test set accuracy subscriber')
    parser.add_argument('--train-accuracy', action='store_true',
                        help='Use train accuracy subscriber')
    parser.add_argument('--svcca', action='store_true',
                        help='Use SVCCA subscriber')
    parser.add_argument('--depth', type=int, default=-1, help='Depth to which to add modules',
                        metavar='N')
    parser.add_argument('--hessian', action='store_true',
                        help='Use Hessian tracker (substantially increases training time)')
    parser.add_argument('--exponential-decay', type=float, required=False,
                        help='Decay parameter for exponential decay', metavar='GAMMA')
    parser.add_argument('--log-dir', type=str, required=False, help='TensorBoard logdir',
                        default='runs')
    parser.add_argument('--seed', type=int, required=False, default=None,
                        help='Seed to use. None means don\'t seed')
    return parser 
[docs]def main():
    args = get_parser().parse_args()
    kwargs = vars(args)
    ikkuna.visualization.TBBackend.info = str(kwargs)
    ikkuna.visualization.configure_prefix(args.log_dir)
    seed = kwargs.pop('seed')
    if seed is not None:
        seed_everything(seed)
    _main(kwargs.pop('dataset'), kwargs.pop('model'), kwargs.pop('batch_size'),
          kwargs.pop('epochs'), kwargs.pop('optimizer'), **vars(args)) 
if __name__ == '__main__':
    main()