mypy passing

This commit is contained in:
Roger Maitland 2026-04-08 10:39:21 -04:00
parent 41716a1b35
commit 665eff0d67

View file

@ -7,7 +7,7 @@ from collections.abc import Iterable, Sequence
from dataclasses import dataclass
from itertools import combinations
from math import acos, sqrt
from typing import Literal
from typing import Any, Literal, TypeAlias, overload
import numpy as np
from build123d import (
@ -26,9 +26,10 @@ from build123d import (
Vector,
TOL_DIGITS,
)
from sklearn.cluster import DBSCAN
from sklearn.cluster import DBSCAN # type: ignore[import-untyped]
EPS = 1e-9
EdgeKey: TypeAlias = tuple[tuple[float, float, float], tuple[float, float, float]]
# Data model
@ -113,7 +114,7 @@ class MeshIndex:
for edge in face.edges():
edge_to_face_indices[_edge_key(edge)].add(index)
adjacency = {index: set() for index in range(len(self.faces))}
adjacency: dict[int, set] = {index: set() for index in range(len(self.faces))}
for face_indices in edge_to_face_indices.values():
for face_index in face_indices:
adjacency[face_index].update(face_indices - {face_index})
@ -170,7 +171,7 @@ def _cluster_points(
return [np.asarray(labels == label) for label in sorted(set(labels)) if label != -1]
def _edge_key(edge) -> tuple[tuple[float, float, float], tuple[float, float, float]]:
def _edge_key(edge) -> EdgeKey:
vertices = edge.vertices()
ends = sorted(_rounded_vertex_key(vertex.center()) for vertex in vertices)
return ends[0], ends[1]
@ -182,6 +183,17 @@ def _face_key(face: Face) -> tuple[tuple[float, float, float], ...]:
)
def _as_face(value: Any, context: str) -> Face:
if isinstance(value, Face):
return value
face_method = getattr(value, "face", None)
if callable(face_method):
face = face_method()
if isinstance(face, Face):
return face
raise RuntimeError(f"Expected Face while building {context}")
def _plane_basis(normal: Vector) -> tuple[Vector, Vector]:
helper = Vector(1.0, 0.0, 0.0)
if abs(helper.dot(normal)) > 0.9:
@ -314,8 +326,8 @@ def _indices_from_sewn_component(mesh_index: MeshIndex, component) -> list[int]:
def _build_face_edge_midpoint_adjacency(
mesh_index: MeshIndex,
) -> dict[int, list[tuple[int, Vector]]]:
edge_to_faces: defaultdict[object, list[int]] = defaultdict(list)
edge_midpoints: dict[object, Vector] = {}
edge_to_faces: defaultdict[EdgeKey, list[int]] = defaultdict(list)
edge_midpoints: dict[EdgeKey, Vector] = {}
for index, face in enumerate(mesh_index.faces):
for edge in face.edges():
@ -328,7 +340,9 @@ def _build_face_edge_midpoint_adjacency(
(vertices[0].Z + vertices[1].Z) / 2.0,
)
adjacency = {index: [] for index in range(len(mesh_index.faces))}
adjacency: dict[int, list[tuple[int, Vector]]] = {
index: [] for index in range(len(mesh_index.faces))
}
for edge_key, face_indices in edge_to_faces.items():
if len(face_indices) != 2:
continue
@ -351,7 +365,10 @@ def build_plane_face(patch: PlanePatch) -> Face:
v_center = (patch.v_min + patch.v_max) / 2.0
u_vec, _v_vec = _plane_basis(patch.normal)
plane = Plane(origin=patch.origin, x_dir=u_vec, z_dir=patch.normal)
return (plane * Pos(u_center, v_center, 0) * Rectangle(u_size, v_size)).face()
return _as_face(
plane * Pos(u_center, v_center, 0) * Rectangle(u_size, v_size),
"plane primitive",
)
def build_cylinder_face(patch: CylinderPatch, support_faces: Sequence[Face]) -> Face:
@ -369,23 +386,25 @@ def build_cylinder_face(patch: CylinderPatch, support_faces: Sequence[Face]) ->
axis_min = min(axis_values)
axis_max = max(axis_values)
radius = _median_scalar(radial_distances)
return (
(
Plane(
origin=patch.axis_point + patch.axis_direction * axis_min,
z_dir=patch.axis_direction,
)
* Cylinder(radius, axis_max - axis_min, align=None)
cylinder_shape = Plane(
origin=patch.axis_point + patch.axis_direction * axis_min,
z_dir=patch.axis_direction,
) * Cylinder(radius, axis_max - axis_min, align=Align.NONE)
cylinder_faces = getattr(cylinder_shape, "faces", None)
if not callable(cylinder_faces):
raise RuntimeError("Expected cylinder shape to provide faces()")
filtered_faces = cylinder_faces().filter_by(GeomType.CYLINDER)
if not filtered_faces or not isinstance(filtered_faces[0], Face):
raise RuntimeError(
"Expected cylindrical face while building cylinder primitive"
)
.faces()
.filter_by(GeomType.CYLINDER)[0]
)
return filtered_faces[0]
def build_sphere_face(patch: SpherePatch, support_faces: Sequence[Face]) -> Face:
vertices = _unique_face_vertices(support_faces)
radius = _median_scalar([(vertex - patch.center).length for vertex in vertices])
return (Pos(*tuple(patch.center)) * Sphere(radius)).face()
return _as_face(Pos(*tuple(patch.center)) * Sphere(radius), "sphere primitive")
# Local signature and patch-growth helpers
@ -511,6 +530,24 @@ def _bounding_boxes_overlap(box1, box2, tolerance: float = 0.0) -> bool:
)
@overload
def grow_curved_patch(
mesh_index: MeshIndex,
patch: CylinderPatch,
allowed_indices: set[int],
shape_scale: float,
) -> CylinderPatch: ...
@overload
def grow_curved_patch(
mesh_index: MeshIndex,
patch: SpherePatch,
allowed_indices: set[int],
shape_scale: float,
) -> SpherePatch: ...
def grow_curved_patch(
mesh_index: MeshIndex,
patch: CylinderPatch | SpherePatch,
@ -920,7 +957,7 @@ def fit_local_cylinder(
if not masks:
return None
best_mask = max(masks, key=np.count_nonzero)
best_mask = max(masks, key=lambda mask: int(np.count_nonzero(mask)))
cluster_records = [record for record, keep in zip(records, best_mask) if keep]
face_indices = sorted(
{index for indices, _ in cluster_records for index in indices}
@ -973,10 +1010,9 @@ def fit_local_cylinder(
)
if not point_masks:
return None
best_point_mask = max(point_masks, key=lambda mask: int(np.count_nonzero(mask)))
best_points = [
point
for point, keep in zip(intersections_2d, max(point_masks, key=np.count_nonzero))
if keep
point for point, keep in zip(intersections_2d, best_point_mask) if keep
]
center_2d = (
_mean_scalar([point[0] for point in best_points]),
@ -1516,11 +1552,16 @@ def shapes_to_code(primitives: Iterable[Shape]) -> list[str]:
.sort_by(Axis.Y)[0]
)
global_origin = pln.from_local_coords(local_origin)
pln = pln.shift_origin(global_origin)
shifted_plane = pln.shift_origin(global_origin)
if not isinstance(shifted_plane, Plane):
raise RuntimeError("Expected Plane.shift_origin() to return Plane")
pln = shifted_plane
bbox = center_oriented_rect.bounding_box()
w, h = bbox.size.X, bbox.size.Y
rect = pln * Rectangle(w, h, align=Align.MIN)
rect = _as_face(
pln * Rectangle(w, h, align=Align.MIN), "planar rectangle"
)
common = rect.intersect(primitive)
if not common or not isinstance(common[0], Face):
raise RuntimeError("Error in generating planar rectangle")
@ -1617,7 +1658,7 @@ def detect_primitives(
# primitives: list[tuple[Face, Shell]] = []
primitives: list[Face] = []
claimed = set()
claimed: set[int] = set()
for patch in patches:
support_faces = mesh_index.face_set(sorted(patch.face_indices))
claimed.update(patch.face_indices)