Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unexpected Behavior: Fabric.load operates out-of-place on nested states #20208

Open
Markus28 opened this issue Aug 17, 2024 · 1 comment · May be fixed by #20210
Open

Unexpected Behavior: Fabric.load operates out-of-place on nested states #20208

Markus28 opened this issue Aug 17, 2024 · 1 comment · May be fixed by #20210
Labels
bug Something isn't working needs triage Waiting to be triaged by maintainers ver: 2.3.x

Comments

@Markus28
Copy link

Bug description

Usually, I would expect Fabric.load(ckpt_path, state) to operate in-place, i.e. for it to call the load_state_dict methods of the loadable objects in state. This does indeed happen when state is a flat dictionary. However, when state is nested (e.g. because one handles multiple models with corresponding optimizers), this changes all of a sudden. Specifically, the modules in the state dictionary will reference different objects after the call to load. I have provided a minimal reproducible example below.

To me this is definitely unexpected behavior and will lead to problems when we are using the references to the old modules when resuming training (e.g. imagine in the example below we use model during training and not state["generator"]["model"]). Currently one needs to explicitly update the old references, which is cumbersome and error-prone, especially when users are not aware of this behavior.

What version are you seeing the problem on?

v2.3

How to reproduce the bug

from lightning import Fabric
from torch import nn

if __name__ == "__main__":
    fabric = Fabric(accelerator="cpu")
    model = nn.Linear(2, 2)
    model = fabric.setup(model)
    state_flat = {"model": model}
    state_nested = {"generator": {"model": model, "optim": None}}
    fabric.save("flat.pt", state_flat)
    fabric.save("nested.pt", state_nested)

    fabric.load("flat.pt", state_flat, strict=True)
    assert model is state_flat["model"]     # This is fine

    fabric.load("nested.pt", state_nested, strict=True)
    assert model is state_nested["generator"]["model"]      # This will fail

Error messages and logs

No response

Environment

