Skip to content

Commit

Permalink
add kiui.seed_everything
Browse files Browse the repository at this point in the history
  • Loading branch information
ashawkey committed Dec 1, 2023
1 parent 7b1f675 commit da4b086
Show file tree
Hide file tree
Showing 5 changed files with 83 additions and 11 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,4 @@ __pycache__/
build/
dist/

tmp_*
tmp*
17 changes: 12 additions & 5 deletions kiui/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,14 +59,21 @@ def retrieve_globals(verbose=False):
if "kiui" in g:
G = g
if verbose:
print(f"[KiuiKit-INFO] located global frame at {frame_id}")
print(f"[INFO] located global frame at {frame_id}")
# print(G)
break
frame_id += 1
if G is None:
raise RuntimeError(
"Cannot locate global frame, make sure you called exactly `import kiui`!"
)

def is_imported(target, verbose=False):

if G is None:
retrieve_globals(verbose)

return target in G

def try_import(target, sources, verbose=False):

Expand All @@ -75,7 +82,7 @@ def try_import(target, sources, verbose=False):

if target in G:
if verbose:
print(f"[KiuiKit-INFO] {target} is already present, skipped.")
print(f"[INFO] {target} is already present, skipped.")
return

if not isinstance(sources, list):
Expand All @@ -84,7 +91,7 @@ def try_import(target, sources, verbose=False):
for source in sources:
try:
if verbose:
print(f"[KiuiKit-INFO] try to import {source}")
print(f"[INFO] try to import {source}")

# (module, component) or ("module", component)
if isinstance(source, tuple):
Expand All @@ -102,11 +109,11 @@ def try_import(target, sources, verbose=False):
G[target] = source

if verbose:
print(f"[KiuiKit-INFO] succeed to import {source} as {target}")
print(f"[INFO] succeed to import {source} as {target}")
break

except ImportError as e:
print(f"[KiuiKit-WARN] failed to import {source} as {target}: {str(e)}")
print(f"[WARN] failed to import {source} as {target}: {str(e)}")


def import_libs(pack, verbose=False):
Expand Down
49 changes: 45 additions & 4 deletions kiui/utils.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,24 @@
import os
import sys
import glob
import tqdm
import cv2
import json
import pickle
import varname
from objprint import objstr
from rich.console import Console

import cv2
from PIL import Image

import torch
import numpy as np
from objprint import objstr
import torch

from rich.console import Console
from kiui.env import is_imported

''' utils
All functions will be automatically imported as kiui.<func>
'''

# inspect array like object x and report stats
def lo(*xs, verbose=0):
Expand Down Expand Up @@ -72,6 +79,40 @@ def _lo(x, name):
_lo(x, name)


def seed_everything(seed=42, verbose=False, strict=False):

os.environ['PYTHONHASHSEED'] = str(seed)

if is_imported('random'):
import random # still need to import it here
random.seed(seed)
if verbose: print(f'[INFO] set random.seed = {seed}')
else:
if verbose: print(f'[INFO] random not imported, skip setting seed')

# assume numpy is imported as np
if is_imported('np'):
import numpy as np
np.random.seed(seed)
if verbose: print(f'[INFO] set np.random.seed = {seed}')
else:
if verbose: print(f'[INFO] numpy not imported, skip setting seed')

if is_imported('torch'):
import torch
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
if verbose: print(f'[INFO] set torch.manual_seed = {seed}')

if strict:
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
torch.use_deterministic_algorithms(True)
if verbose: print(f'[INFO] set strict deterministic mode for torch.')
else:
if verbose: print(f'[INFO] torch not imported, skip setting seed')


def read_json(path):
with open(path, "r") as f:
return json.load(f)
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
if __name__ == "__main__":
setup(
name="kiui",
version="0.1.10",
version="0.1.11",
description="A toolkit for 3D vision",
long_description=open("README.md", encoding="utf-8").read(),
long_description_content_type="text/markdown",
Expand Down
24 changes: 24 additions & 0 deletions tests/test_seed.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import kiui

kiui.seed_everything(42, True)

import random
kiui.seed_everything(42, True)
a = random.random()
kiui.seed_everything(42, True)
b = random.random()
assert a == b

import numpy as np
kiui.seed_everything(42, True)
a = np.random.randn(10)
kiui.seed_everything(42, True)
b = np.random.randn(10)
assert np.allclose(a, b)

import torch
kiui.seed_everything(42, True)
a = torch.randn(10)
kiui.seed_everything(42, True)
b = torch.randn(10)
assert torch.allclose(a, b)

0 comments on commit da4b086

Please sign in to comment.