diff --git a/pyrtl/rtllib/pyrtlfloat/__init__.py b/pyrtl/rtllib/pyrtlfloat/__init__.py new file mode 100644 index 00000000..d9b64710 --- /dev/null +++ b/pyrtl/rtllib/pyrtlfloat/__init__.py @@ -0,0 +1,20 @@ +from ._types import FloatingPointType, FPTypeProperties, PyrtlFloatConfig, RoundingMode +from .floatoperations import ( + BFloat16Operations, + Float16Operations, + Float32Operations, + Float64Operations, + FloatOperations, +) + +__all__ = [ + "FloatingPointType", + "FPTypeProperties", + "PyrtlFloatConfig", + "RoundingMode", + "FloatOperations", + "BFloat16Operations", + "Float16Operations", + "Float32Operations", + "Float64Operations", +] diff --git a/pyrtl/rtllib/pyrtlfloat/_add_sub.py b/pyrtl/rtllib/pyrtlfloat/_add_sub.py new file mode 100644 index 00000000..419d647f --- /dev/null +++ b/pyrtl/rtllib/pyrtlfloat/_add_sub.py @@ -0,0 +1,342 @@ +import pyrtl + +from ._float_utills import FloatUtils +from ._types import PyrtlFloatConfig, RoundingMode + + +class AddSubHelper: + @staticmethod + def add( + config: PyrtlFloatConfig, + operand_a: pyrtl.WireVector, + operand_b: pyrtl.WireVector, + ) -> pyrtl.WireVector: + fp_type_props = config.fp_type_properties + rounding_mode = config.rounding_mode + num_exp_bits = fp_type_props.num_exponent_bits + num_mant_bits = fp_type_props.num_mantissa_bits + total_bits = num_exp_bits + num_mant_bits + 1 + + operand_a_daz = FloatUtils.make_denormals_zero(fp_type_props, operand_a) + operand_b_daz = FloatUtils.make_denormals_zero(fp_type_props, operand_b) + + # operand_smaller is the operand with the smaller absolute value and + # operand_larger is the operand with the larger absolute value + operand_smaller = pyrtl.WireVector(bitwidth=total_bits) + operand_larger = pyrtl.WireVector(bitwidth=total_bits) + + with pyrtl.conditional_assignment: + exponent_and_mantissa_len = num_mant_bits + num_exp_bits + with ( + operand_a_daz[:exponent_and_mantissa_len] + < operand_b_daz[:exponent_and_mantissa_len] + ): + operand_smaller |= operand_a_daz + operand_larger |= operand_b_daz + with pyrtl.otherwise: + operand_smaller |= operand_b_daz + operand_larger |= operand_a_daz + + smaller_operand_sign = FloatUtils.get_sign(fp_type_props, operand_smaller) + larger_operand_sign = FloatUtils.get_sign(fp_type_props, operand_larger) + smaller_operand_exponent = FloatUtils.get_exponent( + fp_type_props, operand_smaller + ) + larger_operand_exponent = FloatUtils.get_exponent(fp_type_props, operand_larger) + smaller_operand_mantissa = pyrtl.concat( + pyrtl.Const(1), FloatUtils.get_mantissa(fp_type_props, operand_smaller) + ) + larger_operand_mantissa = pyrtl.concat( + pyrtl.Const(1), FloatUtils.get_mantissa(fp_type_props, operand_larger) + ) + + exponent_diff = larger_operand_exponent - smaller_operand_exponent + smaller_mantissa_shifted = pyrtl.shift_right_logical( + smaller_operand_mantissa, exponent_diff + ) + grs = pyrtl.WireVector(bitwidth=3) # guard, round, sticky bits for rounding + with pyrtl.conditional_assignment: + with exponent_diff >= 2: + guard_and_round = pyrtl.shift_right_logical( + smaller_operand_mantissa, exponent_diff - 2 + )[:2] + mask = ( + pyrtl.shift_left_logical( + pyrtl.Const(1, bitwidth=num_mant_bits), exponent_diff - 2 + ) + - 1 + ) + sticky = (smaller_operand_mantissa & mask) != 0 + grs |= pyrtl.concat(guard_and_round, sticky) + with exponent_diff == 1: + grs |= pyrtl.concat( + smaller_operand_mantissa[0], pyrtl.Const(0, bitwidth=2) + ) + with pyrtl.otherwise: + grs |= 0 + smaller_mantissa_shifted_grs = pyrtl.concat(smaller_mantissa_shifted, grs) + larger_mantissa_extended = pyrtl.concat( + larger_operand_mantissa, pyrtl.Const(0, bitwidth=3) + ) + + sum_exponent, sum_mantissa, sum_grs, sum_carry = AddSubHelper._add_operands( + larger_operand_exponent, + smaller_mantissa_shifted_grs, + larger_mantissa_extended, + ) + + sub_exponent, sub_mantissa, sub_grs, num_leading_zeros = ( + AddSubHelper._sub_operands( + num_mant_bits, + larger_operand_exponent, + smaller_mantissa_shifted_grs, + larger_mantissa_extended, + ) + ) + + # WireVectors for the raw addition or subtraction result, before handling + # special cases + raw_result_exponent = pyrtl.WireVector(bitwidth=num_exp_bits) + raw_result_mantissa = pyrtl.WireVector(bitwidth=num_mant_bits) + if rounding_mode == RoundingMode.RNE: + raw_result_grs = pyrtl.WireVector(bitwidth=3) + + with pyrtl.conditional_assignment: + with smaller_operand_sign == larger_operand_sign: # add + raw_result_exponent |= sum_exponent + raw_result_mantissa |= sum_mantissa + if rounding_mode == RoundingMode.RNE: + raw_result_grs |= sum_grs + with pyrtl.otherwise: # sub + raw_result_exponent |= sub_exponent + raw_result_mantissa |= sub_mantissa + if rounding_mode == RoundingMode.RNE: + raw_result_grs |= sub_grs + + if rounding_mode == RoundingMode.RNE: + ( + raw_result_rounded_exponent, + raw_result_rounded_mantissa, + rounding_exponent_incremented, + ) = AddSubHelper._round( + num_mant_bits, + num_exp_bits, + raw_result_exponent, + raw_result_mantissa, + raw_result_grs, + ) + + smaller_operand_nan = FloatUtils.is_NaN(fp_type_props, operand_smaller) + larger_operand_nan = FloatUtils.is_NaN(fp_type_props, operand_larger) + smaller_operand_inf = FloatUtils.is_inf(fp_type_props, operand_smaller) + larger_operand_inf = FloatUtils.is_inf(fp_type_props, operand_larger) + smaller_operand_zero = FloatUtils.is_zero(fp_type_props, operand_smaller) + larger_operand_zero = FloatUtils.is_zero(fp_type_props, operand_larger) + + # WireVectors for the final result after handling special cases + final_result_sign = pyrtl.WireVector(bitwidth=1) + final_result_mantissa = pyrtl.WireVector(bitwidth=num_mant_bits) + final_result_exponent = pyrtl.WireVector(bitwidth=num_exp_bits) + + # handle special cases + with pyrtl.conditional_assignment: + # if either operand is NaN or both operands are infinity of opposite signs, + # the result is NaN + with ( + smaller_operand_nan + | larger_operand_nan + | ( + smaller_operand_inf + & larger_operand_inf + & (larger_operand_sign != smaller_operand_sign) + ) + ): + final_result_sign |= larger_operand_sign + FloatUtils.make_output_NaN( + fp_type_props, final_result_exponent, final_result_mantissa + ) + # infinities + with smaller_operand_inf: + final_result_sign |= larger_operand_sign + FloatUtils.make_output_inf( + fp_type_props, final_result_exponent, final_result_mantissa + ) + with larger_operand_inf: + final_result_sign |= larger_operand_sign + FloatUtils.make_output_inf( + fp_type_props, final_result_exponent, final_result_mantissa + ) + # +num + -num = +0 + with ( + (smaller_operand_mantissa == larger_operand_mantissa) + & (smaller_operand_exponent == larger_operand_exponent) + & (larger_operand_sign != smaller_operand_sign) + ): + final_result_sign |= 0 + FloatUtils.make_output_zero( + final_result_exponent, final_result_mantissa + ) + with smaller_operand_zero: + final_result_sign |= larger_operand_sign + final_result_mantissa |= larger_operand_mantissa + final_result_exponent |= larger_operand_exponent + with larger_operand_zero: + final_result_sign |= smaller_operand_sign + final_result_mantissa |= smaller_operand_mantissa + final_result_exponent |= smaller_operand_exponent + # overflow and underflow + initial_larger_exponent_max_value = pyrtl.Const(2**num_exp_bits - 2) + if rounding_mode == RoundingMode.RNE: + larger_exponent_max_value = ( + initial_larger_exponent_max_value + - sum_carry + - rounding_exponent_incremented + ) + else: + larger_exponent_max_value = ( + initial_larger_exponent_max_value - sum_carry + ) + initial_larger_exponent_min_value = pyrtl.Const(1) + if rounding_mode == RoundingMode.RNE: + larger_exponent_min_value = ( + initial_larger_exponent_min_value + + num_leading_zeros + - rounding_exponent_incremented + ) + else: + larger_exponent_min_value = ( + initial_larger_exponent_min_value + num_leading_zeros + ) + with (smaller_operand_sign == larger_operand_sign) & ( + larger_operand_exponent > larger_exponent_max_value + ): # detect overflow on addition + final_result_sign |= larger_operand_sign + if rounding_mode == RoundingMode.RNE: + FloatUtils.make_output_inf( + fp_type_props, final_result_exponent, final_result_mantissa + ) + else: + FloatUtils.make_output_largest_finite_number( + fp_type_props, final_result_exponent, final_result_mantissa + ) + with (smaller_operand_sign != larger_operand_sign) & ( + larger_operand_exponent < larger_exponent_min_value + ): # detect underflow on subtraction + final_result_sign |= larger_operand_sign + FloatUtils.make_output_zero( + final_result_exponent, final_result_mantissa + ) + with pyrtl.otherwise: + final_result_sign |= larger_operand_sign + if rounding_mode == RoundingMode.RNE: + final_result_exponent |= raw_result_rounded_exponent + final_result_mantissa |= raw_result_rounded_mantissa + else: + final_result_exponent |= raw_result_exponent + final_result_mantissa |= raw_result_mantissa + + return pyrtl.concat( + final_result_sign, final_result_exponent, final_result_mantissa + ) + + @staticmethod + def sub( + config: PyrtlFloatConfig, + operand_a: pyrtl.WireVector, + operand_b: pyrtl.WireVector, + ) -> pyrtl.WireVector: + num_exp_bits = config.fp_type_properties.num_exponent_bits + num_mant_bits = config.fp_type_properties.num_mantissa_bits + operand_b_negated = operand_b ^ pyrtl.concat( + pyrtl.Const(1, bitwidth=1), + pyrtl.Const(0, bitwidth=num_exp_bits + num_mant_bits), + ) + return AddSubHelper.add(config, operand_a, operand_b_negated) + + @staticmethod + def _add_operands( + larger_operand_exponent: pyrtl.WireVector, + smaller_mantissa_shifted_grs: pyrtl.WireVector, + larger_mantissa_extended: pyrtl.WireVector, + ) -> tuple[pyrtl.WireVector, pyrtl.WireVector, pyrtl.WireVector, pyrtl.WireVector]: + sum_mantissa_grs = pyrtl.WireVector() + sum_mantissa_grs <<= larger_mantissa_extended + smaller_mantissa_shifted_grs + sum_carry = sum_mantissa_grs[-1] + sum_mantissa = pyrtl.select( + sum_carry, sum_mantissa_grs[4:], sum_mantissa_grs[3:-1] + ) + sum_grs = pyrtl.select( + sum_carry, + pyrtl.concat(sum_mantissa_grs[2:4], sum_mantissa_grs[:2] != 0), + sum_mantissa_grs[:3], + ) + sum_exponent = pyrtl.select( + sum_carry, larger_operand_exponent + 1, larger_operand_exponent + ) + return sum_exponent, sum_mantissa, sum_grs, sum_carry + + @staticmethod + def _sub_operands( + num_mant_bits: int, + larger_operand_exponent: pyrtl.WireVector, + smaller_mantissa_shifted_grs: pyrtl.WireVector, + larger_mantissa_extended: pyrtl.WireVector, + ) -> tuple[pyrtl.WireVector, pyrtl.WireVector, pyrtl.WireVector, pyrtl.WireVector]: + def leading_zero_priority_encoder(wire: pyrtl.WireVector, length: int): + out = pyrtl.WireVector( + bitwidth=pyrtl.infer_val_and_bitwidth(length - 1).bitwidth + ) + with pyrtl.conditional_assignment: + for i in range(wire.bitwidth - 1, wire.bitwidth - length - 1, -1): + with wire[i]: + out |= wire.bitwidth - i - 1 + return out + + sub_mantissa_grs = pyrtl.WireVector(bitwidth=num_mant_bits + 4) + sub_mantissa_grs <<= larger_mantissa_extended - smaller_mantissa_shifted_grs + num_leading_zeros = leading_zero_priority_encoder( + sub_mantissa_grs, num_mant_bits + 1 + ) + sub_mantissa_grs_shifted = pyrtl.shift_left_logical( + sub_mantissa_grs, num_leading_zeros + ) + sub_mantissa = sub_mantissa_grs_shifted[3:] + sub_grs = sub_mantissa_grs_shifted[:3] + sub_exponent = larger_operand_exponent - num_leading_zeros + return sub_exponent, sub_mantissa, sub_grs, num_leading_zeros + + @staticmethod + def _round( + num_mant_bits: int, + num_exp_bits: int, + raw_result_exponent: pyrtl.WireVector, + raw_result_mantissa: pyrtl.WireVector, + raw_result_grs: pyrtl.WireVector, + ) -> tuple[pyrtl.WireVector, pyrtl.WireVector]: + last = raw_result_mantissa[0] + guard = raw_result_grs[2] + round = raw_result_grs[1] + sticky = raw_result_grs[0] + round_up = guard & (last | round | sticky) + raw_result_rounded_mantissa = pyrtl.WireVector(bitwidth=num_mant_bits) + raw_result_rounded_exponent = pyrtl.WireVector(bitwidth=num_exp_bits) + rounding_exponent_incremented = pyrtl.WireVector(bitwidth=1) + with pyrtl.conditional_assignment: + with round_up: + with raw_result_mantissa == (1 << num_mant_bits) - 1: + raw_result_rounded_mantissa |= 0 + raw_result_rounded_exponent |= raw_result_exponent + 1 + rounding_exponent_incremented |= 1 + with pyrtl.otherwise: + raw_result_rounded_mantissa |= raw_result_mantissa + 1 + raw_result_rounded_exponent |= raw_result_exponent + rounding_exponent_incremented |= 0 + with pyrtl.otherwise: + raw_result_rounded_mantissa |= raw_result_mantissa + raw_result_rounded_exponent |= raw_result_exponent + rounding_exponent_incremented |= 0 + return ( + raw_result_rounded_exponent, + raw_result_rounded_mantissa, + rounding_exponent_incremented, + ) diff --git a/pyrtl/rtllib/pyrtlfloat/_float_utills.py b/pyrtl/rtllib/pyrtlfloat/_float_utills.py new file mode 100644 index 00000000..0ae58329 --- /dev/null +++ b/pyrtl/rtllib/pyrtlfloat/_float_utills.py @@ -0,0 +1,104 @@ +import pyrtl + +from ._types import FPTypeProperties + + +class FloatUtils: + @staticmethod + def get_sign(fp_prop: FPTypeProperties, wire: pyrtl.WireVector) -> pyrtl.WireVector: + return wire[fp_prop.num_mantissa_bits + fp_prop.num_exponent_bits] + + @staticmethod + def get_exponent( + fp_prop: FPTypeProperties, wire: pyrtl.WireVector + ) -> pyrtl.WireVector: + return wire[ + fp_prop.num_mantissa_bits : fp_prop.num_mantissa_bits + + fp_prop.num_exponent_bits + ] + + @staticmethod + def get_mantissa( + fp_prop: FPTypeProperties, wire: pyrtl.WireVector + ) -> pyrtl.WireVector: + return wire[: fp_prop.num_mantissa_bits] + + @staticmethod + def is_zero(fp_prop: FPTypeProperties, wire: pyrtl.WireVector) -> pyrtl.WireVector: + return (FloatUtils.get_mantissa(fp_prop, wire) == 0) & ( + FloatUtils.get_exponent(fp_prop, wire) == 0 + ) + + @staticmethod + def is_inf(fp_prop: FPTypeProperties, wire: pyrtl.WireVector) -> pyrtl.WireVector: + return (FloatUtils.get_mantissa(fp_prop, wire) == 0) & ( + FloatUtils.get_exponent(fp_prop, wire) + == (1 << fp_prop.num_exponent_bits) - 1 + ) + + @staticmethod + def is_denormalized( + fp_prop: FPTypeProperties, wire: pyrtl.WireVector + ) -> pyrtl.WireVector: + return (FloatUtils.get_mantissa(fp_prop, wire) != 0) & ( + FloatUtils.get_exponent(fp_prop, wire) == 0 + ) + + @staticmethod + def is_NaN(fp_prop: FPTypeProperties, wire: pyrtl.WireVector) -> pyrtl.WireVector: + return (FloatUtils.get_mantissa(fp_prop, wire) != 0) & ( + FloatUtils.get_exponent(fp_prop, wire) + == (1 << fp_prop.num_exponent_bits) - 1 + ) + + @staticmethod + def make_denormals_zero( + fp_prop: FPTypeProperties, wire: pyrtl.WireVector + ) -> pyrtl.WireVector: + out = pyrtl.WireVector( + bitwidth=fp_prop.num_mantissa_bits + fp_prop.num_exponent_bits + 1 + ) + with pyrtl.conditional_assignment: + with FloatUtils.get_exponent(fp_prop, wire) == 0: + out |= pyrtl.concat( + FloatUtils.get_sign(fp_prop, wire), + FloatUtils.get_exponent(fp_prop, wire), + pyrtl.Const(0, bitwidth=fp_prop.num_mantissa_bits), + ) + with pyrtl.otherwise: + out |= wire + return out + + @staticmethod + def make_output_inf( + fp_prop: FPTypeProperties, + exponent: pyrtl.WireVector, + mantissa: pyrtl.WireVector, + ) -> None: + exponent |= (1 << fp_prop.num_exponent_bits) - 1 + mantissa |= 0 + + @staticmethod + def make_output_NaN( + fp_prop: FPTypeProperties, + exponent: pyrtl.WireVector, + mantissa: pyrtl.WireVector, + ) -> None: + exponent |= (1 << fp_prop.num_exponent_bits) - 1 + mantissa |= 1 << (fp_prop.num_mantissa_bits - 1) + + @staticmethod + def make_output_zero( + exponent: pyrtl.WireVector, mantissa: pyrtl.WireVector + ) -> None: + exponent |= 0 + mantissa |= 0 + + @staticmethod + def make_output_largest_finite_number( + fp_prop: FPTypeProperties, + exponent: pyrtl.WireVector, + mantissa: pyrtl.WireVector, + ) -> None: + exponent |= (1 << fp_prop.num_exponent_bits) - 2 + mantissa |= (1 << fp_prop.num_mantissa_bits) - 1 diff --git a/pyrtl/rtllib/pyrtlfloat/_multiplication.py b/pyrtl/rtllib/pyrtlfloat/_multiplication.py new file mode 100644 index 00000000..a0de7d25 --- /dev/null +++ b/pyrtl/rtllib/pyrtlfloat/_multiplication.py @@ -0,0 +1,161 @@ +import pyrtl + +from ._float_utills import FloatUtils +from ._types import PyrtlFloatConfig, RoundingMode + + +class MultiplicationHelper: + @staticmethod + def multiply( + config: PyrtlFloatConfig, + operand_a: pyrtl.WireVector, + operand_b: pyrtl.WireVector, + ) -> pyrtl.WireVector: + fp_type_props = config.fp_type_properties + rounding_mode = config.rounding_mode + num_exp_bits = fp_type_props.num_exponent_bits + num_mant_bits = fp_type_props.num_mantissa_bits + + operand_a_daz = FloatUtils.make_denormals_zero(fp_type_props, operand_a) + operand_b_daz = FloatUtils.make_denormals_zero(fp_type_props, operand_b) + a_sign = FloatUtils.get_sign(fp_type_props, operand_a_daz) + b_sign = FloatUtils.get_sign(fp_type_props, operand_b_daz) + a_exponent = FloatUtils.get_exponent(fp_type_props, operand_a_daz) + b_exponent = FloatUtils.get_exponent(fp_type_props, operand_b_daz) + + exponent_bias = 2 ** (fp_type_props.num_exponent_bits - 1) - 1 + + result_sign = a_sign ^ b_sign + operand_exponent_sums = a_exponent + b_exponent + product_exponent = operand_exponent_sums - pyrtl.Const(exponent_bias) + + a_mantissa = pyrtl.concat( + pyrtl.Const(1), FloatUtils.get_mantissa(fp_type_props, operand_a_daz) + ) + b_mantissa = pyrtl.concat( + pyrtl.Const(1), FloatUtils.get_mantissa(fp_type_props, operand_b_daz) + ) + product_mantissa = a_mantissa * b_mantissa + + normalized_product_exponent = pyrtl.WireVector(bitwidth=num_exp_bits + 1) + normalized_product_mantissa = pyrtl.WireVector(bitwidth=num_mant_bits) + + need_to_normalize = product_mantissa[-1] + + if rounding_mode == RoundingMode.RNE: + guard = pyrtl.WireVector(bitwidth=1) + sticky = pyrtl.WireVector(bitwidth=1) + last = pyrtl.WireVector(bitwidth=1) + + with pyrtl.conditional_assignment: + with need_to_normalize: + normalized_product_mantissa |= product_mantissa[-num_mant_bits - 1 :] + normalized_product_exponent |= product_exponent + 1 + if rounding_mode == RoundingMode.RNE: + guard |= product_mantissa[-num_mant_bits - 2] + sticky |= product_mantissa[: -num_mant_bits - 2] != 0 + last |= product_mantissa[-num_mant_bits - 1] + with pyrtl.otherwise: + normalized_product_mantissa |= product_mantissa[-num_mant_bits - 2 : -1] + normalized_product_exponent |= product_exponent + if rounding_mode == RoundingMode.RNE: + guard |= product_mantissa[-num_mant_bits - 3] + sticky |= product_mantissa[: -num_mant_bits - 3] != 0 + last |= product_mantissa[-num_mant_bits - 2] + + if rounding_mode == RoundingMode.RNE: + rounded_product_mantissa = pyrtl.WireVector(bitwidth=num_mant_bits) + rounded_product_exponent = pyrtl.WireVector(bitwidth=num_exp_bits + 1) + exponent_incremented = pyrtl.WireVector(bitwidth=1) + with pyrtl.conditional_assignment: + with guard & (last | sticky): + with normalized_product_mantissa == (1 << num_mant_bits) - 1: + rounded_product_mantissa |= 0 + rounded_product_exponent |= normalized_product_exponent + 1 + exponent_incremented |= 1 + with pyrtl.otherwise: + rounded_product_mantissa |= normalized_product_mantissa + 1 + rounded_product_exponent |= normalized_product_exponent + exponent_incremented |= 0 + with pyrtl.otherwise: + rounded_product_mantissa |= normalized_product_mantissa + rounded_product_exponent |= normalized_product_exponent + exponent_incremented |= 0 + + result_exponent = pyrtl.WireVector(bitwidth=num_exp_bits) + result_mantissa = pyrtl.WireVector(bitwidth=num_mant_bits) + + operand_a_nan = FloatUtils.is_NaN(fp_type_props, operand_a_daz) + operand_b_nan = FloatUtils.is_NaN(fp_type_props, operand_b_daz) + operand_a_inf = FloatUtils.is_inf(fp_type_props, operand_a_daz) + operand_b_inf = FloatUtils.is_inf(fp_type_props, operand_b_daz) + operand_a_zero = FloatUtils.is_zero(fp_type_props, operand_a_daz) + operand_b_zero = FloatUtils.is_zero(fp_type_props, operand_b_daz) + operand_a_denormalized = FloatUtils.is_denormalized( + fp_type_props, operand_a_daz + ) + operand_b_denormalized = FloatUtils.is_denormalized( + fp_type_props, operand_b_daz + ) + + # Overflow and underflow checks (only for normal cases) + sum_exponent_max_value = pyrtl.Const(2**num_exp_bits - 2 + exponent_bias) + sum_exponent_min_value = pyrtl.Const(1 + exponent_bias) + if rounding_mode == RoundingMode.RNE: + exponent_max_value = ( + sum_exponent_max_value - need_to_normalize - exponent_incremented + ) + exponent_min_value = ( + sum_exponent_min_value - need_to_normalize - exponent_incremented + ) + else: + exponent_max_value = sum_exponent_max_value - need_to_normalize + exponent_min_value = sum_exponent_min_value - need_to_normalize + + if rounding_mode == RoundingMode.RNE: + raw_result_exponent = rounded_product_exponent[0:num_exp_bits] + raw_result_mantissa = rounded_product_mantissa + else: + raw_result_exponent = normalized_product_exponent[0:num_exp_bits] + raw_result_mantissa = normalized_product_mantissa + + with pyrtl.conditional_assignment: + # nan + with ( + operand_a_nan + | operand_b_nan + | (operand_a_inf & operand_b_zero) + | (operand_a_zero & operand_b_inf) + ): + FloatUtils.make_output_NaN( + fp_type_props, result_exponent, result_mantissa + ) + # infinity + with operand_a_inf | operand_b_inf: + FloatUtils.make_output_inf( + fp_type_props, result_exponent, result_mantissa + ) + # overflow + with operand_exponent_sums > exponent_max_value: + if rounding_mode == RoundingMode.RNE: + FloatUtils.make_output_inf( + fp_type_props, result_exponent, result_mantissa + ) + else: + FloatUtils.make_output_largest_finite_number( + fp_type_props, result_exponent, result_mantissa + ) + # zero or underflow + with ( + operand_a_zero + | operand_b_zero + | (operand_exponent_sums < exponent_min_value) + | operand_a_denormalized + | operand_b_denormalized + ): + FloatUtils.make_output_zero(result_exponent, result_mantissa) + with pyrtl.otherwise: + result_exponent |= raw_result_exponent + result_mantissa |= raw_result_mantissa + + return pyrtl.concat(result_sign, result_exponent, result_mantissa) diff --git a/pyrtl/rtllib/pyrtlfloat/_types.py b/pyrtl/rtllib/pyrtlfloat/_types.py new file mode 100644 index 00000000..15a1c811 --- /dev/null +++ b/pyrtl/rtllib/pyrtlfloat/_types.py @@ -0,0 +1,30 @@ +from dataclasses import dataclass +from enum import Enum + + +class RoundingMode(Enum): + RTZ = 1 # round towards zero (truncate) + RNE = 2 # round to nearest, ties to even (default mode) + + +@dataclass(frozen=True) +class FPTypeProperties: + num_exponent_bits: int + num_mantissa_bits: int + + +class FloatingPointType(Enum): + BFLOAT16 = FPTypeProperties(num_exponent_bits=8, num_mantissa_bits=7) + FLOAT16 = FPTypeProperties(num_exponent_bits=5, num_mantissa_bits=10) + FLOAT32 = FPTypeProperties(num_exponent_bits=8, num_mantissa_bits=23) + FLOAT64 = FPTypeProperties(num_exponent_bits=11, num_mantissa_bits=52) + + +@dataclass(frozen=True) +class PyrtlFloatConfig: + fp_type_properties: FPTypeProperties + rounding_mode: RoundingMode + + +class PyrtlFloatException(Exception): + pass diff --git a/pyrtl/rtllib/pyrtlfloat/floatoperations.py b/pyrtl/rtllib/pyrtlfloat/floatoperations.py new file mode 100644 index 00000000..ef081b0a --- /dev/null +++ b/pyrtl/rtllib/pyrtlfloat/floatoperations.py @@ -0,0 +1,77 @@ +import pyrtl + +from ._add_sub import AddSubHelper +from ._multiplication import MultiplicationHelper +from ._types import FloatingPointType, PyrtlFloatConfig, RoundingMode + + +class FloatOperations: + default_rounding_mode = RoundingMode.RNE + + @staticmethod + def mul( + config: PyrtlFloatConfig, + operand_a: pyrtl.WireVector, + operand_b: pyrtl.WireVector, + ) -> pyrtl.WireVector: + return MultiplicationHelper.multiply(config, operand_a, operand_b) + + @staticmethod + def add( + config: PyrtlFloatConfig, + operand_a: pyrtl.WireVector, + operand_b: pyrtl.WireVector, + ) -> pyrtl.WireVector: + return AddSubHelper.add(config, operand_a, operand_b) + + @staticmethod + def sub( + config: PyrtlFloatConfig, + operand_a: pyrtl.WireVector, + operand_b: pyrtl.WireVector, + ) -> pyrtl.WireVector: + return AddSubHelper.sub(config, operand_a, operand_b) + + +class _BaseTypedFloatOperations: + _fp_type: FloatingPointType = None + + @classmethod + def mul( + cls, operand_a: pyrtl.WireVector, operand_b: pyrtl.WireVector + ) -> pyrtl.WireVector: + return FloatOperations.mul(cls._get_config(), operand_a, operand_b) + + @classmethod + def add( + cls, operand_a: pyrtl.WireVector, operand_b: pyrtl.WireVector + ) -> pyrtl.WireVector: + return FloatOperations.add(cls._get_config(), operand_a, operand_b) + + @classmethod + def sub( + cls, operand_a: pyrtl.WireVector, operand_b: pyrtl.WireVector + ) -> pyrtl.WireVector: + return FloatOperations.sub(cls._get_config(), operand_a, operand_b) + + @classmethod + def _get_config(cls) -> PyrtlFloatConfig: + return PyrtlFloatConfig( + cls._fp_type.value, FloatOperations.default_rounding_mode + ) + + +class BFloat16Operations(_BaseTypedFloatOperations): + _fp_type = FloatingPointType.BFLOAT16 + + +class Float16Operations(_BaseTypedFloatOperations): + _fp_type = FloatingPointType.FLOAT16 + + +class Float32Operations(_BaseTypedFloatOperations): + _fp_type = FloatingPointType.FLOAT32 + + +class Float64Operations(_BaseTypedFloatOperations): + _fp_type = FloatingPointType.FLOAT64 diff --git a/tests/rtllib/pyrtlfloat/test_add_sub.py b/tests/rtllib/pyrtlfloat/test_add_sub.py new file mode 100644 index 00000000..3f006da2 --- /dev/null +++ b/tests/rtllib/pyrtlfloat/test_add_sub.py @@ -0,0 +1,26 @@ +import unittest + +import pyrtl +from pyrtl.rtllib.pyrtlfloat import Float16Operations, FloatOperations, RoundingMode + + +class TestMultiplication(unittest.TestCase): + def setUp(self): + pyrtl.reset_working_block() + a = pyrtl.Input(bitwidth=16, name="a") + b = pyrtl.Input(bitwidth=16, name="b") + FloatOperations.default_rounding_mode = RoundingMode.RNE + result_add = pyrtl.Output(name="result_add") + result_add <<= Float16Operations.add(a, b) + result_sub = pyrtl.Output(name="result_sub") + result_sub <<= Float16Operations.sub(a, b) + self.sim = pyrtl.Simulation() + + def test_multiplication_simple(self): + self.sim.step({"a": 0b0100001000000000, "b": 0b0100010100000000}) + self.assertEqual(self.sim.inspect("result_add"), 0b0100100000000000) + self.assertEqual(self.sim.inspect("result_sub"), 0b1100000000000000) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/rtllib/pyrtlfloat/test_multiplication.py b/tests/rtllib/pyrtlfloat/test_multiplication.py new file mode 100644 index 00000000..439efabb --- /dev/null +++ b/tests/rtllib/pyrtlfloat/test_multiplication.py @@ -0,0 +1,27 @@ +import unittest + +import pyrtl +from pyrtl.rtllib.pyrtlfloat import Float16Operations, FloatOperations, RoundingMode + + +class TestMultiplication(unittest.TestCase): + def setUp(self): + pyrtl.reset_working_block() + a = pyrtl.Input(bitwidth=16, name="a") + b = pyrtl.Input(bitwidth=16, name="b") + FloatOperations.default_rounding_mode = RoundingMode.RNE + result_rne = pyrtl.Output(name="result_rne") + result_rne <<= Float16Operations.mul(a, b) + FloatOperations.default_rounding_mode = RoundingMode.RTZ + result_rtz = pyrtl.Output(name="result_rtz") + result_rtz <<= Float16Operations.mul(a, b) + self.sim = pyrtl.Simulation() + + def test_multiplication_simple(self): + self.sim.step({"a": 0b0011111000000000, "b": 0b0011110000000001}) + self.assertEqual(self.sim.inspect("result_rne"), 0b0011111000000010) + self.assertEqual(self.sim.inspect("result_rtz"), 0b0011111000000001) + + +if __name__ == "__main__": + unittest.main()