Order functions, methods and properties in a class by Python's conventional order

This commit is contained in:
gumyr 2025-01-04 15:46:21 -05:00
parent b5396639dc
commit 93513b1449

View file

@ -233,6 +233,188 @@ interface for efficient and extensible CAD modeling workflows.
}
def sort_class_methods_by_convention(class_def: cst.ClassDef) -> cst.ClassDef:
"""Sort methods and properties in a class according to Python conventions."""
methods, properties = extract_methods_and_properties(class_def)
sorted_body = order_methods_by_convention(methods, properties)
other_statements = [
stmt for stmt in class_def.body.body if not isinstance(stmt, cst.FunctionDef)
]
final_body = cst.IndentedBlock(body=other_statements + sorted_body)
return class_def.with_changes(body=final_body)
def extract_methods_and_properties(
class_def: cst.ClassDef,
) -> tuple[List[cst.FunctionDef], List[List[cst.FunctionDef]]]:
"""
Extract methods and properties (with setters grouped together) from a class.
Returns:
- methods: Regular methods in the class.
- properties: List of grouped properties, where each group contains a getter
and its associated setter, if present.
"""
methods = []
properties = {}
for stmt in class_def.body.body:
if isinstance(stmt, cst.FunctionDef):
for decorator in stmt.decorators:
# Handle @property
if (
isinstance(decorator.decorator, cst.Name)
and decorator.decorator.value == "property"
):
properties[stmt.name.value] = [stmt] # Initialize with getter
# Handle @property.setter
elif (
isinstance(decorator.decorator, cst.Attribute)
and decorator.decorator.attr.value == "setter"
):
base_name = decorator.decorator.value.value # Extract base name
if base_name in properties:
properties[base_name].append(
stmt
) # Add setter to the property group
else:
# Setter appears before the getter
properties[base_name] = [None, stmt]
# Add non-property methods
if not any(
isinstance(decorator.decorator, cst.Name)
and decorator.decorator.value == "property"
or isinstance(decorator.decorator, cst.Attribute)
and decorator.decorator.attr.value == "setter"
for decorator in stmt.decorators
):
methods.append(stmt)
# Convert property dictionary into a sorted list of grouped properties
sorted_properties = [group for _, group in sorted(properties.items())]
return methods, sorted_properties
def order_methods_by_convention(
methods: List[cst.FunctionDef], properties: List[List[cst.FunctionDef]]
) -> List[cst.BaseStatement]:
"""
Order methods and properties in a class by Python's conventional order with section headers.
Sections:
- Constructor
- Properties (grouped by getter and setter)
- Class Methods
- Static Methods
- Public and Private Instance Methods
"""
def method_key(method: cst.FunctionDef) -> tuple[int, str]:
name = method.name.value
decorators = {
decorator.decorator.value
for decorator in method.decorators
if isinstance(decorator.decorator, cst.Name)
}
if name == "__init__":
return (0, name) # Constructor always comes first
elif name.startswith("__") and name.endswith("__"):
return (1, name) # Dunder methods follow
elif any(
decorator == "property" or decorator.endswith(".setter")
for decorator in decorators
):
return (2, name) # Properties and setters follow dunder methods
elif "classmethod" in decorators:
return (3, name) # Class methods follow properties
elif "staticmethod" in decorators:
return (4, name) # Static methods follow class methods
elif not name.startswith("_"):
return (5, name) # Public instance methods
else:
return (6, name) # Private methods last
# Flatten properties into a single sorted list
flattened_properties = [
prop for group in properties for prop in group if prop is not None
]
# Separate __init__, class methods, static methods, and instance methods
init_methods = [m for m in methods if m.name.value == "__init__"]
class_methods = [
m
for m in methods
if any(decorator.decorator.value == "classmethod" for decorator in m.decorators)
]
static_methods = [
m
for m in methods
if any(
decorator.decorator.value == "staticmethod" for decorator in m.decorators
)
]
instance_methods = [
m
for m in methods
if m.name.value != "__init__"
and not any(
decorator.decorator.value in {"classmethod", "staticmethod"}
for decorator in m.decorators
)
]
# Sort properties and each method group alphabetically
sorted_properties = sorted(flattened_properties, key=lambda prop: prop.name.value)
sorted_class_methods = sorted(class_methods, key=lambda m: m.name.value)
sorted_static_methods = sorted(static_methods, key=lambda m: m.name.value)
sorted_instance_methods = sorted(instance_methods, key=lambda m: method_key(m))
# Combine all sections with headers
ordered_sections: List[cst.BaseStatement] = []
if init_methods:
ordered_sections.append(
cst.SimpleStatementLine([cst.Expr(cst.Comment("# ---- Constructor ----"))])
)
ordered_sections.extend(init_methods)
if sorted_properties:
ordered_sections.append(
cst.SimpleStatementLine([cst.Expr(cst.Comment("# ---- Properties ----"))])
)
ordered_sections.extend(sorted_properties)
if sorted_class_methods:
ordered_sections.append(
cst.SimpleStatementLine(
[cst.Expr(cst.Comment("# ---- Class Methods ----"))]
)
)
ordered_sections.extend(sorted_class_methods)
if sorted_static_methods:
ordered_sections.append(
cst.SimpleStatementLine(
[cst.Expr(cst.Comment("# ---- Static Methods ----"))]
)
)
ordered_sections.extend(sorted_static_methods)
if sorted_instance_methods:
ordered_sections.append(
cst.SimpleStatementLine(
[cst.Expr(cst.Comment("# ---- Instance Methods ----"))]
)
)
ordered_sections.extend(sorted_instance_methods)
return ordered_sections
class ImportCollector(cst.CSTVisitor):
def __init__(self):
self.imports: Set[str] = set()
@ -259,6 +441,22 @@ class ClassExtractor(cst.CSTVisitor):
self.extracted_classes[node.name.value] = node
class ClassMethodExtractor(cst.CSTVisitor):
def __init__(self):
self.class_methods: Dict[str, List[cst.FunctionDef]] = {}
def visit_ClassDef(self, node: cst.ClassDef) -> None:
class_name = node.name.value
self.class_methods[class_name] = []
for statement in node.body.body:
if isinstance(statement, cst.FunctionDef):
self.class_methods[class_name].append(statement)
# Sort methods alphabetically by name
self.class_methods[class_name].sort(key=lambda method: method.name.value)
class MixinClassExtractor(cst.CSTVisitor):
def __init__(self):
self.extracted_classes: Dict[str, cst.ClassDef] = {}
@ -285,6 +483,9 @@ class StandaloneFunctionAndVariableCollector(cst.CSTVisitor):
if self.current_scope_level == 0:
self.functions.append(node)
def get_sorted_functions(self) -> List[cst.FunctionDef]:
return sorted(self.functions, key=lambda func: func.name.value)
class GlobalVariableExtractor(cst.CSTVisitor):
def __init__(self):
@ -402,6 +603,7 @@ def write_topo_class_files(
}
for group_name, class_names in class_groups.items():
module_docstring = f"""
build123d topology
@ -442,9 +644,10 @@ license:
source_tree.visit(variable_collector)
group_classes = [
extracted_classes[name] for name in class_names if name in extracted_classes
sort_class_methods_by_convention(extracted_classes[name])
for name in class_names
if name in extracted_classes
]
# Add imports for base classes based on layer dependencies
additional_imports = []
if group_name != "shape_core":
@ -535,7 +738,7 @@ license:
body.append(var)
body.append(cst.EmptyLine(indent=False))
for func in function_collector.functions:
for func in function_collector.get_sorted_functions():
if func.name.value in function_source[group_name]:
body.append(func)
class_module = cst.Module(body=body, header=header)