[RLlib] Add example script for bare metal Policy with custom view_requirements. (#17896)

This commit is contained in:
simonsays1980 2021-08-20 12:17:13 +02:00 committed by GitHub
parent 40330ca439
commit 60aee4a330
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 180 additions and 4 deletions

View file

@ -15,7 +15,7 @@ import ray.cluster_utils
import ray._private.profiling as profiling
from ray._private.test_utils import (client_test_enabled,
RayTestTimeoutException, SignalActor)
RayTestTimeoutException)
if client_test_enabled():
from ray.util.client import ray

View file

@ -1660,6 +1660,14 @@ py_test(
args = ["--as-test", "--framework=torch", "--stop-reward=150", "--num-cpus=4"]
)
py_test(
name = "examples/bare_metal_policy_with_custom_view_reqs",
main = "examples/bare_metal_policy_with_custom_view_reqs.py",
tags = ["team:ml", "examples", "examples_B"],
size = "small",
srcs = ["examples/bare_metal_policy_with_custom_view_reqs.py"],
)
py_test(
name = "examples/batch_norm_model_ppo_tf",
main = "examples/batch_norm_model.py",

View file

@ -563,9 +563,13 @@ class SimpleListCollector(SampleCollector):
data_list.append(self.agent_collectors[k].episode_id)
else:
if data_col not in buffers[k]:
fill_value = np.zeros_like(view_req.space.sample()) \
if isinstance(view_req.space, Space) else \
view_req.space
if view_req.data_col is not None:
space = policy.view_requirements[
view_req.data_col].space
else:
space = view_req.space
fill_value = np.zeros_like(space.sample()) \
if isinstance(space, Space) else space
self.agent_collectors[k]._build_buffers({
data_col: fill_value
})

View file

@ -0,0 +1,54 @@
import argparse
import ray
from ray.rllib.agents.trainer_template import build_trainer
from ray.rllib.examples.policy.bare_metal_policy_with_custom_view_reqs \
import BareMetalPolicyWithCustomViewReqs
def get_cli_args():
"""Create CLI parser and return parsed arguments"""
parser = argparse.ArgumentParser()
# general args
parser.add_argument(
"--run", default="PPO", help="The RLlib-registered algorithm to use.")
parser.add_argument("--num-cpus", type=int, default=3)
parser.add_argument(
"--local-mode",
action="store_true",
help="Init Ray in local mode for easier debugging.")
args = parser.parse_args()
print(f"Running with following CLI args: {args}")
return args
if __name__ == "__main__":
args = get_cli_args()
ray.init(num_cpus=args.num_cpus or None, local_mode=args.local_mode)
# Create q custom Trainer class using our custom Policy.
BareMetalPolicyTrainer = build_trainer(
name="MyPolicy", default_policy=BareMetalPolicyWithCustomViewReqs)
config = {
"env": "CartPole-v0",
"model": {
# Necessary to get the whole trajectory of 'state_in_0' in the
# sample batch.
"max_seq_len": 1,
},
"num_workers": 1,
# NOTE: Does this have consequences?
# I use it for not loading tensorflow/pytorch.
"framework": None,
"log_level": "DEBUG",
"create_env_on_driver": True,
}
# Train the Trainer with our policy.
my_trainer = BareMetalPolicyTrainer(config=config)
results = my_trainer.train()
print(results)

View file

@ -0,0 +1,110 @@
import numpy as np
from gym.spaces import Box
from ray.rllib.policy import Policy
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.policy.view_requirement import ViewRequirement
from ray.rllib.utils.annotations import override
from ray.rllib.utils.typing import ModelWeights
class BareMetalPolicyWithCustomViewReqs(Policy):
"""
This policy does not much with the state, but shows,
how the training# Trajectory View API can be used to
pass user-specific view requirements to RLlib.
"""
def __init__(self, observation_space, action_space, model_config, *args,
**kwargs):
super().__init__(observation_space, action_space, model_config, *args,
**kwargs)
self.observation_space = observation_space
self.action_space = action_space
self.state_size = 10
self.model_config = model_config or {}
space = Box(
low=-np.inf,
high=np.inf,
shape=(self.state_size, ),
dtype=np.float64)
# Set view requirements such that the policy state is held in
# memory for 2 environment steps.
self.view_requirements["state_in_0"] = \
ViewRequirement("state_out_0",
shift="-2:-1",
used_for_training=False,
used_for_compute_actions=True,
batch_repeat_value=1)
self.view_requirements["state_out_0"] = \
ViewRequirement(
space=space,
used_for_training=False,
used_for_compute_actions=True,
batch_repeat_value=1)
# Set the initial state. This is necessary for starting
# the policy.
def get_initial_state(self):
return [np.zeros((self.state_size, ), dtype=np.float32)]
def compute_actions(self,
obs_batch=None,
state_batches=None,
prev_action_batch=None,
prev_reward_batch=None,
info_batch=None,
episodes=None,
**kwargs):
# First dimension is the batch (in list), second is the number of
# states, and third is the shift. Fourth is the size of the state.
batch_size = state_batches[0].shape[0]
actions = np.array(
[self.action_space.sample() for _ in range(batch_size)])
new_state_batches = list(state_batches[0][0])
return actions, [new_state_batches], {}
def compute_actions_from_input_dict(self,
input_dict,
explore=None,
timestep=None,
episodes=None,
**kwargs):
# Access the `infos` key here so it'll show up here always during
# action sampling.
infos = input_dict.get("infos")
assert infos is not None
# Default implementation just passes obs, prev-a/r, and states on to
# `self.compute_actions()`.
state_batches = [
s for k, s in input_dict.items() if k[:9] == "state_in_"
]
# Make sure that two (shift="-2:-1") past states are contained in the
# state_batch.
assert state_batches[0].shape[1] == 2
assert state_batches[0].shape[2] == self.state_size
return self.compute_actions(
input_dict[SampleBatch.OBS],
state_batches,
prev_action_batch=input_dict.get(SampleBatch.PREV_ACTIONS),
prev_reward_batch=input_dict.get(SampleBatch.PREV_REWARDS),
info_batch=input_dict.get(SampleBatch.INFOS),
explore=explore,
timestep=timestep,
episodes=episodes,
**kwargs,
)
def learn_on_batch(self, samples):
return
@override(Policy)
def get_weights(self) -> ModelWeights:
"""No weights to save."""
return {}
@override(Policy)
def set_weights(self, weights: ModelWeights) -> None:
"""No weights to set."""
pass