mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
89 lines
2.7 KiB
Python
89 lines
2.7 KiB
Python
from gym.spaces import Box
|
|
import numpy as np
|
|
import random
|
|
import tree # pip install dm_tree
|
|
|
|
from ray.rllib.policy.policy import Policy
|
|
from ray.rllib.policy.sample_batch import SampleBatch
|
|
from ray.rllib.utils.annotations import override
|
|
from ray.rllib.utils.typing import ModelWeights
|
|
|
|
|
|
class RandomPolicy(Policy):
|
|
"""Hand-coded policy that returns random actions."""
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
super().__init__(*args, **kwargs)
|
|
|
|
# Whether for compute_actions, the bounds given in action_space
|
|
# should be ignored (default: False). This is to test action-clipping
|
|
# and any Env's reaction to bounds breaches.
|
|
if self.config.get("ignore_action_bounds", False) and isinstance(
|
|
self.action_space, Box
|
|
):
|
|
self.action_space_for_sampling = Box(
|
|
-float("inf"),
|
|
float("inf"),
|
|
shape=self.action_space.shape,
|
|
dtype=self.action_space.dtype,
|
|
)
|
|
else:
|
|
self.action_space_for_sampling = self.action_space
|
|
|
|
@override(Policy)
|
|
def init_view_requirements(self):
|
|
super().init_view_requirements()
|
|
# Disable for_training and action attributes for SampleBatch.INFOS column
|
|
# since it can not be properly batched.
|
|
vr = self.view_requirements[SampleBatch.INFOS]
|
|
vr.used_for_training = False
|
|
vr.used_for_compute_actions = False
|
|
|
|
@override(Policy)
|
|
def compute_actions(
|
|
self,
|
|
obs_batch,
|
|
state_batches=None,
|
|
prev_action_batch=None,
|
|
prev_reward_batch=None,
|
|
**kwargs
|
|
):
|
|
# Alternatively, a numpy array would work here as well.
|
|
# e.g.: np.array([random.choice([0, 1])] * len(obs_batch))
|
|
return [self.action_space_for_sampling.sample() for _ in obs_batch], [], {}
|
|
|
|
@override(Policy)
|
|
def learn_on_batch(self, samples):
|
|
"""No learning."""
|
|
return {}
|
|
|
|
@override(Policy)
|
|
def compute_log_likelihoods(
|
|
self,
|
|
actions,
|
|
obs_batch,
|
|
state_batches=None,
|
|
prev_action_batch=None,
|
|
prev_reward_batch=None,
|
|
):
|
|
return np.array([random.random()] * len(obs_batch))
|
|
|
|
@override(Policy)
|
|
def get_weights(self) -> ModelWeights:
|
|
"""No weights to save."""
|
|
return {}
|
|
|
|
@override(Policy)
|
|
def set_weights(self, weights: ModelWeights) -> None:
|
|
"""No weights to set."""
|
|
pass
|
|
|
|
@override(Policy)
|
|
def _get_dummy_batch_from_view_requirements(self, batch_size: int = 1):
|
|
return SampleBatch(
|
|
{
|
|
SampleBatch.OBS: tree.map_structure(
|
|
lambda s: s[None], self.observation_space.sample()
|
|
),
|
|
}
|
|
)
|