diff --git a/prosit/layers.py b/prosit/layers.py index 9e642f9..608f2d3 100644 --- a/prosit/layers.py +++ b/prosit/layers.py @@ -32,7 +32,7 @@ def __init__( def build(self, input_shape): assert len(input_shape) == 3 self.W = self.add_weight( - (input_shape[-1],), + shape=(input_shape[-1],), initializer=self.init, name="{}_W".format(self.name), regularizer=self.W_regularizer, @@ -40,7 +40,7 @@ def build(self, input_shape): ) if self.bias: self.b = self.add_weight( - (input_shape[1],), + shape=(input_shape[1],), initializer="zero", name="{}_b".format(self.name), regularizer=self.b_regularizer, @@ -50,7 +50,7 @@ def build(self, input_shape): self.b = None if self.context: self.u = self.add_weight( - (input_shape[-1],), + shape=(input_shape[-1],), initializer=self.init, name="{}_u".format(self.name), regularizer=self.u_regularizer,