mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2026-03-26 08:00:46 -07:00
Merge 44a58b0233 into 82a973c043
This commit is contained in:
commit
54a9f6827b
2 changed files with 827 additions and 827 deletions
|
|
@ -1,482 +1,482 @@
|
|||
# this scripts installs necessary requirements and launches main program in webui.py
|
||||
import logging
|
||||
import re
|
||||
import subprocess
|
||||
import os
|
||||
import shutil
|
||||
import sys
|
||||
import importlib.util
|
||||
import importlib.metadata
|
||||
import platform
|
||||
import json
|
||||
import shlex
|
||||
from functools import lru_cache
|
||||
|
||||
from modules import cmd_args, errors
|
||||
from modules.paths_internal import script_path, extensions_dir
|
||||
from modules.timer import startup_timer
|
||||
from modules import logging_config
|
||||
|
||||
args, _ = cmd_args.parser.parse_known_args()
|
||||
logging_config.setup_logging(args.loglevel)
|
||||
|
||||
python = sys.executable
|
||||
git = os.environ.get('GIT', "git")
|
||||
index_url = os.environ.get('INDEX_URL', "")
|
||||
dir_repos = "repositories"
|
||||
|
||||
# Whether to default to printing command output
|
||||
default_command_live = (os.environ.get('WEBUI_LAUNCH_LIVE_OUTPUT') == "1")
|
||||
|
||||
os.environ.setdefault('GRADIO_ANALYTICS_ENABLED', 'False')
|
||||
|
||||
|
||||
def check_python_version():
|
||||
is_windows = platform.system() == "Windows"
|
||||
major = sys.version_info.major
|
||||
minor = sys.version_info.minor
|
||||
micro = sys.version_info.micro
|
||||
|
||||
if is_windows:
|
||||
supported_minors = [10]
|
||||
else:
|
||||
supported_minors = [7, 8, 9, 10, 11]
|
||||
|
||||
if not (major == 3 and minor in supported_minors):
|
||||
import modules.errors
|
||||
|
||||
modules.errors.print_error_explanation(f"""
|
||||
INCOMPATIBLE PYTHON VERSION
|
||||
|
||||
This program is tested with 3.10.6 Python, but you have {major}.{minor}.{micro}.
|
||||
If you encounter an error with "RuntimeError: Couldn't install torch." message,
|
||||
or any other error regarding unsuccessful package (library) installation,
|
||||
please downgrade (or upgrade) to the latest version of 3.10 Python
|
||||
and delete current Python and "venv" folder in WebUI's directory.
|
||||
|
||||
You can download 3.10 Python from here: https://www.python.org/downloads/release/python-3106/
|
||||
|
||||
{"Alternatively, use a binary release of WebUI: https://github.com/AUTOMATIC1111/stable-diffusion-webui/releases/tag/v1.0.0-pre" if is_windows else ""}
|
||||
|
||||
Use --skip-python-version-check to suppress this warning.
|
||||
""")
|
||||
|
||||
|
||||
@lru_cache()
|
||||
def commit_hash():
|
||||
try:
|
||||
return subprocess.check_output([git, "-C", script_path, "rev-parse", "HEAD"], shell=False, encoding='utf8').strip()
|
||||
except Exception:
|
||||
return "<none>"
|
||||
|
||||
|
||||
@lru_cache()
|
||||
def git_tag():
|
||||
try:
|
||||
return subprocess.check_output([git, "-C", script_path, "describe", "--tags"], shell=False, encoding='utf8').strip()
|
||||
except Exception:
|
||||
try:
|
||||
|
||||
changelog_md = os.path.join(script_path, "CHANGELOG.md")
|
||||
with open(changelog_md, "r", encoding="utf-8") as file:
|
||||
line = next((line.strip() for line in file if line.strip()), "<none>")
|
||||
line = line.replace("## ", "")
|
||||
return line
|
||||
except Exception:
|
||||
return "<none>"
|
||||
|
||||
|
||||
def run(command, desc=None, errdesc=None, custom_env=None, live: bool = default_command_live) -> str:
|
||||
if desc is not None:
|
||||
print(desc)
|
||||
|
||||
run_kwargs = {
|
||||
"args": command,
|
||||
"shell": True,
|
||||
"env": os.environ if custom_env is None else custom_env,
|
||||
"encoding": 'utf8',
|
||||
"errors": 'ignore',
|
||||
}
|
||||
|
||||
if not live:
|
||||
run_kwargs["stdout"] = run_kwargs["stderr"] = subprocess.PIPE
|
||||
|
||||
result = subprocess.run(**run_kwargs)
|
||||
|
||||
if result.returncode != 0:
|
||||
error_bits = [
|
||||
f"{errdesc or 'Error running command'}.",
|
||||
f"Command: {command}",
|
||||
f"Error code: {result.returncode}",
|
||||
]
|
||||
if result.stdout:
|
||||
error_bits.append(f"stdout: {result.stdout}")
|
||||
if result.stderr:
|
||||
error_bits.append(f"stderr: {result.stderr}")
|
||||
raise RuntimeError("\n".join(error_bits))
|
||||
|
||||
return (result.stdout or "")
|
||||
|
||||
|
||||
def is_installed(package):
|
||||
try:
|
||||
dist = importlib.metadata.distribution(package)
|
||||
except importlib.metadata.PackageNotFoundError:
|
||||
try:
|
||||
spec = importlib.util.find_spec(package)
|
||||
except ModuleNotFoundError:
|
||||
return False
|
||||
|
||||
return spec is not None
|
||||
|
||||
return dist is not None
|
||||
|
||||
|
||||
def repo_dir(name):
|
||||
return os.path.join(script_path, dir_repos, name)
|
||||
|
||||
|
||||
def run_pip(command, desc=None, live=default_command_live):
|
||||
if args.skip_install:
|
||||
return
|
||||
|
||||
index_url_line = f' --index-url {index_url}' if index_url != '' else ''
|
||||
return run(f'"{python}" -m pip {command} --prefer-binary{index_url_line}', desc=f"Installing {desc}", errdesc=f"Couldn't install {desc}", live=live)
|
||||
|
||||
|
||||
def check_run_python(code: str) -> bool:
|
||||
result = subprocess.run([python, "-c", code], capture_output=True, shell=False)
|
||||
return result.returncode == 0
|
||||
|
||||
|
||||
def git_fix_workspace(dir, name):
|
||||
run(f'"{git}" -C "{dir}" fetch --refetch --no-auto-gc', f"Fetching all contents for {name}", f"Couldn't fetch {name}", live=True)
|
||||
run(f'"{git}" -C "{dir}" gc --aggressive --prune=now', f"Pruning {name}", f"Couldn't prune {name}", live=True)
|
||||
return
|
||||
|
||||
|
||||
def run_git(dir, name, command, desc=None, errdesc=None, custom_env=None, live: bool = default_command_live, autofix=True):
|
||||
try:
|
||||
return run(f'"{git}" -C "{dir}" {command}', desc=desc, errdesc=errdesc, custom_env=custom_env, live=live)
|
||||
except RuntimeError:
|
||||
if not autofix:
|
||||
raise
|
||||
|
||||
print(f"{errdesc}, attempting autofix...")
|
||||
git_fix_workspace(dir, name)
|
||||
|
||||
return run(f'"{git}" -C "{dir}" {command}', desc=desc, errdesc=errdesc, custom_env=custom_env, live=live)
|
||||
|
||||
|
||||
def git_clone(url, dir, name, commithash=None):
|
||||
# TODO clone into temporary dir and move if successful
|
||||
|
||||
if os.path.exists(dir):
|
||||
if commithash is None:
|
||||
return
|
||||
|
||||
current_hash = run_git(dir, name, 'rev-parse HEAD', None, f"Couldn't determine {name}'s hash: {commithash}", live=False).strip()
|
||||
if current_hash == commithash:
|
||||
return
|
||||
|
||||
if run_git(dir, name, 'config --get remote.origin.url', None, f"Couldn't determine {name}'s origin URL", live=False).strip() != url:
|
||||
run_git(dir, name, f'remote set-url origin "{url}"', None, f"Failed to set {name}'s origin URL", live=False)
|
||||
|
||||
run_git(dir, name, 'fetch', f"Fetching updates for {name}...", f"Couldn't fetch {name}", autofix=False)
|
||||
|
||||
run_git(dir, name, f'checkout {commithash}', f"Checking out commit for {name} with hash: {commithash}...", f"Couldn't checkout commit {commithash} for {name}", live=True)
|
||||
|
||||
return
|
||||
|
||||
try:
|
||||
run(f'"{git}" clone --config core.filemode=false "{url}" "{dir}"', f"Cloning {name} into {dir}...", f"Couldn't clone {name}", live=True)
|
||||
except RuntimeError:
|
||||
shutil.rmtree(dir, ignore_errors=True)
|
||||
raise
|
||||
|
||||
if commithash is not None:
|
||||
run(f'"{git}" -C "{dir}" checkout {commithash}', None, "Couldn't checkout {name}'s hash: {commithash}")
|
||||
|
||||
|
||||
def git_pull_recursive(dir):
|
||||
for subdir, _, _ in os.walk(dir):
|
||||
if os.path.exists(os.path.join(subdir, '.git')):
|
||||
try:
|
||||
output = subprocess.check_output([git, '-C', subdir, 'pull', '--autostash'])
|
||||
print(f"Pulled changes for repository in '{subdir}':\n{output.decode('utf-8').strip()}\n")
|
||||
except subprocess.CalledProcessError as e:
|
||||
print(f"Couldn't perform 'git pull' on repository in '{subdir}':\n{e.output.decode('utf-8').strip()}\n")
|
||||
|
||||
|
||||
def version_check(commit):
|
||||
try:
|
||||
import requests
|
||||
commits = requests.get('https://api.github.com/repos/AUTOMATIC1111/stable-diffusion-webui/branches/master').json()
|
||||
if commit != "<none>" and commits['commit']['sha'] != commit:
|
||||
print("--------------------------------------------------------")
|
||||
print("| You are not up to date with the most recent release. |")
|
||||
print("| Consider running `git pull` to update. |")
|
||||
print("--------------------------------------------------------")
|
||||
elif commits['commit']['sha'] == commit:
|
||||
print("You are up to date with the most recent release.")
|
||||
else:
|
||||
print("Not a git clone, can't perform version check.")
|
||||
except Exception as e:
|
||||
print("version check failed", e)
|
||||
|
||||
|
||||
def run_extension_installer(extension_dir):
|
||||
path_installer = os.path.join(extension_dir, "install.py")
|
||||
if not os.path.isfile(path_installer):
|
||||
return
|
||||
|
||||
try:
|
||||
env = os.environ.copy()
|
||||
env['PYTHONPATH'] = f"{script_path}{os.pathsep}{env.get('PYTHONPATH', '')}"
|
||||
|
||||
stdout = run(f'"{python}" "{path_installer}"', errdesc=f"Error running install.py for extension {extension_dir}", custom_env=env).strip()
|
||||
if stdout:
|
||||
print(stdout)
|
||||
except Exception as e:
|
||||
errors.report(str(e))
|
||||
|
||||
|
||||
def list_extensions(settings_file):
|
||||
settings = {}
|
||||
|
||||
try:
|
||||
with open(settings_file, "r", encoding="utf8") as file:
|
||||
settings = json.load(file)
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
except Exception:
|
||||
errors.report(f'\nCould not load settings\nThe config file "{settings_file}" is likely corrupted\nIt has been moved to the "tmp/config.json"\nReverting config to default\n\n''', exc_info=True)
|
||||
os.replace(settings_file, os.path.join(script_path, "tmp", "config.json"))
|
||||
|
||||
disabled_extensions = set(settings.get('disabled_extensions', []))
|
||||
disable_all_extensions = settings.get('disable_all_extensions', 'none')
|
||||
|
||||
if disable_all_extensions != 'none' or args.disable_extra_extensions or args.disable_all_extensions or not os.path.isdir(extensions_dir):
|
||||
return []
|
||||
|
||||
return [x for x in os.listdir(extensions_dir) if x not in disabled_extensions]
|
||||
|
||||
|
||||
def run_extensions_installers(settings_file):
|
||||
if not os.path.isdir(extensions_dir):
|
||||
return
|
||||
|
||||
with startup_timer.subcategory("run extensions installers"):
|
||||
for dirname_extension in list_extensions(settings_file):
|
||||
logging.debug(f"Installing {dirname_extension}")
|
||||
|
||||
path = os.path.join(extensions_dir, dirname_extension)
|
||||
|
||||
if os.path.isdir(path):
|
||||
run_extension_installer(path)
|
||||
startup_timer.record(dirname_extension)
|
||||
|
||||
|
||||
re_requirement = re.compile(r"\s*([-_a-zA-Z0-9]+)\s*(?:==\s*([-+_.a-zA-Z0-9]+))?\s*")
|
||||
|
||||
|
||||
def requirements_met(requirements_file):
|
||||
"""
|
||||
Does a simple parse of a requirements.txt file to determine if all rerqirements in it
|
||||
are already installed. Returns True if so, False if not installed or parsing fails.
|
||||
"""
|
||||
|
||||
import importlib.metadata
|
||||
import packaging.version
|
||||
|
||||
with open(requirements_file, "r", encoding="utf8") as file:
|
||||
for line in file:
|
||||
if line.strip() == "":
|
||||
continue
|
||||
|
||||
m = re.match(re_requirement, line)
|
||||
if m is None:
|
||||
return False
|
||||
|
||||
package = m.group(1).strip()
|
||||
version_required = (m.group(2) or "").strip()
|
||||
|
||||
if version_required == "":
|
||||
continue
|
||||
|
||||
try:
|
||||
version_installed = importlib.metadata.version(package)
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
if packaging.version.parse(version_required) != packaging.version.parse(version_installed):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def prepare_environment():
|
||||
torch_index_url = os.environ.get('TORCH_INDEX_URL', "https://download.pytorch.org/whl/cu121")
|
||||
torch_command = os.environ.get('TORCH_COMMAND', f"pip install torch==2.1.2 torchvision==0.16.2 --extra-index-url {torch_index_url}")
|
||||
if args.use_ipex:
|
||||
if platform.system() == "Windows":
|
||||
# The "Nuullll/intel-extension-for-pytorch" wheels were built from IPEX source for Intel Arc GPU: https://github.com/intel/intel-extension-for-pytorch/tree/xpu-main
|
||||
# This is NOT an Intel official release so please use it at your own risk!!
|
||||
# See https://github.com/Nuullll/intel-extension-for-pytorch/releases/tag/v2.0.110%2Bxpu-master%2Bdll-bundle for details.
|
||||
#
|
||||
# Strengths (over official IPEX 2.0.110 windows release):
|
||||
# - AOT build (for Arc GPU only) to eliminate JIT compilation overhead: https://github.com/intel/intel-extension-for-pytorch/issues/399
|
||||
# - Bundles minimal oneAPI 2023.2 dependencies into the python wheels, so users don't need to install oneAPI for the whole system.
|
||||
# - Provides a compatible torchvision wheel: https://github.com/intel/intel-extension-for-pytorch/issues/465
|
||||
# Limitation:
|
||||
# - Only works for python 3.10
|
||||
url_prefix = "https://github.com/Nuullll/intel-extension-for-pytorch/releases/download/v2.0.110%2Bxpu-master%2Bdll-bundle"
|
||||
torch_command = os.environ.get('TORCH_COMMAND', f"pip install {url_prefix}/torch-2.0.0a0+gite9ebda2-cp310-cp310-win_amd64.whl {url_prefix}/torchvision-0.15.2a0+fa99a53-cp310-cp310-win_amd64.whl {url_prefix}/intel_extension_for_pytorch-2.0.110+gitc6ea20b-cp310-cp310-win_amd64.whl")
|
||||
else:
|
||||
# Using official IPEX release for linux since it's already an AOT build.
|
||||
# However, users still have to install oneAPI toolkit and activate oneAPI environment manually.
|
||||
# See https://intel.github.io/intel-extension-for-pytorch/index.html#installation for details.
|
||||
torch_index_url = os.environ.get('TORCH_INDEX_URL', "https://pytorch-extension.intel.com/release-whl/stable/xpu/us/")
|
||||
torch_command = os.environ.get('TORCH_COMMAND', f"pip install torch==2.0.0a0 intel-extension-for-pytorch==2.0.110+gitba7f6c1 --extra-index-url {torch_index_url}")
|
||||
requirements_file = os.environ.get('REQS_FILE', "requirements_versions.txt")
|
||||
requirements_file_for_npu = os.environ.get('REQS_FILE_FOR_NPU', "requirements_npu.txt")
|
||||
|
||||
xformers_package = os.environ.get('XFORMERS_PACKAGE', 'xformers==0.0.23.post1')
|
||||
clip_package = os.environ.get('CLIP_PACKAGE', "https://github.com/openai/CLIP/archive/d50d76daa670286dd6cacf3bcd80b5e4823fc8e1.zip")
|
||||
openclip_package = os.environ.get('OPENCLIP_PACKAGE', "https://github.com/mlfoundations/open_clip/archive/bb6e834e9c70d9c27d0dc3ecedeebeaeb1ffad6b.zip")
|
||||
|
||||
assets_repo = os.environ.get('ASSETS_REPO', "https://github.com/AUTOMATIC1111/stable-diffusion-webui-assets.git")
|
||||
stable_diffusion_repo = os.environ.get('STABLE_DIFFUSION_REPO', "https://github.com/Stability-AI/stablediffusion.git")
|
||||
stable_diffusion_xl_repo = os.environ.get('STABLE_DIFFUSION_XL_REPO', "https://github.com/Stability-AI/generative-models.git")
|
||||
k_diffusion_repo = os.environ.get('K_DIFFUSION_REPO', 'https://github.com/crowsonkb/k-diffusion.git')
|
||||
blip_repo = os.environ.get('BLIP_REPO', 'https://github.com/salesforce/BLIP.git')
|
||||
|
||||
assets_commit_hash = os.environ.get('ASSETS_COMMIT_HASH', "6f7db241d2f8ba7457bac5ca9753331f0c266917")
|
||||
stable_diffusion_commit_hash = os.environ.get('STABLE_DIFFUSION_COMMIT_HASH', "cf1d67a6fd5ea1aa600c4df58e5b47da45f6bdbf")
|
||||
stable_diffusion_xl_commit_hash = os.environ.get('STABLE_DIFFUSION_XL_COMMIT_HASH', "45c443b316737a4ab6e40413d7794a7f5657c19f")
|
||||
k_diffusion_commit_hash = os.environ.get('K_DIFFUSION_COMMIT_HASH', "ab527a9a6d347f364e3d185ba6d714e22d80cb3c")
|
||||
blip_commit_hash = os.environ.get('BLIP_COMMIT_HASH', "48211a1594f1321b00f14c9f7a5b4813144b2fb9")
|
||||
|
||||
try:
|
||||
# the existence of this file is a signal to webui.sh/bat that webui needs to be restarted when it stops execution
|
||||
os.remove(os.path.join(script_path, "tmp", "restart"))
|
||||
os.environ.setdefault('SD_WEBUI_RESTARTING', '1')
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
if not args.skip_python_version_check:
|
||||
check_python_version()
|
||||
|
||||
startup_timer.record("checks")
|
||||
|
||||
commit = commit_hash()
|
||||
tag = git_tag()
|
||||
startup_timer.record("git version info")
|
||||
|
||||
print(f"Python {sys.version}")
|
||||
print(f"Version: {tag}")
|
||||
print(f"Commit hash: {commit}")
|
||||
|
||||
if args.reinstall_torch or not is_installed("torch") or not is_installed("torchvision"):
|
||||
run(f'"{python}" -m {torch_command}', "Installing torch and torchvision", "Couldn't install torch", live=True)
|
||||
startup_timer.record("install torch")
|
||||
|
||||
if args.use_ipex:
|
||||
args.skip_torch_cuda_test = True
|
||||
if not args.skip_torch_cuda_test and not check_run_python("import torch; assert torch.cuda.is_available()"):
|
||||
raise RuntimeError(
|
||||
'Torch is not able to use GPU; '
|
||||
'add --skip-torch-cuda-test to COMMANDLINE_ARGS variable to disable this check'
|
||||
)
|
||||
startup_timer.record("torch GPU test")
|
||||
|
||||
if not is_installed("clip"):
|
||||
run_pip(f"install {clip_package}", "clip")
|
||||
startup_timer.record("install clip")
|
||||
|
||||
if not is_installed("open_clip"):
|
||||
run_pip(f"install {openclip_package}", "open_clip")
|
||||
startup_timer.record("install open_clip")
|
||||
|
||||
if (not is_installed("xformers") or args.reinstall_xformers) and args.xformers:
|
||||
run_pip(f"install -U -I --no-deps {xformers_package}", "xformers")
|
||||
startup_timer.record("install xformers")
|
||||
|
||||
if not is_installed("ngrok") and args.ngrok:
|
||||
run_pip("install ngrok", "ngrok")
|
||||
startup_timer.record("install ngrok")
|
||||
|
||||
os.makedirs(os.path.join(script_path, dir_repos), exist_ok=True)
|
||||
|
||||
git_clone(assets_repo, repo_dir('stable-diffusion-webui-assets'), "assets", assets_commit_hash)
|
||||
git_clone(stable_diffusion_repo, repo_dir('stable-diffusion-stability-ai'), "Stable Diffusion", stable_diffusion_commit_hash)
|
||||
git_clone(stable_diffusion_xl_repo, repo_dir('generative-models'), "Stable Diffusion XL", stable_diffusion_xl_commit_hash)
|
||||
git_clone(k_diffusion_repo, repo_dir('k-diffusion'), "K-diffusion", k_diffusion_commit_hash)
|
||||
git_clone(blip_repo, repo_dir('BLIP'), "BLIP", blip_commit_hash)
|
||||
|
||||
startup_timer.record("clone repositores")
|
||||
|
||||
if not os.path.isfile(requirements_file):
|
||||
requirements_file = os.path.join(script_path, requirements_file)
|
||||
|
||||
if not requirements_met(requirements_file):
|
||||
run_pip(f"install -r \"{requirements_file}\"", "requirements")
|
||||
startup_timer.record("install requirements")
|
||||
|
||||
if not os.path.isfile(requirements_file_for_npu):
|
||||
requirements_file_for_npu = os.path.join(script_path, requirements_file_for_npu)
|
||||
|
||||
if "torch_npu" in torch_command and not requirements_met(requirements_file_for_npu):
|
||||
run_pip(f"install -r \"{requirements_file_for_npu}\"", "requirements_for_npu")
|
||||
startup_timer.record("install requirements_for_npu")
|
||||
|
||||
if not args.skip_install:
|
||||
run_extensions_installers(settings_file=args.ui_settings_file)
|
||||
|
||||
if args.update_check:
|
||||
version_check(commit)
|
||||
startup_timer.record("check version")
|
||||
|
||||
if args.update_all_extensions:
|
||||
git_pull_recursive(extensions_dir)
|
||||
startup_timer.record("update extensions")
|
||||
|
||||
if "--exit" in sys.argv:
|
||||
print("Exiting because of --exit argument")
|
||||
exit(0)
|
||||
|
||||
|
||||
def configure_for_tests():
|
||||
if "--api" not in sys.argv:
|
||||
sys.argv.append("--api")
|
||||
if "--ckpt" not in sys.argv:
|
||||
sys.argv.append("--ckpt")
|
||||
sys.argv.append(os.path.join(script_path, "test/test_files/empty.pt"))
|
||||
if "--skip-torch-cuda-test" not in sys.argv:
|
||||
sys.argv.append("--skip-torch-cuda-test")
|
||||
if "--disable-nan-check" not in sys.argv:
|
||||
sys.argv.append("--disable-nan-check")
|
||||
|
||||
os.environ['COMMANDLINE_ARGS'] = ""
|
||||
|
||||
|
||||
def start():
|
||||
print(f"Launching {'API server' if '--nowebui' in sys.argv else 'Web UI'} with arguments: {shlex.join(sys.argv[1:])}")
|
||||
import webui
|
||||
if '--nowebui' in sys.argv:
|
||||
webui.api_only()
|
||||
else:
|
||||
webui.webui()
|
||||
|
||||
|
||||
def dump_sysinfo():
|
||||
from modules import sysinfo
|
||||
import datetime
|
||||
|
||||
text = sysinfo.get()
|
||||
filename = f"sysinfo-{datetime.datetime.utcnow().strftime('%Y-%m-%d-%H-%M')}.json"
|
||||
|
||||
with open(filename, "w", encoding="utf8") as file:
|
||||
file.write(text)
|
||||
|
||||
return filename
|
||||
# this scripts installs necessary requirements and launches main program in webui.py
|
||||
import logging
|
||||
import re
|
||||
import subprocess
|
||||
import os
|
||||
import shutil
|
||||
import sys
|
||||
import importlib.util
|
||||
import importlib.metadata
|
||||
import platform
|
||||
import json
|
||||
import shlex
|
||||
from functools import lru_cache
|
||||
|
||||
from modules import cmd_args, errors
|
||||
from modules.paths_internal import script_path, extensions_dir
|
||||
from modules.timer import startup_timer
|
||||
from modules import logging_config
|
||||
|
||||
args, _ = cmd_args.parser.parse_known_args()
|
||||
logging_config.setup_logging(args.loglevel)
|
||||
|
||||
python = sys.executable
|
||||
git = os.environ.get('GIT', "git")
|
||||
index_url = os.environ.get('INDEX_URL', "")
|
||||
dir_repos = "repositories"
|
||||
|
||||
# Whether to default to printing command output
|
||||
default_command_live = (os.environ.get('WEBUI_LAUNCH_LIVE_OUTPUT') == "1")
|
||||
|
||||
os.environ.setdefault('GRADIO_ANALYTICS_ENABLED', 'False')
|
||||
|
||||
|
||||
def check_python_version():
|
||||
is_windows = platform.system() == "Windows"
|
||||
major = sys.version_info.major
|
||||
minor = sys.version_info.minor
|
||||
micro = sys.version_info.micro
|
||||
|
||||
if is_windows:
|
||||
supported_minors = [10]
|
||||
else:
|
||||
supported_minors = [7, 8, 9, 10, 11]
|
||||
|
||||
if not (major == 3 and minor in supported_minors):
|
||||
import modules.errors
|
||||
|
||||
modules.errors.print_error_explanation(f"""
|
||||
INCOMPATIBLE PYTHON VERSION
|
||||
|
||||
This program is tested with 3.10.6 Python, but you have {major}.{minor}.{micro}.
|
||||
If you encounter an error with "RuntimeError: Couldn't install torch." message,
|
||||
or any other error regarding unsuccessful package (library) installation,
|
||||
please downgrade (or upgrade) to the latest version of 3.10 Python
|
||||
and delete current Python and "venv" folder in WebUI's directory.
|
||||
|
||||
You can download 3.10 Python from here: https://www.python.org/downloads/release/python-3106/
|
||||
|
||||
{"Alternatively, use a binary release of WebUI: https://github.com/AUTOMATIC1111/stable-diffusion-webui/releases/tag/v1.0.0-pre" if is_windows else ""}
|
||||
|
||||
Use --skip-python-version-check to suppress this warning.
|
||||
""")
|
||||
|
||||
|
||||
@lru_cache()
|
||||
def commit_hash():
|
||||
try:
|
||||
return subprocess.check_output([git, "-C", script_path, "rev-parse", "HEAD"], shell=False, encoding='utf8').strip()
|
||||
except Exception:
|
||||
return "<none>"
|
||||
|
||||
|
||||
@lru_cache()
|
||||
def git_tag():
|
||||
try:
|
||||
return subprocess.check_output([git, "-C", script_path, "describe", "--tags"], shell=False, encoding='utf8').strip()
|
||||
except Exception:
|
||||
try:
|
||||
|
||||
changelog_md = os.path.join(script_path, "CHANGELOG.md")
|
||||
with open(changelog_md, "r", encoding="utf-8") as file:
|
||||
line = next((line.strip() for line in file if line.strip()), "<none>")
|
||||
line = line.replace("## ", "")
|
||||
return line
|
||||
except Exception:
|
||||
return "<none>"
|
||||
|
||||
|
||||
def run(command, desc=None, errdesc=None, custom_env=None, live: bool = default_command_live) -> str:
|
||||
if desc is not None:
|
||||
print(desc)
|
||||
|
||||
run_kwargs = {
|
||||
"args": command,
|
||||
"shell": True,
|
||||
"env": os.environ if custom_env is None else custom_env,
|
||||
"encoding": 'utf8',
|
||||
"errors": 'ignore',
|
||||
}
|
||||
|
||||
if not live:
|
||||
run_kwargs["stdout"] = run_kwargs["stderr"] = subprocess.PIPE
|
||||
|
||||
result = subprocess.run(**run_kwargs)
|
||||
|
||||
if result.returncode != 0:
|
||||
error_bits = [
|
||||
f"{errdesc or 'Error running command'}.",
|
||||
f"Command: {command}",
|
||||
f"Error code: {result.returncode}",
|
||||
]
|
||||
if result.stdout:
|
||||
error_bits.append(f"stdout: {result.stdout}")
|
||||
if result.stderr:
|
||||
error_bits.append(f"stderr: {result.stderr}")
|
||||
raise RuntimeError("\n".join(error_bits))
|
||||
|
||||
return (result.stdout or "")
|
||||
|
||||
|
||||
def is_installed(package):
|
||||
try:
|
||||
dist = importlib.metadata.distribution(package)
|
||||
except importlib.metadata.PackageNotFoundError:
|
||||
try:
|
||||
spec = importlib.util.find_spec(package)
|
||||
except ModuleNotFoundError:
|
||||
return False
|
||||
|
||||
return spec is not None
|
||||
|
||||
return dist is not None
|
||||
|
||||
|
||||
def repo_dir(name):
|
||||
return os.path.join(script_path, dir_repos, name)
|
||||
|
||||
|
||||
def run_pip(command, desc=None, live=default_command_live):
|
||||
if args.skip_install:
|
||||
return
|
||||
|
||||
index_url_line = f' --index-url {index_url}' if index_url != '' else ''
|
||||
return run(f'"{python}" -m pip {command} --prefer-binary{index_url_line}', desc=f"Installing {desc}", errdesc=f"Couldn't install {desc}", live=live)
|
||||
|
||||
|
||||
def check_run_python(code: str) -> bool:
|
||||
result = subprocess.run([python, "-c", code], capture_output=True, shell=False)
|
||||
return result.returncode == 0
|
||||
|
||||
|
||||
def git_fix_workspace(dir, name):
|
||||
run(f'"{git}" -C "{dir}" fetch --refetch --no-auto-gc', f"Fetching all contents for {name}", f"Couldn't fetch {name}", live=True)
|
||||
run(f'"{git}" -C "{dir}" gc --aggressive --prune=now', f"Pruning {name}", f"Couldn't prune {name}", live=True)
|
||||
return
|
||||
|
||||
|
||||
def run_git(dir, name, command, desc=None, errdesc=None, custom_env=None, live: bool = default_command_live, autofix=True):
|
||||
try:
|
||||
return run(f'"{git}" -C "{dir}" {command}', desc=desc, errdesc=errdesc, custom_env=custom_env, live=live)
|
||||
except RuntimeError:
|
||||
if not autofix:
|
||||
raise
|
||||
|
||||
print(f"{errdesc}, attempting autofix...")
|
||||
git_fix_workspace(dir, name)
|
||||
|
||||
return run(f'"{git}" -C "{dir}" {command}', desc=desc, errdesc=errdesc, custom_env=custom_env, live=live)
|
||||
|
||||
|
||||
def git_clone(url, dir, name, commithash=None):
|
||||
# TODO clone into temporary dir and move if successful
|
||||
|
||||
if os.path.exists(dir):
|
||||
if commithash is None:
|
||||
return
|
||||
|
||||
current_hash = run_git(dir, name, 'rev-parse HEAD', None, f"Couldn't determine {name}'s hash: {commithash}", live=False).strip()
|
||||
if current_hash == commithash:
|
||||
return
|
||||
|
||||
if run_git(dir, name, 'config --get remote.origin.url', None, f"Couldn't determine {name}'s origin URL", live=False).strip() != url:
|
||||
run_git(dir, name, f'remote set-url origin "{url}"', None, f"Failed to set {name}'s origin URL", live=False)
|
||||
|
||||
run_git(dir, name, 'fetch', f"Fetching updates for {name}...", f"Couldn't fetch {name}", autofix=False)
|
||||
|
||||
run_git(dir, name, f'checkout {commithash}', f"Checking out commit for {name} with hash: {commithash}...", f"Couldn't checkout commit {commithash} for {name}", live=True)
|
||||
|
||||
return
|
||||
|
||||
try:
|
||||
run(f'"{git}" clone --config core.filemode=false "{url}" "{dir}"', f"Cloning {name} into {dir}...", f"Couldn't clone {name}", live=True)
|
||||
except RuntimeError:
|
||||
shutil.rmtree(dir, ignore_errors=True)
|
||||
raise
|
||||
|
||||
if commithash is not None:
|
||||
run(f'"{git}" -C "{dir}" checkout {commithash}', None, "Couldn't checkout {name}'s hash: {commithash}")
|
||||
|
||||
|
||||
def git_pull_recursive(dir):
|
||||
for subdir, _, _ in os.walk(dir):
|
||||
if os.path.exists(os.path.join(subdir, '.git')):
|
||||
try:
|
||||
output = subprocess.check_output([git, '-C', subdir, 'pull', '--autostash'])
|
||||
print(f"Pulled changes for repository in '{subdir}':\n{output.decode('utf-8').strip()}\n")
|
||||
except subprocess.CalledProcessError as e:
|
||||
print(f"Couldn't perform 'git pull' on repository in '{subdir}':\n{e.output.decode('utf-8').strip()}\n")
|
||||
|
||||
|
||||
def version_check(commit):
|
||||
try:
|
||||
import requests
|
||||
commits = requests.get('https://api.github.com/repos/AUTOMATIC1111/stable-diffusion-webui/branches/master', timeout=10.0).json()
|
||||
if commit != "<none>" and commits['commit']['sha'] != commit:
|
||||
print("--------------------------------------------------------")
|
||||
print("| You are not up to date with the most recent release. |")
|
||||
print("| Consider running `git pull` to update. |")
|
||||
print("--------------------------------------------------------")
|
||||
elif commits['commit']['sha'] == commit:
|
||||
print("You are up to date with the most recent release.")
|
||||
else:
|
||||
print("Not a git clone, can't perform version check.")
|
||||
except Exception as e:
|
||||
print("version check failed", e)
|
||||
|
||||
|
||||
def run_extension_installer(extension_dir):
|
||||
path_installer = os.path.join(extension_dir, "install.py")
|
||||
if not os.path.isfile(path_installer):
|
||||
return
|
||||
|
||||
try:
|
||||
env = os.environ.copy()
|
||||
env['PYTHONPATH'] = f"{script_path}{os.pathsep}{env.get('PYTHONPATH', '')}"
|
||||
|
||||
stdout = run(f'"{python}" "{path_installer}"', errdesc=f"Error running install.py for extension {extension_dir}", custom_env=env).strip()
|
||||
if stdout:
|
||||
print(stdout)
|
||||
except Exception as e:
|
||||
errors.report(str(e))
|
||||
|
||||
|
||||
def list_extensions(settings_file):
|
||||
settings = {}
|
||||
|
||||
try:
|
||||
with open(settings_file, "r", encoding="utf8") as file:
|
||||
settings = json.load(file)
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
except Exception:
|
||||
errors.report(f'\nCould not load settings\nThe config file "{settings_file}" is likely corrupted\nIt has been moved to the "tmp/config.json"\nReverting config to default\n\n''', exc_info=True)
|
||||
os.replace(settings_file, os.path.join(script_path, "tmp", "config.json"))
|
||||
|
||||
disabled_extensions = set(settings.get('disabled_extensions', []))
|
||||
disable_all_extensions = settings.get('disable_all_extensions', 'none')
|
||||
|
||||
if disable_all_extensions != 'none' or args.disable_extra_extensions or args.disable_all_extensions or not os.path.isdir(extensions_dir):
|
||||
return []
|
||||
|
||||
return [x for x in os.listdir(extensions_dir) if x not in disabled_extensions]
|
||||
|
||||
|
||||
def run_extensions_installers(settings_file):
|
||||
if not os.path.isdir(extensions_dir):
|
||||
return
|
||||
|
||||
with startup_timer.subcategory("run extensions installers"):
|
||||
for dirname_extension in list_extensions(settings_file):
|
||||
logging.debug(f"Installing {dirname_extension}")
|
||||
|
||||
path = os.path.join(extensions_dir, dirname_extension)
|
||||
|
||||
if os.path.isdir(path):
|
||||
run_extension_installer(path)
|
||||
startup_timer.record(dirname_extension)
|
||||
|
||||
|
||||
re_requirement = re.compile(r"\s*([-_a-zA-Z0-9]+)\s*(?:==\s*([-+_.a-zA-Z0-9]+))?\s*")
|
||||
|
||||
|
||||
def requirements_met(requirements_file):
|
||||
"""
|
||||
Does a simple parse of a requirements.txt file to determine if all rerqirements in it
|
||||
are already installed. Returns True if so, False if not installed or parsing fails.
|
||||
"""
|
||||
|
||||
import importlib.metadata
|
||||
import packaging.version
|
||||
|
||||
with open(requirements_file, "r", encoding="utf8") as file:
|
||||
for line in file:
|
||||
if line.strip() == "":
|
||||
continue
|
||||
|
||||
m = re.match(re_requirement, line)
|
||||
if m is None:
|
||||
return False
|
||||
|
||||
package = m.group(1).strip()
|
||||
version_required = (m.group(2) or "").strip()
|
||||
|
||||
if version_required == "":
|
||||
continue
|
||||
|
||||
try:
|
||||
version_installed = importlib.metadata.version(package)
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
if packaging.version.parse(version_required) != packaging.version.parse(version_installed):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def prepare_environment():
|
||||
torch_index_url = os.environ.get('TORCH_INDEX_URL', "https://download.pytorch.org/whl/cu121")
|
||||
torch_command = os.environ.get('TORCH_COMMAND', f"pip install torch==2.1.2 torchvision==0.16.2 --extra-index-url {torch_index_url}")
|
||||
if args.use_ipex:
|
||||
if platform.system() == "Windows":
|
||||
# The "Nuullll/intel-extension-for-pytorch" wheels were built from IPEX source for Intel Arc GPU: https://github.com/intel/intel-extension-for-pytorch/tree/xpu-main
|
||||
# This is NOT an Intel official release so please use it at your own risk!!
|
||||
# See https://github.com/Nuullll/intel-extension-for-pytorch/releases/tag/v2.0.110%2Bxpu-master%2Bdll-bundle for details.
|
||||
#
|
||||
# Strengths (over official IPEX 2.0.110 windows release):
|
||||
# - AOT build (for Arc GPU only) to eliminate JIT compilation overhead: https://github.com/intel/intel-extension-for-pytorch/issues/399
|
||||
# - Bundles minimal oneAPI 2023.2 dependencies into the python wheels, so users don't need to install oneAPI for the whole system.
|
||||
# - Provides a compatible torchvision wheel: https://github.com/intel/intel-extension-for-pytorch/issues/465
|
||||
# Limitation:
|
||||
# - Only works for python 3.10
|
||||
url_prefix = "https://github.com/Nuullll/intel-extension-for-pytorch/releases/download/v2.0.110%2Bxpu-master%2Bdll-bundle"
|
||||
torch_command = os.environ.get('TORCH_COMMAND', f"pip install {url_prefix}/torch-2.0.0a0+gite9ebda2-cp310-cp310-win_amd64.whl {url_prefix}/torchvision-0.15.2a0+fa99a53-cp310-cp310-win_amd64.whl {url_prefix}/intel_extension_for_pytorch-2.0.110+gitc6ea20b-cp310-cp310-win_amd64.whl")
|
||||
else:
|
||||
# Using official IPEX release for linux since it's already an AOT build.
|
||||
# However, users still have to install oneAPI toolkit and activate oneAPI environment manually.
|
||||
# See https://intel.github.io/intel-extension-for-pytorch/index.html#installation for details.
|
||||
torch_index_url = os.environ.get('TORCH_INDEX_URL', "https://pytorch-extension.intel.com/release-whl/stable/xpu/us/")
|
||||
torch_command = os.environ.get('TORCH_COMMAND', f"pip install torch==2.0.0a0 intel-extension-for-pytorch==2.0.110+gitba7f6c1 --extra-index-url {torch_index_url}")
|
||||
requirements_file = os.environ.get('REQS_FILE', "requirements_versions.txt")
|
||||
requirements_file_for_npu = os.environ.get('REQS_FILE_FOR_NPU', "requirements_npu.txt")
|
||||
|
||||
xformers_package = os.environ.get('XFORMERS_PACKAGE', 'xformers==0.0.23.post1')
|
||||
clip_package = os.environ.get('CLIP_PACKAGE', "https://github.com/openai/CLIP/archive/d50d76daa670286dd6cacf3bcd80b5e4823fc8e1.zip")
|
||||
openclip_package = os.environ.get('OPENCLIP_PACKAGE', "https://github.com/mlfoundations/open_clip/archive/bb6e834e9c70d9c27d0dc3ecedeebeaeb1ffad6b.zip")
|
||||
|
||||
assets_repo = os.environ.get('ASSETS_REPO', "https://github.com/AUTOMATIC1111/stable-diffusion-webui-assets.git")
|
||||
stable_diffusion_repo = os.environ.get('STABLE_DIFFUSION_REPO', "https://github.com/Stability-AI/stablediffusion.git")
|
||||
stable_diffusion_xl_repo = os.environ.get('STABLE_DIFFUSION_XL_REPO', "https://github.com/Stability-AI/generative-models.git")
|
||||
k_diffusion_repo = os.environ.get('K_DIFFUSION_REPO', 'https://github.com/crowsonkb/k-diffusion.git')
|
||||
blip_repo = os.environ.get('BLIP_REPO', 'https://github.com/salesforce/BLIP.git')
|
||||
|
||||
assets_commit_hash = os.environ.get('ASSETS_COMMIT_HASH', "6f7db241d2f8ba7457bac5ca9753331f0c266917")
|
||||
stable_diffusion_commit_hash = os.environ.get('STABLE_DIFFUSION_COMMIT_HASH', "cf1d67a6fd5ea1aa600c4df58e5b47da45f6bdbf")
|
||||
stable_diffusion_xl_commit_hash = os.environ.get('STABLE_DIFFUSION_XL_COMMIT_HASH', "45c443b316737a4ab6e40413d7794a7f5657c19f")
|
||||
k_diffusion_commit_hash = os.environ.get('K_DIFFUSION_COMMIT_HASH', "ab527a9a6d347f364e3d185ba6d714e22d80cb3c")
|
||||
blip_commit_hash = os.environ.get('BLIP_COMMIT_HASH', "48211a1594f1321b00f14c9f7a5b4813144b2fb9")
|
||||
|
||||
try:
|
||||
# the existence of this file is a signal to webui.sh/bat that webui needs to be restarted when it stops execution
|
||||
os.remove(os.path.join(script_path, "tmp", "restart"))
|
||||
os.environ.setdefault('SD_WEBUI_RESTARTING', '1')
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
if not args.skip_python_version_check:
|
||||
check_python_version()
|
||||
|
||||
startup_timer.record("checks")
|
||||
|
||||
commit = commit_hash()
|
||||
tag = git_tag()
|
||||
startup_timer.record("git version info")
|
||||
|
||||
print(f"Python {sys.version}")
|
||||
print(f"Version: {tag}")
|
||||
print(f"Commit hash: {commit}")
|
||||
|
||||
if args.reinstall_torch or not is_installed("torch") or not is_installed("torchvision"):
|
||||
run(f'"{python}" -m {torch_command}', "Installing torch and torchvision", "Couldn't install torch", live=True)
|
||||
startup_timer.record("install torch")
|
||||
|
||||
if args.use_ipex:
|
||||
args.skip_torch_cuda_test = True
|
||||
if not args.skip_torch_cuda_test and not check_run_python("import torch; assert torch.cuda.is_available()"):
|
||||
raise RuntimeError(
|
||||
'Torch is not able to use GPU; '
|
||||
'add --skip-torch-cuda-test to COMMANDLINE_ARGS variable to disable this check'
|
||||
)
|
||||
startup_timer.record("torch GPU test")
|
||||
|
||||
if not is_installed("clip"):
|
||||
run_pip(f"install {clip_package}", "clip")
|
||||
startup_timer.record("install clip")
|
||||
|
||||
if not is_installed("open_clip"):
|
||||
run_pip(f"install {openclip_package}", "open_clip")
|
||||
startup_timer.record("install open_clip")
|
||||
|
||||
if (not is_installed("xformers") or args.reinstall_xformers) and args.xformers:
|
||||
run_pip(f"install -U -I --no-deps {xformers_package}", "xformers")
|
||||
startup_timer.record("install xformers")
|
||||
|
||||
if not is_installed("ngrok") and args.ngrok:
|
||||
run_pip("install ngrok", "ngrok")
|
||||
startup_timer.record("install ngrok")
|
||||
|
||||
os.makedirs(os.path.join(script_path, dir_repos), exist_ok=True)
|
||||
|
||||
git_clone(assets_repo, repo_dir('stable-diffusion-webui-assets'), "assets", assets_commit_hash)
|
||||
git_clone(stable_diffusion_repo, repo_dir('stable-diffusion-stability-ai'), "Stable Diffusion", stable_diffusion_commit_hash)
|
||||
git_clone(stable_diffusion_xl_repo, repo_dir('generative-models'), "Stable Diffusion XL", stable_diffusion_xl_commit_hash)
|
||||
git_clone(k_diffusion_repo, repo_dir('k-diffusion'), "K-diffusion", k_diffusion_commit_hash)
|
||||
git_clone(blip_repo, repo_dir('BLIP'), "BLIP", blip_commit_hash)
|
||||
|
||||
startup_timer.record("clone repositores")
|
||||
|
||||
if not os.path.isfile(requirements_file):
|
||||
requirements_file = os.path.join(script_path, requirements_file)
|
||||
|
||||
if not requirements_met(requirements_file):
|
||||
run_pip(f"install -r \"{requirements_file}\"", "requirements")
|
||||
startup_timer.record("install requirements")
|
||||
|
||||
if not os.path.isfile(requirements_file_for_npu):
|
||||
requirements_file_for_npu = os.path.join(script_path, requirements_file_for_npu)
|
||||
|
||||
if "torch_npu" in torch_command and not requirements_met(requirements_file_for_npu):
|
||||
run_pip(f"install -r \"{requirements_file_for_npu}\"", "requirements_for_npu")
|
||||
startup_timer.record("install requirements_for_npu")
|
||||
|
||||
if not args.skip_install:
|
||||
run_extensions_installers(settings_file=args.ui_settings_file)
|
||||
|
||||
if args.update_check:
|
||||
version_check(commit)
|
||||
startup_timer.record("check version")
|
||||
|
||||
if args.update_all_extensions:
|
||||
git_pull_recursive(extensions_dir)
|
||||
startup_timer.record("update extensions")
|
||||
|
||||
if "--exit" in sys.argv:
|
||||
print("Exiting because of --exit argument")
|
||||
exit(0)
|
||||
|
||||
|
||||
def configure_for_tests():
|
||||
if "--api" not in sys.argv:
|
||||
sys.argv.append("--api")
|
||||
if "--ckpt" not in sys.argv:
|
||||
sys.argv.append("--ckpt")
|
||||
sys.argv.append(os.path.join(script_path, "test/test_files/empty.pt"))
|
||||
if "--skip-torch-cuda-test" not in sys.argv:
|
||||
sys.argv.append("--skip-torch-cuda-test")
|
||||
if "--disable-nan-check" not in sys.argv:
|
||||
sys.argv.append("--disable-nan-check")
|
||||
|
||||
os.environ['COMMANDLINE_ARGS'] = ""
|
||||
|
||||
|
||||
def start():
|
||||
print(f"Launching {'API server' if '--nowebui' in sys.argv else 'Web UI'} with arguments: {shlex.join(sys.argv[1:])}")
|
||||
import webui
|
||||
if '--nowebui' in sys.argv:
|
||||
webui.api_only()
|
||||
else:
|
||||
webui.webui()
|
||||
|
||||
|
||||
def dump_sysinfo():
|
||||
from modules import sysinfo
|
||||
import datetime
|
||||
|
||||
text = sysinfo.get()
|
||||
filename = f"sysinfo-{datetime.datetime.utcnow().strftime('%Y-%m-%d-%H-%M')}.json"
|
||||
|
||||
with open(filename, "w", encoding="utf8") as file:
|
||||
file.write(text)
|
||||
|
||||
return filename
|
||||
|
|
|
|||
|
|
@ -1,345 +1,345 @@
|
|||
import cv2
|
||||
import requests
|
||||
import os
|
||||
import numpy as np
|
||||
from PIL import ImageDraw
|
||||
from modules import paths_internal
|
||||
from pkg_resources import parse_version
|
||||
|
||||
GREEN = "#0F0"
|
||||
BLUE = "#00F"
|
||||
RED = "#F00"
|
||||
|
||||
|
||||
def crop_image(im, settings):
|
||||
""" Intelligently crop an image to the subject matter """
|
||||
|
||||
scale_by = 1
|
||||
if is_landscape(im.width, im.height):
|
||||
scale_by = settings.crop_height / im.height
|
||||
elif is_portrait(im.width, im.height):
|
||||
scale_by = settings.crop_width / im.width
|
||||
elif is_square(im.width, im.height):
|
||||
if is_square(settings.crop_width, settings.crop_height):
|
||||
scale_by = settings.crop_width / im.width
|
||||
elif is_landscape(settings.crop_width, settings.crop_height):
|
||||
scale_by = settings.crop_width / im.width
|
||||
elif is_portrait(settings.crop_width, settings.crop_height):
|
||||
scale_by = settings.crop_height / im.height
|
||||
|
||||
im = im.resize((int(im.width * scale_by), int(im.height * scale_by)))
|
||||
im_debug = im.copy()
|
||||
|
||||
focus = focal_point(im_debug, settings)
|
||||
|
||||
# take the focal point and turn it into crop coordinates that try to center over the focal
|
||||
# point but then get adjusted back into the frame
|
||||
y_half = int(settings.crop_height / 2)
|
||||
x_half = int(settings.crop_width / 2)
|
||||
|
||||
x1 = focus.x - x_half
|
||||
if x1 < 0:
|
||||
x1 = 0
|
||||
elif x1 + settings.crop_width > im.width:
|
||||
x1 = im.width - settings.crop_width
|
||||
|
||||
y1 = focus.y - y_half
|
||||
if y1 < 0:
|
||||
y1 = 0
|
||||
elif y1 + settings.crop_height > im.height:
|
||||
y1 = im.height - settings.crop_height
|
||||
|
||||
x2 = x1 + settings.crop_width
|
||||
y2 = y1 + settings.crop_height
|
||||
|
||||
crop = [x1, y1, x2, y2]
|
||||
|
||||
results = []
|
||||
|
||||
results.append(im.crop(tuple(crop)))
|
||||
|
||||
if settings.annotate_image:
|
||||
d = ImageDraw.Draw(im_debug)
|
||||
rect = list(crop)
|
||||
rect[2] -= 1
|
||||
rect[3] -= 1
|
||||
d.rectangle(rect, outline=GREEN)
|
||||
results.append(im_debug)
|
||||
if settings.desktop_view_image:
|
||||
im_debug.show()
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def focal_point(im, settings):
|
||||
corner_points = image_corner_points(im, settings) if settings.corner_points_weight > 0 else []
|
||||
entropy_points = image_entropy_points(im, settings) if settings.entropy_points_weight > 0 else []
|
||||
face_points = image_face_points(im, settings) if settings.face_points_weight > 0 else []
|
||||
|
||||
pois = []
|
||||
|
||||
weight_pref_total = 0
|
||||
if corner_points:
|
||||
weight_pref_total += settings.corner_points_weight
|
||||
if entropy_points:
|
||||
weight_pref_total += settings.entropy_points_weight
|
||||
if face_points:
|
||||
weight_pref_total += settings.face_points_weight
|
||||
|
||||
corner_centroid = None
|
||||
if corner_points:
|
||||
corner_centroid = centroid(corner_points)
|
||||
corner_centroid.weight = settings.corner_points_weight / weight_pref_total
|
||||
pois.append(corner_centroid)
|
||||
|
||||
entropy_centroid = None
|
||||
if entropy_points:
|
||||
entropy_centroid = centroid(entropy_points)
|
||||
entropy_centroid.weight = settings.entropy_points_weight / weight_pref_total
|
||||
pois.append(entropy_centroid)
|
||||
|
||||
face_centroid = None
|
||||
if face_points:
|
||||
face_centroid = centroid(face_points)
|
||||
face_centroid.weight = settings.face_points_weight / weight_pref_total
|
||||
pois.append(face_centroid)
|
||||
|
||||
average_point = poi_average(pois, settings)
|
||||
|
||||
if settings.annotate_image:
|
||||
d = ImageDraw.Draw(im)
|
||||
max_size = min(im.width, im.height) * 0.07
|
||||
if corner_centroid is not None:
|
||||
color = BLUE
|
||||
box = corner_centroid.bounding(max_size * corner_centroid.weight)
|
||||
d.text((box[0], box[1] - 15), f"Edge: {corner_centroid.weight:.02f}", fill=color)
|
||||
d.ellipse(box, outline=color)
|
||||
if len(corner_points) > 1:
|
||||
for f in corner_points:
|
||||
d.rectangle(f.bounding(4), outline=color)
|
||||
if entropy_centroid is not None:
|
||||
color = "#ff0"
|
||||
box = entropy_centroid.bounding(max_size * entropy_centroid.weight)
|
||||
d.text((box[0], box[1] - 15), f"Entropy: {entropy_centroid.weight:.02f}", fill=color)
|
||||
d.ellipse(box, outline=color)
|
||||
if len(entropy_points) > 1:
|
||||
for f in entropy_points:
|
||||
d.rectangle(f.bounding(4), outline=color)
|
||||
if face_centroid is not None:
|
||||
color = RED
|
||||
box = face_centroid.bounding(max_size * face_centroid.weight)
|
||||
d.text((box[0], box[1] - 15), f"Face: {face_centroid.weight:.02f}", fill=color)
|
||||
d.ellipse(box, outline=color)
|
||||
if len(face_points) > 1:
|
||||
for f in face_points:
|
||||
d.rectangle(f.bounding(4), outline=color)
|
||||
|
||||
d.ellipse(average_point.bounding(max_size), outline=GREEN)
|
||||
|
||||
return average_point
|
||||
|
||||
|
||||
def image_face_points(im, settings):
|
||||
if settings.dnn_model_path is not None:
|
||||
detector = cv2.FaceDetectorYN.create(
|
||||
settings.dnn_model_path,
|
||||
"",
|
||||
(im.width, im.height),
|
||||
0.9, # score threshold
|
||||
0.3, # nms threshold
|
||||
5000 # keep top k before nms
|
||||
)
|
||||
faces = detector.detect(np.array(im))
|
||||
results = []
|
||||
if faces[1] is not None:
|
||||
for face in faces[1]:
|
||||
x = face[0]
|
||||
y = face[1]
|
||||
w = face[2]
|
||||
h = face[3]
|
||||
results.append(
|
||||
PointOfInterest(
|
||||
int(x + (w * 0.5)), # face focus left/right is center
|
||||
int(y + (h * 0.33)), # face focus up/down is close to the top of the head
|
||||
size=w,
|
||||
weight=1 / len(faces[1])
|
||||
)
|
||||
)
|
||||
return results
|
||||
else:
|
||||
np_im = np.array(im)
|
||||
gray = cv2.cvtColor(np_im, cv2.COLOR_BGR2GRAY)
|
||||
|
||||
tries = [
|
||||
[f'{cv2.data.haarcascades}haarcascade_eye.xml', 0.01],
|
||||
[f'{cv2.data.haarcascades}haarcascade_frontalface_default.xml', 0.05],
|
||||
[f'{cv2.data.haarcascades}haarcascade_profileface.xml', 0.05],
|
||||
[f'{cv2.data.haarcascades}haarcascade_frontalface_alt.xml', 0.05],
|
||||
[f'{cv2.data.haarcascades}haarcascade_frontalface_alt2.xml', 0.05],
|
||||
[f'{cv2.data.haarcascades}haarcascade_frontalface_alt_tree.xml', 0.05],
|
||||
[f'{cv2.data.haarcascades}haarcascade_eye_tree_eyeglasses.xml', 0.05],
|
||||
[f'{cv2.data.haarcascades}haarcascade_upperbody.xml', 0.05]
|
||||
]
|
||||
for t in tries:
|
||||
classifier = cv2.CascadeClassifier(t[0])
|
||||
minsize = int(min(im.width, im.height) * t[1]) # at least N percent of the smallest side
|
||||
try:
|
||||
faces = classifier.detectMultiScale(gray, scaleFactor=1.1,
|
||||
minNeighbors=7, minSize=(minsize, minsize),
|
||||
flags=cv2.CASCADE_SCALE_IMAGE)
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
if faces:
|
||||
rects = [[f[0], f[1], f[0] + f[2], f[1] + f[3]] for f in faces]
|
||||
return [PointOfInterest((r[0] + r[2]) // 2, (r[1] + r[3]) // 2, size=abs(r[0] - r[2]),
|
||||
weight=1 / len(rects)) for r in rects]
|
||||
return []
|
||||
|
||||
|
||||
def image_corner_points(im, settings):
|
||||
grayscale = im.convert("L")
|
||||
|
||||
# naive attempt at preventing focal points from collecting at watermarks near the bottom
|
||||
gd = ImageDraw.Draw(grayscale)
|
||||
gd.rectangle([0, im.height * .9, im.width, im.height], fill="#999")
|
||||
|
||||
np_im = np.array(grayscale)
|
||||
|
||||
points = cv2.goodFeaturesToTrack(
|
||||
np_im,
|
||||
maxCorners=100,
|
||||
qualityLevel=0.04,
|
||||
minDistance=min(grayscale.width, grayscale.height) * 0.06,
|
||||
useHarrisDetector=False,
|
||||
)
|
||||
|
||||
if points is None:
|
||||
return []
|
||||
|
||||
focal_points = []
|
||||
for point in points:
|
||||
x, y = point.ravel()
|
||||
focal_points.append(PointOfInterest(x, y, size=4, weight=1 / len(points)))
|
||||
|
||||
return focal_points
|
||||
|
||||
|
||||
def image_entropy_points(im, settings):
|
||||
landscape = im.height < im.width
|
||||
portrait = im.height > im.width
|
||||
if landscape:
|
||||
move_idx = [0, 2]
|
||||
move_max = im.size[0]
|
||||
elif portrait:
|
||||
move_idx = [1, 3]
|
||||
move_max = im.size[1]
|
||||
else:
|
||||
return []
|
||||
|
||||
e_max = 0
|
||||
crop_current = [0, 0, settings.crop_width, settings.crop_height]
|
||||
crop_best = crop_current
|
||||
while crop_current[move_idx[1]] < move_max:
|
||||
crop = im.crop(tuple(crop_current))
|
||||
e = image_entropy(crop)
|
||||
|
||||
if (e > e_max):
|
||||
e_max = e
|
||||
crop_best = list(crop_current)
|
||||
|
||||
crop_current[move_idx[0]] += 4
|
||||
crop_current[move_idx[1]] += 4
|
||||
|
||||
x_mid = int(crop_best[0] + settings.crop_width / 2)
|
||||
y_mid = int(crop_best[1] + settings.crop_height / 2)
|
||||
|
||||
return [PointOfInterest(x_mid, y_mid, size=25, weight=1.0)]
|
||||
|
||||
|
||||
def image_entropy(im):
|
||||
# greyscale image entropy
|
||||
# band = np.asarray(im.convert("L"))
|
||||
band = np.asarray(im.convert("1"), dtype=np.uint8)
|
||||
hist, _ = np.histogram(band, bins=range(0, 256))
|
||||
hist = hist[hist > 0]
|
||||
return -np.log2(hist / hist.sum()).sum()
|
||||
|
||||
|
||||
def centroid(pois):
|
||||
x = [poi.x for poi in pois]
|
||||
y = [poi.y for poi in pois]
|
||||
return PointOfInterest(sum(x) / len(pois), sum(y) / len(pois))
|
||||
|
||||
|
||||
def poi_average(pois, settings):
|
||||
weight = 0.0
|
||||
x = 0.0
|
||||
y = 0.0
|
||||
for poi in pois:
|
||||
weight += poi.weight
|
||||
x += poi.x * poi.weight
|
||||
y += poi.y * poi.weight
|
||||
avg_x = round(weight and x / weight)
|
||||
avg_y = round(weight and y / weight)
|
||||
|
||||
return PointOfInterest(avg_x, avg_y)
|
||||
|
||||
|
||||
def is_landscape(w, h):
|
||||
return w > h
|
||||
|
||||
|
||||
def is_portrait(w, h):
|
||||
return h > w
|
||||
|
||||
|
||||
def is_square(w, h):
|
||||
return w == h
|
||||
|
||||
|
||||
model_dir_opencv = os.path.join(paths_internal.models_path, 'opencv')
|
||||
if parse_version(cv2.__version__) >= parse_version('4.8'):
|
||||
model_file_path = os.path.join(model_dir_opencv, 'face_detection_yunet_2023mar.onnx')
|
||||
model_url = 'https://github.com/opencv/opencv_zoo/blob/b6e370b10f641879a87890d44e42173077154a05/models/face_detection_yunet/face_detection_yunet_2023mar.onnx?raw=true'
|
||||
else:
|
||||
model_file_path = os.path.join(model_dir_opencv, 'face_detection_yunet.onnx')
|
||||
model_url = 'https://github.com/opencv/opencv_zoo/blob/91fb0290f50896f38a0ab1e558b74b16bc009428/models/face_detection_yunet/face_detection_yunet_2022mar.onnx?raw=true'
|
||||
|
||||
|
||||
def download_and_cache_models():
|
||||
if not os.path.exists(model_file_path):
|
||||
os.makedirs(model_dir_opencv, exist_ok=True)
|
||||
print(f"downloading face detection model from '{model_url}' to '{model_file_path}'")
|
||||
response = requests.get(model_url)
|
||||
with open(model_file_path, "wb") as f:
|
||||
f.write(response.content)
|
||||
return model_file_path
|
||||
|
||||
|
||||
class PointOfInterest:
|
||||
def __init__(self, x, y, weight=1.0, size=10):
|
||||
self.x = x
|
||||
self.y = y
|
||||
self.weight = weight
|
||||
self.size = size
|
||||
|
||||
def bounding(self, size):
|
||||
return [
|
||||
self.x - size // 2,
|
||||
self.y - size // 2,
|
||||
self.x + size // 2,
|
||||
self.y + size // 2
|
||||
]
|
||||
|
||||
|
||||
class Settings:
|
||||
def __init__(self, crop_width=512, crop_height=512, corner_points_weight=0.5, entropy_points_weight=0.5, face_points_weight=0.5, annotate_image=False, dnn_model_path=None):
|
||||
self.crop_width = crop_width
|
||||
self.crop_height = crop_height
|
||||
self.corner_points_weight = corner_points_weight
|
||||
self.entropy_points_weight = entropy_points_weight
|
||||
self.face_points_weight = face_points_weight
|
||||
self.annotate_image = annotate_image
|
||||
self.desktop_view_image = False
|
||||
self.dnn_model_path = dnn_model_path
|
||||
import cv2
|
||||
import requests
|
||||
import os
|
||||
import numpy as np
|
||||
from PIL import ImageDraw
|
||||
from modules import paths_internal
|
||||
from pkg_resources import parse_version
|
||||
|
||||
GREEN = "#0F0"
|
||||
BLUE = "#00F"
|
||||
RED = "#F00"
|
||||
|
||||
|
||||
def crop_image(im, settings):
|
||||
""" Intelligently crop an image to the subject matter """
|
||||
|
||||
scale_by = 1
|
||||
if is_landscape(im.width, im.height):
|
||||
scale_by = settings.crop_height / im.height
|
||||
elif is_portrait(im.width, im.height):
|
||||
scale_by = settings.crop_width / im.width
|
||||
elif is_square(im.width, im.height):
|
||||
if is_square(settings.crop_width, settings.crop_height):
|
||||
scale_by = settings.crop_width / im.width
|
||||
elif is_landscape(settings.crop_width, settings.crop_height):
|
||||
scale_by = settings.crop_width / im.width
|
||||
elif is_portrait(settings.crop_width, settings.crop_height):
|
||||
scale_by = settings.crop_height / im.height
|
||||
|
||||
im = im.resize((int(im.width * scale_by), int(im.height * scale_by)))
|
||||
im_debug = im.copy()
|
||||
|
||||
focus = focal_point(im_debug, settings)
|
||||
|
||||
# take the focal point and turn it into crop coordinates that try to center over the focal
|
||||
# point but then get adjusted back into the frame
|
||||
y_half = int(settings.crop_height / 2)
|
||||
x_half = int(settings.crop_width / 2)
|
||||
|
||||
x1 = focus.x - x_half
|
||||
if x1 < 0:
|
||||
x1 = 0
|
||||
elif x1 + settings.crop_width > im.width:
|
||||
x1 = im.width - settings.crop_width
|
||||
|
||||
y1 = focus.y - y_half
|
||||
if y1 < 0:
|
||||
y1 = 0
|
||||
elif y1 + settings.crop_height > im.height:
|
||||
y1 = im.height - settings.crop_height
|
||||
|
||||
x2 = x1 + settings.crop_width
|
||||
y2 = y1 + settings.crop_height
|
||||
|
||||
crop = [x1, y1, x2, y2]
|
||||
|
||||
results = []
|
||||
|
||||
results.append(im.crop(tuple(crop)))
|
||||
|
||||
if settings.annotate_image:
|
||||
d = ImageDraw.Draw(im_debug)
|
||||
rect = list(crop)
|
||||
rect[2] -= 1
|
||||
rect[3] -= 1
|
||||
d.rectangle(rect, outline=GREEN)
|
||||
results.append(im_debug)
|
||||
if settings.desktop_view_image:
|
||||
im_debug.show()
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def focal_point(im, settings):
|
||||
corner_points = image_corner_points(im, settings) if settings.corner_points_weight > 0 else []
|
||||
entropy_points = image_entropy_points(im, settings) if settings.entropy_points_weight > 0 else []
|
||||
face_points = image_face_points(im, settings) if settings.face_points_weight > 0 else []
|
||||
|
||||
pois = []
|
||||
|
||||
weight_pref_total = 0
|
||||
if corner_points:
|
||||
weight_pref_total += settings.corner_points_weight
|
||||
if entropy_points:
|
||||
weight_pref_total += settings.entropy_points_weight
|
||||
if face_points:
|
||||
weight_pref_total += settings.face_points_weight
|
||||
|
||||
corner_centroid = None
|
||||
if corner_points:
|
||||
corner_centroid = centroid(corner_points)
|
||||
corner_centroid.weight = settings.corner_points_weight / weight_pref_total
|
||||
pois.append(corner_centroid)
|
||||
|
||||
entropy_centroid = None
|
||||
if entropy_points:
|
||||
entropy_centroid = centroid(entropy_points)
|
||||
entropy_centroid.weight = settings.entropy_points_weight / weight_pref_total
|
||||
pois.append(entropy_centroid)
|
||||
|
||||
face_centroid = None
|
||||
if face_points:
|
||||
face_centroid = centroid(face_points)
|
||||
face_centroid.weight = settings.face_points_weight / weight_pref_total
|
||||
pois.append(face_centroid)
|
||||
|
||||
average_point = poi_average(pois, settings)
|
||||
|
||||
if settings.annotate_image:
|
||||
d = ImageDraw.Draw(im)
|
||||
max_size = min(im.width, im.height) * 0.07
|
||||
if corner_centroid is not None:
|
||||
color = BLUE
|
||||
box = corner_centroid.bounding(max_size * corner_centroid.weight)
|
||||
d.text((box[0], box[1] - 15), f"Edge: {corner_centroid.weight:.02f}", fill=color)
|
||||
d.ellipse(box, outline=color)
|
||||
if len(corner_points) > 1:
|
||||
for f in corner_points:
|
||||
d.rectangle(f.bounding(4), outline=color)
|
||||
if entropy_centroid is not None:
|
||||
color = "#ff0"
|
||||
box = entropy_centroid.bounding(max_size * entropy_centroid.weight)
|
||||
d.text((box[0], box[1] - 15), f"Entropy: {entropy_centroid.weight:.02f}", fill=color)
|
||||
d.ellipse(box, outline=color)
|
||||
if len(entropy_points) > 1:
|
||||
for f in entropy_points:
|
||||
d.rectangle(f.bounding(4), outline=color)
|
||||
if face_centroid is not None:
|
||||
color = RED
|
||||
box = face_centroid.bounding(max_size * face_centroid.weight)
|
||||
d.text((box[0], box[1] - 15), f"Face: {face_centroid.weight:.02f}", fill=color)
|
||||
d.ellipse(box, outline=color)
|
||||
if len(face_points) > 1:
|
||||
for f in face_points:
|
||||
d.rectangle(f.bounding(4), outline=color)
|
||||
|
||||
d.ellipse(average_point.bounding(max_size), outline=GREEN)
|
||||
|
||||
return average_point
|
||||
|
||||
|
||||
def image_face_points(im, settings):
|
||||
if settings.dnn_model_path is not None:
|
||||
detector = cv2.FaceDetectorYN.create(
|
||||
settings.dnn_model_path,
|
||||
"",
|
||||
(im.width, im.height),
|
||||
0.9, # score threshold
|
||||
0.3, # nms threshold
|
||||
5000 # keep top k before nms
|
||||
)
|
||||
faces = detector.detect(np.array(im))
|
||||
results = []
|
||||
if faces[1] is not None:
|
||||
for face in faces[1]:
|
||||
x = face[0]
|
||||
y = face[1]
|
||||
w = face[2]
|
||||
h = face[3]
|
||||
results.append(
|
||||
PointOfInterest(
|
||||
int(x + (w * 0.5)), # face focus left/right is center
|
||||
int(y + (h * 0.33)), # face focus up/down is close to the top of the head
|
||||
size=w,
|
||||
weight=1 / len(faces[1])
|
||||
)
|
||||
)
|
||||
return results
|
||||
else:
|
||||
np_im = np.array(im)
|
||||
gray = cv2.cvtColor(np_im, cv2.COLOR_BGR2GRAY)
|
||||
|
||||
tries = [
|
||||
[f'{cv2.data.haarcascades}haarcascade_eye.xml', 0.01],
|
||||
[f'{cv2.data.haarcascades}haarcascade_frontalface_default.xml', 0.05],
|
||||
[f'{cv2.data.haarcascades}haarcascade_profileface.xml', 0.05],
|
||||
[f'{cv2.data.haarcascades}haarcascade_frontalface_alt.xml', 0.05],
|
||||
[f'{cv2.data.haarcascades}haarcascade_frontalface_alt2.xml', 0.05],
|
||||
[f'{cv2.data.haarcascades}haarcascade_frontalface_alt_tree.xml', 0.05],
|
||||
[f'{cv2.data.haarcascades}haarcascade_eye_tree_eyeglasses.xml', 0.05],
|
||||
[f'{cv2.data.haarcascades}haarcascade_upperbody.xml', 0.05]
|
||||
]
|
||||
for t in tries:
|
||||
classifier = cv2.CascadeClassifier(t[0])
|
||||
minsize = int(min(im.width, im.height) * t[1]) # at least N percent of the smallest side
|
||||
try:
|
||||
faces = classifier.detectMultiScale(gray, scaleFactor=1.1,
|
||||
minNeighbors=7, minSize=(minsize, minsize),
|
||||
flags=cv2.CASCADE_SCALE_IMAGE)
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
if faces:
|
||||
rects = [[f[0], f[1], f[0] + f[2], f[1] + f[3]] for f in faces]
|
||||
return [PointOfInterest((r[0] + r[2]) // 2, (r[1] + r[3]) // 2, size=abs(r[0] - r[2]),
|
||||
weight=1 / len(rects)) for r in rects]
|
||||
return []
|
||||
|
||||
|
||||
def image_corner_points(im, settings):
|
||||
grayscale = im.convert("L")
|
||||
|
||||
# naive attempt at preventing focal points from collecting at watermarks near the bottom
|
||||
gd = ImageDraw.Draw(grayscale)
|
||||
gd.rectangle([0, im.height * .9, im.width, im.height], fill="#999")
|
||||
|
||||
np_im = np.array(grayscale)
|
||||
|
||||
points = cv2.goodFeaturesToTrack(
|
||||
np_im,
|
||||
maxCorners=100,
|
||||
qualityLevel=0.04,
|
||||
minDistance=min(grayscale.width, grayscale.height) * 0.06,
|
||||
useHarrisDetector=False,
|
||||
)
|
||||
|
||||
if points is None:
|
||||
return []
|
||||
|
||||
focal_points = []
|
||||
for point in points:
|
||||
x, y = point.ravel()
|
||||
focal_points.append(PointOfInterest(x, y, size=4, weight=1 / len(points)))
|
||||
|
||||
return focal_points
|
||||
|
||||
|
||||
def image_entropy_points(im, settings):
|
||||
landscape = im.height < im.width
|
||||
portrait = im.height > im.width
|
||||
if landscape:
|
||||
move_idx = [0, 2]
|
||||
move_max = im.size[0]
|
||||
elif portrait:
|
||||
move_idx = [1, 3]
|
||||
move_max = im.size[1]
|
||||
else:
|
||||
return []
|
||||
|
||||
e_max = 0
|
||||
crop_current = [0, 0, settings.crop_width, settings.crop_height]
|
||||
crop_best = crop_current
|
||||
while crop_current[move_idx[1]] < move_max:
|
||||
crop = im.crop(tuple(crop_current))
|
||||
e = image_entropy(crop)
|
||||
|
||||
if (e > e_max):
|
||||
e_max = e
|
||||
crop_best = list(crop_current)
|
||||
|
||||
crop_current[move_idx[0]] += 4
|
||||
crop_current[move_idx[1]] += 4
|
||||
|
||||
x_mid = int(crop_best[0] + settings.crop_width / 2)
|
||||
y_mid = int(crop_best[1] + settings.crop_height / 2)
|
||||
|
||||
return [PointOfInterest(x_mid, y_mid, size=25, weight=1.0)]
|
||||
|
||||
|
||||
def image_entropy(im):
|
||||
# greyscale image entropy
|
||||
# band = np.asarray(im.convert("L"))
|
||||
band = np.asarray(im.convert("1"), dtype=np.uint8)
|
||||
hist, _ = np.histogram(band, bins=range(0, 256))
|
||||
hist = hist[hist > 0]
|
||||
return -np.log2(hist / hist.sum()).sum()
|
||||
|
||||
|
||||
def centroid(pois):
|
||||
x = [poi.x for poi in pois]
|
||||
y = [poi.y for poi in pois]
|
||||
return PointOfInterest(sum(x) / len(pois), sum(y) / len(pois))
|
||||
|
||||
|
||||
def poi_average(pois, settings):
|
||||
weight = 0.0
|
||||
x = 0.0
|
||||
y = 0.0
|
||||
for poi in pois:
|
||||
weight += poi.weight
|
||||
x += poi.x * poi.weight
|
||||
y += poi.y * poi.weight
|
||||
avg_x = round(weight and x / weight)
|
||||
avg_y = round(weight and y / weight)
|
||||
|
||||
return PointOfInterest(avg_x, avg_y)
|
||||
|
||||
|
||||
def is_landscape(w, h):
|
||||
return w > h
|
||||
|
||||
|
||||
def is_portrait(w, h):
|
||||
return h > w
|
||||
|
||||
|
||||
def is_square(w, h):
|
||||
return w == h
|
||||
|
||||
|
||||
model_dir_opencv = os.path.join(paths_internal.models_path, 'opencv')
|
||||
if parse_version(cv2.__version__) >= parse_version('4.8'):
|
||||
model_file_path = os.path.join(model_dir_opencv, 'face_detection_yunet_2023mar.onnx')
|
||||
model_url = 'https://github.com/opencv/opencv_zoo/blob/b6e370b10f641879a87890d44e42173077154a05/models/face_detection_yunet/face_detection_yunet_2023mar.onnx?raw=true'
|
||||
else:
|
||||
model_file_path = os.path.join(model_dir_opencv, 'face_detection_yunet.onnx')
|
||||
model_url = 'https://github.com/opencv/opencv_zoo/blob/91fb0290f50896f38a0ab1e558b74b16bc009428/models/face_detection_yunet/face_detection_yunet_2022mar.onnx?raw=true'
|
||||
|
||||
|
||||
def download_and_cache_models():
|
||||
if not os.path.exists(model_file_path):
|
||||
os.makedirs(model_dir_opencv, exist_ok=True)
|
||||
print(f"downloading face detection model from '{model_url}' to '{model_file_path}'")
|
||||
response = requests.get(model_url, timeout=10.0)
|
||||
with open(model_file_path, "wb") as f:
|
||||
f.write(response.content)
|
||||
return model_file_path
|
||||
|
||||
|
||||
class PointOfInterest:
|
||||
def __init__(self, x, y, weight=1.0, size=10):
|
||||
self.x = x
|
||||
self.y = y
|
||||
self.weight = weight
|
||||
self.size = size
|
||||
|
||||
def bounding(self, size):
|
||||
return [
|
||||
self.x - size // 2,
|
||||
self.y - size // 2,
|
||||
self.x + size // 2,
|
||||
self.y + size // 2
|
||||
]
|
||||
|
||||
|
||||
class Settings:
|
||||
def __init__(self, crop_width=512, crop_height=512, corner_points_weight=0.5, entropy_points_weight=0.5, face_points_weight=0.5, annotate_image=False, dnn_model_path=None):
|
||||
self.crop_width = crop_width
|
||||
self.crop_height = crop_height
|
||||
self.corner_points_weight = corner_points_weight
|
||||
self.entropy_points_weight = entropy_points_weight
|
||||
self.face_points_weight = face_points_weight
|
||||
self.annotate_image = annotate_image
|
||||
self.desktop_view_image = False
|
||||
self.dnn_model_path = dnn_model_path
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue