diff --git a/src/models/mutual_self_attention.py b/src/models/mutual_self_attention.py index a4d5543..5a03cbf 100644 --- a/src/models/mutual_self_attention.py +++ b/src/models/mutual_self_attention.py @@ -45,8 +45,8 @@ def __init__( style_fidelity, reference_attn, reference_adain, - fusion_blocks, batch_size=batch_size, + fusion_blocks=fusion_blocks, ) def register_reference_hooks(