main program

This module contains functions and classes for simplifying the training of ANN classifiers. It accepts the following arguments:

usage: [-h] -m MODEL -d {MNIST,FashionMNIST,CIFAR10,CIFAR100}
               [-b BATCH_SIZE] [-e EPOCHS] [-o OPTIMIZER] [-l LEARNING_RATE]
               [-a N] [-s SUBSAMPLE] [-v {tb,mpl}] [-V]
               [--spectral-norm TOPIC [TOPIC ...]]
               [--variance TOPIC [TOPIC ...]] [--histogram TOPIC [TOPIC ...]]
               [--ratio TOPIC,TOPIC [TOPIC,TOPIC ...]]
               [--norm TOPIC [TOPIC ...]] [--test-accuracy] [--train-accuracy]
               [--svcca] [--depth N] [--hessian] [--exponential-decay GAMMA]
               [--log-dir LOG_DIR] [--seed SEED]

Named Arguments

-m, --model Model class to train
-d, --dataset

Possible choices: MNIST, FashionMNIST, CIFAR10, CIFAR100

Dataset to train on

-b, --batch-size
 Default: 128
-e, --epochs Default: 10
-o, --optimizer

Optimizer to use

Default: “Adam”

-l, --learning-rate

Learning rate

Default: 0.01

-a, --ratio-average

Number of ratios to average for stability (currently unused)

Default: 10

-s, --subsample

Number of batches to ignore between updates

Default: 1

-v, --visualisation

Possible choices: tb, mpl

Visualisation backend to use.

Default: “tb”

-V, --verbose

Show training progress bar

Default: False

 Use spectral norm subscriber(s)
--variance Use variance norm subscriber(s)
--histogram Use histogram subscriber(s)
--ratio Use ratio subscriber(s)
--norm Use 2-norm subscriber(s)

Use test set accuracy subscriber

Default: False


Use train accuracy subscriber

Default: False


Use SVCCA subscriber

Default: False


Depth to which to add modules

Default: -1


Use Hessian tracker (substantially increases training time)

Default: False

 Decay parameter for exponential decay

TensorBoard logdir

Default: “runs”

--seed Seed to use. None means don’t seed
main._main(dataset_str, model_str, batch_size, epochs, optimizer, **kwargs)[source]

Run the training procedure.

  • dataset_str (str) – Name of the dataset to use
  • model_str (str) – Unqualified name of the model class to use
  • batch_size (int) –
  • epochs (int) –
  • optimizer (str) – Name of the optimizer to use

Obtain a configured argument parser. This function is necessary for the sphinx argparse extension.

Return type:argparse.ArgumentParser