Source code for train.train

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from ikkuna.utils import create_optimizer
from ikkuna.export import Exporter


[docs]class Trainer: '''Class to bundle all logic and parameters that go into training a model on some dataset. Attributes ---------- _dataset : torch.utils.data.Dataset The dataset used for training _num_classes : int Number of target categories (inferred) _shape : list Shape of the input data (N, H, W, C) _batch_size : int Training batch size _loss_function : torch.nn._Loss Loss function instance for training _dataloader : torch.utils.data.DataLoader loader for the training dataset _optimizer : torch.optim.Optimizer _scheduler : torch.optim.lr_scheduler._LRScheduler '''
[docs] def __init__(self, dataset_meta, **kwargs): '''Create a new Trainer. Handlers, model and optimizer are left uninitialised and must be set with :meth:`~train.Trainer.add_subscriber()`, :meth:`~train.Trainer.set_model()` and :meth:`~train.Trainer.optimize()` before calling :meth:`~train.Trainer.train_batch()`. Parameters ---------- dataset_meta : ikkuna.utils.DatasetMeta Train data, obtained via :func:`ikkuna.utils.load_dataset()`. Currently, only full batches are used; if the batch size does not evenly divide the number of examples, the last batch is dropped. batch_size : int loss : function Defaults to torch.nn.CrossEntropyLoss depth : int Depth to which to traverse the module tree. Ignored if ``exporter`` keyword arg is set ''' ############################################################################################ # Acquire parameters # ############################################################################################ self._dataset, self._num_classes, self._shape = dataset_meta self._batch_size = kwargs.pop('batch_size', 1) self._loss_function = kwargs.pop('loss', nn.CrossEntropyLoss()) self._dataloader = DataLoader(self._dataset, batch_size=self._batch_size, pin_memory=True, shuffle=True, drop_last=True) self._data_iter = iter(self._dataloader) N_train = self._shape[0] self._batches_per_epoch = N_train // self._batch_size self._batch_counter = 0 self._global_counter = 0 self._epoch = 0 self._scheduler = None self._create_graph = kwargs.get('create_graph', False) # we use these to peek one step ahead in the data iterator to know an epoch has ended # already in the epoch's final iteration, not at the beginning of the next one self._next_X, self._next_Y = next(self._data_iter) print(f'Number of classes: {self._num_classes}') print(f'Data shape: {self._shape}') self._exporter = kwargs.get('exporter', Exporter(kwargs.get('depth', -1))) self._exporter.set_loss(self._loss_function)
@property def create_graph(self): return self._create_graph @create_graph.setter def create_graph(self, value): self._create_graph = value @property def current_batch(self): '''int: 0-based batch index''' return self._batch_counter @property def batches_per_epoch(self): '''int: number of batches in an epoch (assuming only full batches, I think)''' return self._batches_per_epoch @property def loss(self): '''torch.nn.Module: The loss function in use''' return self._loss_function @property def model(self): '''torch.nn.Module: Model''' return self._model @property def exporter(self): '''ikkuna.export.Exporter: Exporter used during training''' return self._exporter @property def optimizer(self): '''torch.optim.Optimizer: Optimizer in use, if set''' return self._optimizer
[docs] def add_subscriber(self, subscriber): '''Add a subscriber. Parameters ---------- subscriber : ikkuna.export.subscriber.Subscriber ''' self._exporter.message_bus.register_subscriber(subscriber)
[docs] def optimize(self, name='Adam', **kwargs): '''Set the optimizer. Parameters ---------- name : str Name of the optimizer (must exist in :mod:`torch.optim`) **kwargs All other kwargs are forwarded to the optimizer constructor ''' self._optimizer = create_optimizer(self._model, name, **kwargs) print(f'Using {self._optimizer.__class__.__name__} optimizer')
[docs] def initialize(self, init): '''Run an initilization funnction on :attr:`Trainer.model` Parameters ---------- init : function ''' self._model.apply(init)
[docs] def set_schedule(self, Scheduler, *args, **kwargs): '''Set a scheduler to anneal the learning rate. Parameters ---------- Scheduler : type Class of the Scheduler to use (e.g. :class:`~torch.optim.lr_scheduler.LambdaLR`) *args : list Passed to the scheduler constructor **kwargs : dict Passed to the scheduler constructor ''' if not self._optimizer: raise ValueError('You must set the optimizer before setting the schedule.') self._scheduler = Scheduler(self._optimizer, *args, **kwargs)
[docs] def set_model(self, model_or_str): '''Set the model to train. This method will attempt to load from :mod:`ikkuna.models` if a string is passed. .. warning:: The function automatically calls :meth:`torch.nn.Module.cuda()` if cuda is available. Parameters ---------- model_or_str : torch.nn.Module or str Model or name of the model (must exist in :mod:`ikkuna.models`) ''' if isinstance(model_or_str, str): from ikkuna.utils import get_model self._model = get_model(model_or_str, self._shape[1:], num_classes=self._num_classes, exporter=self._exporter) else: self._model = model_or_str if torch.cuda.is_available(): self._model.cuda() if self._exporter._model is None: self._exporter.set_model(self._model)
[docs] def train_batch(self): '''Run through 1 batch in the training set. The iterator will wrap around and restart at the beginning.''' # to be safe, enable batch-norm, dropout, and the like. Could be changed externally, so # do this before each epoch self._model.train(True) X, Y = self._next_X, self._next_Y data, labels = X, Y if torch.cuda.is_available(): data, labels = data.cuda(async=True), labels.cuda(async=True) self._optimizer.zero_grad() output = self._model(data) loss = self._loss_function(output, labels) loss.backward(create_graph=self._create_graph) self._optimizer.step() try: self._next_X, self._next_Y = next(self._data_iter) except StopIteration: if self._scheduler: self._scheduler.step(self._epoch) self._exporter.epoch_finished() self._batch_counter = 0 self._epoch += 1 self._data_iter = iter(self._dataloader) self._next_X, self._next_Y = next(self._data_iter) else: if self._scheduler: self._scheduler.step(self._epoch) self._batch_counter += 1 self._global_counter += 1