diff --git a/gemd/__version__.py b/gemd/__version__.py index 55f47b0..dd0edd7 100644 --- a/gemd/__version__.py +++ b/gemd/__version__.py @@ -1 +1 @@ -__version__ = "2.1.10" +__version__ = "2.1.11" diff --git a/gemd/entity/base_entity.py b/gemd/entity/base_entity.py index 06d169e..7b70d2c 100644 --- a/gemd/entity/base_entity.py +++ b/gemd/entity/base_entity.py @@ -207,11 +207,15 @@ def _cached_equals(this: "BaseEntity", # Note that this could violate transitivity -- Link(scope1) == obj == Link(scope2) def __eq__(self, other): from gemd.entity.link_by_uid import LinkByUID - if isinstance(other, LinkByUID): + from gemd.util import cached_isinstance + + if id(self) == id(other): + return True + elif cached_isinstance(other, LinkByUID): return self.uids.get(other.scope) == other.id - elif isinstance(other, tuple): + elif cached_isinstance(other, tuple): return len(other) == 2 and other[0] in self.uids and self.uids[other[0]] == other[1] - elif isinstance(other, BaseEntity): + elif cached_isinstance(other, BaseEntity): # We have to be a little clever for efficiency and to avoid infinite recursion return BaseEntity._cached_equals(self, other) else: diff --git a/gemd/entity/link_by_uid.py b/gemd/entity/link_by_uid.py index b8c10c3..0ca54a1 100644 --- a/gemd/entity/link_by_uid.py +++ b/gemd/entity/link_by_uid.py @@ -67,12 +67,16 @@ def from_entity(cls, entity: BaseEntityType, *, scope=None): # Note that this could violate transitivity def __eq__(self, other): from gemd.entity.base_entity import BaseEntity - if isinstance(other, BaseEntity): + from gemd.util import cached_isinstance + + if cached_isinstance(other, LinkByUID): + return self.scope == other.scope and self.id == other.id + elif cached_isinstance(other, BaseEntity): if self.scope in other.uids: return other.uids[self.scope] == self.id else: return False - elif isinstance(other, tuple): # Make them interchangeable in a dict + elif cached_isinstance(other, tuple): # Make them interchangeable in a dict return len(other) == 2 and (self.scope, self.id) == other else: return super().__eq__(other) diff --git a/tests/entity/test_link_by_uid.py b/tests/entity/test_link_by_uid.py index 40b6b17..6f35552 100644 --- a/tests/entity/test_link_by_uid.py +++ b/tests/entity/test_link_by_uid.py @@ -47,3 +47,4 @@ def test_equality(): assert link == ("foo", "bar") assert link != ("foo", "bar", "baz") assert link != ("foo", "rab") + assert link != "foo: rab"