diff --git a/scripts/dev/create_stub_repos.py b/scripts/dev/create_stub_repos.py index fb6187043..33b9b47fb 100644 --- a/scripts/dev/create_stub_repos.py +++ b/scripts/dev/create_stub_repos.py @@ -68,7 +68,10 @@ def main() -> None: "class CrossAttention:\n def forward(self, *a, **k): pass\n" "\nSDP_IS_AVAILABLE = True\nXFORMERS_IS_AVAILABLE = False\n", ) - touch(os.path.join(REPOS, gm, "sgm", "modules", "diffusionmodules", "__init__.py")) + touch( + os.path.join(REPOS, gm, "sgm", "modules", "diffusionmodules", "__init__.py"), + "from . import model, openaimodel\n", + ) touch( os.path.join(REPOS, gm, "sgm", "modules", "diffusionmodules", "model.py"), "class AttnBlock:\n def forward(self, *a, **k): pass\n",