Unexpected Behavior: Fabric.load
operates out-of-place on nested states
#20208
Labels
Fabric.load
operates out-of-place on nested states
#20208
Bug description
Usually, I would expect
Fabric.load(ckpt_path, state)
to operate in-place, i.e. for it to call theload_state_dict
methods of the loadable objects instate
. This does indeed happen whenstate
is a flat dictionary. However, whenstate
is nested (e.g. because one handles multiple models with corresponding optimizers), this changes all of a sudden. Specifically, the modules in thestate
dictionary will reference different objects after the call toload
. I have provided a minimal reproducible example below.To me this is definitely unexpected behavior and will lead to problems when we are using the references to the old modules when resuming training (e.g. imagine in the example below we use
model
during training and notstate["generator"]["model"]
). Currently one needs to explicitly update the old references, which is cumbersome and error-prone, especially when users are not aware of this behavior.What version are you seeing the problem on?
v2.3
How to reproduce the bug
Error messages and logs
No response
Environment
Current environment
- GPU:
- Quadro RTX 5000
- Quadro RTX 5000
- available: True
- version: 12.1
- graph-transformer-pytorch: 0.1.1
- lightning: 2.3.0
- lightning-utilities: 0.11.2
- pytorch-lightning: 2.2.2
- rotary-embedding-torch: 0.6.2
- torch: 2.3.0
- torch-geometric: 2.5.3
- torchmetrics: 1.4.0.post0
- aiohttp: 3.9.5
- aiosignal: 1.3.1
- annotated-types: 0.7.0
- antlr4-python3-runtime: 4.9.3
- async-timeout: 4.0.3
- attrs: 23.2.0
- autoregressive-graph-generation: 0.0.0
- beartype: 0.18.5
- biopandas: 0.4.1
- brotli: 1.1.0
- certifi: 2024.6.2
- cffi: 1.16.0
- cfgv: 3.4.0
- charset-normalizer: 3.3.2
- click: 8.1.7
- colorama: 0.4.6
- contourpy: 1.2.1
- cycler: 0.12.1
- deepspeed: 0.14.4
- distlib: 0.3.8
- docker-pycreds: 0.4.0
- docopt: 0.6.2
- einops: 0.8.0
- et-xmlfile: 1.1.0
- fastavro: 1.9.5
- filelock: 3.14.0
- fonttools: 4.53.0
- freesasa: 2.2.1
- frozenlist: 1.4.1
- fsspec: 2024.5.0
- ftpretty: 0.4.0
- gitdb: 4.0.11
- gitpython: 3.1.43
- goatools: 1.4.12
- graph-transformer-pytorch: 0.1.1
- gudhi: 3.9.0
- h5py: 3.11.0
- heapdict: 1.0.1
- hjson: 3.1.0
- hydra-core: 1.3.2
- identify: 2.5.36
- idna: 3.7
- imageio: 2.34.1
- jinja2: 3.1.4
- joblib: 1.4.2
- kiwisolver: 1.4.5
- lightning: 2.3.0
- lightning-utilities: 0.11.2
- loguru: 0.7.2
- markdown-it-py: 3.0.0
- markupsafe: 2.1.5
- matplotlib: 3.8.4
- mdurl: 0.1.2
- mpmath: 1.3.0
- multidict: 6.0.5
- munkres: 1.1.4
- networkx: 3.3
- ninja: 1.11.1.1
- nodeenv: 1.9.0
- numpy: 1.26.4
- nvidia-cublas-cu12: 12.1.3.1
- nvidia-cuda-cupti-cu12: 12.1.105
- nvidia-cuda-nvrtc-cu12: 12.1.105
- nvidia-cuda-runtime-cu12: 12.1.105
- nvidia-cudnn-cu12: 8.9.2.26
- nvidia-cufft-cu12: 11.0.2.54
- nvidia-curand-cu12: 10.3.2.106
- nvidia-cusolver-cu12: 11.4.5.107
- nvidia-cusparse-cu12: 12.1.0.106
- nvidia-ml-py: 12.555.43
- nvidia-nccl-cu12: 2.20.5
- nvidia-nvjitlink-cu12: 12.5.40
- nvidia-nvtx-cu12: 12.1.105
- omegaconf: 2.3.0
- openpyxl: 3.1.5
- packaging: 24.0
- pandas: 2.2.2
- patsy: 0.5.6
- pillow: 10.3.0
- pip: 24.0
- platformdirs: 4.2.2
- pot: 0.9.3
- pre-commit: 3.7.1
- proteinshake: 0.3.14
- protobuf: 4.25.3
- psutil: 5.9.8
- py-cpuinfo: 9.0.0
- pycairo: 1.25.0
- pycparser: 2.22
- pydantic: 2.7.4
- pydantic-core: 2.18.4
- pydot: 3.0.1
- pyemd: 1.0.0
- pygments: 2.18.0
- pygobject: 3.46.0
- pygsp: 0.5.1
- pyparsing: 3.1.2
- pysocks: 1.7.1
- python-dateutil: 2.9.0
- pytorch-lightning: 2.2.2
- pytz: 2024.1
- pyyaml: 6.0.1
- rdkit: 2024.3.3
- rdkit-pypi: 2022.9.5
- requests: 2.32.3
- rich: 13.7.1
- rotary-embedding-torch: 0.6.2
- rustworkx: 0.14.2
- scikit-learn: 1.5.0
- scipy: 1.13.1
- sentry-sdk: 2.3.1
- setproctitle: 1.3.3
- setuptools: 69.5.1
- six: 1.16.0
- smmap: 5.0.1
- statsmodels: 0.14.2
- sympy: 1.12.1
- threadpoolctl: 3.5.0
- torch: 2.3.0
- torch-geometric: 2.5.3
- torchmetrics: 1.4.0.post0
- tqdm: 4.66.4
- triton: 2.3.0
- typing-extensions: 4.12.1
- tzdata: 2024.1
- unicodedata2: 15.1.0
- urllib3: 2.2.1
- virtualenv: 20.26.2
- wandb: 0.17.0
- wheel: 0.43.0
- xlsxwriter: 3.2.0
- yarl: 1.9.4
- zstandard: 0.22.0
- OS: Linux
- architecture:
- 64bit
- ELF
- processor: x86_64
- python: 3.10.14
- release: 5.14.21-150400.24.97-default
- version: Proposal for help #1 SMP PREEMPT_DYNAMIC Fri Oct 27 10:29:06 UTC 2023 (8546fda)
More info
No response
The text was updated successfully, but these errors were encountered: