Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] pettingzoo environment wrapper does not work with dictionary action spaces #2680

Open
rerz opened this issue Dec 29, 2024 · 0 comments
Assignees
Labels
bug Something isn't working

Comments

@rerz
Copy link

rerz commented Dec 29, 2024

Describe the bug

I am trying to use a parallel pettingzoo environment with a gym.spaces.Dict action space using a ProbabilisticActor to sample the actions. The _step_parallel method does not handle this case properly.

To Reproduce

Define a parallel pettingzoo environment with a dictionary action space and pass a tensordict where the leaves of the actions have shape (2, 2).

The environment in my test repo https://github.com/rerz/rltest should work for this purpose.

class Environment(pettingzoo.ParallelEnv):
    def __init__(self):
        pettingzoo.ParallelEnv.__init__(self)

        agents = [
            "agent_0",
            "agent_1",
        ]

        self.agents = agents
        self.possible_agents = agents

    def action_space(self, agent: AgentID) -> gymnasium.spaces.Space:
        return spaces.Dict([
            ("target", spaces.Box(0, 1, [2])),
            ("strength", spaces.Box(0, 1, [2])), 
            ("healing", spaces.Box(0, 100, [2])) 
        ])

env = PettingZooWrapper(
    env=Environment(),
    categorical_actions=False,
)
  File "<project>/rltest/.venv/lib/python3.10/site-packages/torchrl/collectors/collectors.py", line 1052, in iterator
    tensordict_out = self.rollout()
  File "<project>/rltest/.venv/lib/python3.10/site-packages/torchrl/_utils.py", line 546, in unpack_rref_and_invoke_function
    return func(self, *args, **kwargs)
  File "<project>/rltest/.venv/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
  File "<project>/rltest/.venv/lib/python3.10/site-packages/torchrl/collectors/collectors.py", line 1187, in rollout
    env_output, env_next_output = self.env.step_and_maybe_reset(env_input)
  File "<project>/rltest/.venv/lib/python3.10/site-packages/torchrl/envs/common.py", line 3257, in step_and_maybe_reset
    tensordict = self.step(tensordict)
  File "<project>/rltest/.venv/lib/python3.10/site-packages/torchrl/envs/common.py", line 1844, in step
    next_tensordict = self._step(tensordict)
  File "<project>/rltest/.venv/lib/python3.10/site-packages/torchrl/envs/transforms/transforms.py", line 794, in _step
    next_tensordict = self.base_env._step(tensordict_in)
  File "<project>/rltest/.venv/lib/python3.10/site-packages/torchrl/envs/libs/pettingzoo.py", line 600, in _step
    ) = self._step_parallel(tensordict)
  File "<project>/rltest/.venv/lib/python3.10/site-packages/torchrl/envs/libs/pettingzoo.py", line 738, in _step_parallel
    action_dict[agent] = group_action_np[index]
KeyError: 0

Expected behavior

I can use dictionary action spaces without workarounds.

System info

Describe the characteristic of your environment:

  • Python 3.12
  • torchrl/tensordict main branch

Reason and Possible fixes

This wrapper for the wrapper seems to work for Dict action spaces but probably breaks other cases:

class PettingZooDictWrapper(PettingZooWrapper):
    def __init__(self, env, **kwargs):
        PettingZooWrapper.__init__(self, env, **kwargs)

    def _step_parallel(
        self,
        tensordict: TensorDictBase,
    ) -> Tuple[Dict, Dict, Dict, Dict, Dict]:
        action_dict = {}
        for group, agents in self.group_map.items():
            group_action = tensordict.get((group, "action"))
            group_action_np = self.input_spec[
                "full_action_spec", group, "action"
            ].to_numpy(group_action)
            for index, agent in enumerate(agents):
                # extract agent actions from the dictionary
                action = { key: group_action_np[key][index] for key in list(group_action.keys()) }
                action_dict[agent] = action

        return self._env.step(action_dict)

Checklist

  • [x ] I have checked that there is no similar issue in the repo (required)
  • [x ] I have read the documentation (required)
  • [x ] I have provided a minimal working example to reproduce the bug (required)
@rerz rerz added the bug Something isn't working label Dec 29, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants