From 2a8403e0a0b0ceab44c2802f834a05b1be7d1a80 Mon Sep 17 00:00:00 2001 From: Sj-Si Date: Mon, 15 Apr 2024 16:10:12 -0400 Subject: [PATCH] fix sorting of tree list --- modules/ui_extra_networks.py | 54 +++++++++++++++++++++++------------- 1 file changed, 35 insertions(+), 19 deletions(-) diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py index a7c1ee725..ab432d212 100644 --- a/modules/ui_extra_networks.py +++ b/modules/ui_extra_networks.py @@ -14,8 +14,7 @@ from fastapi.exceptions import HTTPException from PIL import Image from starlette.responses import FileResponse, JSONResponse, Response -from modules import (errors, extra_networks, shared, - ui_extra_networks_user_metadata, util) +from modules import errors, extra_networks, shared, ui_extra_networks_user_metadata, util from modules.images import read_info_from_image, save_image_with_geninfo from modules.infotext_utils import image_from_url_text @@ -134,9 +133,9 @@ class DirectoryTreeNode: """Flattens the keys/values of the tree nodes into a dictionary. Args: - res: The dictionary result updated in place. On initial call, should be passed - as an empty dictionary. - dirs_only: Whether to only add directories to the result. + res: The dictionary result updated in place. On initial call, + should be passed as an empty dictionary. + dirs_only: Whether to only add directories to the result. Raises: KeyError: If any nodes in the tree have the same ID. @@ -150,6 +149,25 @@ class DirectoryTreeNode: for child in self.children: child.flatten(res, dirs_only) + def to_sorted_list(self, res: list) -> None: + """Sorts the tree by absolute path and groups by directories/files. + + Since we are sorting a directory tree, we always want the directories to come + before the files. So we have to sort these two lists separately. + + Args: + res: The list result updated in place. On initial call, should be passed + as an empty list. + """ + res.append(self) + dir_children = [x for x in self.children if x.is_dir] + file_children = [x for x in self.children if not x.is_dir] + for child in sorted(dir_children, key=lambda x: shared.natural_sort_key(x.abspath)): + child.to_sorted_list(res) + + for child in sorted(file_children, key=lambda x: shared.natural_sort_key(x.abspath)): + child.to_sorted_list(res) + def apply(self, fn: Callable) -> None: """Recursively calls passed function with instance for entire tree.""" fn(self) @@ -693,20 +711,21 @@ class ExtraNetworksPage: if not self.tree_roots: return {} - # Flatten each root into a single dict - tree = {} + # Flatten roots into a single sorted list of nodes. + # Directories always come before files. After that, natural sort is used. + sorted_nodes = [] for node in self.tree_roots.values(): - subtree = {} - node.flatten(subtree) - tree.update(subtree) + _sorted_nodes = [] + node.to_sorted_list(_sorted_nodes) + sorted_nodes.extend(_sorted_nodes) path_to_div_id = {} div_id_to_node = {} # reverse mapping # First assign div IDs to each node. Used for parent ID lookup later. - for i, path in enumerate(sorted(tree.keys(), key=shared.natural_sort_key)): + for i, node in enumerate(sorted_nodes): div_id = str(i) - path_to_div_id[path] = div_id - div_id_to_node[div_id] = tree[path] + path_to_div_id[node.abspath] = div_id + div_id_to_node[div_id] = node show_files = shared.opts.extra_networks_tree_view_show_files is True for div_id, node in div_id_to_node.items(): @@ -966,12 +985,9 @@ def initialize(): def register_default_pages(): - from modules.ui_extra_networks_checkpoints import \ - ExtraNetworksPageCheckpoints - from modules.ui_extra_networks_hypernets import \ - ExtraNetworksPageHypernetworks - from modules.ui_extra_networks_textual_inversion import \ - ExtraNetworksPageTextualInversion + from modules.ui_extra_networks_checkpoints import ExtraNetworksPageCheckpoints + from modules.ui_extra_networks_hypernets import ExtraNetworksPageHypernetworks + from modules.ui_extra_networks_textual_inversion import ExtraNetworksPageTextualInversion register_page(ExtraNetworksPageTextualInversion()) register_page(ExtraNetworksPageHypernetworks())