Merge pull request #707 from alexer/fix-eq

Fix __eq__ and __ne__ for classes implementing them
This commit is contained in:
Roger Maitland 2024-09-22 10:10:13 -04:00 committed by GitHub
commit 720bee9fa0
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 80 additions and 37 deletions

View file

@ -444,8 +444,10 @@ class Vector:
__str__ = __repr__
def __eq__(self, other: Vector) -> bool: # type: ignore[override]
def __eq__(self, other: object) -> bool:
"""Vectors equal operator =="""
if not isinstance(other, Vector):
return NotImplemented
return self.wrapped.IsEqual(other.wrapped, 0.00001, 0.00001)
def __hash__(self) -> int:
@ -670,7 +672,7 @@ class Axis(metaclass=AxisMeta):
def __eq__(self, other: object) -> bool:
if not isinstance(other, Axis):
return False
return NotImplemented
return self.position == other.position and self.direction == other.direction
def located(self, new_location: Location):
@ -1468,10 +1470,10 @@ class Location:
def __pow__(self, exponent: int) -> Location:
return Location(self.wrapped.Powered(exponent))
def __eq__(self, other: Location) -> bool:
def __eq__(self, other: object) -> bool:
"""Compare Locations"""
if not isinstance(other, Location):
raise ValueError("other must be a Location")
return NotImplemented
quaternion1 = gp_Quaternion()
quaternion1.SetEulerAngles(
gp_EulerSequence.gp_Intrinsic_XYZ,
@ -2139,27 +2141,6 @@ class Plane(metaclass=PlaneMeta):
origin=self.origin + self.z_dir * amount, x_dir=self.x_dir, z_dir=self.z_dir
)
def _eq_iter(self, other: Plane):
"""Iterator to successively test equality
Args:
other: Plane to compare to
Returns:
Are planes equal
"""
# equality tolerances
eq_tolerance_origin = 1e-6
eq_tolerance_dot = 1e-6
yield isinstance(other, Plane) # comparison is with another Plane
# origins are the same
yield abs(self._origin - other.origin) < eq_tolerance_origin
# z-axis vectors are parallel (assumption: both are unit vectors)
yield abs(self.z_dir.dot(other.z_dir) - 1) < eq_tolerance_dot
# x-axis vectors are parallel (assumption: both are unit vectors)
yield abs(self.x_dir.dot(other.x_dir) - 1) < eq_tolerance_dot
def __copy__(self) -> Plane:
"""Return copy of self"""
return Plane(gp_Pln(self.wrapped.Position()))
@ -2168,13 +2149,23 @@ class Plane(metaclass=PlaneMeta):
"""Return deepcopy of self"""
return Plane(gp_Pln(self.wrapped.Position()))
def __eq__(self, other: Plane):
def __eq__(self, other: object):
"""Are planes equal operator =="""
return all(self._eq_iter(other))
if not isinstance(other, Plane):
return NotImplemented
def __ne__(self, other: Plane):
"""Are planes not equal operator !+"""
return not self.__eq__(other)
# equality tolerances
eq_tolerance_origin = 1e-6
eq_tolerance_dot = 1e-6
return (
# origins are the same
abs(self._origin - other.origin) < eq_tolerance_origin
# z-axis vectors are parallel (assumption: both are unit vectors)
and abs(self.z_dir.dot(other.z_dir) - 1) < eq_tolerance_dot
# x-axis vectors are parallel (assumption: both are unit vectors)
and abs(self.x_dir.dot(other.x_dir) - 1) < eq_tolerance_dot
)
def __neg__(self) -> Plane:
"""Reverse z direction of plane operator -"""

View file

