diff --git a/test/test_conv.py b/test/test_conv.py index 5daf603..dc4b8d5 100644 --- a/test/test_conv.py +++ b/test/test_conv.py @@ -1,5 +1,3 @@ -from __future__ import division, print_function - import unittest import numpy as np import tensorflow as tf diff --git a/test/test_hex.py b/test/test_hex.py new file mode 100644 index 0000000..515c8e2 --- /dev/null +++ b/test/test_hex.py @@ -0,0 +1,814 @@ +import unittest +import numpy as np +import tensorflow as tf + +from tfscripts.hex import conv +from tfscripts.hex import icecube +from tfscripts.hex import rotation + + +class TestHexKernels(unittest.TestCase): + + def setUp(self): + self.random_state = np.random.RandomState(42) + + def test_hex_kernel(self): + """Test HexKernel""" + + test_cases = [ + { + "filter_size": [2, 0, 3, 2], + "n_vars": 7, + "kernel": np.array( + [ + [ + [[0.0, 0.0], [0.0, 0.0], [0.0, 0.0]], + [ + [-1.4257220029830933, -1.0143787860870361], + [0.39387720823287964, -1.9629442691802979], + [-0.10814568400382996, 0.11988475173711777], + ], + [ + [0.4779662489891052, -1.6896843910217285], + [0.4115297496318817, -0.6118844151496887], + [-0.8208955526351929, -1.7109564542770386], + ], + ], + [ + [ + [-0.019794760271906853, 0.40792471170425415], + [-0.11573483049869537, -0.30879053473472595], + [-0.9291547536849976, 0.2432696521282196], + ], + [ + [-0.4925185739994049, 0.31435173749923706], + [-0.9397227168083191, -0.4897875487804413], + [-0.3880728483200073, 0.27374207973480225], + ], + [ + [0.7232375741004944, 0.48548829555511475], + [0.5194052457809448, -0.5773778557777405], + [0.7949565052986145, 1.7222315073013306], + ], + ], + [ + [ + [-0.03933761268854141, 0.44203484058380127], + [-0.4911632537841797, -0.012474085204303265], + [0.18765877187252045, -1.074103832244873], + ], + [ + [0.9184949398040771, -1.2781122922897339], + [0.4061416685581207, 0.4695652723312378], + [-1.096517562866211, 0.9345183968544006], + ], + [[0.0, 0.0], [0.0, 0.0], [0.0, 0.0]], + ], + ] + ), + }, + { + "filter_size": [3, 1, 2], + "n_vars": 25, + "kernel": np.array( + [ + [ + [0.0, 0.0], + [0.0, 0.0], + [0.0, 0.0], + [0.0, 0.0], + [-1.4257220029830933, -1.0143787860870361], + [0.0, 0.0], + [0.0, 0.0], + ], + [ + [0.0, 0.0], + [0.0, 0.0], + [0.0, 0.0], + [0.4779662489891052, -1.6896843910217285], + [-0.019794760271906853, 0.40792471170425415], + [-0.4925185739994049, 0.31435173749923706], + [0.7232375741004944, 0.48548829555511475], + ], + [ + [0.0, 0.0], + [-0.03933761268854141, 0.44203484058380127], + [0.9184949398040771, -1.2781122922897339], + [-0.6554914116859436, 0.7017678022384644], + [-1.0997180938720703, -0.7472116351127625], + [-0.8416294455528259, 0.2393561452627182], + [0.0, 0.0], + ], + [ + [0.0, 0.0], + [0.5353959202766418, -0.6356369256973267], + [0.011429687030613422, -1.0444294214248657], + [0.8294717073440552, -1.1644237041473389], + [-0.5676082372665405, 1.6229828596115112], + [-1.7003906965255737, -1.0814650058746338], + [0.0, 0.0], + ], + [ + [0.0, 0.0], + [0.5154496431350708, -0.6644119024276733], + [0.4813891053199768, -1.372441291809082], + [-1.166320562362671, -0.726149320602417], + [-0.06755267083644867, 1.1745551824569702], + [-0.39131826162338257, -0.7624845504760742], + [0.0, 0.0], + ], + [ + [0.9889782667160034, -1.0488485097885132], + [0.09571779519319534, 0.8825100064277649], + [-0.209646537899971, -0.11976470053195953], + [1.8110568523406982, 0.26369625329971313], + [0.0, 0.0], + [0.0, 0.0], + [0.0, 0.0], + ], + [ + [0.0, 0.0], + [0.0, 0.0], + [0.2035822570323944, -0.5185909867286682], + [0.0, 0.0], + [0.0, 0.0], + [0.0, 0.0], + [0.0, 0.0], + ], + ] + ), + }, + { + "filter_size": [4, 0], + "n_vars": 37, + "kernel": np.array( + [ + [ + 0.0, + 0.0, + 0.0, + -1.4257220029830933, + 0.4779662489891052, + -0.019794760271906853, + -0.4925185739994049, + ], + [ + 0.0, + 0.0, + 0.7232375741004944, + -0.03933761268854141, + 0.9184949398040771, + -0.6554914116859436, + -1.0997180938720703, + ], + [ + 0.0, + -0.8416294455528259, + 0.5353959202766418, + 0.011429687030613422, + 0.8294717073440552, + -0.5676082372665405, + -1.7003906965255737, + ], + [ + 0.5154496431350708, + 0.4813891053199768, + -1.166320562362671, + -0.06755267083644867, + -0.39131826162338257, + 0.9889782667160034, + 0.09571779519319534, + ], + [ + -0.209646537899971, + 1.8110568523406982, + 0.2035822570323944, + 1.2006571292877197, + -1.4237172603607178, + 1.2166802883148193, + 0.0, + ], + [ + -0.5895527601242065, + -1.7983052730560303, + -0.7603991031646729, + -1.3157434463500977, + 0.6037954688072205, + 0.0, + 0.0, + ], + [ + -1.0926811695098877, + -1.2689539194107056, + -0.5327851176261902, + 0.8136177659034729, + 0.0, + 0.0, + 0.0, + ], + ] + ), + }, + ] + + for test in test_cases: + tf.random.set_seed(42) + kernel_obj = conv.HexKernel( + filter_size=test["filter_size"], + get_ones=False, + float_precision="float32", + seed=42, + name="HexKernel", + ) + kernel = kernel_obj() + var_list = kernel_obj.var_list + + self.assertEqual(len(var_list), test["n_vars"]) + self.assertTrue(np.allclose(test["kernel"], kernel, atol=1e-6)) + + def test_icecube_kernel(self): + """Test IceCubeKernel""" + + test_cases = [ + { + "filter_size": [1, 2], + "n_vars": 78, + "kernel": np.array( + [ + [ + [[0.0, 0.0]], + [[0.0, 0.0]], + [[0.0, 0.0]], + [[0.0, 0.0]], + [[-1.4257220029830933, -1.0143787860870361]], + [[0.4779662489891052, -1.6896843910217285]], + [[-0.019794760271906853, 0.40792471170425415]], + [[-0.4925185739994049, 0.31435173749923706]], + [[0.7232375741004944, 0.48548829555511475]], + [[-0.03933761268854141, 0.44203484058380127]], + ], + [ + [[0.0, 0.0]], + [[0.0, 0.0]], + [[0.0, 0.0]], + [[0.9184949398040771, -1.2781122922897339]], + [[-0.6554914116859436, 0.7017678022384644]], + [[-1.0997180938720703, -0.7472116351127625]], + [[-0.8416294455528259, 0.2393561452627182]], + [[0.5353959202766418, -0.6356369256973267]], + [[0.011429687030613422, -1.0444294214248657]], + [[0.8294717073440552, -1.1644237041473389]], + ], + [ + [[0.0, 0.0]], + [[0.0, 0.0]], + [[-0.5676082372665405, 1.6229828596115112]], + [[-1.7003906965255737, -1.0814650058746338]], + [[0.5154496431350708, -0.6644119024276733]], + [[0.4813891053199768, -1.372441291809082]], + [[-1.166320562362671, -0.726149320602417]], + [[-0.06755267083644867, 1.1745551824569702]], + [[-0.39131826162338257, -0.7624845504760742]], + [[0.9889782667160034, -1.0488485097885132]], + ], + [ + [[0.0, 0.0]], + [[0.09571779519319534, 0.8825100064277649]], + [[-0.209646537899971, -0.11976470053195953]], + [[1.8110568523406982, 0.26369625329971313]], + [[0.2035822570323944, -0.5185909867286682]], + [[1.2006571292877197, 0.2735150456428528]], + [[-1.4237172603607178, 0.4937574565410614]], + [[1.2166802883148193, 0.6768909692764282]], + [[-0.5895527601242065, 1.2136154174804688]], + [[-1.7983052730560303, -1.4686884880065918]], + ], + [ + [[-0.7603991031646729, 0.6917804479598999]], + [[-1.3157434463500977, -1.3364534378051758]], + [[0.6037954688072205, -0.4373283088207245]], + [[-1.0926811695098877, 0.4675874412059784]], + [[-1.2689539194107056, -0.5594342947006226]], + [[-0.5327851176261902, -0.8017807602882385]], + [[0.8136177659034729, -0.8163193464279175]], + [[1.0433070659637451, 0.6938974857330322]], + [[-0.48861029744148254, 0.7178099155426025]], + [[-1.5572731494903564, 0.30047574639320374]], + ], + [ + [[1.3027359247207642, 0.5769704580307007]], + [[-0.30500108003616333, 0.6116410493850708]], + [[0.993407666683197, 0.5634545683860779]], + [[0.5683414936065674, -0.038891635835170746]], + [[0.18263010680675507, -0.9309208393096924]], + [[-0.30102670192718506, -0.4619215428829193]], + [[-0.22547315061092377, 0.6620444655418396]], + [[0.18600444495677948, 0.3983207643032074]], + [[-1.3858344554901123, 1.4597420692443848]], + [[-0.4150679409503937, -0.09204770624637604]], + ], + [ + [[-0.21830973029136658, 0.4716184139251709]], + [[-1.6000971794128418, 0.016703147441148758]], + [[1.486158847808838, 0.38790279626846313]], + [[1.3315898180007935, -0.04911533743143082]], + [[0.8701032996177673, -0.11769035458564758]], + [[0.8997407555580139, -0.6535756587982178]], + [[0.2050367146730423, 0.11578050255775452]], + [[-0.6097397208213806, -0.32293230295181274]], + [[0.36548155546188354, -0.3372085690498352]], + [[0.0, 0.0]], + ], + [ + [[0.1818438619375229, -0.4313610792160034]], + [[0.22320835292339325, 0.021276840940117836]], + [[-1.1691735982894897, 0.35153236985206604]], + [[0.2640562653541565, -0.17022287845611572]], + [[-1.3610483407974243, -0.2239222377538681]], + [[-0.516293466091156, -0.22145582735538483]], + [[0.03286828100681305, -0.16413533687591553]], + [[-0.5490068793296814, 0.349784791469574]], + [[0.0, 0.0]], + [[0.0, 0.0]], + ], + [ + [[1.0159858465194702, -0.5081911683082581]], + [[0.8340901136398315, -0.21344462037086487]], + [[-1.8715938329696655, -0.501074492931366]], + [[-1.5713595151901245, -0.3148786127567291]], + [[0.4968576431274414, -0.6774446368217468]], + [[0.44643816351890564, 1.2766880989074707]], + [[1.2446485757827759, -0.9975428581237793]], + [[0.0, 0.0]], + [[0.0, 0.0]], + [[0.0, 0.0]], + ], + [ + [[0.5764395594596863, -1.1536171436309814]], + [[1.8420573472976685, -0.009640185162425041]], + [[-0.07150302827358246, -0.016818424686789513]], + [[1.0523600578308105, 0.0008560324786230922]], + [[0.0, 0.0]], + [[0.0, 0.0]], + [[0.0, 0.0]], + [[0.0, 0.0]], + [[0.0, 0.0]], + [[0.0, 0.0]], + ], + ] + ), + }, + { + "filter_size": [], + "n_vars": 78, + "kernel": np.array( + [ + [ + 0.0, + 0.0, + 0.0, + 0.0, + -1.4257220029830933, + 0.4779662489891052, + -0.019794760271906853, + -0.4925185739994049, + 0.7232375741004944, + -0.03933761268854141, + ], + [ + 0.0, + 0.0, + 0.0, + 0.9184949398040771, + -0.6554914116859436, + -1.0997180938720703, + -0.8416294455528259, + 0.5353959202766418, + 0.011429687030613422, + 0.8294717073440552, + ], + [ + 0.0, + 0.0, + -0.5676082372665405, + -1.7003906965255737, + 0.5154496431350708, + 0.4813891053199768, + -1.166320562362671, + -0.06755267083644867, + -0.39131826162338257, + 0.9889782667160034, + ], + [ + 0.0, + 0.09571779519319534, + -0.209646537899971, + 1.8110568523406982, + 0.2035822570323944, + 1.2006571292877197, + -1.4237172603607178, + 1.2166802883148193, + -0.5895527601242065, + -1.7983052730560303, + ], + [ + -0.7603991031646729, + -1.3157434463500977, + 0.6037954688072205, + -1.0926811695098877, + -1.2689539194107056, + -0.5327851176261902, + 0.8136177659034729, + 1.0433070659637451, + -0.48861029744148254, + -1.5572731494903564, + ], + [ + 1.3027359247207642, + -0.30500108003616333, + 0.993407666683197, + 0.5683414936065674, + 0.18263010680675507, + -0.30102670192718506, + -0.22547315061092377, + 0.18600444495677948, + -1.3858344554901123, + -0.4150679409503937, + ], + [ + -0.21830973029136658, + -1.6000971794128418, + 1.486158847808838, + 1.3315898180007935, + 0.8701032996177673, + 0.8997407555580139, + 0.2050367146730423, + -0.6097397208213806, + 0.36548155546188354, + 0.0, + ], + [ + 0.1818438619375229, + 0.22320835292339325, + -1.1691735982894897, + 0.2640562653541565, + -1.3610483407974243, + -0.516293466091156, + 0.03286828100681305, + -0.5490068793296814, + 0.0, + 0.0, + ], + [ + 1.0159858465194702, + 0.8340901136398315, + -1.8715938329696655, + -1.5713595151901245, + 0.4968576431274414, + 0.44643816351890564, + 1.2446485757827759, + 0.0, + 0.0, + 0.0, + ], + [ + 0.5764395594596863, + 1.8420573472976685, + -0.07150302827358246, + 1.0523600578308105, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + ], + ] + ), + }, + ] + + for test in test_cases: + tf.random.set_seed(42) + kernel_obj = icecube.IceCubeKernel( + shape=test["filter_size"], + get_ones=False, + float_precision="float32", + seed=42, + name="HexKernel", + ) + kernel = kernel_obj() + var_list = kernel_obj.var_list + + self.assertEqual(len(var_list), test["n_vars"]) + self.assertTrue(np.allclose(test["kernel"], kernel, atol=1e-6)) + + def test_rotated_hex_kernel(self): + """Test RotatedHexKernel""" + + test_cases = [ + { + "filter_size": [2, 0, 2, 1], + "num_rotations": 3, + "n_vars": 8, + "kernel": np.array( + [ + [ + [[-0.0, 0.0, -0.0], [0.0, 0.0, 0.0]], + [ + [ + -0.6020655632019043, + 0.27684009075164795, + 0.015805857256054878, + ], + [ + 0.6445701718330383, + 0.2416737973690033, + -0.03702780604362488, + ], + ], + [ + [ + -0.31330278515815735, + -0.015057608485221863, + 0.3932696282863617, + ], + [ + 0.3354213237762451, + -0.013144878670573235, + -0.9212984442710876, + ], + ], + ], + [ + [ + [ + -0.4740760326385498, + -0.0075770169496536255, + -0.7334061861038208, + ], + [ + 0.5075448155403137, + -0.006614527199417353, + 1.7181239128112793, + ], + ], + [ + [ + 0.934548556804657, + -0.5457363128662109, + 1.1384203433990479, + ], + [ + -1.0005258321762085, + -0.4764128029346466, + -2.666935920715332, + ], + ], + [ + [ + 0.012975295074284077, + 0.35158050060272217, + -0.5774956941604614, + ], + [ + -0.013891325332224369, + 0.3069201111793518, + 1.352878212928772, + ], + ], + ], + [ + [ + [ + -0.4740760326385498, + -0.0075770169496536255, + -0.7334061861038208, + ], + [ + 0.5075448155403137, + -0.006614527199417353, + 1.7181239128112793, + ], + ], + [ + [ + 0.32284170389175415, + 0.18295539915561676, + 0.031410567462444305, + ], + [ + -0.34563368558883667, + 0.15971504151821136, + -0.07358439266681671, + ], + ], + [[-0.0, 0.0, -0.0], [0.0, 0.0, 0.0]], + ], + ] + ), + }, + { + "filter_size": [3, 1, 1, 2], + "num_rotations": 1, + "n_vars": 26, + "kernel": np.array( + [ + [ + [[0.0, 0.0]], + [[0.0, 0.0]], + [[0.0, 0.0]], + [[0.0, 0.0]], + [[0.2444324940443039, 0.05568281188607216]], + [[0.0, 0.0]], + [[0.0, 0.0]], + ], + [ + [[0.0, 0.0]], + [[0.0, 0.0]], + [[0.0, 0.0]], + [[0.5779832601547241, 0.13166716694831848]], + [[-1.4003510475158691, -0.3190062344074249]], + [[-0.08110759407281876, -0.01847667247056961]], + [[-0.46983906626701355, -0.1070314347743988]], + ], + [ + [[0.0, 0.0]], + [[2.1744582653045654, 0.49535128474235535]], + [[0.6188783049583435, 0.1409832388162613]], + [[1.102797508239746, 0.2512221932411194]], + [[0.5738735795021057, 0.13073095679283142]], + [[-0.7870204448699951, -0.17928676307201385]], + [[0.0, 0.0]], + ], + [ + [[0.0, 0.0]], + [[-2.04158616065979, -0.46508243680000305]], + [[-0.04723098501563072, -0.010759429074823856]], + [[-1.7118033170700073, -0.38995641469955444]], + [[-0.023766720667481422, -0.005414164625108242]], + [[-1.3203843832015991, -0.3007894456386566]], + [[0.0, 0.0]], + ], + [ + [[0.0, 0.0]], + [[-0.6815028786659241, -0.15524938702583313]], + [[0.8683603405952454, 0.19781635701656342]], + [[-0.5913459658622742, -0.13471123576164246]], + [[-1.010508418083191, -0.23019830882549286]], + [[1.187423825263977, 0.27050042152404785]], + [[0.0, 0.0]], + ], + [ + [[-0.2517136037349701, -0.05734148249030113]], + [[0.9959111213684082, 0.2268729954957962]], + [[0.013723134994506836, 0.0031261914409697056]], + [[0.6428269147872925, 0.1464388370513916]], + [[0.0, 0.0]], + [[0.0, 0.0]], + [[0.0, 0.0]], + ], + [ + [[0.0, 0.0]], + [[0.0, 0.0]], + [[0.11492425203323364, 0.02618025802075863]], + [[0.0, 0.0]], + [[0.0, 0.0]], + [[0.0, 0.0]], + [[0.0, 0.0]], + ], + ] + ), + }, + ] + + for test in test_cases: + tf.random.set_seed(42) + kernel_obj = rotation.RotatedHexKernel( + filter_size=test["filter_size"], + num_rotations=test["num_rotations"], + float_precision="float32", + seed=42, + name="RotatedHexKernel", + ) + kernel = kernel_obj() + var_list = kernel_obj.var_list + + self.assertEqual(len(var_list), test["n_vars"]) + self.assertTrue(np.allclose(test["kernel"], kernel, atol=1e-6)) + + def test_dynamic_rotated_hex_kernel(self): + """Test DynamicRotationHexKernel""" + + test_cases = [ + { + "filter_size": [2, 0, 2, 1], + "azimuth": 60.0, + "n_vars": 2, + "kernel": np.array( + [ + [ + [[0.0], [0.0]], + [[-1.8633298873901367], [1.8340147733688354]], + [[0.8756545186042786], [-0.9188050627708435]], + ], + [ + [[-0.8208955526351929], [-1.7109564542770386]], + [[-1.4257220029830933], [-1.0143787860870361]], + [[0.4779662489891052], [-1.6896843910217285]], + ], + [ + [[-0.8208955526351929], [-1.7109564542770386]], + [[0.4115297496318817], [-0.6118844151496887]], + [[0.0], [0.0]], + ], + ] + ), + }, + { + "filter_size": [3, 1, 1, 2], + "azimuth": np.array(243.0), + "n_vars": 9, + "kernel": np.array( + [ + [ + [[0.0, 0.0]], + [[0.0, 0.0]], + [[0.0, 0.0]], + [[0.0, 0.0]], + [[0.614751935005188, 0.4126650393009186]], + [[0.1084856390953064, 0.07282324880361557]], + [[0.0, 0.0]], + ], + [ + [[0.0, 0.0]], + [[0.0, 0.0]], + [[-0.07387778908014297, 0.04715276136994362]], + [[-0.10614082217216492, -0.23711901903152466]], + [[-0.8478127717971802, 0.1880636364221573]], + [[-0.8326864838600159, 0.2814864218235016]], + [[-0.03343696892261505, 0.37572962045669556]], + ], + [ + [[0.0, 0.0]], + [[-0.41864079236984253, 0.26719897985458374]], + [[-0.03404032811522484, 0.24718335270881653]], + [[0.414851576089859, -0.665774405002594]], + [[-0.7592743039131165, -1.6560028791427612]], + [[0.26231634616851807, -0.5828783512115479]], + [[-0.005900642368942499, 0.06630522757768631]], + ], + [ + [[0.0, 0.0]], + [[-0.08913522213697433, -1.225508689880371]], + [[0.4978506565093994, -1.6511404514312744]], + [[-1.4257220029830933, -1.0143787860870361]], + [[-0.5993744134902954, -0.17593613266944885]], + [[-0.365479052066803, 0.6292445063591003]], + [[0.0, 0.0]], + ], + [ + [[-0.16495771706104279, -0.11208175122737885]], + [[0.5046157836914062, -1.3580321073532104]], + [[0.7387052774429321, -0.7811640501022339]], + [[-1.7995491027832031, 1.7375566959381104]], + [[0.12104412913322449, -1.0844377279281616]], + [[0.7807207107543945, -1.086395502090454]], + [[0.0, 0.0]], + ], + [ + [[-0.9347603917121887, -0.6351298689842224]], + [[0.042605914175510406, -0.19909705221652985]], + [[0.9114534854888916, 1.712262511253357]], + [[-0.6210527420043945, -0.8963721394538879]], + [[0.1377742439508438, -0.19171684980392456]], + [[0.0, 0.0]], + [[0.0, 0.0]], + ], + [ + [[0.0, 0.0]], + [[-0.09832371771335602, 0.10526517778635025]], + [[-0.5571677088737488, 0.5965026021003723]], + [[0.0, 0.0]], + [[0.0, 0.0]], + [[0.0, 0.0]], + [[0.0, 0.0]], + ], + ] + ), + }, + ] + + for test in test_cases: + tf.random.set_seed(42) + kernel_obj = rotation.DynamicRotationHexKernel( + filter_size=test["filter_size"], + float_precision="float32", + seed=42, + name="DynamicRotationHexKernel", + ) + kernel = kernel_obj(azimuth=test["azimuth"]) + var_list = kernel_obj.var_list + + self.assertEqual(len(var_list), test["n_vars"]) + self.assertTrue(np.allclose(test["kernel"], kernel, atol=1e-6)) diff --git a/test/test_layers.py b/test/test_layers.py index 5b3ffd8..38eaedb 100644 --- a/test/test_layers.py +++ b/test/test_layers.py @@ -1,5 +1,3 @@ -from __future__ import division, print_function - import unittest import numpy as np import tensorflow as tf diff --git a/tfscripts/hex/conv.py b/tfscripts/hex/conv.py index d4ba4fc..7f2e2a0 100644 --- a/tfscripts/hex/conv.py +++ b/tfscripts/hex/conv.py @@ -14,10 +14,10 @@ # tfscripts specific imports from tfscripts.utils import SeedCounter -from tfscripts.weights import new_weights, new_biases +from tfscripts.weights import new_weights from tfscripts.hex.visual import print_hex_data from tfscripts.hex import rotation -from tfscripts.hex.icecube import get_icecube_kernel +from tfscripts.hex.icecube import IceCubeKernel from tfscripts.conv import dynamic_conv, conv4d_stacked # constants @@ -73,189 +73,200 @@ def hex_distance(h1, h2): return (abs(a1 - a2) + abs(b1 - b2) + abs(c1 - c2)) / 2 -def get_hex_kernel( - filter_size, - print_kernel=False, - get_ones=False, - float_precision=FLOAT_PRECISION, - seed=None, -): - """Get hexagonal convolution kernel +class HexKernel(tf.Module): + """Hexagonal convolution kernel""" - Create Weights for a hexagonal kernel. - The Kernel will be of a hexagonal shape in the first two dimensions, - while the other dimensions are normal. - The hexagonal kernel is off the shape: - [kernel_edge_points, kernel_edge_points, *filter_size[2:]] - But elements with coordinates in the first two dimensions, that don't belong - to the hexagon are set to a tf.Constant 0. + def __init__( + self, + filter_size, + get_ones=False, + float_precision=FLOAT_PRECISION, + seed=None, + name="HexKernel", + ): + """Get hexagonal convolution kernel - The hexagon is defined by filter_size[0:2]. - filter_size[0] defines the size of the hexagon and - filter_size[1] the orientation. + Create Weights for a hexagonal kernel. + The Kernel will be of a hexagonal shape in the first two dimensions, + while the other dimensions are normal. + The hexagonal kernel is of the shape: + [kernel_edge_points, kernel_edge_points, *filter_size[2:]] + But elements with coordinates in the first two dimensions, that don't belong + to the hexagon are set to a tf.Constant 0. - Parameters - ---------- - filter_size : A list of int - filter_size = [s, o, 3. dim(e.g. z), 4. dim(e.g. t),...] - s: size of hexagon - o: orientation of hexagon + The hexagon is defined by filter_size[0:2]. + filter_size[0] defines the size of the hexagon and + filter_size[1] the orientation. - Examples: + Parameters + ---------- + filter_size : A list of int + filter_size = [s, o, 3. dim(e.g. z), 4. dim(e.g. t),...] + s: size of hexagon + o: orientation of hexagon - s = 2, o = 0: - 1 1 0 1 1 + Examples: - 1 1 1 1 1 1 + s = 2, o = 0: + 1 1 0 1 1 - 0 1 1 1 1 + 1 1 1 1 1 1 - s = 3, o = 2: - 0 1 0 0 0 0 0 1 + 0 1 1 1 1 - 0 1 1 1 1 0 0 1 1 1 1 + s = 3, o = 2: + 0 1 0 0 0 0 0 1 - 1 1 1 1 1 0 0 1 1 1 1 1 + 0 1 1 1 1 0 0 1 1 1 1 - 0 1 1 1 1 1 0 1 1 1 1 1 + 1 1 1 1 1 0 0 1 1 1 1 1 - 0 0 1 1 1 1 1 1 1 1 1 1 + 0 1 1 1 1 1 0 1 1 1 1 1 - 0 0 1 1 1 1 0 1 1 1 1 + 0 0 1 1 1 1 1 1 1 1 1 1 - 0 0 0 0 0 1 0 1 + 0 0 1 1 1 1 0 1 1 1 1 - print_kernel : bool. - True: print first two dimensions of kernel. - 0 represents a const 0 Tensor of shape filter_size[2:] - 1 represents a trainable Tensor of shape filter_size[2:] - This can be used to verify the shape of the hex kernel - False: do not print - - get_ones : bool, optional - If True, returns constant ones for elements in hexagon. - If False, return trainable tf.tensor for elements in hexagon. - In both cases, constant zeros are returned for elements outside of - hexagon. - float_precision : tf.dtype, optional - The tensorflow dtype describing the float precision to use. - seed : int, optional - Seed for the random number generator. + 0 0 0 0 0 1 0 1 - Returns - ------- - tf.Tensor - A Tensor with shape: [ s, s, *filter_size[2:] ] - where s = 2*filter_size[0] -1 if x == o - [hexagon is parallel to axis of first dimension] - = 2*filter_size[0] +1 if x != o - [hexagon is tilted to axis of first dimension] - list of tf.Variable - A list of tensorflow variables created in this function + get_ones : bool, optional + If True, returns constant ones for elements in hexagon. + If False, return trainable tf.tensor for elements in hexagon. + In both cases, constant zeros are returned for elements outside of + hexagon. + float_precision : tf.dtype, optional + The tensorflow dtype describing the float precision to use. + seed : int, optional + Seed for the random number generator. - Raises - ------ - ValueError - Description - """ - # create seed counter - cnt = SeedCounter(seed) + Returns + ------- + tf.Tensor + A Tensor with shape: [ s, s, *filter_size[2:] ] + where s = 2*filter_size[0] -1 if x == o + [hexagon is parallel to axis of first dimension] + = 2*filter_size[0] +1 if x != o + [hexagon is tilted to axis of first dimension] + list of tf.Variable + A list of tensorflow variables created in this function + + Raises + ------ + ValueError + Description + """ + # create seed counter + cnt = SeedCounter(seed) - k = filter_size[0] - x = filter_size[1] + k = filter_size[0] + x = filter_size[1] - if x >= k: - raise ValueError( - "get_hex_kernel: filter_size (k,x,z) must fulfill " - "x < k: ({}, {}, {})".format(k, x, filter_size[2]) - ) - if x == 0: - kernel_edge_points = 2 * k - 1 - else: - kernel_edge_points = 2 * k + 1 - - zeros = tf.zeros(filter_size[2:], dtype=float_precision) - ones = tf.ones(filter_size[2:], dtype=float_precision) - - var_list = [] - a_list = [] - test_hex_dict = {} - for a in range(kernel_edge_points): - b_list = [] - for b in range(kernel_edge_points): - - # ------------------------- - # regular aligned hexagons - # ------------------------- - if x == 0: - if a + b < k - 1 or a + b > 3 * k - 3: - weights = zeros - test_hex_dict[(a, b)] = 0 - else: - if get_ones: - weights = ones + if x >= k: + raise ValueError( + "HexKernel: filter_size (k,x,z) must fulfill " + "x < k: ({}, {}, {})".format(k, x, filter_size[2]) + ) + if x == 0: + kernel_edge_points = 2 * k - 1 + else: + kernel_edge_points = 2 * k + 1 + + zeros = tf.zeros(filter_size[2:], dtype=float_precision) + ones = tf.ones(filter_size[2:], dtype=float_precision) + + self.var_list = [] + self.a_list = [] + self.test_hex_dict = {} + for a in range(kernel_edge_points): + b_list = [] + for b in range(kernel_edge_points): + + # ------------------------- + # regular aligned hexagons + # ------------------------- + if x == 0: + if a + b < k - 1 or a + b > 3 * k - 3: + weights = zeros + self.test_hex_dict[(a, b)] = 0 else: - weights = new_weights( - filter_size[2:], - float_precision=float_precision, - seed=cnt(), - ) - var_list.append(weights) - test_hex_dict[(a, b)] = 1 - - # ------------------------- - # tilted hexagons - # ------------------------- - else: - inHexagon = False - # check if inside normal k.0 aligned hexagon - # |----inside normal k.0 rhombus -----------| - if ( - (a > 0 and a < 2 * k) - and (b > 0 and b < 2 * k) - and - # |--in k.0 aligned hexagon-| - (a + b > k and a + b < 3 * k) - ): - - if a + b > k and a + b < 3 * k: - inHexagon = True + if get_ones: + weights = ones + else: + weights = new_weights( + filter_size[2:], + float_precision=float_precision, + seed=cnt(), + name=name + f"_weights_{a}_{b}", + ) + self.var_list.append(weights) + self.test_hex_dict[(a, b)] = 1 + + # ------------------------- + # tilted hexagons + # ------------------------- else: - # add 6 additional edges outside of k.0 aligned hexagon - if a == 2 * k - x and b == 0: # Edge 1 - inHexagon = True - elif a == k - x and b == x: # Edge 2 - inHexagon = True - elif a == 0 and b == k + x: # Edge 3 - inHexagon = True - elif a == x and b == 2 * k: # Edge 4 - inHexagon = True - elif a == k + x and b == 2 * k - x: # Edge 5 - inHexagon = True - elif a == 2 * k and b == k - x: # Edge 6 - inHexagon = True - # get weights or constant 0 depending on if point is in hexagon - if inHexagon: - if get_ones: - weights = ones + inHexagon = False + # check if inside normal k.0 aligned hexagon + # |----inside normal k.0 rhombus -----------| + if ( + (a > 0 and a < 2 * k) + and (b > 0 and b < 2 * k) + and + # |--in k.0 aligned hexagon-| + (a + b > k and a + b < 3 * k) + ): + + if a + b > k and a + b < 3 * k: + inHexagon = True else: - weights = new_weights( - filter_size[2:], - float_precision=float_precision, - seed=cnt(), - ) - var_list.append(weights) - test_hex_dict[(a, b)] = 1 - else: - weights = zeros - test_hex_dict[(a, b)] = 0 + # add 6 additional edges outside of k.0 aligned hexagon + if a == 2 * k - x and b == 0: # Edge 1 + inHexagon = True + elif a == k - x and b == x: # Edge 2 + inHexagon = True + elif a == 0 and b == k + x: # Edge 3 + inHexagon = True + elif a == x and b == 2 * k: # Edge 4 + inHexagon = True + elif a == k + x and b == 2 * k - x: # Edge 5 + inHexagon = True + elif a == 2 * k and b == k - x: # Edge 6 + inHexagon = True + # get weights or constant 0 depending on if point is in hexagon + if inHexagon: + if get_ones: + weights = ones + else: + weights = new_weights( + filter_size[2:], + float_precision=float_precision, + seed=cnt(), + name=name + f"_weights_{a}_{b}", + ) + self.var_list.append(weights) + self.test_hex_dict[(a, b)] = 1 + else: + weights = zeros + self.test_hex_dict[(a, b)] = 0 - b_list.append(weights) - a_list.append(tf.stack(b_list)) - hexKernel = tf.stack(a_list) - if print_kernel: - print_hex_data(test_hex_dict) - return hexKernel, var_list + b_list.append(weights) + self.a_list.append(b_list) + + def print_kernel(self): + """Print the hexagonal kernel + + Print first two dimensions of kernel. + 0 represents a const 0 Tensor of shape filter_size[2:] + 1 represents a trainable Tensor of shape filter_size[2:] + This can be used to verify the shape of the hex kernel + """ + print_hex_data(self.test_hex_dict) + + def __call__(self): + """Get the hexagonal kernel""" + a_list = [tf.stack(b_list) for b_list in self.a_list] + hex_kernel = tf.stack(a_list) + return hex_kernel class ConvHex(tf.Module): @@ -273,7 +284,7 @@ def __init__( zero_out=False, kernel=None, var_list=None, - azimuth=None, + turn_azimuth=False, float_precision=FLOAT_PRECISION, seed=None, name=None, @@ -328,7 +339,7 @@ def __init__( [1, 1, 1, 1, 1]: a stride of 1 is used along all axes. [1, 1, 2, 1, 1]: a stride of 2 is used along the y axis. num_rotations : int, optional - If num_rotations >= 1: weights of a kernel will be shared over + If num_rotations > 1: weights of a kernel will be shared over 'num_rotations' many rotated versions of that kernel. dilation_rate : None or list of int, optional The dilation rate to be used for the layer. @@ -347,9 +358,10 @@ def __init__( var_list : list of tf.Variables, optional A list of Variables of which the kernel is created from. This must only be provided (and only if) the parameter 'kernel' is not None. - azimuth : float or scalar float tf.Tensor - Hexagonal kernel is turned by the angle 'azimuth' - [given in degrees] in counterclockwise direction + turn_azimuth : bool, optional + If True, the kernel will be turned by the angle 'azimuth' in + counterclockwise direction. The azimuth is given in degrees + and must be provided in the __call__ method. float_precision : tf.dtype, optional The tensorflow dtype describing the float precision to use. seed : int, optional @@ -363,6 +375,7 @@ def __init__( Input data. """ super(ConvHex, self).__init__(name=name) + self.turn_azimuth = turn_azimuth # make sure it is a 2d or 3d convolution assert len(input_shape) == 4 or len(input_shape) == 5 @@ -371,27 +384,36 @@ def __init__( num_channels = input_shape[-1] if kernel is None: - if azimuth is not None and filter_size[:2] != [1, 0]: - kernel, var_list = rotation.get_dynamic_rotation_hex_kernel( + if self.turn_azimuth and filter_size[:2] != [1, 0]: + kernel_obj = rotation.DynamicRotationHexKernel( filter_size + [num_channels, num_filters], - azimuth, float_precision=float_precision, seed=seed, + name=self.name + "_kernel", ) + var_list = kernel_obj.var_list else: if num_rotations > 1: - kernel, var_list = rotation.get_rotated_hex_kernel( + kernel_obj = rotation.RotatedHexKernel( filter_size + [num_channels, num_filters], num_rotations, float_precision=float_precision, seed=seed, + name=self.name + "_kernel", ) + var_list = kernel_obj.var_list else: - kernel, var_list = get_hex_kernel( + kernel_obj = HexKernel( filter_size + [num_channels, num_filters], float_precision=float_precision, seed=seed, + name=self.name + "_kernel", ) + var_list = kernel_obj.var_list + else: + + def kernel_obj(): + return kernel self.num_filters = num_filters self.filter_size = filter_size @@ -400,19 +422,21 @@ def __init__( self.num_rotations = num_rotations self.dilation_rate = dilation_rate self.zero_out = zero_out - self.azimuth = azimuth self.float_precision = float_precision self.seed = seed - self.kernel = kernel + self.kernel_obj = kernel_obj self.kernel_var_list = var_list - def __call__(self, inputs): + def __call__(self, inputs, azimuth=None): """Apply ConvHex Module. Parameters ---------- inputs : tf.Tensor Input tensor. + azimuth : float or scalar float tf.Tensor + Hexagonal kernel is turned by the angle 'azimuth' + [given in degrees] in counterclockwise direction Returns ------- @@ -424,10 +448,10 @@ def __call__(self, inputs): # make sure it is a 2d or 3d convolution assert len(inputs.get_shape()) == 4 or len(inputs.get_shape()) == 5 - if self.azimuth is not None and self.filter_size[:2] != [1, 0]: + if self.turn_azimuth and self.filter_size[:2] != [1, 0]: result = dynamic_conv( inputs=inputs, - filters=self.kernel, + filters=self.kernel_obj(azimuth), strides=self.strides[1:-1], padding=self.padding, dilation_rate=self.dilation_rate, @@ -435,7 +459,7 @@ def __call__(self, inputs): else: result = tf.nn.convolution( input=inputs, - filters=self.kernel, + filters=self.kernel_obj(), strides=self.strides[1:-1], padding=self.padding, dilations=self.dilation_rate, @@ -449,12 +473,15 @@ def __call__(self, inputs): logger = logging.getLogger(__name__) logger.warning("Assuming IceCube shape for layer", result) - zero_out_matrix, var_list = get_icecube_kernel( + kernel_obj = IceCubeKernel( result.get_shape().as_list()[3:], get_ones=True, float_precision=self.float_precision, seed=self.seed, + name=self.name, ) + zero_out_matrix = kernel_obj() + var_list = kernel_obj.var_list result = result * zero_out_matrix # Make sure there were no extra variables created. @@ -463,7 +490,7 @@ def __call__(self, inputs): else: # Generic hexagonal shape - zero_out_matrix, var_list = get_hex_kernel( + kernel_obj = HexKernel( [ (result.get_shape().as_list()[1] + 1) // 2, 0, @@ -474,6 +501,8 @@ def __call__(self, inputs): float_precision=self.float_precision, seed=self.seed, ) + zero_out_matrix = kernel_obj() + var_list = kernel_obj.var_list # Make sure there were no extra variables created. # These would have to be saved to tf.Module, to allow tracking @@ -507,7 +536,7 @@ def __init__( dilation_rate=None, kernel=None, var_list=None, - azimuth=None, + turn_azimuth=False, stack_axis=None, zero_out=False, float_precision=FLOAT_PRECISION, @@ -581,9 +610,10 @@ def __init__( var_list : list of tf.Variables, optional A list of Variables of which the kernel is created from. This must only be provided (and only if) the parameter 'kernel' is not None. - azimuth : float or scalar float tf.Tensor - Hexagonal kernel is turned by the angle 'azimuth' - [given in degrees] in counterclockwise direction + turn_azimuth : bool, optional + If True, the kernel will be turned by the angle 'azimuth' in + counterclockwise direction. The azimuth is given in degrees + and must be provided in the __call__ method. stack_axis : Int Axis along which the convolutions will be stacked. By default the axis with the lowest output dimensionality will be @@ -606,6 +636,8 @@ def __init__( """ super(ConvHex4d, self).__init__(name=name) + self.turn_azimuth = turn_azimuth + # make sure it is a 4d convolution assert len(input_shape) == 6 @@ -613,27 +645,36 @@ def __init__( num_channels = input_shape[5] if kernel is None: - if azimuth is not None: - kernel, var_list = rotation.get_dynamic_rotation_hex_kernel( + if self.turn_azimuth: + kernel_obj = HexKernel( filter_size + [num_channels, num_filters], - azimuth, float_precision=float_precision, seed=seed, + name=self.name + "_kernel", ) + var_list = kernel_obj.var_list else: if num_rotations > 1: - kernel, var_list = rotation.get_rotated_hex_kernel( + kernel_obj = rotation.RotatedHexKernel( filter_size + [num_channels, num_filters], num_rotations, float_precision=float_precision, seed=seed, + name=self.name + "_kernel", ) + var_list = kernel_obj.var_list else: - kernel, var_list = get_hex_kernel( + kernel_obj = HexKernel( filter_size + [num_channels, num_filters], float_precision=float_precision, seed=seed, + name=self.name + "_kernel", ) + var_list = kernel_obj.var_list + else: + + def kernel_obj(): + return kernel self.num_filters = num_filters self.filter_size = filter_size @@ -641,21 +682,23 @@ def __init__( self.strides = strides self.num_rotations = num_rotations self.dilation_rate = dilation_rate - self.azimuth = azimuth self.stack_axis = stack_axis self.zero_out = zero_out self.float_precision = float_precision self.seed = seed - self.kernel = kernel + self.kernel_obj = kernel_obj self.kernel_var_list = var_list - def __call__(self, inputs): + def __call__(self, inputs, azimuth=None): """Apply ConvHex4d Module. Parameters ---------- inputs : tf.Tensor Input tensor. + azimuth : float or scalar float tf.Tensor + Hexagonal kernel is turned by the angle 'azimuth' + [given in degrees] in counterclockwise direction Returns ------- @@ -667,10 +710,15 @@ def __call__(self, inputs): # make sure it is a 4d convolution assert len(inputs.get_shape()) == 6 + if self.turn_azimuth is not None: + kernel = self.kernel_obj(azimuth) + else: + kernel = self.kernel_obj() + # convolve with tf conv4d_stacked result = conv4d_stacked( input=inputs, - filter=self.kernel, + filter=kernel, strides=self.strides, padding=self.padding, dilation_rate=self.dilation_rate, @@ -679,7 +727,7 @@ def __call__(self, inputs): # zero out elements that don't belong on hexagon if self.zero_out: - zero_out_matrix, var_list = get_hex_kernel( + kernel_obj = HexKernel( [ int((result.get_shape().as_list()[1] + 1) / 2), 0, @@ -691,6 +739,8 @@ def __call__(self, inputs): float_precision=self.float_precision, seed=self.seed, ) + zero_out_matrix = kernel_obj() + var_list = kernel_obj.var_list # Make sure there were no extra variables created. # These would have to be saved to tf.Module, to allow tracking @@ -708,137 +758,3 @@ def __call__(self, inputs): ) return result - - -def create_conv_hex_layers_weights( - num_input_channels, - filter_size_list, - num_filters_list, - num_rotations_list=1, - azimuth_list=None, - float_precision=FLOAT_PRECISION, - seed=None, -): - """Create weights and biases for conv hex n-dimensional layers with n >= 2 - - Parameters - ---------- - num_input_channels : int - Number of channels of input layer. - filter_size_list : list of int or list of list of int - A list of filter sizes. - If only one filter_size is given, this will be used for all layers. - filter_size : A list of int - filter_size = [s, o, 3. dim(e.g. z), 4. dim(e.g. t),...] - s: size of hexagon - o: orientation of hexagon - - Examples: - - s = 2, o = 0: - 1 1 0 1 1 - - 1 1 1 1 1 1 - - 0 1 1 1 1 - - s = 3, o = 2: - 0 1 0 0 0 0 0 1 - - 0 1 1 1 1 0 0 1 1 1 1 - - 1 1 1 1 1 0 0 1 1 1 1 1 - - 0 1 1 1 1 1 0 1 1 1 1 1 - - 0 0 1 1 1 1 1 1 1 1 1 1 - - 0 0 1 1 1 1 0 1 1 1 1 - - 0 0 0 0 0 1 0 1 - num_filters_list : list of int - A list of int where each int denotes the number of filters in - that layer. - num_rotations_list : int or list of int, optional - The number of rotations to use for each layer. - If num_rotations >= 1: weights of a kernel will be shared over - 'num_rotations' many rotated versions of that kernel. - If only a single number is given, the same number of rotations will be - used for all layers. - azimuth_list : None, optional - A list of floats or scalar tf.tensors denoting the azimuth angle by - which the kernel of each layer is rotated. - Hexagonal kernel is turned by the angle 'azimuth' [given in degrees] - in counterclockwise direction. - If only a single azimuth angle is given, the same rotation is used for - all layers. - If azimuth is None, the hexagonal kernel is not rotated. - float_precision : tf.dtype, optional - The tensorflow dtype describing the float precision to use. - seed : int, optional - Seed for the random number generator. - - Returns - ------- - list of tf.Tensor - List of weight tensors for each layer. - list of tf.Tensor - List of bias tensors for each layer. - list of tf.Variable - A list of tensorflow variables created in this function - """ - # create seed counter - cnt = SeedCounter(seed) - - # create num_rotations_list - if isinstance(num_rotations_list, int): - num_rotations_list = [ - num_rotations_list for i in range(len(num_filters_list)) - ] - # create azimuth_list - if azimuth_list is None or tf.is_tensor(azimuth_list): - azimuth_list = [azimuth_list for i in range(len(num_filters_list))] - - weights_list = [] - biases_list = [] - variable_list = [] - for filter_size, num_filters, num_rotations, azimuth in zip( - filter_size_list, - num_filters_list, - num_rotations_list, - azimuth_list, - ): - if azimuth is not None: - kernel, var_list = rotation.get_dynamic_rotation_hex_kernel( - filter_size, - azimuth, - float_precision=float_precision, - seed=cnt(), - ) - else: - if num_rotations > 1: - kernel, var_list = rotation.get_rotated_hex_kernel( - filter_size + [num_input_channels, num_filters], - num_rotations, - float_precision=float_precision, - seed=cnt(), - ) - else: - kernel, var_list = get_hex_kernel( - filter_size + [num_input_channels, num_filters], - float_precision=float_precision, - seed=cnt(), - ) - - variable_list.extend(var_list) - weights_list.append(kernel) - biases_list.append( - new_biases( - length=num_filters * num_rotations, - float_precision=float_precision, - seed=cnt(), - ) - ) - num_input_channels = num_filters - - return weights_list, biases_list, variable_list diff --git a/tfscripts/hex/icecube.py b/tfscripts/hex/icecube.py index 2dbcac2..6de34f1 100644 --- a/tfscripts/hex/icecube.py +++ b/tfscripts/hex/icecube.py @@ -239,62 +239,76 @@ def get_icecube_string_from_hex_coord(a, b): return hex_string_coord_dict[(a, b)] -def get_icecube_kernel( - shape, get_ones=False, float_precision=FLOAT_PRECISION, seed=None -): - """ - Get a kernel of shape 'shape' for IceCube where coordinates of no real - strings are set to constant zeros. +class IceCubeKernel(tf.Module): - Parameters - ---------- - shape : list of int - The shape of the desired kernel. - get_ones : bool, optional - If True, returns constant ones for real DOMs, zeros for virtual DOMs. - If False, return trainable parameter for real DOMs, - zeros for virtual DOMs - float_precision : tf.dtype, optional - The tensorflow dtype describing the float precision to use. - seed : int, optional - Seed for the random number generator. + def __init__( + self, + shape, + get_ones=False, + float_precision=FLOAT_PRECISION, + seed=None, + name="IceCubeKernel", + ): + """ + Get a kernel of shape 'shape' for IceCube where coordinates of no real + strings are set to constant zeros. - Returns - ------- - tf.Tensor - The icecube kernel with the desired shape. - list of tf.Variable - A list of tensorflow variables created in this function - """ - # create seed counter - cnt = SeedCounter(seed) + Parameters + ---------- + shape : list of int + The shape of the desired kernel. + get_ones : bool, optional + If True, returns constant ones for real DOMs, zeros for virtual DOMs. + If False, return trainable parameter for real DOMs, + zeros for virtual DOMs + float_precision : tf.dtype, optional + The tensorflow dtype describing the float precision to use. + seed : int, optional + Seed for the random number generator. + name : str, optional + The name of the kernel. - zeros = tf.zeros(shape, dtype=float_precision) - ones = tf.ones(shape, dtype=float_precision) + Returns + ------- + tf.Tensor + The icecube kernel with the desired shape. + list of tf.Variable + A list of tensorflow variables created in this function + """ + # create seed counter + cnt = SeedCounter(seed) - var_list = [] - a_list = [] - for a in range(-4, 6): + zeros = tf.zeros(shape, dtype=float_precision) + ones = tf.ones(shape, dtype=float_precision) - b_list = [] - for b in range(-5, 5): + self.var_list = [] + self.a_list = [] + for a in range(-4, 6): - if (a, b) in hex_string_coord_dict.keys(): - # String exists - if get_ones: - weights = ones + b_list = [] + for b in range(-5, 5): + + if (a, b) in hex_string_coord_dict.keys(): + # String exists + if get_ones: + weights = ones + else: + weights = new_weights( + shape, + float_precision=float_precision, + seed=cnt(), + name=name + f"_weights_{a}_{b}", + ) + self.var_list.append(weights) else: - weights = new_weights( - shape, - float_precision=float_precision, - seed=cnt(), - ) - var_list.append(weights) - else: - # virtual string, string does not actually exist - weights = zeros + # virtual string, string does not actually exist + weights = zeros + + b_list.append(weights) + self.a_list.append(b_list) - b_list.append(weights) - a_list.append(tf.stack(b_list)) - icecube_kernel = tf.stack(a_list) - return icecube_kernel, var_list + def __call__(self): + """Get the icecube kernel""" + a_list = [tf.stack(b_list) for b_list in self.a_list] + icecube_kernel = tf.stack(a_list) + return icecube_kernel diff --git a/tfscripts/hex/rotation.py b/tfscripts/hex/rotation.py index 5bd2015..4452bb3 100644 --- a/tfscripts/hex/rotation.py +++ b/tfscripts/hex/rotation.py @@ -74,9 +74,7 @@ def tf_get_rotated_corner_weights(corner_weights, azimuth): num_dims = len(corner_weights.get_shape().as_list()[1:]) degree_steps = 360.0 / size - a = tf.reshape( - azimuth % degree_steps, [tf.shape(input=azimuth)[0]] + [1] * num_dims - ) + a = tf.reshape(azimuth % degree_steps, [1] * num_dims) b = tf.cast(azimuth / degree_steps, tf.int32) rotatedcorner_weights = [] for i in range(size): @@ -97,480 +95,611 @@ def tf_get_rotated_corner_weights(corner_weights, azimuth): return rotatedcorner_weights -def get_dynamic_rotation_hex_kernel( - filter_size, - azimuth, - float_precision=FLOAT_PRECISION, - seed=None, -): - """Dynamically azimuthally rotated hexagonal kernels. +class DynamicRotationHexKernel(tf.Module): + """Dynamically azimuthally rotated hexagonal kernels.""" - Create Weights for a hexagonal kernel. - The Kernel is dynamically rotated by the 'azimuth' angle. - The Kernel will be of a hexagonal shape in the first two dimensions, - while the other dimensions are normal. - The hexagonal kernel is of the shape: - [kernel_edge_points, kernel_edge_points, *filter_size[2:]] - But elements with coordinates in the first two dimensions, that don't belong - to the hexagon are set to a tf.Constant 0. - - The hexagon is defined by filter_size[0:2]. - filter_size[0] defines the size of the hexagon and - filter_size[1] the orientation. - - Parameters - ---------- - filter_size : A list of int - filter_size = [s, o, 3. dim(e.g. z), 4. dim(e.g. t),...] - filter_size[-2:] = [no_in_channels, no_out_channels] - s: size of hexagon - o: orientation of hexagon - - Examples: - - s = 2, o = 0: - 1 1 0 1 1 + def __init__( + self, + filter_size, + float_precision=FLOAT_PRECISION, + seed=None, + name="DynamicRotationHexKernel", + ): + """Dynamically azimuthally rotated hexagonal kernels. + + Create Weights for a hexagonal kernel. + The Kernel is dynamically rotated by the 'azimuth' angle. + The Kernel will be of a hexagonal shape in the first two dimensions, + while the other dimensions are normal. + The hexagonal kernel is of the shape: + [kernel_edge_points, kernel_edge_points, *filter_size[2:]] + But elements with coordinates in the first two dimensions, that don't belong + to the hexagon are set to a tf.Constant 0. + + The hexagon is defined by filter_size[0:2]. + filter_size[0] defines the size of the hexagon and + filter_size[1] the orientation. + + Parameters + ---------- + filter_size : A list of int + filter_size = [s, o, 3. dim(e.g. z), 4. dim(e.g. t),...] + filter_size[-2:] = [no_in_channels, no_out_channels] + s: size of hexagon + o: orientation of hexagon + + Examples: + + s = 2, o = 0: + 1 1 0 1 1 + + 1 1 1 1 1 1 + + 0 1 1 1 1 + + float_precision : tf.dtype, optional + The tensorflow dtype describing the float precision to use. + seed : int, optional + Seed for the random number generator. + name : str, optional + The name of the operation. + + Returns + ------- + tf.Tensor + A Tensor with shape: [ s, s, *filter_size[2:]] + where s = 2*filter_size[0] -1 if x == o + [hexagon is parallel to axis of first dimension] + = 2*filter_size[0] +1 if x != o + [hexagon is tilted to axis of first dimension] - 1 1 1 1 1 1 + Raises + ------ + ValueError + Description - 0 1 1 1 1 + """ + self.float_precision = float_precision + self.filter_size = filter_size - azimuth : tf tensor - A scalar float tf.Tensor denoting the angle by which the kernel will - be dynamically rotated. Azimuth angle is given in degrees. - float_precision : tf.dtype, optional - The tensorflow dtype describing the float precision to use. - seed : int, optional - Seed for the random number generator. + # create seed counter + cnt = SeedCounter(seed) - Returns - ------- - tf.Tensor - A Tensor with shape: [ s, s, *filter_size[2:]] - where s = 2*filter_size[0] -1 if x == o - [hexagon is parallel to axis of first dimension] - = 2*filter_size[0] +1 if x != o - [hexagon is tilted to axis of first dimension] - - Raises - ------ - ValueError - Description + self.var_list = [] + self.no_of_dims = len(filter_size) - """ - # create seed counter - cnt = SeedCounter(seed) - - var_list = [] - no_of_dims = len(filter_size) - - Z = tf.zeros( - [tf.shape(input=azimuth)[0]] + filter_size[2:], dtype=float_precision - ) - center_weight = new_weights( - [1] + filter_size[2:], - float_precision=float_precision, - seed=cnt(), - ) - var_list.append(center_weight) - multiples = [tf.shape(input=azimuth)[0]] + [1] * (no_of_dims - 2) - center_weight = tf.tile(center_weight, multiples) - - # HARDCODE MAGIC... ToDo: Generalize and clean up - if filter_size[0:2] == [2, 0]: - # hexagonal 2,0 Filter - corner_weights1 = new_weights( - [6] + filter_size[2:], - float_precision=float_precision, - seed=cnt(), - ) - var_list.append(corner_weights1) - elif filter_size[0:2] == [2, 1]: - # hexagonal 2,1 Filter - corner_weights1 = new_weights( - [6] + filter_size[2:], + self.Z = tf.zeros(filter_size[2:], dtype=float_precision) + self.center_weight = new_weights( + filter_size[2:], float_precision=float_precision, seed=cnt(), + name=name + "_center_weight", ) - var_list.append(corner_weights1) - corner_weights2 = [] - for i in range(6): - weights = new_weights( - filter_size[2:], + self.var_list.append(self.center_weight) + + self.corner_weights1 = None + self.corner_weights2 = None + self.corner_weights3 = None + + # HARDCODE MAGIC... ToDo: Generalize and clean up + if filter_size[0:2] == [2, 0]: + # hexagonal 2,0 Filter + self.corner_weights1 = new_weights( + [6] + filter_size[2:], float_precision=float_precision, seed=cnt(), + name=name + "_corner_weights1", ) - var_list.append(weights) - corner_weights2.extend([Z, weights]) - corner_weights2 = tf.stack(corner_weights2) - elif filter_size[0:2] == [3, 0]: - # hexagonal 3,0 Filter - corner_weights1 = new_weights( - [6] + filter_size[2:], - float_precision=float_precision, - seed=cnt(), - ) - var_list.append(corner_weights1) - corner_weights2 = new_weights( - [12] + filter_size[2:], - float_precision=float_precision, - seed=cnt(), - ) - var_list.append(corner_weights2) - elif filter_size[0:2] == [3, 1]: - # hexagonal 3,1 Filter - corner_weights1 = new_weights( - [6] + filter_size[2:], - float_precision=float_precision, - seed=cnt(), - ) - var_list.append(corner_weights1) - corner_weights2 = new_weights( - [12] + filter_size[2:], - float_precision=float_precision, - seed=cnt(), - ) - var_list.append(corner_weights2) - corner_weights3 = [] - for i in range(6): - weights = new_weights( - filter_size[2:], + self.var_list.append(self.corner_weights1) + elif filter_size[0:2] == [2, 1]: + # hexagonal 2,1 Filter + self.corner_weights1 = new_weights( + [6] + filter_size[2:], float_precision=float_precision, seed=cnt(), + name=name + "_corner_weights1", ) - var_list.append(weights) - corner_weights3.extend([Z, weights, Z]) - corner_weights3 = tf.stack(corner_weights3) - elif filter_size[0:2] == [3, 2]: - # hexagonal 3,2 Filter - corner_weights1 = new_weights( - [6] + filter_size[2:], - float_precision=float_precision, - seed=cnt(), - ) - var_list.append(corner_weights1) - corner_weights2 = new_weights( - [12] + filter_size[2:], - float_precision=float_precision, - seed=cnt(), - ) - var_list.append(corner_weights2) - corner_weights3 = [] - for i in range(6): - weights = new_weights( - filter_size[2:], + self.var_list.append(self.corner_weights1) + self.corner_weights2 = [] + for i in range(6): + weights = new_weights( + filter_size[2:], + float_precision=float_precision, + seed=cnt(), + name=name + "_corner_weights2", + ) + self.var_list.append(weights) + self.corner_weights2.extend([self.Z, weights]) + elif filter_size[0:2] == [3, 0]: + # hexagonal 3,0 Filter + self.corner_weights1 = new_weights( + [6] + filter_size[2:], float_precision=float_precision, seed=cnt(), + name=name + "_corner_weights1", ) - var_list.append(weights) - corner_weights3.extend([Z, Z, weights]) - corner_weights3 = tf.stack(corner_weights3) - elif filter_size[0:2] == [4, 0]: - # hexagonal 4,0 Filter - corner_weights1 = new_weights( - [6] + filter_size[2:], - float_precision=float_precision, - seed=cnt(), - ) - var_list.append(corner_weights1) - corner_weights2 = new_weights( - [12] + filter_size[2:], - float_precision=float_precision, - seed=cnt(), - ) - var_list.append(corner_weights2) - corner_weights3 = new_weights( - [18] + filter_size[2:], - float_precision=float_precision, - seed=cnt(), - ) - var_list.append(corner_weights3) - else: - raise ValueError( - "get_dynamic_rotation_hex_kernel: Unsupported " - "hexagonal filter_size: {!r}".format(filter_size[0:2]) - ) - - rotated_kernel_rows = [] - if filter_size[0:2] == [2, 0]: - # hexagonal 2,0 Filter - A = tf_get_rotated_corner_weights(corner_weights1, azimuth) - rotated_kernel_rows.append(tf.stack([Z, A[5], A[0]], axis=1)) - rotated_kernel_rows.append( - tf.stack([A[3], center_weight, A[1]], axis=1) - ) - rotated_kernel_rows.append(tf.stack([A[3], A[2], Z], axis=1)) - elif filter_size[0:2] == [2, 1] or filter_size[0:2] == [3, 0]: - # hexagonal 2,1 and 3,0 Filter - A = tf_get_rotated_corner_weights(corner_weights1, azimuth) - B = tf_get_rotated_corner_weights(corner_weights2, azimuth) - rotated_kernel_rows.append( - tf.stack([Z, Z, B[9], B[10], B[11]], axis=1) - ) - rotated_kernel_rows.append( - tf.stack([Z, B[8], A[5], A[0], B[0]], axis=1) - ) - rotated_kernel_rows.append( - tf.stack([B[7], A[4], center_weight, A[1], B[1]], axis=1) - ) - rotated_kernel_rows.append( - tf.stack([B[6], A[3], A[2], B[2], Z], axis=1) - ) - rotated_kernel_rows.append(tf.stack([B[5], B[4], B[3], Z, Z], axis=1)) - elif ( - filter_size[0:2] == [3, 1] - or filter_size[0:2] == [3, 2] - or filter_size[0:2] == [4, 0] - ): - # hexagonal 3,1 3,2 and 4,0 filter - A = tf_get_rotated_corner_weights(corner_weights1, azimuth) - B = tf_get_rotated_corner_weights(corner_weights2, azimuth) - C = tf_get_rotated_corner_weights(corner_weights3, azimuth) - rotated_kernel_rows.append( - tf.stack([Z, Z, Z, C[15], C[16], C[17], C[0]], axis=1) - ) - rotated_kernel_rows.append( - tf.stack([Z, Z, C[14], B[9], B[10], B[11], C[1]], axis=1) - ) - rotated_kernel_rows.append( - tf.stack([Z, C[13], B[8], A[5], A[0], B[0], C[2]], axis=1) - ) - rotated_kernel_rows.append( - tf.stack( - [C[12], B[7], A[4], center_weight, A[1], B[1], C[3]], axis=1 + self.var_list.append(self.corner_weights1) + self.corner_weights2 = new_weights( + [12] + filter_size[2:], + float_precision=float_precision, + seed=cnt(), + name=name + "_corner_weights2", + ) + self.var_list.append(self.corner_weights2) + elif filter_size[0:2] == [3, 1]: + # hexagonal 3,1 Filter + self.corner_weights1 = new_weights( + [6] + filter_size[2:], + float_precision=float_precision, + seed=cnt(), + name=name + "_corner_weights1", + ) + self.var_list.append(self.corner_weights1) + self.corner_weights2 = new_weights( + [12] + filter_size[2:], + float_precision=float_precision, + seed=cnt(), + name=name + "_corner_weights2", + ) + self.var_list.append(self.corner_weights2) + self.corner_weights3 = [] + for i in range(6): + weights = new_weights( + filter_size[2:], + float_precision=float_precision, + seed=cnt(), + name=name + f"_corner_weights3_{i}", + ) + self.var_list.append(weights) + self.corner_weights3.extend([self.Z, weights, self.Z]) + elif filter_size[0:2] == [3, 2]: + # hexagonal 3,2 Filter + self.corner_weights1 = new_weights( + [6] + filter_size[2:], + float_precision=float_precision, + seed=cnt(), + name=name + "_corner_weights1", + ) + self.var_list.append(self.corner_weights1) + self.corner_weights2 = new_weights( + [12] + filter_size[2:], + float_precision=float_precision, + seed=cnt(), + name=name + "_corner_weights2", + ) + self.var_list.append(self.corner_weights2) + self.corner_weights3 = [] + for i in range(6): + weights = new_weights( + filter_size[2:], + float_precision=float_precision, + seed=cnt(), + name=name + f"_corner_weights3_{i}", + ) + self.var_list.append(weights) + self.corner_weights3.extend([self.Z, self.Z, weights]) + elif filter_size[0:2] == [4, 0]: + # hexagonal 4,0 Filter + self.corner_weights1 = new_weights( + [6] + filter_size[2:], + float_precision=float_precision, + seed=cnt(), + name=name + "_corner_weights1", + ) + self.var_list.append(self.corner_weights1) + self.corner_weights2 = new_weights( + [12] + filter_size[2:], + float_precision=float_precision, + seed=cnt(), + name=name + "_corner_weights2", + ) + self.var_list.append(self.corner_weights2) + self.corner_weights3 = new_weights( + [18] + filter_size[2:], + float_precision=float_precision, + seed=cnt(), + name=name + "_corner_weights3", + ) + self.var_list.append(self.corner_weights3) + else: + raise ValueError( + "DynamicRotationHexKernel: Unsupported " + "hexagonal filter_size: {!r}".format(filter_size[0:2]) ) - ) - rotated_kernel_rows.append( - tf.stack([C[11], B[6], A[3], A[2], B[2], C[4], Z], axis=1) - ) - rotated_kernel_rows.append( - tf.stack([C[10], B[5], B[4], B[3], C[5], Z, Z], axis=1) - ) - rotated_kernel_rows.append( - tf.stack([C[9], C[8], C[7], C[6], Z, Z, Z], axis=1) - ) - else: - raise ValueError( - "get_dynamic_rotation_hex_kernel: Unsupported " - "hexagonal filter_size: {!r}".format(filter_size[0:2]) - ) - - rotated_kernel = tf.stack(rotated_kernel_rows, axis=1) - - return rotated_kernel, var_list - - -# ------------------------------------------------------------------------- -# hexagonal azimuth rotated filters -# ------------------------------------------------------------------------- -def get_rotated_hex_kernel( - filter_size, - num_rotations, - float_precision=FLOAT_PRECISION, - seed=None, -): - """ - Create Weights for a hexagonal kernel. - The kernel is rotated 'num_rotations' many times. - Weights are shared over rotated versions. - The Kernel will be of a hexagonal shape in the first two dimensions, - while the other dimensions are normal. - The hexagonal kernel is of the shape: - [kernel_edge_points, kernel_edge_points, *filter_size[2:]] - But elements with coordinates in the first two dimensions, that don't belong - to the hexagon are set to a tf.Constant 0. - - The hexagon is defined by filter_size[0:2]. - filter_size[0] defines the size of the hexagon and - filter_size[1] the orientation. - - Parameters - ---------- - filter_size : A list of int - filter_size = [s, o, 3. dim(e.g. z), 4. dim(e.g. t),...] - filter_size[-2:] = [no_in_channels, no_out_channels] - s: size of hexagon - o: orientation of hexagon - - Examples: - - s = 2, o = 0: - 1 1 0 1 1 - - 1 1 1 1 1 1 - 0 1 1 1 1 + def __call__(self, azimuth): + """Dynamically azimuthally rotated hexagonal kernels. - num_rotations : int. - number of rotational kernels to create. - Kernels will be rotated by 360 degrees / num_rotations - float_precision : tf.dtype, optional - The tensorflow dtype describing the float precision to use. - seed : int, optional - Seed for the random number generator. + Parameters + ---------- + azimuth : tf tensor + A scalar float tf.Tensor denoting the angle by which the kernel will + be dynamically rotated. Azimuth angle is given in degrees. - Returns - ------- - tf.Tensor - A Tensor with shape: - [ s, s, *filter_size[2:-1], filter_size[-1]*num_rotations ] + Returns + ------- + tf.Tensor + A Tensor with shape: [ s, s, *filter_size[2:]] where s = 2*filter_size[0] -1 if x == o - [hexagon is parallel to axis of first dimension] + [hexagon is parallel to axis of first dimension] = 2*filter_size[0] +1 if x != o - [hexagon is tilted to axis of first dimension] - - Raises - ------ - ValueError - Description - - """ - # create seed counter - cnt = SeedCounter(seed) - - # define function to get new weights with correct shape - var_list = [] - - def get_new_weights(var_list): - weights = new_weights( - filter_size[2:-2], - float_precision=float_precision, - seed=cnt(), - ) - var_list.append(weights) - return weights - - no_of_dims = len(filter_size) - azimuths = np.linspace(0, 360, num_rotations + 1)[:-1] - Z = tf.zeros(filter_size[2:-2], dtype=float_precision) - center_weight = get_new_weights(var_list) - - # HARDCODE MAGIC... ToDo: Generalize - if filter_size[0:2] == [2, 0]: - # hexagonal 2,0 Filter - corner_weights1 = [get_new_weights(var_list) for i in range(6)] - - elif filter_size[0:2] == [2, 1]: - # hexagonal 2,1 Filter - corner_weights1 = [get_new_weights(var_list) for i in range(6)] - - corner_weights2 = [] - for i in range(6): - corner_weights2.extend([Z, get_new_weights(var_list)]) - - elif filter_size[0:2] == [3, 0]: - # hexagonal 3,0 Filter - corner_weights1 = [get_new_weights(var_list) for i in range(6)] - corner_weights2 = [get_new_weights(var_list) for i in range(12)] - - elif filter_size[0:2] == [3, 1]: - # hexagonal 3,1 Filter - corner_weights1 = [get_new_weights(var_list) for i in range(6)] - corner_weights2 = [get_new_weights(var_list) for i in range(12)] - - corner_weights3 = [] - for i in range(6): - corner_weights3.extend([Z, get_new_weights(var_list), Z]) - - elif filter_size[0:2] == [3, 2]: - # hexagonal 3,2 Filter - corner_weights1 = [get_new_weights(var_list) for i in range(6)] - corner_weights2 = [get_new_weights(var_list) for i in range(12)] - - corner_weights3 = [] - for i in range(6): - corner_weights3.extend([Z, Z, get_new_weights(var_list)]) - - elif filter_size[0:2] == [4, 0]: - # hexagonal 4,0 Filter - corner_weights1 = [get_new_weights(var_list) for i in range(6)] - corner_weights2 = [get_new_weights(var_list) for i in range(12)] - corner_weights3 = [get_new_weights(var_list) for i in range(18)] - - else: - raise ValueError( - "get_rotated_hex_kernel: Unsupported " - "hexagonal filter_size: {!r}".format(filter_size[0:2]) - ) + [hexagon is tilted to axis of first dimension] + """ + azimuth = tf.cast(azimuth, self.float_precision) + assert azimuth.get_shape().as_list() == [] + + multiples = [1] * (self.no_of_dims - 2) + center_weight = tf.tile(self.center_weight, multiples) + + # Get corner weights + corner_weights1 = self.corner_weights1 + if self.filter_size[0:2] == [2, 1]: + corner_weights2 = tf.stack(self.corner_weights2) + else: + corner_weights2 = self.corner_weights2 - rotated_kernels = [] - in_out_channel_weights = new_weights( - [num_rotations] + filter_size[-2:], - float_precision=float_precision, - seed=cnt(), - ) - var_list.append(in_out_channel_weights) + if self.filter_size[0:2] in ([3, 1], [3, 2]): + corner_weights3 = tf.stack(self.corner_weights3) + else: + corner_weights3 = self.corner_weights3 - for i, azimuth in enumerate(azimuths): + # combine kernel by rotating corner weights rotated_kernel_rows = [] - if filter_size[0:2] == [2, 0]: + if self.filter_size[0:2] == [2, 0]: # hexagonal 2,0 Filter - A = get_rotated_corner_weights(corner_weights1, azimuth) - rotated_kernel_rows.append(tf.stack([Z, A[5], A[0]])) - rotated_kernel_rows.append(tf.stack([A[3], center_weight, A[1]])) - rotated_kernel_rows.append(tf.stack([A[3], A[2], Z])) - elif filter_size[0:2] == [2, 1] or filter_size[0:2] == [3, 0]: + A = tf_get_rotated_corner_weights(corner_weights1, azimuth) + rotated_kernel_rows.append(tf.stack([self.Z, A[5], A[0]], axis=0)) + rotated_kernel_rows.append( + tf.stack([A[3], center_weight, A[1]], axis=0) + ) + rotated_kernel_rows.append(tf.stack([A[3], A[2], self.Z], axis=0)) + elif self.filter_size[0:2] == [2, 1] or self.filter_size[0:2] == [ + 3, + 0, + ]: # hexagonal 2,1 and 3,0 Filter - A = get_rotated_corner_weights(corner_weights1, azimuth) - B = get_rotated_corner_weights(corner_weights2, azimuth) - rotated_kernel_rows.append(tf.stack([Z, Z, B[9], B[10], B[11]])) - rotated_kernel_rows.append(tf.stack([Z, B[8], A[5], A[0], B[0]])) + A = tf_get_rotated_corner_weights(corner_weights1, azimuth) + B = tf_get_rotated_corner_weights(corner_weights2, azimuth) + rotated_kernel_rows.append( + tf.stack([self.Z, self.Z, B[9], B[10], B[11]], axis=0) + ) + rotated_kernel_rows.append( + tf.stack([self.Z, B[8], A[5], A[0], B[0]], axis=0) + ) rotated_kernel_rows.append( - tf.stack([B[7], A[4], center_weight, A[1], B[1]]) + tf.stack([B[7], A[4], center_weight, A[1], B[1]], axis=0) + ) + rotated_kernel_rows.append( + tf.stack([B[6], A[3], A[2], B[2], self.Z], axis=0) + ) + rotated_kernel_rows.append( + tf.stack([B[5], B[4], B[3], self.Z, self.Z], axis=0) ) - rotated_kernel_rows.append(tf.stack([B[6], A[3], A[2], B[2], Z])) - rotated_kernel_rows.append(tf.stack([B[5], B[4], B[3], Z, Z])) elif ( - filter_size[0:2] == [3, 1] - or filter_size[0:2] == [3, 2] - or filter_size[0:2] == [4, 0] + self.filter_size[0:2] == [3, 1] + or self.filter_size[0:2] == [3, 2] + or self.filter_size[0:2] == [4, 0] ): # hexagonal 3,1 3,2 and 4,0 filter - A = get_rotated_corner_weights(corner_weights1, azimuth) - B = get_rotated_corner_weights(corner_weights2, azimuth) - C = get_rotated_corner_weights(corner_weights3, azimuth) + A = tf_get_rotated_corner_weights(corner_weights1, azimuth) + B = tf_get_rotated_corner_weights(corner_weights2, azimuth) + C = tf_get_rotated_corner_weights(corner_weights3, azimuth) rotated_kernel_rows.append( - tf.stack([Z, Z, Z, C[15], C[16], C[17], C[0]]) + tf.stack( + [self.Z, self.Z, self.Z, C[15], C[16], C[17], C[0]], axis=0 + ) ) rotated_kernel_rows.append( - tf.stack([Z, Z, C[14], B[9], B[10], B[11], C[1]]) + tf.stack( + [self.Z, self.Z, C[14], B[9], B[10], B[11], C[1]], axis=0 + ) ) rotated_kernel_rows.append( - tf.stack([Z, C[13], B[8], A[5], A[0], B[0], C[2]]) + tf.stack([self.Z, C[13], B[8], A[5], A[0], B[0], C[2]], axis=0) ) rotated_kernel_rows.append( - tf.stack([C[12], B[7], A[4], center_weight, A[1], B[1], C[3]]) + tf.stack( + [C[12], B[7], A[4], center_weight, A[1], B[1], C[3]], + axis=0, + ) ) rotated_kernel_rows.append( - tf.stack([C[11], B[6], A[3], A[2], B[2], C[4], Z]) + tf.stack([C[11], B[6], A[3], A[2], B[2], C[4], self.Z], axis=0) ) rotated_kernel_rows.append( - tf.stack([C[10], B[5], B[4], B[3], C[5], Z, Z]) + tf.stack( + [C[10], B[5], B[4], B[3], C[5], self.Z, self.Z], axis=0 + ) ) rotated_kernel_rows.append( - tf.stack([C[9], C[8], C[7], C[6], Z, Z, Z]) + tf.stack( + [C[9], C[8], C[7], C[6], self.Z, self.Z, self.Z], axis=0 + ) ) else: raise ValueError( - "get_rotated_hex_kernel: Unsupported hexagonal " - "filter_size: {!r}".format(filter_size[0:2]) + "DynamicRotationHexKernel: Unsupported " + "hexagonal filter_size: {!r}".format(self.filter_size[0:2]) + ) + + rotated_kernel = tf.stack(rotated_kernel_rows, axis=0) + + return rotated_kernel + + +# ------------------------------------------------------------------------- +# hexagonal azimuth rotated filters +# ------------------------------------------------------------------------- +class RotatedHexKernel(tf.Module): + """Rotated hexagonal kernels.""" + + def __init__( + self, + filter_size, + num_rotations, + float_precision=FLOAT_PRECISION, + seed=None, + name="RotatedHexKernel", + ): + """ + Create Weights for a hexagonal kernel. + The kernel is rotated 'num_rotations' many times. + Weights are shared over rotated versions. + The Kernel will be of a hexagonal shape in the first two dimensions, + while the other dimensions are normal. + The hexagonal kernel is of the shape: + [kernel_edge_points, kernel_edge_points, *filter_size[2:]] + But elements with coordinates in the first two dimensions, that don't belong + to the hexagon are set to a tf.Constant 0. + + The hexagon is defined by filter_size[0:2]. + filter_size[0] defines the size of the hexagon and + filter_size[1] the orientation. + + Parameters + ---------- + filter_size : A list of int + filter_size = [s, o, 3. dim(e.g. z), 4. dim(e.g. t),...] + filter_size[-2:] = [no_in_channels, no_out_channels] + s: size of hexagon + o: orientation of hexagon + + Examples: + + s = 2, o = 0: + 1 1 0 1 1 + + 1 1 1 1 1 1 + + 0 1 1 1 1 + + num_rotations : int. + number of rotational kernels to create. + Kernels will be rotated by 360 degrees / num_rotations + float_precision : tf.dtype, optional + The tensorflow dtype describing the float precision to use. + seed : int, optional + Seed for the random number generator. + name : str, optional + The name of the operation. + + Returns + ------- + tf.Tensor + A Tensor with shape: + [ s, s, *filter_size[2:-1], filter_size[-1]*num_rotations ] + where s = 2*filter_size[0] -1 if x == o + [hexagon is parallel to axis of first dimension] + = 2*filter_size[0] +1 if x != o + [hexagon is tilted to axis of first dimension] + + Raises + ------ + ValueError + Description + + """ + self.float_precision = float_precision + self.filter_size = filter_size + + # create seed counter + cnt = SeedCounter(seed) + + # define function to get new weights with correct shape + self.var_list = [] + + def get_new_weights(): + weights = new_weights( + filter_size[2:-2], + float_precision=float_precision, + seed=cnt(), + name=name, ) - rotated_kernel_single = tf.stack(rotated_kernel_rows) + self.var_list.append(weights) + return weights - # Add free parameters for in and out channel - # tile to correct format - rotated_kernel_single = tf.expand_dims(rotated_kernel_single, -1) - rotated_kernel_single = tf.expand_dims(rotated_kernel_single, -1) + self.no_of_dims = len(filter_size) + self.azimuths = np.linspace(0, 360, num_rotations + 1)[:-1] + self.Z = tf.zeros(filter_size[2:-2], dtype=float_precision) + self.center_weight = get_new_weights() - multiples = [1 for i in range(no_of_dims - 2)] + filter_size[-2:] - rotated_kernel_tiled = tf.tile(rotated_kernel_single, multiples) + self.corner_weights1 = None + self.corner_weights2 = None + self.corner_weights3 = None - # multiply weights to make in and out channels independent - rotated_kernel = rotated_kernel_tiled * in_out_channel_weights[i] + # HARDCODE MAGIC... ToDo: Generalize + if filter_size[0:2] == [2, 0]: + # hexagonal 2,0 Filter + self.corner_weights1 = [get_new_weights() for i in range(6)] + + elif filter_size[0:2] == [2, 1]: + # hexagonal 2,1 Filter + self.corner_weights1 = [get_new_weights() for i in range(6)] + + self.corner_weights2 = [] + for i in range(6): + self.corner_weights2.extend([self.Z, get_new_weights()]) + + elif filter_size[0:2] == [3, 0]: + # hexagonal 3,0 Filter + self.corner_weights1 = [get_new_weights() for i in range(6)] + self.corner_weights2 = [get_new_weights() for i in range(12)] + + elif filter_size[0:2] == [3, 1]: + # hexagonal 3,1 Filter + self.corner_weights1 = [get_new_weights() for i in range(6)] + self.corner_weights2 = [get_new_weights() for i in range(12)] + + self.corner_weights3 = [] + for i in range(6): + self.corner_weights3.extend( + [self.Z, get_new_weights(), self.Z] + ) + + elif filter_size[0:2] == [3, 2]: + # hexagonal 3,2 Filter + self.corner_weights1 = [get_new_weights() for i in range(6)] + self.corner_weights2 = [get_new_weights() for i in range(12)] + + self.corner_weights3 = [] + for i in range(6): + self.corner_weights3.extend( + [self.Z, self.Z, get_new_weights()] + ) + + elif filter_size[0:2] == [4, 0]: + # hexagonal 4,0 Filter + self.corner_weights1 = [get_new_weights() for i in range(6)] + self.corner_weights2 = [get_new_weights() for i in range(12)] + self.corner_weights3 = [get_new_weights() for i in range(18)] - rotated_kernels.append(rotated_kernel) + else: + raise ValueError( + "RotatedHexKernel: Unsupported " + "hexagonal filter_size: {!r}".format(filter_size[0:2]) + ) - rotated_kernels = tf.concat( - values=rotated_kernels, axis=len(filter_size) - 1 - ) - return rotated_kernels, var_list + self.in_out_channel_weights = new_weights( + [num_rotations] + filter_size[-2:], + float_precision=float_precision, + seed=cnt(), + name=name + "_in_out_channel_weights", + ) + self.var_list.append(self.in_out_channel_weights) + + def __call__(self): + """Rotated hexagonal kernels. + + Returns + ------- + tf.Tensor + A Tensor with shape: + [ s, s, *filter_size[2:-1], filter_size[-1]*num_rotations ] + where s = 2*filter_size[0] -1 if x == o + [hexagon is parallel to axis of first dimension] + = 2*filter_size[0] +1 if x != o + [hexagon is tilted to axis of first dimension] + """ + rotated_kernels = [] + for i, azimuth in enumerate(self.azimuths): + rotated_kernel_rows = [] + if self.filter_size[0:2] == [2, 0]: + # hexagonal 2,0 Filter + A = get_rotated_corner_weights(self.corner_weights1, azimuth) + rotated_kernel_rows.append(tf.stack([self.Z, A[5], A[0]])) + rotated_kernel_rows.append( + tf.stack([A[3], self.center_weight, A[1]]) + ) + rotated_kernel_rows.append(tf.stack([A[3], A[2], self.Z])) + elif self.filter_size[0:2] == [2, 1] or self.filter_size[0:2] == [ + 3, + 0, + ]: + # hexagonal 2,1 and 3,0 Filter + A = get_rotated_corner_weights(self.corner_weights1, azimuth) + B = get_rotated_corner_weights(self.corner_weights2, azimuth) + rotated_kernel_rows.append( + tf.stack([self.Z, self.Z, B[9], B[10], B[11]]) + ) + rotated_kernel_rows.append( + tf.stack([self.Z, B[8], A[5], A[0], B[0]]) + ) + rotated_kernel_rows.append( + tf.stack([B[7], A[4], self.center_weight, A[1], B[1]]) + ) + rotated_kernel_rows.append( + tf.stack([B[6], A[3], A[2], B[2], self.Z]) + ) + rotated_kernel_rows.append( + tf.stack([B[5], B[4], B[3], self.Z, self.Z]) + ) + elif ( + self.filter_size[0:2] == [3, 1] + or self.filter_size[0:2] == [3, 2] + or self.filter_size[0:2] == [4, 0] + ): + # hexagonal 3,1 3,2 and 4,0 filter + A = get_rotated_corner_weights(self.corner_weights1, azimuth) + B = get_rotated_corner_weights(self.corner_weights2, azimuth) + C = get_rotated_corner_weights(self.corner_weights3, azimuth) + rotated_kernel_rows.append( + tf.stack( + [self.Z, self.Z, self.Z, C[15], C[16], C[17], C[0]] + ) + ) + rotated_kernel_rows.append( + tf.stack([self.Z, self.Z, C[14], B[9], B[10], B[11], C[1]]) + ) + rotated_kernel_rows.append( + tf.stack([self.Z, C[13], B[8], A[5], A[0], B[0], C[2]]) + ) + rotated_kernel_rows.append( + tf.stack( + [ + C[12], + B[7], + A[4], + self.center_weight, + A[1], + B[1], + C[3], + ] + ) + ) + rotated_kernel_rows.append( + tf.stack([C[11], B[6], A[3], A[2], B[2], C[4], self.Z]) + ) + rotated_kernel_rows.append( + tf.stack([C[10], B[5], B[4], B[3], C[5], self.Z, self.Z]) + ) + rotated_kernel_rows.append( + tf.stack([C[9], C[8], C[7], C[6], self.Z, self.Z, self.Z]) + ) + else: + raise ValueError( + "RotatedHexKernel: Unsupported hexagonal " + "filter_size: {!r}".format(self.filter_size[0:2]) + ) + rotated_kernel_single = tf.stack(rotated_kernel_rows) + + # Add free parameters for in and out channel + # tile to correct format + rotated_kernel_single = tf.expand_dims(rotated_kernel_single, -1) + rotated_kernel_single = tf.expand_dims(rotated_kernel_single, -1) + + multiples = [ + 1 for i in range(self.no_of_dims - 2) + ] + self.filter_size[-2:] + rotated_kernel_tiled = tf.tile(rotated_kernel_single, multiples) + + # multiply weights to make in and out channels independent + rotated_kernel = ( + rotated_kernel_tiled * self.in_out_channel_weights[i] + ) + + rotated_kernels.append(rotated_kernel) + + rotated_kernels = tf.concat( + values=rotated_kernels, axis=len(self.filter_size) - 1 + ) + return rotated_kernels diff --git a/tfscripts/layers.py b/tfscripts/layers.py index f12dbb8..f38b3a3 100644 --- a/tfscripts/layers.py +++ b/tfscripts/layers.py @@ -158,7 +158,6 @@ def __init__( biases=None, trafo=None, hex_num_rotations=1, - hex_azimuth=None, hex_zero_out=False, float_precision=FLOAT_PRECISION, seed=None, @@ -279,13 +278,8 @@ def __init__( patch. hex_num_rotations : int, optional Only used if method == 'hex_convolution'. - If num_rotations >= 1: weights of a kernel will be shared over + If num_rotations > 1: weights of a kernel will be shared over 'num_rotations' many rotated versions of that kernel. - hex_azimuth : None or float or scalar float tf.Tensor - Only used if method == 'hex_convolution'. - Hexagonal kernel is turned by the angle 'azimuth' - [given in degrees] in counterclockwise direction. - If azimuth is None, the kernel will not be rotated dynamically. hex_zero_out : bool, optional Only used if method == 'hex_convolution'. If True, elements in result tensor which are not part of hexagon or @@ -372,6 +366,7 @@ def __init__( shape=shape, float_precision=float_precision, seed=self.cnt(), + name=self.name + "_weights", ) # Create new biases, one for each filter. @@ -380,6 +375,7 @@ def __init__( length=num_filters, float_precision=float_precision, seed=self.cnt(), + name=self.name + "_biases", ) if num_dims == 1 or num_dims == 2 or num_dims == 3: @@ -423,13 +419,13 @@ def temp_func(inputs): padding=padding, strides=strides, num_rotations=hex_num_rotations, - azimuth=hex_azimuth, dilation_rate=dilation_rate, zero_out=hex_zero_out, kernel=weights, var_list=var_list, float_precision=float_precision, seed=self.cnt(), + name=self.name + "_hex_conv", ) elif num_dims == 4: self.conv_layer = hx.ConvHex4d( @@ -439,13 +435,13 @@ def temp_func(inputs): padding=padding, strides=strides, num_rotations=hex_num_rotations, - azimuth=hex_azimuth, dilation_rate=dilation_rate, zero_out=hex_zero_out, kernel=weights, var_list=var_list, float_precision=float_precision, seed=self.cnt(), + name=self.name + "_hex_conv", ) # Create new biases, one for each filter. @@ -454,6 +450,7 @@ def temp_func(inputs): length=num_filters * hex_num_rotations, float_precision=float_precision, seed=self.cnt(), + name=self.name + "_biases", ) # ------------------- @@ -569,6 +566,7 @@ def temp_func(inputs): length=num_filters, float_precision=float_precision, seed=self.cnt(), + name=self.name + "_biases", ) else: @@ -857,12 +855,14 @@ def __init__( shape=[num_inputs, num_outputs], float_precision=float_precision, seed=seed, + name=self.name + "_weights", ) if biases is None: biases = new_biases( length=num_outputs, float_precision=float_precision, seed=seed, + name=self.name + "_biases", ) self.biases = biases @@ -1070,12 +1070,14 @@ def __init__( shape=[num_channels, num_inputs, num_outputs], float_precision=float_precision, seed=seed, + name=self.name + "_weights", ) if biases is None: biases = new_weights( shape=[num_outputs, num_channels], float_precision=float_precision, seed=seed, + name=self.name + "_biases", ) self.biases = biases @@ -1367,7 +1369,10 @@ def __init__( name="{}_{:03d}".format(name, i), ) if verbose: - print("{}_{:03d}".format(name, i), layer_i.output_shape) + print( + " {}_{:03d}: ".format(name, i), + list(layer_i.output_shape), + ) self.layers.append(layer_i) def __call__(self, inputs, is_training, keep_prob=None): @@ -1424,7 +1429,6 @@ def __init__( biases_list=None, trafo_list=None, hex_num_rotations_list=1, - hex_azimuth_list=None, hex_zero_out_list=False, float_precision=FLOAT_PRECISION, seed=None, @@ -1553,16 +1557,9 @@ def __init__( If only one trafo method is given, it will be used for all layers. hex_num_rotations_list : int or list of int, optional Only used if method == 'hex_convolution'. - If num_rotations >= 1: weights of a kernel will be shared over + If num_rotations > 1: weights of a kernel will be shared over 'num_rotations' many rotated versions of that kernel. If only one int is give, it will apply to all layers. - hex_azimuth_list : list of float or list scalar tf.Tensor, optional - Only used if method == 'hex_convolution'. - Hexagonal kernel is turned by the angle 'azimuth' - [given in degrees] in counterclockwise direction. - If azimuth is None, the kernel will not be rotated dynamically. - If only one azimuth angle is given, all layers will be turned by - the same angle. hex_zero_out_list : bool or list of bool, optional Only used if method == 'hex_convolution'. If True, elements in result tensor which are not part of hexagon or @@ -1733,10 +1730,6 @@ def __init__( hex_num_rotations_list for i in range(num_layers) ] - # create hex_azimuth_list - if hex_azimuth_list is None or tf.is_tensor(hex_azimuth_list): - hex_azimuth_list = [hex_azimuth_list for i in range(num_layers)] - # create hex_zero out array if isinstance(hex_zero_out_list, bool): hex_zero_out_list = [hex_zero_out_list for i in range(num_layers)] @@ -1770,14 +1763,16 @@ def __init__( biases=biases_list[i], trafo=trafo_list[i], hex_num_rotations=hex_num_rotations_list[i], - hex_azimuth=hex_azimuth_list[i], hex_zero_out=hex_zero_out_list[i], float_precision=float_precision, seed=self.cnt(), name="{}_{:03d}".format(name, i), ) if verbose: - print("{}_{:03d}".format(name, i), layer_i.output_shape) + print( + " {}_{:03d}:".format(name, i), + list(layer_i.output_shape), + ) self.layers.append(layer_i) def __call__(self, inputs, is_training, keep_prob=None):