train

Module Contents

class train.Trainer(dataset_meta, **kwargs)[source]

Bases: object

Class to bundle all logic and parameters that go into training a model on some dataset.

_dataset

The dataset used for training

Type:torch.utils.data.Dataset
_num_classes

Number of target categories (inferred)

Type:int
_shape

Shape of the input data (N, H, W, C)

Type:list
_batch_size

Training batch size

Type:int
_loss_function

Loss function instance for training

Type:torch.nn._Loss
_dataloader

loader for the training dataset

Type:torch.utils.data.DataLoader
_optimizer
Type:torch.optim.Optimizer
_scheduler
Type:torch.optim.lr_scheduler._LRScheduler
__init__(dataset_meta, **kwargs)[source]

Create a new Trainer. Handlers, model and optimizer are left uninitialised and must be set with add_subscriber(), set_model() and optimize() before calling train_batch().

Parameters:
  • dataset_meta (ikkuna.utils.DatasetMeta) – Train data, obtained via ikkuna.utils.load_dataset(). Currently, only full batches are used; if the batch size does not evenly divide the number of examples, the last batch is dropped.
  • batch_size (int) –
  • loss (function) – Defaults to torch.nn.CrossEntropyLoss
  • depth (int) – Depth to which to traverse the module tree. Ignored if exporter keyword arg is set
create_graph
current_batch

0-based batch index

Type:int
batches_per_epoch

number of batches in an epoch (assuming only full batches, I think)

Type:int
loss

The loss function in use

Type:torch.nn.Module
model

Model

Type:torch.nn.Module
exporter

Exporter used during training

Type:ikkuna.export.Exporter
optimizer

Optimizer in use, if set

Type:torch.optim.Optimizer
add_subscriber(subscriber)[source]

Add a subscriber.

Parameters:subscriber (ikkuna.export.subscriber.Subscriber) –
optimize(name='Adam', **kwargs)[source]

Set the optimizer.

Parameters:
  • name (str) – Name of the optimizer (must exist in torch.optim)
  • **kwargs – All other kwargs are forwarded to the optimizer constructor
initialize(init)[source]

Run an initilization funnction on Trainer.model

Parameters:init (function) –
set_schedule(Scheduler, *args, **kwargs)[source]

Set a scheduler to anneal the learning rate.

Parameters:
  • Scheduler (type) – Class of the Scheduler to use (e.g. LambdaLR)
  • *args (list) – Passed to the scheduler constructor
  • **kwargs (dict) – Passed to the scheduler constructor
set_model(model_or_str)[source]

Set the model to train. This method will attempt to load from ikkuna.models if a string is passed.

Warning

The function automatically calls torch.nn.Module.cuda() if cuda is available.

Parameters:model_or_str (torch.nn.Module or str) – Model or name of the model (must exist in ikkuna.models)
train_batch()[source]

Run through 1 batch in the training set. The iterator will wrap around and restart at the beginning.