From d6d3e03cd4868b1f69d4f334ddd341e4d414c342 Mon Sep 17 00:00:00 2001 From: Dmitry Rogozhkin Date: Tue, 24 Dec 2024 07:48:00 -0800 Subject: [PATCH] Use torch.xpu.mem_get_info for XPU (#3275) torch.xpu.mem_get_info API is available starting from PyTorch 2.6 (and in nightly 2.6.0.dev20241206+xpu or later). To work properly this method requires PyTorch built with the SYCL runtime which supports API to query device memory stats. If not available, exception will be raised. Requires: https://github.com/pytorch/pytorch/pull/141230 Fixes: #2929 Fixes: https://github.com/huggingface/transformers/issues/31922 Signed-off-by: Dmitry Rogozhkin --- src/accelerate/utils/memory.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/src/accelerate/utils/memory.py b/src/accelerate/utils/memory.py index 42e944d550b..ce220c1b8e4 100644 --- a/src/accelerate/utils/memory.py +++ b/src/accelerate/utils/memory.py @@ -174,7 +174,19 @@ def get_xpu_available_memory(device_index: int): from intel_extension_for_pytorch.xpu import mem_get_info return mem_get_info(device_index)[0] + elif version.parse(torch.__version__).release >= version.parse("2.6").release: + # torch.xpu.mem_get_info API is available starting from PyTorch 2.6 + # It further requires PyTorch built with the SYCL runtime which supports API + # to query available device memory. If not available, exception will be + # raised. Version of SYCL runtime used to build PyTorch is being reported + # with print(torch.version.xpu) and corresponds to the version of Intel DPC++ + # SYCL compiler. First version to support required feature is 20250001. + try: + return torch.xpu.mem_get_info(device_index)[0] + except Exception: + pass + warnings.warn( - "The XPU `mem_get_info` API is available in IPEX version >=2.5. The current returned available memory is incorrect. Please consider upgrading your IPEX version." + "The XPU `mem_get_info` API is available in IPEX version >=2.5 or PyTorch >=2.6. The current returned available memory is incorrect. Please consider upgrading your IPEX or PyTorch version." ) return torch.xpu.max_memory_allocated(device_index)