Skip to content

Commit

Permalink
Use torch.xpu.mem_get_info for XPU (#3275)
Browse files Browse the repository at this point in the history
torch.xpu.mem_get_info API is available starting from PyTorch 2.6 (and
in nightly 2.6.0.dev20241206+xpu or later). To work properly this method
requires PyTorch built with the SYCL runtime which supports API to query
device memory stats. If not available, exception will be raised.

Requires: pytorch/pytorch#141230
Fixes: #2929
Fixes: huggingface/transformers#31922

Signed-off-by: Dmitry Rogozhkin <[email protected]>
  • Loading branch information
dvrogozh authored Dec 24, 2024
1 parent acfbf72 commit d6d3e03
Showing 1 changed file with 13 additions and 1 deletion.
14 changes: 13 additions & 1 deletion src/accelerate/utils/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,19 @@ def get_xpu_available_memory(device_index: int):
from intel_extension_for_pytorch.xpu import mem_get_info

return mem_get_info(device_index)[0]
elif version.parse(torch.__version__).release >= version.parse("2.6").release:
# torch.xpu.mem_get_info API is available starting from PyTorch 2.6
# It further requires PyTorch built with the SYCL runtime which supports API
# to query available device memory. If not available, exception will be
# raised. Version of SYCL runtime used to build PyTorch is being reported
# with print(torch.version.xpu) and corresponds to the version of Intel DPC++
# SYCL compiler. First version to support required feature is 20250001.
try:
return torch.xpu.mem_get_info(device_index)[0]
except Exception:
pass

warnings.warn(
"The XPU `mem_get_info` API is available in IPEX version >=2.5. The current returned available memory is incorrect. Please consider upgrading your IPEX version."
"The XPU `mem_get_info` API is available in IPEX version >=2.5 or PyTorch >=2.6. The current returned available memory is incorrect. Please consider upgrading your IPEX or PyTorch version."
)
return torch.xpu.max_memory_allocated(device_index)

0 comments on commit d6d3e03

Please sign in to comment.