-
Notifications
You must be signed in to change notification settings - Fork 159
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
How to export trained model as a .pt (pytorch ) or ONNX model. #227
Comments
Although I do not know how Ray can do that directly, I tried to unwrap a Ray checkpoint and figured out its structure. |
RLlib's So the problem falls back to how to load the checkpoint MARLlib saved. I've personally wrote a script to load the from eval import load_model
ckpt = load_model(
{
"model_path": "best_model/checkpoint",
"params_path": "best_model/params.json",
}
)
env = marl.make_env(environment_name=ckpt.env_name, map_name=ckpt.map_name)
env_instance, env_info = env
# Change the policy name accordingly
policy = ckpt.trainer.get_policy("shared_policy")
policy.export_model("/directoty/to/save") PS: In case anybody want to know how to use the raw model: model = policy.model
state = policy.get_initial_state()
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def get_obs(env):
obs = env.observation_space.sample()
# Suppose observation is a dict. E.g.
# obs = {
# "action_mask": [0, 0, 1, 0],
# "obs": [1, 1, 4, 5, 1, 4],
# }
for key in obs:
obs[key] = torch.from_numpy(np.array([obs[key]])).to(DEVICE)
return obs
dummy_input = {
"input_dict": {"obs": get_obs(env_instance)},
"state": [torch.from_numpy(np.array(state)).to(DEVICE)],
"seq_lens": np.array([1])
}
output = model(**dummy_input) |
How to export trained model as a .pt (pytorch ) or ONNX model.
I have fully trained my model and want to deploy the model into the Unity ML agents Env. I have to export the trained model either in Pytorch or ONNX.
I could only see one option "algo.render()" in the documentation.
The text was updated successfully, but these errors were encountered: