Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
1210 commits
Select commit Hold shift + click to select a range
a3e36fe
fix syntax
smilesun Jan 17, 2024
100c687
fix diva
smilesun Jan 17, 2024
2d70f93
Merge branch 'fbopt' into fbopt_matchdiva
smilesun Jan 17, 2024
a3f3559
code style
ntchen Jan 17, 2024
2ec560e
c_msel_val_top_k
ntchen Jan 17, 2024
96e0b44
c_msel_setpoin_delay
ntchen Jan 17, 2024
791d4b7
resolve too many attributes
ntchen Jan 17, 2024
30393a3
ignore too many local variables
ntchen Jan 17, 2024
3a077f6
code style
ntchen Jan 17, 2024
c39fb00
pylint
ntchen Jan 17, 2024
fb93437
pylint
ntchen Jan 17, 2024
b226ffb
Merge pull request #750 from marrlab/foptcodacy
smilesun Jan 18, 2024
898af0b
Merge branch 'fbopt' into fbopt_matchdiva
smilesun Jan 18, 2024
358c936
Merge branch 'master' into fbopt
smilesun Jan 18, 2024
7c92b5e
Merge pull request #746 from marrlab/fbopt_matchdiva
smilesun Jan 18, 2024
aa1d5cd
Merge branch 'master' into fbopt
smilesun Jan 19, 2024
2319105
Update pyproject.toml
smilesun Jan 19, 2024
04a6f25
Update ci.yml: revert how pytest runs: no poetry
smilesun Jan 19, 2024
1298519
Update pyproject.toml: rm tensorboard = "^2.14.0"
smilesun Jan 19, 2024
0248bb0
Merge branch 'master' into fbopt
smilesun Jan 20, 2024
5bd00d6
Update ci.yml
smilesun Jan 21, 2024
108c2e7
Update ci.yml
smilesun Jan 21, 2024
d615736
Merge branch 'master' into fbopt
smilesun Jan 21, 2024
00a30af
add tensorboard to pyproject
smilesun Jan 21, 2024
6dbe7d6
update poetry.lock file
smilesun Jan 22, 2024
486d81a
update dependencies
agisga Jan 23, 2024
6a39fb6
Merge pull request #765 from marrlab/issue_754
smilesun Jan 24, 2024
b7d21db
Merge branch 'master' into fbopt
smilesun Jan 24, 2024
2cd16d8
Merge branch 'master' into fbopt
smilesun Jan 29, 2024
1c4dfd1
reduce nubmer of iterations
smilesun Jan 29, 2024
eb812c2
.
smilesun Jan 29, 2024
1374b33
use aug pac
smilesun Feb 1, 2024
b763838
Update pacs_diva_fbopt_alone_es1.yaml: batch 64 to 32
smilesun Feb 1, 2024
6d34ab4
Update pacs_diva_fbopt_alone_es1_autoki.yaml
smilesun Feb 1, 2024
a3fa382
es=10
smilesun Feb 6, 2024
964c892
.
smilesun Feb 6, 2024
01e03bb
merge conflict with master
smilesun Feb 6, 2024
f224039
Merge branch 'master' into fbopt
smilesun Feb 7, 2024
4c69a52
force setpoint change once
smilesun Feb 7, 2024
775e41f
Update pacs_diva_fbopt_alone_es1_autoki.yaml
smilesun Feb 8, 2024
600e39e
attempt to reproduce previous diva results on pacs data
agisga Feb 8, 2024
b58cbff
Merge pull request #778 from marrlab/diva_pacs_reproduce
smilesun Feb 8, 2024
1c3f26a
1e-5 type notation not supported for yaml file
agisga Feb 8, 2024
8b9d921
.
smilesun Feb 8, 2024
180f2b8
better yaml for benchmark
smilesun Feb 9, 2024
6541cd4
refine yaml file
smilesun Feb 9, 2024
e05081a
Rename pacs_diva_fbopt_alone_es1.yaml to pacs_diva_fbopt_alone_es1_ra…
smilesun Feb 13, 2024
5bf06e0
Merge branch 'master' into fbopt
smilesun Feb 14, 2024
b0e679f
add gamma-y sample
smilesun Feb 15, 2024
99c9ad2
minor improvements to generation of figures for the fbopt experiments
agisga Feb 15, 2024
e38e9b5
add class to put tensorboard data to text file
smilesun Feb 23, 2024
e6ffee6
make fbopt script into separate folder
smilesun Feb 23, 2024
9ecb520
improve code
smilesun Feb 23, 2024
eee82b4
Merge branch 'fbopt' into fbopt_visualizations
agisga Feb 23, 2024
7b13762
Merge pull request #782 from marrlab/fbopt_visualizations
agisga Feb 23, 2024
8cca9cf
minor adjustments to the visualizations
agisga Feb 23, 2024
6dd6b2d
fbopt figures: backup x and y data to txt files
agisga Feb 23, 2024
e538970
bug fix for the last commit
agisga Feb 23, 2024
76b73a3
fbopt plots: fixes txt saving
agisga Feb 23, 2024
f5c06a7
single run reproduce
smilesun Feb 26, 2024
ba1e19d
move benchmark submit back
smilesun Feb 26, 2024
e3ef1ab
change mode for submission script
smilesun Feb 26, 2024
98689ff
add output dir
smilesun Feb 27, 2024
eea73e3
added matplotlib from nutan
smilesun Feb 27, 2024
33c027a
logscale
smilesun Feb 27, 2024
d11da9b
complex math display not possible
smilesun Feb 27, 2024
3f1f04f
change output name
smilesun Feb 27, 2024
9907274
remove latex in filename
smilesun Feb 27, 2024
0ac873d
.
smilesun Feb 27, 2024
3edc169
todo
smilesun Feb 27, 2024
75eaaa3
.
smilesun Feb 28, 2024
c8f5f5f
skip draw
smilesun Feb 28, 2024
4e4f490
arrow width
smilesun Feb 28, 2024
cd800eb
plot len
smilesun Feb 28, 2024
1712cac
phase portrait arrow size automatic
smilesun Feb 28, 2024
d3b0a12
color bar for phase portrait
smilesun Feb 28, 2024
5603790
bounding gbox tight
smilesun Feb 28, 2024
9e1eb99
remove \ in filename
smilesun Feb 28, 2024
5364878
more latex
smilesun Feb 28, 2024
7385e81
.
smilesun Feb 28, 2024
0ed3a21
latex in plot
smilesun Feb 28, 2024
3b1bb09
phase portrait neg
smilesun Feb 28, 2024
b70dbd0
log scale to phase portrait plot after carla's suggestion
smilesun Feb 29, 2024
382005e
log single curve
smilesun Feb 29, 2024
f3e9d0f
comment how setpoint model selction works
smilesun Mar 1, 2024
1ae0ef7
.
smilesun Mar 8, 2024
56e06cb
.
smilesun Mar 8, 2024
5e7b9c9
Merge branch 'master' into fbopt
smilesun Mar 8, 2024
d88d0df
instructions for the reproduction of M-HOF plots as in the paper
agisga Mar 12, 2024
45ffb5d
instructions to reproduce an M-HOF run in the README.md
agisga Mar 12, 2024
969f101
minor clarification in README.md
agisga Mar 12, 2024
6457db3
Update README.md
smilesun Mar 12, 2024
605f09e
Update README.md
smilesun Mar 15, 2024
a7f022b
Update README.md
smilesun Mar 15, 2024
4e093c5
Update README.md
smilesun Mar 17, 2024
b566765
Update README.md
smilesun Mar 17, 2024
40f7093
Update README.md
smilesun Mar 19, 2024
68a8f1c
Update README.md
smilesun Mar 25, 2024
264343b
merged gamma_reg_collision
MatteoWohlrapp May 7, 2024
5a6ba4e
Added changes to the new model selection
MatteoWohlrapp May 7, 2024
d374a19
merged master
MatteoWohlrapp May 7, 2024
fa7b95f
Merged gamma_reg updated coverage
MatteoWohlrapp May 7, 2024
ed4b2d1
Merge branch 'master' into mhof_dev_merge
MatteoWohlrapp May 10, 2024
2b3573d
added new benchmark for pacs, dial, fopt, erm
MatteoWohlrapp May 10, 2024
f17e540
updated benchmark
MatteoWohlrapp May 10, 2024
2bf91a8
using diva instead of erm as it has hyper init
MatteoWohlrapp May 10, 2024
71b9d26
updated benchmark yaml
MatteoWohlrapp May 10, 2024
2d275a4
renamed the
MatteoWohlrapp May 13, 2024
0608f8d
Merge branch 'master' into mhof_dev_merge
MatteoWohlrapp May 13, 2024
e4d87d9
added hyper init and hyper update function and a new benchmark for fb…
MatteoWohlrapp May 13, 2024
5c44550
Fixed import of backpack
MatteoWohlrapp May 14, 2024
416389f
Fixed benchmark import not successfull in erm
MatteoWohlrapp May 14, 2024
d59d37a
fixed indentation
MatteoWohlrapp May 14, 2024
a4fa31d
Added backpack check for fishr
MatteoWohlrapp May 14, 2024
a38c777
Debugging backpack in erm
MatteoWohlrapp May 14, 2024
870bc57
Adding more logging to erm
MatteoWohlrapp May 14, 2024
d66de08
irm benhcmakr
smilesun May 14, 2024
f7795a9
added list_str_multiplier
MatteoWohlrapp May 14, 2024
6c5ca27
Added directories to gitignore, adjusted fobt_fishr_erm benchmark
May 14, 2024
6738c2c
Changed batch size due to memory reasons and changed task naming
May 14, 2024
67691de
correct hyper spec
smilesun May 17, 2024
38fb5f2
Solved indexing issue in fbopt mu controller and added flag info to t…
May 22, 2024
caf32b7
removed prints
May 22, 2024
a098a68
Fixed indentation for convert4backpack
May 22, 2024
518d919
fixed codacity
May 28, 2024
e947f67
fixed codacity
May 28, 2024
94a1819
Merge branch 'master' into mhof_dev_merge
May 28, 2024
5b28c73
Added the hyperparameter naming to parts of the models novel on this …
May 28, 2024
e6a6f3e
Merge branch 'mhof_dev_merge' into erm_hyper_init
May 28, 2024
7124cc7
Fixed codacity
Jun 11, 2024
5288432
Merge branch 'mhof_dev_merge' into erm_hyper_init
MatteoWohlrapp Jun 11, 2024
d2ac388
retrigger run
Jun 11, 2024
4f1347a
fixed codacity
Jun 11, 2024
6a81b4c
Merge branch 'erm_hyper_init' of https://github.com/marrlab/DomainLab…
Jun 11, 2024
f2fa952
fixed codacity
Jun 11, 2024
279f510
Added test for erm functions
Jun 11, 2024
f8f4f0e
Disabling line too long for argument
Jun 11, 2024
c7caeb5
Merge branch 'mhof_dev_merge' into erm_hyper_init
MatteoWohlrapp Jul 2, 2024
530999b
Merge branch 'erm_hyper_init' into erm_hyper_init_updated_irm
smilesun Jul 2, 2024
d1ccf46
Update task_pacs_aug.py, update path of PACS
smilesun Jul 2, 2024
6acdfdd
Update pacs_fbopt_fishr_erm.yaml
smilesun Jul 2, 2024
d60a789
Update pacs_fbopt_fishr_erm.yaml
smilesun Jul 2, 2024
976c25a
Update task_pacs_aug.py, fix codacy
smilesun Jul 2, 2024
a99c9f5
Merge pull request #843 from marrlab/erm_hyper_init
smilesun Jul 2, 2024
33ed1cb
Merge pull request #845 from marrlab/mhof_dev_merge
smilesun Jul 2, 2024
f90e655
Merge branch 'mhof_dev' into erm_hyper_init_updated_irm
smilesun Jul 5, 2024
f4e6a0c
merge conflict
smilesun Jul 12, 2024
f1b413f
Merge branch 'mhof_dev' into erm_hyper_init_updated_irm
smilesun Jul 12, 2024
d90ffae
Merge branch 'master' into mhof_dev
smilesun Jul 14, 2024
0605ffc
Merge branch 'master' into mhof_dev
smilesun Jul 15, 2024
1c683ba
Update pacs_fbopt_dial_diva.yaml
smilesun Jul 16, 2024
dc7005a
Merge branch 'master' into mhof_dev
smilesun Jul 18, 2024
10442ba
better yaml for dial_diva
smilesun Jul 18, 2024
6f5ce78
Merge branch 'master' into mhof_dev
smilesun Jul 20, 2024
174799d
Merge branch 'master' into mhof_dev
smilesun Jul 25, 2024
580bd23
Merge branch 'master' into mhof_dev
smilesun Sep 17, 2024
4686aa4
Merge branch 'master' into mhof_dev
smilesun Oct 4, 2024
913873f
Merge branch 'mhof_dev' into erm_hyper_init_updated_irm
smilesun Oct 4, 2024
b2b204d
copy branch fbopt_vector_ki_gain to mhof_dev
smilesun Oct 4, 2024
0678950
rename yaml
smilesun Oct 5, 2024
9a01811
erm alone more hyper
smilesun Oct 5, 2024
9150621
Merge pull request #838 from marrlab/erm_hyper_init_updated_irm
smilesun Oct 5, 2024
ca9766e
Merge branch 'mhof_dev' into mhof_dev_vector_ki_gain
smilesun Oct 5, 2024
cbdc505
Update aistat_irm_erm_mhof.yaml
smilesun Oct 6, 2024
731a144
Merge branch 'mhof_dev' into mhof_dev_vector_ki_gain
smilesun Oct 6, 2024
144efba
Update aistat_irm_erm_mhof.yaml, fix grammar error
smilesun Oct 6, 2024
9e09396
Merge branch 'mhof_dev' into mhof_dev_vector_ki_gain
smilesun Oct 6, 2024
68c1e36
Update sh_link_pacs_dataset.sh
smilesun Oct 6, 2024
b4ddafe
Merge branch 'mhof_dev' into mhof_dev_vector_ki_gain
smilesun Oct 6, 2024
3a051fe
Update .gitignore, display slurm_errors,
smilesun Oct 6, 2024
d53c863
Merge branch 'mhof_dev' into mhof_dev_vector_ki_gain
smilesun Oct 6, 2024
e731f27
Merge branch 'master' into mhof_dev
smilesun Oct 8, 2024
f15478c
Merge branch 'mhof_dev' into mhof_dev_vector_ki_gain
smilesun Oct 8, 2024
1be1634
causalirl yaml
smilesun Oct 8, 2024
121ca55
.
smilesun Oct 8, 2024
6626475
use self.decoratee instead of self.model
smilesun Oct 8, 2024
d5d3a0b
cmd script to test mhof irm
smilesun Oct 8, 2024
e6c368b
script to test mhof_irm
smilesun Oct 8, 2024
6701cbf
enable grad for irm inside torch.no_grad for mhof
smilesun Oct 9, 2024
888d714
filter out zero reg loss in abstract trainer
smilesun Oct 9, 2024
64bcc9c
trainer behaves like model, now decoratte's cal_loss has to be changed
smilesun Oct 9, 2024
22f343c
overwrite multiplier from scheduler(default static scheduler then no …
smilesun Oct 9, 2024
23607fb
per domain irm to separate file
smilesun Oct 9, 2024
28bdfa9
Merge branch 'mhof_dev' into mhof_dev_vector_ki_gain
smilesun Oct 9, 2024
63ad47b
dial mhof yaml
smilesun Oct 9, 2024
8ba7e11
number of bathces to estimate ratio
smilesun Oct 9, 2024
6073403
Update aistat_irm_erm_only.yaml
smilesun Oct 9, 2024
8bd1163
Update aistat_irm_erm_mhof.yaml
smilesun Oct 9, 2024
7667cbc
.
smilesun Oct 9, 2024
5800ef1
Merge branch 'master' into mhof_dev
smilesun Oct 10, 2024
795478f
Merge branch 'mhof_dev' into mhof_dev_vector_ki_gain
smilesun Oct 10, 2024
b64e5d0
Update aistat_irm_erm_only.yaml
smilesun Oct 10, 2024
88861ce
Merge branch 'master' into mhof_dev
smilesun Oct 10, 2024
6433f58
Merge branch 'mhof_dev' into mhof_dev_vector_ki_gain
smilesun Oct 10, 2024
62183bc
Merge branch 'master' into mhof_dev
smilesun Oct 10, 2024
6f5bac1
Merge branch 'mhof_dev' into mhof_dev_vector_ki_gain
smilesun Oct 10, 2024
ba68e27
fix num_batches for loss ratio estimate
smilesun Oct 10, 2024
c09879d
.
smilesun Oct 10, 2024
394cdd6
add back missing @property due to text insert
smilesun Oct 10, 2024
401e978
Merge branch 'mhof_dev' into mhof_dev_vector_ki_gain
smilesun Oct 10, 2024
5585b4a
use correct hyper range
smilesun Oct 10, 2024
d176e1a
Update aistat_irm_erm_mhof.yaml
smilesun Oct 10, 2024
3a4226e
add unit test
smilesun Oct 10, 2024
011cccb
Update aistat_irm_erm_mhof.yaml
smilesun Oct 10, 2024
ae03219
Update and rename pacs_fbopt_dial_diva.yaml to aistat_pacs_mhof_dial_…
smilesun Oct 10, 2024
3166eb2
dial
smilesun Oct 10, 2024
d8cee2d
fix partially issue #777
smilesun Oct 11, 2024
9884640
detailed doc for mhof args, change str args to boolean, fix issue #777
smilesun Oct 11, 2024
cd15915
new irm yaml file
smilesun Oct 11, 2024
273422b
.
smilesun Oct 11, 2024
dbbda55
take square of irm loss, copy reg loss from decoratee to fbopt
smilesun Oct 11, 2024
1cb2428
.
smilesun Oct 11, 2024
155f0df
.
smilesun Oct 11, 2024
c41e92b
.
smilesun Oct 11, 2024
b8731cb
Milestone: feedforward works now with trainers
smilesun Oct 13, 2024
f4e4773
yaml for feedforward
smilesun Oct 13, 2024
e3ed8c6
tr_with_init_mu
smilesun Oct 15, 2024
e38fa75
doc
smilesun Oct 15, 2024
044e3cc
.
smilesun Oct 15, 2024
aea6e90
no ma for setpoint
smilesun Oct 17, 2024
e5197f3
logger
smilesun Oct 17, 2024
dc58a50
setpoint ada as argument
smilesun Oct 17, 2024
9ddd037
yaml file search setpoint ada
smilesun Oct 17, 2024
7bcda06
Merge pull request #880 from marrlab/mhof_dev_vector_ki_gain
smilesun Oct 17, 2024
80d6a73
correct yaml
smilesun Oct 17, 2024
733bb0b
update yaml
smilesun Oct 18, 2024
f123280
adamw benchmark
smilesun Nov 29, 2024
893a982
Merge branch 'master' into mhof_dev
smilesun Nov 29, 2024
28b31c2
Update aistat_trainer_combo_dial_irm_erm_mhof.yaml
smilesun Nov 29, 2024
0103913
benchmark lr_scheduler
smilesun Dec 4, 2024
b0279dc
Merge branch 'lr_scheduler' into mhof_dev_lr_scheduler
smilesun Dec 4, 2024
c29877e
Update ci.yml
smilesun Dec 4, 2024
6597acb
train_causalIRL.py self.cal_loss return a tuple
smilesun Dec 4, 2024
61d7d98
Merge branch 'master' into mhof_dev
smilesun Dec 4, 2024
3eed766
fix bug of deleting multipliers, now only do for erm
smilesun Dec 5, 2024
87cbc68
not allowing naked erm be combined with fbopt
smilesun Dec 5, 2024
7d8d081
fix typo new argument type
smilesun Dec 5, 2024
b296229
Update train_ema.py
smilesun Dec 5, 2024
3cef1a2
fix test_ma
smilesun Dec 5, 2024
6f08403
no_dump in test_fbopt.py
smilesun Dec 6, 2024
ae907b2
unit test no_dump
smilesun Dec 10, 2024
924f4cc
Merge branch 'master' into mhof_dev
smilesun Dec 10, 2024
7ac297a
Merge branch 'mhof_dev' into mhof_dev_lr_scheduler
smilesun Dec 10, 2024
c4d8bfc
Update aistat_trainer_combo_dial_irm_erm_mhof.yaml
smilesun Dec 10, 2024
78aeafd
Merge pull request #898 from marrlab/mhof_dev_lr_scheduler
smilesun Dec 10, 2024
ab19a3e
Update test_irm.py
smilesun Dec 10, 2024
40327c0
Update test_irm.py, no_dump to save disk space, last try
smilesun Dec 10, 2024
598a4cc
Merge branch 'master' into mhof_dev
smilesun Dec 10, 2024
30fe2f5
split test_fbopt into two files
smilesun Dec 11, 2024
e431d73
remove dial combo with jigen in test_mk_exp_jigen.py, issue #901
smilesun Dec 11, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,10 @@ name: CI

on:
push:
branches: master
branches: mhof_dev
pull_request:
branches: master
branches: mhof_dev
workflow_dispatch:

jobs:
test:
name: Run tests
Expand Down
2 changes: 0 additions & 2 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@ tests/__pycache__/
.vscode/
domainlab/zdata/pacs
/data/
/.snakemake/
/dist
/domainlab.egg-info
/runs
/slurm_errors.txt
69 changes: 69 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ For example, the following result (without any augmentation like flip) is for P

Source: https://arxiv.org/pdf/2403.14356.pdf

Citation:
```bibtex
@misc{sun2024domainlab,
title={DomainLab: A modular Python package for domain generalization in deep learning},
Expand All @@ -132,3 +133,71 @@ Source: https://arxiv.org/pdf/2403.14356.pdf
year={2024}
}
```

# M-HOF-Opt: Multi-Objective Hierarchical Output Feedback Optimization via Multiplier Induced Loss Landscape Scheduling
Source: https://arxiv.org/pdf/2403.13728.pdf

M-HOF-Opt is implemented in [DomainLab](https://github.com/marrlab/DomainLab). If you meet any problems, feel free to report them at https://github.com/marrlab/DomainLab/issues

## Dependencies and Data Preparation
#### Example dependencies installation
```
git checkout mhof # switch to mhof branch
conda create --name domainlab_py39 python=3.9 # create a virtual environment
conda activate domainlab_py39 # activate virtual environment
conda install pytorch==1.12.1 torchvision==0.13.1 torchaudio==0.12.1 cudatoolkit=11.6 -c pytorch -c conda-forge
conda install torchmetrics==0.10.3
pip install -r requirements_notorch.txt
conda install tensorboard # install tensorboard
```

#### Data preparation: download the domain generalization dataset PACS

step 1:

use the following script to download PACS to your local laptop and upload it to your cluster

https://github.com/marrlab/DomainLab/blob/fbopt/data/script/download_pacs.py

step 2:
make a symbolic link following the example script in https://github.com/marrlab/DomainLab/blob/master/sh_pacs.sh

where `mkdir -p data/pacs` is executed under the repository directory,

`ln -s /dir/to/yourdata/pacs/raw ./data/pacs/PACS`
will create a symbolic link under the repository directory

### M-HOF experiments reproduction

#### Run the experiment

To execute a single run of the M-HOF method, from the root folder run the command:

```
python main_out.py -c a_reproduce_pacs_diva.yaml
```

which uses the configuration file [a_reproduce_pacs_diva.yaml](https://github.com/marrlab/DomainLab/blob/mhof/a_reproduce_pacs_diva.yaml).

#### Visualization of the results

The results of the experiment are stored in the `runs` directory generated by Tensorboard.
The various loss curves with the corresponding setpoint change curves, as well as phase-portrait-like figures showing the loss dynamics between the task loss and the various regularization losses, can be obtained by running the script [script_generate_all_figures_diva.sh](https://github.com/marrlab/DomainLab/blob/mhof/script_generate_all_figures_diva.sh):

```
bash script_generate_all_figures_diva.sh
```

The resulting figures will be stored in the directory `figures_diva`, which can be changed by editing the top of the [script_generate_all_figures_diva.sh](https://github.com/marrlab/DomainLab/blob/mhof/script_generate_all_figures_diva.sh) file if needed.

Citation:
```bibtex
@misc{sun2024m,
title={M-HOF-Opt: Multi-Objective Hierarchical Output Feedback Optimization via Multiplier Induced Loss Landscape Scheduling},
author={Sun, Xudong and Chen, Nutan and Gossmann, Alexej and Xing, Yu and Dorigatt, Emilio and Drost, Felix and Feistner, Carla and Scarcella, Daniele and Beer, Lisa and Marr, Carsten},
journal={https://arxiv.org/pdf/2403.13728.pdf},
number={2403.13728},
year={2024},
publisher={https://arxiv.org/pdf/2403.13728.pdf}
}
```
24 changes: 24 additions & 0 deletions a_reproduce_pacs_diva.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
te_d: sketch
tpath: examples/tasks/task_pacs_aug.py
bs: 32
model: diva
trainer: fbopt
gamma_y: 1.0
ini_setpoint_ratio: 0.99
str_diva_multiplier_type: gammad_recon
coeff_ma_output_state: 0.1
coeff_ma_setpoint: 0.9
exp_shoulder_clip: 5
mu_init: 0.000001
k_i_gain_ratio: 0.5
mu_clip: 10
epos: 1000
epos_min: 200
npath: examples/nets/resnet50domainbed.py
npath_dom: examples/nets/resnet50domainbed.py
es: 2
lr: 0.00005
zx_dim: 0
zy_dim: 64
zd_dim: 64
force_setpoint_change_once: True
1 change: 1 addition & 0 deletions a_test_feedforward_irm.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
python main_out.py --te_d=0 --tr_d 1 2 --task=mnistcolor10 --model=erm --nname=conv_bn_pool_2 --trainer=hyperscheduler_irm_dial --k_i_gain_ratio=0.5 --force_setpoint_change_once --epos=10 --epos_min=4 --exp_shoulder_clip=1 --mu_clip=100 --ini_setpoint_ratio=0.99999999
1 change: 1 addition & 0 deletions a_test_mhof_irm.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
python main_out.py --te_d=0 --tr_d 1 2 --task=mnistcolor10 --model=erm --nname=conv_bn_pool_2 --trainer=fbopt_irm_dial --k_i_gain_ratio=0.5 --force_setpoint_change_once --epos=500 --epos_min=400 --exp_shoulder_clip=1 --mu_clip=100 --ini_setpoint_ratio=0.9 --nb4reg_over_task_ratio=0 --tr_with_init_mu --coeff_ma_setpoint=0.0 --str_setpoint_ada="SliderAnyComponent()"
9 changes: 7 additions & 2 deletions domainlab/algos/builder_diva.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@
"""
from domainlab.algos.a_algo_builder import NodeAlgoBuilder
from domainlab.algos.msels.c_msel_oracle import MSelOracleVisitor
from domainlab.algos.msels.c_msel_setpoint_delay import MSelSetpointDelay
from domainlab.algos.msels.c_msel_val import MSelValPerf
from domainlab.algos.msels.c_msel_val_top_k import MSelValPerfTopK
from domainlab.algos.observers.b_obvisitor import ObVisitor
from domainlab.algos.observers.c_obvisitor_cleanup import ObVisitorCleanUp
from domainlab.algos.observers.c_obvisitor_gen import ObVisitorGen
Expand Down Expand Up @@ -35,7 +37,8 @@ def init_business(self, exp):
request = RequestVAEBuilderCHW(task.isize.c, task.isize.h, task.isize.w, args)
node = VAEChainNodeGetter(request)()
task.get_list_domains_tr_te(args.tr_d, args.te_d)
model = mk_diva(list_str_y=task.list_str_y)(
model = mk_diva(str_diva_multiplier_type=args.str_diva_multiplier_type, list_str_y=task.list_str_y)(

node,
zd_dim=args.zd_dim,
zy_dim=args.zy_dim,
Expand All @@ -48,7 +51,9 @@ def init_business(self, exp):
beta_d=args.beta_d,
)
device = get_device(args)
model_sel = MSelOracleVisitor(MSelValPerf(max_es=args.es), val_threshold=args.val_threshold)
model_sel = MSelSetpointDelay(
MSelOracleVisitor(MSelValPerfTopK(max_es=args.es)), val_threshold=args.val_threshold
)
if not args.gen:
observer = ObVisitor(model_sel)
else:
Expand Down
21 changes: 21 additions & 0 deletions domainlab/algos/builder_fbopt_dial.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
"""
builder for feedback optimization of dial
"""
from domainlab.algos.builder_diva import NodeAlgoBuilderDIVA
from domainlab.algos.trainers.train_fbopt_b import TrainerFbOpt


class NodeAlgoBuilderFbOptDial(NodeAlgoBuilderDIVA):
"""
builder for feedback optimization for dial
"""

def init_business(self, exp):
"""
return trainer, model, observer
"""
trainer_in, model, observer, device = super().init_business(exp)
trainer_in.init_business(model, exp.task, observer, device, exp.args)
trainer = TrainerFbOpt()
trainer.init_business(trainer_in, exp.task, observer, device, exp.args)
return trainer, model, observer, device
4 changes: 3 additions & 1 deletion domainlab/algos/builder_jigen1.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@
"""
from domainlab.algos.a_algo_builder import NodeAlgoBuilder
from domainlab.algos.msels.c_msel_oracle import MSelOracleVisitor
from domainlab.algos.msels.c_msel_setpoint_delay import MSelSetpointDelay
from domainlab.algos.msels.c_msel_val import MSelValPerf
from domainlab.algos.msels.c_msel_val_top_k import MSelValPerfTopK
from domainlab.algos.observers.b_obvisitor import ObVisitor
from domainlab.algos.observers.c_obvisitor_cleanup import ObVisitorCleanUp
from domainlab.algos.trainers.hyper_scheduler import HyperSchedulerWarmupExponential
Expand All @@ -30,7 +32,7 @@ def init_business(self, exp):
task = exp.task
args = exp.args
device = get_device(args)
msel = MSelOracleVisitor(msel=MSelValPerf(max_es=args.es), val_threshold=args.val_threshold)
msel = MSelSetpointDelay(MSelOracleVisitor(MSelValPerfTopK(max_es=args.es)), val_threshold=args.val_threshold)
observer = ObVisitor(msel)
observer = ObVisitorCleanUp(observer)

Expand Down
6 changes: 6 additions & 0 deletions domainlab/algos/msels/a_model_sel.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,12 @@ def sel_model_te_acc(self):
return self.msel.sel_model_te_acc
return -1

@property
def oracle_last_setpoint_sel_te_acc(self):
if self.msel is not None:
return self.msel.oracle_last_setpoint_sel_te_acc
return -1

@property
def model_selection_epoch(self):
"""
Expand Down
54 changes: 54 additions & 0 deletions domainlab/algos/msels/c_msel_setpoint_delay.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
"""
logs the best up-to-event selected model at each event when setpoint shrinks
"""
from domainlab.algos.msels.a_model_sel import AMSel
from domainlab.utils.logger import Logger


class MSelSetpointDelay(AMSel):
"""
This class decorate another model selection object, it logs the current
selected performance from the decoratee each time the setpoint shrinks
"""

def __init__(self, msel, val_threshold = None):
super().__init__(val_threshold)
# NOTE: super() has to come first always otherwise self.msel will be overwritten to be None
self.msel = msel
self._oracle_last_setpoint_sel_te_acc = 0.0

@property
def oracle_last_setpoint_sel_te_acc(self):
"""
return the last setpoint best acc
"""
return self._oracle_last_setpoint_sel_te_acc

def base_update(self, clear_counter=False):
"""
if the best model should be updated
currently, clear_counter is set via
flag = super().tr_epoch(epoch, self.flag_setpoint_updated)
"""
logger = Logger.get_logger()
logger.info(
f"setpoint selected current acc {self._oracle_last_setpoint_sel_te_acc}"
)
if clear_counter:
# for the current version of code, clear_counter = flag_setpoint_updated
log_message = (
f"setpoint msel te acc updated from "
# self._oracle_last_setpoint_sel_te_acc start from 0.0, and always saves
# the test acc when last setpoint decrease occurs
f"{self._oracle_last_setpoint_sel_te_acc} to "
# self.sel_model_te_acc defined as a property
# in a_msel, which returns self.msel.sel_model_te_acc
# is the validation acc based model selection, which
# does not take setpoint into account
f"{self.sel_model_te_acc}"
)
logger.info(log_message)
self._oracle_last_setpoint_sel_te_acc = self.sel_model_te_acc
# let decoratee decide if model should be selected or not
flag = self.msel.update(clear_counter)
return flag
61 changes: 61 additions & 0 deletions domainlab/algos/msels/c_msel_val_top_k.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
"""
Model Selection should be decoupled from
"""
from domainlab.algos.msels.c_msel_val import MSelValPerf
from domainlab.utils.logger import Logger


class MSelValPerfTopK(MSelValPerf):
"""
1. Model selection using validation performance
2. Visitor pattern to trainer
"""

def __init__(self, max_es, top_k=2):
super().__init__(max_es) # construct self.tr_obs (observer)
self.top_k = top_k
self.list_top_k_acc = [0.0 for _ in range(top_k)]

def update(self, clear_counter=False):
"""
if the best model should be updated
"""
flag_super = super().update(clear_counter)
metric_val_current = self.tr_obs.metric_val[self.tr_obs.str_metric4msel]
acc_min = min(self.list_top_k_acc)
if metric_val_current > acc_min:
# overwrite
logger = Logger.get_logger()
logger.info(
f"top k validation acc: {self.list_top_k_acc} \
overwriting/reset counter"
)
self.es_c = 0 # restore counter
ind = self.list_top_k_acc.index(acc_min)
# avoid having identical values
if metric_val_current not in self.list_top_k_acc:
self.list_top_k_acc[ind] = metric_val_current
logger.info(
f"top k validation acc updated: \
{self.list_top_k_acc}"
)
# overwrite to ensure consistency
# issue #569: initially self.list_top_k_acc will be [xx, 0] and it does not matter since 0 will be overwriten by second epoch validation acc.
# actually, after epoch 1, most often, sefl._best_val_acc will be the higher value of self.list_top_k_acc will overwriten by min(self.list_top_k_acc)
logger.info(
f"top-2 val sel: overwriting best val acc from {self._best_val_acc} to "
f"minimum of {self.list_top_k_acc} which is {min(self.list_top_k_acc)} "
f"to ensure consistency"
)
self._best_val_acc = min(self.list_top_k_acc)
# overwrite test acc, this does not depend on if val top-k acc has been overwritten or not
metric_te_current = self.tr_obs.metric_te[self.tr_obs.str_metric4msel]
if self._sel_model_te_acc != metric_te_current:
# this can only happen if the validation acc has decreased and current val acc is only bigger than min(self.list_top_k_acc} but lower than max(self.list_top_k_acc)
logger.info(
f"top-2 val sel: overwriting selected model test acc from "
f"{self._sel_model_te_acc} to {metric_te_current} to ensure consistency"
)
self._sel_model_te_acc = metric_te_current
return True # if metric_val_current > acc_min:
return flag_super
27 changes: 24 additions & 3 deletions domainlab/algos/observers/b_obvisitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,22 @@ def __init__(self, model_sel):
self.metric_val = None
self.perf_metric = None

self.flag_setpoint_changed_once = False

@property
def str_metric4msel(self):
"""
string representing the metric used for persisting models on the disk
"""
return self.host_trainer.str_metric4msel

def update(self, epoch):
def reset(self):
"""
reset observer via reset model selector
"""
self.model_sel.reset()

def update(self, epoch, flag_info=False):
logger = Logger.get_logger()
logger.info(f"epoch: {epoch}")
self.epo = epoch
Expand All @@ -53,13 +61,18 @@ def update(self, epoch):
self.loader_te, self.device
)
self.metric_te = metric_te
if self.model_sel.update(epoch):
if self.model_sel.update(epoch, flag_info):
logger.info("better model found")
self.host_trainer.model.save()
logger.info("persisted")
acc = self.metric_te.get("acc")
flag_stop = self.model_sel.if_stop(acc)
flag_enough = epoch >= self.host_trainer.aconf.epos_min

self.flag_setpoint_changed_once |= flag_info
if self.host_trainer.aconf.force_setpoint_change_once:
return flag_stop & flag_enough & self.flag_setpoint_changed_once

return flag_stop & flag_enough

def accept(self, trainer):
Expand Down Expand Up @@ -106,7 +119,15 @@ def after_all(self):
metric_te.update({"model_selection_epoch": self.model_sel.model_selection_epoch})
else:
metric_te.update({"acc_val": -1})
metric_te.update({"model_selection_epoch": -1})

if hasattr(self, "model_sel") and hasattr(
self.model_sel, "oracle_last_setpoint_sel_te_acc"
):
metric_te.update(
{"acc_setpoint": self.model_sel.oracle_last_setpoint_sel_te_acc}
)
else:
metric_te.update({"acc_setpoint": -1})
self.dump_prediction(model_ld, metric_te)
# save metric to one line in csv result file
self.host_trainer.model.visitor(metric_te)
Expand Down
Loading
Loading