[RLlib] Attention Net prep PR #3. (#12450)

This commit is contained in:
Sven Mika 2020-12-07 13:08:17 +01:00 committed by GitHub
parent 401d342602
commit 99c81c6795
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
32 changed files with 354 additions and 247 deletions

View file

@ -13,12 +13,12 @@ from typing import Dict, List, Optional, Type, Union
from ray.rllib.agents.impala import vtrace_tf as vtrace
from ray.rllib.agents.impala.vtrace_tf_policy import _make_time_major, \
clip_gradients, choose_optimizer
from ray.rllib.agents.ppo.ppo_tf_policy import postprocess_ppo_gae
from ray.rllib.evaluation.episode import MultiAgentEpisode
from ray.rllib.evaluation.postprocessing import Postprocessing
from ray.rllib.models.tf.tf_action_dist import Categorical
from ray.rllib.policy.policy import Policy
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.evaluation.postprocessing import compute_advantages
from ray.rllib.policy.tf_policy_template import build_tf_policy
from ray.rllib.policy.tf_policy import LearningRateSchedule, TFPolicy
from ray.rllib.agents.ppo.ppo_tf_policy import KLCoeffMixin, ValueNetworkMixin
@ -338,31 +338,14 @@ def postprocess_trajectory(
SampleBatch: The postprocessed, modified SampleBatch (or a new one).
"""
if not policy.config["vtrace"]:
completed = sample_batch["dones"][-1]
if completed:
last_r = 0.0
else:
next_state = []
for i in range(policy.num_state_tensors()):
next_state.append([sample_batch["state_out_{}".format(i)][-1]])
last_r = policy._value(sample_batch[SampleBatch.NEXT_OBS][-1],
sample_batch[SampleBatch.ACTIONS][-1],
sample_batch[SampleBatch.REWARDS][-1],
*next_state)
batch = compute_advantages(
sample_batch,
last_r,
policy.config["gamma"],
policy.config["lambda"],
use_gae=policy.config["use_gae"],
use_critic=policy.config["use_critic"])
else:
batch = sample_batch
sample_batch = postprocess_ppo_gae(policy, sample_batch,
other_agent_batches, episode)
# TODO: (sven) remove this del once we have trajectory view API fully in
# place.
del batch.data["new_obs"] # not used, so save some bandwidth
del sample_batch.data["new_obs"] # not used, so save some bandwidth
return batch
return sample_batch
def add_values(policy):

View file

@ -38,7 +38,7 @@ DEFAULT_CONFIG = with_common_config({
# If true, use the Generalized Advantage Estimator (GAE)
# with a value function, see https://arxiv.org/pdf/1506.02438.pdf.
"use_gae": True,
# The GAE(lambda) parameter.
# The GAE (lambda) parameter.
"lambda": 1.0,
# Initial coefficient for KL divergence.
"kl_coeff": 0.2,

View file

@ -193,13 +193,22 @@ def postprocess_ppo_gae(
last_r = 0.0
# Trajectory has been truncated -> last r=VF estimate of last obs.
else:
next_state = []
for i in range(policy.num_state_tensors()):
next_state.append(sample_batch["state_out_{}".format(i)][-1])
last_r = policy._value(sample_batch[SampleBatch.NEXT_OBS][-1],
sample_batch[SampleBatch.ACTIONS][-1],
sample_batch[SampleBatch.REWARDS][-1],
*next_state)
# Input dict is provided to us automatically via the Model's
# requirements. It's a single-timestep (last one in trajectory)
# input_dict.
if policy.config["_use_trajectory_view_api"]:
# Create an input dict according to the Model's requirements.
input_dict = policy.model.get_input_dict(sample_batch, index=-1)
last_r = policy._value(**input_dict)
# TODO: (sven) Remove once trajectory view API is all-algo default.
else:
next_state = []
for i in range(policy.num_state_tensors()):
next_state.append(sample_batch["state_out_{}".format(i)][-1])
last_r = policy._value(sample_batch[SampleBatch.NEXT_OBS][-1],
sample_batch[SampleBatch.ACTIONS][-1],
sample_batch[SampleBatch.REWARDS][-1],
*next_state)
# Adds the policy logits, VF preds, and advantages to the batch,
# using GAE ("generalized advantage estimation") or not.
@ -208,7 +217,9 @@ def postprocess_ppo_gae(
last_r,
policy.config["gamma"],
policy.config["lambda"],
use_gae=policy.config["use_gae"])
use_gae=policy.config["use_gae"],
use_critic=policy.config.get("use_critic", True))
return batch
@ -292,25 +303,40 @@ class ValueNetworkMixin:
# observation.
if config["use_gae"]:
@make_tf_callable(self.get_session())
def value(ob, prev_action, prev_reward, *state):
model_out, _ = self.model({
SampleBatch.CUR_OBS: tf.convert_to_tensor([ob]),
SampleBatch.PREV_ACTIONS: tf.convert_to_tensor(
[prev_action]),
SampleBatch.PREV_REWARDS: tf.convert_to_tensor(
[prev_reward]),
"is_training": tf.convert_to_tensor([False]),
}, [tf.convert_to_tensor([s]) for s in state],
tf.convert_to_tensor([1]))
# [0] = remove the batch dim.
return self.model.value_function()[0]
# Input dict is provided to us automatically via the Model's
# requirements. It's a single-timestep (last one in trajectory)
# input_dict.
if config["_use_trajectory_view_api"]:
@make_tf_callable(self.get_session())
def value(**input_dict):
model_out, _ = self.model.from_batch(
input_dict, is_training=False)
# [0] = remove the batch dim.
return self.model.value_function()[0]
# TODO: (sven) Remove once trajectory view API is all-algo default.
else:
@make_tf_callable(self.get_session())
def value(ob, prev_action, prev_reward, *state):
model_out, _ = self.model({
SampleBatch.CUR_OBS: tf.convert_to_tensor([ob]),
SampleBatch.PREV_ACTIONS: tf.convert_to_tensor(
[prev_action]),
SampleBatch.PREV_REWARDS: tf.convert_to_tensor(
[prev_reward]),
"is_training": tf.convert_to_tensor([False]),
}, [tf.convert_to_tensor([s]) for s in state],
tf.convert_to_tensor([1]))
# [0] = remove the batch dim.
return self.model.value_function()[0]
# When not doing GAE, we do not require the value function's output.
else:
@make_tf_callable(self.get_session())
def value(ob, prev_action, prev_reward, *state):
def value(*args, **kwargs):
return tf.constant(0.0)
self._value = value

View file

@ -210,22 +210,36 @@ class ValueNetworkMixin:
# When doing GAE, we need the value function estimate on the
# observation.
if config["use_gae"]:
# Input dict is provided to us automatically via the Model's
# requirements. It's a single-timestep (last one in trajectory)
# input_dict.
if config["_use_trajectory_view_api"]:
def value(ob, prev_action, prev_reward, *state):
model_out, _ = self.model({
SampleBatch.CUR_OBS: convert_to_torch_tensor(
np.asarray([ob]), self.device),
SampleBatch.PREV_ACTIONS: convert_to_torch_tensor(
np.asarray([prev_action]), self.device),
SampleBatch.PREV_REWARDS: convert_to_torch_tensor(
np.asarray([prev_reward]), self.device),
"is_training": False,
}, [
convert_to_torch_tensor(np.asarray([s]), self.device)
for s in state
], convert_to_torch_tensor(np.asarray([1]), self.device))
# [0] = remove the batch dim.
return self.model.value_function()[0]
def value(**input_dict):
model_out, _ = self.model.from_batch(
convert_to_torch_tensor(input_dict, self.device),
is_training=False)
# [0] = remove the batch dim.
return self.model.value_function()[0]
# TODO: (sven) Remove once trajectory view API is all-algo default.
else:
def value(ob, prev_action, prev_reward, *state):
model_out, _ = self.model({
SampleBatch.CUR_OBS: convert_to_torch_tensor(
np.asarray([ob]), self.device),
SampleBatch.PREV_ACTIONS: convert_to_torch_tensor(
np.asarray([prev_action]), self.device),
SampleBatch.PREV_REWARDS: convert_to_torch_tensor(
np.asarray([prev_reward]), self.device),
"is_training": False,
}, [
convert_to_torch_tensor(np.asarray([s]), self.device)
for s in state
], convert_to_torch_tensor(np.asarray([1]), self.device))
# [0] = remove the batch dim.
return self.model.value_function()[0]
# When not doing GAE, we do not require the value function's output.
else:

View file

@ -1,9 +1,6 @@
from gym.spaces import Box
from ray.rllib.models.modelv2 import ModelV2
from ray.rllib.models.preprocessors import get_preprocessor
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
from ray.rllib.policy.view_requirement import ViewRequirement
from ray.rllib.utils.annotations import override
from ray.rllib.utils.framework import try_import_torch
@ -25,17 +22,13 @@ class RNNModel(TorchModelV2, nn.Module):
self.fc2 = nn.Linear(self.rnn_hidden_dim, num_outputs)
self.n_agents = model_config["n_agents"]
self.inference_view_requirements.update({
"state_in_0": ViewRequirement(
"state_out_0",
data_rel_pos=-1,
space=Box(-1.0, 1.0, (self.n_agents, self.rnn_hidden_dim)))
})
@override(ModelV2)
def get_initial_state(self):
# Place hidden states on same device as model.
return [self.fc1.weight.new(1, self.rnn_hidden_dim).zero_().squeeze(0)]
return [
self.fc1.weight.new(self.n_agents,
self.rnn_hidden_dim).zero_().squeeze(0)
]
@override(ModelV2)
def forward(self, input_dict, hidden_state, seq_lens):

View file

@ -215,9 +215,6 @@ class QMixTorchPolicy(Policy):
name="target_model",
default_model=RNNModel).to(self.device)
# Combine view_requirements for Model and Policy.
self.view_requirements.update(self.model.inference_view_requirements)
self.exploration = self._create_exploration()
# Setup the mixer network.

View file

@ -28,7 +28,7 @@ class MADDPGPostprocessing:
other_agent_batches=None,
episode=None):
# FIXME: Get done from info is required since agentwise done is not
# supported now.
# supported now.
sample_batch.data[SampleBatch.DONES] = self.get_done_from_info(
sample_batch.data[SampleBatch.INFOS])
@ -251,6 +251,9 @@ class MADDPGTFPolicy(MADDPGPostprocessing, TFPolicy):
loss_inputs=loss_inputs,
dist_inputs=actor_feature)
del self.view_requirements["prev_actions"]
del self.view_requirements["prev_rewards"]
self.sess.run(tf1.global_variables_initializer())
# Hard initial update

View file

@ -191,7 +191,7 @@ class _SampleCollector(metaclass=ABCMeta):
postprocessor.
This is usually called to collect samples for policy training.
If not enough data has been collected yet (`rollout_fragment_length`),
returns None.
returns an empty list.
Returns:
List[Union[MultiAgentBatch, SampleBatch]]: Returns a (possibly

View file

@ -1,4 +1,5 @@
import collections
from gym.spaces import Space
import logging
import numpy as np
from typing import Any, List, Dict, Tuple, TYPE_CHECKING, Union
@ -8,12 +9,11 @@ from ray.rllib.evaluation.collectors.sample_collector import _SampleCollector
from ray.rllib.evaluation.episode import MultiAgentEpisode
from ray.rllib.policy.policy import Policy
from ray.rllib.policy.sample_batch import SampleBatch, MultiAgentBatch
from ray.rllib.policy.view_requirement import ViewRequirement
from ray.rllib.utils.annotations import override
from ray.rllib.utils.debug import summarize
from ray.rllib.utils.typing import AgentID, EpisodeID, EnvID, PolicyID, \
TensorType
from ray.rllib.utils.framework import try_import_tf, try_import_torch
from ray.rllib.utils.typing import AgentID, EpisodeID, EnvID, PolicyID, \
TensorType, ViewRequirementsDict
from ray.util.debug import log_once
_, tf, _ = try_import_tf()
@ -48,13 +48,13 @@ class _AgentCollector:
def __init__(self, shift_before: int = 0):
self.shift_before = max(shift_before, 1)
self.buffers: Dict[str, List] = {}
self.episode_id = None
# The simple timestep count for this agent. Gets increased by one
# each time a (non-initial!) observation is added.
self.count = 0
def add_init_obs(self, episode_id: EpisodeID, agent_index: int,
env_id: EnvID, t: int, init_obs: TensorType,
view_requirements: Dict[str, ViewRequirement]) -> None:
env_id: EnvID, t: int, init_obs: TensorType) -> None:
"""Adds an initial observation (after reset) to the Agent's trajectory.
Args:
@ -67,19 +67,17 @@ class _AgentCollector:
ts=-1(!), then an action/reward/next-obs at t=0, etc..
init_obs (TensorType): The initial observation tensor (after
`env.reset()`).
view_requirements (Dict[str, ViewRequirements])
"""
if SampleBatch.OBS not in self.buffers:
self._build_buffers(
single_row={
SampleBatch.OBS: init_obs,
SampleBatch.EPS_ID: episode_id,
SampleBatch.AGENT_INDEX: agent_index,
"env_id": env_id,
"t": t,
})
self.buffers[SampleBatch.OBS].append(init_obs)
self.buffers[SampleBatch.EPS_ID].append(episode_id)
self.episode_id = episode_id
self.buffers[SampleBatch.AGENT_INDEX].append(agent_index)
self.buffers["env_id"].append(env_id)
self.buffers["t"].append(t)
@ -97,6 +95,11 @@ class _AgentCollector:
assert SampleBatch.OBS not in values
values[SampleBatch.OBS] = values[SampleBatch.NEXT_OBS]
del values[SampleBatch.NEXT_OBS]
# Make sure EPS_ID stays the same for this agent. Usually, it should
# not be part of `values` anyways.
if SampleBatch.EPS_ID in values:
assert values[SampleBatch.EPS_ID] == self.episode_id
del values[SampleBatch.EPS_ID]
for k, v in values.items():
if k not in self.buffers:
@ -104,8 +107,7 @@ class _AgentCollector:
self.buffers[k].append(v)
self.count += 1
def build(self, view_requirements: Dict[str, ViewRequirement]) -> \
SampleBatch:
def build(self, view_requirements: ViewRequirementsDict) -> SampleBatch:
"""Builds a SampleBatch from the thus-far collected agent data.
If the episode/trajectory has no DONE=True at the end, will copy
@ -115,32 +117,29 @@ class _AgentCollector:
by a Policy.
Args:
view_requirements (Dict[str, ViewRequirement]: The view
view_requirements (ViewRequirementsDict): The view
requirements dict needed to build the SampleBatch from the raw
buffers (which may have data shifts as well as mappings from
view-col to data-col in them).
Returns:
SampleBatch: The built SampleBatch for this agent, ready to go into
postprocessing.
"""
# TODO: measure performance gains when using a UsageTrackingDict
# instead of a SampleBatch for postprocessing (this would eliminate
# copies (for creating this SampleBatch) of many unused columns for
# no reason (not used by postprocessor)).
batch_data = {}
np_data = {}
for view_col, view_req in view_requirements.items():
# Create the batch of data from the different buffers.
data_col = view_req.data_col or view_col
# Some columns don't exist yet (get created during postprocessing).
# -> skip.
if data_col not in self.buffers:
continue
# OBS are already shifted by -1 (the initial obs starts one ts
# before all other data columns).
shift = view_req.data_rel_pos - \
shift = view_req.shift - \
(1 if data_col == SampleBatch.OBS else 0)
if data_col not in np_data:
np_data[data_col] = to_float_np_array(self.buffers[data_col])
@ -161,8 +160,12 @@ class _AgentCollector:
data = np_data[data_col][self.shift_before + shift:shift]
if len(data) > 0:
batch_data[view_col] = data
batch = SampleBatch(batch_data)
# Add EPS_ID and UNROLL_ID to batch.
batch.data[SampleBatch.EPS_ID] = np.repeat(self.episode_id,
batch.count)
if SampleBatch.UNROLL_ID not in batch.data:
# TODO: (sven) Once we have the additional
# model.preprocess_train_batch in place (attention net PR), we
@ -200,7 +203,7 @@ class _AgentCollector:
] else 0)
# Python primitive or dict (e.g. INFOs).
if isinstance(data, (int, float, bool, str, dict)):
self.buffers[col] = [0 for _ in range(shift)]
self.buffers[col] = [data for _ in range(shift)]
# np.ndarray, torch.Tensor, or tf.Tensor.
else:
shape = data.shape
@ -239,25 +242,24 @@ class _PolicyCollector:
def add_postprocessed_batch_for_training(
self, batch: SampleBatch,
view_requirements: Dict[str, ViewRequirement]) -> None:
view_requirements: ViewRequirementsDict) -> None:
"""Adds a postprocessed SampleBatch (single agent) to our buffers.
Args:
batch (SampleBatch): A single agent (one trajectory) SampleBatch
to be added to the Policy's buffers.
view_requirements (Dict[str, ViewRequirement]: The view
view_requirements (DViewRequirementsDict): The view
requirements for the policy. This is so we know, whether a
view-column needs to be copied at all (not needed for
training).
"""
for view_col, data in batch.items():
# TODO(ekl) how do we handle this for policies that don't extend
# Torch / TF Policy template (no inference of view reqs)?
# Skip columns that are not used for training.
# if view_col not in view_requirements or \
# not view_requirements[view_col].used_for_training:
# continue
self.buffers[view_col].extend(data)
# 1) If col is not in view_requirements, we must have a direct
# child of the base Policy that doesn't do auto-view req creation.
# 2) Col is in view-reqs and needed for training.
if view_col not in view_requirements or \
view_requirements[view_col].used_for_training:
self.buffers[view_col].extend(data)
# Add the agent's trajectory length to our count.
self.count += batch.count
@ -380,9 +382,6 @@ class _SimpleListCollector(_SampleCollector):
self.agent_key_to_policy_id[agent_key] = policy_id
else:
assert self.agent_key_to_policy_id[agent_key] == policy_id
policy = self.policy_map[policy_id]
view_reqs = policy.model.inference_view_requirements if \
getattr(policy, "model", None) else policy.view_requirements
# Add initial obs to Trajectory.
assert agent_key not in self.agent_collectors
@ -393,8 +392,7 @@ class _SimpleListCollector(_SampleCollector):
agent_index=episode._agent_index(agent_id),
env_id=env_id,
t=t,
init_obs=init_obs,
view_requirements=view_reqs)
init_obs=init_obs)
self.episodes[episode.episode_id] = episode
if episode.batch_builder is None:
@ -442,17 +440,22 @@ class _SimpleListCollector(_SampleCollector):
# Create the batch of data from the different buffers.
data_col = view_req.data_col or view_col
time_indices = \
view_req.data_rel_pos - (
view_req.shift - (
1 if data_col in [SampleBatch.OBS, "t", "env_id",
SampleBatch.EPS_ID,
SampleBatch.AGENT_INDEX] else 0)
data_list = []
for k in keys:
if data_col not in buffers[k]:
self.agent_collectors[k]._build_buffers({
data_col: view_req.space.sample()
})
data_list.append(buffers[k][data_col][time_indices])
if data_col == SampleBatch.EPS_ID:
data_list.append(self.agent_collectors[k].episode_id)
else:
if data_col not in buffers[k]:
fill_value = np.zeros_like(view_req.space.sample()) \
if isinstance(view_req.space, Space) else \
view_req.space
self.agent_collectors[k]._build_buffers({
data_col: fill_value
})
data_list.append(buffers[k][data_col][time_indices])
input_dict[view_col] = np.array(data_list)
self._reset_inference_calls(policy_id)
@ -517,8 +520,8 @@ class _SimpleListCollector(_SampleCollector):
del other_batches[agent_id]
pid = self.agent_key_to_policy_id[(episode_id, agent_id)]
policy = self.policy_map[pid]
if any(pre_batch["dones"][:-1]) or len(set(
pre_batch["eps_id"])) > 1:
if any(pre_batch[SampleBatch.DONES][:-1]) or len(
set(pre_batch[SampleBatch.EPS_ID])) > 1:
raise ValueError(
"Batches sent to postprocessing must only contain steps "
"from a single trajectory.", pre_batch)

View file

@ -177,6 +177,7 @@ class RolloutWorker(ParallelIteratorWorker):
fake_sampler: bool = False,
spaces: Optional[Dict[PolicyID, Tuple[gym.spaces.Space,
gym.spaces.Space]]] = None,
_use_trajectory_view_api: bool = True,
policy: Union[type, Dict[
str, Tuple[Optional[type], gym.Space, gym.Space,
PartialTrainerConfigDict]]] = None,
@ -295,6 +296,8 @@ class RolloutWorker(ParallelIteratorWorker):
gym.spaces.Space]]]): An optional space dict mapping policy IDs
to (obs_space, action_space)-tuples. This is used in case no
Env is created on this RolloutWorker.
_use_trajectory_view_api (bool): Whether to collect samples through
the experimental Trajectory View API.
policy: Obsoleted arg. Use `policy_spec` instead.
"""
# Deprecated arg.
@ -459,6 +462,14 @@ class RolloutWorker(ParallelIteratorWorker):
self.policy_map, self.preprocessors = self._build_policy_map(
policy_dict, policy_config)
# Update Policy's view requirements from Model, only if Policy directly
# inherited from base `Policy` class. At this point here, the Policy
# must have it's Model (if any) defined and ready to output an initial
# state.
for pol in self.policy_map.values():
if not pol._model_init_state_automatically_added:
pol._update_model_inference_view_requirements_from_init_state()
if (ray.is_initialized()
and ray.worker._mode() != ray.worker.LOCAL_MODE):
# Check available number of GPUs
@ -568,8 +579,8 @@ class RolloutWorker(ParallelIteratorWorker):
soft_horizon=soft_horizon,
no_done_at_end=no_done_at_end,
observation_fn=observation_fn,
_use_trajectory_view_api=policy_config.get(
"_use_trajectory_view_api", False))
_use_trajectory_view_api=_use_trajectory_view_api,
)
# Start the Sampler thread.
self.sampler.start()
else:
@ -590,8 +601,8 @@ class RolloutWorker(ParallelIteratorWorker):
soft_horizon=soft_horizon,
no_done_at_end=no_done_at_end,
observation_fn=observation_fn,
_use_trajectory_view_api=policy_config.get(
"_use_trajectory_view_api", False))
_use_trajectory_view_api=_use_trajectory_view_api,
)
self.input_reader: InputReader = input_creator(self.io_context)
self.output_writer: OutputWriter = output_creator(self.io_context)

