Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] SyncDataCollector Not Working With ParallelEnv when Built with Replay Buffer #2617

Open
3 tasks done
AlexandreBrown opened this issue Nov 29, 2024 · 0 comments
Open
3 tasks done
Assignees
Labels
bug Something isn't working

Comments

@AlexandreBrown
Copy link

AlexandreBrown commented Nov 29, 2024

Describe the bug

When creating a sync data collector with a replay buffer (passed to its constructor), then it crashes when yielding from the collector.

To Reproduce

  1. Create a ParallelEnv (I used num_workers=2 in my test)
train_env = ParallelEnv(
    num_workers=int(cfg["env"]["num_workers"]), create_env_fn=create_env_fn
)
  1. Create a replay buffer with his storage having ndim = 2
storage_kwargs = {}
storage_kwargs["max_size"] = capacity
storage_kwargs["device"] = storage_device
storage_dim = 1
if cfg["env"]["num_workers"] > 1:
    storage_dim += 1
storage_kwargs["ndim"] = storage_dim

if "cpu" in storage_device.type:
    # LazyMemmapStorage is only supported on CPU
    replay_buffer = TensorDictReplayBuffer(
        storage=LazyMemmapStorage(**storage_kwargs),
        transform=transform,
    )
else:
    replay_buffer = TensorDictReplayBuffer(
        storage=LazyTensorStorage(**storage_kwargs),
        transform=transform,
    )
  1. Create a SyncDataCollector and make sure to pass the replay buffer to its constructor :
max_frames_per_traj = cfg["env"]["max_frames_per_traj"]

frames_per_batch = 128

data_collector = SyncDataCollector(
        create_env_fn=env,
        policy=policy,
        total_frames=data_collector_cfg["total_frames"],
        max_frames_per_traj=max_frames_per_traj,
        frames_per_batch=frames_per_batch,
        env_device=cfg["env"]["device"],
        storing_device=cfg["storage_device"],
        policy_device=cfg["policy_device"],
        exploration_type=exploration_type,
        init_random_frames=data_collector_cfg.get("init_random_frames", 0),
        postproc=None,
        replay_buffer=replay_buffer,
    )
  1. Yield from the data collector
for _ in data_collector:
       pass
  1. Observe the crash :
