ikkuna.export

Module Contents

class ikkuna.export.Exporter(depth, module_filter=None, message_bus=<ikkuna.export.messages.MessageBus object>)[source]

Bases: object

Class for managing publishing of data from model code.

An Exporter is used in the model code by either explicitly registering modules for tracking with add_modules() or by calling it with newly constructed modules which will then be returned as-is, but be registered in the process.

e = Exporter(...)
features = nn.Sequential([
    nn.Linear(...),
    e(nn.Conv2d(...)),
    nn.ReLU()
])

Modules will be tracked recursively unless specified otherwise, meaning the following is possible:

e = Exporter(...)
e.add_modules(extremely_complex_model)
# e will now track all layers of extremely_complex_model

Three further changes to the training code are necessary

  1. set_model() to have the Exporter wire up the appropriate callbacks.
  2. set_loss() should be called with the loss function so that labels can be extracted during training if if any Subscribers rely on the 'input_labels' message
  3. epoch_finished() should be called if any Subscribers rely on the 'epoch_finished' message
_modules

All tracked modules

Type:dict(torch.nn.Module, ikkuna.utils.NamedModule)
_weight_cache

Cache for keeping the previous weights for computing differences

Type:dict
_bias_cache

see _weight_cache

Type:dict
_model
Type:torch.nn.Module
_train_step

Current batch index

Type:int
_global_step

Global step accross all epochs

Type:int
_epoch

Current epoch

Type:int
_is_training

Flag enabling/disabling some messages during testing

Type:bool
_depth

Depth to which to traverse the module tree

Type:int
_module_filter

Set of modules to capture when calling add_modules(). Everything not in this list is ignored

Type:list(torch.nn.Module)
message_bus
modules

list(torch.nn.Module) - Modules tracked by this Exporter

named_modules

list(ikkuna.utils.NamedModule) - Named modules tracked by this Exporter

_add_module_by_name(named_module)[source]

Register a module with a name attached.

Parameters:named_module (ikkuna.utils.NamedModule) –
add_modules(module, recursive=True)[source]

Add modules to supervise. If the module has weight and/or bias members, updates to those will be tracked. Ignores any module in _module_filter.

Parameters:
Raises:

ValueError – If module is neither a tuple, nor a (subclass of) torch.nn.Module

__call__(module, recursive=True)[source]

Shorthand for add_modules() which returns its input unmodified.

:param see Exporter.add_modules():

Returns:The input module
Return type:torch.nn.Module
train(train=True)[source]

Switch to training mode. This will ensure all data is published.

test(test=True)[source]

Switch to testing mode. This will turn off all publishing.

new_loss(loss)[source]

Callback for publishing current training loss.

new_input_data(*args)[source]

Callback for new training input to the network.

Parameters:*args (tuple) – Network inputs
new_output_and_labels(network_output, labels)[source]

Callback for final network output.

Parameters:data (torch.Tensor) – The final layer’s output
new_activations(module, in_, out_)[source]

Callback for newly arriving activations. Registered as a hook to the tracked modules. Will trigger export of all new activation and weight/bias data.

Parameters:
new_layer_gradients(module, gradients)[source]

Callback for newly arriving layer gradients (loss wrt layer output). Registered as a hook to the tracked modules.

Warning

Currently, only layers with one output are supported.

Parameters:
Raises:

RuntimeError – If the module has multiple outputs

new_parameter_gradients(module, gradients)[source]

Callback for newly arriving gradients wrt weight and/or bias. Registered as a hook to the tracked modules. Will trigger export of all new gradient data.

Parameters:
set_model(model)[source]

Set the model for direct access for some metrics.

Parameters:model (torch.nn.Module) –
set_loss(loss_function)[source]

Add hook to loss function to extract labels.

Parameters:loss_function (torch.nn._Loss) –
step()[source]

Increase batch counter (per epoch) and the global step counter.

freeze_module(module)[source]

Convenience method for freezing training for a module.

Parameters:module (torch.nn.Module) – Module to freeze
epoch_finished()[source]

Increase the epoch counter and reset the batch counter.

Submodules

ikkuna.export.messages

ikkuna.export.messages.META_KINDS = {'batch_finished', 'batch_started', 'epoch_finished', 'epoch_started', 'input_data', 'input_labels', 'loss', 'network_output'}

Message kinds which are not tied to any specific module. These topics is just what comes with the library, others can be added to a specific MessageBus

ikkuna.export.messages.DATA_KINDS = {'activations', 'bias_gradients', 'bias_updates', 'biases', 'layer_gradients', 'weight_gradients', 'weight_updates', 'weights'}

Message kinds which are tied to a specific module and always carry data. These topics is just what comes with the library, others can be added to a specific MessageBus

class ikkuna.export.messages.Message(tag, global_step, train_step, epoch, kind)[source]

Bases: abc.ABC

Base class for messages emitted from the Exporter.

These messages are assembled into MessageBundle objects in the Subscription.

__init__(tag, global_step, train_step, epoch, kind)[source]
Parameters:
  • tag (str) – Tag for this message
  • global_step (int) – Global train step
  • train_step (int) – Epoch-local train step
  • epoch (int) – Epoch index
  • kind (str) – Message topic
