From 63272a1b8b07b8445addee4a40466649f97bc199 Mon Sep 17 00:00:00 2001 From: arvindajaybharadwaj Date: Sun, 17 Aug 2025 23:38:50 +0530 Subject: [PATCH 01/33] Add positutils.py --- pyrtl/positutils.py | 152 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 152 insertions(+) create mode 100644 pyrtl/positutils.py diff --git a/pyrtl/positutils.py b/pyrtl/positutils.py new file mode 100644 index 00000000..269ae5e2 --- /dev/null +++ b/pyrtl/positutils.py @@ -0,0 +1,152 @@ +"""Implements utility functions for posit operations""" + +import pyrtl +from pyrtl.corecircuits import shift_right_logical, shift_left_logical + +#decode posit +def decode_posit(x, nbits, es): + sign = x[nbits - 1] + rest = [x[nbits - 2 - i] for i in range(nbits - 1)] + regime_bit = rest[0] + run_len = pyrtl.Const(0, bitwidth=nbits) + active = pyrtl.Const(1, bitwidth=1) + + for i in range(nbits - 1): + bit = rest[i] + is_same = bit == regime_bit + run_len = run_len + pyrtl.select( + active & is_same, + pyrtl.Const(1), + pyrtl.Const(0) + ) + active = active & is_same + + k_pos = run_len - pyrtl.Const(1, bitwidth=nbits) + k_neg = (~run_len) + pyrtl.Const(1, bitwidth=nbits) + k = pyrtl.select(regime_bit, k_pos, k_neg) + + exp = pyrtl.Const(0, bitwidth=es) + for i in range(nbits - 2): + exp = pyrtl.select(run_len == pyrtl.Const(i), rest[i + 1], exp) + + start_idx = run_len + pyrtl.Const(1 + es, bitwidth=nbits) + fraction_bits = [] + + for i in range(nbits - 1): + in_range = (i >= start_idx) & (i < (nbits - 1)) + bit_val = pyrtl.select(in_range, rest[i], pyrtl.Const(0)) + fraction_bits.append(bit_val) + + frac_result = pyrtl.concat_list(fraction_bits[::-1]) + fraction_length = ( + pyrtl.Const(nbits, bitwidth=nbits) + - run_len + - pyrtl.Const(es, bitwidth=nbits) + - pyrtl.Const(2, bitwidth=nbits) + ) + + return sign, k, exp, frac_result, fraction_length + +def get_upto_regime(k, n_val, sign_final): + precomputed_val = (1 << (n_val - 1)) - 1 + n_c = pyrtl.Const(n_val, bitwidth=n_val) + n_minus_1 = pyrtl.Const(n_val - 1, bitwidth=n_val) + n_minus_2 = pyrtl.Const(n_val - 2, bitwidth=n_val) + n_minus_3 = pyrtl.Const(n_val - 3, bitwidth=n_val) + + k_thresh = pyrtl.Const(1 << (n_val - 1), bitwidth=k.bitwidth) + abs_k = pyrtl.select( + k >= k_thresh, + ( + (~k + pyrtl.Const(1, bitwidth=k.bitwidth)) + & pyrtl.Const((1 << n_val) - 1, bitwidth=k.bitwidth) + ), + k, + ) + + large_neg_regime = abs_k >= n_minus_1 + large_pos_regime = abs_k >= n_minus_2 + + temp_rem1 = (n_c + k) - pyrtl.Const(2, bitwidth=n_val) + sign_case1_inner = shift_right_logical( + pyrtl.Const(1 << (n_val - 2), bitwidth=n_val), abs_k + ) + + rem_bits_case1 = pyrtl.select( + large_neg_regime, + pyrtl.Const(0, bitwidth=n_val), + temp_rem1, + ) + sign_case1 = pyrtl.select( + large_neg_regime, + pyrtl.Const(0, bitwidth=n_val), + sign_case1_inner, + ) + + temp_rem2 = n_minus_3 - k + shift_amt = k + pyrtl.Const(2, bitwidth=n_val) + ones = ( + shift_left_logical(pyrtl.Const(1, bitwidth=n_val), shift_amt) + - pyrtl.Const(2, bitwidth=n_val) + ) + shifted_case2 = shift_left_logical(ones, temp_rem2) + + rem_bits_case2 = pyrtl.select( + large_pos_regime, + pyrtl.Const(0, bitwidth=n_val), + temp_rem2, + ) + sign_case2 = pyrtl.select( + large_pos_regime, + pyrtl.Const(precomputed_val, bitwidth=n_val), + shifted_case2, + ) + + cond_k_ge = k >= k_thresh + rem_bits = pyrtl.select(cond_k_ge, rem_bits_case1, rem_bits_case2) + sign_w_regime = pyrtl.select(cond_k_ge, sign_case1, sign_case2) + + sign_w_regime_trimmed = sign_w_regime[: n_val - 1] + sign_w_regime_final = pyrtl.concat(sign_final, sign_w_regime_trimmed) + return rem_bits, sign_w_regime_final + +def frac_with_hidden_one(frac, frac_length, nbits): + one_table = [ + pyrtl.Const(1 << i, bitwidth=32) + for i in range(nbits + 1) + ] + one_shifted = pyrtl.Const(0, bitwidth=32) + + for i in range(nbits + 1): + one_shifted = pyrtl.select( + frac_length == pyrtl.Const(i, bitwidth=8), + one_table[i], + one_shifted, + ) + + frac_32 = pyrtl.concat( + pyrtl.Const(0, bitwidth=32 - (nbits - 1)), frac + ) + full = one_shifted + frac_32 + return full + +def remove_first_one(val): + found = pyrtl.Const(0, bitwidth=1) + result_bits = [] + + for i in range(val.bitwidth): + bit = val[val.bitwidth - 1 - i] # MSB first + new_bit = pyrtl.select( + (found == 0) & (bit == 1), + pyrtl.Const(0, bitwidth=1), + bit, + ) + result_bits.append(new_bit) + + found = pyrtl.select( + (found == 0) & (bit == 1), + pyrtl.Const(1, bitwidth=1), + found, + ) + + return pyrtl.concat_list(result_bits[::-1]) From b5604ea3f6e9b9a0d1cd13c9ef31b381c8d5c9b2 Mon Sep 17 00:00:00 2001 From: arvindajaybharadwaj Date: Sun, 17 Aug 2025 23:48:22 +0530 Subject: [PATCH 02/33] Add positadder.py --- pyrtl/rtllib/positadder.py | 177 +++++++++++++++++++++++++++++++++++++ 1 file changed, 177 insertions(+) create mode 100644 pyrtl/rtllib/positadder.py diff --git a/pyrtl/rtllib/positadder.py b/pyrtl/rtllib/positadder.py new file mode 100644 index 00000000..3466e582 --- /dev/null +++ b/pyrtl/rtllib/positadder.py @@ -0,0 +1,177 @@ +import pyrtl +from pyrtl.corecircuits import shift_left_logical, shift_right_logical +from pyrtl.positutils import decode_posit, get_upto_regime, frac_with_hidden_one, remove_first_one + +def posit_add(a, b, nbits, es): + # Decode input posits into regime (k), exponent, fraction, and fraction length + _, k_a, exp_a, frac_a, frac_len_a = decode_posit(a, nbits, es) + _, k_b, exp_b, frac_b, frac_len_b = decode_posit(b, nbits, es) + + # match bitwidths of fractional part + frac_a_aligned = pyrtl.WireVector(bitwidth=nbits) + frac_b_aligned = pyrtl.WireVector(bitwidth=nbits) + + frac_len_a_aligned = pyrtl.WireVector(bitwidth=nbits) + frac_len_b_aligned = pyrtl.WireVector(bitwidth=nbits) + + with pyrtl.conditional_assignment: + with frac_len_a > frac_len_b: + shift_amt = frac_len_a - frac_len_b + frac_b_aligned |= shift_left_logical(frac_b, shift_amt) + frac_a_aligned |= frac_a + frac_len_b_aligned |= frac_len_b + shift_amt + frac_len_a_aligned |= frac_len_a + + with frac_len_a < frac_len_b: + shift_amt = frac_len_b - frac_len_a + frac_a_aligned |= shift_left_logical(frac_a, shift_amt) + frac_b_aligned |= frac_b + frac_len_a_aligned |= frac_len_a + shift_amt + frac_len_b_aligned |= frac_len_b + + with pyrtl.otherwise: + frac_a_aligned |= frac_a + frac_b_aligned |= frac_b + frac_len_a_aligned |= frac_len_a + frac_len_b_aligned |= frac_len_b + + # Add hidden leading one to fractions + frac_a_full = frac_with_hidden_one(frac_a_aligned, frac_len_a_aligned, + nbits) + frac_b_full = frac_with_hidden_one(frac_b_aligned, frac_len_b_aligned, + nbits) + + # Compute scales (regime*k + exponent) + scale_a = shift_left_logical(k_a, es) + exp_a + scale_b = shift_left_logical(k_b, es) + exp_b + + offset = scale_a - scale_b + + is_neg_offset = pyrtl.select( + offset > pyrtl.Const(127, bitwidth=nbits), + truecase=pyrtl.Const(1, bitwidth=nbits), + falsecase=pyrtl.Const(0, bitwidth=nbits) + ) + + shifted_a = pyrtl.WireVector(bitwidth=frac_a_full.bitwidth) + shifted_b = pyrtl.WireVector(bitwidth=frac_b_full.bitwidth) + result_scale = pyrtl.WireVector(bitwidth=offset.bitwidth) + result_exp = pyrtl.WireVector(bitwidth=es) + + neg_offset = (~offset) + pyrtl.Const(1, bitwidth=offset.bitwidth) + + # Align fractions based on offset + with pyrtl.conditional_assignment: + with is_neg_offset == pyrtl.Const(0, bitwidth=nbits): + shifted_b |= frac_b_full + shifted_a |= shift_left_logical(frac_a_full, offset) + result_scale |= scale_a + result_exp |= exp_a + + with is_neg_offset == pyrtl.Const(1, bitwidth=nbits): + shifted_b |= shift_left_logical(frac_b_full, neg_offset) + shifted_a |= frac_a_full + result_scale |= scale_b + result_exp |= exp_b + + # Add shifted fractions + result_frac = shifted_a + shifted_b + + # checking for overflow, if overflow, increase scale + result_scale = pyrtl.select( + offset == pyrtl.Const(0, bitwidth=offset.bitwidth), + result_scale + 1, + result_scale + ) + result_k = shift_right_logical(result_scale, + pyrtl.Const(es, bitwidth=nbits)) + + # Extract regime bits + rem_bits, regime_bits = get_upto_regime(result_k, nbits, 0) + + # Extract exponent from scale + result_exp = result_scale - shift_left_logical(result_k, es) + result_exp = shift_left_logical( + result_exp, rem_bits - pyrtl.Const(es, bitwidth=nbits) + ) + + # Remaining fraction length + frac_len = rem_bits - es + + # handling rounding of fractional bits + count = pyrtl.Const(0, bitwidth=nbits) + rounded_frac = pyrtl.Const(0, bitwidth=nbits) + found = pyrtl.Const(0, bitwidth=nbits) + + for i in range(nbits): + bit = result_frac[nbits - 1 - i] + cond = pyrtl.select(bit == pyrtl.Const(1), 1, found) + found = found | cond + count = pyrtl.select(found == pyrtl.Const(1), + count + 1, count) + + # Exclude leading one + count = count - 1 + bits_to_shift = count - frac_len + + # Normalize fraction + truncated_frac = shift_right_logical(result_frac, bits_to_shift) + + # Guard, Round, Sticky bits + ground_bit = result_frac & shift_right_logical( + pyrtl.Const(1, bitwidth=nbits), (bits_to_shift - 1) + ) + round_bit = result_frac & shift_right_logical( + pyrtl.Const(1, bitwidth=nbits), (bits_to_shift - 2) + ) + sticky_bit = result_frac & shift_right_logical( + pyrtl.Const(1, bitwidth=nbits), (bits_to_shift - 3) + ) + + cond = ground_bit & (round_bit | sticky_bit) + rounded_frac = pyrtl.WireVector(bitwidth=nbits) + + # Apply rounding rules + with pyrtl.conditional_assignment: + with ground_bit == pyrtl.Const(0): + rounded_frac |= truncated_frac + with cond == pyrtl.Const(1): + rounded_frac |= truncated_frac + 1 + with cond == pyrtl.Const(0): + with truncated_frac[0] == 1: + rounded_frac |= truncated_frac + 1 + with truncated_frac[0] == 0: + rounded_frac |= truncated_frac + + # Remove hidden one from rounded fraction + rounded_frac = remove_first_one(rounded_frac) + + # Combine regime, exponent, and fraction + added_posit = ( + pyrtl.Const(0, bitwidth=nbits) + + regime_bits + + result_exp + + rounded_frac + ) + result_posit = pyrtl.WireVector(bitwidth=nbits) + + # Checking for special cases (NaR and 0) + isNar = ( + pyrtl.select(a == pyrtl.Const(1 << nbits - 1, bitwidth=nbits), 1, 0) | + pyrtl.select(b == pyrtl.Const(1 << nbits - 1, bitwidth=nbits), 1, 0) + ) + + with pyrtl.conditional_assignment: + with isNar == pyrtl.Const(1, bitwidth=nbits): + result_posit |= pyrtl.Const(1 << nbits - 1, bitwidth=nbits) + + with a == pyrtl.Const(0, bitwidth=nbits): + result_posit |= b + + with b == pyrtl.Const(0, bitwidth=nbits): + result_posit |= a + + with pyrtl.otherwise: + result_posit |= added_posit + + return result_posit \ No newline at end of file From 3785b780e20004045dbf59cd886c224cdcc35bdd Mon Sep 17 00:00:00 2001 From: arvindajaybharadwaj Date: Sun, 17 Aug 2025 23:53:45 +0530 Subject: [PATCH 03/33] Add positmatmul.py --- pyrtl/rtllib/positmatmul.py | 37 +++++++++++++++++++++++++++++++++++++ 1 file changed, 37 insertions(+) create mode 100644 pyrtl/rtllib/positmatmul.py diff --git a/pyrtl/rtllib/positmatmul.py b/pyrtl/rtllib/positmatmul.py new file mode 100644 index 00000000..1145e0b7 --- /dev/null +++ b/pyrtl/rtllib/positmatmul.py @@ -0,0 +1,37 @@ +import pyrtl +from pyrtl import PyrtlError +from pyrtl.rtllib.matrix import Matrix +from positadder import posit_add +from positmul import posit_mul + +def posit_matmul(x, y, nbits, es): + if not isinstance(x, Matrix): + msg = f"error: expecting a Matrix, got {type(x)} instead" + raise PyrtlError(msg) + if not isinstance(y, Matrix): + msg = f"error: expecting a Matrix, got {type(y)} instead" + raise PyrtlError(msg) + + if x.columns != y.rows: + msg = ( + f"error: rows and columns mismatch. Matrix a: {x.columns} columns, " + f"Matrix b: {y.rows} rows" + ) + raise PyrtlError(msg) + + result = Matrix( + x.rows, + y.columns, + nbits, + max_bits=x.max_bits, + ) + + for i in range(x.rows): + for j in range(y.columns): + acc = pyrtl.Const(0, bitwidth=nbits) + for k in range(x.columns): + prod = posit_mul(nbits, es, x[i, k], y[k, j]) + acc = posit_add(acc, prod, nbits, es) + result[i, j] = acc + + return result \ No newline at end of file From e8bc4a7dd4b991aa164b4d8e99d65e1d3b13e739 Mon Sep 17 00:00:00 2001 From: arvindajaybharadwaj Date: Mon, 18 Aug 2025 09:02:15 +0530 Subject: [PATCH 04/33] Add docstrings --- pyrtl/positutils.py | 44 +++++++++++++++++++++++++++++++++++-- pyrtl/rtllib/positadder.py | 9 ++++++++ pyrtl/rtllib/positmatmul.py | 11 ++++++++++ 3 files changed, 62 insertions(+), 2 deletions(-) diff --git a/pyrtl/positutils.py b/pyrtl/positutils.py index 269ae5e2..5724baab 100644 --- a/pyrtl/positutils.py +++ b/pyrtl/positutils.py @@ -3,8 +3,20 @@ import pyrtl from pyrtl.corecircuits import shift_right_logical, shift_left_logical -#decode posit def decode_posit(x, nbits, es): + """Decode posit into its components and return them as a :class:`tuple`. + + :param x: A :class:`WireVector` that represents the posit. + :param nbits: A :class:`int` that represents the bitwidth of the posit. + :param es: A :class:`int` that represents the exponent size of the posit. + + :return: A :class:`tuple` consisting of: + - :class:`WireVector` for sign + - :class:`WireVector` for k + - :class:`WireVector` for exponent + - :class:`WireVector` for fractional bits + - :class:`WireVector` for length of fraction + """ sign = x[nbits - 1] rest = [x[nbits - 2 - i] for i in range(nbits - 1)] regime_bit = rest[0] @@ -48,6 +60,17 @@ def decode_posit(x, nbits, es): return sign, k, exp, frac_result, fraction_length def get_upto_regime(k, n_val, sign_final): + """Calculates the remaining bits and the regime bits. + + :param k: A :class:`WireVector` that represents the k value. + :param n_val: A :class:`WireVector` that represents the bitwidth of + the posit. + :param sign_final: A :class:`WireVector` that represents the final sign. + + :return: A :class:`tuple` consisting of: + - :class:`WireVector` representing the remaining bits. + - :class:`WireVector` representing the regime bits with sign bit. + """ precomputed_val = (1 << (n_val - 1)) - 1 n_c = pyrtl.Const(n_val, bitwidth=n_val) n_minus_1 = pyrtl.Const(n_val - 1, bitwidth=n_val) @@ -111,6 +134,17 @@ def get_upto_regime(k, n_val, sign_final): return rem_bits, sign_w_regime_final def frac_with_hidden_one(frac, frac_length, nbits): + """Adds a hidden 1 to the fractional bits. + + :param frac: A :class:`WireVector` that represents the fractional bits. + :param frac_length: A :class:`WireVector` that represents the length of + the fractional bits. + :param nbits: A :class:`WireVector` that represents the bitwidth of the + posit. + + :return: A :class:`WireVector` that represents the fraction with the + hidden 1. + """ one_table = [ pyrtl.Const(1 << i, bitwidth=32) for i in range(nbits + 1) @@ -131,6 +165,12 @@ def frac_with_hidden_one(frac, frac_length, nbits): return full def remove_first_one(val): + """Removes the leading hidden bit of 1. + + :param val: A :class:`WireVector` that represents the fractional bits. + + :return: A :class:`WireVector` with the hidden bit of 1 removed. + """ found = pyrtl.Const(0, bitwidth=1) result_bits = [] @@ -149,4 +189,4 @@ def remove_first_one(val): found, ) - return pyrtl.concat_list(result_bits[::-1]) + return pyrtl.concat_list(result_bits[::-1]) \ No newline at end of file diff --git a/pyrtl/rtllib/positadder.py b/pyrtl/rtllib/positadder.py index 3466e582..20a5f8e7 100644 --- a/pyrtl/rtllib/positadder.py +++ b/pyrtl/rtllib/positadder.py @@ -3,6 +3,15 @@ from pyrtl.positutils import decode_posit, get_upto_regime, frac_with_hidden_one, remove_first_one def posit_add(a, b, nbits, es): + """Adds two numbers in posit format and returns their sum. + + :param a: A :class:`WireVector` to add. Bitwidths need to match. + :param b: A :class:`WireVector` to add. Bitwidths need to match. + :param nbits: A :class:`int` representing the total bitwidth of the posit. + :param es: A :class:`int` representing the exponent size of the posit. + + :return: A :class:`WireVector` that represents the sum of the two posits. + """ # Decode input posits into regime (k), exponent, fraction, and fraction length _, k_a, exp_a, frac_a, frac_len_a = decode_posit(a, nbits, es) _, k_b, exp_b, frac_b, frac_len_b = decode_posit(b, nbits, es) diff --git a/pyrtl/rtllib/positmatmul.py b/pyrtl/rtllib/positmatmul.py index 1145e0b7..0c94c8c2 100644 --- a/pyrtl/rtllib/positmatmul.py +++ b/pyrtl/rtllib/positmatmul.py @@ -5,6 +5,17 @@ from positmul import posit_mul def posit_matmul(x, y, nbits, es): + """Performs matrix multiplication on posits. + + :param x: A :class:`Matrix` to be multiplied. + :param y: A :class:`Matrix` to be multiplied. + :param nbits: A :class:`int` representing the bitwidth of each cell of + the matrix. + :param es: A :class:`int` representing the exponent size of the posit. + + :return: A :class:`Matrix` that represents the product of two posit + matrices. + """ if not isinstance(x, Matrix): msg = f"error: expecting a Matrix, got {type(x)} instead" raise PyrtlError(msg) From ac19689034f29f9af417fbb3131a0958a9620249 Mon Sep 17 00:00:00 2001 From: arvindajaybharadwaj Date: Mon, 18 Aug 2025 14:19:24 +0530 Subject: [PATCH 05/33] Add type annotations --- pyrtl/positutils.py | 30 +++++++++++++--- pyrtl/rtllib/positadder.py | 68 ++++++++++++++++++++++++------------- pyrtl/rtllib/positmatmul.py | 32 +++++++++-------- 3 files changed, 86 insertions(+), 44 deletions(-) diff --git a/pyrtl/positutils.py b/pyrtl/positutils.py index 5724baab..3131d881 100644 --- a/pyrtl/positutils.py +++ b/pyrtl/positutils.py @@ -3,7 +3,16 @@ import pyrtl from pyrtl.corecircuits import shift_right_logical, shift_left_logical -def decode_posit(x, nbits, es): + +def decode_posit( + x: pyrtl.WireVector, nbits: int, es: int +) -> tuple[ + pyrtl.WireVector, + pyrtl.WireVector, + pyrtl.WireVector, + pyrtl.WireVector, + pyrtl.WireVector, +]: """Decode posit into its components and return them as a :class:`tuple`. :param x: A :class:`WireVector` that represents the posit. @@ -29,7 +38,7 @@ def decode_posit(x, nbits, es): run_len = run_len + pyrtl.select( active & is_same, pyrtl.Const(1), - pyrtl.Const(0) + pyrtl.Const(0), ) active = active & is_same @@ -59,7 +68,12 @@ def decode_posit(x, nbits, es): return sign, k, exp, frac_result, fraction_length -def get_upto_regime(k, n_val, sign_final): + +def get_upto_regime( + k: pyrtl.WireVector, + n_val: pyrtl.WireVector, + sign_final: pyrtl.WireVector, +) -> tuple[pyrtl.WireVector, pyrtl.WireVector]: """Calculates the remaining bits and the regime bits. :param k: A :class:`WireVector` that represents the k value. @@ -133,7 +147,12 @@ def get_upto_regime(k, n_val, sign_final): sign_w_regime_final = pyrtl.concat(sign_final, sign_w_regime_trimmed) return rem_bits, sign_w_regime_final -def frac_with_hidden_one(frac, frac_length, nbits): + +def frac_with_hidden_one( + frac: pyrtl.WireVector, + frac_length: pyrtl.WireVector, + nbits: int, +) -> pyrtl.WireVector: """Adds a hidden 1 to the fractional bits. :param frac: A :class:`WireVector` that represents the fractional bits. @@ -164,7 +183,8 @@ def frac_with_hidden_one(frac, frac_length, nbits): full = one_shifted + frac_32 return full -def remove_first_one(val): + +def remove_first_one(val: pyrtl.WireVector) -> pyrtl.WireVector: """Removes the leading hidden bit of 1. :param val: A :class:`WireVector` that represents the fractional bits. diff --git a/pyrtl/rtllib/positadder.py b/pyrtl/rtllib/positadder.py index 20a5f8e7..6030f8d2 100644 --- a/pyrtl/rtllib/positadder.py +++ b/pyrtl/rtllib/positadder.py @@ -1,8 +1,16 @@ import pyrtl from pyrtl.corecircuits import shift_left_logical, shift_right_logical -from pyrtl.positutils import decode_posit, get_upto_regime, frac_with_hidden_one, remove_first_one +from pyrtl.positutils import ( + decode_posit, + get_upto_regime, + frac_with_hidden_one, + remove_first_one, +) -def posit_add(a, b, nbits, es): + +def posit_add( + a: pyrtl.WireVector, b: pyrtl.WireVector, nbits: int, es: int +) -> pyrtl.WireVector: """Adds two numbers in posit format and returns their sum. :param a: A :class:`WireVector` to add. Bitwidths need to match. @@ -10,16 +18,16 @@ def posit_add(a, b, nbits, es): :param nbits: A :class:`int` representing the total bitwidth of the posit. :param es: A :class:`int` representing the exponent size of the posit. - :return: A :class:`WireVector` that represents the sum of the two posits. + :return: A :class:`WireVector` that represents the sum of the two posits. """ # Decode input posits into regime (k), exponent, fraction, and fraction length _, k_a, exp_a, frac_a, frac_len_a = decode_posit(a, nbits, es) _, k_b, exp_b, frac_b, frac_len_b = decode_posit(b, nbits, es) - # match bitwidths of fractional part + # Match bitwidths of fractional part frac_a_aligned = pyrtl.WireVector(bitwidth=nbits) frac_b_aligned = pyrtl.WireVector(bitwidth=nbits) - + frac_len_a_aligned = pyrtl.WireVector(bitwidth=nbits) frac_len_b_aligned = pyrtl.WireVector(bitwidth=nbits) @@ -45,10 +53,8 @@ def posit_add(a, b, nbits, es): frac_len_b_aligned |= frac_len_b # Add hidden leading one to fractions - frac_a_full = frac_with_hidden_one(frac_a_aligned, frac_len_a_aligned, - nbits) - frac_b_full = frac_with_hidden_one(frac_b_aligned, frac_len_b_aligned, - nbits) + frac_a_full = frac_with_hidden_one(frac_a_aligned, frac_len_a_aligned, nbits) + frac_b_full = frac_with_hidden_one(frac_b_aligned, frac_len_b_aligned, nbits) # Compute scales (regime*k + exponent) scale_a = shift_left_logical(k_a, es) + exp_a @@ -59,7 +65,7 @@ def posit_add(a, b, nbits, es): is_neg_offset = pyrtl.select( offset > pyrtl.Const(127, bitwidth=nbits), truecase=pyrtl.Const(1, bitwidth=nbits), - falsecase=pyrtl.Const(0, bitwidth=nbits) + falsecase=pyrtl.Const(0, bitwidth=nbits), ) shifted_a = pyrtl.WireVector(bitwidth=frac_a_full.bitwidth) @@ -86,14 +92,13 @@ def posit_add(a, b, nbits, es): # Add shifted fractions result_frac = shifted_a + shifted_b - # checking for overflow, if overflow, increase scale + # Checking for overflow, if overflow, increase scale result_scale = pyrtl.select( offset == pyrtl.Const(0, bitwidth=offset.bitwidth), result_scale + 1, - result_scale + result_scale, ) - result_k = shift_right_logical(result_scale, - pyrtl.Const(es, bitwidth=nbits)) + result_k = shift_right_logical(result_scale, pyrtl.Const(es, bitwidth=nbits)) # Extract regime bits rem_bits, regime_bits = get_upto_regime(result_k, nbits, 0) @@ -107,7 +112,7 @@ def posit_add(a, b, nbits, es): # Remaining fraction length frac_len = rem_bits - es - # handling rounding of fractional bits + # Handling rounding of fractional bits count = pyrtl.Const(0, bitwidth=nbits) rounded_frac = pyrtl.Const(0, bitwidth=nbits) found = pyrtl.Const(0, bitwidth=nbits) @@ -116,8 +121,7 @@ def posit_add(a, b, nbits, es): bit = result_frac[nbits - 1 - i] cond = pyrtl.select(bit == pyrtl.Const(1), 1, found) found = found | cond - count = pyrtl.select(found == pyrtl.Const(1), - count + 1, count) + count = pyrtl.select(found == pyrtl.Const(1), count + 1, count) # Exclude leading one count = count - 1 @@ -157,17 +161,17 @@ def posit_add(a, b, nbits, es): # Combine regime, exponent, and fraction added_posit = ( - pyrtl.Const(0, bitwidth=nbits) + - regime_bits + - result_exp + - rounded_frac + pyrtl.Const(0, bitwidth=nbits) + + regime_bits + + result_exp + + rounded_frac ) result_posit = pyrtl.WireVector(bitwidth=nbits) # Checking for special cases (NaR and 0) isNar = ( - pyrtl.select(a == pyrtl.Const(1 << nbits - 1, bitwidth=nbits), 1, 0) | - pyrtl.select(b == pyrtl.Const(1 << nbits - 1, bitwidth=nbits), 1, 0) + pyrtl.select(a == pyrtl.Const(1 << nbits - 1, bitwidth=nbits), 1, 0) + | pyrtl.select(b == pyrtl.Const(1 << nbits - 1, bitwidth=nbits), 1, 0) ) with pyrtl.conditional_assignment: @@ -183,4 +187,20 @@ def posit_add(a, b, nbits, es): with pyrtl.otherwise: result_posit |= added_posit - return result_posit \ No newline at end of file + return result_posit + +# Simulation +nbits = 8 +es = 1 + +a = pyrtl.Input(bitwidth=nbits, name='const_a') +b = pyrtl.Input(bitwidth=nbits, name='const_b') +posit = pyrtl.Output(bitwidth=nbits, name='posit') + +added_posit = posit_add(a, b, nbits, es) + +posit <<= added_posit + +sim = pyrtl.Simulation() +sim.step({'const_a': 0b01011100, 'const_b': 0b01100000}) # 3.5 + 4 = 7.5 +print("added posit =", format(sim.inspect('posit'), '08b')) \ No newline at end of file diff --git a/pyrtl/rtllib/positmatmul.py b/pyrtl/rtllib/positmatmul.py index 0c94c8c2..d14dba93 100644 --- a/pyrtl/rtllib/positmatmul.py +++ b/pyrtl/rtllib/positmatmul.py @@ -4,39 +4,41 @@ from positadder import posit_add from positmul import posit_mul -def posit_matmul(x, y, nbits, es): + +def posit_matmul(x: Matrix, y: Matrix, nbits: int, es: int) -> Matrix: """Performs matrix multiplication on posits. :param x: A :class:`Matrix` to be multiplied. :param y: A :class:`Matrix` to be multiplied. :param nbits: A :class:`int` representing the bitwidth of each cell of - the matrix. + the matrix. :param es: A :class:`int` representing the exponent size of the posit. :return: A :class:`Matrix` that represents the product of two posit - matrices. + matrices. """ if not isinstance(x, Matrix): msg = f"error: expecting a Matrix, got {type(x)} instead" raise PyrtlError(msg) + if not isinstance(y, Matrix): msg = f"error: expecting a Matrix, got {type(y)} instead" raise PyrtlError(msg) - + if x.columns != y.rows: msg = ( - f"error: rows and columns mismatch. Matrix a: {x.columns} columns, " - f"Matrix b: {y.rows} rows" - ) + f"error: rows and columns mismatch. " + f"Matrix a: {x.columns} columns, Matrix b: {y.rows} rows" + ) raise PyrtlError(msg) - + result = Matrix( - x.rows, - y.columns, - nbits, - max_bits=x.max_bits, - ) - + x.rows, + y.columns, + nbits, + max_bits=x.max_bits, + ) + for i in range(x.rows): for j in range(y.columns): acc = pyrtl.Const(0, bitwidth=nbits) @@ -44,5 +46,5 @@ def posit_matmul(x, y, nbits, es): prod = posit_mul(nbits, es, x[i, k], y[k, j]) acc = posit_add(acc, prod, nbits, es) result[i, j] = acc - + return result \ No newline at end of file From 7c8b80889ef8824d7bed27915f3386da0a998e59 Mon Sep 17 00:00:00 2001 From: arvindajaybharadwaj Date: Mon, 18 Aug 2025 14:40:03 +0530 Subject: [PATCH 06/33] Add example for positadder.py --- pyrtl/rtllib/positadder.py | 47 ++++++++++++++++++++++---------------- 1 file changed, 27 insertions(+), 20 deletions(-) diff --git a/pyrtl/rtllib/positadder.py b/pyrtl/rtllib/positadder.py index 6030f8d2..ddcda14f 100644 --- a/pyrtl/rtllib/positadder.py +++ b/pyrtl/rtllib/positadder.py @@ -13,10 +13,33 @@ def posit_add( ) -> pyrtl.WireVector: """Adds two numbers in posit format and returns their sum. - :param a: A :class:`WireVector` to add. Bitwidths need to match. + .. doctest only:: + + >>> import pyrtl + >>> pyrtl.reset_working_block() + + Example:: + + >>> nbits = 8 + >>> es = 1 + + >>> a = pyrtl.Input(bitwidth=nbits, name='a') + >>> b = pyrtl.Input(bitwidth=nbits, name='b') + >>> posit = pyrtl.Output(bitwidth=nbits, name='posit') + + >>> added_posit = posit_add(a, b, nbits, es) + + >>> posit <<= added_posit + + >>> sim = pyrtl.Simulation() + >>> sim.step({'a': 0b01011100, 'b': 0b01100000}) # 3.5 + 4 = 7.5 + >>> format(sim.inspect('posit'), '08b') + '01100111' + + :param a: A :class:`.WireVector` to add. Bitwidths need to match. :param b: A :class:`WireVector` to add. Bitwidths need to match. - :param nbits: A :class:`int` representing the total bitwidth of the posit. - :param es: A :class:`int` representing the exponent size of the posit. + :param nbits: A :class:`.int` representing the total bitwidth of the posit. + :param es: A :class:`.int` representing the exponent size of the posit. :return: A :class:`WireVector` that represents the sum of the two posits. """ @@ -187,20 +210,4 @@ def posit_add( with pyrtl.otherwise: result_posit |= added_posit - return result_posit - -# Simulation -nbits = 8 -es = 1 - -a = pyrtl.Input(bitwidth=nbits, name='const_a') -b = pyrtl.Input(bitwidth=nbits, name='const_b') -posit = pyrtl.Output(bitwidth=nbits, name='posit') - -added_posit = posit_add(a, b, nbits, es) - -posit <<= added_posit - -sim = pyrtl.Simulation() -sim.step({'const_a': 0b01011100, 'const_b': 0b01100000}) # 3.5 + 4 = 7.5 -print("added posit =", format(sim.inspect('posit'), '08b')) \ No newline at end of file + return result_posit \ No newline at end of file From 9f2ecf867d1aa31039a18da042480b9c865bf4be Mon Sep 17 00:00:00 2001 From: Sunidhi M Date: Tue, 19 Aug 2025 18:55:55 +0530 Subject: [PATCH 07/33] Update positutils.py --- pyrtl/positutils.py | 159 +++++++++++++++++++------------------------- 1 file changed, 70 insertions(+), 89 deletions(-) diff --git a/pyrtl/positutils.py b/pyrtl/positutils.py index 3131d881..78c4de4d 100644 --- a/pyrtl/positutils.py +++ b/pyrtl/positutils.py @@ -1,4 +1,4 @@ -"""Implements utility functions for posit operations""" +"""Implements utility functions for posit operations.""" import pyrtl from pyrtl.corecircuits import shift_right_logical, shift_left_logical @@ -13,18 +13,18 @@ def decode_posit( pyrtl.WireVector, pyrtl.WireVector, ]: - """Decode posit into its components and return them as a :class:`tuple`. - - :param x: A :class:`WireVector` that represents the posit. - :param nbits: A :class:`int` that represents the bitwidth of the posit. - :param es: A :class:`int` that represents the exponent size of the posit. - - :return: A :class:`tuple` consisting of: - - :class:`WireVector` for sign - - :class:`WireVector` for k - - :class:`WireVector` for exponent - - :class:`WireVector` for fractional bits - - :class:`WireVector` for length of fraction + """Decode posit into its components and return them as a tuple. + + :param x: A WireVector that represents the posit. + :param nbits: An int that represents the bitwidth of the posit. + :param es: An int that represents the exponent size of the posit. + + :return: A tuple consisting of: + - WireVector for sign + - WireVector for k + - WireVector for exponent + - WireVector for fractional bits + - WireVector for length of fraction """ sign = x[nbits - 1] rest = [x[nbits - 2 - i] for i in range(nbits - 1)] @@ -76,75 +76,64 @@ def get_upto_regime( ) -> tuple[pyrtl.WireVector, pyrtl.WireVector]: """Calculates the remaining bits and the regime bits. - :param k: A :class:`WireVector` that represents the k value. - :param n_val: A :class:`WireVector` that represents the bitwidth of - the posit. - :param sign_final: A :class:`WireVector` that represents the final sign. + :param k: A WireVector that represents the k value. + :param n_val: A WireVector that represents the bitwidth of the posit. + :param sign_final: A WireVector that represents the final sign. - :return: A :class:`tuple` consisting of: - - :class:`WireVector` representing the remaining bits. - - :class:`WireVector` representing the regime bits with sign bit. + :return: A tuple consisting of: + - WireVector representing the remaining bits. + - WireVector representing the regime bits with sign bit. """ precomputed_val = (1 << (n_val - 1)) - 1 - n_c = pyrtl.Const(n_val, bitwidth=n_val) - n_minus_1 = pyrtl.Const(n_val - 1, bitwidth=n_val) - n_minus_2 = pyrtl.Const(n_val - 2, bitwidth=n_val) - n_minus_3 = pyrtl.Const(n_val - 3, bitwidth=n_val) - - k_thresh = pyrtl.Const(1 << (n_val - 1), bitwidth=k.bitwidth) - abs_k = pyrtl.select( - k >= k_thresh, - ( - (~k + pyrtl.Const(1, bitwidth=k.bitwidth)) - & pyrtl.Const((1 << n_val) - 1, bitwidth=k.bitwidth) - ), - k, - ) - - large_neg_regime = abs_k >= n_minus_1 - large_pos_regime = abs_k >= n_minus_2 - - temp_rem1 = (n_c + k) - pyrtl.Const(2, bitwidth=n_val) - sign_case1_inner = shift_right_logical( - pyrtl.Const(1 << (n_val - 2), bitwidth=n_val), abs_k - ) - rem_bits_case1 = pyrtl.select( - large_neg_regime, - pyrtl.Const(0, bitwidth=n_val), - temp_rem1, - ) - sign_case1 = pyrtl.select( - large_neg_regime, - pyrtl.Const(0, bitwidth=n_val), - sign_case1_inner, - ) + n_c = pyrtl.Const(n_val) + n_minus_1 = pyrtl.Const(n_val - 1) + n_minus_2 = pyrtl.Const(n_val - 2) + n_minus_3 = pyrtl.Const(n_val - 3) - temp_rem2 = n_minus_3 - k - shift_amt = k + pyrtl.Const(2, bitwidth=n_val) - ones = ( - shift_left_logical(pyrtl.Const(1, bitwidth=n_val), shift_amt) - - pyrtl.Const(2, bitwidth=n_val) - ) - shifted_case2 = shift_left_logical(ones, temp_rem2) + rem_bits = pyrtl.WireVector(bitwidth=n_val) + sign_w_regime = pyrtl.WireVector(bitwidth=n_val) - rem_bits_case2 = pyrtl.select( - large_pos_regime, - pyrtl.Const(0, bitwidth=n_val), - temp_rem2, - ) - sign_case2 = pyrtl.select( - large_pos_regime, - pyrtl.Const(precomputed_val, bitwidth=n_val), - shifted_case2, + abs_k = pyrtl.WireVector(bitwidth=k.bitwidth) + abs_k <<= pyrtl.select( + k >= (1 << (n_val - 1)), + (~k + 1) & ((1 << n_val) - 1), + k, ) - cond_k_ge = k >= k_thresh - rem_bits = pyrtl.select(cond_k_ge, rem_bits_case1, rem_bits_case2) - sign_w_regime = pyrtl.select(cond_k_ge, sign_case1, sign_case2) + large_neg_regime = abs_k >= n_minus_1 + large_pos_regime = abs_k >= n_minus_2 - sign_w_regime_trimmed = sign_w_regime[: n_val - 1] + with pyrtl.conditional_assignment: + with k >= (1 << (n_val - 1)): + with large_neg_regime: + rem_bits |= 0 + sign_w_regime |= 0 + with ~large_neg_regime: + temp_rem = n_c + k - 2 + rem_bits |= temp_rem + sign_w_regime |= shift_right_logical( + pyrtl.Const(1 << (n_val - 2), bitwidth=n_val), abs_k + ) + + with k < (1 << (n_val - 1)): + with large_pos_regime: + rem_bits |= 0 + sign_w_regime |= pyrtl.Const(precomputed_val, bitwidth=n_val) + with ~large_pos_regime: + temp_rem = n_minus_3 - k + shift_amt = k + 2 + rem_bits |= temp_rem + ones = shift_left_logical( + pyrtl.Const(1, bitwidth=n_val), shift_amt + ) - pyrtl.Const(2, bitwidth=n_val) + shifted = shift_left_logical(ones, temp_rem) + sign_w_regime |= shifted + + sign_w_regime_trimmed = pyrtl.WireVector(bitwidth=n_val - 1) + sign_w_regime_trimmed <<= sign_w_regime[: n_val - 1] sign_w_regime_final = pyrtl.concat(sign_final, sign_w_regime_trimmed) + return rem_bits, sign_w_regime_final @@ -155,19 +144,13 @@ def frac_with_hidden_one( ) -> pyrtl.WireVector: """Adds a hidden 1 to the fractional bits. - :param frac: A :class:`WireVector` that represents the fractional bits. - :param frac_length: A :class:`WireVector` that represents the length of - the fractional bits. - :param nbits: A :class:`WireVector` that represents the bitwidth of the - posit. + :param frac: A WireVector that represents the fractional bits. + :param frac_length: A WireVector that represents the length of the fractional bits. + :param nbits: An int that represents the bitwidth of the posit. - :return: A :class:`WireVector` that represents the fraction with the - hidden 1. + :return: A WireVector that represents the fraction with the hidden 1. """ - one_table = [ - pyrtl.Const(1 << i, bitwidth=32) - for i in range(nbits + 1) - ] + one_table = [pyrtl.Const(1 << i, bitwidth=32) for i in range(nbits + 1)] one_shifted = pyrtl.Const(0, bitwidth=32) for i in range(nbits + 1): @@ -177,9 +160,7 @@ def frac_with_hidden_one( one_shifted, ) - frac_32 = pyrtl.concat( - pyrtl.Const(0, bitwidth=32 - (nbits - 1)), frac - ) + frac_32 = pyrtl.concat(pyrtl.Const(0, bitwidth=32 - (nbits - 1)), frac) full = one_shifted + frac_32 return full @@ -187,9 +168,9 @@ def frac_with_hidden_one( def remove_first_one(val: pyrtl.WireVector) -> pyrtl.WireVector: """Removes the leading hidden bit of 1. - :param val: A :class:`WireVector` that represents the fractional bits. + :param val: A WireVector that represents the fractional bits. - :return: A :class:`WireVector` with the hidden bit of 1 removed. + :return: A WireVector with the hidden bit of 1 removed. """ found = pyrtl.Const(0, bitwidth=1) result_bits = [] @@ -209,4 +190,4 @@ def remove_first_one(val: pyrtl.WireVector) -> pyrtl.WireVector: found, ) - return pyrtl.concat_list(result_bits[::-1]) \ No newline at end of file + return pyrtl.concat_list(result_bits[::-1]) From 30d48055cac7d4cfae5cb549b8b4fd4565302a7a Mon Sep 17 00:00:00 2001 From: Sunidhi M Date: Tue, 19 Aug 2025 18:59:04 +0530 Subject: [PATCH 08/33] Create positmul.py --- pyrtl/rtllib/positmul.py | 137 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 137 insertions(+) create mode 100644 pyrtl/rtllib/positmul.py diff --git a/pyrtl/rtllib/positmul.py b/pyrtl/rtllib/positmul.py new file mode 100644 index 00000000..1051fbcc --- /dev/null +++ b/pyrtl/rtllib/positmul.py @@ -0,0 +1,137 @@ +import pyrtl +from pyrtl.corecircuits import shift_right_logical, shift_left_logical +from pyrtl.positutils import decode_posit, get_upto_regime + + +def posit_mul( + a: pyrtl.WireVector, b: pyrtl.WireVector, nbits: int, es: int +) -> pyrtl.WireVector: + """Multiplies two numbers in posit format and returns their product. + + .. doctest only:: + + >>> import pyrtl + >>> pyrtl.reset_working_block() + >>> nbits, es = 8, 1 + >>> a = pyrtl.Input(bitwidth=nbits, name='a') + >>> b = pyrtl.Input(bitwidth=nbits, name='b') + >>> out = pyrtl.Output(bitwidth=nbits, name='out') + >>> out <<= posit_mul(a, b, nbits, es) + >>> sim = pyrtl.Simulation() + >>> sim.step({'a': 0b01011100, 'b': 0b01000000}) # 3.5 * 2 = 7.0 (approx) + >>> format(sim.inspect('out'), '08b') # doctest: +ELLIPSIS + '...' + + :param a: A WireVector posit multiplicand. + :param b: A WireVector posit multiplier. + :param nbits: Total bitwidth of the posit. + :param es: Exponent size of the posit. + :return: A WireVector representing the product of the two posits. + """ + + # Decode inputs + sign_a, k_a, exp_a, frac_a, fraclength_a = decode_posit(a, nbits, es) + sign_b, k_b, exp_b, frac_b, fraclength_b = decode_posit(b, nbits, es) + + # Handle multiplication of special cases + either_zero = (a == 0) | (b == 0) + either_inf = (a == (1 << (nbits - 1))) | (b == (1 << (nbits - 1))) + + final_value = pyrtl.WireVector(bitwidth=nbits, name='final_value') + result_zero = pyrtl.Const(0, bitwidth=nbits) + result_nar = pyrtl.Const(1 << (nbits - 1), bitwidth=nbits) + normal_case = ~(either_zero | either_inf) + + # Compute resultant sign + sign_final = sign_a ^ sign_b + + # Compute scale + scale_a = k_a * pyrtl.Const(2 ** es) + exp_a + scale_b = k_b * pyrtl.Const(2 ** es) + exp_b + scale_sum = scale_a + scale_b + + # Fraction multiplication with implicit 1 + one_table = [pyrtl.Const(1 << i, bitwidth=32) for i in range(nbits)] + one_shifted_a = pyrtl.Const(0, bitwidth=32) + one_shifted_b = pyrtl.Const(0, bitwidth=32) + for i in range(nbits): + one_shifted_a = pyrtl.select(fraclength_a == i, one_table[i], one_shifted_a) + one_shifted_b = pyrtl.select(fraclength_b == i, one_table[i], one_shifted_b) + + frac_a_32 = pyrtl.concat(pyrtl.Const(0, bitwidth=24), frac_a) + frac_b_32 = pyrtl.concat(pyrtl.Const(0, bitwidth=24), frac_b) + frac_a_full = one_shifted_a + frac_a_32 + frac_b_full = one_shifted_b + frac_b_32 + + frac_product = frac_a_full * frac_b_full + + # Normalize fraction + fraclen_total = fraclength_a + fraclength_b + threshold = shift_left_logical(pyrtl.Const(1, bitwidth=32), fraclen_total + 1) + frac = frac_product + scale = pyrtl.Const(0, bitwidth=8) + for _ in range(8): + shifted = shift_right_logical(frac, 1) + should_shift = frac >= threshold + frac = pyrtl.select(should_shift, shifted, frac) + scale = pyrtl.select(should_shift, scale + 1, scale) + normalized_frac = frac + normalized_scale = scale + + # Remove extra 1 + mask_table = [pyrtl.Const((1 << i) - 1, bitwidth=32) for i in range(33)] + mask_val = pyrtl.Const(0) + for i in range(1, 33): + mask_val = pyrtl.select(fraclen_total == i, mask_table[i], mask_val) + frac_result = normalized_frac & mask_val + + # Final scale + final_scale = scale_sum + normalized_scale + + # Extract k and exponent + resultk = shift_right_logical(final_scale, pyrtl.Const(es, bitwidth=8)) + mod_mask = pyrtl.Const((1 << es) - 1, bitwidth=final_scale.bitwidth) + resultExponent = final_scale & mod_mask + + # Get remaining bits and regime + rem_bits, sign_w_regime = get_upto_regime(resultk, nbits, sign_final) + + # Fraction bits with rounding + frac_bits = rem_bits - es + is_small = rem_bits <= es + + shift_amt_small = es - rem_bits + exp_shifted_small = shift_right_logical(resultExponent, shift_amt_small) + value_small = sign_w_regime + exp_shifted_small + + sum_fraclens = fraclength_a + fraclength_b + roundup_bit = pyrtl.Const(0, bitwidth=1) + cond_round = sum_fraclens > frac_bits + shift_amt1 = sum_fraclens - frac_bits - 1 + shift_amt2 = sum_fraclens - frac_bits + + roundup_candidate = shift_right_logical(frac_result, shift_amt1) & 1 + frac_shifted = shift_right_logical(frac_result, shift_amt2) + frac_shifted_else = shift_left_logical(frac_result, frac_bits - sum_fraclens) + + frac_final = pyrtl.WireVector(bitwidth=nbits) + frac_final <<= pyrtl.select(cond_round, frac_shifted, frac_shifted_else) + + roundup_bit = pyrtl.WireVector(bitwidth=1, name='roundup_bit') + roundup_bit <<= pyrtl.select(cond_round, roundup_candidate, pyrtl.Const(0, bitwidth=1)) + + exp_shifted_large = shift_left_logical(resultExponent, frac_bits) + value_large = sign_w_regime + exp_shifted_large + frac_final + all_ones = pyrtl.Const((1 << nbits) - 1, bitwidth=nbits) + value_rounded = pyrtl.select( + (roundup_bit & (value_large != all_ones)), value_large + 1, value_large + ) + + computed_value = pyrtl.select(is_small, value_small, value_rounded) + + # Select between normal values and special computed value + final_value <<= pyrtl.select( + either_zero, result_zero, pyrtl.select(either_inf, result_nar, computed_value) + ) + + return final_value From 2158520287631771789632281cd87a631d92240e Mon Sep 17 00:00:00 2001 From: arvindajaybharadwaj Date: Tue, 19 Aug 2025 22:22:45 +0530 Subject: [PATCH 09/33] fix: correctly extract exp field for es > 1 --- pyrtl/positutils.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/pyrtl/positutils.py b/pyrtl/positutils.py index 78c4de4d..a91b28e3 100644 --- a/pyrtl/positutils.py +++ b/pyrtl/positutils.py @@ -46,9 +46,18 @@ def decode_posit( k_neg = (~run_len) + pyrtl.Const(1, bitwidth=nbits) k = pyrtl.select(regime_bit, k_pos, k_neg) - exp = pyrtl.Const(0, bitwidth=es) - for i in range(nbits - 2): - exp = pyrtl.select(run_len == pyrtl.Const(i), rest[i + 1], exp) + exp_bits = [] + for j in range(es): + bit_val = pyrtl.Const(0, bitwidth=1) + for i in range(nbits - 2): + cond = run_len == pyrtl.Const(i, bitwidth=nbits) + # exponent bit is at rest[i + 1 + j] + target_idx = i + 1 + j + if target_idx < (nbits - 1): + bit_val = pyrtl.select(cond, rest[target_idx], bit_val) + exp_bits.append(bit_val) + + exp = pyrtl.concat_list(exp_bits[::-1]) if es > 0 else pyrtl.Const(0) start_idx = run_len + pyrtl.Const(1 + es, bitwidth=nbits) fraction_bits = [] From 5bcac4cdabb7f2df2ae273944d445fca6f7d74bf Mon Sep 17 00:00:00 2001 From: arvindajaybharadwaj Date: Wed, 20 Aug 2025 19:14:51 +0530 Subject: [PATCH 10/33] Update: Add examples to posit utility functions --- pyrtl/positutils.py | 106 +++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 104 insertions(+), 2 deletions(-) diff --git a/pyrtl/positutils.py b/pyrtl/positutils.py index a91b28e3..a0e2395d 100644 --- a/pyrtl/positutils.py +++ b/pyrtl/positutils.py @@ -15,6 +15,44 @@ def decode_posit( ]: """Decode posit into its components and return them as a tuple. + .. doctest only:: + + >>> import pyrtl + >>> pyrtl.reset_working_block() + + Example:: + >>> nbits = 8 + >>> es = 2 + + >>> a = pyrtl.Input(bitwidth=nbits, name='a') + >>> sign_out = pyrtl.Output(bitwidth=nbits, name='sign_out') + >>> k_out = pyrtl.Output(bitwidth=nbits, name='k_out') + >>> exp_out = pyrtl.Output(bitwidth=es, name='exp_out') + >>> frac_bits_out = pyrtl.Output(bitwidth=nbits, name='frac_bits_out') + >>> frac_len_out = pyrtl.Output(bitwidth=nbits, name='frac_len_out') + + >>> sign, k, exp, frac_bits, frac_len = decode_posit(a, nbits, es) + + >>> sign_out <<= sign + >>> k_out <<= k + >>> exp_out <<= exp + >>> frac_bits_out <<= frac_bits + >>> frac_len_out <<= frac_len + + >>> sim = pyrtl.Simulation() + >>> sim.step({'a': 0b01011100}) + + >>> sim.inspect('sign_out') + '0' + >>> sim.inspect('k_out') + '0' + >>> sim.inspect('exp_out') + '3' + >>> format(sim.inspect('frac_bits_out'), '08b') + '00000100' + >>> sim.inspect('frac_len_out') + '3' + :param x: A WireVector that represents the posit. :param nbits: An int that represents the bitwidth of the posit. :param es: An int that represents the exponent size of the posit. @@ -80,13 +118,40 @@ def decode_posit( def get_upto_regime( k: pyrtl.WireVector, - n_val: pyrtl.WireVector, + n_val: int, sign_final: pyrtl.WireVector, ) -> tuple[pyrtl.WireVector, pyrtl.WireVector]: """Calculates the remaining bits and the regime bits. + .. doctest only:: + + >>> import pyrtl + >>> pyrtl.reset_working_block() + + Example:: + + >>> nbits = 8 + >>> k_in = pyrtl.Input(bitwidth=nbits, name='k_in') + >>> sign_final = pyrtl.Input(bitwidth=1, name='sign_final') + + >>> rem_bits_out = pyrtl.Output(bitwidth=nbits, name='rem_bits_out') + >>> sign_w_regime_out = pyrtl.Output(bitwidth=nbits, name='sign_w_regime_out') + + >>> rem_bits, sign_w_regime = get_upto_regime(k_in, nbits, sign_final) + + >>> rem_bits_out <<= rem_bits + >>> sign_w_regime_out <<= sign_w_regime + + >>> sim = pyrtl.Simulation() + >>> sim.step({'k_in': 2, 'sign_final': 0}) + + >>> sim.inspect('rem_bits_out') + '3' + >>> format(sim.inspect('sign_w_regime_out'), '08b') + '01110000' + :param k: A WireVector that represents the k value. - :param n_val: A WireVector that represents the bitwidth of the posit. + :param n_val: A int that represents the bitwidth of the posit. :param sign_final: A WireVector that represents the final sign. :return: A tuple consisting of: @@ -153,6 +218,25 @@ def frac_with_hidden_one( ) -> pyrtl.WireVector: """Adds a hidden 1 to the fractional bits. + .. doctest only:: + + >>> import pyrtl + >>> pyrtl.reset_working_block() + + Example:: + >>> nbits = 8 + >>> frac_in = pyrtl.Input(bitwidth=nbits-1, name='frac_in') + >>> frac_len_in = pyrtl.Input(bitwidth=nbits, name='frac_len_in') + >>> frac_out = pyrtl.Output(bitwidth=32, name='frac_out') + + >>> frac_out <<= frac_with_hidden_one(frac_in, frac_len_in, nbits) + + >>> sim = pyrtl.Simulation() + >>> sim.step({'frac_in': 0b0010101, 'frac_len_in': 5}) + + >>> format(sim.inspect('frac_out'), '08b') + '000110101' + :param frac: A WireVector that represents the fractional bits. :param frac_length: A WireVector that represents the length of the fractional bits. :param nbits: An int that represents the bitwidth of the posit. @@ -177,6 +261,24 @@ def frac_with_hidden_one( def remove_first_one(val: pyrtl.WireVector) -> pyrtl.WireVector: """Removes the leading hidden bit of 1. + .. doctest only:: + + >>> import pyrtl + >>> pyrtl.reset_working_block() + + Example:: + >>> nbits = 8 + >>> frac_with_one = pyrtl.Input(bitwidth=nbits, name='frac_with_one') + >>> frac_removed = pyrtl.Output(bitwidth=nbits, name='frac_removed') + + >>> frac_removed <<= remove_first_one(frac_with_one) + + >>> sim = pyrtl.Simulation() + >>> sim.step({'frac_with_one': 0b10010110}) + + >>> format(sim.inspect('frac_removed'), '08b') + '00010110' + :param val: A WireVector that represents the fractional bits. :return: A WireVector with the hidden bit of 1 removed. From befcefe071ad6c2de7f92e04d9efff8459a37ab2 Mon Sep 17 00:00:00 2001 From: arvindajaybharadwaj Date: Sun, 24 Aug 2025 16:52:08 +0530 Subject: [PATCH 11/33] update: add function to convert decimal to posit --- pyrtl/positutils.py | 66 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 66 insertions(+) diff --git a/pyrtl/positutils.py b/pyrtl/positutils.py index a0e2395d..e434d750 100644 --- a/pyrtl/positutils.py +++ b/pyrtl/positutils.py @@ -1,5 +1,6 @@ """Implements utility functions for posit operations.""" +import math import pyrtl from pyrtl.corecircuits import shift_right_logical, shift_left_logical @@ -302,3 +303,68 @@ def remove_first_one(val: pyrtl.WireVector) -> pyrtl.WireVector: ) return pyrtl.concat_list(result_bits[::-1]) + + +def decimal_to_posit(x: float, nbits: int, es: int) -> int: + """Convert a decimal float to Posit representation. + + .. doctest only:: + + >>> import math + + Example:: + >>> nbits, es = 16, 2 + >>> format(decimal_to_posit(4992, nbits, es), '016b') + '0111100000111000' + + >>> nbits, es = 8, 1 + >>> format(decimal_to_posit(5000, nbits, es), '08b') + '01111111' + + :param x: The decimal float to be converted. + :param nbits: Total number of bits in the posit representation. + :param es: Maximum number of exponent bits. + :return: The integer representation of the posit encoding. + """ + if x == 0: + return 0 + + sign = 0 + if x < 0: + sign = 1 + x = -x + + useed = 2 ** (2 ** es) + k = int(math.floor(math.log(x, useed))) + regime_value = useed ** k + + remaining = x / regime_value + exponent = int(math.floor(math.log2(remaining))) + exponent = max(0, exponent) + remaining /= (2 ** exponent) + + fraction = remaining - 1.0 + frac_bits = [] + + for _ in range(nbits * 2): + fraction *= 2 + if fraction >= 1: + frac_bits.append("1") + fraction -= 1 + else: + frac_bits.append("0") + + if k >= 0: + regime_bits = "1" * (k + 1) + "0" + else: + regime_bits = "0" * (-k) + "1" + + bits = str(sign) + bits += regime_bits + exp_str = bin(exponent)[2:].zfill(es) + bits += exp_str + bits += "".join(frac_bits) + + bits = bits[:nbits].ljust(nbits, "0") + + return int(bits, 2) \ No newline at end of file From 1b081aa1e3c74c6e4b3dfd3fca46a3bfa9e6e122 Mon Sep 17 00:00:00 2001 From: arvindajaybharadwaj Date: Sun, 24 Aug 2025 22:31:45 +0530 Subject: [PATCH 12/33] Add test_positadder.py --- tests/rtllib/test_positadder.py | 54 +++++++++++++++++++++++++++++++++ 1 file changed, 54 insertions(+) create mode 100644 tests/rtllib/test_positadder.py diff --git a/tests/rtllib/test_positadder.py b/tests/rtllib/test_positadder.py new file mode 100644 index 00000000..225ccfd3 --- /dev/null +++ b/tests/rtllib/test_positadder.py @@ -0,0 +1,54 @@ +import doctest +import random +import unittest + +import pyrtl +import pyrtl.rtllib.testingutils as utils +from pyrtl.rtllib.positadder import posit_add +from pyrtl.positutils import decimal_to_posit + +class TestDocTests(unittest.TestCase): + """Test documentation examples.""" + + def test_doctests(self): + failures, tests = doctest.testmod(m=pyrtl.rtllib.positadder) + self.assertGreater(tests, 0) + self.assertEqual(failures, 0) + +class TestPositAdder(unittest.TestCase): + @classmethod + def setUpClass(cls): + random.seed(42) + + def setUp(self): + pyrtl.reset_working_block() + + def tearDown(self): + pyrtl.reset_working_block() + + def test_posit_adder(self): + nbits_list = [8, 16, 32] + es_list = [0, 1, 2, 3, 4] + nbits, es = random.choice(nbits_list), random.choice(es_list) + a = pyrtl.Input(bitwidth=nbits, name="a") + b = pyrtl.Input(bitwidth=nbits, name="b") + out = pyrtl.Output(name="out") + + out <<= posit_add(a, b, nbits, es) + + useed = 2 ** (2 ** es) + maxpos = useed ** (nbits - 2) + + wires = [a, b] + vals_raw = [[random.randint(0, maxpos) for _ in range(7)] for _ in wires] + vals = [[decimal_to_posit(j, nbits, es) for j in i] for i in vals_raw] + + out_vals = utils.sim_and_ret_out(out, wires, vals) + true_result_raw = [x + y for x, y in zip(vals_raw[0], vals_raw[1])] + true_result = [decimal_to_posit(i, nbits, es) for i in true_result_raw] + + for sim, expected in zip(out_vals, true_result): + self.assertLessEqual(abs(sim - expected), 1) + +if __name__ == "__main__": + unittest.main() \ No newline at end of file From 926e4b0f47745bd509152d1d79a80ebabe9a28ff Mon Sep 17 00:00:00 2001 From: arvindajaybharadwaj Date: Mon, 25 Aug 2025 12:56:52 +0530 Subject: [PATCH 13/33] fix: handle shifting bits when es = 0 --- pyrtl/rtllib/positadder.py | 26 +++++++++++++++++++------- 1 file changed, 19 insertions(+), 7 deletions(-) diff --git a/pyrtl/rtllib/positadder.py b/pyrtl/rtllib/positadder.py index ddcda14f..d3322602 100644 --- a/pyrtl/rtllib/positadder.py +++ b/pyrtl/rtllib/positadder.py @@ -80,8 +80,8 @@ def posit_add( frac_b_full = frac_with_hidden_one(frac_b_aligned, frac_len_b_aligned, nbits) # Compute scales (regime*k + exponent) - scale_a = shift_left_logical(k_a, es) + exp_a - scale_b = shift_left_logical(k_b, es) + exp_b + scale_a = (k_a if es == 0 else shift_left_logical(k_a, es)) + exp_a + scale_b = (k_b if es == 0 else shift_left_logical(k_b, es)) + exp_b offset = scale_a - scale_b @@ -94,7 +94,10 @@ def posit_add( shifted_a = pyrtl.WireVector(bitwidth=frac_a_full.bitwidth) shifted_b = pyrtl.WireVector(bitwidth=frac_b_full.bitwidth) result_scale = pyrtl.WireVector(bitwidth=offset.bitwidth) - result_exp = pyrtl.WireVector(bitwidth=es) + if es == 0: + result_exp = pyrtl.WireVector(bitwidth=1) # placeholder + else: + result_exp = pyrtl.WireVector(bitwidth=es) neg_offset = (~offset) + pyrtl.Const(1, bitwidth=offset.bitwidth) @@ -121,15 +124,24 @@ def posit_add( result_scale + 1, result_scale, ) - result_k = shift_right_logical(result_scale, pyrtl.Const(es, bitwidth=nbits)) + result_k = ( + result_scale + if es == 0 + else shift_right_logical(result_scale, pyrtl.Const(es, bitwidth=nbits)) + ) # Extract regime bits rem_bits, regime_bits = get_upto_regime(result_k, nbits, 0) # Extract exponent from scale - result_exp = result_scale - shift_left_logical(result_k, es) - result_exp = shift_left_logical( - result_exp, rem_bits - pyrtl.Const(es, bitwidth=nbits) + result_exp = result_scale - ( + result_k if es == 0 else shift_left_logical(result_k, es) + ) + shift_amt = rem_bits - pyrtl.Const(es, bitwidth=nbits) + result_exp = pyrtl.select( + shift_amt == pyrtl.Const(0), + result_exp, + shift_left_logical(result_exp, shift_amt), ) # Remaining fraction length From c34a811752529838fa38a4feea5d29e3832c1ccc Mon Sep 17 00:00:00 2001 From: arvindajaybharadwaj Date: Mon, 25 Aug 2025 12:59:17 +0530 Subject: [PATCH 14/33] fix: handle rounding error while converting decimal to posit --- pyrtl/positutils.py | 47 +++++++++++++++++++++++++++++++++++---------- 1 file changed, 37 insertions(+), 10 deletions(-) diff --git a/pyrtl/positutils.py b/pyrtl/positutils.py index e434d750..364aa610 100644 --- a/pyrtl/positutils.py +++ b/pyrtl/positutils.py @@ -244,7 +244,7 @@ def frac_with_hidden_one( :return: A WireVector that represents the fraction with the hidden 1. """ - one_table = [pyrtl.Const(1 << i, bitwidth=32) for i in range(nbits + 1)] + one_table = [pyrtl.Const(1 << i, bitwidth=nbits + 1) for i in range(nbits + 1)] one_shifted = pyrtl.Const(0, bitwidth=32) for i in range(nbits + 1): @@ -313,6 +313,7 @@ def decimal_to_posit(x: float, nbits: int, es: int) -> int: >>> import math Example:: + >>> nbits, es = 16, 2 >>> format(decimal_to_posit(4992, nbits, es), '016b') '0111100000111000' @@ -329,23 +330,25 @@ def decimal_to_posit(x: float, nbits: int, es: int) -> int: if x == 0: return 0 + # Sign sign = 0 if x < 0: sign = 1 x = -x useed = 2 ** (2 ** es) + k = int(math.floor(math.log(x, useed))) regime_value = useed ** k - remaining = x / regime_value - exponent = int(math.floor(math.log2(remaining))) + + exponent = int(math.floor(math.log2(remaining))) if es > 0 else 0 exponent = max(0, exponent) - remaining /= (2 ** exponent) + remaining /= 2 ** exponent + # Fraction bits fraction = remaining - 1.0 frac_bits = [] - for _ in range(nbits * 2): fraction *= 2 if fraction >= 1: @@ -354,17 +357,41 @@ def decimal_to_posit(x: float, nbits: int, es: int) -> int: else: frac_bits.append("0") + # Regime bits if k >= 0: regime_bits = "1" * (k + 1) + "0" else: regime_bits = "0" * (-k) + "1" - bits = str(sign) - bits += regime_bits - exp_str = bin(exponent)[2:].zfill(es) - bits += exp_str + bits = str(sign) + regime_bits + + # Exponent bits + if es > 0: + exp_str = format(exponent & ((1 << es) - 1), f"0{es}b") + bits += exp_str + bits += "".join(frac_bits) - bits = bits[:nbits].ljust(nbits, "0") + # Handle rounding if bits exceed nbits + if len(bits) > nbits: + main = bits[:nbits] + guard = bits[nbits] + roundb = bits[nbits + 1] if nbits + 1 < len(bits) else "0" + sticky = "1" if "1" in bits[nbits + 2:] else "0" + + increment = ( + (guard == "1") + and (roundb == "1" or sticky == "1" or main[-1] == "1") + ) + + if increment: + main_int = int(main, 2) + 1 + if main_int >= (1 << (nbits - 1)): + main_int = (1 << (nbits - 1)) - 1 + main = format(main_int, f"0{nbits}b") + + bits = main + else: + bits = bits.ljust(nbits, "0") return int(bits, 2) \ No newline at end of file From 5aa148c4dfaf683a7cec61410311d6a5da6e161c Mon Sep 17 00:00:00 2001 From: arvindajaybharadwaj Date: Tue, 26 Aug 2025 00:15:48 +0530 Subject: [PATCH 15/33] Add doctest examples for positmatmul.py --- pyrtl/rtllib/positmatmul.py | 41 +++++++++++++++++++++++++++++++++---- 1 file changed, 37 insertions(+), 4 deletions(-) diff --git a/pyrtl/rtllib/positmatmul.py b/pyrtl/rtllib/positmatmul.py index d14dba93..833eb319 100644 --- a/pyrtl/rtllib/positmatmul.py +++ b/pyrtl/rtllib/positmatmul.py @@ -1,13 +1,46 @@ import pyrtl from pyrtl import PyrtlError -from pyrtl.rtllib.matrix import Matrix -from positadder import posit_add -from positmul import posit_mul +from pyrtl.rtllib.matrix import Matrix, matrix_wv_to_list +from pyrtl.rtllib.positadder import posit_add +from pyrtl.rtllib.positmul import posit_mul def posit_matmul(x: Matrix, y: Matrix, nbits: int, es: int) -> Matrix: """Performs matrix multiplication on posits. + .. doctest only:: + + >>> import pyrtl + >>> from pyrtl.rtllib.matrix import Matrix, matrix_wv_to_list + >>> pyrtl.reset_working_block() + + Example:: + + >>> nbits = 8 + >>> es = 1 + + >>> matrix_x = [[0b01000000, 0b01010000], [0b01011000, 0b01100000]] + >>> test_x = Matrix(2, 2, bits=nbits, value=matrix_x) + + >>> matrix_y = [[0b01000000, 0b01010000], [0b01011000, 0b01100000]] + >>> test_y = Matrix(2, 2, bits=nbits, value=matrix_y) + + >>> result = posit_matmul(test_x, test_y, nbits, es) + + >>> output = pyrtl.Output(name='output') + >>> output <<= result.to_wirevector() + + >>> sim = pyrtl.Simulation() + >>> sim.step() + + >>> raw_matrix = matrix_wv_to_list( + ... sim.inspect("output"), result.rows, result.columns, result.bits + ... ) + + >>> pretty_matrix = [[format(val, '08b') for val in row] for row in raw_matrix] + >>> pretty_matrix + [['01100110', '01101010'], ['01101111', '01110001']] + :param x: A :class:`Matrix` to be multiplied. :param y: A :class:`Matrix` to be multiplied. :param nbits: A :class:`int` representing the bitwidth of each cell of @@ -43,7 +76,7 @@ def posit_matmul(x: Matrix, y: Matrix, nbits: int, es: int) -> Matrix: for j in range(y.columns): acc = pyrtl.Const(0, bitwidth=nbits) for k in range(x.columns): - prod = posit_mul(nbits, es, x[i, k], y[k, j]) + prod = posit_mul(x[i, k], y[k, j], nbits, es) acc = posit_add(acc, prod, nbits, es) result[i, j] = acc From a286ab8d573499a86520fd733c0a3d7ebfaa3a37 Mon Sep 17 00:00:00 2001 From: arvindajaybharadwaj Date: Thu, 28 Aug 2025 22:06:34 +0530 Subject: [PATCH 16/33] Add test_positmatmul.py --- tests/rtllib/test_positmatmul.py | 97 ++++++++++++++++++++++++++++++++ 1 file changed, 97 insertions(+) create mode 100644 tests/rtllib/test_positmatmul.py diff --git a/tests/rtllib/test_positmatmul.py b/tests/rtllib/test_positmatmul.py new file mode 100644 index 00000000..1d6e079b --- /dev/null +++ b/tests/rtllib/test_positmatmul.py @@ -0,0 +1,97 @@ +import doctest +import random +import unittest + +import pyrtl +import pyrtl.rtllib.testingutils as utils +from pyrtl.rtllib.matrix import Matrix, matrix_wv_to_list +from pyrtl.rtllib.positmatmul import posit_matmul +from pyrtl.positutils import decimal_to_posit + + +class TestDocTests(unittest.TestCase): + """Test documentation examples.""" + + def test_doctests(self): + failures, tests = doctest.testmod(m=pyrtl.rtllib.positadder) + self.assertGreater(tests, 0) + self.assertEqual(failures, 0) + + +class PositMatrixTestBase(unittest.TestCase): + def check_against_expected(self, result, expected_output, rows, cols, nbits): + expected = Matrix(rows, cols, bits=nbits, value=expected_output) + + result_wv = pyrtl.Output(name='result') + expected_wv = pyrtl.Output(name='expected') + + result_wv <<= result.to_wirevector() + expected_wv <<= expected.to_wirevector() + + sim = pyrtl.Simulation() + sim.step({}) + + result_vals = matrix_wv_to_list(sim.inspect('result'), rows, cols, nbits) + expected_vals = matrix_wv_to_list(sim.inspect('expected'), rows, cols, nbits) + + for i in range(len(result_vals)): + for j in range(len(result_vals[0])): + self.assertLessEqual(abs(result_vals[i][j] - expected_vals[i][j]), 2) + + def generate_and_check(self, m, n, p, identity=False): + nbits_list = [8, 16] + es_list = [0, 1, 2] + nbits, es = random.choice(nbits_list), random.choice(es_list) + + useed = 2 ** (2 ** es) + maxpos = useed ** (nbits - 2) + + matrix_x_raw = [[random.randint(0, maxpos) for _ in range(n)] for _ in range(m)] + matrix_x = [[decimal_to_posit(val, nbits, es) for val in row] for row in matrix_x_raw] + test_x = Matrix(m, n, bits=nbits, value=matrix_x) + + if identity: + matrix_y_raw = [[1 if i == j else 0 for j in range(n)] for i in range(n)] + p = n + else: + matrix_y_raw = [[random.randint(0, maxpos) for _ in range(p)] for _ in range(n)] + matrix_y = [[decimal_to_posit(val, nbits, es) for val in row] for row in matrix_y_raw] + test_y = Matrix(n, p, bits=nbits, value=matrix_y) + + result = posit_matmul(test_x, test_y, nbits, es) + + expected_output_raw = [[0 for _ in range(p)] for _ in range(m)] + for i in range(m): + for j in range(p): + for k in range(n): + expected_output_raw[i][j] += matrix_x_raw[i][k] * matrix_y_raw[k][j] + expected_output = [[decimal_to_posit(val, nbits, es) for val in row] for row in expected_output_raw] + + self.check_against_expected(result, expected_output, m, p, nbits) + + +class TestPositMatmul(PositMatrixTestBase): + @classmethod + def setUpClass(cls): + random.seed(42) + + def setUp(self): + pyrtl.reset_working_block() + + def tearDown(self): + pyrtl.reset_working_block() + + def test_posit_matmul_identity(self): + m = random.randint(1, 5) + n = random.randint(1, 5) + self.generate_and_check(m, n, p=None, identity=True) + + def test_posit_matmul(self): + m = random.randint(1, 5) + n = random.randint(1, 5) + p = random.randint(1, 5) + self.generate_and_check(m, n, p, identity=False) + + +if __name__ == "__main__": + unittest.main() \ No newline at end of file From da680dd2414610e34b4b1befd7562bdf629b7825 Mon Sep 17 00:00:00 2001 From: arvindajaybharadwaj Date: Thu, 28 Aug 2025 22:08:03 +0530 Subject: [PATCH 17/33] fix: remove WireVector names to prevent duplication error --- pyrtl/rtllib/positmul.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pyrtl/rtllib/positmul.py b/pyrtl/rtllib/positmul.py index 1051fbcc..e3f462a8 100644 --- a/pyrtl/rtllib/positmul.py +++ b/pyrtl/rtllib/positmul.py @@ -37,7 +37,7 @@ def posit_mul( either_zero = (a == 0) | (b == 0) either_inf = (a == (1 << (nbits - 1))) | (b == (1 << (nbits - 1))) - final_value = pyrtl.WireVector(bitwidth=nbits, name='final_value') + final_value = pyrtl.WireVector(bitwidth=nbits) result_zero = pyrtl.Const(0, bitwidth=nbits) result_nar = pyrtl.Const(1 << (nbits - 1), bitwidth=nbits) normal_case = ~(either_zero | either_inf) @@ -117,7 +117,7 @@ def posit_mul( frac_final = pyrtl.WireVector(bitwidth=nbits) frac_final <<= pyrtl.select(cond_round, frac_shifted, frac_shifted_else) - roundup_bit = pyrtl.WireVector(bitwidth=1, name='roundup_bit') + roundup_bit = pyrtl.WireVector(bitwidth=1) roundup_bit <<= pyrtl.select(cond_round, roundup_candidate, pyrtl.Const(0, bitwidth=1)) exp_shifted_large = shift_left_logical(resultExponent, frac_bits) From 0f0ba20de39fc83a60c4d4c3c4d9d5ae16e98f86 Mon Sep 17 00:00:00 2001 From: arvindajaybharadwaj Date: Thu, 28 Aug 2025 22:23:30 +0530 Subject: [PATCH 18/33] update: format to pep8 --- tests/rtllib/test_positmatmul.py | 48 ++++++++++++++++++++++++-------- 1 file changed, 37 insertions(+), 11 deletions(-) diff --git a/tests/rtllib/test_positmatmul.py b/tests/rtllib/test_positmatmul.py index 1d6e079b..8e80392a 100644 --- a/tests/rtllib/test_positmatmul.py +++ b/tests/rtllib/test_positmatmul.py @@ -31,12 +31,18 @@ def check_against_expected(self, result, expected_output, rows, cols, nbits): sim = pyrtl.Simulation() sim.step({}) - result_vals = matrix_wv_to_list(sim.inspect('result'), rows, cols, nbits) - expected_vals = matrix_wv_to_list(sim.inspect('expected'), rows, cols, nbits) + result_vals = matrix_wv_to_list( + sim.inspect('result'), rows, cols, nbits + ) + expected_vals = matrix_wv_to_list( + sim.inspect('expected'), rows, cols, nbits + ) for i in range(len(result_vals)): for j in range(len(result_vals[0])): - self.assertLessEqual(abs(result_vals[i][j] - expected_vals[i][j]), 2) + self.assertLessEqual( + abs(result_vals[i][j] - expected_vals[i][j]), 2 + ) def generate_and_check(self, m, n, p, identity=False): nbits_list = [8, 16] @@ -46,16 +52,30 @@ def generate_and_check(self, m, n, p, identity=False): useed = 2 ** (2 ** es) maxpos = useed ** (nbits - 2) - matrix_x_raw = [[random.randint(0, maxpos) for _ in range(n)] for _ in range(m)] - matrix_x = [[decimal_to_posit(val, nbits, es) for val in row] for row in matrix_x_raw] + matrix_x_raw = [ + [random.randint(0, maxpos) for _ in range(n)] for _ in range(m) + ] + matrix_x = [ + [decimal_to_posit(val, nbits, es) for val in row] + for row in matrix_x_raw + ] test_x = Matrix(m, n, bits=nbits, value=matrix_x) if identity: - matrix_y_raw = [[1 if i == j else 0 for j in range(n)] for i in range(n)] + matrix_y_raw = [ + [1 if i == j else 0 for j in range(n)] for i in range(n) + ] p = n else: - matrix_y_raw = [[random.randint(0, maxpos) for _ in range(p)] for _ in range(n)] - matrix_y = [[decimal_to_posit(val, nbits, es) for val in row] for row in matrix_y_raw] + matrix_y_raw = [ + [random.randint(0, maxpos) for _ in range(p)] + for _ in range(n) + ] + + matrix_y = [ + [decimal_to_posit(val, nbits, es) for val in row] + for row in matrix_y_raw + ] test_y = Matrix(n, p, bits=nbits, value=matrix_y) result = posit_matmul(test_x, test_y, nbits, es) @@ -64,8 +84,14 @@ def generate_and_check(self, m, n, p, identity=False): for i in range(m): for j in range(p): for k in range(n): - expected_output_raw[i][j] += matrix_x_raw[i][k] * matrix_y_raw[k][j] - expected_output = [[decimal_to_posit(val, nbits, es) for val in row] for row in expected_output_raw] + expected_output_raw[i][j] += ( + matrix_x_raw[i][k] * matrix_y_raw[k][j] + ) + + expected_output = [ + [decimal_to_posit(val, nbits, es) for val in row] + for row in expected_output_raw + ] self.check_against_expected(result, expected_output, m, p, nbits) @@ -94,4 +120,4 @@ def test_posit_matmul(self): if __name__ == "__main__": - unittest.main() \ No newline at end of file + unittest.main() From 0b8626803a3887c80f01d91fca0ce8cd540ffa82 Mon Sep 17 00:00:00 2001 From: Sunidhi M Date: Thu, 28 Aug 2025 22:37:47 +0530 Subject: [PATCH 19/33] Create test_positmul.py --- tests/rtllib/test_positmul.py | 54 +++++++++++++++++++++++++++++++++++ 1 file changed, 54 insertions(+) create mode 100644 tests/rtllib/test_positmul.py diff --git a/tests/rtllib/test_positmul.py b/tests/rtllib/test_positmul.py new file mode 100644 index 00000000..e41a0f53 --- /dev/null +++ b/tests/rtllib/test_positmul.py @@ -0,0 +1,54 @@ +import doctest +import random +import unittest + +import pyrtl +import pyrtl.rtllib.testingutils as utils +from pyrtl.rtllib.positmul import posit_mul +from pyrtl.positutils import decimal_to_posit + +class TestDocTests(unittest.TestCase): + """Test documentation examples.""" + + def test_doctests(self): + failures, tests = doctest.testmod(m=pyrtl.rtllib.positmul) + self.assertGreater(tests, 0) + self.assertEqual(failures, 0) + +class TestPositMultiplier(unittest.TestCase): + @classmethod + def setUpClass(cls): + random.seed(42) + + def setUp(self): + pyrtl.reset_working_block() + + def tearDown(self): + pyrtl.reset_working_block() + + def test_posit_multiplier(self): + nbits_list = [8, 16, 32] + es_list = [0, 1, 2, 3, 4] + nbits, es = random.choice(nbits_list), random.choice(es_list) + a = pyrtl.Input(bitwidth=nbits, name="a") + b = pyrtl.Input(bitwidth=nbits, name="b") + out = pyrtl.Output(name="out") + + out <<= posit_mul(a, b, nbits, es) + + useed = 2 ** (2 ** es) + maxpos = useed ** (nbits - 2) + + wires = [a, b] + vals_raw = [[random.randint(0, maxpos) for _ in range(7)] for _ in wires] + vals = [[decimal_to_posit(j, nbits, es) for j in i] for i in vals_raw] + + out_vals = utils.sim_and_ret_out(out, wires, vals) + true_result_raw = [x * y for x, y in zip(vals_raw[0], vals_raw[1])] + true_result = [decimal_to_posit(i, nbits, es) for i in true_result_raw] + + for sim, expected in zip(out_vals, true_result): + self.assertLessEqual(abs(sim - expected), 1) + +if __name__ == "__main__": + unittest.main() From 13e65a35f48c263a3df1ead67b339190dfebb2cd Mon Sep 17 00:00:00 2001 From: Srijan Shivashankar Kotabagi <162854010+Srijansk17@users.noreply.github.com> Date: Thu, 4 Sep 2025 19:40:51 +0530 Subject: [PATCH 20/33] Add files via upload --- pyrtl/rtllib/Posits_Sub.py | 311 +++++++++++++++++++++++++++++++++++++ 1 file changed, 311 insertions(+) create mode 100644 pyrtl/rtllib/Posits_Sub.py diff --git a/pyrtl/rtllib/Posits_Sub.py b/pyrtl/rtllib/Posits_Sub.py new file mode 100644 index 00000000..a1816f35 --- /dev/null +++ b/pyrtl/rtllib/Posits_Sub.py @@ -0,0 +1,311 @@ +import pyrtl +from pyrtl.corecircuits import shift_left_logical, shift_right_logical +import math + +def create_posit_subtractor_n_es(n, es): + """ + Generates a PyRTL hardware design for a posit subtractor. + + This function is a factory that creates a subtractor circuit customized + for the given posit configuration (n, es). + + Args: + n (int): The total number of bits in the posit format. + es (int): The number of exponent bits in the posit format. + """ + pyrtl.reset_working_block() + + # --- Derived Parameters --- + # These widths are calculated to be large enough for any valid n, es. + # Width for counters/lengths that go up to n (e.g., run_len). + n_bits_width = math.ceil(math.log2(n + 1)) if n > 0 else 1 + # Width for the regime value 'k'. Needs to be signed and can range approx -n to +n. + k_width = n_bits_width + 2 + # Width for the scale value (k << es + e). + scale_width = k_width + es + # Width for the offset (difference between two scales). + offset_width = scale_width + 1 + # A generous internal width for fraction arithmetic to prevent overflow during shifts. + # The maximum shift can be large, so 3*n provides a safe margin. + internal_width = 3 * n + + # --- Inputs / Outputs --- + a = pyrtl.Input(bitwidth=n, name='a') + b = pyrtl.Input(bitwidth=n, name='b') + result = pyrtl.Output(bitwidth=n, name='result') + + # --- Decode Posit --- + def decode_posit(x): + """ Decodes a posit into its sign, k, exponent, and fraction components. """ + is_zero = (x == 0) + is_nar = (x == (1 << (n - 1))) + + sign = x[n - 1] + + # Invert bits if sign is 1 (two's complement representation of posits) + x_abs = pyrtl.select(sign, ~x + 1, x) + + regime_bits = x_abs[0:n-1] + + # Find the end of the regime + # First bit determines if regime is 1s or 0s + regime_term_bit = ~regime_bits[n-2] + + run_len = pyrtl.WireVector(bitwidth=n_bits_width) + + # Priority encoder to find first regime terminating bit + found = pyrtl.Const(0, 1) + for i in reversed(range(n-1)): + is_terminator = (regime_bits[i] == regime_term_bit) + with pyrtl.conditional_assignment: + with ~found & is_terminator: + run_len |= (n-2-i) + found = found | is_terminator + + with pyrtl.conditional_assignment: + with ~found: # All bits are the same + run_len |= n-1 + + + regime_val = pyrtl.select(regime_term_bit, -run_len, run_len - 1) + k = pyrtl.as_signed(regime_val) + k.bitwidth = k_width + + # Calculate remaining bits after regime + rem_len = n - 2 - run_len + + # Extract exponent and fraction + exp = pyrtl.WireVector(bitwidth=es) + frac_max_len = n - 2 - es # Theoretical max + frac = pyrtl.WireVector(bitwidth=max(1, frac_max_len)) + + with pyrtl.conditional_assignment: + with rem_len >= es: + exp_and_frac = x_abs << (run_len + 2) + exp |= exp_and_frac[n-1-es:n-1] + frac_len = rem_len - es + frac_shifted = exp_and_frac << es + frac |= frac_shifted[n-1-frac_max_len: n-1] if frac_max_len > 0 else 0 + with rem_len < es: + exp_and_frac = x_abs << (run_len + 2) + exp |= exp_and_frac[n-1-rem_len:n-1] << (es - rem_len) if rem_len > 0 else 0 + frac_len = 0 + frac |= 0 + + # Handle special cases at the end + with pyrtl.conditional_assignment: + with is_zero: + k |= 0 + exp |= 0 + frac |= 0 + with is_nar: + # NaR doesn't have a valid k, exp, or frac + pass + + return sign, k, exp, frac, is_zero, is_nar + + + # --- Add the hidden bit to the fraction --- + def frac_with_hidden_one(frac): + """ Prepends the implicit '1' bit to the fraction for calculations. """ + # The fraction is always prepended with a '1' unless the value is zero/NaR + hidden_bit = pyrtl.Const(1, bitwidth=1) + # Combine with the explicit fraction bits. + full_frac = pyrtl.concat(hidden_bit, frac) + # Zero-extend to the internal working width. + return pyrtl.concat(pyrtl.Const(0, bitwidth=internal_width - full_frac.bitwidth), full_frac) + + + # --- Main Posit Subtractor Logic --- + def posit_sub(a_in, b_in): + """ Top-level function for the subtraction a - b. """ + s1, k1, e1, f1, is_a_zero, is_a_nar = decode_posit(a_in) + s2_orig, k2, e2, f2, is_b_zero, is_b_nar = decode_posit(b_in) + + # For subtraction a - b, we calculate a + (-b). + # The sign of -b is ~s2_orig. + s2 = ~s2_orig + + # Handle special cases immediately + nar_result = pyrtl.Const((1 << (n - 1)), bitwidth=n) + with pyrtl.conditional_assignment: + with is_a_nar | is_b_nar: + return nar_result + with is_a_zero: + return pyrtl.concat(~s2_orig, b_in[0:n-1]) # return -b + with is_b_zero: + return a_in # return a + + # Effective operation is addition if signs are the same, subtraction otherwise + op_is_add = (s1 == s2) + + # Calculate the total scale for each posit + scale1 = (pyrtl.as_signed(k1) << es) + pyrtl.as_signed(e1) + scale2 = (pyrtl.as_signed(k2) << es) + pyrtl.as_signed(e2) + scale1.bitwidth = scale_width + scale2.bitwidth = scale_width + + # Get fractions with the hidden bit prepended + frac1_full = frac_with_hidden_one(f1) + frac2_full = frac_with_hidden_one(f2) + + # Align fractions by shifting the one with the smaller scale + offset = pyrtl.as_signed(scale1) - pyrtl.as_signed(scale2) + offset.bitwidth = offset_width + + shifted1 = pyrtl.WireVector(bitwidth=internal_width) + shifted2 = pyrtl.WireVector(bitwidth=internal_width) + result_scale = pyrtl.WireVector(bitwidth=scale_width) + + offset_is_neg = offset[-1] + abs_offset = pyrtl.select(offset_is_neg, -offset, offset) + + with pyrtl.conditional_assignment: + with offset_is_neg: # scale1 < scale2, shift frac1 right + shifted1 |= frac1_full >> abs_offset + shifted2 |= frac2_full + result_scale |= scale2 + with ~offset_is_neg: # scale1 >= scale2, shift frac2 right + shifted1 |= frac1_full + shifted2 |= frac2_full >> abs_offset + result_scale |= scale1 + + # Perform addition or subtraction on aligned fractions + mag_result = pyrtl.WireVector(bitwidth=internal_width) + result_sign = pyrtl.WireVector(bitwidth=1) + + with pyrtl.conditional_assignment: + with op_is_add: + mag_result |= shifted1 + shifted2 + result_sign |= s1 + with ~op_is_add: + # Subtraction of magnitudes + with shifted1 >= shifted2: + mag_result |= shifted1 - shifted2 + result_sign |= s1 + with shifted1 < shifted2: + mag_result |= shifted2 - shifted1 + result_sign |= s2 + + is_result_zero = (mag_result == 0) + + # Normalize the result: find the MSB to adjust scale and fraction + msb_pos = pyrtl.WireVector(bitwidth=n_bits_width) + found_msb = pyrtl.Const(0, 1) + for i in reversed(range(internal_width)): + is_set = mag_result[i] + with pyrtl.conditional_assignment: + with ~found_msb & is_set: + msb_pos |= i + found_msb = found_msb | is_set + + # The alignment point of the hidden bit was at frac_max_len. + # It is now at msb_pos. The scale must be adjusted by the difference. + scale_adjustment = pyrtl.as_signed(msb_pos) - (max(1, frac_max_len) + 1) + final_scale = pyrtl.as_signed(result_scale) + scale_adjustment + + # --- Re-encode into Posit Format --- + # This is a simplified re-encoding and does not include rounding logic (e.g., round-to-nearest-even) + + final_k_signed = pyrtl.as_signed(final_scale) >> es + final_e = final_scale & ((1 << es) - 1) + + # Build the regime + k_is_neg = final_k_signed[-1] + abs_k = pyrtl.select(k_is_neg, -final_k_signed, final_k_signed) + + run_len = pyrtl.select(k_is_neg, abs_k, abs_k + 1) + + regime = pyrtl.WireVector(n-1) + with pyrtl.conditional_assignment: + with k_is_neg: # Regime of 0s followed by a 1 + # e.g., if run_len = 3, we want ...0001... + regime |= pyrtl.Const(1) << (n - 2 - run_len) + with ~k_is_neg: # Regime of 1s followed by a 0 + # e.g., if run_len = 3, we want ...1110... + regime |= ((pyrtl.Const(1) << run_len) -1 ) << (n-1-run_len) + + # Remove the new hidden bit from the fraction + frac_to_encode = mag_result << (internal_width - msb_pos) + + # Combine exponent and fraction + exp_and_frac = pyrtl.concat(final_e, frac_to_encode) + + # Number of available bits after the regime + rem_bits = n - 2 - run_len + + # Combine all parts + unsigned_result = pyrtl.WireVector(n-1) + with pyrtl.conditional_assignment: + with rem_bits > 0: + # Shift exp+frac into position + shifted_exp_frac = exp_and_frac >> (exp_and_frac.bitwidth - rem_bits) + unsigned_result |= regime | shifted_exp_frac + with rem_bits <= 0: + unsigned_result |= regime + + # Two's complement the result if the sign is negative + final_unsigned = pyrtl.concat(pyrtl.Const(0, 1), unsigned_result) + final_signed = pyrtl.select(result_sign, -final_unsigned, final_unsigned) + + #handle the the edge case of 1|0 and 0|1 + # Final result with special case handling + final_posit = pyrtl.select( + is_result_zero, + truecase=pyrtl.Const(0, bitwidth=n), + falsecase=final_signed + ) + return final_posit + + # --- Instantiate the subtractor circuit --- + result <<= posit_sub(a, b) + return pyrtl.working_block() + +# --- Simulation and Test --- +if __name__ == '__main__': + # --- Configuration --- + # You can change these values to test different posit formats + N_BITS = 8 + ES_BITS = 1 + + print(f"--- Testing Posit<{N_BITS}, {ES_BITS}> Subtractor ---") + + # Create the hardware design for the specified configuration + # This needs to be done before any simulation setup + try: + posit_sub_block = create_posit_subtractor_n_es(n=N_BITS, es=ES_BITS) + except pyrtl.PyrtlError as e: + print(f"Error creating PyRTL block: {e}") + print("Please ensure n and es values are valid (e.g., n > es + 2)") + exit() + + + # Setup simulation + sim_trace = pyrtl.SimulationTrace() + sim = pyrtl.Simulation(tracer=sim_trace, block=posit_sub_block) + + # --- Test Case --- + # a = 0b01000100 -> Posit<8,1> represents 0.5 + # b = 0b01101000 -> Posit<8,1> represents 4.0 + # a - b = 0.5 - 4.0 = -3.5 + # -3.5 in Posit<8,1> is approximately 0b10011100 (which is the encoding for -4.0, the closest value) + # The exact value -3.5 cannot be represented. Let's test with representable values. + # + # Test 2: 4.0 - 0.5 = 3.5 + # a = 4.0 -> 0b01101000 + # b = 0.5 -> 0b01000100 + # result should be 3.5 -> approx 0b01100110 (which is 3.0, the nearest value) + + a_val = 0b01101000 # 4.0 + b_val = 0b01000100 # 0.5 + + # Run the simulation step + sim.step({'a': a_val, 'b': b_val}) + + # Inspect the simulation output + result_val = sim.inspect('result') + + print(f"\nInput 'a' : {a_val:0{N_BITS}b} (Represents ~4.0)") + print(f"Input 'b' : {b_val:0{N_BITS}b} (Represents 0.5)") + print(f"Result a-b : {result_val:0{N_BITS}b} (Represents ~3.0)") + print(f"Expected approx: {0b01100110:0{N_BITS}b}") \ No newline at end of file From 1e1621356daf1409d94306cdeb4d20753f800079 Mon Sep 17 00:00:00 2001 From: Sunidhi M Date: Fri, 5 Sep 2025 11:58:41 +0530 Subject: [PATCH 21/33] Fix: Debug positsub --- pyrtl/rtllib/Posits_Sub.py | 502 ++++++++++++++----------------------- 1 file changed, 194 insertions(+), 308 deletions(-) diff --git a/pyrtl/rtllib/Posits_Sub.py b/pyrtl/rtllib/Posits_Sub.py index a1816f35..53bd7c90 100644 --- a/pyrtl/rtllib/Posits_Sub.py +++ b/pyrtl/rtllib/Posits_Sub.py @@ -1,311 +1,197 @@ import pyrtl from pyrtl.corecircuits import shift_left_logical, shift_right_logical -import math - -def create_posit_subtractor_n_es(n, es): - """ - Generates a PyRTL hardware design for a posit subtractor. - - This function is a factory that creates a subtractor circuit customized - for the given posit configuration (n, es). - - Args: - n (int): The total number of bits in the posit format. - es (int): The number of exponent bits in the posit format. - """ - pyrtl.reset_working_block() - - # --- Derived Parameters --- - # These widths are calculated to be large enough for any valid n, es. - # Width for counters/lengths that go up to n (e.g., run_len). - n_bits_width = math.ceil(math.log2(n + 1)) if n > 0 else 1 - # Width for the regime value 'k'. Needs to be signed and can range approx -n to +n. - k_width = n_bits_width + 2 - # Width for the scale value (k << es + e). - scale_width = k_width + es - # Width for the offset (difference between two scales). - offset_width = scale_width + 1 - # A generous internal width for fraction arithmetic to prevent overflow during shifts. - # The maximum shift can be large, so 3*n provides a safe margin. - internal_width = 3 * n - - # --- Inputs / Outputs --- - a = pyrtl.Input(bitwidth=n, name='a') - b = pyrtl.Input(bitwidth=n, name='b') - result = pyrtl.Output(bitwidth=n, name='result') - - # --- Decode Posit --- - def decode_posit(x): - """ Decodes a posit into its sign, k, exponent, and fraction components. """ - is_zero = (x == 0) - is_nar = (x == (1 << (n - 1))) - - sign = x[n - 1] - - # Invert bits if sign is 1 (two's complement representation of posits) - x_abs = pyrtl.select(sign, ~x + 1, x) - - regime_bits = x_abs[0:n-1] - - # Find the end of the regime - # First bit determines if regime is 1s or 0s - regime_term_bit = ~regime_bits[n-2] - - run_len = pyrtl.WireVector(bitwidth=n_bits_width) - - # Priority encoder to find first regime terminating bit - found = pyrtl.Const(0, 1) - for i in reversed(range(n-1)): - is_terminator = (regime_bits[i] == regime_term_bit) - with pyrtl.conditional_assignment: - with ~found & is_terminator: - run_len |= (n-2-i) - found = found | is_terminator - - with pyrtl.conditional_assignment: - with ~found: # All bits are the same - run_len |= n-1 - - - regime_val = pyrtl.select(regime_term_bit, -run_len, run_len - 1) - k = pyrtl.as_signed(regime_val) - k.bitwidth = k_width - - # Calculate remaining bits after regime - rem_len = n - 2 - run_len - - # Extract exponent and fraction - exp = pyrtl.WireVector(bitwidth=es) - frac_max_len = n - 2 - es # Theoretical max - frac = pyrtl.WireVector(bitwidth=max(1, frac_max_len)) - - with pyrtl.conditional_assignment: - with rem_len >= es: - exp_and_frac = x_abs << (run_len + 2) - exp |= exp_and_frac[n-1-es:n-1] - frac_len = rem_len - es - frac_shifted = exp_and_frac << es - frac |= frac_shifted[n-1-frac_max_len: n-1] if frac_max_len > 0 else 0 - with rem_len < es: - exp_and_frac = x_abs << (run_len + 2) - exp |= exp_and_frac[n-1-rem_len:n-1] << (es - rem_len) if rem_len > 0 else 0 - frac_len = 0 - frac |= 0 - - # Handle special cases at the end - with pyrtl.conditional_assignment: - with is_zero: - k |= 0 - exp |= 0 - frac |= 0 - with is_nar: - # NaR doesn't have a valid k, exp, or frac - pass - - return sign, k, exp, frac, is_zero, is_nar - - - # --- Add the hidden bit to the fraction --- - def frac_with_hidden_one(frac): - """ Prepends the implicit '1' bit to the fraction for calculations. """ - # The fraction is always prepended with a '1' unless the value is zero/NaR - hidden_bit = pyrtl.Const(1, bitwidth=1) - # Combine with the explicit fraction bits. - full_frac = pyrtl.concat(hidden_bit, frac) - # Zero-extend to the internal working width. - return pyrtl.concat(pyrtl.Const(0, bitwidth=internal_width - full_frac.bitwidth), full_frac) - - - # --- Main Posit Subtractor Logic --- - def posit_sub(a_in, b_in): - """ Top-level function for the subtraction a - b. """ - s1, k1, e1, f1, is_a_zero, is_a_nar = decode_posit(a_in) - s2_orig, k2, e2, f2, is_b_zero, is_b_nar = decode_posit(b_in) - - # For subtraction a - b, we calculate a + (-b). - # The sign of -b is ~s2_orig. - s2 = ~s2_orig - - # Handle special cases immediately - nar_result = pyrtl.Const((1 << (n - 1)), bitwidth=n) - with pyrtl.conditional_assignment: - with is_a_nar | is_b_nar: - return nar_result - with is_a_zero: - return pyrtl.concat(~s2_orig, b_in[0:n-1]) # return -b - with is_b_zero: - return a_in # return a - - # Effective operation is addition if signs are the same, subtraction otherwise - op_is_add = (s1 == s2) - - # Calculate the total scale for each posit - scale1 = (pyrtl.as_signed(k1) << es) + pyrtl.as_signed(e1) - scale2 = (pyrtl.as_signed(k2) << es) + pyrtl.as_signed(e2) - scale1.bitwidth = scale_width - scale2.bitwidth = scale_width - - # Get fractions with the hidden bit prepended - frac1_full = frac_with_hidden_one(f1) - frac2_full = frac_with_hidden_one(f2) - - # Align fractions by shifting the one with the smaller scale - offset = pyrtl.as_signed(scale1) - pyrtl.as_signed(scale2) - offset.bitwidth = offset_width - - shifted1 = pyrtl.WireVector(bitwidth=internal_width) - shifted2 = pyrtl.WireVector(bitwidth=internal_width) - result_scale = pyrtl.WireVector(bitwidth=scale_width) - - offset_is_neg = offset[-1] - abs_offset = pyrtl.select(offset_is_neg, -offset, offset) - - with pyrtl.conditional_assignment: - with offset_is_neg: # scale1 < scale2, shift frac1 right - shifted1 |= frac1_full >> abs_offset - shifted2 |= frac2_full - result_scale |= scale2 - with ~offset_is_neg: # scale1 >= scale2, shift frac2 right - shifted1 |= frac1_full - shifted2 |= frac2_full >> abs_offset - result_scale |= scale1 - - # Perform addition or subtraction on aligned fractions - mag_result = pyrtl.WireVector(bitwidth=internal_width) - result_sign = pyrtl.WireVector(bitwidth=1) - - with pyrtl.conditional_assignment: - with op_is_add: - mag_result |= shifted1 + shifted2 - result_sign |= s1 - with ~op_is_add: - # Subtraction of magnitudes - with shifted1 >= shifted2: - mag_result |= shifted1 - shifted2 - result_sign |= s1 - with shifted1 < shifted2: - mag_result |= shifted2 - shifted1 - result_sign |= s2 - - is_result_zero = (mag_result == 0) - - # Normalize the result: find the MSB to adjust scale and fraction - msb_pos = pyrtl.WireVector(bitwidth=n_bits_width) - found_msb = pyrtl.Const(0, 1) - for i in reversed(range(internal_width)): - is_set = mag_result[i] - with pyrtl.conditional_assignment: - with ~found_msb & is_set: - msb_pos |= i - found_msb = found_msb | is_set - - # The alignment point of the hidden bit was at frac_max_len. - # It is now at msb_pos. The scale must be adjusted by the difference. - scale_adjustment = pyrtl.as_signed(msb_pos) - (max(1, frac_max_len) + 1) - final_scale = pyrtl.as_signed(result_scale) + scale_adjustment - - # --- Re-encode into Posit Format --- - # This is a simplified re-encoding and does not include rounding logic (e.g., round-to-nearest-even) - - final_k_signed = pyrtl.as_signed(final_scale) >> es - final_e = final_scale & ((1 << es) - 1) - - # Build the regime - k_is_neg = final_k_signed[-1] - abs_k = pyrtl.select(k_is_neg, -final_k_signed, final_k_signed) - - run_len = pyrtl.select(k_is_neg, abs_k, abs_k + 1) - - regime = pyrtl.WireVector(n-1) - with pyrtl.conditional_assignment: - with k_is_neg: # Regime of 0s followed by a 1 - # e.g., if run_len = 3, we want ...0001... - regime |= pyrtl.Const(1) << (n - 2 - run_len) - with ~k_is_neg: # Regime of 1s followed by a 0 - # e.g., if run_len = 3, we want ...1110... - regime |= ((pyrtl.Const(1) << run_len) -1 ) << (n-1-run_len) - - # Remove the new hidden bit from the fraction - frac_to_encode = mag_result << (internal_width - msb_pos) - - # Combine exponent and fraction - exp_and_frac = pyrtl.concat(final_e, frac_to_encode) - - # Number of available bits after the regime - rem_bits = n - 2 - run_len - - # Combine all parts - unsigned_result = pyrtl.WireVector(n-1) - with pyrtl.conditional_assignment: - with rem_bits > 0: - # Shift exp+frac into position - shifted_exp_frac = exp_and_frac >> (exp_and_frac.bitwidth - rem_bits) - unsigned_result |= regime | shifted_exp_frac - with rem_bits <= 0: - unsigned_result |= regime - - # Two's complement the result if the sign is negative - final_unsigned = pyrtl.concat(pyrtl.Const(0, 1), unsigned_result) - final_signed = pyrtl.select(result_sign, -final_unsigned, final_unsigned) - - #handle the the edge case of 1|0 and 0|1 - # Final result with special case handling - final_posit = pyrtl.select( - is_result_zero, - truecase=pyrtl.Const(0, bitwidth=n), - falsecase=final_signed +from positutils import ( + decode_posit, + get_upto_regime, + frac_with_hidden_one, + remove_first_one, + twos_comp +) + + +def posit_sub(a, b, nbits, es): + + inf = pyrtl.Const((1 << nbits) - 1, bitwidth=nbits) + is_inf = (a == inf) | (b == inf) + is_a_zero = (a == 0) + is_b_zero = (b == 0) + + signbit1, k_a, exp_a, frac_a, frac_len_a = decode_posit(a, nbits, es) + signbit2, k_b, exp_b, frac_b, frac_len_b = decode_posit(b, nbits, es) + + sign_final = signbit1 ^ signbit2 + + neg_b = ~b + 1 + + res_inf = inf + res_a_zero = neg_b + res_b_zero = a + scale_a = (k_a if es == 0 else shift_left_logical(k_a, es)) + exp_a + scale_b = (k_b if es == 0 else shift_left_logical(k_b, es)) + exp_b + + frac_bits = pyrtl.select(frac_len_a > frac_len_b, frac_len_a, frac_len_b) + + shift_amt1 = frac_a - frac_b + shift_amt2 = frac_b - frac_a + + frac2_shifted = pyrtl.shift_left_logical(frac_b, shift_amt1) + frac1_shifted = pyrtl.shift_right_logical(frac_a, shift_amt2) + + + frac_a = pyrtl.select(frac_a > frac_b, frac_a, frac1_shifted) + frac_b = pyrtl.select(frac_a > frac_b, frac2_shifted, frac_b) + + offset = scale_a - scale_b + is_equal = (offset == 0) + + frac_diff = frac_a - frac_b + + resultFrac_default = pyrtl.Const(0, bitwidth=nbits) + resultScale_default = pyrtl.Const(0, bitwidth=nbits) + + resultFrac = pyrtl.select(is_equal, frac_diff, resultFrac_default) + resultScale = pyrtl.select(is_equal, scale_a, resultScale_default) + + is_offset_pos = offset > 0 + + one_shifted = shift_left_logical(pyrtl.Const(1, bitwidth=nbits), frac_bits) + frac1_ext = frac_a + one_shifted + frac2_ext = frac_b + one_shifted + + + shifted_frac1 = shift_left_logical(frac1_ext, offset) + + frac_diff_pos = shifted_frac1 - frac2_ext + + resultFrac_default = pyrtl.Const(0, bitwidth=nbits) + resultScale_default = pyrtl.Const(0, bitwidth=nbits) + + resultFrac = pyrtl.select(is_offset_pos, frac_diff_pos, resultFrac_default) + resultScale = pyrtl.select(is_offset_pos, scale_a, resultScale_default) + + is_offset_neg = offset < 0 + + signbit1_flipped = signbit1 ^ pyrtl.Const(1, bitwidth=1) + + one_shifted = shift_left_logical(pyrtl.Const(1, bitwidth=nbits), frac_bits) + frac1_ext = frac_a + one_shifted + frac2_ext = frac_b + one_shifted + + offset_neg = ~offset + pyrtl.Const(1, bitwidth=offset.bitwidth) + + offset_abs = pyrtl.select(is_offset_neg, offset_neg, offset) + + shifted_frac2 = shift_left_logical(frac2_ext, offset_abs) + + frac_diff_neg = shifted_frac2 - frac1_ext + + + resultFrac_default2 = pyrtl.Const(0, bitwidth=nbits) + resultScale_default2 = pyrtl.Const(0, bitwidth=nbits) + + is_offset_neg = offset < 0 + resultFrac = pyrtl.select(is_offset_neg, frac_diff_neg, resultFrac_default2) + resultScale = pyrtl.select(is_offset_neg, scale_b, resultScale_default2) + signbit1 = pyrtl.select(is_offset_neg, signbit1_flipped, signbit1) + + is_resultFrac_neg = resultFrac[-1] + + signbit1_flipped = signbit1 ^ pyrtl.Const(1, bitwidth=1) + + resultFrac_neg = ~resultFrac + pyrtl.Const(1, bitwidth=resultFrac.bitwidth) + + signbit1 = pyrtl.select(is_resultFrac_neg, signbit1_flipped, signbit1) + resultFrac = pyrtl.select(is_resultFrac_neg, resultFrac_neg, resultFrac) + + + bitlength = resultFrac.bitwidth + + needs_shift = frac_bits > bitlength + shift_amt = frac_bits - bitlength + + + shifted_resultFrac = shift_left_logical(resultFrac, shift_amt) + + resultFrac_if = pyrtl.select(needs_shift, shifted_resultFrac, resultFrac) + resultScale_if = pyrtl.select(needs_shift, resultScale - shift_amt, resultScale) + bitlength_if = pyrtl.select(needs_shift, frac_bits, bitlength) + + resultFrac = resultFrac_if + resultScale = resultScale_if + bitlength = bitlength_if + + offset_neg = ~offset + pyrtl.Const(1, bitwidth=offset.bitwidth) + offset_abs = pyrtl.select(offset < 0, offset_neg, offset) + + + resultScale = resultScale + (bitlength - pyrtl.Const(1, bitwidth=bitlength.bitwidth) + - offset_abs - frac_bits) + + mask = (1 << es) - 1 + mask_const = pyrtl.Const(mask, bitwidth=resultScale.bitwidth) + + scale_mod = resultScale & mask_const + + resultExponent = resultScale & scale_mod + + resultk = shift_right_logical(resultScale, es) + + rem_bits, regime = get_upto_regime(resultk, nbits, sign_final) + + frac_bits = rem_bits - es + bitlength_gt = bitlength > (frac_bits + 1) + bitlength_lt = bitlength < (frac_bits + 1) + + resultFrac_shifted = pyrtl.select(bitlength_gt, + pyrtl.shift_right_logical(resultFrac, bitlength - frac_bits - 1), + pyrtl.select(bitlength_lt, + pyrtl.shift_left_logical(resultFrac, frac_bits + 1 - bitlength), + resultFrac ) - return final_posit - - # --- Instantiate the subtractor circuit --- - result <<= posit_sub(a, b) - return pyrtl.working_block() - -# --- Simulation and Test --- -if __name__ == '__main__': - # --- Configuration --- - # You can change these values to test different posit formats - N_BITS = 8 - ES_BITS = 1 - - print(f"--- Testing Posit<{N_BITS}, {ES_BITS}> Subtractor ---") - - # Create the hardware design for the specified configuration - # This needs to be done before any simulation setup - try: - posit_sub_block = create_posit_subtractor_n_es(n=N_BITS, es=ES_BITS) - except pyrtl.PyrtlError as e: - print(f"Error creating PyRTL block: {e}") - print("Please ensure n and es values are valid (e.g., n > es + 2)") - exit() - - - # Setup simulation - sim_trace = pyrtl.SimulationTrace() - sim = pyrtl.Simulation(tracer=sim_trace, block=posit_sub_block) - - # --- Test Case --- - # a = 0b01000100 -> Posit<8,1> represents 0.5 - # b = 0b01101000 -> Posit<8,1> represents 4.0 - # a - b = 0.5 - 4.0 = -3.5 - # -3.5 in Posit<8,1> is approximately 0b10011100 (which is the encoding for -4.0, the closest value) - # The exact value -3.5 cannot be represented. Let's test with representable values. - # - # Test 2: 4.0 - 0.5 = 3.5 - # a = 4.0 -> 0b01101000 - # b = 0.5 -> 0b01000100 - # result should be 3.5 -> approx 0b01100110 (which is 3.0, the nearest value) - - a_val = 0b01101000 # 4.0 - b_val = 0b01000100 # 0.5 - - # Run the simulation step - sim.step({'a': a_val, 'b': b_val}) - - # Inspect the simulation output - result_val = sim.inspect('result') - - print(f"\nInput 'a' : {a_val:0{N_BITS}b} (Represents ~4.0)") - print(f"Input 'b' : {b_val:0{N_BITS}b} (Represents 0.5)") - print(f"Result a-b : {result_val:0{N_BITS}b} (Represents ~3.0)") - print(f"Expected approx: {0b01100110:0{N_BITS}b}") \ No newline at end of file + ) + + resultFrac_bitwidth = resultFrac.bitwidth + roundup = pyrtl.select(bitlength_gt, + (pyrtl.shift_right_logical(resultFrac, resultFrac_bitwidth - frac_bits - 2)) & 1, + pyrtl.Const(0, 1) + ) + + resultFrac_adj = resultFrac_shifted - pyrtl.shift_left_logical(pyrtl.Const(1, resultFrac_shifted.bitwidth), frac_bits) + value = regime + pyrtl.shift_left_logical(resultExponent, frac_bits) + resultFrac_adj + + rem_bits_le_es = rem_bits <= es + bitlength_gt = bitlength > (frac_bits + 1) + bitlength_lt = bitlength < (frac_bits + 1) + roundup_nonzero = roundup != 0 + value_not_max = value != ((1 << nbits) - 1) + + final_value = pyrtl.select(rem_bits_le_es, + regime + pyrtl.shift_right_logical(resultExponent, es - rem_bits), + + pyrtl.select(signbit1, + twos_comp( + pyrtl.select(roundup_nonzero & value_not_max, + value + 1, + value + ), nbits + ), + + pyrtl.select(roundup_nonzero & value_not_max, + value + 1, + value + ) + ) + ) + + result = pyrtl.select( + is_inf, res_inf, + pyrtl.select( + is_a_zero, res_a_zero, + pyrtl.select( + is_b_zero, res_b_zero, + final_value + ) + ) + ) + + return result From d38b9b6f3d8cb24f8b261099314fa0437d61fb73 Mon Sep 17 00:00:00 2001 From: Sunidhi M Date: Fri, 5 Sep 2025 11:59:09 +0530 Subject: [PATCH 22/33] rename Posits_Sub.py to positsub.py --- pyrtl/rtllib/{Posits_Sub.py => positsub.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename pyrtl/rtllib/{Posits_Sub.py => positsub.py} (100%) diff --git a/pyrtl/rtllib/Posits_Sub.py b/pyrtl/rtllib/positsub.py similarity index 100% rename from pyrtl/rtllib/Posits_Sub.py rename to pyrtl/rtllib/positsub.py From c106c82d4e6c41326dfb93c6a797338a6d8e0588 Mon Sep 17 00:00:00 2001 From: Sunidhi M Date: Sat, 6 Sep 2025 12:09:44 +0530 Subject: [PATCH 23/33] Add twos compliment function --- pyrtl/positutils.py | 38 +++++++++++++++++++++++++++++++++++++- 1 file changed, 37 insertions(+), 1 deletion(-) diff --git a/pyrtl/positutils.py b/pyrtl/positutils.py index 364aa610..05322c06 100644 --- a/pyrtl/positutils.py +++ b/pyrtl/positutils.py @@ -305,6 +305,42 @@ def remove_first_one(val: pyrtl.WireVector) -> pyrtl.WireVector: return pyrtl.concat_list(result_bits[::-1]) +def twos_comp(x: pyrtl.WireVector, n: int) -> pyrtl.WireVector: + """Compute the two's complement of an n-bit WireVector. + + Two's complement is the standard way of representing signed integers + in binary systems. The process is: + 1. Invert all the bits (one's complement). + 2. Add 1 to the result. + This function ensures the result is limited to 'n' bits. + + Example:: + + >>> import pyrtl + >>> pyrtl.reset_working_block() + >>> x = pyrtl.Const(5, bitwidth=4) # 0101 (decimal 5) + >>> result = twos_comp(x, 4) + >>> sim = pyrtl.Simulation() + >>> sim.step({}) + >>> format(sim.inspect(result), '04b') + '1011' # -5 in two's complement + + :param x: The input value as a PyRTL WireVector. + :param n: Bitwidth to operate on. + :return: The n-bit two's complement representation of x. + """ + + # Mask with n bits set to 1 + mask = pyrtl.Const((1 << n) - 1, bitwidth=n) + # Invert the bits of x + inverted = x ^ mask + # Add 1 to complete the two's complement process + added = inverted + 1 + # Ensure the result fits into exactly n bits + return added & mask + + + def decimal_to_posit(x: float, nbits: int, es: int) -> int: """Convert a decimal float to Posit representation. @@ -394,4 +430,4 @@ def decimal_to_posit(x: float, nbits: int, es: int) -> int: else: bits = bits.ljust(nbits, "0") - return int(bits, 2) \ No newline at end of file + return int(bits, 2) From 40f0b5c6213b3909b7a78df363eb9f609ed8be64 Mon Sep 17 00:00:00 2001 From: Sunidhi M Date: Sat, 13 Sep 2025 03:19:41 +0530 Subject: [PATCH 24/33] Fix : Proper alignment of fraction bits and regime --- pyrtl/rtllib/positsub.py | 434 +++++++++++++++++++++++++-------------- 1 file changed, 278 insertions(+), 156 deletions(-) diff --git a/pyrtl/rtllib/positsub.py b/pyrtl/rtllib/positsub.py index 53bd7c90..984abba7 100644 --- a/pyrtl/rtllib/positsub.py +++ b/pyrtl/rtllib/positsub.py @@ -1,197 +1,319 @@ import pyrtl -from pyrtl.corecircuits import shift_left_logical, shift_right_logical -from positutils import ( +from pyrtl.corecircuits import ( + shift_left_logical, + shift_right_logical, + shift_right_arithmetic, +) +from pyrtl.positutils import ( decode_posit, get_upto_regime, - frac_with_hidden_one, - remove_first_one, - twos_comp + zero_ext, + resize, + unify_width, + absdiff, + bitlen_u, + twos_comp, + sign_ext, ) +from pyrtl.rtllib.positadd import posit_add + + +def posit_sub( + a: pyrtl.WireVector, b: pyrtl.WireVector, nbits: int, es: int +) -> pyrtl.WireVector: + """Subtracts two numbers in posit format and returns their difference. + + .. doctest only:: + + >>> import pyrtl + >>> pyrtl.reset_working_block() + + Example:: + + >>> nbits = 8 + >>> es = 1 + >>> a = pyrtl.Input(bitwidth=nbits, name='a') + >>> b = pyrtl.Input(bitwidth=nbits, name='b') + >>> posit = pyrtl.Output(bitwidth=nbits, name='posit') + >>> result = posit_add(a, b, nbits, es) + >>> posit <<= result + >>> sim = pyrtl.Simulation() + >>> sim.step({'a': 0b01011100, 'b': 0b01100000}) # 4.5 - 2 = 2.5 + >>> format(sim.inspect('posit'), '08b') + '01100111' + + :param a: A :class:`.WireVector` to sub. Bitwidths need to match. + :param b: A :class:`WireVector` to sub. Bitwidths need to match. + :param nbits: A :class:`.int` representing the total bitwidth of the posit. + :param es: A :class:`.int` representing the exponent size of the posit. + + :return: A :class:`WireVector` that represents the differnece of the two + posits. + """ + # Subtraction of special cases + nar = pyrtl.Const(1 << (nbits - 1), bitwidth=nbits) + zero = pyrtl.Const(0, bitwidth=nbits) + mask = pyrtl.Const((1 << nbits) - 1, bitwidth=nbits) + + is_nar = (a == nar) | (b == nar) + neg_b = ((~b) + pyrtl.Const(1, bitwidth=nbits)) & mask + quick = pyrtl.select( + is_nar, + nar, + pyrtl.select( + a == zero, + neg_b, + pyrtl.select(b == zero, a, pyrtl.Const(0, bitwidth=nbits)), + ), + ) + have_quick = quick != pyrtl.Const(0, bitwidth=nbits) + # Decode input posits + sign1, k1, exponent1, frac1, fl1 = decode_posit(a, nbits, es) + sign2, k2, exponent2, frac2, fl2 = decode_posit(b, nbits, es) -def posit_sub(a, b, nbits, es): - - inf = pyrtl.Const((1 << nbits) - 1, bitwidth=nbits) - is_inf = (a == inf) | (b == inf) - is_a_zero = (a == 0) - is_b_zero = (b == 0) - - signbit1, k_a, exp_a, frac_a, frac_len_a = decode_posit(a, nbits, es) - signbit2, k_b, exp_b, frac_b, frac_len_b = decode_posit(b, nbits, es) - - sign_final = signbit1 ^ signbit2 - - neg_b = ~b + 1 - - res_inf = inf - res_a_zero = neg_b - res_b_zero = a - scale_a = (k_a if es == 0 else shift_left_logical(k_a, es)) + exp_a - scale_b = (k_b if es == 0 else shift_left_logical(k_b, es)) + exp_b - - frac_bits = pyrtl.select(frac_len_a > frac_len_b, frac_len_a, frac_len_b) - - shift_amt1 = frac_a - frac_b - shift_amt2 = frac_b - frac_a - - frac2_shifted = pyrtl.shift_left_logical(frac_b, shift_amt1) - frac1_shifted = pyrtl.shift_right_logical(frac_a, shift_amt2) - - - frac_a = pyrtl.select(frac_a > frac_b, frac_a, frac1_shifted) - frac_b = pyrtl.select(frac_a > frac_b, frac2_shifted, frac_b) - - offset = scale_a - scale_b - is_equal = (offset == 0) - - frac_diff = frac_a - frac_b - - resultFrac_default = pyrtl.Const(0, bitwidth=nbits) - resultScale_default = pyrtl.Const(0, bitwidth=nbits) - - resultFrac = pyrtl.select(is_equal, frac_diff, resultFrac_default) - resultScale = pyrtl.select(is_equal, scale_a, resultScale_default) - - is_offset_pos = offset > 0 - - one_shifted = shift_left_logical(pyrtl.Const(1, bitwidth=nbits), frac_bits) - frac1_ext = frac_a + one_shifted - frac2_ext = frac_b + one_shifted - - - shifted_frac1 = shift_left_logical(frac1_ext, offset) - - frac_diff_pos = shifted_frac1 - frac2_ext - - resultFrac_default = pyrtl.Const(0, bitwidth=nbits) - resultScale_default = pyrtl.Const(0, bitwidth=nbits) - - resultFrac = pyrtl.select(is_offset_pos, frac_diff_pos, resultFrac_default) - resultScale = pyrtl.select(is_offset_pos, scale_a, resultScale_default) - - is_offset_neg = offset < 0 - - signbit1_flipped = signbit1 ^ pyrtl.Const(1, bitwidth=1) - - one_shifted = shift_left_logical(pyrtl.Const(1, bitwidth=nbits), frac_bits) - frac1_ext = frac_a + one_shifted - frac2_ext = frac_b + one_shifted - - offset_neg = ~offset + pyrtl.Const(1, bitwidth=offset.bitwidth) - - offset_abs = pyrtl.select(is_offset_neg, offset_neg, offset) - - shifted_frac2 = shift_left_logical(frac2_ext, offset_abs) + # Opposite sign detection + opp = sign1 != sign2 - frac_diff_neg = shifted_frac2 - frac1_ext + neg_a = twos_comp(a, nbits) + neg_b = twos_comp(b, nbits) + # a positive, b negative : a + |b| + sum_posneg = posit_add(a, neg_b, nbits, es) - resultFrac_default2 = pyrtl.Const(0, bitwidth=nbits) - resultScale_default2 = pyrtl.Const(0, bitwidth=nbits) + # a negative, b positive : |a| + b, then negate the final sum + sum_negpos_pos = posit_add(neg_a, b, nbits, es) + sum_negpos = twos_comp(sum_negpos_pos, nbits) - is_offset_neg = offset < 0 - resultFrac = pyrtl.select(is_offset_neg, frac_diff_neg, resultFrac_default2) - resultScale = pyrtl.select(is_offset_neg, scale_b, resultScale_default2) - signbit1 = pyrtl.select(is_offset_neg, signbit1_flipped, signbit1) + # Final opposite-sign sum + sum_v = pyrtl.select( + sign1 == pyrtl.Const(0, bitwidth=1), sum_posneg, sum_negpos + ) - is_resultFrac_neg = resultFrac[-1] + SC_BW = max( + nbits + es + 6, + k1.bitwidth + es + 2, + k2.bitwidth + es + 2, + exponent1.bitwidth + 2, + exponent2.bitwidth + 2, + ) - signbit1_flipped = signbit1 ^ pyrtl.Const(1, bitwidth=1) + k1_se = sign_ext(k1, SC_BW) + k2_se = sign_ext(k2, SC_BW) + exp1_ze = zero_ext(exponent1, SC_BW) + exp2_ze = zero_ext(exponent2, SC_BW) + + # Compute scale = k*2^es * exponent + if es == 0: + scale1 = k1_se + exp1_ze + scale2 = k2_se + exp2_ze + else: + sh_es_sc = pyrtl.Const(es, bitwidth=SC_BW) + scale1 = shift_left_logical(k1_se, sh_es_sc) + exp1_ze + scale2 = shift_left_logical(k2_se, sh_es_sc) + exp2_ze + + # Align fraction precision to max(fl1, fl2) + frac_bits = pyrtl.select(fl1 > fl2, fl1, fl2) + shift12 = fl1 - fl2 + shift21 = fl2 - fl1 + f1a = pyrtl.select(fl1 >= fl2, frac1, shift_left_logical(frac1, shift21)) + f2a = pyrtl.select(fl2 >= fl1, frac2, shift_left_logical(frac2, shift12)) + + one_n = pyrtl.Const(1, bitwidth=nbits) + one_frac = shift_left_logical(one_n, frac_bits) + + # Compute offsets + offset = scale1 - scale2 + off_neg = offset[SC_BW - 1] + abs_off = pyrtl.select( + off_neg, + ((~offset) + pyrtl.Const(1, bitwidth=SC_BW)), + offset, + ) - resultFrac_neg = ~resultFrac + pyrtl.Const(1, bitwidth=resultFrac.bitwidth) + W = max(nbits * 2, one_frac.bitwidth + SC_BW + 2) + f1w = zero_ext(f1a, W) + f2w = zero_ext(f2a, W) + onew = zero_ext(one_frac, W) + offW = resize(abs_off, W) + + # Add hidden one to both sides and align by offset + a_sum = f1w + onew + b_sum = f2w + onew + a_sh = shift_left_logical(a_sum, offW[: a_sum.bitwidth]) + b_sh = shift_left_logical(b_sum, offW[: b_sum.bitwidth]) + + was_neg_A, diff_A = absdiff(f1w, f2w) + scale_A = scale1 + sign_A = sign1 ^ was_neg_A + + was_neg_B, diff_B = absdiff(a_sh, b_sum) + scale_B = scale1 + sign_B = pyrtl.select(was_neg_B, sign1 ^ pyrtl.Const(1, 1), sign1) + + was_neg_C, diff_C = absdiff(b_sh, a_sum) + scale_C = scale2 + base_sign_C = sign1 ^ pyrtl.Const(1, 1) + sign_C = pyrtl.select( + was_neg_C, base_sign_C ^ pyrtl.Const(1, 1), base_sign_C + ) - signbit1 = pyrtl.select(is_resultFrac_neg, signbit1_flipped, signbit1) - resultFrac = pyrtl.select(is_resultFrac_neg, resultFrac_neg, resultFrac) + is_zero_off = abs_off == pyrtl.Const(0, bitwidth=SC_BW) + res_frac0 = pyrtl.select( + is_zero_off, diff_A, pyrtl.select(off_neg, diff_C, diff_B) + ) + res_scale0 = pyrtl.select( + is_zero_off, scale_A, pyrtl.select(off_neg, scale_C, scale_B) + ) + sign0 = pyrtl.select( + is_zero_off, sign_A, pyrtl.select(off_neg, sign_C, sign_B) + ) - bitlength = resultFrac.bitwidth + same_fields = (k1 == k2) & (exponent1 == exponent2) & (frac1 == frac2) + is_exact_cancel = is_zero_off & same_fields + res_frac0 = pyrtl.select( + is_exact_cancel, pyrtl.Const(0, bitwidth=W), res_frac0 + ) + sign0 = pyrtl.select(is_exact_cancel, pyrtl.Const(0, bitwidth=1), sign0) + + # Normalize to target precision: same-sign => target = frac_bits + blen_bw = max(8, int(math.ceil(math.log2(W + 1)))) + bitlen0 = bitlen_u(res_frac0, blen_bw) + + fb_target = bitlen_u(frac_bits, blen_bw) + diff_needed = fb_target - bitlen0 + need_extend = ~diff_needed[blen_bw - 1] + extend_amt = diff_needed[:W] + max_shift = pyrtl.Const(W - 1, bitwidth=W) + extend_amt = pyrtl.select( + extend_amt > max_shift, max_shift, extend_amt + ) - needs_shift = frac_bits > bitlength - shift_amt = frac_bits - bitlength + res_frac1 = pyrtl.select( + need_extend, shift_left_logical(res_frac0, extend_amt), res_frac0 + ) + res_scale1 = pyrtl.select( + need_extend, res_scale0 - resize(extend_amt, SC_BW), res_scale0 + ) + bitlen1 = pyrtl.select(need_extend, fb_target, bitlen0) + + # Final scale tweak (same_sign): + (bitlength - 1 - |offset| - frac_bits) + adj1 = resize(bitlen1, SC_BW) - pyrtl.Const(1, bitwidth=SC_BW) + adj2 = adj1 - resize(abs_off, SC_BW) - resize(frac_bits, SC_BW) + scale_final = res_scale1 + adj2 + + # Extract k and exponent + if es == 0: + resultk = scale_final + resultExponent_sc = pyrtl.Const(0, bitwidth=SC_BW) + else: + shamt_sf = pyrtl.Const(es, bitwidth=SC_BW) + resultk = shift_right_arithmetic(scale_final, shamt_sf) + k_lsl_sc = shift_left_logical(resize(resultk, SC_BW), shamt_sf) + resultExponent_sc = scale_final - k_lsl_sc + + # Regime with sign=0 + rem_bits, regime = get_upto_regime( + resize(resultk, nbits), nbits, pyrtl.Const(0, bitwidth=1) + ) - - shifted_resultFrac = shift_left_logical(resultFrac, shift_amt) + # Small posit if no room for exponent+fraction + is_small = rem_bits <= pyrtl.Const(es, bitwidth=nbits) + shift_amt_small = pyrtl.Const(es, bitwidth=nbits) - rem_bits + exp_shifted_small = shift_right_logical( + resize(resultExponent_sc, nbits), shift_amt_small + ) + small_value = regime + exp_shifted_small - resultFrac_if = pyrtl.select(needs_shift, shifted_resultFrac, resultFrac) - resultScale_if = pyrtl.select(needs_shift, resultScale - shift_amt, resultScale) - bitlength_if = pyrtl.select(needs_shift, frac_bits, bitlength) + # normal form - frac_bits_avail = rem_bits - es + frac_bits_avail = rem_bits - pyrtl.Const(es, bitwidth=nbits) - resultFrac = resultFrac_if - resultScale = resultScale_if - bitlength = bitlength_if + sum_keep = ( + resize(frac_bits_avail, blen_bw) + pyrtl.Const(1, bitwidth=blen_bw) + ) - offset_neg = ~offset + pyrtl.Const(1, bitwidth=offset.bitwidth) - offset_abs = pyrtl.select(offset < 0, offset_neg, offset) + bitlen1_u, sum_keep_u, Wc = unify_width(bitlen1, sum_keep) + ge = bitlen1_u >= sum_keep_u + r_amt_wide = pyrtl.select( + ge, bitlen1_u - sum_keep_u, pyrtl.Const(0, bitwidth=Wc) + ) + r_amt = resize(r_amt_wide, W) + l_amt_wide = pyrtl.select( + ge, pyrtl.Const(0, bitwidth=Wc), sum_keep_u - bitlen1_u + ) + l_amt = resize(l_amt_wide, W) - resultScale = resultScale + (bitlength - pyrtl.Const(1, bitwidth=bitlength.bitwidth) - - offset_abs - frac_bits) + kept_plus_hidden = pyrtl.select( + ge, shift_right_logical(res_frac1, r_amt), shift_left_logical(res_frac1, l_amt) + ) - mask = (1 << es) - 1 - mask_const = pyrtl.Const(mask, bitwidth=resultScale.bitwidth) + r_amt_nonzero = r_amt_wide != pyrtl.Const(0, bitwidth=Wc) + r_amt_minus1 = resize(r_amt - pyrtl.Const(1, bitwidth=r_amt.bitwidth), W) + guard_src = shift_right_logical(res_frac1, r_amt_minus1) + guard_bit = pyrtl.select( + ge & r_amt_nonzero, guard_src & pyrtl.Const(1, bitwidth=W), pyrtl.Const(0, bitwidth=W) + ) + guard_is_one = guard_bit != pyrtl.Const(0, bitwidth=W) - scale_mod = resultScale & mask_const + # Remove hidden one + oneW = pyrtl.Const(1, bitwidth=W) + one_keep = shift_left_logical(oneW, resize(frac_bits_avail, W)) + frac_field_w = kept_plus_hidden - one_keep - resultExponent = resultScale & scale_mod + exp_shifted = shift_left_logical( + resize(resultExponent_sc, nbits), frac_bits_avail + ) + frac_mask = ( + shift_left_logical(pyrtl.Const(1, bitwidth=nbits), frac_bits_avail) + - pyrtl.Const(1, bitwidth=nbits) + ) + frac_field = resize(frac_field_w, nbits) & frac_mask - resultk = shift_right_logical(resultScale, es) + value_large = regime + exp_shifted + frac_field + all_ones = pyrtl.Const((1 << nbits) - 1, bitwidth=nbits) + not_all_ones = value_large != all_ones - rem_bits, regime = get_upto_regime(resultk, nbits, sign_final) + # Rounding final posit + do_round = guard_is_one & not_all_ones + value_rounded = pyrtl.select(do_round, value_large + 1, value_large) - frac_bits = rem_bits - es - bitlength_gt = bitlength > (frac_bits + 1) - bitlength_lt = bitlength < (frac_bits + 1) + packed_pos = pyrtl.select(is_small, small_value, value_rounded) - resultFrac_shifted = pyrtl.select(bitlength_gt, - pyrtl.shift_right_logical(resultFrac, bitlength - frac_bits - 1), - pyrtl.select(bitlength_lt, - pyrtl.shift_left_logical(resultFrac, frac_bits + 1 - bitlength), - resultFrac - ) + packed_signed = pyrtl.select( + sign0 == pyrtl.Const(1, bitwidth=1), + ((~packed_pos) + pyrtl.Const(1, bitwidth=nbits)) & mask, + packed_pos, ) - resultFrac_bitwidth = resultFrac.bitwidth - roundup = pyrtl.select(bitlength_gt, - (pyrtl.shift_right_logical(resultFrac, resultFrac_bitwidth - frac_bits - 2)) & 1, - pyrtl.Const(0, 1) + exp_shifted = shift_left_logical( + resize(resultExponent_sc, nbits), frac_bits_avail ) + frac_mask = ( + shift_left_logical(pyrtl.Const(1, bitwidth=nbits), frac_bits_avail) + - pyrtl.Const(1, bitwidth=nbits) + ) + frac_field = resize(frac_field_w, nbits) & frac_mask - resultFrac_adj = resultFrac_shifted - pyrtl.shift_left_logical(pyrtl.Const(1, resultFrac_shifted.bitwidth), frac_bits) - value = regime + pyrtl.shift_left_logical(resultExponent, frac_bits) + resultFrac_adj + value_large = regime + exp_shifted + frac_field + all_ones = pyrtl.Const((1 << nbits) - 1, bitwidth=nbits) + not_all_ones = value_large != all_ones - rem_bits_le_es = rem_bits <= es - bitlength_gt = bitlength > (frac_bits + 1) - bitlength_lt = bitlength < (frac_bits + 1) - roundup_nonzero = roundup != 0 - value_not_max = value != ((1 << nbits) - 1) + do_round = guard_is_one & not_all_ones + value_rounded = pyrtl.select(do_round, value_large + 1, value_large) - final_value = pyrtl.select(rem_bits_le_es, - regime + pyrtl.shift_right_logical(resultExponent, es - rem_bits), - - pyrtl.select(signbit1, - twos_comp( - pyrtl.select(roundup_nonzero & value_not_max, - value + 1, - value - ), nbits - ), + packed_pos = pyrtl.select(is_small, small_value, value_rounded) - pyrtl.select(roundup_nonzero & value_not_max, - value + 1, - value - ) - ) + packed_signed = pyrtl.select( + sign0 == pyrtl.Const(1, bitwidth=1), + ((~packed_pos) + pyrtl.Const(1, bitwidth=nbits)) & mask, + packed_pos, ) - result = pyrtl.select( - is_inf, res_inf, - pyrtl.select( - is_a_zero, res_a_zero, - pyrtl.select( - is_b_zero, res_b_zero, - final_value - ) - ) - ) + is_zero_res = (res_frac1 == pyrtl.Const(0, bitwidth=W)) | is_exact_cancel + same_sign_out = pyrtl.select(is_zero_res, zero, packed_signed) + nonquick = pyrtl.select(opp, sum_v, same_sign_out) + result = pyrtl.select(have_quick, quick, nonquick) return result From dd748ed9c204bc12dce26612986daf976e55a586 Mon Sep 17 00:00:00 2001 From: Sunidhi M Date: Sat, 13 Sep 2025 03:23:41 +0530 Subject: [PATCH 25/33] add posit helper fns for width align and abs diff --- pyrtl/positutils.py | 131 ++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 127 insertions(+), 4 deletions(-) diff --git a/pyrtl/positutils.py b/pyrtl/positutils.py index 05322c06..1fcf06eb 100644 --- a/pyrtl/positutils.py +++ b/pyrtl/positutils.py @@ -330,15 +330,138 @@ def twos_comp(x: pyrtl.WireVector, n: int) -> pyrtl.WireVector: :return: The n-bit two's complement representation of x. """ - # Mask with n bits set to 1 mask = pyrtl.Const((1 << n) - 1, bitwidth=n) - # Invert the bits of x inverted = x ^ mask - # Add 1 to complete the two's complement process added = inverted + 1 - # Ensure the result fits into exactly n bits return added & mask +def zero_ext(x: pyrtl.WireVector, new_bw: int) -> pyrtl.WireVector: + """ + Zero-extend `x` to `new_bw` bits. + + Behavior: + - If `new_bw == x.bitwidth`, returns `x` unchanged (no copy). + - If `new_bw > x.bitwidth`, pads MSBs with zeros so the unsigned value is preserved. + - Truncation is NOT performed here (use `_resize` for that). + + Equivalent to concatenating (new_bw - x.bitwidth) zero bits above `x`. + """ + assert new_bw >= x.bitwidth + if new_bw == x.bitwidth: + return x + return pyrtl.concat(pyrtl.Const(0, bitwidth=new_bw - x.bitwidth), x) + + +def resize(x: pyrtl.WireVector, new_bw: int) -> pyrtl.WireVector: + """ + Resize `x` to exactly `new_bw` bits. + + Behavior: + - If `new_bw == x.bitwidth`, returns `x` unchanged. + - If `new_bw < x.bitwidth`, returns the lower `new_bw` bits (truncates MSBs). + - If `new_bw > x.bitwidth`, zero-extends to widen (unsigned semantics). + + Note: This is UNSIGNED resizing. If you need sign-preserving growth, + sign-extend externally instead of using `_resize`. + """ + if new_bw == x.bitwidth: + return x + if new_bw < x.bitwidth: + return x[:new_bw] + return zero_ext(x, new_bw) + + +def sign_ext(x: pyrtl.WireVector, new_bw: int) -> pyrtl.WireVector: + """ + Sign-extend `x` to bitwidth `new_bw` (two's-complement semantics). + + Behavior: + - If `new_bw <= x.bitwidth`, returns `x` unchanged (no truncation or copy). + - If `new_bw > x.bitwidth`, replicates the MSB (sign bit) into the new upper bits + so that the signed value is preserved under two's-complement interpretation. + + Use when: + - Widening **signed** quantities (e.g., regime `k`) before arithmetic or shifting. + For **unsigned** widening, prefer `zero_ext`. + + Implementation note: + - Captures `signbit = x[x.bitwidth-1]`, creates `new_bw - x.bitwidth` copies, + and concatenates them above `x`. + + :param x: Source WireVector interpreted as a signed two's-complement value. + :param new_bw: Target bitwidth to extend to. + :return: A WireVector of width `new_bw` with the same signed value as `x`. + """ + if new_bw <= x.bitwidth: + return x + signbit = x[x.bitwidth-1] + pad = pyrtl.concat_list([signbit for _ in range(new_bw - x.bitwidth)]) + return pyrtl.concat(pad, x) + + +def unify_width(a: pyrtl.WireVector, b: pyrtl.WireVector): + """ + Return `(a_w, b_w, W)` where both inputs are widened to the same bitwidth `W`. + + - `W = max(a.bitwidth, b.bitwidth)` + - Widening uses zero-extension (unsigned semantics). + - Useful before doing comparisons/arithmetic that expects matching widths. + + :return: (a_zero_extended, b_zero_extended, unified_bitwidth) + """ + W = max(a.bitwidth, b.bitwidth) + return zero_ext(a, W), zero_ext(b, W), W + + +def absdiff(a: pyrtl.WireVector, b: pyrtl.WireVector): + """ + Unsigned absolute difference between `a` and `b`. + + Steps (purely combinational): + 1) Zero-extend to a common width. + 2) Compute `a >= b` to pick subtraction order. + 3) Return `(was_neg, |a-b|)`, where: + - `was_neg` is 1 when `a < b` (i.e., subtraction would be negative), + - magnitude is the absolute difference. + + :return: (was_neg:1-bit, magnitude:WireVector) + """ + aa, bb, W = unify_width(a, b) + a_ge_b = aa >= bb + mag = pyrtl.select(a_ge_b, aa - bb, bb - aa) + was_neg = ~a_ge_b & pyrtl.Const(1, bitwidth=1) + return was_neg, mag + + +def bitlen_u(x: pyrtl.WireVector, out_bw: int) -> pyrtl.WireVector: + """ + Return the UNSIGNED bit-length of `x` (index of MSB + 1). + + Definition: + - bitlen(0b00010100) = 5 + - bitlen(0b00000000) = 0 + + Implementation detail: + - Counts leading zeros, then computes `bitlen = width - leading_zeros`. + - `out_bw` must be large enough to encode up to `x.bitwidth` + (use >= ceil(log2(x.bitwidth+1)) to be safe). + + :param x: WireVector to analyze (treated as unsigned). + :param out_bw: Bitwidth of the integer result (e.g., 8 or more). + :return: WireVector of width `out_bw` with the bit-length of `x`. + """ + bw = x.bitwidth + lz = pyrtl.Const(0, bitwidth=out_bw) + seen = pyrtl.Const(0, bitwidth=1) + + for i in range(bw): + bit = x[bw - 1 - i] + inc = (~seen) & (~bit) + lz = lz + pyrtl.select(inc, pyrtl.Const(1, bitwidth=out_bw), + pyrtl.Const(0, bitwidth=out_bw)) + seen = seen | bit + + return pyrtl.Const(bw, bitwidth=out_bw) - lz def decimal_to_posit(x: float, nbits: int, es: int) -> int: From a7245cc47ebe7891d8512c6217401860e760706c Mon Sep 17 00:00:00 2001 From: Sunidhi M Date: Sat, 13 Sep 2025 21:41:41 +0530 Subject: [PATCH 26/33] Fix: Proper sub for es = 0 --- pyrtl/rtllib/positsub.py | 249 +++++++++++++++++---------------------- 1 file changed, 110 insertions(+), 139 deletions(-) diff --git a/pyrtl/rtllib/positsub.py b/pyrtl/rtllib/positsub.py index 984abba7..c0ba5fd2 100644 --- a/pyrtl/rtllib/positsub.py +++ b/pyrtl/rtllib/positsub.py @@ -17,79 +17,61 @@ ) from pyrtl.rtllib.positadd import posit_add - def posit_sub( a: pyrtl.WireVector, b: pyrtl.WireVector, nbits: int, es: int ) -> pyrtl.WireVector: """Subtracts two numbers in posit format and returns their difference. - .. doctest only:: + .. doctest:: >>> import pyrtl + >>> from positutils import decimal_to_posit + >>> from positsub import posit_sub >>> pyrtl.reset_working_block() - - Example:: - - >>> nbits = 8 - >>> es = 1 + >>> nbits, es = 8, 1 >>> a = pyrtl.Input(bitwidth=nbits, name='a') >>> b = pyrtl.Input(bitwidth=nbits, name='b') - >>> posit = pyrtl.Output(bitwidth=nbits, name='posit') - >>> result = posit_add(a, b, nbits, es) - >>> posit <<= result + >>> out = pyrtl.Output(bitwidth=nbits, name='out') + >>> out <<= posit_sub(a, b, nbits, es) >>> sim = pyrtl.Simulation() - >>> sim.step({'a': 0b01011100, 'b': 0b01100000}) # 4.5 - 2 = 2.5 - >>> format(sim.inspect('posit'), '08b') - '01100111' - - :param a: A :class:`.WireVector` to sub. Bitwidths need to match. - :param b: A :class:`WireVector` to sub. Bitwidths need to match. - :param nbits: A :class:`.int` representing the total bitwidth of the posit. - :param es: A :class:`.int` representing the exponent size of the posit. - - :return: A :class:`WireVector` that represents the differnece of the two - posits. + >>> aval = decimal_to_posit(4.5, nbits, es) + >>> bval = decimal_to_posit(2.0, nbits, es) + >>> sim.step({'a': aval, 'b': bval}) # 4.5 - 2.0 = 2.5 + >>> sim.inspect('out') == decimal_to_posit(2.5, nbits, es) + True """ - # Subtraction of special cases - nar = pyrtl.Const(1 << (nbits - 1), bitwidth=nbits) + # special cases + nar = pyrtl.Const(1 << (nbits - 1), bitwidth=nbits) zero = pyrtl.Const(0, bitwidth=nbits) mask = pyrtl.Const((1 << nbits) - 1, bitwidth=nbits) + maxpos = pyrtl.Const((1 << (nbits - 1)) - 1, bitwidth=nbits) is_nar = (a == nar) | (b == nar) - neg_b = ((~b) + pyrtl.Const(1, bitwidth=nbits)) & mask + negb_quick = ((~b) + pyrtl.Const(1, bitwidth=nbits)) & mask quick = pyrtl.select( is_nar, nar, pyrtl.select( a == zero, - neg_b, - pyrtl.select(b == zero, a, pyrtl.Const(0, bitwidth=nbits)), + negb_quick, + pyrtl.select(b == zero, a, zero), ), ) - have_quick = quick != pyrtl.Const(0, bitwidth=nbits) + have_quick = is_nar | (a == zero) | (b == zero) - # Decode input posits + # Decode input posit sign1, k1, exponent1, frac1, fl1 = decode_posit(a, nbits, es) sign2, k2, exponent2, frac2, fl2 = decode_posit(b, nbits, es) - # Opposite sign detection + # check if inputs are opposite sign opp = sign1 != sign2 - neg_a = twos_comp(a, nbits) neg_b = twos_comp(b, nbits) + sum_posneg = posit_add(a, neg_b, nbits, es) + sum_negpos = twos_comp(posit_add(neg_a, b, nbits, es), nbits) + sum_v = pyrtl.select(sign1 == pyrtl.Const(0, 1), sum_posneg, sum_negpos) - # a positive, b negative : a + |b| - sum_posneg = posit_add(a, neg_b, nbits, es) - - # a negative, b positive : |a| + b, then negate the final sum - sum_negpos_pos = posit_add(neg_a, b, nbits, es) - sum_negpos = twos_comp(sum_negpos_pos, nbits) - - # Final opposite-sign sum - sum_v = pyrtl.select( - sign1 == pyrtl.Const(0, bitwidth=1), sum_posneg, sum_negpos - ) - + # Internal widths SC_BW = max( nbits + es + 6, k1.bitwidth + es + 2, @@ -103,7 +85,7 @@ def posit_sub( exp1_ze = zero_ext(exponent1, SC_BW) exp2_ze = zero_ext(exponent2, SC_BW) - # Compute scale = k*2^es * exponent + # compute scale = k*2^es + exponent if es == 0: scale1 = k1_se + exp1_ze scale2 = k2_se + exp2_ze @@ -112,7 +94,7 @@ def posit_sub( scale1 = shift_left_logical(k1_se, sh_es_sc) + exp1_ze scale2 = shift_left_logical(k2_se, sh_es_sc) + exp2_ze - # Align fraction precision to max(fl1, fl2) + #Align fraction precision to max(fl1, fl2) frac_bits = pyrtl.select(fl1 > fl2, fl1, fl2) shift12 = fl1 - fl2 shift21 = fl2 - fl1 @@ -122,14 +104,10 @@ def posit_sub( one_n = pyrtl.Const(1, bitwidth=nbits) one_frac = shift_left_logical(one_n, frac_bits) - # Compute offsets + # offset calculation between scales offset = scale1 - scale2 off_neg = offset[SC_BW - 1] - abs_off = pyrtl.select( - off_neg, - ((~offset) + pyrtl.Const(1, bitwidth=SC_BW)), - offset, - ) + abs_off = pyrtl.select(off_neg, ((~offset) + pyrtl.Const(1, bitwidth=SC_BW)), offset) W = max(nbits * 2, one_frac.bitwidth + SC_BW + 2) f1w = zero_ext(f1a, W) @@ -154,51 +132,45 @@ def posit_sub( was_neg_C, diff_C = absdiff(b_sh, a_sum) scale_C = scale2 base_sign_C = sign1 ^ pyrtl.Const(1, 1) - sign_C = pyrtl.select( - was_neg_C, base_sign_C ^ pyrtl.Const(1, 1), base_sign_C - ) + sign_C = pyrtl.select(was_neg_C, base_sign_C ^ pyrtl.Const(1, 1), base_sign_C) is_zero_off = abs_off == pyrtl.Const(0, bitwidth=SC_BW) - res_frac0 = pyrtl.select( - is_zero_off, diff_A, pyrtl.select(off_neg, diff_C, diff_B) - ) - res_scale0 = pyrtl.select( - is_zero_off, scale_A, pyrtl.select(off_neg, scale_C, scale_B) - ) - sign0 = pyrtl.select( - is_zero_off, sign_A, pyrtl.select(off_neg, sign_C, sign_B) - ) + # choose diff/scale/sign by offset sign + res_frac0 = pyrtl.select(is_zero_off, diff_A, pyrtl.select(off_neg, diff_C, diff_B)) + res_scale0 = pyrtl.select(is_zero_off, scale_A, pyrtl.select(off_neg, scale_C, scale_B)) + sign0 = pyrtl.select(is_zero_off, sign_A, pyrtl.select(off_neg, sign_C, sign_B)) + # exact cancel (same fields when abs_off==0) same_fields = (k1 == k2) & (exponent1 == exponent2) & (frac1 == frac2) is_exact_cancel = is_zero_off & same_fields - res_frac0 = pyrtl.select( - is_exact_cancel, pyrtl.Const(0, bitwidth=W), res_frac0 - ) - sign0 = pyrtl.select(is_exact_cancel, pyrtl.Const(0, bitwidth=1), sign0) + res_frac0 = pyrtl.select(is_exact_cancel, pyrtl.Const(0, bitwidth=W), res_frac0) + sign0 = pyrtl.select(is_exact_cancel, pyrtl.Const(0, bitwidth=1), sign0) + + # Only for es==0, same-sign, nonzero offset, and neg result. + same_sign = ~opp + nonzero_off = abs_off != pyrtl.Const(0, bitwidth=SC_BW) + es_is_zero = pyrtl.Const(1, bitwidth=1) if es == 0 else pyrtl.Const(0, bitwidth=1) + need_k_corr = same_sign & nonzero_off & es_is_zero & (sign0 == pyrtl.Const(1, bitwidth=1)) + res_scale0 = pyrtl.select(need_k_corr, res_scale0 + pyrtl.Const(1, bitwidth=SC_BW), res_scale0) - # Normalize to target precision: same-sign => target = frac_bits + # Normalize to target precision blen_bw = max(8, int(math.ceil(math.log2(W + 1)))) bitlen0 = bitlen_u(res_frac0, blen_bw) - fb_target = bitlen_u(frac_bits, blen_bw) + fb_target = resize(frac_bits, blen_bw) + pyrtl.Const(1, bitwidth=blen_bw) + diff_needed = fb_target - bitlen0 need_extend = ~diff_needed[blen_bw - 1] - extend_amt = diff_needed[:W] + extend_amt = resize(diff_needed[:W], W) max_shift = pyrtl.Const(W - 1, bitwidth=W) - extend_amt = pyrtl.select( - extend_amt > max_shift, max_shift, extend_amt - ) + extend_amt = pyrtl.select(extend_amt > max_shift, max_shift, extend_amt) - res_frac1 = pyrtl.select( - need_extend, shift_left_logical(res_frac0, extend_amt), res_frac0 - ) - res_scale1 = pyrtl.select( - need_extend, res_scale0 - resize(extend_amt, SC_BW), res_scale0 - ) - bitlen1 = pyrtl.select(need_extend, fb_target, bitlen0) + res_frac1 = pyrtl.select(need_extend, shift_left_logical(res_frac0, extend_amt), res_frac0) + res_scale1 = pyrtl.select(need_extend, res_scale0 - resize(extend_amt, SC_BW), res_scale0) + bitlen1 = pyrtl.select(need_extend, fb_target, bitlen0) - # Final scale tweak (same_sign): + (bitlength - 1 - |offset| - frac_bits) + # scale tweak for same-sign: + (bitlen - 1 - |offset| - frac_bits) adj1 = resize(bitlen1, SC_BW) - pyrtl.Const(1, bitwidth=SC_BW) adj2 = adj1 - resize(abs_off, SC_BW) - resize(frac_bits, SC_BW) scale_final = res_scale1 + adj2 @@ -213,36 +185,48 @@ def posit_sub( k_lsl_sc = shift_left_logical(resize(resultk, SC_BW), shamt_sf) resultExponent_sc = scale_final - k_lsl_sc - # Regime with sign=0 + # Regime packing rem_bits, regime = get_upto_regime( resize(resultk, nbits), nbits, pyrtl.Const(0, bitwidth=1) ) - # Small posit if no room for exponent+fraction - is_small = rem_bits <= pyrtl.Const(es, bitwidth=nbits) - shift_amt_small = pyrtl.Const(es, bitwidth=nbits) - rem_bits - exp_shifted_small = shift_right_logical( - resize(resultExponent_sc, nbits), shift_amt_small - ) + # Small posit path + es_nb = pyrtl.Const(es, bitwidth=nbits) + is_small = rem_bits <= es_nb + shift_amt_small = es_nb - rem_bits + exp_shifted_small = shift_right_logical(resize(resultExponent_sc, nbits), shift_amt_small) small_value = regime + exp_shifted_small - # normal form - frac_bits_avail = rem_bits - es - frac_bits_avail = rem_bits - pyrtl.Const(es, bitwidth=nbits) - - sum_keep = ( - resize(frac_bits_avail, blen_bw) + pyrtl.Const(1, bitwidth=blen_bw) + # For exponent drop (rem_bits < es) + shift_amt_small_nz = shift_amt_small != pyrtl.Const(0, bitwidth=nbits) + sam1 = pyrtl.select( + shift_amt_small_nz, + shift_amt_small - pyrtl.Const(1, bitwidth=nbits), + pyrtl.Const(0, bitwidth=nbits), ) + guard_src_small = shift_right_logical(resize(resultExponent_sc, SC_BW), resize(sam1, SC_BW)) + guard_exp = pyrtl.select(shift_amt_small_nz, guard_src_small[0], pyrtl.Const(0, bitwidth=1)) + + one_sc = pyrtl.Const(1, bitwidth=SC_BW) + lower_mask_small = pyrtl.select( + shift_amt_small_nz, + shift_left_logical(one_sc, resize(sam1, SC_BW)) - one_sc, + pyrtl.Const(0, bitwidth=SC_BW), + ) + sticky_exp = (resize(resultExponent_sc, SC_BW) & lower_mask_small) != pyrtl.Const(0, bitwidth=SC_BW) + + # Normal packing fields + rem_gt_es = rem_bits > es_nb + frac_bits_avail = pyrtl.select(rem_gt_es, rem_bits - es_nb, pyrtl.Const(0, bitwidth=nbits)) + + sum_keep = resize(frac_bits_avail, blen_bw) + pyrtl.Const(1, bitwidth=blen_bw) bitlen1_u, sum_keep_u, Wc = unify_width(bitlen1, sum_keep) ge = bitlen1_u >= sum_keep_u - r_amt_wide = pyrtl.select( - ge, bitlen1_u - sum_keep_u, pyrtl.Const(0, bitwidth=Wc) - ) + r_amt_wide = pyrtl.select(ge, bitlen1_u - sum_keep_u, pyrtl.Const(0, bitwidth=Wc)) + l_amt_wide = pyrtl.select(ge, pyrtl.Const(0, bitwidth=Wc), sum_keep_u - bitlen1_u) r_amt = resize(r_amt_wide, W) - l_amt_wide = pyrtl.select( - ge, pyrtl.Const(0, bitwidth=Wc), sum_keep_u - bitlen1_u - ) l_amt = resize(l_amt_wide, W) kept_plus_hidden = pyrtl.select( @@ -250,61 +234,44 @@ def posit_sub( ) r_amt_nonzero = r_amt_wide != pyrtl.Const(0, bitwidth=Wc) - r_amt_minus1 = resize(r_amt - pyrtl.Const(1, bitwidth=r_amt.bitwidth), W) - guard_src = shift_right_logical(res_frac1, r_amt_minus1) - guard_bit = pyrtl.select( - ge & r_amt_nonzero, guard_src & pyrtl.Const(1, bitwidth=W), pyrtl.Const(0, bitwidth=W) - ) - guard_is_one = guard_bit != pyrtl.Const(0, bitwidth=W) + r_amt_minus1 = resize(r_amt - pyrtl.Const(1, bitwidth=r_amt.bitwidth), W) + guard_src = shift_right_logical(res_frac1, r_amt_minus1) + guard_frac = pyrtl.select(ge & r_amt_nonzero, guard_src[0], pyrtl.Const(0, bitwidth=1)) - # Remove hidden one + trimmed = shift_right_logical(res_frac1, r_amt) + recon = shift_left_logical(trimmed, r_amt) + sticky_frac = pyrtl.select(ge & r_amt_nonzero, (res_frac1 != recon), pyrtl.Const(0, bitwidth=1)) + + # Remove hidden one to form fraction field oneW = pyrtl.Const(1, bitwidth=W) one_keep = shift_left_logical(oneW, resize(frac_bits_avail, W)) frac_field_w = kept_plus_hidden - one_keep - exp_shifted = shift_left_logical( - resize(resultExponent_sc, nbits), frac_bits_avail - ) - frac_mask = ( - shift_left_logical(pyrtl.Const(1, bitwidth=nbits), frac_bits_avail) - - pyrtl.Const(1, bitwidth=nbits) - ) + exp_shifted = shift_left_logical(resize(resultExponent_sc, nbits), frac_bits_avail) + frac_mask = shift_left_logical(pyrtl.Const(1, bitwidth=nbits), frac_bits_avail) - pyrtl.Const(1, bitwidth=nbits) frac_field = resize(frac_field_w, nbits) & frac_mask value_large = regime + exp_shifted + frac_field - all_ones = pyrtl.Const((1 << nbits) - 1, bitwidth=nbits) - not_all_ones = value_large != all_ones - # Rounding final posit - do_round = guard_is_one & not_all_ones - value_rounded = pyrtl.select(do_round, value_large + 1, value_large) + # Tie LSBs for rounding-to-even + lsb_large = value_large[0] + lsb_small = small_value[0] - packed_pos = pyrtl.select(is_small, small_value, value_rounded) + # Normal path rounding + round_up_large = guard_frac & (sticky_frac | lsb_large) + value_rounded = pyrtl.select((value_large != maxpos) & round_up_large, value_large + 1, value_large) - packed_signed = pyrtl.select( - sign0 == pyrtl.Const(1, bitwidth=1), - ((~packed_pos) + pyrtl.Const(1, bitwidth=nbits)) & mask, - packed_pos, - ) + any_frac = r_amt_nonzero & (guard_frac | sticky_frac) + guard_small_final = pyrtl.select(shift_amt_small_nz, guard_exp, guard_frac) + sticky_small_final = pyrtl.select(shift_amt_small_nz, (sticky_exp | any_frac), sticky_frac) - exp_shifted = shift_left_logical( - resize(resultExponent_sc, nbits), frac_bits_avail - ) - frac_mask = ( - shift_left_logical(pyrtl.Const(1, bitwidth=nbits), frac_bits_avail) - - pyrtl.Const(1, bitwidth=nbits) - ) - frac_field = resize(frac_field_w, nbits) & frac_mask + round_up_small = guard_small_final & (sticky_small_final | lsb_small) + small_value_rounded = pyrtl.select((small_value != maxpos) & round_up_small, small_value + 1, small_value) - value_large = regime + exp_shifted + frac_field - all_ones = pyrtl.Const((1 << nbits) - 1, bitwidth=nbits) - not_all_ones = value_large != all_ones - - do_round = guard_is_one & not_all_ones - value_rounded = pyrtl.select(do_round, value_large + 1, value_large) - - packed_pos = pyrtl.select(is_small, small_value, value_rounded) + # Select packed (unsigned) posit + packed_pos = pyrtl.select(is_small, small_value_rounded, value_rounded) + # Apply sign of result packed_signed = pyrtl.select( sign0 == pyrtl.Const(1, bitwidth=1), ((~packed_pos) + pyrtl.Const(1, bitwidth=nbits)) & mask, @@ -314,6 +281,10 @@ def posit_sub( is_zero_res = (res_frac1 == pyrtl.Const(0, bitwidth=W)) | is_exact_cancel same_sign_out = pyrtl.select(is_zero_res, zero, packed_signed) + # Final select: special case / opposite sign / same sign nonquick = pyrtl.select(opp, sum_v, same_sign_out) result = pyrtl.select(have_quick, quick, nonquick) return result + + + From 801a15ef78edc63868f07404802ed85662e7c475 Mon Sep 17 00:00:00 2001 From: Sunidhi M Date: Sat, 13 Sep 2025 21:46:35 +0530 Subject: [PATCH 27/33] Fix decimal_to_posit to handle negative nos --- pyrtl/positutils.py | 80 +++++++++++++++++++++++++-------------------- 1 file changed, 45 insertions(+), 35 deletions(-) diff --git a/pyrtl/positutils.py b/pyrtl/positutils.py index 1fcf06eb..e05cbe89 100644 --- a/pyrtl/positutils.py +++ b/pyrtl/positutils.py @@ -488,8 +488,8 @@ def decimal_to_posit(x: float, nbits: int, es: int) -> int: """ if x == 0: return 0 - - # Sign + + # handle sign at the end of twos comp sign = 0 if x < 0: sign = 1 @@ -497,18 +497,35 @@ def decimal_to_posit(x: float, nbits: int, es: int) -> int: useed = 2 ** (2 ** es) - k = int(math.floor(math.log(x, useed))) - regime_value = useed ** k - remaining = x / regime_value + if x == float('inf'): + return (1 << (nbits - 1)) - 1 + if x == 0 or x < useed ** (-(nbits - 2)): + return 0 + + # regime bits + if x >= 1: + k = int(math.floor(math.log(x, useed))) + else: + k = int(math.floor(math.log(x, useed))) + + regime_scale = useed ** k + remaining = x / regime_scale - exponent = int(math.floor(math.log2(remaining))) if es > 0 else 0 - exponent = max(0, exponent) - remaining /= 2 ** exponent + # Exponent bits + exponent = 0 + if es > 0 and remaining > 0: + exponent = int(math.floor(math.log2(remaining))) + exponent = max(0, min(exponent, (1 << es) - 1)) + remaining /= 2 ** exponent - # Fraction bits + # Frcation bits fraction = remaining - 1.0 frac_bits = [] - for _ in range(nbits * 2): + + #Remaning bits + max_frac_bits = nbits - 1 + + for _ in range(max_frac_bits): fraction *= 2 if fraction >= 1: frac_bits.append("1") @@ -516,41 +533,34 @@ def decimal_to_posit(x: float, nbits: int, es: int) -> int: else: frac_bits.append("0") - # Regime bits + # Build regime bits if k >= 0: regime_bits = "1" * (k + 1) + "0" else: regime_bits = "0" * (-k) + "1" - bits = str(sign) + regime_bits - - # Exponent bits + bits = "0" + regime_bits + + # Add exponent bits if es > 0: - exp_str = format(exponent & ((1 << es) - 1), f"0{es}b") + exp_str = format(exponent, f"0{es}b") bits += exp_str - + + # Add fraction bits bits += "".join(frac_bits) - # Handle rounding if bits exceed nbits + # Trim to nbits with rounding if len(bits) > nbits: - main = bits[:nbits] - guard = bits[nbits] - roundb = bits[nbits + 1] if nbits + 1 < len(bits) else "0" - sticky = "1" if "1" in bits[nbits + 2:] else "0" - - increment = ( - (guard == "1") - and (roundb == "1" or sticky == "1" or main[-1] == "1") - ) - - if increment: - main_int = int(main, 2) + 1 - if main_int >= (1 << (nbits - 1)): - main_int = (1 << (nbits - 1)) - 1 - main = format(main_int, f"0{nbits}b") - - bits = main + bits = bits[:nbits] else: bits = bits.ljust(nbits, "0") - return int(bits, 2) + # Convert to integer + result = int(bits, 2) + + # Apply twos complement for negative numbers + if sign: + mask = (1 << nbits) - 1 + result = ((~result) + 1) & mask + + return result From c636ae35ab284a323011fe5b0d49dc960974786c Mon Sep 17 00:00:00 2001 From: Sunidhi M Date: Sat, 13 Sep 2025 21:49:55 +0530 Subject: [PATCH 28/33] Create test_positsub.py --- tests/rtllib/test_positsub.py | 72 +++++++++++++++++++++++++++++++++++ 1 file changed, 72 insertions(+) create mode 100644 tests/rtllib/test_positsub.py diff --git a/tests/rtllib/test_positsub.py b/tests/rtllib/test_positsub.py new file mode 100644 index 00000000..c2d11ec4 --- /dev/null +++ b/tests/rtllib/test_positsub.py @@ -0,0 +1,72 @@ +import doctest +import random +import unittest +import positsub +import pyrtl +import pyrtl.rtllib.testingutils as utils +from pyrtl.rtllib.positsub import posit_sub +from pyrtl.positutils import decimal_to_posit + + +class TestDocTests(unittest.TestCase): + """Test documentation examples.""" + + def test_doctests(self): + failures, tests = doctest.testmod(m=positsub) + self.assertGreater(tests, 0) + self.assertEqual(failures, 0) + + +class TestPositSubtractor(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.seed = 42 + random.seed(cls.seed) + + def setUp(self): + pyrtl.reset_working_block() + + def tearDown(self): + pyrtl.reset_working_block() + + def test_posit_subtractor(self): + nbits_list = [8, 16, 32] + es_list = [0, 1, 2, 3, 4] + nbits, es = random.choice(nbits_list), random.choice(es_list) + + a = pyrtl.Input(bitwidth=nbits, name="a") + b = pyrtl.Input(bitwidth=nbits, name="b") + out = pyrtl.Output(name="out") + out <<= posit_sub(a, b, nbits, es) + + useed = 2 ** (2 ** es) + maxpos = useed ** (nbits - 2) + + wires = [a, b] + core = [[random.randint(0, maxpos) for _ in range(5)] for _ in wires] + + core[0][:2] = [0, 1] + core[1][:2] = [0, 1] + core[0].append(123) + core[1].append(123) + core[0].append(124) + core[1].append(123) + + vals_raw = core + vals = [[decimal_to_posit(j, nbits, es) for j in row] for row in vals_raw] + + out_vals = utils.sim_and_ret_out(out, wires, vals) + + true_result_raw = [x - y for x, y in zip(vals_raw[0], vals_raw[1])] + true_result = [decimal_to_posit(i, nbits, es) for i in true_result_raw] + + for idx, (sim, expected) in enumerate(zip(out_vals, true_result)): + delta = abs(sim - expected) + self.assertLessEqual( + delta, 1, + f"Mismatch at index {idx}: sim={sim}, expected={expected}, |sim-exp|={delta}" + ) + + +if __name__ == "__main__": + unittest.main() From 018a95ba05761de8dc4ec9c9a0a036d30864c4c2 Mon Sep 17 00:00:00 2001 From: arvindajaybharadwaj Date: Sun, 14 Sep 2025 13:30:37 +0530 Subject: [PATCH 29/33] fix: handle large posit values correctly --- pyrtl/positutils.py | 88 ++++++++++++++++++++++----------------------- 1 file changed, 44 insertions(+), 44 deletions(-) diff --git a/pyrtl/positutils.py b/pyrtl/positutils.py index e05cbe89..f3bd86ce 100644 --- a/pyrtl/positutils.py +++ b/pyrtl/positutils.py @@ -488,8 +488,8 @@ def decimal_to_posit(x: float, nbits: int, es: int) -> int: """ if x == 0: return 0 - - # handle sign at the end of twos comp + + # Sign sign = 0 if x < 0: sign = 1 @@ -497,35 +497,18 @@ def decimal_to_posit(x: float, nbits: int, es: int) -> int: useed = 2 ** (2 ** es) - if x == float('inf'): - return (1 << (nbits - 1)) - 1 - if x == 0 or x < useed ** (-(nbits - 2)): - return 0 - - # regime bits - if x >= 1: - k = int(math.floor(math.log(x, useed))) - else: - k = int(math.floor(math.log(x, useed))) - - regime_scale = useed ** k - remaining = x / regime_scale + k = int(math.floor(math.log(x, useed))) + regime_value = useed ** k + remaining = x / regime_value - # Exponent bits - exponent = 0 - if es > 0 and remaining > 0: - exponent = int(math.floor(math.log2(remaining))) - exponent = max(0, min(exponent, (1 << es) - 1)) - remaining /= 2 ** exponent + exponent = int(math.floor(math.log2(remaining))) if es > 0 else 0 + exponent = max(0, exponent) + remaining /= 2 ** exponent - # Frcation bits + # Fraction bits fraction = remaining - 1.0 frac_bits = [] - - #Remaning bits - max_frac_bits = nbits - 1 - - for _ in range(max_frac_bits): + for _ in range(nbits * 2): fraction *= 2 if fraction >= 1: frac_bits.append("1") @@ -533,34 +516,51 @@ def decimal_to_posit(x: float, nbits: int, es: int) -> int: else: frac_bits.append("0") - # Build regime bits + # Regime bits if k >= 0: regime_bits = "1" * (k + 1) + "0" else: regime_bits = "0" * (-k) + "1" - bits = "0" + regime_bits - - # Add exponent bits + bits = "0" + regime_bits + + # Exponent bits if es > 0: - exp_str = format(exponent, f"0{es}b") + exp_str = format(exponent & ((1 << es) - 1), f"0{es}b") bits += exp_str - - # Add fraction bits + bits += "".join(frac_bits) - # Trim to nbits with rounding + # Handle rounding if bits exceed nbits if len(bits) > nbits: - bits = bits[:nbits] + main = bits[:nbits] + guard = bits[nbits] + roundb = bits[nbits + 1] if nbits + 1 < len(bits) else "0" + sticky = "1" if "1" in bits[nbits + 2:] else "0" + + increment = ( + (guard == "1") + and (roundb == "1" or sticky == "1" or main[-1] == "1") + ) + + if increment: + main_int = int(main, 2) + 1 + if main_int >= (1 << (nbits - 1)): + main_int = (1 << (nbits - 1)) - 1 + main = format(main_int, f"0{nbits}b") + + bits = main else: bits = bits.ljust(nbits, "0") - # Convert to integer - result = int(bits, 2) - - # Apply twos complement for negative numbers + ones_comp = "" if sign: - mask = (1 << nbits) - 1 - result = ((~result) + 1) & mask - - return result + for i in bits: + if i == "0": + ones_comp = ones_comp + "1" + else: + ones_comp = ones_comp + "0" + result = int(ones_comp, 2) + 1 + return result + + return int(bits, 2) From c55a3ce754b51d20c522b51e56c86a7d00ee1148 Mon Sep 17 00:00:00 2001 From: arvindajaybharadwaj Date: Sun, 14 Sep 2025 13:31:52 +0530 Subject: [PATCH 30/33] fix: correct import path for math and posit_add --- pyrtl/rtllib/positsub.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/pyrtl/rtllib/positsub.py b/pyrtl/rtllib/positsub.py index c0ba5fd2..a3ca283d 100644 --- a/pyrtl/rtllib/positsub.py +++ b/pyrtl/rtllib/positsub.py @@ -1,4 +1,5 @@ import pyrtl +import math from pyrtl.corecircuits import ( shift_left_logical, shift_right_logical, @@ -15,7 +16,7 @@ twos_comp, sign_ext, ) -from pyrtl.rtllib.positadd import posit_add +from pyrtl.rtllib.positadder import posit_add def posit_sub( a: pyrtl.WireVector, b: pyrtl.WireVector, nbits: int, es: int @@ -25,8 +26,8 @@ def posit_sub( .. doctest:: >>> import pyrtl - >>> from positutils import decimal_to_posit - >>> from positsub import posit_sub + >>> from pyrtl.positutils import decimal_to_posit + >>> from pyrtl.rtllib.positsub import posit_sub >>> pyrtl.reset_working_block() >>> nbits, es = 8, 1 >>> a = pyrtl.Input(bitwidth=nbits, name='a') From 93260ecf51472f6ad5ada830dda6178c65872cd3 Mon Sep 17 00:00:00 2001 From: arvindajaybharadwaj Date: Sun, 14 Sep 2025 13:33:28 +0530 Subject: [PATCH 31/33] fix: correct path for positsub --- tests/rtllib/test_positsub.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/rtllib/test_positsub.py b/tests/rtllib/test_positsub.py index c2d11ec4..a7c7dcab 100644 --- a/tests/rtllib/test_positsub.py +++ b/tests/rtllib/test_positsub.py @@ -1,7 +1,6 @@ import doctest import random import unittest -import positsub import pyrtl import pyrtl.rtllib.testingutils as utils from pyrtl.rtllib.positsub import posit_sub @@ -12,7 +11,7 @@ class TestDocTests(unittest.TestCase): """Test documentation examples.""" def test_doctests(self): - failures, tests = doctest.testmod(m=positsub) + failures, tests = doctest.testmod(m=pyrtl.rtllib.positsub) self.assertGreater(tests, 0) self.assertEqual(failures, 0) From 678ab9ca6e5e81c3e5a205cdb2c452009df5d23a Mon Sep 17 00:00:00 2001 From: arvindajaybharadwaj Date: Mon, 29 Sep 2025 08:25:53 +0530 Subject: [PATCH 32/33] update: remove posit sub --- pyrtl/rtllib/positsub.py | 291 ---------------------------------- tests/rtllib/test_positsub.py | 71 --------- 2 files changed, 362 deletions(-) delete mode 100644 pyrtl/rtllib/positsub.py delete mode 100644 tests/rtllib/test_positsub.py diff --git a/pyrtl/rtllib/positsub.py b/pyrtl/rtllib/positsub.py deleted file mode 100644 index a3ca283d..00000000 --- a/pyrtl/rtllib/positsub.py +++ /dev/null @@ -1,291 +0,0 @@ -import pyrtl -import math -from pyrtl.corecircuits import ( - shift_left_logical, - shift_right_logical, - shift_right_arithmetic, -) -from pyrtl.positutils import ( - decode_posit, - get_upto_regime, - zero_ext, - resize, - unify_width, - absdiff, - bitlen_u, - twos_comp, - sign_ext, -) -from pyrtl.rtllib.positadder import posit_add - -def posit_sub( - a: pyrtl.WireVector, b: pyrtl.WireVector, nbits: int, es: int -) -> pyrtl.WireVector: - """Subtracts two numbers in posit format and returns their difference. - - .. doctest:: - - >>> import pyrtl - >>> from pyrtl.positutils import decimal_to_posit - >>> from pyrtl.rtllib.positsub import posit_sub - >>> pyrtl.reset_working_block() - >>> nbits, es = 8, 1 - >>> a = pyrtl.Input(bitwidth=nbits, name='a') - >>> b = pyrtl.Input(bitwidth=nbits, name='b') - >>> out = pyrtl.Output(bitwidth=nbits, name='out') - >>> out <<= posit_sub(a, b, nbits, es) - >>> sim = pyrtl.Simulation() - >>> aval = decimal_to_posit(4.5, nbits, es) - >>> bval = decimal_to_posit(2.0, nbits, es) - >>> sim.step({'a': aval, 'b': bval}) # 4.5 - 2.0 = 2.5 - >>> sim.inspect('out') == decimal_to_posit(2.5, nbits, es) - True - """ - # special cases - nar = pyrtl.Const(1 << (nbits - 1), bitwidth=nbits) - zero = pyrtl.Const(0, bitwidth=nbits) - mask = pyrtl.Const((1 << nbits) - 1, bitwidth=nbits) - maxpos = pyrtl.Const((1 << (nbits - 1)) - 1, bitwidth=nbits) - - is_nar = (a == nar) | (b == nar) - negb_quick = ((~b) + pyrtl.Const(1, bitwidth=nbits)) & mask - quick = pyrtl.select( - is_nar, - nar, - pyrtl.select( - a == zero, - negb_quick, - pyrtl.select(b == zero, a, zero), - ), - ) - have_quick = is_nar | (a == zero) | (b == zero) - - # Decode input posit - sign1, k1, exponent1, frac1, fl1 = decode_posit(a, nbits, es) - sign2, k2, exponent2, frac2, fl2 = decode_posit(b, nbits, es) - - # check if inputs are opposite sign - opp = sign1 != sign2 - neg_a = twos_comp(a, nbits) - neg_b = twos_comp(b, nbits) - sum_posneg = posit_add(a, neg_b, nbits, es) - sum_negpos = twos_comp(posit_add(neg_a, b, nbits, es), nbits) - sum_v = pyrtl.select(sign1 == pyrtl.Const(0, 1), sum_posneg, sum_negpos) - - # Internal widths - SC_BW = max( - nbits + es + 6, - k1.bitwidth + es + 2, - k2.bitwidth + es + 2, - exponent1.bitwidth + 2, - exponent2.bitwidth + 2, - ) - - k1_se = sign_ext(k1, SC_BW) - k2_se = sign_ext(k2, SC_BW) - exp1_ze = zero_ext(exponent1, SC_BW) - exp2_ze = zero_ext(exponent2, SC_BW) - - # compute scale = k*2^es + exponent - if es == 0: - scale1 = k1_se + exp1_ze - scale2 = k2_se + exp2_ze - else: - sh_es_sc = pyrtl.Const(es, bitwidth=SC_BW) - scale1 = shift_left_logical(k1_se, sh_es_sc) + exp1_ze - scale2 = shift_left_logical(k2_se, sh_es_sc) + exp2_ze - - #Align fraction precision to max(fl1, fl2) - frac_bits = pyrtl.select(fl1 > fl2, fl1, fl2) - shift12 = fl1 - fl2 - shift21 = fl2 - fl1 - f1a = pyrtl.select(fl1 >= fl2, frac1, shift_left_logical(frac1, shift21)) - f2a = pyrtl.select(fl2 >= fl1, frac2, shift_left_logical(frac2, shift12)) - - one_n = pyrtl.Const(1, bitwidth=nbits) - one_frac = shift_left_logical(one_n, frac_bits) - - # offset calculation between scales - offset = scale1 - scale2 - off_neg = offset[SC_BW - 1] - abs_off = pyrtl.select(off_neg, ((~offset) + pyrtl.Const(1, bitwidth=SC_BW)), offset) - - W = max(nbits * 2, one_frac.bitwidth + SC_BW + 2) - f1w = zero_ext(f1a, W) - f2w = zero_ext(f2a, W) - onew = zero_ext(one_frac, W) - offW = resize(abs_off, W) - - # Add hidden one to both sides and align by offset - a_sum = f1w + onew - b_sum = f2w + onew - a_sh = shift_left_logical(a_sum, offW[: a_sum.bitwidth]) - b_sh = shift_left_logical(b_sum, offW[: b_sum.bitwidth]) - - was_neg_A, diff_A = absdiff(f1w, f2w) - scale_A = scale1 - sign_A = sign1 ^ was_neg_A - - was_neg_B, diff_B = absdiff(a_sh, b_sum) - scale_B = scale1 - sign_B = pyrtl.select(was_neg_B, sign1 ^ pyrtl.Const(1, 1), sign1) - - was_neg_C, diff_C = absdiff(b_sh, a_sum) - scale_C = scale2 - base_sign_C = sign1 ^ pyrtl.Const(1, 1) - sign_C = pyrtl.select(was_neg_C, base_sign_C ^ pyrtl.Const(1, 1), base_sign_C) - - is_zero_off = abs_off == pyrtl.Const(0, bitwidth=SC_BW) - - # choose diff/scale/sign by offset sign - res_frac0 = pyrtl.select(is_zero_off, diff_A, pyrtl.select(off_neg, diff_C, diff_B)) - res_scale0 = pyrtl.select(is_zero_off, scale_A, pyrtl.select(off_neg, scale_C, scale_B)) - sign0 = pyrtl.select(is_zero_off, sign_A, pyrtl.select(off_neg, sign_C, sign_B)) - - # exact cancel (same fields when abs_off==0) - same_fields = (k1 == k2) & (exponent1 == exponent2) & (frac1 == frac2) - is_exact_cancel = is_zero_off & same_fields - res_frac0 = pyrtl.select(is_exact_cancel, pyrtl.Const(0, bitwidth=W), res_frac0) - sign0 = pyrtl.select(is_exact_cancel, pyrtl.Const(0, bitwidth=1), sign0) - - # Only for es==0, same-sign, nonzero offset, and neg result. - same_sign = ~opp - nonzero_off = abs_off != pyrtl.Const(0, bitwidth=SC_BW) - es_is_zero = pyrtl.Const(1, bitwidth=1) if es == 0 else pyrtl.Const(0, bitwidth=1) - need_k_corr = same_sign & nonzero_off & es_is_zero & (sign0 == pyrtl.Const(1, bitwidth=1)) - res_scale0 = pyrtl.select(need_k_corr, res_scale0 + pyrtl.Const(1, bitwidth=SC_BW), res_scale0) - - # Normalize to target precision - blen_bw = max(8, int(math.ceil(math.log2(W + 1)))) - bitlen0 = bitlen_u(res_frac0, blen_bw) - - fb_target = resize(frac_bits, blen_bw) + pyrtl.Const(1, bitwidth=blen_bw) - - diff_needed = fb_target - bitlen0 - need_extend = ~diff_needed[blen_bw - 1] - extend_amt = resize(diff_needed[:W], W) - max_shift = pyrtl.Const(W - 1, bitwidth=W) - extend_amt = pyrtl.select(extend_amt > max_shift, max_shift, extend_amt) - - res_frac1 = pyrtl.select(need_extend, shift_left_logical(res_frac0, extend_amt), res_frac0) - res_scale1 = pyrtl.select(need_extend, res_scale0 - resize(extend_amt, SC_BW), res_scale0) - bitlen1 = pyrtl.select(need_extend, fb_target, bitlen0) - - # scale tweak for same-sign: + (bitlen - 1 - |offset| - frac_bits) - adj1 = resize(bitlen1, SC_BW) - pyrtl.Const(1, bitwidth=SC_BW) - adj2 = adj1 - resize(abs_off, SC_BW) - resize(frac_bits, SC_BW) - scale_final = res_scale1 + adj2 - - # Extract k and exponent - if es == 0: - resultk = scale_final - resultExponent_sc = pyrtl.Const(0, bitwidth=SC_BW) - else: - shamt_sf = pyrtl.Const(es, bitwidth=SC_BW) - resultk = shift_right_arithmetic(scale_final, shamt_sf) - k_lsl_sc = shift_left_logical(resize(resultk, SC_BW), shamt_sf) - resultExponent_sc = scale_final - k_lsl_sc - - # Regime packing - rem_bits, regime = get_upto_regime( - resize(resultk, nbits), nbits, pyrtl.Const(0, bitwidth=1) - ) - - # Small posit path - es_nb = pyrtl.Const(es, bitwidth=nbits) - is_small = rem_bits <= es_nb - shift_amt_small = es_nb - rem_bits - exp_shifted_small = shift_right_logical(resize(resultExponent_sc, nbits), shift_amt_small) - small_value = regime + exp_shifted_small - - # For exponent drop (rem_bits < es) - shift_amt_small_nz = shift_amt_small != pyrtl.Const(0, bitwidth=nbits) - sam1 = pyrtl.select( - shift_amt_small_nz, - shift_amt_small - pyrtl.Const(1, bitwidth=nbits), - pyrtl.Const(0, bitwidth=nbits), - ) - guard_src_small = shift_right_logical(resize(resultExponent_sc, SC_BW), resize(sam1, SC_BW)) - guard_exp = pyrtl.select(shift_amt_small_nz, guard_src_small[0], pyrtl.Const(0, bitwidth=1)) - - one_sc = pyrtl.Const(1, bitwidth=SC_BW) - lower_mask_small = pyrtl.select( - shift_amt_small_nz, - shift_left_logical(one_sc, resize(sam1, SC_BW)) - one_sc, - pyrtl.Const(0, bitwidth=SC_BW), - ) - sticky_exp = (resize(resultExponent_sc, SC_BW) & lower_mask_small) != pyrtl.Const(0, bitwidth=SC_BW) - - # Normal packing fields - rem_gt_es = rem_bits > es_nb - frac_bits_avail = pyrtl.select(rem_gt_es, rem_bits - es_nb, pyrtl.Const(0, bitwidth=nbits)) - - sum_keep = resize(frac_bits_avail, blen_bw) + pyrtl.Const(1, bitwidth=blen_bw) - - bitlen1_u, sum_keep_u, Wc = unify_width(bitlen1, sum_keep) - ge = bitlen1_u >= sum_keep_u - - r_amt_wide = pyrtl.select(ge, bitlen1_u - sum_keep_u, pyrtl.Const(0, bitwidth=Wc)) - l_amt_wide = pyrtl.select(ge, pyrtl.Const(0, bitwidth=Wc), sum_keep_u - bitlen1_u) - r_amt = resize(r_amt_wide, W) - l_amt = resize(l_amt_wide, W) - - kept_plus_hidden = pyrtl.select( - ge, shift_right_logical(res_frac1, r_amt), shift_left_logical(res_frac1, l_amt) - ) - - r_amt_nonzero = r_amt_wide != pyrtl.Const(0, bitwidth=Wc) - r_amt_minus1 = resize(r_amt - pyrtl.Const(1, bitwidth=r_amt.bitwidth), W) - guard_src = shift_right_logical(res_frac1, r_amt_minus1) - guard_frac = pyrtl.select(ge & r_amt_nonzero, guard_src[0], pyrtl.Const(0, bitwidth=1)) - - trimmed = shift_right_logical(res_frac1, r_amt) - recon = shift_left_logical(trimmed, r_amt) - sticky_frac = pyrtl.select(ge & r_amt_nonzero, (res_frac1 != recon), pyrtl.Const(0, bitwidth=1)) - - # Remove hidden one to form fraction field - oneW = pyrtl.Const(1, bitwidth=W) - one_keep = shift_left_logical(oneW, resize(frac_bits_avail, W)) - frac_field_w = kept_plus_hidden - one_keep - - exp_shifted = shift_left_logical(resize(resultExponent_sc, nbits), frac_bits_avail) - frac_mask = shift_left_logical(pyrtl.Const(1, bitwidth=nbits), frac_bits_avail) - pyrtl.Const(1, bitwidth=nbits) - frac_field = resize(frac_field_w, nbits) & frac_mask - - value_large = regime + exp_shifted + frac_field - - # Tie LSBs for rounding-to-even - lsb_large = value_large[0] - lsb_small = small_value[0] - - # Normal path rounding - round_up_large = guard_frac & (sticky_frac | lsb_large) - value_rounded = pyrtl.select((value_large != maxpos) & round_up_large, value_large + 1, value_large) - - any_frac = r_amt_nonzero & (guard_frac | sticky_frac) - guard_small_final = pyrtl.select(shift_amt_small_nz, guard_exp, guard_frac) - sticky_small_final = pyrtl.select(shift_amt_small_nz, (sticky_exp | any_frac), sticky_frac) - - round_up_small = guard_small_final & (sticky_small_final | lsb_small) - small_value_rounded = pyrtl.select((small_value != maxpos) & round_up_small, small_value + 1, small_value) - - # Select packed (unsigned) posit - packed_pos = pyrtl.select(is_small, small_value_rounded, value_rounded) - - # Apply sign of result - packed_signed = pyrtl.select( - sign0 == pyrtl.Const(1, bitwidth=1), - ((~packed_pos) + pyrtl.Const(1, bitwidth=nbits)) & mask, - packed_pos, - ) - - is_zero_res = (res_frac1 == pyrtl.Const(0, bitwidth=W)) | is_exact_cancel - same_sign_out = pyrtl.select(is_zero_res, zero, packed_signed) - - # Final select: special case / opposite sign / same sign - nonquick = pyrtl.select(opp, sum_v, same_sign_out) - result = pyrtl.select(have_quick, quick, nonquick) - return result - - - diff --git a/tests/rtllib/test_positsub.py b/tests/rtllib/test_positsub.py deleted file mode 100644 index a7c7dcab..00000000 --- a/tests/rtllib/test_positsub.py +++ /dev/null @@ -1,71 +0,0 @@ -import doctest -import random -import unittest -import pyrtl -import pyrtl.rtllib.testingutils as utils -from pyrtl.rtllib.positsub import posit_sub -from pyrtl.positutils import decimal_to_posit - - -class TestDocTests(unittest.TestCase): - """Test documentation examples.""" - - def test_doctests(self): - failures, tests = doctest.testmod(m=pyrtl.rtllib.positsub) - self.assertGreater(tests, 0) - self.assertEqual(failures, 0) - - -class TestPositSubtractor(unittest.TestCase): - @classmethod - def setUpClass(cls): - cls.seed = 42 - random.seed(cls.seed) - - def setUp(self): - pyrtl.reset_working_block() - - def tearDown(self): - pyrtl.reset_working_block() - - def test_posit_subtractor(self): - nbits_list = [8, 16, 32] - es_list = [0, 1, 2, 3, 4] - nbits, es = random.choice(nbits_list), random.choice(es_list) - - a = pyrtl.Input(bitwidth=nbits, name="a") - b = pyrtl.Input(bitwidth=nbits, name="b") - out = pyrtl.Output(name="out") - out <<= posit_sub(a, b, nbits, es) - - useed = 2 ** (2 ** es) - maxpos = useed ** (nbits - 2) - - wires = [a, b] - core = [[random.randint(0, maxpos) for _ in range(5)] for _ in wires] - - core[0][:2] = [0, 1] - core[1][:2] = [0, 1] - core[0].append(123) - core[1].append(123) - core[0].append(124) - core[1].append(123) - - vals_raw = core - vals = [[decimal_to_posit(j, nbits, es) for j in row] for row in vals_raw] - - out_vals = utils.sim_and_ret_out(out, wires, vals) - - true_result_raw = [x - y for x, y in zip(vals_raw[0], vals_raw[1])] - true_result = [decimal_to_posit(i, nbits, es) for i in true_result_raw] - - for idx, (sim, expected) in enumerate(zip(out_vals, true_result)): - delta = abs(sim - expected) - self.assertLessEqual( - delta, 1, - f"Mismatch at index {idx}: sim={sim}, expected={expected}, |sim-exp|={delta}" - ) - - -if __name__ == "__main__": - unittest.main() From ae2d8a883701b21bef585f3ebd81fc3f960357f1 Mon Sep 17 00:00:00 2001 From: arvindajaybharadwaj Date: Mon, 29 Sep 2025 08:40:34 +0530 Subject: [PATCH 33/33] update: remove unnecessary util functions --- pyrtl/positutils.py | 159 -------------------------------------------- 1 file changed, 159 deletions(-) diff --git a/pyrtl/positutils.py b/pyrtl/positutils.py index f3bd86ce..6de448b6 100644 --- a/pyrtl/positutils.py +++ b/pyrtl/positutils.py @@ -305,165 +305,6 @@ def remove_first_one(val: pyrtl.WireVector) -> pyrtl.WireVector: return pyrtl.concat_list(result_bits[::-1]) -def twos_comp(x: pyrtl.WireVector, n: int) -> pyrtl.WireVector: - """Compute the two's complement of an n-bit WireVector. - - Two's complement is the standard way of representing signed integers - in binary systems. The process is: - 1. Invert all the bits (one's complement). - 2. Add 1 to the result. - This function ensures the result is limited to 'n' bits. - - Example:: - - >>> import pyrtl - >>> pyrtl.reset_working_block() - >>> x = pyrtl.Const(5, bitwidth=4) # 0101 (decimal 5) - >>> result = twos_comp(x, 4) - >>> sim = pyrtl.Simulation() - >>> sim.step({}) - >>> format(sim.inspect(result), '04b') - '1011' # -5 in two's complement - - :param x: The input value as a PyRTL WireVector. - :param n: Bitwidth to operate on. - :return: The n-bit two's complement representation of x. - """ - - mask = pyrtl.Const((1 << n) - 1, bitwidth=n) - inverted = x ^ mask - added = inverted + 1 - return added & mask - -def zero_ext(x: pyrtl.WireVector, new_bw: int) -> pyrtl.WireVector: - """ - Zero-extend `x` to `new_bw` bits. - - Behavior: - - If `new_bw == x.bitwidth`, returns `x` unchanged (no copy). - - If `new_bw > x.bitwidth`, pads MSBs with zeros so the unsigned value is preserved. - - Truncation is NOT performed here (use `_resize` for that). - - Equivalent to concatenating (new_bw - x.bitwidth) zero bits above `x`. - """ - assert new_bw >= x.bitwidth - if new_bw == x.bitwidth: - return x - return pyrtl.concat(pyrtl.Const(0, bitwidth=new_bw - x.bitwidth), x) - - -def resize(x: pyrtl.WireVector, new_bw: int) -> pyrtl.WireVector: - """ - Resize `x` to exactly `new_bw` bits. - - Behavior: - - If `new_bw == x.bitwidth`, returns `x` unchanged. - - If `new_bw < x.bitwidth`, returns the lower `new_bw` bits (truncates MSBs). - - If `new_bw > x.bitwidth`, zero-extends to widen (unsigned semantics). - - Note: This is UNSIGNED resizing. If you need sign-preserving growth, - sign-extend externally instead of using `_resize`. - """ - if new_bw == x.bitwidth: - return x - if new_bw < x.bitwidth: - return x[:new_bw] - return zero_ext(x, new_bw) - - -def sign_ext(x: pyrtl.WireVector, new_bw: int) -> pyrtl.WireVector: - """ - Sign-extend `x` to bitwidth `new_bw` (two's-complement semantics). - - Behavior: - - If `new_bw <= x.bitwidth`, returns `x` unchanged (no truncation or copy). - - If `new_bw > x.bitwidth`, replicates the MSB (sign bit) into the new upper bits - so that the signed value is preserved under two's-complement interpretation. - - Use when: - - Widening **signed** quantities (e.g., regime `k`) before arithmetic or shifting. - For **unsigned** widening, prefer `zero_ext`. - - Implementation note: - - Captures `signbit = x[x.bitwidth-1]`, creates `new_bw - x.bitwidth` copies, - and concatenates them above `x`. - - :param x: Source WireVector interpreted as a signed two's-complement value. - :param new_bw: Target bitwidth to extend to. - :return: A WireVector of width `new_bw` with the same signed value as `x`. - """ - if new_bw <= x.bitwidth: - return x - signbit = x[x.bitwidth-1] - pad = pyrtl.concat_list([signbit for _ in range(new_bw - x.bitwidth)]) - return pyrtl.concat(pad, x) - - -def unify_width(a: pyrtl.WireVector, b: pyrtl.WireVector): - """ - Return `(a_w, b_w, W)` where both inputs are widened to the same bitwidth `W`. - - - `W = max(a.bitwidth, b.bitwidth)` - - Widening uses zero-extension (unsigned semantics). - - Useful before doing comparisons/arithmetic that expects matching widths. - - :return: (a_zero_extended, b_zero_extended, unified_bitwidth) - """ - W = max(a.bitwidth, b.bitwidth) - return zero_ext(a, W), zero_ext(b, W), W - - -def absdiff(a: pyrtl.WireVector, b: pyrtl.WireVector): - """ - Unsigned absolute difference between `a` and `b`. - - Steps (purely combinational): - 1) Zero-extend to a common width. - 2) Compute `a >= b` to pick subtraction order. - 3) Return `(was_neg, |a-b|)`, where: - - `was_neg` is 1 when `a < b` (i.e., subtraction would be negative), - - magnitude is the absolute difference. - - :return: (was_neg:1-bit, magnitude:WireVector) - """ - aa, bb, W = unify_width(a, b) - a_ge_b = aa >= bb - mag = pyrtl.select(a_ge_b, aa - bb, bb - aa) - was_neg = ~a_ge_b & pyrtl.Const(1, bitwidth=1) - return was_neg, mag - - -def bitlen_u(x: pyrtl.WireVector, out_bw: int) -> pyrtl.WireVector: - """ - Return the UNSIGNED bit-length of `x` (index of MSB + 1). - - Definition: - - bitlen(0b00010100) = 5 - - bitlen(0b00000000) = 0 - - Implementation detail: - - Counts leading zeros, then computes `bitlen = width - leading_zeros`. - - `out_bw` must be large enough to encode up to `x.bitwidth` - (use >= ceil(log2(x.bitwidth+1)) to be safe). - - :param x: WireVector to analyze (treated as unsigned). - :param out_bw: Bitwidth of the integer result (e.g., 8 or more). - :return: WireVector of width `out_bw` with the bit-length of `x`. - """ - bw = x.bitwidth - lz = pyrtl.Const(0, bitwidth=out_bw) - seen = pyrtl.Const(0, bitwidth=1) - - for i in range(bw): - bit = x[bw - 1 - i] - inc = (~seen) & (~bit) - lz = lz + pyrtl.select(inc, pyrtl.Const(1, bitwidth=out_bw), - pyrtl.Const(0, bitwidth=out_bw)) - seen = seen | bit - - return pyrtl.Const(bw, bitwidth=out_bw) - lz - - def decimal_to_posit(x: float, nbits: int, es: int) -> int: """Convert a decimal float to Posit representation.