Contents
train¶
Module Contents¶
-
class
train.
Trainer
(dataset_meta, **kwargs)[source]¶ Bases:
object
Class to bundle all logic and parameters that go into training a model on some dataset.
-
_dataset
¶ The dataset used for training
Type: torch.utils.data.Dataset
-
_loss_function
¶ Loss function instance for training
Type: torch.nn._Loss
-
_dataloader
¶ loader for the training dataset
Type: torch.utils.data.DataLoader
-
_optimizer
¶ Type: torch.optim.Optimizer
-
_scheduler
¶ Type: torch.optim.lr_scheduler._LRScheduler
-
__init__
(dataset_meta, **kwargs)[source]¶ Create a new Trainer. Handlers, model and optimizer are left uninitialised and must be set with
add_subscriber()
,set_model()
andoptimize()
before callingtrain_batch()
.Parameters: - dataset_meta (ikkuna.utils.DatasetMeta) – Train data, obtained via
ikkuna.utils.load_dataset()
. Currently, only full batches are used; if the batch size does not evenly divide the number of examples, the last batch is dropped. - batch_size (int) –
- loss (function) – Defaults to torch.nn.CrossEntropyLoss
- depth (int) – Depth to which to traverse the module tree. Ignored if
exporter
keyword arg is set
- dataset_meta (ikkuna.utils.DatasetMeta) – Train data, obtained via
-
create_graph
¶
-
loss
¶ The loss function in use
Type: torch.nn.Module
-
model
¶ Model
Type: torch.nn.Module
-
exporter
¶ Exporter used during training
Type: ikkuna.export.Exporter
-
optimizer
¶ Optimizer in use, if set
Type: torch.optim.Optimizer
-
add_subscriber
(subscriber)[source]¶ Add a subscriber.
Parameters: subscriber (ikkuna.export.subscriber.Subscriber) –
-
optimize
(name='Adam', **kwargs)[source]¶ Set the optimizer.
Parameters: - name (str) – Name of the optimizer (must exist in
torch.optim
) - **kwargs – All other kwargs are forwarded to the optimizer constructor
- name (str) – Name of the optimizer (must exist in
-
initialize
(init)[source]¶ Run an initilization funnction on
Trainer.model
Parameters: init (function) –
-
set_schedule
(Scheduler, *args, **kwargs)[source]¶ Set a scheduler to anneal the learning rate.
Parameters:
-
set_model
(model_or_str)[source]¶ Set the model to train. This method will attempt to load from
ikkuna.models
if a string is passed.Warning
The function automatically calls
torch.nn.Module.cuda()
if cuda is available.Parameters: model_or_str (torch.nn.Module or str) – Model or name of the model (must exist in ikkuna.models
)
-