-
Notifications
You must be signed in to change notification settings - Fork 12
Description
Currently our loss functions are coded as straightforward functions working on torch inputs. Some loss functions have additional parameters that are set at initialization, for example,
def entropy(cube, prior_intensity): ...
where prior_intensity is a reference cube used in the evaluation of the actual target, cube.
This works fine, but can get cumbersome, especially when we are interfacing with a bunch of loss functions all at once (as in a cross-validation loop).
Can we take a page from the way PyTorch designs its loss functions and make most if not all loss functions classes that inherit from torch.nn? This would create objects that could be instantiated with default parameter values easily and generalize the calls to each parameter. For example, see MSE Loss.
This may have additional benefits (with reduce, say) if we think about batching and applications to multiple GPUs.
Does it additionally make sense to include the lambda terms as parameters of the loss object, too? @kadri-nizam do you have any thoughts from experience w/ your VAE architecture?