Skip to content

Commit

Permalink
lint, add init_cuda_malloc()
Browse files Browse the repository at this point in the history
  • Loading branch information
wkpark committed Oct 20, 2024
1 parent e78be27 commit 0cc8146
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 15 deletions.
28 changes: 14 additions & 14 deletions cuda_malloc.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,9 @@ def enum_display_devices():
else:
gpu_names = set()
out = subprocess.check_output(['nvidia-smi', '-L'])
for l in out.split(b'\n'):
if len(l) > 0:
gpu_names.add(l.decode('utf-8').split(' (UUID')[0])
for line in out.split(b'\n'):
if len(line) > 0:
gpu_names.add(line.decode('utf-8').split(' (UUID')[0])
return gpu_names

blacklist = {"GeForce GTX TITAN X", "GeForce GTX 980", "GeForce GTX 970", "GeForce GTX 960", "GeForce GTX 950", "GeForce 945M",
Expand All @@ -55,7 +55,7 @@ def enum_display_devices():
def cuda_malloc_supported():
try:
names = get_gpu_names()
except:
except Exception:
names = set()
for x in names:
if "NVIDIA" in x:
Expand All @@ -82,16 +82,16 @@ def cuda_malloc_supported():
version = module.__version__
if int(version[0]) >= 2: #enable by default for torch version 2.0 and up
args.cuda_malloc = cuda_malloc_supported()
except:
except Exception:
pass

def init_cuda_malloc():
if args.cuda_malloc and not args.disable_cuda_malloc:
env_var = os.environ.get('PYTORCH_CUDA_ALLOC_CONF', None)
if env_var is None:
env_var = "backend:cudaMallocAsync"
else:
env_var += ",backend:cudaMallocAsync"

if args.cuda_malloc and not args.disable_cuda_malloc:
env_var = os.environ.get('PYTORCH_CUDA_ALLOC_CONF', None)
if env_var is None:
env_var = "backend:cudaMallocAsync"
else:
env_var += ",backend:cudaMallocAsync"

os.environ['PYTORCH_CUDA_ALLOC_CONF'] = env_var
print(f"Setup environment PYTORCH_CUDA_ALLOC_CONF={env_var}")
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = env_var
print(f"Setup environment PYTORCH_CUDA_ALLOC_CONF={env_var}")
3 changes: 2 additions & 1 deletion webui.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@
startup_timer = timer.startup_timer
startup_timer.record("launcher")

import cuda_malloc
from cuda_malloc import init_cuda_malloc
init_cuda_malloc()
startup_timer.record("cuda_malloc")

initialize.imports()
Expand Down

0 comments on commit 0cc8146

Please sign in to comment.