View file

@ -1046,7 +1046,6 @@ def _process_observations_w_trajectory_view_api(
# Add actions, rewards, next-obs to collectors.
values_dict = {
"t": episode.length - 1,
"eps_id": episode.episode_id,
"env_id": env_id,
"agent_index": episode._agent_index(agent_id),
# Action (slot 0) taken at timestep t.

View file

@ -59,7 +59,7 @@ class TestTrajectoryViewAPI(unittest.TestCase):
assert view_req_policy[key].data_col is None
else:
assert view_req_policy[key].data_col == SampleBatch.OBS
assert view_req_policy[key].data_rel_pos == 1
assert view_req_policy[key].shift == 1
rollout_worker = trainer.workers.local_worker()
sample_batch = rollout_worker.sample()
expected_count = \
@ -99,10 +99,10 @@ class TestTrajectoryViewAPI(unittest.TestCase):
if key == SampleBatch.PREV_ACTIONS:
assert view_req_policy[key].data_col == SampleBatch.ACTIONS
assert view_req_policy[key].data_rel_pos == -1
assert view_req_policy[key].shift == -1
elif key == SampleBatch.PREV_REWARDS:
assert view_req_policy[key].data_col == SampleBatch.REWARDS
assert view_req_policy[key].data_rel_pos == -1
assert view_req_policy[key].shift == -1
elif key not in [
SampleBatch.NEXT_OBS, SampleBatch.PREV_ACTIONS,
SampleBatch.PREV_REWARDS
@ -110,7 +110,7 @@ class TestTrajectoryViewAPI(unittest.TestCase):
assert view_req_policy[key].data_col is None
else:
assert view_req_policy[key].data_col == SampleBatch.OBS
assert view_req_policy[key].data_rel_pos == 1
assert view_req_policy[key].shift == 1
trainer.stop()
def test_traj_view_simple_performance(self):

View file

@ -352,6 +352,7 @@ class WorkerSet:
fake_sampler=config["fake_sampler"],
extra_python_environs=extra_python_environs,
spaces=spaces,
_use_trajectory_view_api=config["_use_trajectory_view_api"],
)
return worker

View file

@ -28,15 +28,15 @@ class EpisodeEnvAwareLSTMPolicy(RandomPolicy):
"t": ViewRequirement(),
SampleBatch.OBS: ViewRequirement(),
SampleBatch.PREV_ACTIONS: ViewRequirement(
SampleBatch.ACTIONS, space=self.action_space, data_rel_pos=-1),
SampleBatch.ACTIONS, space=self.action_space, shift=-1),
SampleBatch.PREV_REWARDS: ViewRequirement(
SampleBatch.REWARDS, data_rel_pos=-1),
SampleBatch.REWARDS, shift=-1),
}
for i in range(2):
self.model.inference_view_requirements["state_in_{}".format(i)] = \
ViewRequirement(
"state_out_{}".format(i),
data_rel_pos=-1,
shift=-1,
space=self.state_space)
self.model.inference_view_requirements[
"state_out_{}".format(i)] = \
@ -45,7 +45,7 @@ class EpisodeEnvAwareLSTMPolicy(RandomPolicy):
self.view_requirements = dict(
**{
SampleBatch.NEXT_OBS: ViewRequirement(
SampleBatch.OBS, data_rel_pos=1),
SampleBatch.OBS, shift=1),
SampleBatch.ACTIONS: ViewRequirement(space=self.action_space),
SampleBatch.REWARDS: ViewRequirement(),
SampleBatch.DONES: ViewRequirement(),
@ -106,7 +106,7 @@ class EpisodeEnvAwareAttentionPolicy(RandomPolicy):
"state_in_0": ViewRequirement(
"state_out_0",
# Provide state outs -50 to -1 as "state-in".
data_rel_pos="-50:-1",
shift="-50:-1",
# Repeat the incoming state every n time steps (usually max seq
# len).
batch_repeat_value=self.config["model"]["max_seq_len"],

View file

@ -16,7 +16,7 @@ class AlwaysSameHeuristic(Policy):
self.view_requirements.update({
"state_in_0": ViewRequirement(
"state_out_0",
data_rel_pos=-1,
shift=-1,
space=gym.spaces.Box(0, 100, shape=(), dtype=np.int32))
})

View file

@ -143,6 +143,7 @@ if __name__ == "__main__":
# Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0.
"num_gpus": int(os.environ.get("RLLIB_NUM_GPUS", "0")),
"framework": "torch" if args.torch else "tf",
"_use_trajectory_view_api": True,
}
stop = {

View file

@ -61,8 +61,7 @@ class ModelV2:
self.time_major = self.model_config.get("_time_major")
# Basic view requirement for all models: Use the observation as input.
self.inference_view_requirements = {
SampleBatch.OBS: ViewRequirement(
data_rel_pos=0, space=self.obs_space),
SampleBatch.OBS: ViewRequirement(shift=0, space=self.obs_space),
}
# TODO: (sven): Get rid of `get_initial_state` once Trajectory
@ -315,6 +314,29 @@ class ModelV2:
"""
return self.time_major is True
# TODO: (sven) Experimental method.
def get_input_dict(self, sample_batch,
index: int = -1) -> Dict[str, TensorType]:
if index < 0:
index = sample_batch.count - 1
input_dict = {}
for view_col, view_req in self.inference_view_requirements.items():
# Create batches of size 1 (single-agent input-dict).
# Index range.
if isinstance(index, tuple):
data = sample_batch[view_col][index[0]:index[1] + 1]
input_dict[view_col] = np.array([data])
# Single index.
else:
input_dict[view_col] = sample_batch[view_col][index:index + 1]
# Add valid `seq_lens`, just in case RNNs need it.
input_dict["seq_lens"] = np.array([1], dtype=np.int32)
return input_dict
class NullContextManager:
"""No-op context manager"""

View file

@ -178,10 +178,10 @@ class LSTMWrapper(RecurrentNetwork):
if model_config["lstm_use_prev_action"]:
self.inference_view_requirements[SampleBatch.PREV_ACTIONS] = \
ViewRequirement(SampleBatch.ACTIONS, space=self.action_space,
data_rel_pos=-1)
shift=-1)
if model_config["lstm_use_prev_reward"]:
self.inference_view_requirements[SampleBatch.PREV_REWARDS] = \
ViewRequirement(SampleBatch.REWARDS, data_rel_pos=-1)
ViewRequirement(SampleBatch.REWARDS, shift=-1)
@override(RecurrentNetwork)
def forward(self, input_dict: Dict[str, TensorType],

View file

@ -159,10 +159,10 @@ class LSTMWrapper(RecurrentNetwork, nn.Module):
if model_config["lstm_use_prev_action"]:
self.inference_view_requirements[SampleBatch.PREV_ACTIONS] = \
ViewRequirement(SampleBatch.ACTIONS, space=self.action_space,
data_rel_pos=-1)
shift=-1)
if model_config["lstm_use_prev_reward"]:
self.inference_view_requirements[SampleBatch.PREV_REWARDS] = \
ViewRequirement(SampleBatch.REWARDS, data_rel_pos=-1)
ViewRequirement(SampleBatch.REWARDS, shift=-1)
@override(RecurrentNetwork)
def forward(self, input_dict: Dict[str, TensorType],

View file

@ -80,8 +80,6 @@ class DynamicTFPolicy(TFPolicy):
], Tuple[TensorType, type, List[TensorType]]]] = None,
existing_inputs: Optional[Dict[str, "tf1.placeholder"]] = None,
existing_model: Optional[ModelV2] = None,
view_requirements_fn: Optional[Callable[[Policy], Dict[
str, ViewRequirement]]] = None,
get_batch_divisibility_req: Optional[Callable[[Policy],
int]] = None,
obs_include_prev_action_reward: bool = True):
@ -292,14 +290,6 @@ class DynamicTFPolicy(TFPolicy):
action_distribution=action_dist,
timestep=timestep,
explore=explore)
if self.config["_use_trajectory_view_api"]:
self._dummy_batch[SampleBatch.ACTION_DIST_INPUTS] = \
np.zeros(
[1 if not s else s for s in
dist_inputs.shape.as_list()])
self._input_dict[SampleBatch.ACTION_DIST_INPUTS] = \
tf1.placeholder(shape=dist_inputs.shape.as_list(),
dtype=tf.float32)
# Phase 1 init.
sess = tf1.get_default_session() or tf1.Session()
@ -417,42 +407,37 @@ class DynamicTFPolicy(TFPolicy):
input_dict/dummy_batch tuple.
"""
input_dict = {}
dummy_batch = {}
for view_col, view_req in view_requirements.items():
# Point state_in to the already existing self._state_inputs.
mo = re.match("state_in_(\d+)", view_col)
if mo is not None:
input_dict[view_col] = self._state_inputs[int(mo.group(1))]
dummy_batch[view_col] = np.zeros_like(
[view_req.space.sample()])
# State-outs (no placeholders needed).
elif view_col.startswith("state_out_"):
dummy_batch[view_col] = np.zeros_like(
[view_req.space.sample()])
continue
# Skip action dist inputs placeholder (do later).
elif view_col == SampleBatch.ACTION_DIST_INPUTS:
continue
elif view_col in existing_inputs:
input_dict[view_col] = existing_inputs[view_col]
dummy_batch[view_col] = np.zeros(
shape=[
1 if s is None else s
for s in existing_inputs[view_col].shape.as_list()
],
dtype=existing_inputs[view_col].dtype.as_numpy_dtype)
# All others.
else:
if view_req.used_for_training:
input_dict[view_col] = get_placeholder(
space=view_req.space, name=view_col)
dummy_batch[view_col] = np.zeros_like(
[view_req.space.sample()])
dummy_batch = self._get_dummy_batch_from_view_requirements(
batch_size=32)
return input_dict, dummy_batch
def _initialize_loss_from_dummy_batch(
self, auto_remove_unneeded_view_reqs: bool = True,
stats_fn=None) -> None:
# Create the optimizer/exploration optimizer here. Some initialization
# steps (e.g. exploration postprocessing) may need this.
self._optimizer = self.optimizer()
# Test calls depend on variable init, so initialize model first.
self._sess.run(tf1.global_variables_initializer())
@ -509,6 +494,8 @@ class DynamicTFPolicy(TFPolicy):
batch_for_postproc = UsageTrackingDict(sb)
batch_for_postproc.count = sb.count
logger.info("Testing `postprocess_trajectory` w/ dummy batch.")
self.exploration.postprocess_trajectory(self, batch_for_postproc,
self._sess)
postprocessed_batch = self.postprocess_trajectory(batch_for_postproc)
# Add new columns automatically to (loss) input_dict.
if self.config["_use_trajectory_view_api"]:
@ -588,7 +575,8 @@ class DynamicTFPolicy(TFPolicy):
batch_for_postproc.accessed_keys
# Tag those only needed for post-processing.
for key in batch_for_postproc.accessed_keys:
if key not in train_batch.accessed_keys:
if key not in train_batch.accessed_keys and \
key not in self.model.inference_view_requirements:
self.view_requirements[key].used_for_training = False
if key in self._loss_input_dict:
del self._loss_input_dict[key]

View file

@ -194,7 +194,6 @@ def build_eager_tf_policy(name,
action_sampler_fn=None,
action_distribution_fn=None,
mixins=None,
view_requirements_fn=None,
obs_include_prev_action_reward=True,
get_batch_divisibility_req=None):
"""Build an eager TF policy.
@ -265,9 +264,6 @@ def build_eager_tf_policy(name,
for s in self.model.get_initial_state()
]
# Update this Policy's ViewRequirements (if function given).
if callable(view_requirements_fn):
self.view_requirements.update(view_requirements_fn(self))
# Combine view_requirements for Model and Policy.
self.view_requirements.update(
self.model.inference_view_requirements)
@ -275,12 +271,6 @@ def build_eager_tf_policy(name,
if before_loss_init:
before_loss_init(self, observation_space, action_space, config)
self._initialize_loss_from_dummy_batch(
auto_remove_unneeded_view_reqs=True,
stats_fn=stats_fn,
)
self._loss_initialized = True
if optimizer_fn:
optimizers = optimizer_fn(self, config)
else:
@ -293,10 +283,16 @@ def build_eager_tf_policy(name,
# Just like torch Policy does.
self._optimizer = optimizers[0] if optimizers else None
self._initialize_loss_from_dummy_batch(
auto_remove_unneeded_view_reqs=True,
stats_fn=stats_fn,
)
self._loss_initialized = True
if after_init:
after_init(self, observation_space, action_space, config)
# Got to reset global_timestep again after this fake run-through.
# Got to reset global_timestep again after fake run-throughs.
self.global_timestep = 0
@override(Policy)
@ -410,7 +406,7 @@ def build_eager_tf_policy(name,
timestep=timestep, explore=explore)
if action_distribution_fn:
dist_inputs, dist_class, state_out = \
dist_inputs, self.dist_class, state_out = \
action_distribution_fn(
self, self.model,
input_dict[SampleBatch.CUR_OBS],
@ -418,11 +414,10 @@ def build_eager_tf_policy(name,
timestep=timestep,
is_training=False)
else:
dist_class = self.dist_class
dist_inputs, state_out = self.model(
input_dict, state_batches, seq_lens)
action_dist = dist_class(dist_inputs, self.model)
action_dist = self.dist_class(dist_inputs, self.model)
# Get the exploration action from the forward results.
actions, logp = self.exploration.get_exploration_action(
@ -466,12 +461,12 @@ def build_eager_tf_policy(name,
"is_training": tf.constant(False),
}
if obs_include_prev_action_reward:
input_dict.update({
SampleBatch.PREV_ACTIONS: tf.convert_to_tensor(
prev_action_batch),
SampleBatch.PREV_REWARDS: tf.convert_to_tensor(
prev_reward_batch),
})
if prev_action_batch is not None:
input_dict[SampleBatch.PREV_ACTIONS] = \
tf.convert_to_tensor(prev_action_batch)
if prev_reward_batch is not None:
input_dict[SampleBatch.PREV_REWARDS] = \
tf.convert_to_tensor(prev_reward_batch)
# Exploration hook before each forward pass.
self.exploration.before_compute_actions(explore=False)
@ -559,7 +554,9 @@ def build_eager_tf_policy(name,
@override(Policy)
def get_initial_state(self):
return self.model.get_initial_state()
if hasattr(self, "model"):
return self.model.get_initial_state()
return []
def get_session(self):
return None # None implies eager

View file

@ -92,6 +92,7 @@ class Policy(metaclass=ABCMeta):
self.view_requirements = view_reqs
else:
self.view_requirements.update(view_reqs)
self._model_init_state_automatically_added = False
@abstractmethod
@DeveloperAPI
@ -278,7 +279,8 @@ class Policy(metaclass=ABCMeta):
# `self.compute_actions()`.
state_batches = [
# TODO: (sven) remove unsqueezing code here for non-traj.view API.
s if self.config["_use_trajectory_view_api"] else s.unsqueeze(0)
s if self.config.get("_use_trajectory_view_api", False) else
s.unsqueeze(0)
if torch and isinstance(s, torch.Tensor) else np.expand_dims(s, 0)
for k, s in input_dict.items() if k[:9] == "state_in_"
]
@ -564,16 +566,25 @@ class Policy(metaclass=ABCMeta):
SampleBatch.OBS: ViewRequirement(space=self.observation_space),
SampleBatch.NEXT_OBS: ViewRequirement(
data_col=SampleBatch.OBS,
data_rel_pos=1,
shift=1,
space=self.observation_space),
SampleBatch.ACTIONS: ViewRequirement(space=self.action_space),
# For backward compatibility with custom Models that don't specify
# these explicitly (will be removed by Policy if not used).
SampleBatch.PREV_ACTIONS: ViewRequirement(
data_col=SampleBatch.ACTIONS,
shift=-1,
space=self.action_space),
SampleBatch.REWARDS: ViewRequirement(),
# For backward compatibility with custom Models that don't specify
# these explicitly (will be removed by Policy if not used).
SampleBatch.PREV_REWARDS: ViewRequirement(
data_col=SampleBatch.REWARDS, shift=-1),
SampleBatch.DONES: ViewRequirement(),
SampleBatch.INFOS: ViewRequirement(),
SampleBatch.EPS_ID: ViewRequirement(),
SampleBatch.UNROLL_ID: ViewRequirement(),
SampleBatch.AGENT_INDEX: ViewRequirement(),
SampleBatch.UNROLL_ID: ViewRequirement(),
"t": ViewRequirement(),
}
@ -616,6 +627,7 @@ class Policy(metaclass=ABCMeta):
-1.0, 1.0, shape=value.shape[1:], dtype=value.dtype))
batch_for_postproc = UsageTrackingDict(self._dummy_batch)
batch_for_postproc.count = self._dummy_batch.count
self.exploration.postprocess_trajectory(self, batch_for_postproc)
postprocessed_batch = self.postprocess_trajectory(batch_for_postproc)
if state_outs:
B = 4 # For RNNs, have B=4, T=[depends on sample_batch_size]
@ -700,27 +712,33 @@ class Policy(metaclass=ABCMeta):
ret[view_col] = \
np.zeros((batch_size, ) + shape[1:], np.float32)
else:
ret[view_col] = np.zeros_like(
[view_req.space.sample() for _ in range(batch_size)])
if isinstance(view_req.space, gym.spaces.Space):
ret[view_col] = np.zeros_like(
[view_req.space.sample() for _ in range(batch_size)])
else:
ret[view_col] = [view_req.space for _ in range(batch_size)]
return SampleBatch(ret)
def _update_model_inference_view_requirements_from_init_state(self):
"""Uses this Model's initial state to auto-add necessary ViewReqs.
"""Uses Model's (or this Policy's) init state to add needed ViewReqs.
Can be called from within a Policy to make sure RNNs automatically
update their internal state-related view requirements.
Changes the `self.inference_view_requirements` dict.
"""
model = self.model
self._model_init_state_automatically_added = True
model = getattr(self, "model", None)
obj = model or self
# Add state-ins to this model's view.
for i, state in enumerate(model.get_initial_state()):
model.inference_view_requirements["state_in_{}".format(i)] = \
ViewRequirement(
"state_out_{}".format(i),
data_rel_pos=-1,
space=Box(-1.0, 1.0, shape=state.shape))
model.inference_view_requirements["state_out_{}".format(i)] = \
ViewRequirement(space=Box(-1.0, 1.0, shape=state.shape))
for i, state in enumerate(obj.get_initial_state()):
space = Box(-1.0, 1.0, shape=state.shape) if \
hasattr(state, "shape") else state
view_reqs = model.inference_view_requirements if model else \
self.view_requirements
view_reqs["state_in_{}".format(i)] = ViewRequirement(
"state_out_{}".format(i), shift=-1, space=space)
view_reqs["state_out_{}".format(i)] = ViewRequirement(space=space)
def clip_action(action, action_space):

View file

@ -115,7 +115,9 @@ class SampleBatch:
[s[k] for s in concat_samples],
time_major=concat_samples[0].time_major)
return SampleBatch(
out, _seq_lens=seq_lens, _time_major=concat_samples[0].time_major)
out,
_seq_lens=np.array(seq_lens, dtype=np.int32),
_time_major=concat_samples[0].time_major)
@PublicAPI
def concat(self, other: "SampleBatch") -> "SampleBatch":
@ -154,7 +156,8 @@ class SampleBatch:
"""
return SampleBatch(
{k: np.array(v, copy=True)
for (k, v) in self.data.items()})
for (k, v) in self.data.items()},
_seq_lens=self.seq_lens)
@PublicAPI
def rows(self) -> Dict[str, TensorType]:

View file

@ -2,6 +2,7 @@ import numpy as np
from scipy.stats import norm
import unittest
import ray
import ray.rllib.agents.dqn as dqn
import ray.rllib.agents.pg as pg
import ray.rllib.agents.ppo as ppo
@ -87,8 +88,8 @@ def do_test_log_likelihood(run,
logp = policy.compute_log_likelihoods(
np.array([a]),
preprocessed_obs_batch,
prev_action_batch=np.array([prev_a]),
prev_reward_batch=np.array([prev_r]))
prev_action_batch=np.array([prev_a]) if prev_a else None,
prev_reward_batch=np.array([prev_r]) if prev_r else None)
check(logp, expected_logp[0], rtol=0.2)
# Test all available actions for their logp values.
else:
@ -98,12 +99,20 @@ def do_test_log_likelihood(run,
logp = policy.compute_log_likelihoods(
np.array([a]),
preprocessed_obs_batch,
prev_action_batch=np.array([prev_a]),
prev_reward_batch=np.array([prev_r]))
prev_action_batch=np.array([prev_a]) if prev_a else None,
prev_reward_batch=np.array([prev_r]) if prev_r else None)
check(np.exp(logp), expected_prob, atol=0.2)
class TestComputeLogLikelihood(unittest.TestCase):
@classmethod
def setUpClass(cls) -> None:
ray.init()
@classmethod
def tearDownClass(cls) -> None:
ray.shutdown()
def test_dqn(self):
"""Tests, whether DQN correctly computes logp in soft-q mode."""
config = dqn.DEFAULT_CONFIG.copy()

View file

@ -274,7 +274,8 @@ class TFPolicy(Policy):
else:
self._loss = loss
self._optimizer = self.optimizer()
if self._optimizer is None:
self._optimizer = self.optimizer()
self._grads_and_vars = [
(g, v) for (g, v) in self.gradients(self._optimizer, self._loss)
if g is not None

View file

@ -8,7 +8,6 @@ from ray.rllib.policy import eager_tf_policy
from ray.rllib.policy.policy import Policy, LEARNER_STATS_KEY
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.policy.tf_policy import TFPolicy
from ray.rllib.policy.view_requirement import ViewRequirement
from ray.rllib.utils import add_mixins, force_list
from ray.rllib.utils.annotations import override, DeveloperAPI
from ray.rllib.utils.framework import try_import_tf
@ -66,8 +65,6 @@ def build_tf_policy(
Policy, ModelV2, TensorType, TensorType, TensorType
], Tuple[TensorType, type, List[TensorType]]]] = None,
mixins: Optional[List[type]] = None,
view_requirements_fn: Optional[Callable[[Policy], Dict[
str, ViewRequirement]]] = None,
get_batch_divisibility_req: Optional[Callable[[Policy], int]] = None,
# TODO: (sven) deprecate once _use_trajectory_view_api is always True.
obs_include_prev_action_reward: bool = True,
@ -231,7 +228,6 @@ def build_tf_policy(
action_distribution_fn=action_distribution_fn,
existing_inputs=existing_inputs,
existing_model=existing_model,
view_requirements_fn=view_requirements_fn,
get_batch_divisibility_req=get_batch_divisibility_req,
obs_include_prev_action_reward=obs_include_prev_action_reward)

View file

@ -8,7 +8,6 @@ from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
from ray.rllib.policy.policy import Policy, LEARNER_STATS_KEY
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.policy.torch_policy import TorchPolicy
from ray.rllib.policy.view_requirement import ViewRequirement
from ray.rllib.utils import add_mixins, force_list
from ray.rllib.utils.annotations import override, DeveloperAPI
from ray.rllib.utils.framework import try_import_torch
@ -70,8 +69,6 @@ def build_torch_policy(
apply_gradients_fn: Optional[Callable[
[Policy, "torch.optim.Optimizer"], None]] = None,
mixins: Optional[List[type]] = None,
view_requirements_fn: Optional[Callable[[Policy], Dict[
str, ViewRequirement]]] = None,
get_batch_divisibility_req: Optional[Callable[[Policy], int]] = None
) -> Type[TorchPolicy]:
"""Helper function for creating a torch policy class at runtime.
@ -174,9 +171,6 @@ def build_torch_policy(
mixins (Optional[List[type]]): Optional list of any class mixins for
the returned policy class. These mixins will be applied in order
and will have higher precedence than the TorchPolicy class.
view_requirements_fn (Optional[Callable[[Policy],
Dict[str, ViewRequirement]]]): An optional callable to retrieve
additional train view requirements for this policy.
get_batch_divisibility_req (Optional[Callable[[Policy], int]]):
Optional callable that returns the divisibility requirement for
sample batches. If None, will assume a value of 1.
@ -242,9 +236,6 @@ def build_torch_policy(
get_batch_divisibility_req=get_batch_divisibility_req,
)
# Update this Policy's ViewRequirements (if function given).
if callable(view_requirements_fn):
self.view_requirements.update(view_requirements_fn(self))
# Merge Model's view requirements into Policy's.
self.view_requirements.update(
self.model.inference_view_requirements)

View file

@ -29,31 +29,42 @@ class ViewRequirement:
def __init__(self,
data_col: Optional[str] = None,
space: gym.Space = None,
data_rel_pos: Union[int, List[int]] = 0,
shift: Union[int, List[int]] = 0,
index: Optional[int] = None,
used_for_training: bool = True):
"""Initializes a ViewRequirement object.
Args:
data_col (): The data column name from the SampleBatch (str key).
If None, use the dict key under which this ViewRequirement
resides.
data_col (Optional[str]): The data column name from the SampleBatch
(str key). If None, use the dict key under which this
ViewRequirement resides.
space (gym.Space): The gym Space used in case we need to pad data
in inaccessible areas of the trajectory (t<0 or t>H).
Default: Simple box space, e.g. rewards.
data_rel_pos (Union[int, str, List[int]]): Single shift value or
shift (Union[int, str, List[int]]): Single shift value or
list of relative positions to use (relative to the underlying
`data_col`).
Example: For a view column "prev_actions", you can set
`data_col="actions"` and `data_rel_pos=-1`.
`data_col="actions"` and `shift=-1`.
Example: For a view column "obs" in an Atari framestacking
fashion, you can set `data_col="obs"` and
`data_rel_pos=[-3, -2, -1, 0]`.
`shift=[-3, -2, -1, 0]`.
Example: For the obs input to an attention net, you can specify
a range via a str: `shift="-100:0"`, which will pass in
the past 100 observations plus the current one.
index (Optional[int]): An optional absolute position arg,
used e.g. for the location of a requested inference dict within
the trajectory. Negative values refer to counting from the end
of a trajectory.
used_for_training (bool): Whether the data will be used for
training. If False, the column will not be copied into the
final train batch.
"""
self.data_col = data_col
self.space = space or gym.spaces.Box(
self.space = space if space is not None else gym.spaces.Box(
float("-inf"), float("inf"), shape=())
self.data_rel_pos = data_rel_pos
self.index = index
self.shift = shift
self.used_for_training = used_for_training

View file

@ -151,10 +151,10 @@ def test_concat_batches(ray_start_regular_shared):
def test_standardize(ray_start_regular_shared):
workers = make_workers(0)
a = ParallelRollouts(workers, mode="async")
b = a.for_each(StandardizeFields(["t"]))
b = a.for_each(StandardizeFields([SampleBatch.EPS_ID]))
batch = next(b)
assert abs(np.mean(batch["t"])) < 0.001, batch
assert abs(np.std(batch["t"]) - 1.0) < 0.001, batch
assert abs(np.mean(batch[SampleBatch.EPS_ID])) < 0.001, batch
assert abs(np.std(batch[SampleBatch.EPS_ID]) - 1.0) < 0.001, batch
def test_async_grads(ray_start_regular_shared):

View file

@ -7,6 +7,8 @@ import ray
from ray.tune.registry import register_env
from ray.rllib.agents.dqn.dqn_tf_policy import DQNTFPolicy
from ray.rllib.agents.pg import PGTrainer
from ray.rllib.evaluation.episode import MultiAgentEpisode
from ray.rllib.evaluation.rollout_worker import get_global_worker
from ray.rllib.examples.policy.random_policy import RandomPolicy
from ray.rllib.examples.env.multi_agent import MultiAgentCartPole, \
BasicMultiAgent, EarlyDoneMultiAgent, RoundRobinMultiAgent
@ -321,21 +323,31 @@ class TestMultiAgentEnv(unittest.TestCase):
if episodes is not None:
# Pretend we did a model-based rollout and want to return
# the extra trajectory.
builder = episodes[0].new_batch_builder()
rollout_id = random.randint(0, 10000)
for t in range(5):
builder.add_values(
agent_id="extra_0",
policy_id="p1", # use p1 so we can easily check it
t=t,
eps_id=rollout_id, # new id for each rollout
obs=obs_batch[0],
actions=0,
rewards=0,
dones=t == 4,
infos={},
new_obs=obs_batch[0])
batch = builder.build_and_reset(episode=None)
env_id = episodes[0].env_id
fake_eps = MultiAgentEpisode(
episodes[0]._policies, episodes[0]._policy_mapping_fn,
lambda: None, lambda x: None, env_id)
builder = get_global_worker().sampler.sample_collector
agent_id = "extra_0"
policy_id = "p1" # use p1 so we can easily check it
builder.add_init_obs(fake_eps, agent_id, env_id, policy_id,
-1, obs_batch[0])
for t in range(4):
builder.add_action_reward_next_obs(
episode_id=fake_eps.episode_id,
agent_id=agent_id,
env_id=env_id,
policy_id=policy_id,
agent_done=t == 3,
values=dict(
t=t,
actions=0,
rewards=0,
dones=t == 3,
infos={},
new_obs=obs_batch[0]))
batch = builder.postprocess_episode(
episode=fake_eps, build=True)
episodes[0].add_extra_batch(batch)
# Just return zeros for actions
@ -350,12 +362,17 @@ class TestMultiAgentEnv(unittest.TestCase):
"p0": (ModelBasedPolicy, obs_space, act_space, {}),
"p1": (ModelBasedPolicy, obs_space, act_space, {}),
},
policy_config={"_use_trajectory_view_api": True},
policy_mapping_fn=lambda agent_id: "p0",
rollout_fragment_length=5)
batch = ev.sample()
# 5 environment steps (rollout_fragment_length).
self.assertEqual(batch.count, 5)
# 10 agent steps for p0: 2 agents, both using p0 as their policy.
self.assertEqual(batch.policy_batches["p0"].count, 10)
self.assertEqual(batch.policy_batches["p1"].count, 25)
# 20 agent steps for p1: Each time both(!) agents takes 1 step,
# p1 takes 4: 5 (rollout-fragment length) * 4 = 20
self.assertEqual(batch.policy_batches["p1"].count, 20)
def test_train_multi_agent_cartpole_single_policy(self):
n = 10

