Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] torch geometric layers not working in policy network #2679

Open
rerz opened this issue Dec 29, 2024 · 0 comments
Open

[BUG] torch geometric layers not working in policy network #2679

rerz opened this issue Dec 29, 2024 · 0 comments
Assignees
Labels
bug Something isn't working

Comments

@rerz
Copy link

rerz commented Dec 29, 2024

Describe the bug

I am trying to integrate torch geometric layers into my policy network and I think I am running into a variant of #1613

To Reproduce

Include any torch geometric layer (i used torch_geometric.nn.Linear) in a policy and try to collect data from it with a SyncDataCollector.

from torch_geometric.nn import Linear
from torch import nn

class GNNMessageNet(nn.Module):
    def __init__(self):
        nn.Module.__init__(self)

        self.conv = Linear(10, 32)
Traceback (most recent call last):
  File "/home/rerz/.local/share/JetBrains/IntelliJIdea2024.2/python-ce/helpers/pydev/pydevd.py", line 1570, in _exec
    pydev_imports.execfile(file, globals, locals)  # execute the script
    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/rerz/.local/share/JetBrains/IntelliJIdea2024.2/python-ce/helpers/pydev/_pydev_imps/_pydev_execfile.py", line 18, in execfile
    exec(compile(contents+"\n", file, 'exec'), glob, loc)
  File "<project>/packages/gridworks/src/gridworks/task/train.py", line 101, in <module>
    collector = SyncDataCollector(
                ^^^^^^^^^^^^^^^^^^
  File "<project>/.venv/lib/python3.12/site-packages/torchrl/collectors/collectors.py", line 654, in __init__
    (self.policy, self.get_weights_fn,) = self._get_policy_and_device(
                                          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "<project>/.venv/lib/python3.12/site-packages/torchrl/collectors/collectors.py", line 236, in _get_policy_and_device
    policy = deepcopy(policy)
             ^^^^^^^^^^^^^^^^
  File "/usr/lib/python3.12/copy.py", line 162, in deepcopy
    y = _reconstruct(x, memo, *rv)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/lib/python3.12/copy.py", line 259, in _reconstruct
    state = deepcopy(state, memo)
            ^^^^^^^^^^^^^^^^^^^^^
  File "/usr/lib/python3.12/copy.py", line 136, in deepcopy
    y = copier(x, memo)
        ^^^^^^^^^^^^^^^
  File "/usr/lib/python3.12/copy.py", line 221, in _deepcopy_dict
    y[deepcopy(key, memo)] = deepcopy(value, memo)
                             ^^^^^^^^^^^^^^^^^^^^^
  File "/usr/lib/python3.12/copy.py", line 162, in deepcopy
    y = _reconstruct(x, memo, *rv)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/lib/python3.12/copy.py", line 259, in _reconstruct
    state = deepcopy(state, memo)
            ^^^^^^^^^^^^^^^^^^^^^
  File "/usr/lib/python3.12/copy.py", line 136, in deepcopy
    y = copier(x, memo)
        ^^^^^^^^^^^^^^^
  File "/usr/lib/python3.12/copy.py", line 221, in _deepcopy_dict
    y[deepcopy(key, memo)] = deepcopy(value, memo)
                             ^^^^^^^^^^^^^^^^^^^^^
  File "/usr/lib/python3.12/copy.py", line 136, in deepcopy
    y = copier(x, memo)
        ^^^^^^^^^^^^^^^
  File "/usr/lib/python3.12/copy.py", line 221, in _deepcopy_dict
    y[deepcopy(key, memo)] = deepcopy(value, memo)
                             ^^^^^^^^^^^^^^^^^^^^^
  File "/usr/lib/python3.12/copy.py", line 162, in deepcopy
    y = _reconstruct(x, memo, *rv)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/lib/python3.12/copy.py", line 259, in _reconstruct
    state = deepcopy(state, memo)
            ^^^^^^^^^^^^^^^^^^^^^
  File "/usr/lib/python3.12/copy.py", line 136, in deepcopy
    y = copier(x, memo)
        ^^^^^^^^^^^^^^^
  File "/usr/lib/python3.12/copy.py", line 221, in _deepcopy_dict
    y[deepcopy(key, memo)] = deepcopy(value, memo)
                             ^^^^^^^^^^^^^^^^^^^^^
  File "/usr/lib/python3.12/copy.py", line 136, in deepcopy
    y = copier(x, memo)
        ^^^^^^^^^^^^^^^
  File "/usr/lib/python3.12/copy.py", line 221, in _deepcopy_dict
    y[deepcopy(key, memo)] = deepcopy(value, memo)
                             ^^^^^^^^^^^^^^^^^^^^^
  File "/usr/lib/python3.12/copy.py", line 162, in deepcopy
    y = _reconstruct(x, memo, *rv)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/lib/python3.12/copy.py", line 259, in _reconstruct
    state = deepcopy(state, memo)
            ^^^^^^^^^^^^^^^^^^^^^
  File "/usr/lib/python3.12/copy.py", line 136, in deepcopy
    y = copier(x, memo)
        ^^^^^^^^^^^^^^^
  File "/usr/lib/python3.12/copy.py", line 221, in _deepcopy_dict
    y[deepcopy(key, memo)] = deepcopy(value, memo)
                             ^^^^^^^^^^^^^^^^^^^^^
  File "/usr/lib/python3.12/copy.py", line 136, in deepcopy
    y = copier(x, memo)
        ^^^^^^^^^^^^^^^
  File "/usr/lib/python3.12/copy.py", line 221, in _deepcopy_dict
    y[deepcopy(key, memo)] = deepcopy(value, memo)
                             ^^^^^^^^^^^^^^^^^^^^^
  File "/usr/lib/python3.12/copy.py", line 162, in deepcopy
    y = _reconstruct(x, memo, *rv)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/lib/python3.12/copy.py", line 259, in _reconstruct
    state = deepcopy(state, memo)
            ^^^^^^^^^^^^^^^^^^^^^
  File "/usr/lib/python3.12/copy.py", line 136, in deepcopy
    y = copier(x, memo)
        ^^^^^^^^^^^^^^^
  File "/usr/lib/python3.12/copy.py", line 221, in _deepcopy_dict
    y[deepcopy(key, memo)] = deepcopy(value, memo)
                             ^^^^^^^^^^^^^^^^^^^^^
  File "/usr/lib/python3.12/copy.py", line 136, in deepcopy
    y = copier(x, memo)
        ^^^^^^^^^^^^^^^
  File "/usr/lib/python3.12/copy.py", line 221, in _deepcopy_dict
    y[deepcopy(key, memo)] = deepcopy(value, memo)
                             ^^^^^^^^^^^^^^^^^^^^^
  File "/usr/lib/python3.12/copy.py", line 162, in deepcopy
    y = _reconstruct(x, memo, *rv)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/lib/python3.12/copy.py", line 259, in _reconstruct
    state = deepcopy(state, memo)
            ^^^^^^^^^^^^^^^^^^^^^
  File "/usr/lib/python3.12/copy.py", line 136, in deepcopy
    y = copier(x, memo)
        ^^^^^^^^^^^^^^^
  File "/usr/lib/python3.12/copy.py", line 221, in _deepcopy_dict
    y[deepcopy(key, memo)] = deepcopy(value, memo)
                             ^^^^^^^^^^^^^^^^^^^^^
  File "/usr/lib/python3.12/copy.py", line 136, in deepcopy
    y = copier(x, memo)
        ^^^^^^^^^^^^^^^
  File "/usr/lib/python3.12/copy.py", line 221, in _deepcopy_dict
    y[deepcopy(key, memo)] = deepcopy(value, memo)
                             ^^^^^^^^^^^^^^^^^^^^^
  File "/usr/lib/python3.12/copy.py", line 162, in deepcopy
    y = _reconstruct(x, memo, *rv)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/lib/python3.12/copy.py", line 259, in _reconstruct
    state = deepcopy(state, memo)
            ^^^^^^^^^^^^^^^^^^^^^
  File "/usr/lib/python3.12/copy.py", line 136, in deepcopy
    y = copier(x, memo)
        ^^^^^^^^^^^^^^^
  File "/usr/lib/python3.12/copy.py", line 221, in _deepcopy_dict
    y[deepcopy(key, memo)] = deepcopy(value, memo)
                             ^^^^^^^^^^^^^^^^^^^^^
  File "/usr/lib/python3.12/copy.py", line 136, in deepcopy
    y = copier(x, memo)
        ^^^^^^^^^^^^^^^
  File "/usr/lib/python3.12/copy.py", line 221, in _deepcopy_dict
    y[deepcopy(key, memo)] = deepcopy(value, memo)
                             ^^^^^^^^^^^^^^^^^^^^^
  File "/usr/lib/python3.12/copy.py", line 162, in deepcopy
    y = _reconstruct(x, memo, *rv)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/lib/python3.12/copy.py", line 259, in _reconstruct
    state = deepcopy(state, memo)
            ^^^^^^^^^^^^^^^^^^^^^
  File "/usr/lib/python3.12/copy.py", line 136, in deepcopy
    y = copier(x, memo)
        ^^^^^^^^^^^^^^^
  File "/usr/lib/python3.12/copy.py", line 221, in _deepcopy_dict
    y[deepcopy(key, memo)] = deepcopy(value, memo)
                             ^^^^^^^^^^^^^^^^^^^^^
  File "/usr/lib/python3.12/copy.py", line 136, in deepcopy
    y = copier(x, memo)
        ^^^^^^^^^^^^^^^
  File "/usr/lib/python3.12/copy.py", line 221, in _deepcopy_dict
    y[deepcopy(key, memo)] = deepcopy(value, memo)
                             ^^^^^^^^^^^^^^^^^^^^^
  File "/usr/lib/python3.12/copy.py", line 162, in deepcopy
    y = _reconstruct(x, memo, *rv)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/lib/python3.12/copy.py", line 259, in _reconstruct
    state = deepcopy(state, memo)
            ^^^^^^^^^^^^^^^^^^^^^
  File "/usr/lib/python3.12/copy.py", line 136, in deepcopy
    y = copier(x, memo)
        ^^^^^^^^^^^^^^^
  File "/usr/lib/python3.12/copy.py", line 221, in _deepcopy_dict
    y[deepcopy(key, memo)] = deepcopy(value, memo)
                             ^^^^^^^^^^^^^^^^^^^^^
  File "/usr/lib/python3.12/copy.py", line 136, in deepcopy
    y = copier(x, memo)
        ^^^^^^^^^^^^^^^
  File "/usr/lib/python3.12/copy.py", line 221, in _deepcopy_dict
    y[deepcopy(key, memo)] = deepcopy(value, memo)
                             ^^^^^^^^^^^^^^^^^^^^^
  File "/usr/lib/python3.12/copy.py", line 143, in deepcopy
    y = copier(memo)
        ^^^^^^^^^^^^
  File "<project>/.venv/lib/python3.12/site-packages/torch_geometric/nn/dense/linear.py", line 129, in __deepcopy__
    out.weight = copy.deepcopy(self.weight, memo)
    ^^^^^^^^^^
  File "<project>/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1959, in __setattr__
    raise TypeError(
TypeError: cannot assign 'torch.meta.FloatTensor' as parameter 'weight' (torch.nn.Parameter or None expected)
  File ... 

Expected behavior

I can use torch geometric layers in my policy as any other torch layer.

System info

Describe the characteristic of your environment:

  • Python 3.12
  • torchrl/tensordict main branch
  • torch_geometric 2.6.1

Additional context

Issue seems to be happening while deepcopying the torch geometric layer.

Checklist

  • [x ] I have checked that there is no similar issue in the repo (required)
  • [x ] I have read the documentation (required)
  • [x ] I have provided a minimal working example to reproduce the bug (required)
@rerz rerz added the bug Something isn't working label Dec 29, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants