mirror of
https://github.com/gumyr/build123d.git
synced 2026-03-09 08:11:37 -07:00
Fix __eq__ and __ne__ for classes implementing them
Main issue, which concerns Vector, Location, and ShapeList: Comparison with an object of a different type should not cause an exception - they are simply not equal. Raising an exception in __eq__ can (and will*) break unrelated code that expects __eq__ to be well-behaved. (* I noticed this bug when cq-editor choked on it while trying to find a name for an object in a dictionary of local variables) There's a second more minor issue, which concerns the rest of the classes: When the other type in __eq__ is not supported, one should technically return NotImplemented instead of False, to allow the other type to take part in the comparison, in case they know about our type. (__ne__ should also not generally be implemented as just the negation of __eq__ because of this, but that's also a moot point because the __ne__ can just be removed - Python will automatically do the right thing based on __eq__ here) Technically, the __eq__ for Vector and Plane is also broken in another way: It's not transitive. >>> a, b, c = Vector(0), Vector(9e-6), Vector(18e-6) >>> a == b == c True >>> a == c False They should really eg. have a separate is_close() for approximate comparison, but this isn't fixed here, since I have no idea how many places it'd break, for one.
This commit is contained in:
parent
b2cb62fdac
commit
acbebfb017
3 changed files with 80 additions and 37 deletions
|
|
@ -444,8 +444,10 @@ class Vector:
|
|||
|
||||
__str__ = __repr__
|
||||
|
||||
def __eq__(self, other: Vector) -> bool: # type: ignore[override]
|
||||
def __eq__(self, other: object) -> bool:
|
||||
"""Vectors equal operator =="""
|
||||
if not isinstance(other, Vector):
|
||||
return NotImplemented
|
||||
return self.wrapped.IsEqual(other.wrapped, 0.00001, 0.00001)
|
||||
|
||||
def __hash__(self) -> int:
|
||||
|
|
@ -670,7 +672,7 @@ class Axis(metaclass=AxisMeta):
|
|||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
if not isinstance(other, Axis):
|
||||
return False
|
||||
return NotImplemented
|
||||
return self.position == other.position and self.direction == other.direction
|
||||
|
||||
def located(self, new_location: Location):
|
||||
|
|
@ -1468,10 +1470,10 @@ class Location:
|
|||
def __pow__(self, exponent: int) -> Location:
|
||||
return Location(self.wrapped.Powered(exponent))
|
||||
|
||||
def __eq__(self, other: Location) -> bool:
|
||||
def __eq__(self, other: object) -> bool:
|
||||
"""Compare Locations"""
|
||||
if not isinstance(other, Location):
|
||||
raise ValueError("other must be a Location")
|
||||
return NotImplemented
|
||||
quaternion1 = gp_Quaternion()
|
||||
quaternion1.SetEulerAngles(
|
||||
gp_EulerSequence.gp_Intrinsic_XYZ,
|
||||
|
|
@ -2139,27 +2141,6 @@ class Plane(metaclass=PlaneMeta):
|
|||
origin=self.origin + self.z_dir * amount, x_dir=self.x_dir, z_dir=self.z_dir
|
||||
)
|
||||
|
||||
def _eq_iter(self, other: Plane):
|
||||
"""Iterator to successively test equality
|
||||
|
||||
Args:
|
||||
other: Plane to compare to
|
||||
|
||||
Returns:
|
||||
Are planes equal
|
||||
"""
|
||||
# equality tolerances
|
||||
eq_tolerance_origin = 1e-6
|
||||
eq_tolerance_dot = 1e-6
|
||||
|
||||
yield isinstance(other, Plane) # comparison is with another Plane
|
||||
# origins are the same
|
||||
yield abs(self._origin - other.origin) < eq_tolerance_origin
|
||||
# z-axis vectors are parallel (assumption: both are unit vectors)
|
||||
yield abs(self.z_dir.dot(other.z_dir) - 1) < eq_tolerance_dot
|
||||
# x-axis vectors are parallel (assumption: both are unit vectors)
|
||||
yield abs(self.x_dir.dot(other.x_dir) - 1) < eq_tolerance_dot
|
||||
|
||||
def __copy__(self) -> Plane:
|
||||
"""Return copy of self"""
|
||||
return Plane(gp_Pln(self.wrapped.Position()))
|
||||
|
|
@ -2168,13 +2149,23 @@ class Plane(metaclass=PlaneMeta):
|
|||
"""Return deepcopy of self"""
|
||||
return Plane(gp_Pln(self.wrapped.Position()))
|
||||
|
||||
def __eq__(self, other: Plane):
|
||||
def __eq__(self, other: object):
|
||||
"""Are planes equal operator =="""
|
||||
return all(self._eq_iter(other))
|
||||
if not isinstance(other, Plane):
|
||||
return NotImplemented
|
||||
|
||||
def __ne__(self, other: Plane):
|
||||
"""Are planes not equal operator !+"""
|
||||
return not self.__eq__(other)
|
||||
# equality tolerances
|
||||
eq_tolerance_origin = 1e-6
|
||||
eq_tolerance_dot = 1e-6
|
||||
|
||||
return (
|
||||
# origins are the same
|
||||
abs(self._origin - other.origin) < eq_tolerance_origin
|
||||
# z-axis vectors are parallel (assumption: both are unit vectors)
|
||||
and abs(self.z_dir.dot(other.z_dir) - 1) < eq_tolerance_dot
|
||||
# x-axis vectors are parallel (assumption: both are unit vectors)
|
||||
and abs(self.x_dir.dot(other.x_dir) - 1) < eq_tolerance_dot
|
||||
)
|
||||
|
||||
def __neg__(self) -> Plane:
|
||||
"""Reverse z direction of plane operator -"""
|
||||
|
|
|
|||
|
|
@ -1972,7 +1972,7 @@ class Shape(NodeMixin):
|
|||
|
||||
def __eq__(self, other) -> bool:
|
||||
"""Are shapes same operator =="""
|
||||
return self.is_same(other) if isinstance(other, Shape) else False
|
||||
return self.is_same(other) if isinstance(other, Shape) else NotImplemented
|
||||
|
||||
def is_valid(self) -> bool:
|
||||
"""Returns True if no defect is detected on the shape S or any of its
|
||||
|
|
@ -3704,9 +3704,15 @@ class ShapeList(list[T]):
|
|||
"""Filter by axis or geomtype operator |"""
|
||||
return self.filter_by(filter_by)
|
||||
|
||||
def __eq__(self, other: ShapeList):
|
||||
def __eq__(self, other: object):
|
||||
"""ShapeLists equality operator =="""
|
||||
return set(self) == set(other)
|
||||
return set(self) == set(other) if isinstance(other, ShapeList) else NotImplemented
|
||||
|
||||
# Normally implementing __eq__ is enough, but ShapeList subclasses list,
|
||||
# which already implements __ne__, so we need to override it, too
|
||||
def __ne__(self, other: ShapeList):
|
||||
"""ShapeLists inequality operator !="""
|
||||
return set(self) != set(other) if isinstance(other, ShapeList) else NotImplemented
|
||||
|
||||
def __add__(self, other: ShapeList):
|
||||
"""Combine two ShapeLists together operator +"""
|
||||
|
|
|
|||
|
|
@ -93,6 +93,12 @@ DEG2RAD = math.pi / 180
|
|||
RAD2DEG = 180 / math.pi
|
||||
|
||||
|
||||
# Always equal to any other object, to test that __eq__ cooperation is working
|
||||
class AlwaysEqual:
|
||||
def __eq__(self, other):
|
||||
return True
|
||||
|
||||
|
||||
class DirectApiTestCase(unittest.TestCase):
|
||||
def assertTupleAlmostEquals(
|
||||
self,
|
||||
|
|
@ -363,13 +369,13 @@ class TestAxis(DirectApiTestCase):
|
|||
self.assertEqual(Axis.X, Axis.X)
|
||||
self.assertEqual(Axis.Y, Axis.Y)
|
||||
self.assertEqual(Axis.Z, Axis.Z)
|
||||
self.assertEqual(Axis.X, AlwaysEqual())
|
||||
|
||||
def test_axis_not_equal(self):
|
||||
self.assertNotEqual(Axis.X, Axis.Y)
|
||||
random_obj = object()
|
||||
self.assertNotEqual(Axis.X, random_obj)
|
||||
|
||||
|
||||
class TestBoundBox(DirectApiTestCase):
|
||||
def test_basic_bounding_box(self):
|
||||
v = Vertex(1, 1, 1)
|
||||
|
|
@ -1730,15 +1736,21 @@ class TestLocation(DirectApiTestCase):
|
|||
self.assertVectorAlmostEquals(axis.position, (1, 2, 3), 6)
|
||||
self.assertVectorAlmostEquals(axis.direction, (0, 1, 0), 6)
|
||||
|
||||
def test_eq(self):
|
||||
def test_equal(self):
|
||||
loc = Location((1, 2, 3), (4, 5, 6))
|
||||
diff_position = Location((10, 20, 30), (4, 5, 6))
|
||||
diff_orientation = Location((1, 2, 3), (40, 50, 60))
|
||||
same = Location((1, 2, 3), (4, 5, 6))
|
||||
|
||||
self.assertEqual(loc, same)
|
||||
self.assertEqual(loc, AlwaysEqual())
|
||||
|
||||
def test_not_equal(self):
|
||||
loc = Location((1, 2, 3), (40, 50, 60))
|
||||
diff_position = Location((3, 2, 1), (40, 50, 60))
|
||||
diff_orientation = Location((1, 2, 3), (60, 50, 40))
|
||||
|
||||
self.assertNotEqual(loc, diff_position)
|
||||
self.assertNotEqual(loc, diff_orientation)
|
||||
self.assertNotEqual(loc, object())
|
||||
|
||||
def test_neg(self):
|
||||
loc = Location((1, 2, 3), (0, 35, 127))
|
||||
|
|
@ -2666,6 +2678,8 @@ class TestPlane(DirectApiTestCase):
|
|||
Plane(origin=(0, 0, 0), x_dir=(1, 0, 0), z_dir=(0, 1, 1)),
|
||||
Plane(origin=(0, 0, 0), x_dir=(1, 0, 0), z_dir=(0, 1, 1)),
|
||||
)
|
||||
# __eq__ cooperation
|
||||
self.assertEqual(Plane.XY, AlwaysEqual())
|
||||
|
||||
def test_plane_not_equal(self):
|
||||
# type difference
|
||||
|
|
@ -2955,6 +2969,17 @@ class TestShape(DirectApiTestCase):
|
|||
box = Solid.make_box(1, 1, 1)
|
||||
self.assertTrue(box.is_equal(box))
|
||||
|
||||
def test_equal(self):
|
||||
box = Solid.make_box(1, 1, 1)
|
||||
self.assertEqual(box, box)
|
||||
self.assertEqual(box, AlwaysEqual())
|
||||
|
||||
def test_not_equal(self):
|
||||
box = Solid.make_box(1, 1, 1)
|
||||
diff = Solid.make_box(1, 2, 3)
|
||||
self.assertNotEqual(box, diff)
|
||||
self.assertNotEqual(box, object())
|
||||
|
||||
def test_tessellate(self):
|
||||
box123 = Solid.make_box(1, 2, 3)
|
||||
verts, triangles = box123.tessellate(1e-6)
|
||||
|
|
@ -3439,6 +3464,20 @@ class TestShapeList(DirectApiTestCase):
|
|||
sl = ShapeList([Box(1, 2, 3), Vertex(1, 1, 1)])
|
||||
self.assertAlmostEqual(sl.compound().volume, 1 * 2 * 3, 5)
|
||||
|
||||
def test_equal(self):
|
||||
box = Box(1, 1, 1)
|
||||
cyl = Cylinder(1, 1)
|
||||
sl = ShapeList([box, cyl])
|
||||
same = ShapeList([cyl, box])
|
||||
self.assertEqual(sl, same)
|
||||
self.assertEqual(sl, AlwaysEqual())
|
||||
|
||||
def test_not_equal(self):
|
||||
sl = ShapeList([Box(1, 1, 1), Cylinder(1, 1)])
|
||||
diff = ShapeList([Box(1, 1, 1), Box(1, 2, 3)])
|
||||
self.assertNotEqual(sl, diff)
|
||||
self.assertNotEqual(sl, object())
|
||||
|
||||
|
||||
class TestShells(DirectApiTestCase):
|
||||
def test_shell_init(self):
|
||||
|
|
@ -3753,6 +3792,13 @@ class TestVector(DirectApiTestCase):
|
|||
c = Vector(1, 2, 3.000001)
|
||||
self.assertEqual(a, b)
|
||||
self.assertEqual(a, c)
|
||||
self.assertEqual(a, AlwaysEqual())
|
||||
|
||||
def test_vector_not_equal(self):
|
||||
a = Vector(1, 2, 3)
|
||||
b = Vector(3, 2, 1)
|
||||
self.assertNotEqual(a, b)
|
||||
self.assertNotEqual(a, object())
|
||||
|
||||
def test_vector_distance(self):
|
||||
"""
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue