from ikkuna.export.subscriber import PlotSubscriber, Subscription
from ikkuna.export.messages import get_default_bus
[docs]class NormSubscriber(PlotSubscriber):
    def __init__(self, kind, message_bus=get_default_bus(), tag='default', subsample=1, ylims=None,
                 backend='tb', order=2):
        if not isinstance(kind, str):
            raise ValueError('NormSubscriber only accepts 1 kind')
        title        = f'{kind}_norm{order}'
        ylabel       = f'L{order} Norm'
        xlabel       = 'Train step'
        subscription = Subscription(self, [kind], tag=tag, subsample=subsample)
        super().__init__([subscription], message_bus,
                         {'title': title,
                          'ylabel': ylabel,
                          'ylims': ylims,
                          'xlabel': xlabel},
                         backend=backend)
        self._order  = order
        self._add_publication(f'{kind}_norm{order}', type='DATA')
[docs]    def compute(self, message):
        '''Compute the norm of a quantity. A :class:`~ikkuna.export.messages.ModuleMessage`
        with the identifier ``{kind}_norm{order}`` is published. '''
        module, module_name  = message.key
        data = message.data
        norm = data.norm(p=self._order)
        self._backend.add_data(module_name, norm, message.global_step)
        kind = f'{message.kind}_norm{self._order}'
        self.message_bus.publish_module_message(message.global_step,
                                                message.train_step,
                                                message.epoch, kind,
                                                message.key, norm)