View file

@ -1,12 +1,21 @@
import numpy as np
import unittest
import ray
import ray.rllib.agents.ddpg as ddpg
import ray.rllib.agents.dqn as dqn
from ray.rllib.utils.test_utils import check, framework_iterator
class TestParameterNoise(unittest.TestCase):
@classmethod
def setUpClass(cls) -> None:
ray.init()
@classmethod
def tearDownClass(cls) -> None:
ray.shutdown()
def test_ddpg_parameter_noise(self):
self.do_test_parameter_noise_exploration(
ddpg.DDPGTrainer, ddpg.DEFAULT_CONFIG, "Pendulum-v0", {},
@ -37,6 +46,10 @@ class TestParameterNoise(unittest.TestCase):
trainer = trainer_cls(config=config, env=env)
policy = trainer.get_policy()
pol_sess = getattr(policy, "_sess", None)
# Remove noise that has been added during policy initialization
# (exploration.postprocess_trajectory does add noise to measure
# the delta).
policy.exploration._remove_noise(tf_sess=pol_sess)
self.assertFalse(policy.exploration.weights_are_currently_noisy)
noise_before = self._get_current_noise(policy, fw)
@ -96,6 +109,12 @@ class TestParameterNoise(unittest.TestCase):
config["explore"] = False
trainer = trainer_cls(config=config, env=env)
policy = trainer.get_policy()
pol_sess = getattr(policy, "_sess", None)
# Remove noise that has been added during policy initialization
# (exploration.postprocess_trajectory does add noise to measure
# the delta).
policy.exploration._remove_noise(tf_sess=pol_sess)
self.assertFalse(policy.exploration.weights_are_currently_noisy)
initial_weights = self._get_current_weight(policy, fw)

View file

@ -67,6 +67,10 @@ EnvInfoDict = dict
# Represents a File object
FileType = Any
# Represents a ViewRequirements dict mapping column names (str) to
# ViewRequirement objects.
ViewRequirementsDict = Dict[str, "ViewRequirement"]
# Represents the result dict returned by Trainer.train().
ResultDict = dict