tag

The tag associated with this message

Type:str
global_step

Global sequence number. This counter should not reset after each epoch.

Type:int
train_step

Epoch-local sequence number (the current batch index)

Type:int
epoch

Current epoch number

Type:int
kind

Message kind

Type:str
data

This field is optional for NetworkMessage, but mandatory for ModuleMessage

Type:torch.Tensor, tuple(torch.Tensor) or None
key

A key used for grouping messages into MessageBundle s

Type:object
class ikkuna.export.messages.NetworkMessage(tag, global_step, train_step, epoch, kind, data=None)[source]

Bases: ikkuna.export.messages.Message

A message with meta information not tied to any specific module. Can still carry tensor data, if necessary.

data

Optional data. Can be used e.g. for input to the network, labels or network output

Type:torch.Tensor, tuple, float, int or None
key

A key used for grouping messages into MessageBundle s

Type:object
class ikkuna.export.messages.ModuleMessage(tag, global_step, train_step, epoch, kind, named_module, data)[source]

Bases: ikkuna.export.messages.Message

A message tied to a specific module, with tensor data attached.

module

Module emitting this data

Type:torch.nn.Module
key

A key used for grouping messages into MessageBundle s

Type:object
class ikkuna.export.messages.MessageBundle(kinds)[source]

Bases: object

Data object for holding a set of artifacts for a module (or meta information) at one point during training. This data type can be used to buffer different kinds and check whether all expected kinds have been received for a module or meta information. The collection is enforced to be homogeneous with respect to global step, train step, epoch, and identifier (Message.key)

__init__(kinds)[source]
Parameters:kinds (str or list) – Single kind when a Subscription is used, or a list of Message kinds contained in this bundle for use with SynchronizedSubscription
key

An object denoting the common aspect of the collected messages (besides the steps). This can be the NamedModule emitting the data or a string such as 'META' or other denoting these are messages which do not belong to a module.

Type:str
kinds

Alias to expected_kinds

Type:list(str)
expected_kinds

The expected kinds of messages per iteration

Type:list(str)
data

The tensors received for each kind

Type:dict(str, torch.Tensor)
global_step

Global sequence number of this class

Type:int
train_step

Sequence number (training step) of the received messages (should match across all msgs in one iteration)

Type:int
epoch

Epoch of the received messages (should match across all msgs in one iteration)

Type:int
complete()[source]

Check if all expected messages have been received. This means the bundle can be released to subscribers.

Returns:
Return type:bool
check_message(message)[source]

Check consistency of sequence number, step and epoch or set if not set yet. Check consistency of identifier and check for duplication.

Parameters:message (ikkuna.export.messages.Message) –
Raises:ValueError – If message.(global_step|step|epoch|identifier) does not match the current (global_step|step|epoch|identifier) or in case a message of message.kind has already been received
add_message(message)[source]

Add a new message to this object. Will fail if the new messsage does not have the same sequence number and epoch.

Parameters:message (ikkuna.export.messages.Message) –
Raises:ValueError – see check_message()
__getattr__(name)[source]

Override to mimick a property for each kind of message in this data (e.g. message_bundle.activations instead of message_bundle.data['activations'])

class ikkuna.export.messages.MessageBus(name)[source]

Bases: object

A class which receives messages, registers subscribers and relays the former to the latter.

__init__(name)[source]
Parameters:name (str) – Identifier for this bus
register_meta_topic(kind)[source]

Register a topic so it can be subscribed.

deregister_meta_topic(kind)[source]

Unregister a topic so it can not be subscribed any longer.

register_data_topic(kind)[source]

Register a topic so it can be subscribed.

deregister_data_topic(kind)[source]

Unregister a topic so it can not be subscribed any longer.

name

The name of this bus

Type:str
register_subscriber(sub)[source]

Add a new subscriber to the set. Adding subscribers mutliple times will still only call them once per message.

Parameters:sub (ikkuna.export.subscriber.Subscriber) –
Raises:ValueError – If any of the kinds the Subscriber is interested in wasn’t previously registered
publish_network_message(global_step, train_step, epoch, kind, data=None, tag='default')[source]

Publish an update of type NetworkMessage to all registered subscribers.

Parameters:
  • global_step (int) – Global training step
  • train_step (int) – Epoch-relative training step
  • epoch (int) – Epoch index
  • kind (str) – Kind of message
  • data (torch.Tensor or None) – Payload, if necessary
publish_module_message(global_step, train_step, epoch, kind, named_module, data, tag='default')[source]

Publish an update of type ModuleMessage to all registered subscribers.

Parameters:
  • global_step (int) – Global training step
  • train_step (int) – Epoch-relative training step
  • epoch (int) – Epoch index
  • kind (str) – Kind of message
  • named_module (ikkuna.utils.NamedModule) – The module in question
  • data (torch.Tensor) – Payload
ikkuna.export.messages.get_default_bus()[source]

Get the default message bus which is always created when this module is loaded.

Returns:
Return type:MessageBus