diff --git a/scripts/dev/create_stub_repos.py b/scripts/dev/create_stub_repos.py index ced2b1197..530de0c8b 100644 --- a/scripts/dev/create_stub_repos.py +++ b/scripts/dev/create_stub_repos.py @@ -109,7 +109,10 @@ def main() -> None: # k-diffusion: k_diffusion.sampling, utils (sd_schedulers, sd_samplers_lcm) kd = "k-diffusion" - touch(os.path.join(REPOS, kd, "k_diffusion", "__init__.py")) + touch( + os.path.join(REPOS, kd, "k_diffusion", "__init__.py"), + "from . import utils, sampling, external\n", + ) touch(os.path.join(REPOS, kd, "k_diffusion", "utils.py"), "# stub\n") touch( os.path.join(REPOS, kd, "k_diffusion", "external.py"),