ray/rllib/examples/policy/random_policy.py

Ignoring revisions in .git-blame-ignore-revs. Click here to bypass and see the normal blame view.

90 lines
2.7 KiB
Python
Raw Normal View History

from gym.spaces import Box
import numpy as np
import random
import tree # pip install dm_tree
from ray.rllib.policy.policy import Policy
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.utils.annotations import override
from ray.rllib.utils.typing import ModelWeights
class RandomPolicy(Policy):
"""Hand-coded policy that returns random actions."""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# Whether for compute_actions, the bounds given in action_space
# should be ignored (default: False). This is to test action-clipping
# and any Env's reaction to bounds breaches.
if self.config.get("ignore_action_bounds", False) and isinstance(
self.action_space, Box
):
self.action_space_for_sampling = Box(
-float("inf"),
float("inf"),
shape=self.action_space.shape,
dtype=self.action_space.dtype,
)
else:
self.action_space_for_sampling = self.action_space
@override(Policy)
def init_view_requirements(self):
super().init_view_requirements()
# Disable for_training and action attributes for SampleBatch.INFOS column
# since it can not be properly batched.
vr = self.view_requirements[SampleBatch.INFOS]
vr.used_for_training = False
vr.used_for_compute_actions = False
@override(Policy)
def compute_actions(
self,
obs_batch,
state_batches=None,
prev_action_batch=None,
prev_reward_batch=None,
**kwargs
):
# Alternatively, a numpy array would work here as well.
# e.g.: np.array([random.choice([0, 1])] * len(obs_batch))
return [self.action_space_for_sampling.sample() for _ in obs_batch], [], {}
@override(Policy)
def learn_on_batch(self, samples):
"""No learning."""
return {}
@override(Policy)
def compute_log_likelihoods(
self,
actions,
obs_batch,
state_batches=None,
prev_action_batch=None,
prev_reward_batch=None,
):
return np.array([random.random()] * len(obs_batch))
@override(Policy)
def get_weights(self) -> ModelWeights:
"""No weights to save."""
return {}
@override(Policy)
def set_weights(self, weights: ModelWeights) -> None:
"""No weights to set."""
pass
@override(Policy)
def _get_dummy_batch_from_view_requirements(self, batch_size: int = 1):
return SampleBatch(
{
SampleBatch.OBS: tree.map_structure(
lambda s: s[None], self.observation_space.sample()
),
}
)