Traceback (most recent call last):
  File "/home/mila/b/myuser/micromamba/envs/dmc_env/lib/python3.10/runpy.py", line 196, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/home/mila/b/myuser/micromamba/envs/dmc_env/lib/python3.10/runpy.py", line 86, in _run_code
    exec(code, run_globals)
  File "/home/mila/b/myuser/.vscode-server/extensions/ms-python.debugpy-2024.12.0-linux-x64/bundled/libs/debugpy/adapter/../../debugpy/launcher/../../debugpy/__main__.py", line 71, in <module>
    cli.main()
  File "/home/mila/b/myuser/.vscode-server/extensions/ms-python.debugpy-2024.12.0-linux-x64/bundled/libs/debugpy/adapter/../../debugpy/launcher/../../debugpy/../debugpy/server/cli.py", line 501, in main
    run()
  File "/home/mila/b/myuser/.vscode-server/extensions/ms-python.debugpy-2024.12.0-linux-x64/bundled/libs/debugpy/adapter/../../debugpy/launcher/../../debugpy/../debugpy/server/cli.py", line 351, in run_file
    runpy.run_path(target, run_name="__main__")
  File "/home/mila/b/myuser/.vscode-server/extensions/ms-python.debugpy-2024.12.0-linux-x64/bundled/libs/debugpy/_vendored/pydevd/_pydevd_bundle/pydevd_runpy.py", line 310, in run_path
    return _run_module_code(code, init_globals, run_name, pkg_name=pkg_name, script_name=fname)
  File "/home/mila/b/myuser/.vscode-server/extensions/ms-python.debugpy-2024.12.0-linux-x64/bundled/libs/debugpy/_vendored/pydevd/_pydevd_bundle/pydevd_runpy.py", line 127, in _run_module_code
    _run_code(code, mod_globals, init_globals, mod_name, mod_spec, pkg_name, script_name)
  File "/home/mila/b/myuser/.vscode-server/extensions/ms-python.debugpy-2024.12.0-linux-x64/bundled/libs/debugpy/_vendored/pydevd/_pydevd_bundle/pydevd_runpy.py", line 118, in _run_code
    exec(code, run_globals)
  File "scripts/train_rl.py", line 125, in <module>
    main()
  File "/home/mila/b/myuser/micromamba/envs/dmc_env/lib/python3.10/site-packages/hydra/main.py", line 94, in decorated_main
    _run_hydra(
  File "/home/mila/b/myuser/micromamba/envs/dmc_env/lib/python3.10/site-packages/hydra/_internal/utils.py", line 394, in _run_hydra
    _run_app(
  File "/home/mila/b/myuser/micromamba/envs/dmc_env/lib/python3.10/site-packages/hydra/_internal/utils.py", line 457, in _run_app
    run_and_report(
  File "/home/mila/b/myuser/micromamba/envs/dmc_env/lib/python3.10/site-packages/hydra/_internal/utils.py", line 223, in run_and_report
    raise ex
  File "/home/mila/b/myuser/micromamba/envs/dmc_env/lib/python3.10/site-packages/hydra/_internal/utils.py", line 220, in run_and_report
    return func()
  File "/home/mila/b/myuser/micromamba/envs/dmc_env/lib/python3.10/site-packages/hydra/_internal/utils.py", line 458, in <lambda>
    lambda: hydra.run(
  File "/home/mila/b/myuser/micromamba/envs/dmc_env/lib/python3.10/site-packages/hydra/_internal/hydra.py", line 132, in run
    _ = ret.return_value
  File "/home/mila/b/myuser/micromamba/envs/dmc_env/lib/python3.10/site-packages/hydra/core/utils.py", line 260, in return_value
    raise self._return_value
  File "/home/mila/b/myuser/micromamba/envs/dmc_env/lib/python3.10/site-packages/hydra/core/utils.py", line 186, in run_job
    ret.return_value = task_function(task_cfg)
  File "scripts/train_rl.py", line 118, in main
    trainer.train()
  File "/home/mila/b/myuser/SegDAC/segdac_dev/src/segdac_dev/trainers/rl_trainer.py", line 41, in train
    for _ in tqdm(self.train_data_collector, "Env Data Collection"):
  File "/home/mila/b/myuser/micromamba/envs/dmc_env/lib/python3.10/site-packages/tqdm/std.py", line 1181, in __iter__
    for obj in iterable:
  File "/home/mila/b/myuser/micromamba/envs/dmc_env/lib/python3.10/site-packages/torchrl/collectors/collectors.py", line 247, in __iter__
    yield from self.iterator()
  File "/home/mila/b/myuser/micromamba/envs/dmc_env/lib/python3.10/site-packages/torchrl/collectors/collectors.py", line 1035, in iterator
    tensordict_out = self.rollout()
  File "/home/mila/b/myuser/micromamba/envs/dmc_env/lib/python3.10/site-packages/torchrl/_utils.py", line 481, in unpack_rref_and_invoke_function
    return func(self, *args, **kwargs)
  File "/home/mila/b/myuser/micromamba/envs/dmc_env/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
  File "/home/mila/b/myuser/micromamba/envs/dmc_env/lib/python3.10/site-packages/torchrl/collectors/collectors.py", line 1177, in rollout
    self.replay_buffer.add(self._shuttle)
  File "/home/mila/b/myuser/micromamba/envs/dmc_env/lib/python3.10/site-packages/torchrl/data/replay_buffers/replay_buffers.py", line 1202, in add
    self._set_index_in_td(data, index)
  File "/home/mila/b/myuser/micromamba/envs/dmc_env/lib/python3.10/site-packages/torchrl/data/replay_buffers/replay_buffers.py", line 1246, in _set_index_in_td
    tensordict.set("index", expand_as_right(index, tensordict))
  File "/home/mila/b/myuser/micromamba/envs/dmc_env/lib/python3.10/site-packages/tensordict/utils.py", line 370, in expand_as_right
    raise RuntimeError(
RuntimeError: expand_as_right requires the destination tensor to have less dimensions than the input tensor, got tensor.ndimension()=2 and dest.ndimension()=1

Expected behavior

I would expect the same behavior as when we don't pass the replay buffer to the sync data collector and manually do :

for data in tqdm(self.train_data_collector, "Env Data Collection"):
            self.replay_buffer.extend(data)

System info

Describe the characteristic of your environment:

  • Describe how the library was installed (pip, source, ...): pip
  • Python version: 3.10
  • Versions of any other relevant libraries: pip list :
Package                   Version                                                       Editable project location
------------------------- ------------------------------------------------------------- ------------------------------------------
absl-py                   2.1.0
antlr4-python3-runtime    4.9.3
asttokens                 2.4.1
attrs                     24.2.0
av                        13.1.0
certifi                   2024.8.30
charset-normalizer        3.4.0
click                     8.1.7
clip                      1.0
cloudpickle               3.1.0
coloredlogs               15.0.1
comet-ml                  3.47.1
comm                      0.2.2
configobj                 5.0.9
contourpy                 1.3.1
cycler                    0.12.1
Cython                    3.0.11
debugpy                   1.8.9
decorator                 5.1.1
diffusers                 0.31.0
dm_control                1.0.25
dm-env                    1.6
dm-tree                   0.1.8
docker-pycreds            0.4.0
drqv2                     1.0.0                                                        
dulwich                   0.22.6
efficientvit              0.0.0
einops                    0.8.0
etils                     1.10.0
everett                   3.1.0
exceptiongroup            1.2.2
executing                 2.1.0
filelock                  3.16.1
flatbuffers               24.3.25
fonttools                 4.55.0
fsspec                    2024.10.0
ftfy                      6.3.1
gitdb                     4.0.11
GitPython                 3.1.43
glfw                      2.8.0
huggingface-hub           0.26.2
humanfriendly             10.0
hydra-core                1.3.2
idna                      3.10
igraph                    0.11.8
imageio                   2.36.0
importlib_metadata        8.5.0
importlib_resources       6.4.5
ipdb                      0.13.13
ipykernel                 6.29.5
ipython                   8.29.0
jedi                      0.19.2
Jinja2                    3.1.4
jsonschema                4.23.0
jsonschema-specifications 2024.10.1
jupyter_client            8.6.3
jupyter_core              5.7.2
kiwisolver                1.4.7
labmaze                   1.0.6
lazy_loader               0.4
lightning-utilities       0.11.9
loguru                    0.7.2
lvis                      0.5.3
lxml                      5.3.0
markdown-it-py            3.0.0
MarkupSafe                3.0.2
matplotlib                3.9.2
matplotlib-inline         0.1.7
mdurl                     0.1.2
mpmath                    1.3.0
mujoco                    3.2.5
nest-asyncio              1.6.0
networkx                  3.4.2
numpy                     2.1.3
nvidia-cublas-cu12        12.4.5.8
nvidia-cuda-cupti-cu12    12.4.127
nvidia-cuda-nvrtc-cu12    12.4.127
nvidia-cuda-runtime-cu12  12.4.127
nvidia-cudnn-cu12         9.1.0.70
nvidia-cufft-cu12         11.2.1.3
nvidia-curand-cu12        10.3.5.147
nvidia-cusolver-cu12      11.6.1.9
nvidia-cusparse-cu12      12.3.1.170
nvidia-nccl-cu12          2.21.5
nvidia-nvjitlink-cu12     12.4.127
nvidia-nvtx-cu12          12.4.127
omegaconf                 2.3.0
onnx                      1.17.0
onnxruntime               1.20.1
onnxsim                   0.4.36
opencv-python             4.10.0.84
opencv-python-headless    4.10.0.84
orjson                    3.10.12
packaging                 24.2
pandas                    2.2.3
parso                     0.8.4
pexpect                   4.9.0
pillow                    11.0.0
pip                       24.3.1
platformdirs              4.3.6
prompt_toolkit            3.0.48
protobuf                  5.28.3
psutil                    6.1.0
ptyprocess                0.7.0
pure_eval                 0.2.3
py-cpuinfo                9.0.0
pycocotools               2.0.8
Pygments                  2.18.0
PyOpenGL                  3.1.7
PyOpenGL-accelerate       3.1.7
pyparsing                 3.2.0
python-box                6.1.0
python-dateutil           2.9.0.post0
pytz                      2024.2
PyYAML                    6.0.2
pyzmq                     26.2.0
referencing               0.35.1
regex                     2024.11.6
requests                  2.32.3
requests-toolbelt         1.0.0
rich                      13.9.4
rpds-py                   0.21.0
ruamel.yaml               0.18.6
ruamel.yaml.clib          0.2.12
safetensors               0.4.5
scikit-image              0.24.0
scipy                     1.14.1
seaborn                   0.13.2
XXX                    0.0.1                                                         
XXX                0.0.1                                                        
segment_anything          1.0
semantic-version          2.10.0
sentry-sdk                2.19.0
setproctitle              1.3.4
setuptools                75.6.0
simplejson                3.19.3
six                       1.16.0
smmap                     5.0.1
stack-data                0.6.3
sympy                     1.13.1
tensordict                0.6.0
texttable                 1.7.0
tifffile                  2024.9.20
timm                      1.0.11
TinyNeuralNetwork         0.1.0.20241024123327+19e5f6dd0f6e391d3c3640cf46d28f47eb76d289
tokenizers                0.20.4
tomli                     2.1.0
torch                     2.5.0
torch-fidelity            0.3.0
torchaudio                2.5.0
torchmetrics              1.6.0
torchprofile              0.0.4
torchrl                   0.6.0
torchvision               0.20.0
tornado                   6.4.2
tqdm                      4.66.5
traitlets                 5.14.3
transformers              4.46.3
triton                    3.1.0
typing_extensions         4.12.2
tzdata                    2024.2
ultralytics               8.3.38
ultralytics-thop          2.0.12
urllib3                   2.2.3
wandb                     0.18.7
wcwidth                   0.2.13
wheel                     0.45.1
wrapt                     1.17.0
wurlitzer                 3.1.1
zipp                      3.21.0

Checklist

  • I have checked that there is no similar issue in the repo (required)
  • I have read the documentation (required)
  • I have provided a minimal working example to reproduce the bug (required)
@AlexandreBrown AlexandreBrown added the bug Something isn't working label Nov 29, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants