Source code for ikkuna.export.subscriber.norm

from ikkuna.export.subscriber import PlotSubscriber, Subscription
from ikkuna.export.messages import get_default_bus


[docs]class NormSubscriber(PlotSubscriber): def __init__(self, kind, message_bus=get_default_bus(), tag='default', subsample=1, ylims=None, backend='tb', order=2): if not isinstance(kind, str): raise ValueError('NormSubscriber only accepts 1 kind') title = f'{kind}_norm{order}' ylabel = f'L{order} Norm' xlabel = 'Train step' subscription = Subscription(self, [kind], tag=tag, subsample=subsample) super().__init__([subscription], message_bus, {'title': title, 'ylabel': ylabel, 'ylims': ylims, 'xlabel': xlabel}, backend=backend) self._order = order self._add_publication(f'{kind}_norm{order}', type='DATA')
[docs] def compute(self, message): '''Compute the norm of a quantity. A :class:`~ikkuna.export.messages.ModuleMessage` with the identifier ``{kind}_norm{order}`` is published. ''' module, module_name = message.key data = message.data norm = data.norm(p=self._order) self._backend.add_data(module_name, norm, message.global_step) kind = f'{message.kind}_norm{self._order}' self.message_bus.publish_module_message(message.global_step, message.train_step, message.epoch, kind, message.key, norm)