mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
[RLlib] Add example script for bare metal Policy with custom view_requirements
. (#17896)
This commit is contained in:
parent
40330ca439
commit
60aee4a330
5 changed files with 180 additions and 4 deletions
|
@ -15,7 +15,7 @@ import ray.cluster_utils
|
|||
|
||||
import ray._private.profiling as profiling
|
||||
from ray._private.test_utils import (client_test_enabled,
|
||||
RayTestTimeoutException, SignalActor)
|
||||
RayTestTimeoutException)
|
||||
|
||||
if client_test_enabled():
|
||||
from ray.util.client import ray
|
||||
|
|
|
@ -1660,6 +1660,14 @@ py_test(
|
|||
args = ["--as-test", "--framework=torch", "--stop-reward=150", "--num-cpus=4"]
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "examples/bare_metal_policy_with_custom_view_reqs",
|
||||
main = "examples/bare_metal_policy_with_custom_view_reqs.py",
|
||||
tags = ["team:ml", "examples", "examples_B"],
|
||||
size = "small",
|
||||
srcs = ["examples/bare_metal_policy_with_custom_view_reqs.py"],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "examples/batch_norm_model_ppo_tf",
|
||||
main = "examples/batch_norm_model.py",
|
||||
|
|
|
@ -563,9 +563,13 @@ class SimpleListCollector(SampleCollector):
|
|||
data_list.append(self.agent_collectors[k].episode_id)
|
||||
else:
|
||||
if data_col not in buffers[k]:
|
||||
fill_value = np.zeros_like(view_req.space.sample()) \
|
||||
if isinstance(view_req.space, Space) else \
|
||||
view_req.space
|
||||
if view_req.data_col is not None:
|
||||
space = policy.view_requirements[
|
||||
view_req.data_col].space
|
||||
else:
|
||||
space = view_req.space
|
||||
fill_value = np.zeros_like(space.sample()) \
|
||||
if isinstance(space, Space) else space
|
||||
self.agent_collectors[k]._build_buffers({
|
||||
data_col: fill_value
|
||||
})
|
||||
|
|
54
rllib/examples/bare_metal_policy_with_custom_view_reqs.py
Normal file
54
rllib/examples/bare_metal_policy_with_custom_view_reqs.py
Normal file
|
@ -0,0 +1,54 @@
|
|||
import argparse
|
||||
import ray
|
||||
|
||||
from ray.rllib.agents.trainer_template import build_trainer
|
||||
from ray.rllib.examples.policy.bare_metal_policy_with_custom_view_reqs \
|
||||
import BareMetalPolicyWithCustomViewReqs
|
||||
|
||||
|
||||
def get_cli_args():
|
||||
"""Create CLI parser and return parsed arguments"""
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
# general args
|
||||
parser.add_argument(
|
||||
"--run", default="PPO", help="The RLlib-registered algorithm to use.")
|
||||
parser.add_argument("--num-cpus", type=int, default=3)
|
||||
parser.add_argument(
|
||||
"--local-mode",
|
||||
action="store_true",
|
||||
help="Init Ray in local mode for easier debugging.")
|
||||
|
||||
args = parser.parse_args()
|
||||
print(f"Running with following CLI args: {args}")
|
||||
return args
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = get_cli_args()
|
||||
|
||||
ray.init(num_cpus=args.num_cpus or None, local_mode=args.local_mode)
|
||||
|
||||
# Create q custom Trainer class using our custom Policy.
|
||||
BareMetalPolicyTrainer = build_trainer(
|
||||
name="MyPolicy", default_policy=BareMetalPolicyWithCustomViewReqs)
|
||||
|
||||
config = {
|
||||
"env": "CartPole-v0",
|
||||
"model": {
|
||||
# Necessary to get the whole trajectory of 'state_in_0' in the
|
||||
# sample batch.
|
||||
"max_seq_len": 1,
|
||||
},
|
||||
"num_workers": 1,
|
||||
# NOTE: Does this have consequences?
|
||||
# I use it for not loading tensorflow/pytorch.
|
||||
"framework": None,
|
||||
"log_level": "DEBUG",
|
||||
"create_env_on_driver": True,
|
||||
}
|
||||
|
||||
# Train the Trainer with our policy.
|
||||
my_trainer = BareMetalPolicyTrainer(config=config)
|
||||
results = my_trainer.train()
|
||||
print(results)
|
110
rllib/examples/policy/bare_metal_policy_with_custom_view_reqs.py
Normal file
110
rllib/examples/policy/bare_metal_policy_with_custom_view_reqs.py
Normal file
|
@ -0,0 +1,110 @@
|
|||
import numpy as np
|
||||
from gym.spaces import Box
|
||||
|
||||
from ray.rllib.policy import Policy
|
||||
from ray.rllib.policy.sample_batch import SampleBatch
|
||||
from ray.rllib.policy.view_requirement import ViewRequirement
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils.typing import ModelWeights
|
||||
|
||||
|
||||
class BareMetalPolicyWithCustomViewReqs(Policy):
|
||||
"""
|
||||
This policy does not much with the state, but shows,
|
||||
how the training# Trajectory View API can be used to
|
||||
pass user-specific view requirements to RLlib.
|
||||
"""
|
||||
|
||||
def __init__(self, observation_space, action_space, model_config, *args,
|
||||
**kwargs):
|
||||
super().__init__(observation_space, action_space, model_config, *args,
|
||||
**kwargs)
|
||||
self.observation_space = observation_space
|
||||
self.action_space = action_space
|
||||
self.state_size = 10
|
||||
self.model_config = model_config or {}
|
||||
space = Box(
|
||||
low=-np.inf,
|
||||
high=np.inf,
|
||||
shape=(self.state_size, ),
|
||||
dtype=np.float64)
|
||||
# Set view requirements such that the policy state is held in
|
||||
# memory for 2 environment steps.
|
||||
self.view_requirements["state_in_0"] = \
|
||||
ViewRequirement("state_out_0",
|
||||
shift="-2:-1",
|
||||
used_for_training=False,
|
||||
used_for_compute_actions=True,
|
||||
batch_repeat_value=1)
|
||||
self.view_requirements["state_out_0"] = \
|
||||
ViewRequirement(
|
||||
space=space,
|
||||
used_for_training=False,
|
||||
used_for_compute_actions=True,
|
||||
batch_repeat_value=1)
|
||||
|
||||
# Set the initial state. This is necessary for starting
|
||||
# the policy.
|
||||
def get_initial_state(self):
|
||||
return [np.zeros((self.state_size, ), dtype=np.float32)]
|
||||
|
||||
def compute_actions(self,
|
||||
obs_batch=None,
|
||||
state_batches=None,
|
||||
prev_action_batch=None,
|
||||
prev_reward_batch=None,
|
||||
info_batch=None,
|
||||
episodes=None,
|
||||
**kwargs):
|
||||
# First dimension is the batch (in list), second is the number of
|
||||
# states, and third is the shift. Fourth is the size of the state.
|
||||
batch_size = state_batches[0].shape[0]
|
||||
actions = np.array(
|
||||
[self.action_space.sample() for _ in range(batch_size)])
|
||||
new_state_batches = list(state_batches[0][0])
|
||||
return actions, [new_state_batches], {}
|
||||
|
||||
def compute_actions_from_input_dict(self,
|
||||
input_dict,
|
||||
explore=None,
|
||||
timestep=None,
|
||||
episodes=None,
|
||||
**kwargs):
|
||||
# Access the `infos` key here so it'll show up here always during
|
||||
# action sampling.
|
||||
infos = input_dict.get("infos")
|
||||
assert infos is not None
|
||||
|
||||
# Default implementation just passes obs, prev-a/r, and states on to
|
||||
# `self.compute_actions()`.
|
||||
state_batches = [
|
||||
s for k, s in input_dict.items() if k[:9] == "state_in_"
|
||||
]
|
||||
# Make sure that two (shift="-2:-1") past states are contained in the
|
||||
# state_batch.
|
||||
assert state_batches[0].shape[1] == 2
|
||||
assert state_batches[0].shape[2] == self.state_size
|
||||
return self.compute_actions(
|
||||
input_dict[SampleBatch.OBS],
|
||||
state_batches,
|
||||
prev_action_batch=input_dict.get(SampleBatch.PREV_ACTIONS),
|
||||
prev_reward_batch=input_dict.get(SampleBatch.PREV_REWARDS),
|
||||
info_batch=input_dict.get(SampleBatch.INFOS),
|
||||
explore=explore,
|
||||
timestep=timestep,
|
||||
episodes=episodes,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def learn_on_batch(self, samples):
|
||||
return
|
||||
|
||||
@override(Policy)
|
||||
def get_weights(self) -> ModelWeights:
|
||||
"""No weights to save."""
|
||||
return {}
|
||||
|
||||
@override(Policy)
|
||||
def set_weights(self, weights: ModelWeights) -> None:
|
||||
"""No weights to set."""
|
||||
pass
|
Loading…
Add table
Reference in a new issue