diff --git a/README.md b/README.md index 06e1897..f41d9cf 100644 --- a/README.md +++ b/README.md @@ -4,7 +4,6 @@ Cortex is under heavy development. It's functional, but may not fit your needs y (some) Outstanding issues: -* Reloading models does not reload hyperparameters (they need to be specified again when reloading) * Need custom data iterator functionality within custom models * torchtext integration needed * Missing unit tests diff --git a/cortex/_lib/optimizer.py b/cortex/_lib/optimizer.py index 0e3ed93..0926335 100644 --- a/cortex/_lib/optimizer.py +++ b/cortex/_lib/optimizer.py @@ -1,6 +1,4 @@ -'''Module for setting up the optimizer. - -''' +"""Module for setting up the optimizer.""" from collections import defaultdict import logging @@ -52,8 +50,8 @@ def step(self, closure=None): """Performs a single optimization step. Arguments: - closure (callable, optional): A closure that reevaluates the model - and returns the loss. + closure (callable, optional): A closure that reevaluates the + model and returns the loss. """ loss = super().step(closure=closure) @@ -70,7 +68,7 @@ def step(self, closure=None): def setup(model, optimizer='Adam', learning_rate=1.e-4, weight_decay={}, clipping={}, optimizer_options={}, model_optimizer_options={}, scheduler=None, scheduler_options={}): - '''Optimizer entrypoint. + """Optimizer entrypoint. Args: optimizer: Optimizer type. See `torch.optim` for supported optimizers. @@ -83,7 +81,7 @@ def setup(model, optimizer='Adam', learning_rate=1.e-4, scheduler: Optimizer learning rate scheduler. scheduler_options: Options for scheduler. - ''' + """ OPTIMIZERS.clear() SCHEDULERS.clear() @@ -148,29 +146,26 @@ def setup(model, optimizer='Adam', learning_rate=1.e-4, for p in params: p.requires_grad = True - # Learning rates - if isinstance(learning_rate, dict): - eta = learning_rate[network_key] - else: - eta = learning_rate + def extract_value(dict_or_value, default=None): + if isinstance(dict_or_value, dict): + return dict_or_value.get(network_key, default) + return dict_or_value + # Learning rates + network_lr = extract_value(learning_rate) # Weight decay - if isinstance(weight_decay, dict): - wd = weight_decay.get(network_key, 0) - else: - wd = weight_decay - - if isinstance(clipping, dict): - cl = clipping.get(network_key, None) - else: - cl = clipping + network_wd = extract_value(weight_decay, 0) + # Gradient clipping + network_cl = extract_value(clipping) # Update the optimizer options optimizer_options_ = dict((k, v) for k, v in optimizer_options.items()) - optimizer_options_.update(weight_decay=wd, clipping=cl, lr=eta) + optimizer_options_.update( + weight_decay=network_wd, clipping=network_cl, lr=network_lr) - if network_key in model_optimizer_options.keys(): - optimizer_options_.update(**model_optimizer_options) + if network_key in model_optimizer_options: + optimizer_options_.update( + **eval(model_optimizer_options[network_key])) # Create the optimizer op = wrap_optimizer(op)