Skip to content

Make loss functions and regularizers classes that inherit from torch.nn #131

@iancze

Description

@iancze

Currently our loss functions are coded as straightforward functions working on torch inputs. Some loss functions have additional parameters that are set at initialization, for example,

def entropy(cube, prior_intensity): ...
where prior_intensity is a reference cube used in the evaluation of the actual target, cube.

This works fine, but can get cumbersome, especially when we are interfacing with a bunch of loss functions all at once (as in a cross-validation loop).

Can we take a page from the way PyTorch designs its loss functions and make most if not all loss functions classes that inherit from torch.nn? This would create objects that could be instantiated with default parameter values easily and generalize the calls to each parameter. For example, see MSE Loss.

This may have additional benefits (with reduce, say) if we think about batching and applications to multiple GPUs.

Does it additionally make sense to include the lambda terms as parameters of the loss object, too? @kadri-nizam do you have any thoughts from experience w/ your VAE architecture?

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions