From 16c81a32989ab95efa1442d3583b609ab10a7ef0 Mon Sep 17 00:00:00 2001 From: Yunguan Fu Date: Wed, 27 Jan 2021 22:53:58 +0000 Subject: [PATCH 01/17] Issue #9: add pre-commit configs --- .flake8 | 5 + .isort.cfg | 7 + .pre-commit-config.yaml | 48 ++++++ .prettierignore | 1 + .prettierrc | 6 + .pylintrc | 336 ++++++++++++++++++++++++++++++++++++++++ requirements.txt | 2 + 7 files changed, 405 insertions(+) create mode 100644 .flake8 create mode 100644 .isort.cfg create mode 100644 .pre-commit-config.yaml create mode 100644 .prettierignore create mode 100644 .prettierrc create mode 100644 .pylintrc create mode 100644 requirements.txt diff --git a/.flake8 b/.flake8 new file mode 100644 index 0000000..bd71867 --- /dev/null +++ b/.flake8 @@ -0,0 +1,5 @@ +[flake8] +ignore = E203, E266, E501, W503 +max-line-length = 88 +max-complexity = 18 +select = B,C,E,F,W,T4,B9 diff --git a/.isort.cfg b/.isort.cfg new file mode 100644 index 0000000..ae0e639 --- /dev/null +++ b/.isort.cfg @@ -0,0 +1,7 @@ +[settings] +known_third_party = +multi_line_output = 3 +include_trailing_comma = True +force_grid_wrap = 0 +use_parentheses = True +line_length = 88 diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..723fe30 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,48 @@ +repos: + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v3.4.0 + hooks: + - id: check-ast # Simply check whether the files parse as valid python + - id: check-case-conflict # Check for files that would conflict in case-insensitive filesystems + - id: check-builtin-literals # Require literal syntax when initializing empty or zero Python builtin types + - id: check-docstring-first # Check a common error of defining a docstring after code + - id: check-merge-conflict # Check for files that contain merge conflict strings + - id: check-yaml # Check yaml files + - id: check-vcs-permalinks # Ensure that links to vcs websites are permalinks + - id: debug-statements # Check for debugger imports and py37+ `breakpoint()` calls in python source + - id: detect-private-key # Detect the presence of private keys + - id: double-quote-string-fixer # Replace double quoted strings with single quoted strings. + - id: end-of-file-fixer # Ensure that a file is either empty, or ends with one newline + - id: mixed-line-ending # Replace or checks mixed line ending + - id: trailing-whitespace # This hook trims trailing whitespace + - id: file-contents-sorter # Sort the lines in specified files + files: .*requirements.*\.txt$ + - repo: https://github.com/asottile/seed-isort-config + rev: v2.2.0 + hooks: + - id: seed-isort-config + - repo: https://github.com/timothycrosley/isort + rev: 5.7.0 + hooks: + - id: isort + - repo: https://github.com/psf/black + rev: 20.8b1 + hooks: + - id: black + language_version: python3.7 + - repo: https://github.com/pre-commit/mirrors-prettier + rev: v2.2.1 + hooks: + - id: prettier + - repo: https://gitlab.com/pycqa/flake8 + rev: 3.8.4 + hooks: + - id: flake8 + - repo: https://github.com/pycqa/pydocstyle + rev: 5.1.1 # pick a git hash / tag to point to + hooks: + - id: pydocstyle + - repo: https://github.com/pre-commit/mirrors-mypy + rev: v0.800 + hooks: + - id: mypy diff --git a/.prettierignore b/.prettierignore new file mode 100644 index 0000000..4467cb7 --- /dev/null +++ b/.prettierignore @@ -0,0 +1 @@ +docs/joss_paper/paper.md diff --git a/.prettierrc b/.prettierrc new file mode 100644 index 0000000..6d2f1b0 --- /dev/null +++ b/.prettierrc @@ -0,0 +1,6 @@ +{ + "printWidth": 88, + "proseWrap": "always", + "useTabs": false, + "tabWidth": 2 +} diff --git a/.pylintrc b/.pylintrc new file mode 100644 index 0000000..7a9f360 --- /dev/null +++ b/.pylintrc @@ -0,0 +1,336 @@ +[MASTER] + +# Specify a configuration file. +#rcfile= + +# Python code to execute, usually for sys.path manipulation such as +# pygtk.require(). +#init-hook= + +# Profiled execution. +profile=no + +# Add files or directories to the denylist. They should be base names, not +# paths. +ignore=CVS + +# Pickle collected data for later comparisons. +persistent=yes + +# List of plugins (as comma separated values of python modules names) to load, +# usually to register additional checkers. +load-plugins=pylint.extensions.docparams +accept-no-param-doc=no + +[MESSAGES CONTROL] + +# Enable the message, report, category or checker with the given id(s). You can +# either give multiple identifier separated by comma (,) or put this option +# multiple time. See also the "--disable" option for examples. +enable=indexing-exception,old-raise-syntax + +# Disable the message, report, category or checker with the given id(s). You +# can either give multiple identifiers separated by comma (,) or put this +# option multiple times (only on the command line, not in the configuration +# file where it should appear only once).You can also use "--disable=all" to +# disable everything first and then reenable specific checks. For example, if +# you want to run only the similarities checker, you can use "--disable=all +# --enable=similarities". If you want to run only the classes checker, but have +# no Warning level messages displayed, use"--disable=all --enable=classes +# --disable=W" +disable=design,similarities,no-self-use,attribute-defined-outside-init,locally-disabled,star-args,pointless-except,bad-option-value,global-statement,fixme,suppressed-message,useless-suppression,locally-enabled,no-member,no-name-in-module,import-error,unsubscriptable-object,unbalanced-tuple-unpacking,undefined-variable,not-context-manager,invalid-sequence-index,unexpected-keyword-arg,no-value-for-parameter + + +# Set the cache size for astng objects. +cache-size=500 + + +[REPORTS] + +# Set the output format. Available formats are text, parseable, colorized, msvs +# (visual studio) and html. You can also give a reporter class, eg +# mypackage.mymodule.MyReporterClass. +output-format=text + +# Put messages in a separate file for each module / package specified on the +# command line instead of printing them on stdout. Reports (if any) will be +# written in a file name "pylint_global.[txt|html]". +files-output=no + +# Tells whether to display a full report or only the messages +reports=yes + +# Python expression which should return a note less than 10 (10 is the highest +# note). You have access to the variables errors warning, statement which +# respectively contain the number of errors / warnings messages and the total +# number of statements analyzed. This is used by the global evaluation report +# (RP0004). +evaluation=10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10) + +# Add a comment according to your evaluation note. This is used by the global +# evaluation report (RP0004). +comment=no + +# Template used to display messages. This is a python new-style format string +# used to format the message information. See doc for all details +#msg-template= + + +[TYPECHECK] + +# Tells whether missing members accessed in mixin class should be ignored. A +# mixin class is detected if its name ends with "mixin" (case insensitive). +ignore-mixin-members=yes + +# List of classes names for which member attributes should not be checked +# (useful for classes with attributes dynamically set). +ignored-classes=SQLObject + +# When zope mode is activated, add a predefined set of Zope acquired attributes +# to generated-members. +zope=no + +# List of members which are set dynamically and missed by pylint inference +# system, and so shouldn't trigger E0201 when accessed. Python regular +# expressions are accepted. +generated-members=REQUEST,acl_users,aq_parent + +# List of decorators that create context managers from functions, such as +# contextlib.contextmanager. +contextmanager-decorators=contextlib.contextmanager,contextlib2.contextmanager + + +[VARIABLES] + +# Tells whether we should check for unused import in __init__ files. +init-import=no + +# A regular expression matching the beginning of the name of dummy variables +# (i.e. not used). +dummy-variables-rgx=^\*{0,2}(_$|unused_|dummy_) + +# List of additional names supposed to be defined in builtins. Remember that +# you should avoid to define new builtins when possible. +additional-builtins= + + +[BASIC] + +# Required attributes for module, separated by a comma +required-attributes= + +# List of builtins function names that should not be used, separated by a comma +bad-functions=apply,input,reduce + + +# Disable the report(s) with the given id(s). +# All non-Google reports are disabled by default. +disable-report=R0001,R0002,R0003,R0004,R0101,R0102,R0201,R0202,R0220,R0401,R0402,R0701,R0801,R0901,R0902,R0903,R0904,R0911,R0912,R0913,R0914,R0915,R0921,R0922,R0923 + +# Regular expression which should only match correct module names +module-rgx=(([a-z_][a-z0-9_]*)|([A-Z][a-zA-Z0-9]+))$ + +# Regular expression which should only match correct module level names +const-rgx=^(_?[A-Z][A-Z0-9_]*|__[a-z0-9_]+__|_?[a-z][a-z0-9_]*)$ + +# Regular expression which should only match correct class names +class-rgx=^_?[A-Z][a-zA-Z0-9]*$ + +# Regular expression which should only match correct function names +function-rgx=^(?:(?P_?[A-Z][a-zA-Z0-9]*)|(?P_?[a-z][a-z0-9_]*))$ + +# Regular expression which should only match correct method names +method-rgx=^(?:(?P__[a-z0-9_]+__|next)|(?P_{0,2}[A-Z][a-zA-Z0-9]*)|(?P_{0,2}[a-z][a-z0-9_]*))$ + +# Regular expression which should only match correct instance attribute names +attr-rgx=^_{0,2}[a-z][a-z0-9_]*$ + +# Regular expression which should only match correct argument names +argument-rgx=^[a-z][a-z0-9_]*$ + +# Regular expression which should only match correct variable names +variable-rgx=^[a-z][a-z0-9_]*$ + +# Regular expression which should only match correct attribute names in class +# bodies +class-attribute-rgx=^(_?[A-Z][A-Z0-9_]*|__[a-z0-9_]+__|_?[a-z][a-z0-9_]*)$ + +# Regular expression which should only match correct list comprehension / +# generator expression variable names +inlinevar-rgx=^[a-z][a-z0-9_]*$ + +# Good variable names which should always be accepted, separated by a comma +good-names=main,_ + +# Bad variable names which should always be refused, separated by a comma +bad-names= + +# Regular expression which should only match function or class names that do +# not require a docstring. +no-docstring-rgx=(__.*__|main) + +# Minimum line length for functions/classes that require docstrings, shorter +# ones are exempt. +docstring-min-length=1 + + +[FORMAT] + +# Maximum number of characters on a single line. +max-line-length=88 + +# Regexp for a line that is allowed to be longer than the limit. +ignore-long-lines=(?x) + (^\s*(import|from)\s + |\$Id:\s\/\/depot\/.+#\d+\s\$ + |^[a-zA-Z_][a-zA-Z0-9_]*\s*=\s*("[^"]\S+"|'[^']\S+') + |^\s*\#\ LINT\.ThenChange + |^[^#]*\#\ type:\ [a-zA-Z_][a-zA-Z0-9_.,[\] ]*$ + |pylint + |""" + |\# + |lambda + |(https?|ftp):) + +# Allow the body of an if to be on the same line as the test if there is no +# else. +single-line-if-stmt=y + +# List of optional constructs for which whitespace checking is disabled +no-space-check= + +# Maximum number of lines in a module +max-module-lines=1000 + +# String used as indentation unit. This is usually " " (4 spaces) or "\t" (1 +# tab). +indent-string=' ' + + +[SIMILARITIES] + +# Minimum lines number of a similarity. +min-similarity-lines=4 + +# Ignore comments when computing similarities. +ignore-comments=yes + +# Ignore docstrings when computing similarities. +ignore-docstrings=yes + +# Ignore imports when computing similarities. +ignore-imports=no + + +[MISCELLANEOUS] + +# List of note tags to take in consideration, separated by a comma. +notes= + + +[IMPORTS] + +# Deprecated modules which should not be used, separated by a comma +deprecated-modules=regsub,TERMIOS,Bastion,rexec,sets + +# Create a graph of every (i.e. internal and external) dependencies in the +# given file (report RP0402 must not be disabled) +import-graph= + +# Create a graph of external dependencies in the given file (report RP0402 must +# not be disabled) +ext-import-graph= + +# Create a graph of internal dependencies in the given file (report RP0402 must +# not be disabled) +int-import-graph= + + +[CLASSES] + +# List of interface methods to ignore, separated by a comma. This is used for +# instance to not check methods defines in Zope's Interface base class. +ignore-iface-methods=isImplementedBy,deferred,extends,names,namesAndDescriptions,queryDescriptionFor,getBases,getDescriptionFor,getDoc,getName,getTaggedValue,getTaggedValueTags,isEqualOrExtendedBy,setTaggedValue,isImplementedByInstancesOf,adaptWith,is_implemented_by + +# List of method names used to declare (i.e. assign) instance attributes. +defining-attr-methods=__init__,__new__,setUp + +# List of valid names for the first argument in a class method. +valid-classmethod-first-arg=cls,class_ + +# List of valid names for the first argument in a metaclass class method. +valid-metaclass-classmethod-first-arg=mcs + + +[DESIGN] + +# Maximum number of arguments for function / method +max-args=5 + +# Argument names that match this expression will be ignored. Default to name +# with leading underscore +ignored-argument-names=_.* + +# Maximum number of locals for function / method body +max-locals=15 + +# Maximum number of return / yield for function / method body +max-returns=6 + +# Maximum number of branch for function / method body +max-branches=12 + +# Maximum number of statements in function / method body +max-statements=50 + +# Maximum number of parents for a class (see R0901). +max-parents=7 + +# Maximum number of attributes for a class (see R0902). +max-attributes=7 + +# Minimum number of public methods for a class (see R0903). +min-public-methods=2 + +# Maximum number of public methods for a class (see R0904). +max-public-methods=20 + + +[EXCEPTIONS] + +# Exceptions that will emit a warning when being caught. Defaults to +# "Exception" +overgeneral-exceptions=Exception,StandardError,BaseException + + +[AST] + +# Maximum line length for lambdas +short-func-length=1 + +# List of module members that should be marked as deprecated. +# All of the string functions are listed in 4.1.4 Deprecated string functions +# in the Python 2.4 docs. +deprecated-members=string.atof,string.atoi,string.atol,string.capitalize,string.expandtabs,string.find,string.rfind,string.index,string.rindex,string.count,string.lower,string.split,string.rsplit,string.splitfields,string.join,string.joinfields,string.lstrip,string.rstrip,string.strip,string.swapcase,string.translate,string.upper,string.ljust,string.rjust,string.center,string.zfill,string.replace,sys.exitfunc + + +[DOCSTRING] + +default-docstring-type=sphinx +# List of exceptions that do not need to be mentioned in the Raises section of +# a docstring. +ignore-exceptions=AssertionError,NotImplementedError,StopIteration,TypeError,ValueError + + + +[TOKENS] + +# Number of spaces of indent required when the last token on the preceding line +# is an open (, [, or {. +indent-after-paren=4 + + +[GOOGLE LINES] + +# Regexp for a proper copyright notice. +copyright=Copyright \d{4} The TensorFlow Authors\. +All [Rr]ights [Rr]eserved\. diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..122f894 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,2 @@ +mypy +pre-commit>=2.10.0 From dff40f122a83098956f15c4748595eac4a6b3a45 Mon Sep 17 00:00:00 2001 From: Yunguan Fu Date: Wed, 27 Jan 2021 23:53:38 +0000 Subject: [PATCH 02/17] Issue #9: add preprocess script --- .isort.cfg | 2 +- .pre-commit-config.yaml | 1 - requirements.txt | 1 + scripts/neuro_imaging_preprocess.py | 94 +++++++++++++++++++++++++++++ 4 files changed, 96 insertions(+), 2 deletions(-) create mode 100644 scripts/neuro_imaging_preprocess.py diff --git a/.isort.cfg b/.isort.cfg index ae0e639..9db365b 100644 --- a/.isort.cfg +++ b/.isort.cfg @@ -1,5 +1,5 @@ [settings] -known_third_party = +known_third_party =torchio multi_line_output = 3 include_trailing_comma = True force_grid_wrap = 0 diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 723fe30..67b4bf1 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -11,7 +11,6 @@ repos: - id: check-vcs-permalinks # Ensure that links to vcs websites are permalinks - id: debug-statements # Check for debugger imports and py37+ `breakpoint()` calls in python source - id: detect-private-key # Detect the presence of private keys - - id: double-quote-string-fixer # Replace double quoted strings with single quoted strings. - id: end-of-file-fixer # Ensure that a file is either empty, or ends with one newline - id: mixed-line-ending # Replace or checks mixed line ending - id: trailing-whitespace # This hook trims trailing whitespace diff --git a/requirements.txt b/requirements.txt index 122f894..f25e485 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,2 +1,3 @@ mypy pre-commit>=2.10.0 +torchio diff --git a/scripts/neuro_imaging_preprocess.py b/scripts/neuro_imaging_preprocess.py new file mode 100644 index 0000000..a92b52c --- /dev/null +++ b/scripts/neuro_imaging_preprocess.py @@ -0,0 +1,94 @@ +""" +This script is for performing skull stripping on the affine-aligned datasets. + +The data are stored in multiple folders: +- matrices, storing the affine matrices for the affine registration +- mri, storing the MR images +- gif_parcellation, storing the parcellation of the MR images +- reference, storing one image parcellation pair + +The preprocessed files will be saved under +- preprocessed/images +- preprocessed/labels +- preprocessed/reference +""" +import glob +import os + +import torchio as tio + +SMALLEST_BRAIN_LABEL = 24 # from colour table +data_folder_path = "/raid/candi/Yunguan/DeepReg/neuroimaging" +output_folder_path = f"{data_folder_path}/preprocessed" + +# get file paths +image_file_paths = glob.glob(f"{data_folder_path}/mri/*.nii.gz") +label_file_paths = glob.glob(f"{data_folder_path}/gif_parcellation/*.nii.gz") +matrix_file_paths = glob.glob(f"{data_folder_path}/matrices/*.txt") + +assert len(image_file_paths) == len(label_file_paths) == len(matrix_file_paths) +num_images = len(image_file_paths) + +image_file_paths = sorted(image_file_paths) +label_file_paths = sorted(label_file_paths) +matrix_file_paths = sorted(matrix_file_paths) + +# get unique IDs +image_file_names = [ + os.path.split(x)[1].replace(".nii.gz", "") for x in image_file_paths +] +label_file_names = [ + os.path.split(x)[1].replace(".nii.gz", "") for x in label_file_paths +] +matrix_file_names = [os.path.split(x)[1].replace(".txt", "") for x in matrix_file_paths] + +# images have suffix "_t1_pre_on_mni" +# labels have suffix "_t1_pre_NeuroMorph_Parcellation" or "-T1_NeuroMorph_Parcellation" +# matrices have suffix "_t1_pre_to_mni" +# verify sorted filenames are matching +for i in range(num_images): + image_fname = image_file_names[i] + label_fname = label_file_names[i] + label_fname = label_fname.replace( + "_t1_pre_NeuroMorph_Parcellation", "_t1_pre_on_mni" + ) + label_fname = label_fname.replace("-T1_NeuroMorph_Parcellation", "_t1_pre_on_mni") + matrix_fname = matrix_file_names[i] + matrix_fname = matrix_fname.replace("_t1_pre_to_mni", "_t1_pre_on_mni") + assert image_fname == label_fname == matrix_fname + + +def preprocess(image_path: str, label_path: str, matrix_path: str): + """ + Preprocess one data sample. + + Args: + image_path: file path for image + label_path: file path for parcellation + matrix_path: file path for affine matrix + """ + name = os.path.split(image_path)[1].replace("_pre_on_mni.nii.gz", "") + out_image_path = f"{output_folder_path}/images/{name}.nii.gz" + out_label_path = f"{output_folder_path}/labels/{name}.nii.gz" + + # resample parcellation to MNI + matrix = tio.io.read_matrix(matrix_path) + parcellation = tio.LabelMap(label_path, to_mni=matrix) + resample = tio.Resample(image_path, pre_affine_name="to_mni") + parcellation_mni = resample(parcellation) + parcellation_mni.save(out_label_path) + + # get brain mask + extract_brain = tio.Lambda(lambda x: (x >= SMALLEST_BRAIN_LABEL)) + brain_mask = extract_brain(parcellation_mni) + + # skull-stripping + mri = tio.ScalarImage(image_path) + mri.data[~brain_mask.data.bool()] = 0 + mri.save(out_image_path) + + +for image_path, label_path, matrix_path in zip( + image_file_paths, label_file_paths, matrix_file_paths +): + preprocess(image_path=image_path, label_path=label_path, matrix_path=matrix_path) From 057cdb07b48f4fceff6d6bde0ad770958faf486e Mon Sep 17 00:00:00 2001 From: Yunguan Fu Date: Thu, 28 Jan 2021 00:00:35 +0000 Subject: [PATCH 03/17] Issue #9: create dirs and use tqdm --- .isort.cfg | 2 +- requirements.txt | 1 + scripts/neuro_imaging_preprocess.py | 10 ++++++++-- 3 files changed, 10 insertions(+), 3 deletions(-) diff --git a/.isort.cfg b/.isort.cfg index 9db365b..13c4c43 100644 --- a/.isort.cfg +++ b/.isort.cfg @@ -1,5 +1,5 @@ [settings] -known_third_party =torchio +known_third_party =torchio,tqdm multi_line_output = 3 include_trailing_comma = True force_grid_wrap = 0 diff --git a/requirements.txt b/requirements.txt index f25e485..9d4ed68 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,4 @@ mypy pre-commit>=2.10.0 torchio +tqdm diff --git a/scripts/neuro_imaging_preprocess.py b/scripts/neuro_imaging_preprocess.py index a92b52c..c31c9cc 100644 --- a/scripts/neuro_imaging_preprocess.py +++ b/scripts/neuro_imaging_preprocess.py @@ -16,11 +16,17 @@ import os import torchio as tio +from tqdm import tqdm SMALLEST_BRAIN_LABEL = 24 # from colour table data_folder_path = "/raid/candi/Yunguan/DeepReg/neuroimaging" output_folder_path = f"{data_folder_path}/preprocessed" +for folder_name in ["images", "labels"]: + _path = f"{output_folder_path}/{folder_name}" + if not os.path.exists(_path): + os.makedirs(_path) + # get file paths image_file_paths = glob.glob(f"{data_folder_path}/mri/*.nii.gz") label_file_paths = glob.glob(f"{data_folder_path}/gif_parcellation/*.nii.gz") @@ -88,7 +94,7 @@ def preprocess(image_path: str, label_path: str, matrix_path: str): mri.save(out_image_path) -for image_path, label_path, matrix_path in zip( - image_file_paths, label_file_paths, matrix_file_paths +for image_path, label_path, matrix_path in tqdm( + zip(image_file_paths, label_file_paths, matrix_file_paths) ): preprocess(image_path=image_path, label_path=label_path, matrix_path=matrix_path) From 3a7cc396e10d69069876336ac41d6decd4910990 Mon Sep 17 00:00:00 2001 From: Yunguan Fu Date: Thu, 28 Jan 2021 00:02:27 +0000 Subject: [PATCH 04/17] Issue #9: add known length for tqdm --- scripts/neuro_imaging_preprocess.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/scripts/neuro_imaging_preprocess.py b/scripts/neuro_imaging_preprocess.py index c31c9cc..b94ec7a 100644 --- a/scripts/neuro_imaging_preprocess.py +++ b/scripts/neuro_imaging_preprocess.py @@ -95,6 +95,7 @@ def preprocess(image_path: str, label_path: str, matrix_path: str): for image_path, label_path, matrix_path in tqdm( - zip(image_file_paths, label_file_paths, matrix_file_paths) + zip(image_file_paths, label_file_paths, matrix_file_paths), + total=num_images, ): preprocess(image_path=image_path, label_path=label_path, matrix_path=matrix_path) From 52a09ccc0e027c33cc3476b6d784496ff7ca595a Mon Sep 17 00:00:00 2001 From: Yunguan Fu Date: Sun, 21 Feb 2021 00:47:52 +0000 Subject: [PATCH 05/17] Issue #9: add script for vm --- .isort.cfg | 2 +- benchmark/__init__.py | 1 + benchmark/balakrishnan2019/__init__.py | 1 + .../config_balakrishnan_2019.yaml | 49 ++++ .../updated_config_balakrishnan_2019.yaml | 51 ++++ .../voxel_morph_balakrishnan_2019.py | 217 ++++++++++++++++++ 6 files changed, 320 insertions(+), 1 deletion(-) create mode 100644 benchmark/__init__.py create mode 100644 benchmark/balakrishnan2019/__init__.py create mode 100644 benchmark/balakrishnan2019/config_balakrishnan_2019.yaml create mode 100644 benchmark/balakrishnan2019/updated_config_balakrishnan_2019.yaml create mode 100644 benchmark/balakrishnan2019/voxel_morph_balakrishnan_2019.py diff --git a/.isort.cfg b/.isort.cfg index 13c4c43..04455a8 100644 --- a/.isort.cfg +++ b/.isort.cfg @@ -1,5 +1,5 @@ [settings] -known_third_party =torchio,tqdm +known_third_party =deepreg,tensorflow,torchio,tqdm multi_line_output = 3 include_trailing_comma = True force_grid_wrap = 0 diff --git a/benchmark/__init__.py b/benchmark/__init__.py new file mode 100644 index 0000000..5e58a39 --- /dev/null +++ b/benchmark/__init__.py @@ -0,0 +1 @@ +"""Benchmark with other methods.""" diff --git a/benchmark/balakrishnan2019/__init__.py b/benchmark/balakrishnan2019/__init__.py new file mode 100644 index 0000000..3261ebc --- /dev/null +++ b/benchmark/balakrishnan2019/__init__.py @@ -0,0 +1 @@ +"""Reproduce https://arxiv.org/abs/1809.05231.""" diff --git a/benchmark/balakrishnan2019/config_balakrishnan_2019.yaml b/benchmark/balakrishnan2019/config_balakrishnan_2019.yaml new file mode 100644 index 0000000..3db6898 --- /dev/null +++ b/benchmark/balakrishnan2019/config_balakrishnan_2019.yaml @@ -0,0 +1,49 @@ +dataset: + dir: + train: "/raid/candi/Yunguan/DeepReg/neuroimaging/preprocessed" # required + valid: + test: + format: "nifti" + type: "unpaired" # paired / unpaired / grouped + labeled: false # whether to use the labels if available, "true" or "false" + image_shape: [192, 229, 193] + +train: + # define neural network structure + method: "ddf" # options include "ddf", "dvf", "conditional" + backbone: + name: "vm_balakrishnan_2019" # options include "local", "unet" and "global" + num_channel_initial: 16 # number of initial channel in local net, controls the size of the network + depth: 4 + concat_skip: true + encode_num_channels: [16, 32, 32, 32, 32] + decode_num_channels: [32, 32, 32, 32, 32] + + # define the loss function for training + loss: + image: + name: "lncc" # other options include "lncc", "ssd" and "gmi", for local normalised cross correlation, + weight: 1.0 + label: + weight: 0.0 + name: "dice" # options include "dice", "cross-entropy", "mean-squared", "generalised_dice" and "jaccard" + regularization: + weight: 1.0 # weight of regularization loss + name: "gradient" # options include "bending", "gradient" + + # define the optimizer + optimizer: + name: "adam" # options include "adam", "sgd" and "rms" + adam: + learning_rate: 1.0e-4 + + # define the hyper-parameters for preprocessing + preprocess: + data_augmentation: + name: "affine" + batch_size: 2 + shuffle_buffer_num_batch: 1 # shuffle_buffer_size = batch_size * shuffle_buffer_num_batch + + # other training hyper-parameters + epochs: 2 # number of training epochs + save_period: 2 # the model will be saved every `save_period` epochs. diff --git a/benchmark/balakrishnan2019/updated_config_balakrishnan_2019.yaml b/benchmark/balakrishnan2019/updated_config_balakrishnan_2019.yaml new file mode 100644 index 0000000..e07eed2 --- /dev/null +++ b/benchmark/balakrishnan2019/updated_config_balakrishnan_2019.yaml @@ -0,0 +1,51 @@ +dataset: + dir: + test: /home/mathpluscode/Git/DeepReg/data/test/nifti/unpaired/test + train: /home/mathpluscode/Git/DeepReg/data/test/nifti/unpaired/train + valid: null + format: nifti + image_shape: + - 64 + - 64 + - 64 + labeled: false + type: unpaired +train: + backbone: + concat_skip: true + decode_num_channels: + - 32 + - 32 + - 32 + - 32 + - 32 + depth: 4 + encode_num_channels: + - 16 + - 32 + - 32 + - 32 + - 32 + name: vm_balakrishnan_2019 + num_channel_initial: 16 + epochs: 2 + loss: + image: + name: lncc + weight: 1.0 + label: + name: dice + weight: 0.0 + regularization: + name: gradient + weight: 1.0 + method: ddf + optimizer: + learning_rate: 0.0001 + name: Adam + preprocess: + batch_size: 2 + data_augmentation: + name: affine + shuffle_buffer_num_batch: 1 + save_period: 2 diff --git a/benchmark/balakrishnan2019/voxel_morph_balakrishnan_2019.py b/benchmark/balakrishnan2019/voxel_morph_balakrishnan_2019.py new file mode 100644 index 0000000..5573567 --- /dev/null +++ b/benchmark/balakrishnan2019/voxel_morph_balakrishnan_2019.py @@ -0,0 +1,217 @@ +"""This script provides an example of using custom backbone for training.""" + +import argparse +from typing import Tuple, Union + +import tensorflow as tf +import tensorflow.keras.layers as tfkl +from deepreg.model.backbone import UNet +from deepreg.registry import REGISTRY +from deepreg.train import train + + +@REGISTRY.register_backbone(name="vm_balakrishnan_2019") +class VoxelMorphBalakrishnan2019(UNet): + """Reproduce https://arxiv.org/abs/1809.05231.""" + + def __init__(self, **kwargs): + """ + Init. + + Args: + **kwargs: + """ + super().__init__(**kwargs) + + self._out_ddf_upsampling = tf.keras.layers.UpSampling3D(size=2) + self._out_ddf_conv = tfkl.Conv3D( + filters=3, + kernel_size=3, + padding="same", + activation=tf.keras.layers.LeakyReLU(alpha=0.2), + ) + + def build_encode_conv_block( + self, filters: int, kernel_size: int, padding: str + ) -> Union[tf.keras.Model, tfkl.Layer]: + """ + Build a conv block for down-sampling. + + :param filters: number of channels for output + :param kernel_size: arg for conv3d + :param padding: arg for conv3d + :return: a block consists of one or multiple layers + """ + return tfkl.Conv3D( + filters=filters, + kernel_size=kernel_size, + padding=padding, + strides=2, + activation=tf.keras.layers.LeakyReLU(alpha=0.2), + ) + + def build_down_sampling_block( + self, filters: int, kernel_size: int, padding: str, strides: int + ) -> Union[tf.keras.Model, tfkl.Layer]: + """ + Return identity layer. + + :param filters: number of channels for output, arg for conv3d + :param kernel_size: arg for pool3d or conv3d + :param padding: arg for pool3d or conv3d + :param strides: arg for pool3d or conv3d + :return: a block consists of one or multiple layers + """ + return tfkl.Lambda(lambda x: x) + + def build_bottom_block( + self, filters: int, kernel_size: int, padding: str + ) -> Union[tf.keras.Model, tfkl.Layer]: + """ + Return down sample layer. + + :param filters: number of channels for output + :param kernel_size: arg for conv3d + :param padding: arg for conv3d + :return: a block consists of one or multiple layers + """ + return tfkl.Conv3D( + filters=filters, + kernel_size=kernel_size, + padding=padding, + strides=2, + activation=tf.keras.layers.LeakyReLU(alpha=0.2), + ) + + def build_up_sampling_block( + self, + filters: int, + output_padding: int, + kernel_size: int, + padding: str, + strides: int, + output_shape: tuple, + ) -> Union[tf.keras.Model, tfkl.Layer]: + """ + Build a block for up-sampling. + + This block changes the tensor shape (width, height, depth), + but it does not changes the number of channels. + + :param filters: number of channels for output + :param output_padding: padding for output + :param kernel_size: arg for deconv3d + :param padding: arg for deconv3d + :param strides: arg for deconv3d + :param output_shape: shape of the output tensor + :return: a block consists of one or multiple layers + """ + return tf.keras.layers.UpSampling3D(size=strides) + + def build_decode_conv_block( + self, filters: int, kernel_size: int, padding: str + ) -> Union[tf.keras.Model, tfkl.Layer]: + """ + Build a conv block for up-sampling. + + :param filters: number of channels for output + :param kernel_size: arg for conv3d + :param padding: arg for conv3d + :return: a block consists of one or multiple layers + """ + return tfkl.Conv3D( + filters=filters, + kernel_size=kernel_size, + padding=padding, + strides=1, + activation=tf.keras.layers.LeakyReLU(alpha=0.2), + ) + + def build_output_block( + self, + image_size: Tuple[int], + extract_levels: Tuple[int], + out_channels: int, + out_kernel_initializer: str, + out_activation: str, + ) -> Union[tf.keras.Model, tfkl.Layer]: + """ + Build a block for output. + + The input to this block is a list of tensors. + + :param image_size: such as (dim1, dim2, dim3) + :param extract_levels: number of extraction levels. + :param out_channels: number of channels for the extractions + :param out_kernel_initializer: initializer to use for kernels. + :param out_activation: activation to use at end layer. + :return: a block consists of one or multiple layers + """ + return tf.keras.Sequential( + [ + tfkl.Lambda(lambda x: x[-1]), # take the last one / depth 0 + tfkl.Conv3D( + filters=self.num_channel_initial, + kernel_size=3, + padding="same", + activation=tf.keras.layers.LeakyReLU(alpha=0.2), + ), + tfkl.Conv3D( + filters=self.num_channel_initial, + kernel_size=3, + padding="same", + activation=tf.keras.layers.LeakyReLU(alpha=0.2), + ), + ] + ) + + def call(self, inputs: tf.Tensor, training=None, mask=None) -> tf.Tensor: + """ + Build LocalNet graph based on built layers. + + :param inputs: image batch, shape = (batch, f_dim1, f_dim2, f_dim3, ch) + :param training: None or bool. + :param mask: None or tf.Tensor. + :return: shape = (batch, f_dim1, f_dim2, f_dim3, out_channels) + """ + output = super().call(inputs=inputs, training=training, mask=mask) + # upsample again + output = self._out_ddf_upsampling(output) + output = tf.concat([inputs, output], axis=4) + output = self._out_ddf_conv(output) + return output + + +def main(args=None): + """ + Launch training. + + Args: + args: + + """ + parser = argparse.ArgumentParser() + + parser.add_argument( + "--gpu", + "-g", + help="GPU index for training." + '-g "" for using CPU' + '-g "0" for using GPU 0' + '-g "0,1" for using GPU 0 and 1.', + type=str, + required=True, + ) + args = parser.parse_args(args) + + config_path = "config_balakrishnan_2019.yaml" + train( + gpu=args.gpu, + config_path=config_path, + gpu_allow_growth=True, + ckpt_path="", + ) + + +if __name__ == "__main__": + main() From 789564f02666c1dfddc4c07018bdd80e624884a7 Mon Sep 17 00:00:00 2001 From: Yunguan Fu Date: Sun, 21 Feb 2021 00:59:01 +0000 Subject: [PATCH 06/17] Issue #8: remove duplicated config --- .../updated_config_balakrishnan_2019.yaml | 51 ------------------- 1 file changed, 51 deletions(-) delete mode 100644 benchmark/balakrishnan2019/updated_config_balakrishnan_2019.yaml diff --git a/benchmark/balakrishnan2019/updated_config_balakrishnan_2019.yaml b/benchmark/balakrishnan2019/updated_config_balakrishnan_2019.yaml deleted file mode 100644 index e07eed2..0000000 --- a/benchmark/balakrishnan2019/updated_config_balakrishnan_2019.yaml +++ /dev/null @@ -1,51 +0,0 @@ -dataset: - dir: - test: /home/mathpluscode/Git/DeepReg/data/test/nifti/unpaired/test - train: /home/mathpluscode/Git/DeepReg/data/test/nifti/unpaired/train - valid: null - format: nifti - image_shape: - - 64 - - 64 - - 64 - labeled: false - type: unpaired -train: - backbone: - concat_skip: true - decode_num_channels: - - 32 - - 32 - - 32 - - 32 - - 32 - depth: 4 - encode_num_channels: - - 16 - - 32 - - 32 - - 32 - - 32 - name: vm_balakrishnan_2019 - num_channel_initial: 16 - epochs: 2 - loss: - image: - name: lncc - weight: 1.0 - label: - name: dice - weight: 0.0 - regularization: - name: gradient - weight: 1.0 - method: ddf - optimizer: - learning_rate: 0.0001 - name: Adam - preprocess: - batch_size: 2 - data_augmentation: - name: affine - shuffle_buffer_num_batch: 1 - save_period: 2 From ab3f568945f4864145c332560327ad61bbba254d Mon Sep 17 00:00:00 2001 From: Yunguan Fu Date: Sat, 6 Mar 2021 23:08:11 +0000 Subject: [PATCH 07/17] Issue #9: fix bug --- benchmark/balakrishnan2019/voxel_morph_balakrishnan_2019.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/benchmark/balakrishnan2019/voxel_morph_balakrishnan_2019.py b/benchmark/balakrishnan2019/voxel_morph_balakrishnan_2019.py index 5573567..ae58742 100644 --- a/benchmark/balakrishnan2019/voxel_morph_balakrishnan_2019.py +++ b/benchmark/balakrishnan2019/voxel_morph_balakrishnan_2019.py @@ -149,7 +149,7 @@ def build_output_block( """ return tf.keras.Sequential( [ - tfkl.Lambda(lambda x: x[-1]), # take the last one / depth 0 + tfkl.Lambda(lambda x: x[0]), # take the first one / depth 0 tfkl.Conv3D( filters=self.num_channel_initial, kernel_size=3, From 9e2a93e24f96714b4cc7f9aea7a9f3265e42c5f5 Mon Sep 17 00:00:00 2001 From: Yunguan Fu Date: Sun, 21 Mar 2021 22:03:05 +0000 Subject: [PATCH 08/17] define grad loss as VM --- .../voxel_morph_balakrishnan_2019.py | 67 +++++++++++++++++++ 1 file changed, 67 insertions(+) diff --git a/benchmark/balakrishnan2019/voxel_morph_balakrishnan_2019.py b/benchmark/balakrishnan2019/voxel_morph_balakrishnan_2019.py index ae58742..546906f 100644 --- a/benchmark/balakrishnan2019/voxel_morph_balakrishnan_2019.py +++ b/benchmark/balakrishnan2019/voxel_morph_balakrishnan_2019.py @@ -4,6 +4,7 @@ from typing import Tuple, Union import tensorflow as tf +import tensorflow.keras.backend as K import tensorflow.keras.layers as tfkl from deepreg.model.backbone import UNet from deepreg.registry import REGISTRY @@ -182,6 +183,72 @@ def call(self, inputs: tf.Tensor, training=None, mask=None) -> tf.Tensor: return output +@REGISTRY.register_loss(name="gradient-vm") +class GradientNorm(tf.keras.layers.Layer): + """ + Calculate the L1/L2 norm of ddf using central finite difference. + + y_true and y_pred have to be at least 5d tensor, including batch axis. + """ + + def __init__(self, l1: bool = False, name: str = "GradientNorm"): + """ + Init. + + :param l1: bool true if calculate L1 norm, otherwise L2 norm + :param name: name of the loss + """ + super().__init__(name=name) + self.l1 = l1 + + def call(self, inputs: tf.Tensor, **kwargs) -> tf.Tensor: + """ + Return a scalar loss. + + :param inputs: shape = (batch, m_dim1, m_dim2, m_dim3, 3) + :param kwargs: additional arguments. + :return: shape = () + """ + assert len(inputs.shape) == 5 + tf.debugging.check_numerics(inputs, "GRAIDENT ddf value NAN/INF", name=None) + ddf = inputs + + if self.l1: + df = [tf.reduce_mean(tf.abs(f)) for f in self._diffs(ddf)] + else: + assert self.penalty == "l2", ( + "penalty can only be l1 or l2. Got: %s" % self.penalty + ) + df = [tf.reduce_mean(f * f) for f in self._diffs(ddf)] + return tf.add_n(df) / len(df) + + def get_config(self) -> dict: + """Return the config dictionary for recreating this class.""" + config = super().get_config() + config["l1"] = self.l1 + return config + + def _diffs(self, y): + vol_shape = y.get_shape().as_list()[1:-1] + ndims = len(vol_shape) + + df = [] + for i in range(ndims): + d = i + 1 + # permute dimensions to put the ith dimension first + r = [d, *range(d), *range(d + 1, ndims + 2)] + y = K.permute_dimensions(y, r) + dfi = y[1:, ...] - y[:-1, ...] + + # permute back + # note: this might not be necessary for this loss specifically, + # since the results are just summed over anyway. + r = [*range(1, d + 1), 0, *range(d + 1, ndims + 2)] + df.append(K.permute_dimensions(dfi, r)) + + return df + + def main(args=None): """ Launch training. From e529462804c9c34b20381993b791cd422c4a56d3 Mon Sep 17 00:00:00 2001 From: Yunguan Fu Date: Sun, 21 Mar 2021 22:04:41 +0000 Subject: [PATCH 09/17] fix bug --- benchmark/balakrishnan2019/voxel_morph_balakrishnan_2019.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/benchmark/balakrishnan2019/voxel_morph_balakrishnan_2019.py b/benchmark/balakrishnan2019/voxel_morph_balakrishnan_2019.py index 546906f..0d0b18b 100644 --- a/benchmark/balakrishnan2019/voxel_morph_balakrishnan_2019.py +++ b/benchmark/balakrishnan2019/voxel_morph_balakrishnan_2019.py @@ -216,9 +216,6 @@ def call(self, inputs: tf.Tensor, **kwargs) -> tf.Tensor: if self.l1: df = [tf.reduce_mean(tf.abs(f)) for f in self._diffs(ddf)] else: - assert self.penalty == "l2", ( - "penalty can only be l1 or l2. Got: %s" % self.penalty - ) df = [tf.reduce_mean(f * f) for f in self._diffs(ddf)] return tf.add_n(df) / len(df) From 032cf19a8c00c3b962990221e3595caa4b7f16a1 Mon Sep 17 00:00:00 2001 From: Yunguan Fu Date: Mon, 22 Mar 2021 22:26:18 +0000 Subject: [PATCH 10/17] use relu as activation --- .../voxel_morph_balakrishnan_2019.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/benchmark/balakrishnan2019/voxel_morph_balakrishnan_2019.py b/benchmark/balakrishnan2019/voxel_morph_balakrishnan_2019.py index 0d0b18b..8af884e 100644 --- a/benchmark/balakrishnan2019/voxel_morph_balakrishnan_2019.py +++ b/benchmark/balakrishnan2019/voxel_morph_balakrishnan_2019.py @@ -29,7 +29,7 @@ def __init__(self, **kwargs): filters=3, kernel_size=3, padding="same", - activation=tf.keras.layers.LeakyReLU(alpha=0.2), + activation=self.get_activation(), ) def build_encode_conv_block( @@ -48,7 +48,7 @@ def build_encode_conv_block( kernel_size=kernel_size, padding=padding, strides=2, - activation=tf.keras.layers.LeakyReLU(alpha=0.2), + activation=self.get_activation(), ) def build_down_sampling_block( @@ -81,7 +81,7 @@ def build_bottom_block( kernel_size=kernel_size, padding=padding, strides=2, - activation=tf.keras.layers.LeakyReLU(alpha=0.2), + activation=self.get_activation(), ) def build_up_sampling_block( @@ -125,7 +125,7 @@ def build_decode_conv_block( kernel_size=kernel_size, padding=padding, strides=1, - activation=tf.keras.layers.LeakyReLU(alpha=0.2), + activation=self.get_activation(), ) def build_output_block( @@ -155,13 +155,13 @@ def build_output_block( filters=self.num_channel_initial, kernel_size=3, padding="same", - activation=tf.keras.layers.LeakyReLU(alpha=0.2), + activation=self.get_activation(), ), tfkl.Conv3D( filters=self.num_channel_initial, kernel_size=3, padding="same", - activation=tf.keras.layers.LeakyReLU(alpha=0.2), + activation=self.get_activation(), ), ] ) @@ -182,6 +182,10 @@ def call(self, inputs: tf.Tensor, training=None, mask=None) -> tf.Tensor: output = self._out_ddf_conv(output) return output + def get_activation(self) -> tf.keras.layers.Layer: + """Return activation layer.""" + return tf.keras.layers.ReLU() + @REGISTRY.register_loss(name="gradient-vm") class GradientNorm(tf.keras.layers.Layer): From 2377ebeae2e264599587b1e9fb04e03511a0f372 Mon Sep 17 00:00:00 2001 From: Yunguan Fu Date: Tue, 23 Mar 2021 00:36:33 +0000 Subject: [PATCH 11/17] update config --- benchmark/balakrishnan2019/config_balakrishnan_2019.yaml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/benchmark/balakrishnan2019/config_balakrishnan_2019.yaml b/benchmark/balakrishnan2019/config_balakrishnan_2019.yaml index 3db6898..781e5a3 100644 --- a/benchmark/balakrishnan2019/config_balakrishnan_2019.yaml +++ b/benchmark/balakrishnan2019/config_balakrishnan_2019.yaml @@ -6,7 +6,7 @@ dataset: format: "nifti" type: "unpaired" # paired / unpaired / grouped labeled: false # whether to use the labels if available, "true" or "false" - image_shape: [192, 229, 193] + image_shape: [192, 224, 192] train: # define neural network structure @@ -47,3 +47,4 @@ train: # other training hyper-parameters epochs: 2 # number of training epochs save_period: 2 # the model will be saved every `save_period` epochs. + update_freq: 50 From a2fee5933a707ff34b0088182667343a81cd782a Mon Sep 17 00:00:00 2001 From: Yunguan Fu Date: Tue, 23 Mar 2021 00:56:45 +0000 Subject: [PATCH 12/17] remove activation of ddf and change back to leaky relu --- benchmark/balakrishnan2019/voxel_morph_balakrishnan_2019.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/benchmark/balakrishnan2019/voxel_morph_balakrishnan_2019.py b/benchmark/balakrishnan2019/voxel_morph_balakrishnan_2019.py index 8af884e..c89247a 100644 --- a/benchmark/balakrishnan2019/voxel_morph_balakrishnan_2019.py +++ b/benchmark/balakrishnan2019/voxel_morph_balakrishnan_2019.py @@ -161,7 +161,6 @@ def build_output_block( filters=self.num_channel_initial, kernel_size=3, padding="same", - activation=self.get_activation(), ), ] ) @@ -184,7 +183,7 @@ def call(self, inputs: tf.Tensor, training=None, mask=None) -> tf.Tensor: def get_activation(self) -> tf.keras.layers.Layer: """Return activation layer.""" - return tf.keras.layers.ReLU() + return tf.keras.layers.LeakyReLU(alpha=0.2) @REGISTRY.register_loss(name="gradient-vm") From 985690b58e77d7ebfa6bfb91715af6932fab1f2f Mon Sep 17 00:00:00 2001 From: Yunguan Fu Date: Wed, 24 Mar 2021 23:17:54 +0000 Subject: [PATCH 13/17] add kernel init for the network --- benchmark/balakrishnan2019/voxel_morph_balakrishnan_2019.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/benchmark/balakrishnan2019/voxel_morph_balakrishnan_2019.py b/benchmark/balakrishnan2019/voxel_morph_balakrishnan_2019.py index c89247a..442a0a5 100644 --- a/benchmark/balakrishnan2019/voxel_morph_balakrishnan_2019.py +++ b/benchmark/balakrishnan2019/voxel_morph_balakrishnan_2019.py @@ -161,6 +161,9 @@ def build_output_block( filters=self.num_channel_initial, kernel_size=3, padding="same", + kernel_initializer=tf.keras.initializers.RandomNormal( + mean=0.0, stddev=1e-5 + ), ), ] ) From 71dc18a4b63df2aa512a129313c0a30bf0c5f2d0 Mon Sep 17 00:00:00 2001 From: Yunguan Fu Date: Thu, 25 Mar 2021 00:43:25 +0000 Subject: [PATCH 14/17] add check numeric --- benchmark/balakrishnan2019/voxel_morph_balakrishnan_2019.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/benchmark/balakrishnan2019/voxel_morph_balakrishnan_2019.py b/benchmark/balakrishnan2019/voxel_morph_balakrishnan_2019.py index 442a0a5..2a6a466 100644 --- a/benchmark/balakrishnan2019/voxel_morph_balakrishnan_2019.py +++ b/benchmark/balakrishnan2019/voxel_morph_balakrishnan_2019.py @@ -10,6 +10,8 @@ from deepreg.registry import REGISTRY from deepreg.train import train +tf.debugging.enable_check_numerics() + @REGISTRY.register_backbone(name="vm_balakrishnan_2019") class VoxelMorphBalakrishnan2019(UNet): From a32cb0db2da7ae0b695e9dfeb26a259ed94a74bb Mon Sep 17 00:00:00 2001 From: Yunguan Fu Date: Sat, 27 Mar 2021 00:10:10 +0000 Subject: [PATCH 15/17] refactor model --- .../voxel_morph_balakrishnan_2019.py | 113 +++++------------- 1 file changed, 29 insertions(+), 84 deletions(-) diff --git a/benchmark/balakrishnan2019/voxel_morph_balakrishnan_2019.py b/benchmark/balakrishnan2019/voxel_morph_balakrishnan_2019.py index 2a6a466..d345fd6 100644 --- a/benchmark/balakrishnan2019/voxel_morph_balakrishnan_2019.py +++ b/benchmark/balakrishnan2019/voxel_morph_balakrishnan_2019.py @@ -4,14 +4,11 @@ from typing import Tuple, Union import tensorflow as tf -import tensorflow.keras.backend as K import tensorflow.keras.layers as tfkl from deepreg.model.backbone import UNet from deepreg.registry import REGISTRY from deepreg.train import train -tf.debugging.enable_check_numerics() - @REGISTRY.register_backbone(name="vm_balakrishnan_2019") class VoxelMorphBalakrishnan2019(UNet): @@ -150,25 +147,36 @@ def build_output_block( :param out_activation: activation to use at end layer. :return: a block consists of one or multiple layers """ - return tf.keras.Sequential( - [ - tfkl.Lambda(lambda x: x[0]), # take the first one / depth 0 - tfkl.Conv3D( - filters=self.num_channel_initial, - kernel_size=3, - padding="same", - activation=self.get_activation(), - ), - tfkl.Conv3D( - filters=self.num_channel_initial, - kernel_size=3, - padding="same", - kernel_initializer=tf.keras.initializers.RandomNormal( - mean=0.0, stddev=1e-5 + + class OutputBlock(tf.keras.Model): + def __init__(self): + super().__init__() + self.conv1 = ( + tfkl.Conv3D( + filters=self.num_channel_initial, + kernel_size=3, + padding="same", + activation=self.get_activation(), ), - ), - ] - ) + ) + self.conv2 = ( + tfkl.Conv3D( + filters=self.num_channel_initial, + kernel_size=3, + padding="same", + kernel_initializer=tf.keras.initializers.RandomNormal( + mean=0.0, stddev=1e-5 + ), + ), + ) + + def call(self, inputs, training=None, mask=None): + x = inputs[0] + x = self.conv1(x) + x = self.conv2(x) + return x + + return OutputBlock() def call(self, inputs: tf.Tensor, training=None, mask=None) -> tf.Tensor: """ @@ -191,69 +199,6 @@ def get_activation(self) -> tf.keras.layers.Layer: return tf.keras.layers.LeakyReLU(alpha=0.2) -@REGISTRY.register_loss(name="gradient-vm") -class GradientNorm(tf.keras.layers.Layer): - """ - Calculate the L1/L2 norm of ddf using central finite difference. - - y_true and y_pred have to be at least 5d tensor, including batch axis. - """ - - def __init__(self, l1: bool = False, name: str = "GradientNorm"): - """ - Init. - - :param l1: bool true if calculate L1 norm, otherwise L2 norm - :param name: name of the loss - """ - super().__init__(name=name) - self.l1 = l1 - - def call(self, inputs: tf.Tensor, **kwargs) -> tf.Tensor: - """ - Return a scalar loss. - - :param inputs: shape = (batch, m_dim1, m_dim2, m_dim3, 3) - :param kwargs: additional arguments. - :return: shape = () - """ - assert len(inputs.shape) == 5 - tf.debugging.check_numerics(inputs, "GRAIDENT ddf value NAN/INF", name=None) - ddf = inputs - - if self.l1: - df = [tf.reduce_mean(tf.abs(f)) for f in self._diffs(ddf)] - else: - df = [tf.reduce_mean(f * f) for f in self._diffs(ddf)] - return tf.add_n(df) / len(df) - - def get_config(self) -> dict: - """Return the config dictionary for recreating this class.""" - config = super().get_config() - config["l1"] = self.l1 - return config - - def _diffs(self, y): - vol_shape = y.get_shape().as_list()[1:-1] - ndims = len(vol_shape) - - df = [] - for i in range(ndims): - d = i + 1 - # permute dimensions to put the ith dimension first - r = [d, *range(d), *range(d + 1, ndims + 2)] - y = K.permute_dimensions(y, r) - dfi = y[1:, ...] - y[:-1, ...] - - # permute back - # note: this might not be necessary for this loss specifically, - # since the results are just summed over anyway. - r = [*range(1, d + 1), 0, *range(d + 1, ndims + 2)] - df.append(K.permute_dimensions(dfi, r)) - - return df - - def main(args=None): """ Launch training. From e934c94c6a9375103c85cadd8b1e7fb24d8d8996 Mon Sep 17 00:00:00 2001 From: Yunguan Fu Date: Sat, 27 Mar 2021 00:17:15 +0000 Subject: [PATCH 16/17] fix bug --- .../voxel_morph_balakrishnan_2019.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/benchmark/balakrishnan2019/voxel_morph_balakrishnan_2019.py b/benchmark/balakrishnan2019/voxel_morph_balakrishnan_2019.py index d345fd6..763f644 100644 --- a/benchmark/balakrishnan2019/voxel_morph_balakrishnan_2019.py +++ b/benchmark/balakrishnan2019/voxel_morph_balakrishnan_2019.py @@ -149,19 +149,19 @@ def build_output_block( """ class OutputBlock(tf.keras.Model): - def __init__(self): + def __init__(self, num_channel_initial, activation): super().__init__() self.conv1 = ( tfkl.Conv3D( - filters=self.num_channel_initial, + filters=num_channel_initial, kernel_size=3, padding="same", - activation=self.get_activation(), + activation=activation, ), ) self.conv2 = ( tfkl.Conv3D( - filters=self.num_channel_initial, + filters=num_channel_initial, kernel_size=3, padding="same", kernel_initializer=tf.keras.initializers.RandomNormal( @@ -176,7 +176,10 @@ def call(self, inputs, training=None, mask=None): x = self.conv2(x) return x - return OutputBlock() + return OutputBlock( + num_channel_initial=self.num_channel_initial, + activation=self.get_activation(), + ) def call(self, inputs: tf.Tensor, training=None, mask=None) -> tf.Tensor: """ From 246605d99c10c9f16283042113129fd1d3a54591 Mon Sep 17 00:00:00 2001 From: Yunguan Fu Date: Sat, 27 Mar 2021 00:19:36 +0000 Subject: [PATCH 17/17] fix bug --- .../voxel_morph_balakrishnan_2019.py | 26 ++++++++----------- 1 file changed, 11 insertions(+), 15 deletions(-) diff --git a/benchmark/balakrishnan2019/voxel_morph_balakrishnan_2019.py b/benchmark/balakrishnan2019/voxel_morph_balakrishnan_2019.py index 763f644..808c4a2 100644 --- a/benchmark/balakrishnan2019/voxel_morph_balakrishnan_2019.py +++ b/benchmark/balakrishnan2019/voxel_morph_balakrishnan_2019.py @@ -151,22 +151,18 @@ def build_output_block( class OutputBlock(tf.keras.Model): def __init__(self, num_channel_initial, activation): super().__init__() - self.conv1 = ( - tfkl.Conv3D( - filters=num_channel_initial, - kernel_size=3, - padding="same", - activation=activation, - ), + self.conv1 = tfkl.Conv3D( + filters=num_channel_initial, + kernel_size=3, + padding="same", + activation=activation, ) - self.conv2 = ( - tfkl.Conv3D( - filters=num_channel_initial, - kernel_size=3, - padding="same", - kernel_initializer=tf.keras.initializers.RandomNormal( - mean=0.0, stddev=1e-5 - ), + self.conv2 = tfkl.Conv3D( + filters=num_channel_initial, + kernel_size=3, + padding="same", + kernel_initializer=tf.keras.initializers.RandomNormal( + mean=0.0, stddev=1e-5 ), )