diff --git a/modules/safe.py b/modules/safe.py index af019ffd980..9014043c6c8 100644 --- a/modules/safe.py +++ b/modules/safe.py @@ -37,6 +37,9 @@ def find_class(self, module, name): if res is not None: return res + class Empty: + pass + if module == 'collections' and name == 'OrderedDict': return getattr(collections, name) if module == 'torch._utils' and name in ['_rebuild_tensor_v2', '_rebuild_parameter', '_rebuild_device_tensor_from_numpy']: @@ -51,12 +54,8 @@ def find_class(self, module, name): return getattr(numpy, name) if module == '_codecs' and name == 'encode': return encode - if module == "pytorch_lightning.callbacks" and name == 'model_checkpoint': - import pytorch_lightning.callbacks - return pytorch_lightning.callbacks.model_checkpoint - if module == "pytorch_lightning.callbacks.model_checkpoint" and name == 'ModelCheckpoint': - import pytorch_lightning.callbacks.model_checkpoint - return pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint + if module.startswith("pytorch_lightning"): + return Empty if module == "__builtin__" and name == 'set': return set