@ -1996,7 +1996,7 @@ class Shape(NodeMixin):
def __eq__(self, other) -> bool:
"""Are shapes same operator =="""
return self.is_same(other) if isinstance(other, Shape) else False
return self.is_same(other) if isinstance(other, Shape) else NotImplemented
def is_valid(self) -> bool:
"""Returns True if no defect is detected on the shape S or any of its
@ -3728,9 +3728,15 @@ class ShapeList(list[T]):
"""Filter by axis or geomtype operator |"""
return self.filter_by(filter_by)
def __eq__(self, other: ShapeList):
def __eq__(self, other: object):
"""ShapeLists equality operator =="""
return set(self) == set(other)
return set(self) == set(other) if isinstance(other, ShapeList) else NotImplemented
# Normally implementing __eq__ is enough, but ShapeList subclasses list,
# which already implements __ne__, so we need to override it, too
def __ne__(self, other: ShapeList):
"""ShapeLists inequality operator !="""
return set(self) != set(other) if isinstance(other, ShapeList) else NotImplemented
def __add__(self, other: ShapeList):
"""Combine two ShapeLists together operator +"""

View file

@ -95,6 +95,12 @@ DEG2RAD = math.pi / 180
RAD2DEG = 180 / math.pi
# Always equal to any other object, to test that __eq__ cooperation is working
class AlwaysEqual:
def __eq__(self, other):
return True
class DirectApiTestCase(unittest.TestCase):
def assertTupleAlmostEquals(
self,
@ -365,13 +371,13 @@ class TestAxis(DirectApiTestCase):
self.assertEqual(Axis.X, Axis.X)
self.assertEqual(Axis.Y, Axis.Y)
self.assertEqual(Axis.Z, Axis.Z)
self.assertEqual(Axis.X, AlwaysEqual())
def test_axis_not_equal(self):
self.assertNotEqual(Axis.X, Axis.Y)
random_obj = object()
self.assertNotEqual(Axis.X, random_obj)
class TestBoundBox(DirectApiTestCase):
def test_basic_bounding_box(self):
v = Vertex(1, 1, 1)
@ -1758,15 +1764,21 @@ class TestLocation(DirectApiTestCase):
self.assertVectorAlmostEquals(axis.position, (1, 2, 3), 6)
self.assertVectorAlmostEquals(axis.direction, (0, 1, 0), 6)
def test_eq(self):
def test_equal(self):
loc = Location((1, 2, 3), (4, 5, 6))
diff_position = Location((10, 20, 30), (4, 5, 6))
diff_orientation = Location((1, 2, 3), (40, 50, 60))
same = Location((1, 2, 3), (4, 5, 6))
self.assertEqual(loc, same)
self.assertEqual(loc, AlwaysEqual())
def test_not_equal(self):
loc = Location((1, 2, 3), (40, 50, 60))
diff_position = Location((3, 2, 1), (40, 50, 60))
diff_orientation = Location((1, 2, 3), (60, 50, 40))
self.assertNotEqual(loc, diff_position)
self.assertNotEqual(loc, diff_orientation)
self.assertNotEqual(loc, object())
def test_neg(self):
loc = Location((1, 2, 3), (0, 35, 127))
@ -2694,6 +2706,8 @@ class TestPlane(DirectApiTestCase):
Plane(origin=(0, 0, 0), x_dir=(1, 0, 0), z_dir=(0, 1, 1)),
Plane(origin=(0, 0, 0), x_dir=(1, 0, 0), z_dir=(0, 1, 1)),
)
# __eq__ cooperation
self.assertEqual(Plane.XY, AlwaysEqual())
def test_plane_not_equal(self):
# type difference
@ -2983,6 +2997,17 @@ class TestShape(DirectApiTestCase):
box = Solid.make_box(1, 1, 1)
self.assertTrue(box.is_equal(box))
def test_equal(self):
box = Solid.make_box(1, 1, 1)
self.assertEqual(box, box)
self.assertEqual(box, AlwaysEqual())
def test_not_equal(self):
box = Solid.make_box(1, 1, 1)
diff = Solid.make_box(1, 2, 3)
self.assertNotEqual(box, diff)
self.assertNotEqual(box, object())
def test_tessellate(self):
box123 = Solid.make_box(1, 2, 3)
verts, triangles = box123.tessellate(1e-6)
@ -3492,6 +3517,20 @@ class TestShapeList(DirectApiTestCase):
sl = ShapeList([Box(1, 2, 3), Vertex(1, 1, 1)])
self.assertAlmostEqual(sl.compound().volume, 1 * 2 * 3, 5)
def test_equal(self):
box = Box(1, 1, 1)
cyl = Cylinder(1, 1)
sl = ShapeList([box, cyl])
same = ShapeList([cyl, box])
self.assertEqual(sl, same)
self.assertEqual(sl, AlwaysEqual())
def test_not_equal(self):
sl = ShapeList([Box(1, 1, 1), Cylinder(1, 1)])
diff = ShapeList([Box(1, 1, 1), Box(1, 2, 3)])
self.assertNotEqual(sl, diff)
self.assertNotEqual(sl, object())
class TestShells(DirectApiTestCase):
def test_shell_init(self):
@ -3806,6 +3845,13 @@ class TestVector(DirectApiTestCase):
c = Vector(1, 2, 3.000001)
self.assertEqual(a, b)
self.assertEqual(a, c)
self.assertEqual(a, AlwaysEqual())
def test_vector_not_equal(self):
a = Vector(1, 2, 3)
b = Vector(3, 2, 1)
self.assertNotEqual(a, b)
self.assertNotEqual(a, object())
def test_vector_distance(self):
"""