import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from ikkuna.utils import create_optimizer
from ikkuna.export import Exporter
[docs]class Trainer:
'''Class to bundle all logic and parameters that go into training a model on some
dataset.
Attributes
----------
_dataset : torch.utils.data.Dataset
The dataset used for training
_num_classes : int
Number of target categories (inferred)
_shape : list
Shape of the input data (N, H, W, C)
_batch_size : int
Training batch size
_loss_function : torch.nn._Loss
Loss function instance for training
_dataloader : torch.utils.data.DataLoader
loader for the training dataset
_optimizer : torch.optim.Optimizer
_scheduler : torch.optim.lr_scheduler._LRScheduler
'''
[docs] def __init__(self, dataset_meta, **kwargs):
'''Create a new Trainer. Handlers, model and optimizer are left uninitialised and must be
set with :meth:`~train.Trainer.add_subscriber()`, :meth:`~train.Trainer.set_model()` and
:meth:`~train.Trainer.optimize()` before calling :meth:`~train.Trainer.train_batch()`.
Parameters
----------
dataset_meta : ikkuna.utils.DatasetMeta
Train data, obtained via :func:`ikkuna.utils.load_dataset()`. Currently,
only full batches are used; if the batch size does not evenly divide the
number of examples, the last batch is dropped.
batch_size : int
loss : function
Defaults to torch.nn.CrossEntropyLoss
depth : int
Depth to which to traverse the module tree. Ignored if ``exporter`` keyword arg
is set
'''
############################################################################################
# Acquire parameters #
############################################################################################
self._dataset, self._num_classes, self._shape = dataset_meta
self._batch_size = kwargs.pop('batch_size', 1)
self._loss_function = kwargs.pop('loss', nn.CrossEntropyLoss())
self._dataloader = DataLoader(self._dataset, batch_size=self._batch_size,
pin_memory=True, shuffle=True, drop_last=True)
self._data_iter = iter(self._dataloader)
N_train = self._shape[0]
self._batches_per_epoch = N_train // self._batch_size
self._batch_counter = 0
self._global_counter = 0
self._epoch = 0
self._scheduler = None
self._create_graph = kwargs.get('create_graph', False)
# we use these to peek one step ahead in the data iterator to know an epoch has ended
# already in the epoch's final iteration, not at the beginning of the next one
self._next_X, self._next_Y = next(self._data_iter)
print(f'Number of classes: {self._num_classes}')
print(f'Data shape: {self._shape}')
self._exporter = kwargs.get('exporter', Exporter(kwargs.get('depth', -1)))
self._exporter.set_loss(self._loss_function)
@property
def create_graph(self):
return self._create_graph
@create_graph.setter
def create_graph(self, value):
self._create_graph = value
@property
def current_batch(self):
'''int: 0-based batch index'''
return self._batch_counter
@property
def batches_per_epoch(self):
'''int: number of batches in an epoch (assuming only full batches, I think)'''
return self._batches_per_epoch
@property
def loss(self):
'''torch.nn.Module: The loss function in use'''
return self._loss_function
@property
def model(self):
'''torch.nn.Module: Model'''
return self._model
@property
def exporter(self):
'''ikkuna.export.Exporter: Exporter used during training'''
return self._exporter
@property
def optimizer(self):
'''torch.optim.Optimizer: Optimizer in use, if set'''
return self._optimizer
[docs] def add_subscriber(self, subscriber):
'''Add a subscriber.
Parameters
----------
subscriber : ikkuna.export.subscriber.Subscriber
'''
self._exporter.message_bus.register_subscriber(subscriber)
[docs] def optimize(self, name='Adam', **kwargs):
'''Set the optimizer.
Parameters
----------
name : str
Name of the optimizer (must exist in :mod:`torch.optim`)
**kwargs
All other kwargs are forwarded to the optimizer constructor
'''
self._optimizer = create_optimizer(self._model, name, **kwargs)
print(f'Using {self._optimizer.__class__.__name__} optimizer')
[docs] def initialize(self, init):
'''Run an initilization funnction on :attr:`Trainer.model`
Parameters
----------
init : function
'''
self._model.apply(init)
[docs] def set_schedule(self, Scheduler, *args, **kwargs):
'''Set a scheduler to anneal the learning rate.
Parameters
----------
Scheduler : type
Class of the Scheduler to use (e.g.
:class:`~torch.optim.lr_scheduler.LambdaLR`)
*args : list
Passed to the scheduler constructor
**kwargs : dict
Passed to the scheduler constructor
'''
if not self._optimizer:
raise ValueError('You must set the optimizer before setting the schedule.')
self._scheduler = Scheduler(self._optimizer, *args, **kwargs)
[docs] def set_model(self, model_or_str):
'''Set the model to train. This method will attempt to load from :mod:`ikkuna.models` if a
string is passed.
.. warning::
The function automatically calls :meth:`torch.nn.Module.cuda()` if cuda is available.
Parameters
----------
model_or_str : torch.nn.Module or str
Model or name of the model (must exist in :mod:`ikkuna.models`)
'''
if isinstance(model_or_str, str):
from ikkuna.utils import get_model
self._model = get_model(model_or_str, self._shape[1:], num_classes=self._num_classes,
exporter=self._exporter)
else:
self._model = model_or_str
if torch.cuda.is_available():
self._model.cuda()
if self._exporter._model is None:
self._exporter.set_model(self._model)
[docs] def train_batch(self):
'''Run through 1 batch in the training set. The iterator will wrap around and
restart at the beginning.'''
# to be safe, enable batch-norm, dropout, and the like. Could be changed externally, so
# do this before each epoch
self._model.train(True)
X, Y = self._next_X, self._next_Y
data, labels = X, Y
if torch.cuda.is_available():
data, labels = data.cuda(async=True), labels.cuda(async=True)
self._optimizer.zero_grad()
output = self._model(data)
loss = self._loss_function(output, labels)
loss.backward(create_graph=self._create_graph)
self._optimizer.step()
try:
self._next_X, self._next_Y = next(self._data_iter)
except StopIteration:
if self._scheduler:
self._scheduler.step(self._epoch)
self._exporter.epoch_finished()
self._batch_counter = 0
self._epoch += 1
self._data_iter = iter(self._dataloader)
self._next_X, self._next_Y = next(self._data_iter)
else:
if self._scheduler:
self._scheduler.step(self._epoch)
self._batch_counter += 1
self._global_counter += 1