mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
147 lines
6.2 KiB
Python
147 lines
6.2 KiB
Python
![]() |
import gym
|
||
|
from typing import Optional, List, Dict
|
||
|
|
||
|
from ray.rllib.agents.sac.sac_torch_model import SACTorchModel
|
||
|
from ray.rllib.models.modelv2 import ModelV2
|
||
|
from ray.rllib.policy.sample_batch import SampleBatch
|
||
|
from ray.rllib.policy.view_requirement import ViewRequirement
|
||
|
from ray.rllib.utils import override, force_list
|
||
|
from ray.rllib.utils.framework import try_import_torch
|
||
|
from ray.rllib.utils.typing import ModelConfigDict, TensorType
|
||
|
|
||
|
torch, _ = try_import_torch()
|
||
|
|
||
|
|
||
|
class RNNSACTorchModel(SACTorchModel):
|
||
|
def __init__(self,
|
||
|
obs_space: gym.spaces.Space,
|
||
|
action_space: gym.spaces.Space,
|
||
|
num_outputs: Optional[int],
|
||
|
model_config: ModelConfigDict,
|
||
|
name: str,
|
||
|
policy_model_config: ModelConfigDict = None,
|
||
|
q_model_config: ModelConfigDict = None,
|
||
|
twin_q: bool = False,
|
||
|
initial_alpha: float = 1.0,
|
||
|
target_entropy: Optional[float] = None):
|
||
|
super().__init__(
|
||
|
obs_space=obs_space,
|
||
|
action_space=action_space,
|
||
|
num_outputs=num_outputs,
|
||
|
model_config=model_config,
|
||
|
name=name,
|
||
|
policy_model_config=policy_model_config,
|
||
|
q_model_config=q_model_config,
|
||
|
twin_q=twin_q,
|
||
|
initial_alpha=initial_alpha,
|
||
|
target_entropy=target_entropy)
|
||
|
self.use_prev_action = (model_config["lstm_use_prev_action"]
|
||
|
or policy_model_config["lstm_use_prev_action"]
|
||
|
or q_model_config["lstm_use_prev_action"])
|
||
|
|
||
|
self.use_prev_reward = (model_config["lstm_use_prev_reward"]
|
||
|
or policy_model_config["lstm_use_prev_reward"]
|
||
|
or q_model_config["lstm_use_prev_reward"])
|
||
|
if self.use_prev_action:
|
||
|
self.view_requirements[SampleBatch.PREV_ACTIONS] = \
|
||
|
ViewRequirement(SampleBatch.ACTIONS, space=self.action_space,
|
||
|
shift=-1)
|
||
|
if self.use_prev_reward:
|
||
|
self.view_requirements[SampleBatch.PREV_REWARDS] = \
|
||
|
ViewRequirement(SampleBatch.REWARDS, shift=-1)
|
||
|
|
||
|
@override(SACTorchModel)
|
||
|
def forward(self, input_dict: Dict[str, TensorType],
|
||
|
state: List[TensorType],
|
||
|
seq_lens: TensorType) -> (TensorType, List[TensorType]):
|
||
|
"""The common (Q-net and policy-net) forward pass.
|
||
|
|
||
|
NOTE: It is not(!) recommended to override this method as it would
|
||
|
introduce a shared pre-network, which would be updated by both
|
||
|
actor- and critic optimizers.
|
||
|
|
||
|
For rnn support remove input_dict filter and pass state and seq_lens
|
||
|
"""
|
||
|
model_out = {"obs": input_dict[SampleBatch.OBS]}
|
||
|
|
||
|
if self.use_prev_action:
|
||
|
model_out["prev_actions"] = input_dict[SampleBatch.PREV_ACTIONS]
|
||
|
if self.use_prev_reward:
|
||
|
model_out["prev_rewards"] = input_dict[SampleBatch.PREV_REWARDS]
|
||
|
|
||
|
return model_out, state
|
||
|
|
||
|
@override(SACTorchModel)
|
||
|
def _get_q_value(self, model_out: TensorType, actions, net,
|
||
|
state_in: List[TensorType],
|
||
|
seq_lens: TensorType) -> (TensorType, List[TensorType]):
|
||
|
# Continuous case -> concat actions to model_out.
|
||
|
if actions is not None:
|
||
|
if self.concat_obs_and_actions:
|
||
|
model_out[SampleBatch.OBS] = \
|
||
|
torch.cat([model_out[SampleBatch.OBS], actions], dim=-1)
|
||
|
else:
|
||
|
model_out[SampleBatch.OBS] = \
|
||
|
force_list(model_out[SampleBatch.OBS]) + [actions]
|
||
|
|
||
|
# Switch on training mode (when getting Q-values, we are usually in
|
||
|
# training).
|
||
|
model_out["is_training"] = True
|
||
|
|
||
|
out, state_out = net(model_out, state_in, seq_lens)
|
||
|
return out, state_out
|
||
|
|
||
|
@override(SACTorchModel)
|
||
|
def get_q_values(self,
|
||
|
model_out: TensorType,
|
||
|
state_in: List[TensorType],
|
||
|
seq_lens: TensorType,
|
||
|
actions: Optional[TensorType] = None) -> TensorType:
|
||
|
return self._get_q_value(model_out, actions, self.q_net, state_in,
|
||
|
seq_lens)
|
||
|
|
||
|
@override(SACTorchModel)
|
||
|
def get_twin_q_values(self,
|
||
|
model_out: TensorType,
|
||
|
state_in: List[TensorType],
|
||
|
seq_lens: TensorType,
|
||
|
actions: Optional[TensorType] = None) -> TensorType:
|
||
|
return self._get_q_value(model_out, actions, self.twin_q_net, state_in,
|
||
|
seq_lens)
|
||
|
|
||
|
@override(SACTorchModel)
|
||
|
def get_policy_output(
|
||
|
self, model_out: TensorType, state_in: List[TensorType],
|
||
|
seq_lens: TensorType) -> (TensorType, List[TensorType]):
|
||
|
return self.action_model(model_out, state_in, seq_lens)
|
||
|
|
||
|
@override(ModelV2)
|
||
|
def get_initial_state(self):
|
||
|
policy_initial_state = self.action_model.get_initial_state()
|
||
|
q_initial_state = self.q_net.get_initial_state()
|
||
|
if self.twin_q_net:
|
||
|
q_initial_state *= 2
|
||
|
return policy_initial_state + q_initial_state
|
||
|
|
||
|
def select_state(self, state_batch: List[TensorType],
|
||
|
net: List[str]) -> Dict[str, List[TensorType]]:
|
||
|
assert all(n in ["policy", "q", "twin_q"] for n in net), \
|
||
|
"Selected state must be either for policy, q or twin_q network"
|
||
|
policy_state_len = len(self.action_model.get_initial_state())
|
||
|
q_state_len = len(self.q_net.get_initial_state())
|
||
|
|
||
|
selected_state = {}
|
||
|
for n in net:
|
||
|
if n == "policy":
|
||
|
selected_state[n] = state_batch[:policy_state_len]
|
||
|
elif n == "q":
|
||
|
selected_state[n] = state_batch[policy_state_len:
|
||
|
policy_state_len + q_state_len]
|
||
|
elif n == "twin_q":
|
||
|
if self.twin_q_net:
|
||
|
selected_state[n] = state_batch[policy_state_len +
|
||
|
q_state_len:]
|
||
|
else:
|
||
|
selected_state[n] = []
|
||
|
return selected_state
|