Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions amaranth/lib/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -775,6 +775,9 @@ def eq(self, other):
:class:`.Assign`
:py:`self.as_value().eq(other)`
"""
if isinstance(other, ValueCastable):
if not self.shape() == Layout.cast(other.shape()):
raise TypeError(f"Cannot assign value with shape {other.shape()} to view with layout {self.shape()}")
return self.as_value().eq(other)

def __getitem__(self, key):
Expand Down
3 changes: 3 additions & 0 deletions amaranth/lib/enum.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,9 @@ def eq(self, other):
:class:`Assign`
``self.as_value().eq(other)``
"""
if isinstance(other, ValueCastable):
if not self.shape() == other.shape():
raise TypeError(f"Cannot assign value with shape {other.shape()} to value with shape {self.shape()}")
return self.as_value().eq(other)

def __add__(self, other):
Expand Down
36 changes: 29 additions & 7 deletions amaranth/lib/wiring.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import enum
import re
import warnings
import inspect
try:
import annotationlib # py3.14+
except ImportError:
Expand Down Expand Up @@ -1562,21 +1563,26 @@ def connect(m, *args, **kwargs):
(out_path, out_member), = out_kind
for (in_path, in_member) in in_kind:
def connect_value(*, out_path, in_path, src_loc_at):
in_value = Value.cast(_traverse_path(in_path, objects))
out_value = Value.cast(_traverse_path(out_path, objects))
assert type(in_value) in (Const, Signal)
in_value = _traverse_path(in_path, objects)
out_value = _traverse_path(out_path, objects)
# If the input is a constant, only a constant may be connected to it. Ensure that
# this is the case.
try:
in_value = Const.cast(in_value)
except TypeError:
pass
if type(in_value) is Const:
# If the output is not a constant, the connection is illegal.
if type(out_value) is not Const:
try:
out_value = Const.cast(out_value)
except TypeError:
raise ConnectionError(
f"Cannot connect input member {_format_path(in_path)} that has "
f"a constant value {in_value.value!r} to an output member "
f"{_format_path(out_path)} that has a varying value")
# If the output is a constant, the connection is legal only if the value is
# the same for both the input and the output.
if type(out_value) is Const and in_value.value != out_value.value:
if in_value.value != out_value.value:
raise ConnectionError(
f"Cannot connect input member {_format_path(in_path)} that has "
f"a constant value {in_value.value!r} to an output member "
Expand All @@ -1586,8 +1592,24 @@ def connect_value(*, out_path, in_path, src_loc_at):
# value (which is constant) is consistent with a connection that would have
# been made.
return
# A connection that is made at this point is guaranteed to be valid.
connections.append(in_value.eq(out_value, src_loc_at=src_loc_at + 1))
# If the input is a ValueCastable, it must implement `eq()`.
try:
eq = in_value.eq
except AttributeError:
raise ConnectionError(
f"Cannot connect input member {_format_path(in_path)} because the input "
f"value {in_value!r} does not support assignment")
# The `eq()` method may take a `src_loc_at` argument; provide it if it does.
if 'src_loc_at' in inspect.signature(eq).parameters:
kwargs = {'src_loc_at': src_loc_at + 1}
else:
kwargs = {}
try:
connections.append(eq(out_value, **kwargs))
except Exception as e:
raise ConnectionError(
f"Cannot connect input member {_format_path(in_path)} to output member "
f"{_format_path(out_path)} because assignment failed") from e
def connect_dimensions(dimensions, *, out_path, in_path, src_loc_at):
if not dimensions:
return connect_value(out_path=out_path, in_path=in_path, src_loc_at=src_loc_at)
Expand Down
18 changes: 18 additions & 0 deletions tests/test_lib_wiring.py
Original file line number Diff line number Diff line change
Expand Up @@ -851,6 +851,24 @@ class Cycle(enum.Enum, shape=2):
q=NS(signature=Signature({"a": In(Cycle)}),
a=Signal(Cycle)))

def test_shape_mismatch_layout(self):
class LastDelimited(data.Struct):
data: 8
last: 1
class FirstDelimited(data.Struct):
data: 8
first: 1

m = Module()
with self.assertRaisesRegex(ConnectionError,
r"^Cannot connect input member 'q\.a' to output member 'p\.a' because assignment "
r"failed$"):
connect(m,
p=NS(signature=Signature({"a": Out(LastDelimited)}),
a=Signal(LastDelimited)),
q=NS(signature=Signature({"a": In(FirstDelimited)}),
a=Signal(FirstDelimited)))

def test_init_mismatch(self):
m = Module()
with self.assertRaisesRegex(ConnectionError,
Expand Down
Loading