Current environment
  • CUDA:
    - GPU:
    - Quadro RTX 5000
    - Quadro RTX 5000
    - available: True
    - version: 12.1
  • Lightning:
    - graph-transformer-pytorch: 0.1.1
    - lightning: 2.3.0
    - lightning-utilities: 0.11.2
    - pytorch-lightning: 2.2.2
    - rotary-embedding-torch: 0.6.2
    - torch: 2.3.0
    - torch-geometric: 2.5.3
    - torchmetrics: 1.4.0.post0
  • Packages:
    - aiohttp: 3.9.5
    - aiosignal: 1.3.1
    - annotated-types: 0.7.0
    - antlr4-python3-runtime: 4.9.3
    - async-timeout: 4.0.3
    - attrs: 23.2.0
    - autoregressive-graph-generation: 0.0.0
    - beartype: 0.18.5
    - biopandas: 0.4.1
    - brotli: 1.1.0
    - certifi: 2024.6.2
    - cffi: 1.16.0
    - cfgv: 3.4.0
    - charset-normalizer: 3.3.2
    - click: 8.1.7
    - colorama: 0.4.6
    - contourpy: 1.2.1
    - cycler: 0.12.1
    - deepspeed: 0.14.4
    - distlib: 0.3.8
    - docker-pycreds: 0.4.0
    - docopt: 0.6.2
    - einops: 0.8.0
    - et-xmlfile: 1.1.0
    - fastavro: 1.9.5
    - filelock: 3.14.0
    - fonttools: 4.53.0
    - freesasa: 2.2.1
    - frozenlist: 1.4.1
    - fsspec: 2024.5.0
    - ftpretty: 0.4.0
    - gitdb: 4.0.11
    - gitpython: 3.1.43
    - goatools: 1.4.12
    - graph-transformer-pytorch: 0.1.1
    - gudhi: 3.9.0
    - h5py: 3.11.0
    - heapdict: 1.0.1
    - hjson: 3.1.0
    - hydra-core: 1.3.2
    - identify: 2.5.36
    - idna: 3.7
    - imageio: 2.34.1
    - jinja2: 3.1.4
    - joblib: 1.4.2
    - kiwisolver: 1.4.5
    - lightning: 2.3.0
    - lightning-utilities: 0.11.2
    - loguru: 0.7.2
    - markdown-it-py: 3.0.0
    - markupsafe: 2.1.5
    - matplotlib: 3.8.4
    - mdurl: 0.1.2
    - mpmath: 1.3.0
    - multidict: 6.0.5
    - munkres: 1.1.4
    - networkx: 3.3
    - ninja: 1.11.1.1
    - nodeenv: 1.9.0
    - numpy: 1.26.4
    - nvidia-cublas-cu12: 12.1.3.1
    - nvidia-cuda-cupti-cu12: 12.1.105
    - nvidia-cuda-nvrtc-cu12: 12.1.105
    - nvidia-cuda-runtime-cu12: 12.1.105
    - nvidia-cudnn-cu12: 8.9.2.26
    - nvidia-cufft-cu12: 11.0.2.54
    - nvidia-curand-cu12: 10.3.2.106
    - nvidia-cusolver-cu12: 11.4.5.107
    - nvidia-cusparse-cu12: 12.1.0.106
    - nvidia-ml-py: 12.555.43
    - nvidia-nccl-cu12: 2.20.5
    - nvidia-nvjitlink-cu12: 12.5.40
    - nvidia-nvtx-cu12: 12.1.105
    - omegaconf: 2.3.0
    - openpyxl: 3.1.5
    - packaging: 24.0
    - pandas: 2.2.2
    - patsy: 0.5.6
    - pillow: 10.3.0
    - pip: 24.0
    - platformdirs: 4.2.2
    - pot: 0.9.3
    - pre-commit: 3.7.1
    - proteinshake: 0.3.14
    - protobuf: 4.25.3
    - psutil: 5.9.8
    - py-cpuinfo: 9.0.0
    - pycairo: 1.25.0
    - pycparser: 2.22
    - pydantic: 2.7.4
    - pydantic-core: 2.18.4
    - pydot: 3.0.1
    - pyemd: 1.0.0
    - pygments: 2.18.0
    - pygobject: 3.46.0
    - pygsp: 0.5.1
    - pyparsing: 3.1.2
    - pysocks: 1.7.1
    - python-dateutil: 2.9.0
    - pytorch-lightning: 2.2.2
    - pytz: 2024.1
    - pyyaml: 6.0.1
    - rdkit: 2024.3.3
    - rdkit-pypi: 2022.9.5
    - requests: 2.32.3
    - rich: 13.7.1
    - rotary-embedding-torch: 0.6.2
    - rustworkx: 0.14.2
    - scikit-learn: 1.5.0
    - scipy: 1.13.1
    - sentry-sdk: 2.3.1
    - setproctitle: 1.3.3
    - setuptools: 69.5.1
    - six: 1.16.0
    - smmap: 5.0.1
    - statsmodels: 0.14.2
    - sympy: 1.12.1
    - threadpoolctl: 3.5.0
    - torch: 2.3.0
    - torch-geometric: 2.5.3
    - torchmetrics: 1.4.0.post0
    - tqdm: 4.66.4
    - triton: 2.3.0
    - typing-extensions: 4.12.1
    - tzdata: 2024.1
    - unicodedata2: 15.1.0
    - urllib3: 2.2.1
    - virtualenv: 20.26.2
    - wandb: 0.17.0
    - wheel: 0.43.0
    - xlsxwriter: 3.2.0
    - yarl: 1.9.4
    - zstandard: 0.22.0
  • System:
    - OS: Linux
    - architecture:
    - 64bit
    - ELF
    - processor: x86_64
    - python: 3.10.14
    - release: 5.14.21-150400.24.97-default
    - version: Proposal for help #1 SMP PREEMPT_DYNAMIC Fri Oct 27 10:29:06 UTC 2023 (8546fda)

More info

No response

@Markus28 Markus28 added bug Something isn't working needs triage Waiting to be triaged by maintainers labels Aug 17, 2024
@Markus28
Copy link
Author

Markus28 commented Aug 18, 2024

The issues actually run even deeper. When you look at the types of state_flat["model"] and state_nested["generator"]["model"], you will find that the first is a lightning.fabric.wrappers._FabricModule while the second is a plain nn.Linear module.

@Markus28 Markus28 linked a pull request Aug 18, 2024 that will close this issue
11 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working needs triage Waiting to be triaged by maintainers ver: 2.3.x
Projects
None yet
Development

Successfully merging a pull request may close this issue.

1 participant