diff --git a/cuda_pathfinder/cuda/pathfinder/_static_libs/find_bitcode_lib.py b/cuda_pathfinder/cuda/pathfinder/_static_libs/find_bitcode_lib.py index ac038aadfe7..0b8cb176ff5 100644 --- a/cuda_pathfinder/cuda/pathfinder/_static_libs/find_bitcode_lib.py +++ b/cuda_pathfinder/cuda/pathfinder/_static_libs/find_bitcode_lib.py @@ -3,6 +3,7 @@ import functools import os +import re from dataclasses import dataclass from typing import NoReturn, TypedDict @@ -62,6 +63,7 @@ class _BitcodeLibInfo(TypedDict): name for name, info in _SUPPORTED_BITCODE_LIBS_INFO.items() if not IS_WINDOWS or info["available_on_windows"] ) ) +_SM_ARCH_PATTERN = re.compile(r"sm[0-9]+[a-z]?") def _no_such_file_in_dir(dir_path: str, filename: str, error_messages: list[str], attachments: list[str]) -> None: @@ -74,13 +76,24 @@ def _no_such_file_in_dir(dir_path: str, filename: str, error_messages: list[str] attachments.append(f' Directory does not exist: "{dir_path}"') +def _filename_with_sm_arch(filename: str, sm_arch: str | None) -> str: + if sm_arch is None: + return filename + + if not _SM_ARCH_PATTERN.fullmatch(sm_arch): + raise ValueError(f"Invalid sm_arch: {sm_arch!r} must match {_SM_ARCH_PATTERN.pattern!r}") + + stem, ext = os.path.splitext(filename) + return f"{stem}_{sm_arch}{ext}" + + class _FindBitcodeLib: - def __init__(self, name: str) -> None: + def __init__(self, name: str, sm_arch: str | None = None) -> None: if name not in _SUPPORTED_BITCODE_LIBS_INFO: # Updated reference raise ValueError(f"Unknown bitcode library: '{name}'. Supported: {', '.join(SUPPORTED_BITCODE_LIBS)}") self.name: str = name self.config: _BitcodeLibInfo = _SUPPORTED_BITCODE_LIBS_INFO[name] # Updated reference - self.filename: str = self.config["filename"] + self.filename: str = _filename_with_sm_arch(self.config["filename"], sm_arch) self.rel_path: str = self.config["rel_path"] self.site_packages_dirs: tuple[str, ...] = self.config["site_packages_dirs"] self.error_messages: list[str] = [] @@ -130,14 +143,25 @@ def raise_not_found_error(self) -> NoReturn: raise BitcodeLibNotFoundError(f'Failure finding "{self.filename}": {err}\n{att}') -def locate_bitcode_lib(name: str) -> LocatedBitcodeLib: +def locate_bitcode_lib(name: str, *, sm_arch: str | None = None) -> LocatedBitcodeLib: """Locate a bitcode library by name. + When ``sm_arch`` is not ``None``, locate the architecture-specific bitcode + filename with ``_{sm_arch}`` inserted before the ``.bc`` suffix. + + Args: + name: Name of the supported bitcode library to locate. + sm_arch: Optional SM architecture suffix, such as ``"sm90"`` or + ``"sm90a"``. If not ``None``, it must match + ``sm[0-9]+[a-z]?``. + Raises: - ValueError: If ``name`` is not a supported bitcode library. + ValueError: If ``name`` is not a supported bitcode library, or if + ``sm_arch`` is not ``None`` and does not match + ``sm[0-9]+[a-z]?``. BitcodeLibNotFoundError: If the bitcode library cannot be found. """ - finder = _FindBitcodeLib(name) + finder = _FindBitcodeLib(name, sm_arch) abs_path = finder.try_site_packages() if abs_path is not None: @@ -170,11 +194,22 @@ def locate_bitcode_lib(name: str) -> LocatedBitcodeLib: @functools.cache -def find_bitcode_lib(name: str) -> str: +def find_bitcode_lib(name: str, *, sm_arch: str | None = None) -> str: """Find the absolute path to a bitcode library. + When ``sm_arch`` is not ``None``, find the architecture-specific bitcode + filename with ``_{sm_arch}`` inserted before the ``.bc`` suffix. + + Args: + name: Name of the supported bitcode library to find. + sm_arch: Optional SM architecture suffix, such as ``"sm90"`` or + ``"sm90a"``. If not ``None``, it must match + ``sm[0-9]+[a-z]?``. + Raises: - ValueError: If ``name`` is not a supported bitcode library. + ValueError: If ``name`` is not a supported bitcode library, or if + ``sm_arch`` is not ``None`` and does not match + ``sm[0-9]+[a-z]?``. BitcodeLibNotFoundError: If the bitcode library cannot be found. """ - return locate_bitcode_lib(name).abs_path + return locate_bitcode_lib(name, sm_arch=sm_arch).abs_path diff --git a/cuda_pathfinder/tests/test_find_bitcode_lib.py b/cuda_pathfinder/tests/test_find_bitcode_lib.py index 659b068f0ff..83323ac0a95 100644 --- a/cuda_pathfinder/tests/test_find_bitcode_lib.py +++ b/cuda_pathfinder/tests/test_find_bitcode_lib.py @@ -23,10 +23,6 @@ def _bitcode_lib_info(libname: str): return find_bitcode_lib_module._SUPPORTED_BITCODE_LIBS_INFO[libname] -def _bitcode_lib_filename(libname: str) -> str: - return _bitcode_lib_info(libname)["filename"] - - @pytest.fixture def clear_find_bitcode_lib_cache(): find_bitcode_lib_module.find_bitcode_lib.cache_clear() @@ -36,9 +32,9 @@ def clear_find_bitcode_lib_cache(): get_cuda_path_or_home.cache_clear() -def _make_bitcode_lib_file(dir_path: Path, libname: str) -> str: +def _make_bitcode_lib_file(dir_path: Path, filename: str) -> str: dir_path.mkdir(parents=True, exist_ok=True) - file_path = dir_path / _bitcode_lib_filename(libname) + file_path = dir_path / filename file_path.touch() return str(file_path) @@ -92,14 +88,16 @@ def test_locate_bitcode_lib(info_summary_append, libname): @pytest.mark.usefixtures("clear_find_bitcode_lib_cache") @pytest.mark.parametrize("libname", SUPPORTED_BITCODE_LIBS) def test_locate_bitcode_lib_search_order(monkeypatch, tmp_path, libname): + filename = _bitcode_lib_info(libname)["filename"] + site_packages_lib_dir = _site_packages_bitcode_lib_dir_under(tmp_path / "site-packages", libname) - site_packages_path = _make_bitcode_lib_file(site_packages_lib_dir, libname) + site_packages_path = _make_bitcode_lib_file(site_packages_lib_dir, filename) conda_prefix = tmp_path / "conda-prefix" - conda_path = _make_bitcode_lib_file(_bitcode_lib_dir_under(_conda_anchor(conda_prefix), libname), libname) + conda_path = _make_bitcode_lib_file(_bitcode_lib_dir_under(_conda_anchor(conda_prefix), libname), filename) cuda_home = tmp_path / "cuda-home" - cuda_home_path = _make_bitcode_lib_file(_bitcode_lib_dir_under(cuda_home, libname), libname) + cuda_home_path = _make_bitcode_lib_file(_bitcode_lib_dir_under(cuda_home, libname), filename) site_packages_sub_dirs = tuple( tuple(rel_dir.split("/")) for rel_dir in _bitcode_lib_info(libname)["site_packages_dirs"] @@ -135,6 +133,84 @@ def find_expected_sub_dir(sub_dir): assert located_lib.found_via == "CUDA_PATH" +@pytest.mark.usefixtures("clear_find_bitcode_lib_cache") +@pytest.mark.skipif("nvshmem_device" not in SUPPORTED_BITCODE_LIBS, reason="nvshmem_device is not supported") +def test_locate_bitcode_lib_with_sm_arch_search_order(monkeypatch, tmp_path): + libname = "nvshmem_device" + sm_arch = "sm90" + filename = "libnvshmem_device_sm90.bc" + + site_packages_lib_dir = _site_packages_bitcode_lib_dir_under(tmp_path / "site-packages", libname) + site_packages_path = _make_bitcode_lib_file(site_packages_lib_dir, filename) + + conda_prefix = tmp_path / "conda-prefix" + conda_path = _make_bitcode_lib_file(_bitcode_lib_dir_under(_conda_anchor(conda_prefix), libname), filename) + + cuda_home = tmp_path / "cuda-home" + cuda_home_path = _make_bitcode_lib_file(_bitcode_lib_dir_under(cuda_home, libname), filename) + + site_packages_sub_dirs = tuple( + tuple(rel_dir.split("/")) for rel_dir in _bitcode_lib_info(libname)["site_packages_dirs"] + ) + + def find_expected_sub_dir(sub_dir): + assert sub_dir in site_packages_sub_dirs + if sub_dir == site_packages_sub_dirs[0]: + return [str(site_packages_lib_dir)] + return [] + + monkeypatch.setattr( + find_bitcode_lib_module, + "find_sub_dirs_all_sitepackages", + find_expected_sub_dir, + ) + monkeypatch.setenv("CONDA_PREFIX", str(conda_prefix)) + monkeypatch.setenv("CUDA_HOME", str(cuda_home)) + monkeypatch.delenv("CUDA_PATH", raising=False) + + located_lib = locate_bitcode_lib(libname, sm_arch=sm_arch) + assert located_lib.abs_path == site_packages_path + assert located_lib.filename == filename + assert located_lib.found_via == "site-packages" + assert find_bitcode_lib(libname, sm_arch=sm_arch) == site_packages_path + os.remove(site_packages_path) + find_bitcode_lib_module.find_bitcode_lib.cache_clear() + + located_lib = locate_bitcode_lib(libname, sm_arch=sm_arch) + assert located_lib.abs_path == conda_path + assert located_lib.filename == filename + assert located_lib.found_via == "conda" + os.remove(conda_path) + + located_lib = locate_bitcode_lib(libname, sm_arch=sm_arch) + assert located_lib.abs_path == cuda_home_path + assert located_lib.filename == filename + assert located_lib.found_via == "CUDA_PATH" + + +@pytest.mark.usefixtures("clear_find_bitcode_lib_cache") +@pytest.mark.skipif("nvshmem_device" not in SUPPORTED_BITCODE_LIBS, reason="nvshmem_device is not supported") +def test_find_bitcode_lib_cache_keeps_sm_arch_separate(monkeypatch, tmp_path): + libname = "nvshmem_device" + site_packages_lib_dir = _site_packages_bitcode_lib_dir_under(tmp_path / "site-packages", libname) + sm80_path = _make_bitcode_lib_file(site_packages_lib_dir, "libnvshmem_device_sm80.bc") + sm90_path = _make_bitcode_lib_file(site_packages_lib_dir, "libnvshmem_device_sm90.bc") + sm90a_path = _make_bitcode_lib_file(site_packages_lib_dir, "libnvshmem_device_sm90a.bc") + + monkeypatch.setattr( + find_bitcode_lib_module, + "find_sub_dirs_all_sitepackages", + lambda _sub_dir: [str(site_packages_lib_dir)], + ) + monkeypatch.delenv("CONDA_PREFIX", raising=False) + monkeypatch.delenv("CUDA_HOME", raising=False) + monkeypatch.delenv("CUDA_PATH", raising=False) + + assert find_bitcode_lib(libname, sm_arch="sm80") == sm80_path + assert find_bitcode_lib(libname, sm_arch="sm90") == sm90_path + assert find_bitcode_lib(libname, sm_arch="sm90a") == sm90a_path + + @pytest.mark.usefixtures("clear_find_bitcode_lib_cache") def test_find_bitcode_lib_not_found_error_includes_cuda_home_directory_listing(monkeypatch, tmp_path): cuda_home = tmp_path / "cuda-home" @@ -156,12 +232,44 @@ def test_find_bitcode_lib_not_found_error_includes_cuda_home_directory_listing(m find_bitcode_lib("device") message = str(exc_info.value) - expected_missing_file = os.path.join(str(lib_dir), _bitcode_lib_filename("device")) + expected_missing_file = os.path.join(str(lib_dir), _bitcode_lib_info("device")["filename"]) assert f"No such file: {expected_missing_file}" in message assert f'listdir("{lib_dir}"):' in message assert "README.txt" in message +@pytest.mark.usefixtures("clear_find_bitcode_lib_cache") +@pytest.mark.skipif("nvshmem_device" not in SUPPORTED_BITCODE_LIBS, reason="nvshmem_device is not supported") +def test_find_bitcode_lib_with_sm_arch_not_found_error_uses_arch_specific_filename(monkeypatch, tmp_path): + libname = "nvshmem_device" + sm_arch = "sm90" + expected_filename = "libnvshmem_device_sm90.bc" + + cuda_home = tmp_path / "cuda-home" + lib_dir = _bitcode_lib_dir_under(cuda_home, libname) + lib_dir.mkdir(parents=True, exist_ok=True) + extra_file = lib_dir / "libnvshmem_device.bc" + extra_file.touch() + + monkeypatch.setattr( + find_bitcode_lib_module, + "find_sub_dirs_all_sitepackages", + lambda _sub_dir: [], + ) + monkeypatch.delenv("CONDA_PREFIX", raising=False) + monkeypatch.setenv("CUDA_HOME", str(cuda_home)) + monkeypatch.delenv("CUDA_PATH", raising=False) + + with pytest.raises(BitcodeLibNotFoundError, match=rf'Failure finding "{expected_filename}"') as exc_info: + find_bitcode_lib(libname, sm_arch=sm_arch) + + message = str(exc_info.value) + expected_missing_file = os.path.join(str(lib_dir), expected_filename) + assert f"No such file: {expected_missing_file}" in message + assert f'listdir("{lib_dir}"):' in message + assert "libnvshmem_device.bc" in message + + @pytest.mark.usefixtures("clear_find_bitcode_lib_cache") def test_find_bitcode_lib_not_found_error_without_cuda_home(monkeypatch): monkeypatch.setattr( @@ -183,3 +291,34 @@ def test_find_bitcode_lib_not_found_error_without_cuda_home(monkeypatch): def test_find_bitcode_lib_invalid_name(): with pytest.raises(ValueError, match="Unknown bitcode library"): find_bitcode_lib_module.locate_bitcode_lib("invalid") + + +@pytest.mark.parametrize( + "find_fn", + [ + find_bitcode_lib, + locate_bitcode_lib, + ], +) +def test_bitcode_lib_sm_arch_is_keyword_only(find_fn): + with pytest.raises(TypeError): + find_fn("nvshmem_device", "sm90") + + +@pytest.mark.parametrize( + "sm_arch", + [ + "", + "../sm90", + "compute90", + "sm_90", + "sm", + "sm90/extra", + "sm90A", + ], +) +def test_find_bitcode_lib_invalid_sm_arch(sm_arch): + expected_pattern = find_bitcode_lib_module._SM_ARCH_PATTERN.pattern + with pytest.raises(ValueError) as exc_info: + find_bitcode_lib_module.locate_bitcode_lib("device", sm_arch=sm_arch) + assert f"must match {expected_pattern!r}" in str(exc_info.value)