Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ Cortex is under heavy development. It's functional, but may not fit your needs y

(some) Outstanding issues:

* Reloading models does not reload hyperparameters (they need to be specified again when reloading)
* Need custom data iterator functionality within custom models
* torchtext integration needed
* Missing unit tests
Expand Down
43 changes: 19 additions & 24 deletions cortex/_lib/optimizer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
'''Module for setting up the optimizer.

'''
"""Module for setting up the optimizer."""

from collections import defaultdict
import logging
Expand Down Expand Up @@ -52,8 +50,8 @@ def step(self, closure=None):
"""Performs a single optimization step.

Arguments:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
closure (callable, optional): A closure that reevaluates the
model and returns the loss.
"""
loss = super().step(closure=closure)

Expand All @@ -70,7 +68,7 @@ def step(self, closure=None):
def setup(model, optimizer='Adam', learning_rate=1.e-4,
weight_decay={}, clipping={}, optimizer_options={},
model_optimizer_options={}, scheduler=None, scheduler_options={}):
'''Optimizer entrypoint.
"""Optimizer entrypoint.

Args:
optimizer: Optimizer type. See `torch.optim` for supported optimizers.
Expand All @@ -83,7 +81,7 @@ def setup(model, optimizer='Adam', learning_rate=1.e-4,
scheduler: Optimizer learning rate scheduler.
scheduler_options: Options for scheduler.

'''
"""

OPTIMIZERS.clear()
SCHEDULERS.clear()
Expand Down Expand Up @@ -148,29 +146,26 @@ def setup(model, optimizer='Adam', learning_rate=1.e-4,
for p in params:
p.requires_grad = True

# Learning rates
if isinstance(learning_rate, dict):
eta = learning_rate[network_key]
else:
eta = learning_rate
def extract_value(dict_or_value, default=None):
if isinstance(dict_or_value, dict):
return dict_or_value.get(network_key, default)
return dict_or_value

# Learning rates
network_lr = extract_value(learning_rate)
# Weight decay
if isinstance(weight_decay, dict):
wd = weight_decay.get(network_key, 0)
else:
wd = weight_decay

if isinstance(clipping, dict):
cl = clipping.get(network_key, None)
else:
cl = clipping
network_wd = extract_value(weight_decay, 0)
# Gradient clipping
network_cl = extract_value(clipping)

# Update the optimizer options
optimizer_options_ = dict((k, v) for k, v in optimizer_options.items())
optimizer_options_.update(weight_decay=wd, clipping=cl, lr=eta)
optimizer_options_.update(
weight_decay=network_wd, clipping=network_cl, lr=network_lr)

if network_key in model_optimizer_options.keys():
optimizer_options_.update(**model_optimizer_options)
if network_key in model_optimizer_options:
optimizer_options_.update(
**eval(model_optimizer_options[network_key]))

# Create the optimizer
op = wrap_optimizer(op)
Expand Down