Supported pytorch added for Cuda 11.3 and 11.7, dynamically.

Also called the early_access_blackwell_wheels under def prepare_environment()
This commit is contained in:
Kavya Mali 2025-10-28 02:32:23 +05:30 committed by GitHub
parent 6685e532df
commit fe5ea3c0d8
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -347,9 +347,53 @@ def early_access_blackwell_wheels():
return f'pip install {ea_whl.get(sys.version_info.minor)}'
def cu113_cu117():
"""For older GPUs, provides latest PyTorch for CUDA 11.3 (cu113) and 11.7 (cu117)"""
print('cu113_cu117')
cc = get_cuda_comp_cap()
py_minor = sys.version_info.minor
if os.environ.get('TORCH_INDEX_URL') is not None:
return None
if platform.system() != "Windows":
return None
if sys.version_info.major != 3 or py_minor not in (10, 11, 12):
return None
if 5.0 <= cc < 6.0: # Maxwell -> cu113, Compute Capability 5.0/5.2
torch_pkg = "torch==1.12.1+cu113"
tv_pkg = "torchvision==0.13.1+cu113"
torch_index_url = "https://download.pytorch.org/whl/cu113"
torch_command = os.environ.get(
'TORCH_COMMAND', f"pip install {torch_pkg} {tv_pkg} --extra-index-url {torch_index_url}"
)
return torch_command
elif 6.0 <= cc < 7.0: # Pascal -> cu117, Compute Capability 6.0/6.1
torch_pkg = "torch==2.0.1+cu117"
tv_pkg = "torchvision==0.15.2+cu117"
torch_index_url = "https://download.pytorch.org/whl/cu117"
torch_command = os.environ.get(
'TORCH_COMMAND', f"pip install {torch_pkg} {tv_pkg} --extra-index-url {torch_index_url}"
)
return torch_command
return None # fallback
def prepare_environment():
torch_index_url = os.environ.get('TORCH_INDEX_URL', "https://download.pytorch.org/whl/cu128")
torch_command = os.environ.get('TORCH_COMMAND', f"pip install torch==2.7.0 torchvision==0.22.0 --extra-index-url {torch_index_url}")
torch_command = early_access_blackwell_wheels() or cu113_cu117()
if torch_command is None:
# fallback to default
torch_index_url = os.environ.get('TORCH_INDEX_URL', "https://download.pytorch.org/whl/cu128")
torch_command = os.environ.get(
'TORCH_COMMAND',
f"pip install torch==2.7.0 torchvision==0.22.0 --extra-index-url {torch_index_url}"
)
if args.use_ipex:
if platform.system() == "Windows":
# The "Nuullll/intel-extension-for-pytorch" wheels were built from IPEX source for Intel Arc GPU: https://github.com/intel/intel-extension-for-pytorch/tree/xpu-main
@ -512,3 +556,4 @@ def dump_sysinfo():
file.write(text)
return filename