From 44a58b0233f45904b02d11b68c2ecea5cb7f1765 Mon Sep 17 00:00:00 2001 From: Tejas Attarde Date: Fri, 13 Mar 2026 17:21:19 -0400 Subject: [PATCH] set a timeout on launch utils HTTP calls --- modules/launch_utils.py | 964 +++++++++++++------------- modules/textual_inversion/autocrop.py | 690 +++++++++--------- 2 files changed, 827 insertions(+), 827 deletions(-) diff --git a/modules/launch_utils.py b/modules/launch_utils.py index 20c7dc127..9aa6df16f 100644 --- a/modules/launch_utils.py +++ b/modules/launch_utils.py @@ -1,482 +1,482 @@ -# this scripts installs necessary requirements and launches main program in webui.py -import logging -import re -import subprocess -import os -import shutil -import sys -import importlib.util -import importlib.metadata -import platform -import json -import shlex -from functools import lru_cache - -from modules import cmd_args, errors -from modules.paths_internal import script_path, extensions_dir -from modules.timer import startup_timer -from modules import logging_config - -args, _ = cmd_args.parser.parse_known_args() -logging_config.setup_logging(args.loglevel) - -python = sys.executable -git = os.environ.get('GIT', "git") -index_url = os.environ.get('INDEX_URL', "") -dir_repos = "repositories" - -# Whether to default to printing command output -default_command_live = (os.environ.get('WEBUI_LAUNCH_LIVE_OUTPUT') == "1") - -os.environ.setdefault('GRADIO_ANALYTICS_ENABLED', 'False') - - -def check_python_version(): - is_windows = platform.system() == "Windows" - major = sys.version_info.major - minor = sys.version_info.minor - micro = sys.version_info.micro - - if is_windows: - supported_minors = [10] - else: - supported_minors = [7, 8, 9, 10, 11] - - if not (major == 3 and minor in supported_minors): - import modules.errors - - modules.errors.print_error_explanation(f""" -INCOMPATIBLE PYTHON VERSION - -This program is tested with 3.10.6 Python, but you have {major}.{minor}.{micro}. -If you encounter an error with "RuntimeError: Couldn't install torch." message, -or any other error regarding unsuccessful package (library) installation, -please downgrade (or upgrade) to the latest version of 3.10 Python -and delete current Python and "venv" folder in WebUI's directory. - -You can download 3.10 Python from here: https://www.python.org/downloads/release/python-3106/ - -{"Alternatively, use a binary release of WebUI: https://github.com/AUTOMATIC1111/stable-diffusion-webui/releases/tag/v1.0.0-pre" if is_windows else ""} - -Use --skip-python-version-check to suppress this warning. -""") - - -@lru_cache() -def commit_hash(): - try: - return subprocess.check_output([git, "-C", script_path, "rev-parse", "HEAD"], shell=False, encoding='utf8').strip() - except Exception: - return "" - - -@lru_cache() -def git_tag(): - try: - return subprocess.check_output([git, "-C", script_path, "describe", "--tags"], shell=False, encoding='utf8').strip() - except Exception: - try: - - changelog_md = os.path.join(script_path, "CHANGELOG.md") - with open(changelog_md, "r", encoding="utf-8") as file: - line = next((line.strip() for line in file if line.strip()), "") - line = line.replace("## ", "") - return line - except Exception: - return "" - - -def run(command, desc=None, errdesc=None, custom_env=None, live: bool = default_command_live) -> str: - if desc is not None: - print(desc) - - run_kwargs = { - "args": command, - "shell": True, - "env": os.environ if custom_env is None else custom_env, - "encoding": 'utf8', - "errors": 'ignore', - } - - if not live: - run_kwargs["stdout"] = run_kwargs["stderr"] = subprocess.PIPE - - result = subprocess.run(**run_kwargs) - - if result.returncode != 0: - error_bits = [ - f"{errdesc or 'Error running command'}.", - f"Command: {command}", - f"Error code: {result.returncode}", - ] - if result.stdout: - error_bits.append(f"stdout: {result.stdout}") - if result.stderr: - error_bits.append(f"stderr: {result.stderr}") - raise RuntimeError("\n".join(error_bits)) - - return (result.stdout or "") - - -def is_installed(package): - try: - dist = importlib.metadata.distribution(package) - except importlib.metadata.PackageNotFoundError: - try: - spec = importlib.util.find_spec(package) - except ModuleNotFoundError: - return False - - return spec is not None - - return dist is not None - - -def repo_dir(name): - return os.path.join(script_path, dir_repos, name) - - -def run_pip(command, desc=None, live=default_command_live): - if args.skip_install: - return - - index_url_line = f' --index-url {index_url}' if index_url != '' else '' - return run(f'"{python}" -m pip {command} --prefer-binary{index_url_line}', desc=f"Installing {desc}", errdesc=f"Couldn't install {desc}", live=live) - - -def check_run_python(code: str) -> bool: - result = subprocess.run([python, "-c", code], capture_output=True, shell=False) - return result.returncode == 0 - - -def git_fix_workspace(dir, name): - run(f'"{git}" -C "{dir}" fetch --refetch --no-auto-gc', f"Fetching all contents for {name}", f"Couldn't fetch {name}", live=True) - run(f'"{git}" -C "{dir}" gc --aggressive --prune=now', f"Pruning {name}", f"Couldn't prune {name}", live=True) - return - - -def run_git(dir, name, command, desc=None, errdesc=None, custom_env=None, live: bool = default_command_live, autofix=True): - try: - return run(f'"{git}" -C "{dir}" {command}', desc=desc, errdesc=errdesc, custom_env=custom_env, live=live) - except RuntimeError: - if not autofix: - raise - - print(f"{errdesc}, attempting autofix...") - git_fix_workspace(dir, name) - - return run(f'"{git}" -C "{dir}" {command}', desc=desc, errdesc=errdesc, custom_env=custom_env, live=live) - - -def git_clone(url, dir, name, commithash=None): - # TODO clone into temporary dir and move if successful - - if os.path.exists(dir): - if commithash is None: - return - - current_hash = run_git(dir, name, 'rev-parse HEAD', None, f"Couldn't determine {name}'s hash: {commithash}", live=False).strip() - if current_hash == commithash: - return - - if run_git(dir, name, 'config --get remote.origin.url', None, f"Couldn't determine {name}'s origin URL", live=False).strip() != url: - run_git(dir, name, f'remote set-url origin "{url}"', None, f"Failed to set {name}'s origin URL", live=False) - - run_git(dir, name, 'fetch', f"Fetching updates for {name}...", f"Couldn't fetch {name}", autofix=False) - - run_git(dir, name, f'checkout {commithash}', f"Checking out commit for {name} with hash: {commithash}...", f"Couldn't checkout commit {commithash} for {name}", live=True) - - return - - try: - run(f'"{git}" clone --config core.filemode=false "{url}" "{dir}"', f"Cloning {name} into {dir}...", f"Couldn't clone {name}", live=True) - except RuntimeError: - shutil.rmtree(dir, ignore_errors=True) - raise - - if commithash is not None: - run(f'"{git}" -C "{dir}" checkout {commithash}', None, "Couldn't checkout {name}'s hash: {commithash}") - - -def git_pull_recursive(dir): - for subdir, _, _ in os.walk(dir): - if os.path.exists(os.path.join(subdir, '.git')): - try: - output = subprocess.check_output([git, '-C', subdir, 'pull', '--autostash']) - print(f"Pulled changes for repository in '{subdir}':\n{output.decode('utf-8').strip()}\n") - except subprocess.CalledProcessError as e: - print(f"Couldn't perform 'git pull' on repository in '{subdir}':\n{e.output.decode('utf-8').strip()}\n") - - -def version_check(commit): - try: - import requests - commits = requests.get('https://api.github.com/repos/AUTOMATIC1111/stable-diffusion-webui/branches/master').json() - if commit != "" and commits['commit']['sha'] != commit: - print("--------------------------------------------------------") - print("| You are not up to date with the most recent release. |") - print("| Consider running `git pull` to update. |") - print("--------------------------------------------------------") - elif commits['commit']['sha'] == commit: - print("You are up to date with the most recent release.") - else: - print("Not a git clone, can't perform version check.") - except Exception as e: - print("version check failed", e) - - -def run_extension_installer(extension_dir): - path_installer = os.path.join(extension_dir, "install.py") - if not os.path.isfile(path_installer): - return - - try: - env = os.environ.copy() - env['PYTHONPATH'] = f"{script_path}{os.pathsep}{env.get('PYTHONPATH', '')}" - - stdout = run(f'"{python}" "{path_installer}"', errdesc=f"Error running install.py for extension {extension_dir}", custom_env=env).strip() - if stdout: - print(stdout) - except Exception as e: - errors.report(str(e)) - - -def list_extensions(settings_file): - settings = {} - - try: - with open(settings_file, "r", encoding="utf8") as file: - settings = json.load(file) - except FileNotFoundError: - pass - except Exception: - errors.report(f'\nCould not load settings\nThe config file "{settings_file}" is likely corrupted\nIt has been moved to the "tmp/config.json"\nReverting config to default\n\n''', exc_info=True) - os.replace(settings_file, os.path.join(script_path, "tmp", "config.json")) - - disabled_extensions = set(settings.get('disabled_extensions', [])) - disable_all_extensions = settings.get('disable_all_extensions', 'none') - - if disable_all_extensions != 'none' or args.disable_extra_extensions or args.disable_all_extensions or not os.path.isdir(extensions_dir): - return [] - - return [x for x in os.listdir(extensions_dir) if x not in disabled_extensions] - - -def run_extensions_installers(settings_file): - if not os.path.isdir(extensions_dir): - return - - with startup_timer.subcategory("run extensions installers"): - for dirname_extension in list_extensions(settings_file): - logging.debug(f"Installing {dirname_extension}") - - path = os.path.join(extensions_dir, dirname_extension) - - if os.path.isdir(path): - run_extension_installer(path) - startup_timer.record(dirname_extension) - - -re_requirement = re.compile(r"\s*([-_a-zA-Z0-9]+)\s*(?:==\s*([-+_.a-zA-Z0-9]+))?\s*") - - -def requirements_met(requirements_file): - """ - Does a simple parse of a requirements.txt file to determine if all rerqirements in it - are already installed. Returns True if so, False if not installed or parsing fails. - """ - - import importlib.metadata - import packaging.version - - with open(requirements_file, "r", encoding="utf8") as file: - for line in file: - if line.strip() == "": - continue - - m = re.match(re_requirement, line) - if m is None: - return False - - package = m.group(1).strip() - version_required = (m.group(2) or "").strip() - - if version_required == "": - continue - - try: - version_installed = importlib.metadata.version(package) - except Exception: - return False - - if packaging.version.parse(version_required) != packaging.version.parse(version_installed): - return False - - return True - - -def prepare_environment(): - torch_index_url = os.environ.get('TORCH_INDEX_URL', "https://download.pytorch.org/whl/cu121") - torch_command = os.environ.get('TORCH_COMMAND', f"pip install torch==2.1.2 torchvision==0.16.2 --extra-index-url {torch_index_url}") - if args.use_ipex: - if platform.system() == "Windows": - # The "Nuullll/intel-extension-for-pytorch" wheels were built from IPEX source for Intel Arc GPU: https://github.com/intel/intel-extension-for-pytorch/tree/xpu-main - # This is NOT an Intel official release so please use it at your own risk!! - # See https://github.com/Nuullll/intel-extension-for-pytorch/releases/tag/v2.0.110%2Bxpu-master%2Bdll-bundle for details. - # - # Strengths (over official IPEX 2.0.110 windows release): - # - AOT build (for Arc GPU only) to eliminate JIT compilation overhead: https://github.com/intel/intel-extension-for-pytorch/issues/399 - # - Bundles minimal oneAPI 2023.2 dependencies into the python wheels, so users don't need to install oneAPI for the whole system. - # - Provides a compatible torchvision wheel: https://github.com/intel/intel-extension-for-pytorch/issues/465 - # Limitation: - # - Only works for python 3.10 - url_prefix = "https://github.com/Nuullll/intel-extension-for-pytorch/releases/download/v2.0.110%2Bxpu-master%2Bdll-bundle" - torch_command = os.environ.get('TORCH_COMMAND', f"pip install {url_prefix}/torch-2.0.0a0+gite9ebda2-cp310-cp310-win_amd64.whl {url_prefix}/torchvision-0.15.2a0+fa99a53-cp310-cp310-win_amd64.whl {url_prefix}/intel_extension_for_pytorch-2.0.110+gitc6ea20b-cp310-cp310-win_amd64.whl") - else: - # Using official IPEX release for linux since it's already an AOT build. - # However, users still have to install oneAPI toolkit and activate oneAPI environment manually. - # See https://intel.github.io/intel-extension-for-pytorch/index.html#installation for details. - torch_index_url = os.environ.get('TORCH_INDEX_URL', "https://pytorch-extension.intel.com/release-whl/stable/xpu/us/") - torch_command = os.environ.get('TORCH_COMMAND', f"pip install torch==2.0.0a0 intel-extension-for-pytorch==2.0.110+gitba7f6c1 --extra-index-url {torch_index_url}") - requirements_file = os.environ.get('REQS_FILE', "requirements_versions.txt") - requirements_file_for_npu = os.environ.get('REQS_FILE_FOR_NPU', "requirements_npu.txt") - - xformers_package = os.environ.get('XFORMERS_PACKAGE', 'xformers==0.0.23.post1') - clip_package = os.environ.get('CLIP_PACKAGE', "https://github.com/openai/CLIP/archive/d50d76daa670286dd6cacf3bcd80b5e4823fc8e1.zip") - openclip_package = os.environ.get('OPENCLIP_PACKAGE', "https://github.com/mlfoundations/open_clip/archive/bb6e834e9c70d9c27d0dc3ecedeebeaeb1ffad6b.zip") - - assets_repo = os.environ.get('ASSETS_REPO', "https://github.com/AUTOMATIC1111/stable-diffusion-webui-assets.git") - stable_diffusion_repo = os.environ.get('STABLE_DIFFUSION_REPO', "https://github.com/Stability-AI/stablediffusion.git") - stable_diffusion_xl_repo = os.environ.get('STABLE_DIFFUSION_XL_REPO', "https://github.com/Stability-AI/generative-models.git") - k_diffusion_repo = os.environ.get('K_DIFFUSION_REPO', 'https://github.com/crowsonkb/k-diffusion.git') - blip_repo = os.environ.get('BLIP_REPO', 'https://github.com/salesforce/BLIP.git') - - assets_commit_hash = os.environ.get('ASSETS_COMMIT_HASH', "6f7db241d2f8ba7457bac5ca9753331f0c266917") - stable_diffusion_commit_hash = os.environ.get('STABLE_DIFFUSION_COMMIT_HASH', "cf1d67a6fd5ea1aa600c4df58e5b47da45f6bdbf") - stable_diffusion_xl_commit_hash = os.environ.get('STABLE_DIFFUSION_XL_COMMIT_HASH', "45c443b316737a4ab6e40413d7794a7f5657c19f") - k_diffusion_commit_hash = os.environ.get('K_DIFFUSION_COMMIT_HASH', "ab527a9a6d347f364e3d185ba6d714e22d80cb3c") - blip_commit_hash = os.environ.get('BLIP_COMMIT_HASH', "48211a1594f1321b00f14c9f7a5b4813144b2fb9") - - try: - # the existence of this file is a signal to webui.sh/bat that webui needs to be restarted when it stops execution - os.remove(os.path.join(script_path, "tmp", "restart")) - os.environ.setdefault('SD_WEBUI_RESTARTING', '1') - except OSError: - pass - - if not args.skip_python_version_check: - check_python_version() - - startup_timer.record("checks") - - commit = commit_hash() - tag = git_tag() - startup_timer.record("git version info") - - print(f"Python {sys.version}") - print(f"Version: {tag}") - print(f"Commit hash: {commit}") - - if args.reinstall_torch or not is_installed("torch") or not is_installed("torchvision"): - run(f'"{python}" -m {torch_command}', "Installing torch and torchvision", "Couldn't install torch", live=True) - startup_timer.record("install torch") - - if args.use_ipex: - args.skip_torch_cuda_test = True - if not args.skip_torch_cuda_test and not check_run_python("import torch; assert torch.cuda.is_available()"): - raise RuntimeError( - 'Torch is not able to use GPU; ' - 'add --skip-torch-cuda-test to COMMANDLINE_ARGS variable to disable this check' - ) - startup_timer.record("torch GPU test") - - if not is_installed("clip"): - run_pip(f"install {clip_package}", "clip") - startup_timer.record("install clip") - - if not is_installed("open_clip"): - run_pip(f"install {openclip_package}", "open_clip") - startup_timer.record("install open_clip") - - if (not is_installed("xformers") or args.reinstall_xformers) and args.xformers: - run_pip(f"install -U -I --no-deps {xformers_package}", "xformers") - startup_timer.record("install xformers") - - if not is_installed("ngrok") and args.ngrok: - run_pip("install ngrok", "ngrok") - startup_timer.record("install ngrok") - - os.makedirs(os.path.join(script_path, dir_repos), exist_ok=True) - - git_clone(assets_repo, repo_dir('stable-diffusion-webui-assets'), "assets", assets_commit_hash) - git_clone(stable_diffusion_repo, repo_dir('stable-diffusion-stability-ai'), "Stable Diffusion", stable_diffusion_commit_hash) - git_clone(stable_diffusion_xl_repo, repo_dir('generative-models'), "Stable Diffusion XL", stable_diffusion_xl_commit_hash) - git_clone(k_diffusion_repo, repo_dir('k-diffusion'), "K-diffusion", k_diffusion_commit_hash) - git_clone(blip_repo, repo_dir('BLIP'), "BLIP", blip_commit_hash) - - startup_timer.record("clone repositores") - - if not os.path.isfile(requirements_file): - requirements_file = os.path.join(script_path, requirements_file) - - if not requirements_met(requirements_file): - run_pip(f"install -r \"{requirements_file}\"", "requirements") - startup_timer.record("install requirements") - - if not os.path.isfile(requirements_file_for_npu): - requirements_file_for_npu = os.path.join(script_path, requirements_file_for_npu) - - if "torch_npu" in torch_command and not requirements_met(requirements_file_for_npu): - run_pip(f"install -r \"{requirements_file_for_npu}\"", "requirements_for_npu") - startup_timer.record("install requirements_for_npu") - - if not args.skip_install: - run_extensions_installers(settings_file=args.ui_settings_file) - - if args.update_check: - version_check(commit) - startup_timer.record("check version") - - if args.update_all_extensions: - git_pull_recursive(extensions_dir) - startup_timer.record("update extensions") - - if "--exit" in sys.argv: - print("Exiting because of --exit argument") - exit(0) - - -def configure_for_tests(): - if "--api" not in sys.argv: - sys.argv.append("--api") - if "--ckpt" not in sys.argv: - sys.argv.append("--ckpt") - sys.argv.append(os.path.join(script_path, "test/test_files/empty.pt")) - if "--skip-torch-cuda-test" not in sys.argv: - sys.argv.append("--skip-torch-cuda-test") - if "--disable-nan-check" not in sys.argv: - sys.argv.append("--disable-nan-check") - - os.environ['COMMANDLINE_ARGS'] = "" - - -def start(): - print(f"Launching {'API server' if '--nowebui' in sys.argv else 'Web UI'} with arguments: {shlex.join(sys.argv[1:])}") - import webui - if '--nowebui' in sys.argv: - webui.api_only() - else: - webui.webui() - - -def dump_sysinfo(): - from modules import sysinfo - import datetime - - text = sysinfo.get() - filename = f"sysinfo-{datetime.datetime.utcnow().strftime('%Y-%m-%d-%H-%M')}.json" - - with open(filename, "w", encoding="utf8") as file: - file.write(text) - - return filename +# this scripts installs necessary requirements and launches main program in webui.py +import logging +import re +import subprocess +import os +import shutil +import sys +import importlib.util +import importlib.metadata +import platform +import json +import shlex +from functools import lru_cache + +from modules import cmd_args, errors +from modules.paths_internal import script_path, extensions_dir +from modules.timer import startup_timer +from modules import logging_config + +args, _ = cmd_args.parser.parse_known_args() +logging_config.setup_logging(args.loglevel) + +python = sys.executable +git = os.environ.get('GIT', "git") +index_url = os.environ.get('INDEX_URL', "") +dir_repos = "repositories" + +# Whether to default to printing command output +default_command_live = (os.environ.get('WEBUI_LAUNCH_LIVE_OUTPUT') == "1") + +os.environ.setdefault('GRADIO_ANALYTICS_ENABLED', 'False') + + +def check_python_version(): + is_windows = platform.system() == "Windows" + major = sys.version_info.major + minor = sys.version_info.minor + micro = sys.version_info.micro + + if is_windows: + supported_minors = [10] + else: + supported_minors = [7, 8, 9, 10, 11] + + if not (major == 3 and minor in supported_minors): + import modules.errors + + modules.errors.print_error_explanation(f""" +INCOMPATIBLE PYTHON VERSION + +This program is tested with 3.10.6 Python, but you have {major}.{minor}.{micro}. +If you encounter an error with "RuntimeError: Couldn't install torch." message, +or any other error regarding unsuccessful package (library) installation, +please downgrade (or upgrade) to the latest version of 3.10 Python +and delete current Python and "venv" folder in WebUI's directory. + +You can download 3.10 Python from here: https://www.python.org/downloads/release/python-3106/ + +{"Alternatively, use a binary release of WebUI: https://github.com/AUTOMATIC1111/stable-diffusion-webui/releases/tag/v1.0.0-pre" if is_windows else ""} + +Use --skip-python-version-check to suppress this warning. +""") + + +@lru_cache() +def commit_hash(): + try: + return subprocess.check_output([git, "-C", script_path, "rev-parse", "HEAD"], shell=False, encoding='utf8').strip() + except Exception: + return "" + + +@lru_cache() +def git_tag(): + try: + return subprocess.check_output([git, "-C", script_path, "describe", "--tags"], shell=False, encoding='utf8').strip() + except Exception: + try: + + changelog_md = os.path.join(script_path, "CHANGELOG.md") + with open(changelog_md, "r", encoding="utf-8") as file: + line = next((line.strip() for line in file if line.strip()), "") + line = line.replace("## ", "") + return line + except Exception: + return "" + + +def run(command, desc=None, errdesc=None, custom_env=None, live: bool = default_command_live) -> str: + if desc is not None: + print(desc) + + run_kwargs = { + "args": command, + "shell": True, + "env": os.environ if custom_env is None else custom_env, + "encoding": 'utf8', + "errors": 'ignore', + } + + if not live: + run_kwargs["stdout"] = run_kwargs["stderr"] = subprocess.PIPE + + result = subprocess.run(**run_kwargs) + + if result.returncode != 0: + error_bits = [ + f"{errdesc or 'Error running command'}.", + f"Command: {command}", + f"Error code: {result.returncode}", + ] + if result.stdout: + error_bits.append(f"stdout: {result.stdout}") + if result.stderr: + error_bits.append(f"stderr: {result.stderr}") + raise RuntimeError("\n".join(error_bits)) + + return (result.stdout or "") + + +def is_installed(package): + try: + dist = importlib.metadata.distribution(package) + except importlib.metadata.PackageNotFoundError: + try: + spec = importlib.util.find_spec(package) + except ModuleNotFoundError: + return False + + return spec is not None + + return dist is not None + + +def repo_dir(name): + return os.path.join(script_path, dir_repos, name) + + +def run_pip(command, desc=None, live=default_command_live): + if args.skip_install: + return + + index_url_line = f' --index-url {index_url}' if index_url != '' else '' + return run(f'"{python}" -m pip {command} --prefer-binary{index_url_line}', desc=f"Installing {desc}", errdesc=f"Couldn't install {desc}", live=live) + + +def check_run_python(code: str) -> bool: + result = subprocess.run([python, "-c", code], capture_output=True, shell=False) + return result.returncode == 0 + + +def git_fix_workspace(dir, name): + run(f'"{git}" -C "{dir}" fetch --refetch --no-auto-gc', f"Fetching all contents for {name}", f"Couldn't fetch {name}", live=True) + run(f'"{git}" -C "{dir}" gc --aggressive --prune=now', f"Pruning {name}", f"Couldn't prune {name}", live=True) + return + + +def run_git(dir, name, command, desc=None, errdesc=None, custom_env=None, live: bool = default_command_live, autofix=True): + try: + return run(f'"{git}" -C "{dir}" {command}', desc=desc, errdesc=errdesc, custom_env=custom_env, live=live) + except RuntimeError: + if not autofix: + raise + + print(f"{errdesc}, attempting autofix...") + git_fix_workspace(dir, name) + + return run(f'"{git}" -C "{dir}" {command}', desc=desc, errdesc=errdesc, custom_env=custom_env, live=live) + + +def git_clone(url, dir, name, commithash=None): + # TODO clone into temporary dir and move if successful + + if os.path.exists(dir): + if commithash is None: + return + + current_hash = run_git(dir, name, 'rev-parse HEAD', None, f"Couldn't determine {name}'s hash: {commithash}", live=False).strip() + if current_hash == commithash: + return + + if run_git(dir, name, 'config --get remote.origin.url', None, f"Couldn't determine {name}'s origin URL", live=False).strip() != url: + run_git(dir, name, f'remote set-url origin "{url}"', None, f"Failed to set {name}'s origin URL", live=False) + + run_git(dir, name, 'fetch', f"Fetching updates for {name}...", f"Couldn't fetch {name}", autofix=False) + + run_git(dir, name, f'checkout {commithash}', f"Checking out commit for {name} with hash: {commithash}...", f"Couldn't checkout commit {commithash} for {name}", live=True) + + return + + try: + run(f'"{git}" clone --config core.filemode=false "{url}" "{dir}"', f"Cloning {name} into {dir}...", f"Couldn't clone {name}", live=True) + except RuntimeError: + shutil.rmtree(dir, ignore_errors=True) + raise + + if commithash is not None: + run(f'"{git}" -C "{dir}" checkout {commithash}', None, "Couldn't checkout {name}'s hash: {commithash}") + + +def git_pull_recursive(dir): + for subdir, _, _ in os.walk(dir): + if os.path.exists(os.path.join(subdir, '.git')): + try: + output = subprocess.check_output([git, '-C', subdir, 'pull', '--autostash']) + print(f"Pulled changes for repository in '{subdir}':\n{output.decode('utf-8').strip()}\n") + except subprocess.CalledProcessError as e: + print(f"Couldn't perform 'git pull' on repository in '{subdir}':\n{e.output.decode('utf-8').strip()}\n") + + +def version_check(commit): + try: + import requests + commits = requests.get('https://api.github.com/repos/AUTOMATIC1111/stable-diffusion-webui/branches/master', timeout=10.0).json() + if commit != "" and commits['commit']['sha'] != commit: + print("--------------------------------------------------------") + print("| You are not up to date with the most recent release. |") + print("| Consider running `git pull` to update. |") + print("--------------------------------------------------------") + elif commits['commit']['sha'] == commit: + print("You are up to date with the most recent release.") + else: + print("Not a git clone, can't perform version check.") + except Exception as e: + print("version check failed", e) + + +def run_extension_installer(extension_dir): + path_installer = os.path.join(extension_dir, "install.py") + if not os.path.isfile(path_installer): + return + + try: + env = os.environ.copy() + env['PYTHONPATH'] = f"{script_path}{os.pathsep}{env.get('PYTHONPATH', '')}" + + stdout = run(f'"{python}" "{path_installer}"', errdesc=f"Error running install.py for extension {extension_dir}", custom_env=env).strip() + if stdout: + print(stdout) + except Exception as e: + errors.report(str(e)) + + +def list_extensions(settings_file): + settings = {} + + try: + with open(settings_file, "r", encoding="utf8") as file: + settings = json.load(file) + except FileNotFoundError: + pass + except Exception: + errors.report(f'\nCould not load settings\nThe config file "{settings_file}" is likely corrupted\nIt has been moved to the "tmp/config.json"\nReverting config to default\n\n''', exc_info=True) + os.replace(settings_file, os.path.join(script_path, "tmp", "config.json")) + + disabled_extensions = set(settings.get('disabled_extensions', [])) + disable_all_extensions = settings.get('disable_all_extensions', 'none') + + if disable_all_extensions != 'none' or args.disable_extra_extensions or args.disable_all_extensions or not os.path.isdir(extensions_dir): + return [] + + return [x for x in os.listdir(extensions_dir) if x not in disabled_extensions] + + +def run_extensions_installers(settings_file): + if not os.path.isdir(extensions_dir): + return + + with startup_timer.subcategory("run extensions installers"): + for dirname_extension in list_extensions(settings_file): + logging.debug(f"Installing {dirname_extension}") + + path = os.path.join(extensions_dir, dirname_extension) + + if os.path.isdir(path): + run_extension_installer(path) + startup_timer.record(dirname_extension) + + +re_requirement = re.compile(r"\s*([-_a-zA-Z0-9]+)\s*(?:==\s*([-+_.a-zA-Z0-9]+))?\s*") + + +def requirements_met(requirements_file): + """ + Does a simple parse of a requirements.txt file to determine if all rerqirements in it + are already installed. Returns True if so, False if not installed or parsing fails. + """ + + import importlib.metadata + import packaging.version + + with open(requirements_file, "r", encoding="utf8") as file: + for line in file: + if line.strip() == "": + continue + + m = re.match(re_requirement, line) + if m is None: + return False + + package = m.group(1).strip() + version_required = (m.group(2) or "").strip() + + if version_required == "": + continue + + try: + version_installed = importlib.metadata.version(package) + except Exception: + return False + + if packaging.version.parse(version_required) != packaging.version.parse(version_installed): + return False + + return True + + +def prepare_environment(): + torch_index_url = os.environ.get('TORCH_INDEX_URL', "https://download.pytorch.org/whl/cu121") + torch_command = os.environ.get('TORCH_COMMAND', f"pip install torch==2.1.2 torchvision==0.16.2 --extra-index-url {torch_index_url}") + if args.use_ipex: + if platform.system() == "Windows": + # The "Nuullll/intel-extension-for-pytorch" wheels were built from IPEX source for Intel Arc GPU: https://github.com/intel/intel-extension-for-pytorch/tree/xpu-main + # This is NOT an Intel official release so please use it at your own risk!! + # See https://github.com/Nuullll/intel-extension-for-pytorch/releases/tag/v2.0.110%2Bxpu-master%2Bdll-bundle for details. + # + # Strengths (over official IPEX 2.0.110 windows release): + # - AOT build (for Arc GPU only) to eliminate JIT compilation overhead: https://github.com/intel/intel-extension-for-pytorch/issues/399 + # - Bundles minimal oneAPI 2023.2 dependencies into the python wheels, so users don't need to install oneAPI for the whole system. + # - Provides a compatible torchvision wheel: https://github.com/intel/intel-extension-for-pytorch/issues/465 + # Limitation: + # - Only works for python 3.10 + url_prefix = "https://github.com/Nuullll/intel-extension-for-pytorch/releases/download/v2.0.110%2Bxpu-master%2Bdll-bundle" + torch_command = os.environ.get('TORCH_COMMAND', f"pip install {url_prefix}/torch-2.0.0a0+gite9ebda2-cp310-cp310-win_amd64.whl {url_prefix}/torchvision-0.15.2a0+fa99a53-cp310-cp310-win_amd64.whl {url_prefix}/intel_extension_for_pytorch-2.0.110+gitc6ea20b-cp310-cp310-win_amd64.whl") + else: + # Using official IPEX release for linux since it's already an AOT build. + # However, users still have to install oneAPI toolkit and activate oneAPI environment manually. + # See https://intel.github.io/intel-extension-for-pytorch/index.html#installation for details. + torch_index_url = os.environ.get('TORCH_INDEX_URL', "https://pytorch-extension.intel.com/release-whl/stable/xpu/us/") + torch_command = os.environ.get('TORCH_COMMAND', f"pip install torch==2.0.0a0 intel-extension-for-pytorch==2.0.110+gitba7f6c1 --extra-index-url {torch_index_url}") + requirements_file = os.environ.get('REQS_FILE', "requirements_versions.txt") + requirements_file_for_npu = os.environ.get('REQS_FILE_FOR_NPU', "requirements_npu.txt") + + xformers_package = os.environ.get('XFORMERS_PACKAGE', 'xformers==0.0.23.post1') + clip_package = os.environ.get('CLIP_PACKAGE', "https://github.com/openai/CLIP/archive/d50d76daa670286dd6cacf3bcd80b5e4823fc8e1.zip") + openclip_package = os.environ.get('OPENCLIP_PACKAGE', "https://github.com/mlfoundations/open_clip/archive/bb6e834e9c70d9c27d0dc3ecedeebeaeb1ffad6b.zip") + + assets_repo = os.environ.get('ASSETS_REPO', "https://github.com/AUTOMATIC1111/stable-diffusion-webui-assets.git") + stable_diffusion_repo = os.environ.get('STABLE_DIFFUSION_REPO', "https://github.com/Stability-AI/stablediffusion.git") + stable_diffusion_xl_repo = os.environ.get('STABLE_DIFFUSION_XL_REPO', "https://github.com/Stability-AI/generative-models.git") + k_diffusion_repo = os.environ.get('K_DIFFUSION_REPO', 'https://github.com/crowsonkb/k-diffusion.git') + blip_repo = os.environ.get('BLIP_REPO', 'https://github.com/salesforce/BLIP.git') + + assets_commit_hash = os.environ.get('ASSETS_COMMIT_HASH', "6f7db241d2f8ba7457bac5ca9753331f0c266917") + stable_diffusion_commit_hash = os.environ.get('STABLE_DIFFUSION_COMMIT_HASH', "cf1d67a6fd5ea1aa600c4df58e5b47da45f6bdbf") + stable_diffusion_xl_commit_hash = os.environ.get('STABLE_DIFFUSION_XL_COMMIT_HASH', "45c443b316737a4ab6e40413d7794a7f5657c19f") + k_diffusion_commit_hash = os.environ.get('K_DIFFUSION_COMMIT_HASH', "ab527a9a6d347f364e3d185ba6d714e22d80cb3c") + blip_commit_hash = os.environ.get('BLIP_COMMIT_HASH', "48211a1594f1321b00f14c9f7a5b4813144b2fb9") + + try: + # the existence of this file is a signal to webui.sh/bat that webui needs to be restarted when it stops execution + os.remove(os.path.join(script_path, "tmp", "restart")) + os.environ.setdefault('SD_WEBUI_RESTARTING', '1') + except OSError: + pass + + if not args.skip_python_version_check: + check_python_version() + + startup_timer.record("checks") + + commit = commit_hash() + tag = git_tag() + startup_timer.record("git version info") + + print(f"Python {sys.version}") + print(f"Version: {tag}") + print(f"Commit hash: {commit}") + + if args.reinstall_torch or not is_installed("torch") or not is_installed("torchvision"): + run(f'"{python}" -m {torch_command}', "Installing torch and torchvision", "Couldn't install torch", live=True) + startup_timer.record("install torch") + + if args.use_ipex: + args.skip_torch_cuda_test = True + if not args.skip_torch_cuda_test and not check_run_python("import torch; assert torch.cuda.is_available()"): + raise RuntimeError( + 'Torch is not able to use GPU; ' + 'add --skip-torch-cuda-test to COMMANDLINE_ARGS variable to disable this check' + ) + startup_timer.record("torch GPU test") + + if not is_installed("clip"): + run_pip(f"install {clip_package}", "clip") + startup_timer.record("install clip") + + if not is_installed("open_clip"): + run_pip(f"install {openclip_package}", "open_clip") + startup_timer.record("install open_clip") + + if (not is_installed("xformers") or args.reinstall_xformers) and args.xformers: + run_pip(f"install -U -I --no-deps {xformers_package}", "xformers") + startup_timer.record("install xformers") + + if not is_installed("ngrok") and args.ngrok: + run_pip("install ngrok", "ngrok") + startup_timer.record("install ngrok") + + os.makedirs(os.path.join(script_path, dir_repos), exist_ok=True) + + git_clone(assets_repo, repo_dir('stable-diffusion-webui-assets'), "assets", assets_commit_hash) + git_clone(stable_diffusion_repo, repo_dir('stable-diffusion-stability-ai'), "Stable Diffusion", stable_diffusion_commit_hash) + git_clone(stable_diffusion_xl_repo, repo_dir('generative-models'), "Stable Diffusion XL", stable_diffusion_xl_commit_hash) + git_clone(k_diffusion_repo, repo_dir('k-diffusion'), "K-diffusion", k_diffusion_commit_hash) + git_clone(blip_repo, repo_dir('BLIP'), "BLIP", blip_commit_hash) + + startup_timer.record("clone repositores") + + if not os.path.isfile(requirements_file): + requirements_file = os.path.join(script_path, requirements_file) + + if not requirements_met(requirements_file): + run_pip(f"install -r \"{requirements_file}\"", "requirements") + startup_timer.record("install requirements") + + if not os.path.isfile(requirements_file_for_npu): + requirements_file_for_npu = os.path.join(script_path, requirements_file_for_npu) + + if "torch_npu" in torch_command and not requirements_met(requirements_file_for_npu): + run_pip(f"install -r \"{requirements_file_for_npu}\"", "requirements_for_npu") + startup_timer.record("install requirements_for_npu") + + if not args.skip_install: + run_extensions_installers(settings_file=args.ui_settings_file) + + if args.update_check: + version_check(commit) + startup_timer.record("check version") + + if args.update_all_extensions: + git_pull_recursive(extensions_dir) + startup_timer.record("update extensions") + + if "--exit" in sys.argv: + print("Exiting because of --exit argument") + exit(0) + + +def configure_for_tests(): + if "--api" not in sys.argv: + sys.argv.append("--api") + if "--ckpt" not in sys.argv: + sys.argv.append("--ckpt") + sys.argv.append(os.path.join(script_path, "test/test_files/empty.pt")) + if "--skip-torch-cuda-test" not in sys.argv: + sys.argv.append("--skip-torch-cuda-test") + if "--disable-nan-check" not in sys.argv: + sys.argv.append("--disable-nan-check") + + os.environ['COMMANDLINE_ARGS'] = "" + + +def start(): + print(f"Launching {'API server' if '--nowebui' in sys.argv else 'Web UI'} with arguments: {shlex.join(sys.argv[1:])}") + import webui + if '--nowebui' in sys.argv: + webui.api_only() + else: + webui.webui() + + +def dump_sysinfo(): + from modules import sysinfo + import datetime + + text = sysinfo.get() + filename = f"sysinfo-{datetime.datetime.utcnow().strftime('%Y-%m-%d-%H-%M')}.json" + + with open(filename, "w", encoding="utf8") as file: + file.write(text) + + return filename diff --git a/modules/textual_inversion/autocrop.py b/modules/textual_inversion/autocrop.py index ca858ef4c..c75887b7e 100644 --- a/modules/textual_inversion/autocrop.py +++ b/modules/textual_inversion/autocrop.py @@ -1,345 +1,345 @@ -import cv2 -import requests -import os -import numpy as np -from PIL import ImageDraw -from modules import paths_internal -from pkg_resources import parse_version - -GREEN = "#0F0" -BLUE = "#00F" -RED = "#F00" - - -def crop_image(im, settings): - """ Intelligently crop an image to the subject matter """ - - scale_by = 1 - if is_landscape(im.width, im.height): - scale_by = settings.crop_height / im.height - elif is_portrait(im.width, im.height): - scale_by = settings.crop_width / im.width - elif is_square(im.width, im.height): - if is_square(settings.crop_width, settings.crop_height): - scale_by = settings.crop_width / im.width - elif is_landscape(settings.crop_width, settings.crop_height): - scale_by = settings.crop_width / im.width - elif is_portrait(settings.crop_width, settings.crop_height): - scale_by = settings.crop_height / im.height - - im = im.resize((int(im.width * scale_by), int(im.height * scale_by))) - im_debug = im.copy() - - focus = focal_point(im_debug, settings) - - # take the focal point and turn it into crop coordinates that try to center over the focal - # point but then get adjusted back into the frame - y_half = int(settings.crop_height / 2) - x_half = int(settings.crop_width / 2) - - x1 = focus.x - x_half - if x1 < 0: - x1 = 0 - elif x1 + settings.crop_width > im.width: - x1 = im.width - settings.crop_width - - y1 = focus.y - y_half - if y1 < 0: - y1 = 0 - elif y1 + settings.crop_height > im.height: - y1 = im.height - settings.crop_height - - x2 = x1 + settings.crop_width - y2 = y1 + settings.crop_height - - crop = [x1, y1, x2, y2] - - results = [] - - results.append(im.crop(tuple(crop))) - - if settings.annotate_image: - d = ImageDraw.Draw(im_debug) - rect = list(crop) - rect[2] -= 1 - rect[3] -= 1 - d.rectangle(rect, outline=GREEN) - results.append(im_debug) - if settings.desktop_view_image: - im_debug.show() - - return results - - -def focal_point(im, settings): - corner_points = image_corner_points(im, settings) if settings.corner_points_weight > 0 else [] - entropy_points = image_entropy_points(im, settings) if settings.entropy_points_weight > 0 else [] - face_points = image_face_points(im, settings) if settings.face_points_weight > 0 else [] - - pois = [] - - weight_pref_total = 0 - if corner_points: - weight_pref_total += settings.corner_points_weight - if entropy_points: - weight_pref_total += settings.entropy_points_weight - if face_points: - weight_pref_total += settings.face_points_weight - - corner_centroid = None - if corner_points: - corner_centroid = centroid(corner_points) - corner_centroid.weight = settings.corner_points_weight / weight_pref_total - pois.append(corner_centroid) - - entropy_centroid = None - if entropy_points: - entropy_centroid = centroid(entropy_points) - entropy_centroid.weight = settings.entropy_points_weight / weight_pref_total - pois.append(entropy_centroid) - - face_centroid = None - if face_points: - face_centroid = centroid(face_points) - face_centroid.weight = settings.face_points_weight / weight_pref_total - pois.append(face_centroid) - - average_point = poi_average(pois, settings) - - if settings.annotate_image: - d = ImageDraw.Draw(im) - max_size = min(im.width, im.height) * 0.07 - if corner_centroid is not None: - color = BLUE - box = corner_centroid.bounding(max_size * corner_centroid.weight) - d.text((box[0], box[1] - 15), f"Edge: {corner_centroid.weight:.02f}", fill=color) - d.ellipse(box, outline=color) - if len(corner_points) > 1: - for f in corner_points: - d.rectangle(f.bounding(4), outline=color) - if entropy_centroid is not None: - color = "#ff0" - box = entropy_centroid.bounding(max_size * entropy_centroid.weight) - d.text((box[0], box[1] - 15), f"Entropy: {entropy_centroid.weight:.02f}", fill=color) - d.ellipse(box, outline=color) - if len(entropy_points) > 1: - for f in entropy_points: - d.rectangle(f.bounding(4), outline=color) - if face_centroid is not None: - color = RED - box = face_centroid.bounding(max_size * face_centroid.weight) - d.text((box[0], box[1] - 15), f"Face: {face_centroid.weight:.02f}", fill=color) - d.ellipse(box, outline=color) - if len(face_points) > 1: - for f in face_points: - d.rectangle(f.bounding(4), outline=color) - - d.ellipse(average_point.bounding(max_size), outline=GREEN) - - return average_point - - -def image_face_points(im, settings): - if settings.dnn_model_path is not None: - detector = cv2.FaceDetectorYN.create( - settings.dnn_model_path, - "", - (im.width, im.height), - 0.9, # score threshold - 0.3, # nms threshold - 5000 # keep top k before nms - ) - faces = detector.detect(np.array(im)) - results = [] - if faces[1] is not None: - for face in faces[1]: - x = face[0] - y = face[1] - w = face[2] - h = face[3] - results.append( - PointOfInterest( - int(x + (w * 0.5)), # face focus left/right is center - int(y + (h * 0.33)), # face focus up/down is close to the top of the head - size=w, - weight=1 / len(faces[1]) - ) - ) - return results - else: - np_im = np.array(im) - gray = cv2.cvtColor(np_im, cv2.COLOR_BGR2GRAY) - - tries = [ - [f'{cv2.data.haarcascades}haarcascade_eye.xml', 0.01], - [f'{cv2.data.haarcascades}haarcascade_frontalface_default.xml', 0.05], - [f'{cv2.data.haarcascades}haarcascade_profileface.xml', 0.05], - [f'{cv2.data.haarcascades}haarcascade_frontalface_alt.xml', 0.05], - [f'{cv2.data.haarcascades}haarcascade_frontalface_alt2.xml', 0.05], - [f'{cv2.data.haarcascades}haarcascade_frontalface_alt_tree.xml', 0.05], - [f'{cv2.data.haarcascades}haarcascade_eye_tree_eyeglasses.xml', 0.05], - [f'{cv2.data.haarcascades}haarcascade_upperbody.xml', 0.05] - ] - for t in tries: - classifier = cv2.CascadeClassifier(t[0]) - minsize = int(min(im.width, im.height) * t[1]) # at least N percent of the smallest side - try: - faces = classifier.detectMultiScale(gray, scaleFactor=1.1, - minNeighbors=7, minSize=(minsize, minsize), - flags=cv2.CASCADE_SCALE_IMAGE) - except Exception: - continue - - if faces: - rects = [[f[0], f[1], f[0] + f[2], f[1] + f[3]] for f in faces] - return [PointOfInterest((r[0] + r[2]) // 2, (r[1] + r[3]) // 2, size=abs(r[0] - r[2]), - weight=1 / len(rects)) for r in rects] - return [] - - -def image_corner_points(im, settings): - grayscale = im.convert("L") - - # naive attempt at preventing focal points from collecting at watermarks near the bottom - gd = ImageDraw.Draw(grayscale) - gd.rectangle([0, im.height * .9, im.width, im.height], fill="#999") - - np_im = np.array(grayscale) - - points = cv2.goodFeaturesToTrack( - np_im, - maxCorners=100, - qualityLevel=0.04, - minDistance=min(grayscale.width, grayscale.height) * 0.06, - useHarrisDetector=False, - ) - - if points is None: - return [] - - focal_points = [] - for point in points: - x, y = point.ravel() - focal_points.append(PointOfInterest(x, y, size=4, weight=1 / len(points))) - - return focal_points - - -def image_entropy_points(im, settings): - landscape = im.height < im.width - portrait = im.height > im.width - if landscape: - move_idx = [0, 2] - move_max = im.size[0] - elif portrait: - move_idx = [1, 3] - move_max = im.size[1] - else: - return [] - - e_max = 0 - crop_current = [0, 0, settings.crop_width, settings.crop_height] - crop_best = crop_current - while crop_current[move_idx[1]] < move_max: - crop = im.crop(tuple(crop_current)) - e = image_entropy(crop) - - if (e > e_max): - e_max = e - crop_best = list(crop_current) - - crop_current[move_idx[0]] += 4 - crop_current[move_idx[1]] += 4 - - x_mid = int(crop_best[0] + settings.crop_width / 2) - y_mid = int(crop_best[1] + settings.crop_height / 2) - - return [PointOfInterest(x_mid, y_mid, size=25, weight=1.0)] - - -def image_entropy(im): - # greyscale image entropy - # band = np.asarray(im.convert("L")) - band = np.asarray(im.convert("1"), dtype=np.uint8) - hist, _ = np.histogram(band, bins=range(0, 256)) - hist = hist[hist > 0] - return -np.log2(hist / hist.sum()).sum() - - -def centroid(pois): - x = [poi.x for poi in pois] - y = [poi.y for poi in pois] - return PointOfInterest(sum(x) / len(pois), sum(y) / len(pois)) - - -def poi_average(pois, settings): - weight = 0.0 - x = 0.0 - y = 0.0 - for poi in pois: - weight += poi.weight - x += poi.x * poi.weight - y += poi.y * poi.weight - avg_x = round(weight and x / weight) - avg_y = round(weight and y / weight) - - return PointOfInterest(avg_x, avg_y) - - -def is_landscape(w, h): - return w > h - - -def is_portrait(w, h): - return h > w - - -def is_square(w, h): - return w == h - - -model_dir_opencv = os.path.join(paths_internal.models_path, 'opencv') -if parse_version(cv2.__version__) >= parse_version('4.8'): - model_file_path = os.path.join(model_dir_opencv, 'face_detection_yunet_2023mar.onnx') - model_url = 'https://github.com/opencv/opencv_zoo/blob/b6e370b10f641879a87890d44e42173077154a05/models/face_detection_yunet/face_detection_yunet_2023mar.onnx?raw=true' -else: - model_file_path = os.path.join(model_dir_opencv, 'face_detection_yunet.onnx') - model_url = 'https://github.com/opencv/opencv_zoo/blob/91fb0290f50896f38a0ab1e558b74b16bc009428/models/face_detection_yunet/face_detection_yunet_2022mar.onnx?raw=true' - - -def download_and_cache_models(): - if not os.path.exists(model_file_path): - os.makedirs(model_dir_opencv, exist_ok=True) - print(f"downloading face detection model from '{model_url}' to '{model_file_path}'") - response = requests.get(model_url) - with open(model_file_path, "wb") as f: - f.write(response.content) - return model_file_path - - -class PointOfInterest: - def __init__(self, x, y, weight=1.0, size=10): - self.x = x - self.y = y - self.weight = weight - self.size = size - - def bounding(self, size): - return [ - self.x - size // 2, - self.y - size // 2, - self.x + size // 2, - self.y + size // 2 - ] - - -class Settings: - def __init__(self, crop_width=512, crop_height=512, corner_points_weight=0.5, entropy_points_weight=0.5, face_points_weight=0.5, annotate_image=False, dnn_model_path=None): - self.crop_width = crop_width - self.crop_height = crop_height - self.corner_points_weight = corner_points_weight - self.entropy_points_weight = entropy_points_weight - self.face_points_weight = face_points_weight - self.annotate_image = annotate_image - self.desktop_view_image = False - self.dnn_model_path = dnn_model_path +import cv2 +import requests +import os +import numpy as np +from PIL import ImageDraw +from modules import paths_internal +from pkg_resources import parse_version + +GREEN = "#0F0" +BLUE = "#00F" +RED = "#F00" + + +def crop_image(im, settings): + """ Intelligently crop an image to the subject matter """ + + scale_by = 1 + if is_landscape(im.width, im.height): + scale_by = settings.crop_height / im.height + elif is_portrait(im.width, im.height): + scale_by = settings.crop_width / im.width + elif is_square(im.width, im.height): + if is_square(settings.crop_width, settings.crop_height): + scale_by = settings.crop_width / im.width + elif is_landscape(settings.crop_width, settings.crop_height): + scale_by = settings.crop_width / im.width + elif is_portrait(settings.crop_width, settings.crop_height): + scale_by = settings.crop_height / im.height + + im = im.resize((int(im.width * scale_by), int(im.height * scale_by))) + im_debug = im.copy() + + focus = focal_point(im_debug, settings) + + # take the focal point and turn it into crop coordinates that try to center over the focal + # point but then get adjusted back into the frame + y_half = int(settings.crop_height / 2) + x_half = int(settings.crop_width / 2) + + x1 = focus.x - x_half + if x1 < 0: + x1 = 0 + elif x1 + settings.crop_width > im.width: + x1 = im.width - settings.crop_width + + y1 = focus.y - y_half + if y1 < 0: + y1 = 0 + elif y1 + settings.crop_height > im.height: + y1 = im.height - settings.crop_height + + x2 = x1 + settings.crop_width + y2 = y1 + settings.crop_height + + crop = [x1, y1, x2, y2] + + results = [] + + results.append(im.crop(tuple(crop))) + + if settings.annotate_image: + d = ImageDraw.Draw(im_debug) + rect = list(crop) + rect[2] -= 1 + rect[3] -= 1 + d.rectangle(rect, outline=GREEN) + results.append(im_debug) + if settings.desktop_view_image: + im_debug.show() + + return results + + +def focal_point(im, settings): + corner_points = image_corner_points(im, settings) if settings.corner_points_weight > 0 else [] + entropy_points = image_entropy_points(im, settings) if settings.entropy_points_weight > 0 else [] + face_points = image_face_points(im, settings) if settings.face_points_weight > 0 else [] + + pois = [] + + weight_pref_total = 0 + if corner_points: + weight_pref_total += settings.corner_points_weight + if entropy_points: + weight_pref_total += settings.entropy_points_weight + if face_points: + weight_pref_total += settings.face_points_weight + + corner_centroid = None + if corner_points: + corner_centroid = centroid(corner_points) + corner_centroid.weight = settings.corner_points_weight / weight_pref_total + pois.append(corner_centroid) + + entropy_centroid = None + if entropy_points: + entropy_centroid = centroid(entropy_points) + entropy_centroid.weight = settings.entropy_points_weight / weight_pref_total + pois.append(entropy_centroid) + + face_centroid = None + if face_points: + face_centroid = centroid(face_points) + face_centroid.weight = settings.face_points_weight / weight_pref_total + pois.append(face_centroid) + + average_point = poi_average(pois, settings) + + if settings.annotate_image: + d = ImageDraw.Draw(im) + max_size = min(im.width, im.height) * 0.07 + if corner_centroid is not None: + color = BLUE + box = corner_centroid.bounding(max_size * corner_centroid.weight) + d.text((box[0], box[1] - 15), f"Edge: {corner_centroid.weight:.02f}", fill=color) + d.ellipse(box, outline=color) + if len(corner_points) > 1: + for f in corner_points: + d.rectangle(f.bounding(4), outline=color) + if entropy_centroid is not None: + color = "#ff0" + box = entropy_centroid.bounding(max_size * entropy_centroid.weight) + d.text((box[0], box[1] - 15), f"Entropy: {entropy_centroid.weight:.02f}", fill=color) + d.ellipse(box, outline=color) + if len(entropy_points) > 1: + for f in entropy_points: + d.rectangle(f.bounding(4), outline=color) + if face_centroid is not None: + color = RED + box = face_centroid.bounding(max_size * face_centroid.weight) + d.text((box[0], box[1] - 15), f"Face: {face_centroid.weight:.02f}", fill=color) + d.ellipse(box, outline=color) + if len(face_points) > 1: + for f in face_points: + d.rectangle(f.bounding(4), outline=color) + + d.ellipse(average_point.bounding(max_size), outline=GREEN) + + return average_point + + +def image_face_points(im, settings): + if settings.dnn_model_path is not None: + detector = cv2.FaceDetectorYN.create( + settings.dnn_model_path, + "", + (im.width, im.height), + 0.9, # score threshold + 0.3, # nms threshold + 5000 # keep top k before nms + ) + faces = detector.detect(np.array(im)) + results = [] + if faces[1] is not None: + for face in faces[1]: + x = face[0] + y = face[1] + w = face[2] + h = face[3] + results.append( + PointOfInterest( + int(x + (w * 0.5)), # face focus left/right is center + int(y + (h * 0.33)), # face focus up/down is close to the top of the head + size=w, + weight=1 / len(faces[1]) + ) + ) + return results + else: + np_im = np.array(im) + gray = cv2.cvtColor(np_im, cv2.COLOR_BGR2GRAY) + + tries = [ + [f'{cv2.data.haarcascades}haarcascade_eye.xml', 0.01], + [f'{cv2.data.haarcascades}haarcascade_frontalface_default.xml', 0.05], + [f'{cv2.data.haarcascades}haarcascade_profileface.xml', 0.05], + [f'{cv2.data.haarcascades}haarcascade_frontalface_alt.xml', 0.05], + [f'{cv2.data.haarcascades}haarcascade_frontalface_alt2.xml', 0.05], + [f'{cv2.data.haarcascades}haarcascade_frontalface_alt_tree.xml', 0.05], + [f'{cv2.data.haarcascades}haarcascade_eye_tree_eyeglasses.xml', 0.05], + [f'{cv2.data.haarcascades}haarcascade_upperbody.xml', 0.05] + ] + for t in tries: + classifier = cv2.CascadeClassifier(t[0]) + minsize = int(min(im.width, im.height) * t[1]) # at least N percent of the smallest side + try: + faces = classifier.detectMultiScale(gray, scaleFactor=1.1, + minNeighbors=7, minSize=(minsize, minsize), + flags=cv2.CASCADE_SCALE_IMAGE) + except Exception: + continue + + if faces: + rects = [[f[0], f[1], f[0] + f[2], f[1] + f[3]] for f in faces] + return [PointOfInterest((r[0] + r[2]) // 2, (r[1] + r[3]) // 2, size=abs(r[0] - r[2]), + weight=1 / len(rects)) for r in rects] + return [] + + +def image_corner_points(im, settings): + grayscale = im.convert("L") + + # naive attempt at preventing focal points from collecting at watermarks near the bottom + gd = ImageDraw.Draw(grayscale) + gd.rectangle([0, im.height * .9, im.width, im.height], fill="#999") + + np_im = np.array(grayscale) + + points = cv2.goodFeaturesToTrack( + np_im, + maxCorners=100, + qualityLevel=0.04, + minDistance=min(grayscale.width, grayscale.height) * 0.06, + useHarrisDetector=False, + ) + + if points is None: + return [] + + focal_points = [] + for point in points: + x, y = point.ravel() + focal_points.append(PointOfInterest(x, y, size=4, weight=1 / len(points))) + + return focal_points + + +def image_entropy_points(im, settings): + landscape = im.height < im.width + portrait = im.height > im.width + if landscape: + move_idx = [0, 2] + move_max = im.size[0] + elif portrait: + move_idx = [1, 3] + move_max = im.size[1] + else: + return [] + + e_max = 0 + crop_current = [0, 0, settings.crop_width, settings.crop_height] + crop_best = crop_current + while crop_current[move_idx[1]] < move_max: + crop = im.crop(tuple(crop_current)) + e = image_entropy(crop) + + if (e > e_max): + e_max = e + crop_best = list(crop_current) + + crop_current[move_idx[0]] += 4 + crop_current[move_idx[1]] += 4 + + x_mid = int(crop_best[0] + settings.crop_width / 2) + y_mid = int(crop_best[1] + settings.crop_height / 2) + + return [PointOfInterest(x_mid, y_mid, size=25, weight=1.0)] + + +def image_entropy(im): + # greyscale image entropy + # band = np.asarray(im.convert("L")) + band = np.asarray(im.convert("1"), dtype=np.uint8) + hist, _ = np.histogram(band, bins=range(0, 256)) + hist = hist[hist > 0] + return -np.log2(hist / hist.sum()).sum() + + +def centroid(pois): + x = [poi.x for poi in pois] + y = [poi.y for poi in pois] + return PointOfInterest(sum(x) / len(pois), sum(y) / len(pois)) + + +def poi_average(pois, settings): + weight = 0.0 + x = 0.0 + y = 0.0 + for poi in pois: + weight += poi.weight + x += poi.x * poi.weight + y += poi.y * poi.weight + avg_x = round(weight and x / weight) + avg_y = round(weight and y / weight) + + return PointOfInterest(avg_x, avg_y) + + +def is_landscape(w, h): + return w > h + + +def is_portrait(w, h): + return h > w + + +def is_square(w, h): + return w == h + + +model_dir_opencv = os.path.join(paths_internal.models_path, 'opencv') +if parse_version(cv2.__version__) >= parse_version('4.8'): + model_file_path = os.path.join(model_dir_opencv, 'face_detection_yunet_2023mar.onnx') + model_url = 'https://github.com/opencv/opencv_zoo/blob/b6e370b10f641879a87890d44e42173077154a05/models/face_detection_yunet/face_detection_yunet_2023mar.onnx?raw=true' +else: + model_file_path = os.path.join(model_dir_opencv, 'face_detection_yunet.onnx') + model_url = 'https://github.com/opencv/opencv_zoo/blob/91fb0290f50896f38a0ab1e558b74b16bc009428/models/face_detection_yunet/face_detection_yunet_2022mar.onnx?raw=true' + + +def download_and_cache_models(): + if not os.path.exists(model_file_path): + os.makedirs(model_dir_opencv, exist_ok=True) + print(f"downloading face detection model from '{model_url}' to '{model_file_path}'") + response = requests.get(model_url, timeout=10.0) + with open(model_file_path, "wb") as f: + f.write(response.content) + return model_file_path + + +class PointOfInterest: + def __init__(self, x, y, weight=1.0, size=10): + self.x = x + self.y = y + self.weight = weight + self.size = size + + def bounding(self, size): + return [ + self.x - size // 2, + self.y - size // 2, + self.x + size // 2, + self.y + size // 2 + ] + + +class Settings: + def __init__(self, crop_width=512, crop_height=512, corner_points_weight=0.5, entropy_points_weight=0.5, face_points_weight=0.5, annotate_image=False, dnn_model_path=None): + self.crop_width = crop_width + self.crop_height = crop_height + self.corner_points_weight = corner_points_weight + self.entropy_points_weight = entropy_points_weight + self.face_points_weight = face_points_weight + self.annotate_image = annotate_image + self.desktop_view_image = False + self.dnn_model_path = dnn_model_path