Skip to content

Conversation

@mhuen
Copy link
Collaborator

@mhuen mhuen commented Feb 24, 2025

The hexagonal convolution kernels could not properly propagate gradients through to their kernel weights because these weights were modfied (via tf.stack/tf.concat) prior to the forward pass of the layer. As a result, GradientTape was not able to find any gradients for the kernel weights. This PR introduces separate classes for the hexagonal kernel creation that seperates out weight initialization to the class's constructor and any transformations to the forward pass in __call__.

Additionally, names for created weights are now fully propagated through.

@mhuen mhuen merged commit aaaacb3 into master Feb 24, 2025
5 checks passed
@mhuen mhuen deleted the HexConvTF2 branch February 24, 2025 16:31
mhuen added a commit that referenced this pull request Feb 24, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants