Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .python-version
Original file line number Diff line number Diff line change
@@ -1 +1 @@
3.13
3.14
2 changes: 1 addition & 1 deletion justfile
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ tests:

# Run `pytest` in an isolated virtual environment, with the earliest
# version of Python supported by PyRTL.
uv run --python=3.9 --isolated pytest -n auto
uv run --python=3.10 --isolated pytest -n auto

# Run `ruff format` to check that code is formatted properly.
#
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ authors = [
description = "RTL-level Hardware Design and Simulation Toolkit"
readme = "README.md"
license = {file = "LICENSE.md"}
requires-python = ">=3.9"
requires-python = ">=3.10"
classifiers = [
"Development Status :: 4 - Beta",
"Environment :: Console",
Expand Down
3 changes: 1 addition & 2 deletions pyrtl/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,7 @@
import subprocess
import sys
import tempfile
from collections.abc import Iterable
from typing import Callable
from collections.abc import Callable, Iterable

from pyrtl.core import Block, LogicNet, working_block
from pyrtl.helperfuncs import _currently_in_jupyter_notebook, _print_netlist_latex
Expand Down
6 changes: 3 additions & 3 deletions pyrtl/helperfuncs.py
Original file line number Diff line number Diff line change
Expand Up @@ -493,7 +493,7 @@ def fn(n):
msg = f"bitpattern field {exc.args[0]} was not provided in named_field list"
raise PyrtlError(msg) from exc

fmap = dict(zip(lifo, intfields))
fmap = dict(zip(lifo, intfields, strict=True))
for c in bitpattern[::-1]:
if c == "0" or c == "1":
bitlist.append(c)
Expand Down Expand Up @@ -582,7 +582,7 @@ def chop(w: WireVector, *segment_widths: int) -> list[WireVector]:
n_segments = len(segment_widths)
starts = [sum(segment_widths[i + 1 :]) for i in range(n_segments)]
ends = [sum(segment_widths[i:]) for i in range(n_segments)]
return [w[s:e] for s, e in zip(starts, ends)]
return [w[s:e] for s, e in zip(starts, ends, strict=True)]


def input_list(
Expand Down Expand Up @@ -706,7 +706,7 @@ def wirevector_list(
)

wirelist = []
for fullname, bw in zip(names, bitwidth):
for fullname, bw in zip(names, bitwidth, strict=True):
try:
name, bw = fullname.split("/")
except ValueError:
Expand Down
4 changes: 2 additions & 2 deletions pyrtl/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@
import collections
import numbers
import types
from collections.abc import Sequence
from typing import Callable, NamedTuple
from collections.abc import Callable, Sequence
from typing import NamedTuple

from pyrtl.core import Block, LogicNet, _NameIndexer, working_block
from pyrtl.corecircuits import as_wires
Expand Down
2 changes: 1 addition & 1 deletion pyrtl/rtllib/adders.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

import itertools
import math
from typing import Callable
from collections.abc import Callable

import pyrtl

Expand Down
2 changes: 1 addition & 1 deletion pyrtl/rtllib/multipliers.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
"""

import math
from typing import Callable
from collections.abc import Callable

import pyrtl
from pyrtl.rtllib import adders
Expand Down
4 changes: 2 additions & 2 deletions pyrtl/rtllib/muxes.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,7 @@ def _add_signal(self, data_signals):
)
raise pyrtl.PyrtlError(msg)

for dw, sig in zip(self.dest_wires, data_signals):
for dw, sig in zip(self.dest_wires, data_signals, strict=True):
data_signal = pyrtl.as_wires(sig, dw.bitwidth)
self.dest_instrs_info[dw].append(data_signal)

Expand All @@ -268,7 +268,7 @@ def finalize(self):
self._final = True

for dest_w, values in self.dest_instrs_info.items():
mux_vals = dict(zip(self.instructions, values))
mux_vals = dict(zip(self.instructions, values, strict=False))
dest_w <<= sparse_mux(self.signal_wire, mux_vals)


Expand Down
7 changes: 4 additions & 3 deletions pyrtl/rtllib/testingutils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

import random
from typing import Callable
from collections.abc import Callable

import pyrtl

Expand Down Expand Up @@ -62,7 +62,8 @@ def make_inputs_and_values(
random_dist=dist,
)
for i in range(num_wires)
)
),
strict=True,
)
)
return wires, vals
Expand Down Expand Up @@ -165,7 +166,7 @@ def sim_and_ret_outws(
:class:`list` of its values in each cycle.
"""
sim = pyrtl.Simulation()
sim.step_multiple(provided_inputs=dict(zip(inwires, invals)))
sim.step_multiple(provided_inputs=dict(zip(inwires, invals, strict=True)))
return sim.tracer.trace


Expand Down
3 changes: 1 addition & 2 deletions pyrtl/simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,7 @@
import re
import sys
import warnings
from collections.abc import Mapping
from typing import Callable
from collections.abc import Callable, Mapping

from pyrtl.core import Block, PostSynthBlock, _PythonSanitizer, working_block
from pyrtl.helperfuncs import (
Expand Down
3 changes: 2 additions & 1 deletion pyrtl/visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
from __future__ import annotations

import collections
from typing import TYPE_CHECKING, Callable
from collections.abc import Callable
from typing import TYPE_CHECKING

from pyrtl.core import Block, LogicNet, working_block
from pyrtl.pyrtlexceptions import PyrtlError, PyrtlInternalError
Expand Down
3 changes: 1 addition & 2 deletions pyrtl/wire.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
import numbers
import re
import traceback
from typing import Union

from pyrtl import core # needed for _setting_keep_wirevector_call_stack
from pyrtl.core import Block, LogicNet, _NameIndexer, working_block
Expand Down Expand Up @@ -1456,7 +1455,7 @@ def _extend_with_bit(self, bitwidth, extbit):
return concat(extvector, self)


WireVectorLike = Union[WireVector, int, str, bool]
WireVectorLike = WireVector | int | str | bool
"""Alias for types that can be coerced to :class:`WireVector` by :func:`as_wires`."""


Expand Down
4 changes: 2 additions & 2 deletions tests/rtllib/test_adders.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def adder_t_base(self, adder_func, **kwargs):
outwire <<= adder_func(*wires)

out_vals = utils.sim_and_ret_out(outwire, wires, vals)
true_result = [sum(cycle_vals) for cycle_vals in zip(*vals)]
true_result = [sum(cycle_vals) for cycle_vals in zip(*vals, strict=True)]
self.assertEqual(out_vals, true_result)

def test_kogge_stone_1(self):
Expand All @@ -61,7 +61,7 @@ def test_fast_group_adder_1(self):
outwire <<= adders.fast_group_adder(wires)

out_vals = utils.sim_and_ret_out(outwire, wires, vals)
true_result = [sum(cycle_vals) for cycle_vals in zip(*vals)]
true_result = [sum(cycle_vals) for cycle_vals in zip(*vals, strict=True)]
self.assertEqual(out_vals, true_result)


Expand Down
4 changes: 2 additions & 2 deletions tests/rtllib/test_libutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,14 @@ def test_partition_sim(self):
wires, vals = utils.make_inputs_and_values(exact_bitwidth=32, num_wires=1)
out_wires = [pyrtl.Output(8, "o" + str(i)) for i in range(4)]
partitioned_w = libutils.partition_wire(wires[0], 8)
for p_wire, o_wire in zip(partitioned_w, out_wires):
for p_wire, o_wire in zip(partitioned_w, out_wires, strict=True):
o_wire <<= p_wire

out_vals = utils.sim_and_ret_outws(wires, vals)
partitioned_vals = [
[(val >> i) & 0xFF for i in (0, 8, 16, 24)] for val in vals[0]
]
true_vals = tuple(zip(*partitioned_vals))
true_vals = tuple(zip(*partitioned_vals, strict=True))
for index, wire in enumerate(out_wires):
self.assertEqual(tuple(out_vals[wire.name]), true_vals[index])

Expand Down
12 changes: 6 additions & 6 deletions tests/rtllib/test_multipliers.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,10 @@ def mult_t_base(self, len_a, len_b):

xvals = [int(random.uniform(0, 2**len_a - 1)) for i in range(20)]
yvals = [int(random.uniform(0, 2**len_b - 1)) for i in range(20)]
true_result = [i * j for i, j in zip(xvals, yvals)]
true_result = [i * j for i, j in zip(xvals, yvals, strict=True)]
mult_results = []

for x_val, y_val in zip(xvals, yvals):
for x_val, y_val in zip(xvals, yvals, strict=True):
sim = pyrtl.Simulation()
sim.step({a: x_val, b: y_val, reset: 1})
while not sim.inspect("done"):
Expand Down Expand Up @@ -103,10 +103,10 @@ def mult_t_base(self, len_a, len_b, shifts):

xvals = [int(random.uniform(0, 2**len_a - 1)) for i in range(20)]
yvals = [int(random.uniform(0, 2**len_b - 1)) for i in range(20)]
true_result = [i * j for i, j in zip(xvals, yvals)]
true_result = [i * j for i, j in zip(xvals, yvals, strict=True)]
mult_results = []

for x_val, y_val in zip(xvals, yvals):
for x_val, y_val in zip(xvals, yvals, strict=True):
sim = pyrtl.Simulation()
sim.step({a: x_val, b: y_val, reset: 1})
while not sim.inspect("done"):
Expand Down Expand Up @@ -137,7 +137,7 @@ def mult_t_base(self, len_a, len_b, **mult_args):
# creating the testing values and the correct results
xvals = [int(random.uniform(0, 2**len_a - 1)) for i in range(20)]
yvals = [int(random.uniform(0, 2**len_b - 1)) for i in range(20)]
true_result = [i * j for i, j in zip(xvals, yvals)]
true_result = [i * j for i, j in zip(xvals, yvals, strict=True)]

# Setting up and running the tests
sim = pyrtl.Simulation()
Expand Down Expand Up @@ -234,7 +234,7 @@ def mult_t_base(self, len_a, len_b, **mult_args):
bound_b = 2 ** (len_b - 1) - 1
xvals = [int(random.uniform(-bound_a, bound_a)) for i in range(20)]
yvals = [int(random.uniform(-bound_b, bound_b)) for i in range(20)]
true_result = [i * j for i, j in zip(xvals, yvals)]
true_result = [i * j for i, j in zip(xvals, yvals, strict=True)]

# Setting up and running the tests
sim = pyrtl.Simulation()
Expand Down
58 changes: 41 additions & 17 deletions tests/rtllib/test_muxes.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,10 @@ def test_select_with_2_wires(self):
out <<= muxes.prioritized_mux(sels, mux_ins)
actual = utils.sim_and_ret_out(out, sels + mux_ins, sel_vals + vals)
expected = [
pri_mux_actual(sel, val) for sel, val in zip(zip(*sel_vals), zip(*vals))
pri_mux_actual(sel, val)
for sel, val in zip(
zip(*sel_vals, strict=True), zip(*vals, strict=True), strict=True
)
]
self.assertEqual(actual, expected)

Expand All @@ -82,7 +85,10 @@ def test_select_with_5_wires(self):
out <<= muxes.prioritized_mux(sels, mux_ins)
actual = utils.sim_and_ret_out(out, sels + mux_ins, sel_vals + vals)
expected = [
pri_mux_actual(sel, val) for sel, val in zip(zip(*sel_vals), zip(*vals))
pri_mux_actual(sel, val)
for sel, val in zip(
zip(*sel_vals, strict=True), zip(*vals, strict=True), strict=True
)
]
self.assertEqual(actual, expected)

Expand Down Expand Up @@ -169,7 +175,7 @@ def test_two_vals(self):
in_vals = [sel_vals, a1_vals, a2_vals]
out_res = utils.sim_and_ret_out(res, [sel, a1, a2], in_vals)

expected_out = [e2 if sel else e1 for sel, e1, e2 in zip(*in_vals)]
expected_out = [e2 if sel else e1 for sel, e1, e2 in zip(*in_vals, strict=True)]
self.assertEqual(out_res, expected_out)

def test_two_vals_big(self):
Expand All @@ -186,7 +192,8 @@ def test_two_vals_big(self):
)

expected_out = [
e2 if sel else e1 for sel, e1, e2 in zip(sel_vals, a1_vals, a2_vals)
e2 if sel else e1
for sel, e1, e2 in zip(sel_vals, a1_vals, a2_vals, strict=True)
]
self.assertEqual(out_res, expected_out)

Expand All @@ -213,7 +220,8 @@ def test_two_big_close(self):
)

expected_out = [
e2 if sel else e1 for sel, e1, e2 in zip(sel_vals, a1_vals, a2_vals)
e2 if sel else e1
for sel, e1, e2 in zip(sel_vals, a1_vals, a2_vals, strict=True)
]
self.assertEqual(out_res, expected_out)

Expand All @@ -236,7 +244,9 @@ def test_default(self):

expected_out = [
e2 if sel == 6 else e1 if sel == 5 else d
for sel, e1, e2, d in zip(sel_vals, a1_vals, a2_vals, default_vals)
for sel, e1, e2, d in zip(
sel_vals, a1_vals, a2_vals, default_vals, strict=True
)
]
self.assertEqual(out_res, expected_out)

Expand Down Expand Up @@ -300,10 +310,12 @@ def test_really_simple(self):
[sel_vals, i1_0_vals, i1_1_vals, i2_0_vals, i2_1_vals],
)
expected_i1_out = [
v1 if s else v0 for s, v0, v1 in zip(sel_vals, i1_0_vals, i1_1_vals)
v1 if s else v0
for s, v0, v1 in zip(sel_vals, i1_0_vals, i1_1_vals, strict=True)
]
expected_i2_out = [
v1 if s else v0 for s, v0, v1 in zip(sel_vals, i2_0_vals, i2_1_vals)
v1 if s else v0
for s, v0, v1 in zip(sel_vals, i2_0_vals, i2_1_vals, strict=True)
]

self.assertEqual(actual_outputs[i1_out.name], expected_i1_out)
Expand All @@ -312,9 +324,15 @@ def test_really_simple(self):
def test_simple(self):
sel, sel_vals = gen_in(2)

x1s, x1_vals = (list(x) for x in zip(*(gen_in(8) for i in range(4))))
x2s, x2_vals = (list(x) for x in zip(*(gen_in(8) for i in range(4))))
x3s, x3_vals = (list(x) for x in zip(*(gen_in(8) for i in range(4))))
x1s, x1_vals = (
list(x) for x in zip(*(gen_in(8) for i in range(4)), strict=True)
)
x2s, x2_vals = (
list(x) for x in zip(*(gen_in(8) for i in range(4)), strict=True)
)
x3s, x3_vals = (
list(x) for x in zip(*(gen_in(8) for i in range(4)), strict=True)
)

i1_out = pyrtl.Output(name="i1_out")
i2_out = pyrtl.Output(name="i2_out")
Expand All @@ -328,9 +346,15 @@ def test_simple(self):
vals = [sel_vals, *x1_vals, *x2_vals, *x3_vals]
actual_outputs = utils.sim_and_ret_outws(wires, vals)

expected_i1_out = [v[s] for s, v in zip(sel_vals, zip(*x1_vals))]
expected_i2_out = [v[s] for s, v in zip(sel_vals, zip(*x2_vals))]
expected_i3_out = [v[s] for s, v in zip(sel_vals, zip(*x3_vals))]
expected_i1_out = [
v[s] for s, v in zip(sel_vals, zip(*x1_vals, strict=True), strict=True)
]
expected_i2_out = [
v[s] for s, v in zip(sel_vals, zip(*x2_vals, strict=True), strict=True)
]
expected_i3_out = [
v[s] for s, v in zip(sel_vals, zip(*x3_vals, strict=True), strict=True)
]

self.assertEqual(actual_outputs[i1_out.name], expected_i1_out)
self.assertEqual(actual_outputs[i2_out.name], expected_i2_out)
Expand All @@ -345,7 +369,7 @@ def test_simple_demux(self):
in_w, in_vals = utils.an_input_and_vals(2)
outs = (pyrtl.Output(name="output_" + str(i)) for i in range(4))
demux_outs = pyrtl.rtllib.muxes.demux(in_w)
for out_w, demux_out in zip(outs, demux_outs):
for out_w, demux_out in zip(outs, demux_outs, strict=True):
out_w <<= demux_out
traces = utils.sim_and_ret_outws((in_w,), (in_vals,))

Expand All @@ -357,7 +381,7 @@ def test_demux_2(self):
in_w, in_vals = utils.an_input_and_vals(1)
outs = (pyrtl.Output(name="output_" + str(i)) for i in range(2))
demux_outs = pyrtl.rtllib.muxes._demux_2(in_w)
for out_w, demux_out in zip(outs, demux_outs):
for out_w, demux_out in zip(outs, demux_outs, strict=True):
out_w <<= demux_out
traces = utils.sim_and_ret_outws((in_w,), (in_vals,))

Expand All @@ -369,7 +393,7 @@ def test_large_demux(self):
in_w, in_vals = utils.an_input_and_vals(5)
outs = (pyrtl.Output(name="output_" + str(i)) for i in range(32))
demux_outs = pyrtl.rtllib.muxes.demux(in_w)
for out_w, demux_out in zip(outs, demux_outs):
for out_w, demux_out in zip(outs, demux_outs, strict=True):
self.assertEqual(len(demux_out), 1)
out_w <<= demux_out
traces = utils.sim_and_ret_outws((in_w,), (in_vals,))
Expand Down
Loading