mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
61 lines
2 KiB
Python
61 lines
2 KiB
Python
import numpy as np
|
|
|
|
import ray
|
|
import ray.rllib.agents.ppo as ppo
|
|
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
|
|
from ray.rllib.models.catalog import ModelCatalog
|
|
from ray.rllib.utils.framework import try_import_torch
|
|
|
|
torch, _ = try_import_torch()
|
|
|
|
# __sphinx_doc_begin__
|
|
|
|
|
|
# The custom model that will be wrapped by an LSTM.
|
|
class MyCustomModel(TorchModelV2):
|
|
def __init__(self, obs_space, action_space, num_outputs, model_config, name):
|
|
super().__init__(obs_space, action_space, num_outputs, model_config, name)
|
|
self.num_outputs = int(np.product(self.obs_space.shape))
|
|
self._last_batch_size = None
|
|
|
|
# Implement your own forward logic, whose output will then be sent
|
|
# through an LSTM.
|
|
def forward(self, input_dict, state, seq_lens):
|
|
obs = input_dict["obs_flat"]
|
|
# Store last batch size for value_function output.
|
|
self._last_batch_size = obs.shape[0]
|
|
# Return 2x the obs (and empty states).
|
|
# This will further be sent through an automatically provided
|
|
# LSTM head (b/c we are setting use_lstm=True below).
|
|
return obs * 2.0, []
|
|
|
|
def value_function(self):
|
|
return torch.from_numpy(np.zeros(shape=(self._last_batch_size,)))
|
|
|
|
|
|
if __name__ == "__main__":
|
|
ray.init()
|
|
|
|
# Register the above custom model.
|
|
ModelCatalog.register_custom_model("my_torch_model", MyCustomModel)
|
|
|
|
# Create the Trainer.
|
|
trainer = ppo.PPOTrainer(
|
|
env="CartPole-v0",
|
|
config={
|
|
"framework": "torch",
|
|
"model": {
|
|
# Auto-wrap the custom(!) model with an LSTM.
|
|
"use_lstm": True,
|
|
# To further customize the LSTM auto-wrapper.
|
|
"lstm_cell_size": 64,
|
|
# Specify our custom model from above.
|
|
"custom_model": "my_torch_model",
|
|
# Extra kwargs to be passed to your model's c'tor.
|
|
"custom_model_config": {},
|
|
},
|
|
},
|
|
)
|
|
trainer.train()
|
|
|
|
# __sphinx_doc_end__
|