main program¶
This module contains functions and classes for simplifying the training of ANN classifiers. It accepts the following arguments:
usage: main.py [-h] -m MODEL -d {MNIST,FashionMNIST,CIFAR10,CIFAR100}
[-b BATCH_SIZE] [-e EPOCHS] [-o OPTIMIZER] [-l LEARNING_RATE]
[-a N] [-s SUBSAMPLE] [-v {tb,mpl}] [-V]
[--spectral-norm TOPIC [TOPIC ...]]
[--variance TOPIC [TOPIC ...]] [--histogram TOPIC [TOPIC ...]]
[--ratio TOPIC,TOPIC [TOPIC,TOPIC ...]]
[--norm TOPIC [TOPIC ...]] [--test-accuracy] [--train-accuracy]
[--svcca] [--depth N] [--hessian] [--exponential-decay GAMMA]
[--log-dir LOG_DIR] [--seed SEED]
Named Arguments¶
-m, --model | Model class to train |
-d, --dataset | Possible choices: MNIST, FashionMNIST, CIFAR10, CIFAR100 Dataset to train on |
-b, --batch-size | |
Default: 128 | |
-e, --epochs | Default: 10 |
-o, --optimizer | |
Optimizer to use Default: “Adam” | |
-l, --learning-rate | |
Learning rate Default: 0.01 | |
-a, --ratio-average | |
Number of ratios to average for stability (currently unused) Default: 10 | |
-s, --subsample | |
Number of batches to ignore between updates Default: 1 | |
-v, --visualisation | |
Possible choices: tb, mpl Visualisation backend to use. Default: “tb” | |
-V, --verbose | Show training progress bar Default: False |
--spectral-norm | |
Use spectral norm subscriber(s) | |
--variance | Use variance norm subscriber(s) |
--histogram | Use histogram subscriber(s) |
--ratio | Use ratio subscriber(s) |
--norm | Use 2-norm subscriber(s) |
--test-accuracy | |
Use test set accuracy subscriber Default: False | |
--train-accuracy | |
Use train accuracy subscriber Default: False | |
--svcca | Use SVCCA subscriber Default: False |
--depth | Depth to which to add modules Default: -1 |
--hessian | Use Hessian tracker (substantially increases training time) Default: False |
--exponential-decay | |
Decay parameter for exponential decay | |
--log-dir | TensorBoard logdir Default: “runs” |
--seed | Seed to use. None means don’t seed |
-
main.
_main
(dataset_str, model_str, batch_size, epochs, optimizer, **kwargs)[source]¶ Run the training procedure.
Parameters:
-
main.
get_parser
()[source]¶ Obtain a configured argument parser. This function is necessary for the sphinx argparse extension.
Returns: Return type: argparse.ArgumentParser