mirror of
https://github.com/vale981/ray
synced 2025-03-05 18:11:42 -05:00
parent
401d342602
commit
99c81c6795
32 changed files with 354 additions and 247 deletions
|
@ -13,12 +13,12 @@ from typing import Dict, List, Optional, Type, Union
|
|||
from ray.rllib.agents.impala import vtrace_tf as vtrace
|
||||
from ray.rllib.agents.impala.vtrace_tf_policy import _make_time_major, \
|
||||
clip_gradients, choose_optimizer
|
||||
from ray.rllib.agents.ppo.ppo_tf_policy import postprocess_ppo_gae
|
||||
from ray.rllib.evaluation.episode import MultiAgentEpisode
|
||||
from ray.rllib.evaluation.postprocessing import Postprocessing
|
||||
from ray.rllib.models.tf.tf_action_dist import Categorical
|
||||
from ray.rllib.policy.policy import Policy
|
||||
from ray.rllib.policy.sample_batch import SampleBatch
|
||||
from ray.rllib.evaluation.postprocessing import compute_advantages
|
||||
from ray.rllib.policy.tf_policy_template import build_tf_policy
|
||||
from ray.rllib.policy.tf_policy import LearningRateSchedule, TFPolicy
|
||||
from ray.rllib.agents.ppo.ppo_tf_policy import KLCoeffMixin, ValueNetworkMixin
|
||||
|
@ -338,31 +338,14 @@ def postprocess_trajectory(
|
|||
SampleBatch: The postprocessed, modified SampleBatch (or a new one).
|
||||
"""
|
||||
if not policy.config["vtrace"]:
|
||||
completed = sample_batch["dones"][-1]
|
||||
if completed:
|
||||
last_r = 0.0
|
||||
else:
|
||||
next_state = []
|
||||
for i in range(policy.num_state_tensors()):
|
||||
next_state.append([sample_batch["state_out_{}".format(i)][-1]])
|
||||
last_r = policy._value(sample_batch[SampleBatch.NEXT_OBS][-1],
|
||||
sample_batch[SampleBatch.ACTIONS][-1],
|
||||
sample_batch[SampleBatch.REWARDS][-1],
|
||||
*next_state)
|
||||
batch = compute_advantages(
|
||||
sample_batch,
|
||||
last_r,
|
||||
policy.config["gamma"],
|
||||
policy.config["lambda"],
|
||||
use_gae=policy.config["use_gae"],
|
||||
use_critic=policy.config["use_critic"])
|
||||
else:
|
||||
batch = sample_batch
|
||||
sample_batch = postprocess_ppo_gae(policy, sample_batch,
|
||||
other_agent_batches, episode)
|
||||
|
||||
# TODO: (sven) remove this del once we have trajectory view API fully in
|
||||
# place.
|
||||
del batch.data["new_obs"] # not used, so save some bandwidth
|
||||
del sample_batch.data["new_obs"] # not used, so save some bandwidth
|
||||
|
||||
return batch
|
||||
return sample_batch
|
||||
|
||||
|
||||
def add_values(policy):
|
||||
|
|
|
@ -38,7 +38,7 @@ DEFAULT_CONFIG = with_common_config({
|
|||
# If true, use the Generalized Advantage Estimator (GAE)
|
||||
# with a value function, see https://arxiv.org/pdf/1506.02438.pdf.
|
||||
"use_gae": True,
|
||||
# The GAE(lambda) parameter.
|
||||
# The GAE (lambda) parameter.
|
||||
"lambda": 1.0,
|
||||
# Initial coefficient for KL divergence.
|
||||
"kl_coeff": 0.2,
|
||||
|
|
|
@ -193,13 +193,22 @@ def postprocess_ppo_gae(
|
|||
last_r = 0.0
|
||||
# Trajectory has been truncated -> last r=VF estimate of last obs.
|
||||
else:
|
||||
next_state = []
|
||||
for i in range(policy.num_state_tensors()):
|
||||
next_state.append(sample_batch["state_out_{}".format(i)][-1])
|
||||
last_r = policy._value(sample_batch[SampleBatch.NEXT_OBS][-1],
|
||||
sample_batch[SampleBatch.ACTIONS][-1],
|
||||
sample_batch[SampleBatch.REWARDS][-1],
|
||||
*next_state)
|
||||
# Input dict is provided to us automatically via the Model's
|
||||
# requirements. It's a single-timestep (last one in trajectory)
|
||||
# input_dict.
|
||||
if policy.config["_use_trajectory_view_api"]:
|
||||
# Create an input dict according to the Model's requirements.
|
||||
input_dict = policy.model.get_input_dict(sample_batch, index=-1)
|
||||
last_r = policy._value(**input_dict)
|
||||
# TODO: (sven) Remove once trajectory view API is all-algo default.
|
||||
else:
|
||||
next_state = []
|
||||
for i in range(policy.num_state_tensors()):
|
||||
next_state.append(sample_batch["state_out_{}".format(i)][-1])
|
||||
last_r = policy._value(sample_batch[SampleBatch.NEXT_OBS][-1],
|
||||
sample_batch[SampleBatch.ACTIONS][-1],
|
||||
sample_batch[SampleBatch.REWARDS][-1],
|
||||
*next_state)
|
||||
|
||||
# Adds the policy logits, VF preds, and advantages to the batch,
|
||||
# using GAE ("generalized advantage estimation") or not.
|
||||
|
@ -208,7 +217,9 @@ def postprocess_ppo_gae(
|
|||
last_r,
|
||||
policy.config["gamma"],
|
||||
policy.config["lambda"],
|
||||
use_gae=policy.config["use_gae"])
|
||||
use_gae=policy.config["use_gae"],
|
||||
use_critic=policy.config.get("use_critic", True))
|
||||
|
||||
return batch
|
||||
|
||||
|
||||
|
@ -292,25 +303,40 @@ class ValueNetworkMixin:
|
|||
# observation.
|
||||
if config["use_gae"]:
|
||||
|
||||
@make_tf_callable(self.get_session())
|
||||
def value(ob, prev_action, prev_reward, *state):
|
||||
model_out, _ = self.model({
|
||||
SampleBatch.CUR_OBS: tf.convert_to_tensor([ob]),
|
||||
SampleBatch.PREV_ACTIONS: tf.convert_to_tensor(
|
||||
[prev_action]),
|
||||
SampleBatch.PREV_REWARDS: tf.convert_to_tensor(
|
||||
[prev_reward]),
|
||||
"is_training": tf.convert_to_tensor([False]),
|
||||
}, [tf.convert_to_tensor([s]) for s in state],
|
||||
tf.convert_to_tensor([1]))
|
||||
# [0] = remove the batch dim.
|
||||
return self.model.value_function()[0]
|
||||
# Input dict is provided to us automatically via the Model's
|
||||
# requirements. It's a single-timestep (last one in trajectory)
|
||||
# input_dict.
|
||||
if config["_use_trajectory_view_api"]:
|
||||
|
||||
@make_tf_callable(self.get_session())
|
||||
def value(**input_dict):
|
||||
model_out, _ = self.model.from_batch(
|
||||
input_dict, is_training=False)
|
||||
# [0] = remove the batch dim.
|
||||
return self.model.value_function()[0]
|
||||
|
||||
# TODO: (sven) Remove once trajectory view API is all-algo default.
|
||||
else:
|
||||
|
||||
@make_tf_callable(self.get_session())
|
||||
def value(ob, prev_action, prev_reward, *state):
|
||||
model_out, _ = self.model({
|
||||
SampleBatch.CUR_OBS: tf.convert_to_tensor([ob]),
|
||||
SampleBatch.PREV_ACTIONS: tf.convert_to_tensor(
|
||||
[prev_action]),
|
||||
SampleBatch.PREV_REWARDS: tf.convert_to_tensor(
|
||||
[prev_reward]),
|
||||
"is_training": tf.convert_to_tensor([False]),
|
||||
}, [tf.convert_to_tensor([s]) for s in state],
|
||||
tf.convert_to_tensor([1]))
|
||||
# [0] = remove the batch dim.
|
||||
return self.model.value_function()[0]
|
||||
|
||||
# When not doing GAE, we do not require the value function's output.
|
||||
else:
|
||||
|
||||
@make_tf_callable(self.get_session())
|
||||
def value(ob, prev_action, prev_reward, *state):
|
||||
def value(*args, **kwargs):
|
||||
return tf.constant(0.0)
|
||||
|
||||
self._value = value
|
||||
|
|
|
@ -210,22 +210,36 @@ class ValueNetworkMixin:
|
|||
# When doing GAE, we need the value function estimate on the
|
||||
# observation.
|
||||
if config["use_gae"]:
|
||||
# Input dict is provided to us automatically via the Model's
|
||||
# requirements. It's a single-timestep (last one in trajectory)
|
||||
# input_dict.
|
||||
if config["_use_trajectory_view_api"]:
|
||||
|
||||
def value(ob, prev_action, prev_reward, *state):
|
||||
model_out, _ = self.model({
|
||||
SampleBatch.CUR_OBS: convert_to_torch_tensor(
|
||||
np.asarray([ob]), self.device),
|
||||
SampleBatch.PREV_ACTIONS: convert_to_torch_tensor(
|
||||
np.asarray([prev_action]), self.device),
|
||||
SampleBatch.PREV_REWARDS: convert_to_torch_tensor(
|
||||
np.asarray([prev_reward]), self.device),
|
||||
"is_training": False,
|
||||
}, [
|
||||
convert_to_torch_tensor(np.asarray([s]), self.device)
|
||||
for s in state
|
||||
], convert_to_torch_tensor(np.asarray([1]), self.device))
|
||||
# [0] = remove the batch dim.
|
||||
return self.model.value_function()[0]
|
||||
def value(**input_dict):
|
||||
model_out, _ = self.model.from_batch(
|
||||
convert_to_torch_tensor(input_dict, self.device),
|
||||
is_training=False)
|
||||
# [0] = remove the batch dim.
|
||||
return self.model.value_function()[0]
|
||||
|
||||
# TODO: (sven) Remove once trajectory view API is all-algo default.
|
||||
else:
|
||||
|
||||
def value(ob, prev_action, prev_reward, *state):
|
||||
model_out, _ = self.model({
|
||||
SampleBatch.CUR_OBS: convert_to_torch_tensor(
|
||||
np.asarray([ob]), self.device),
|
||||
SampleBatch.PREV_ACTIONS: convert_to_torch_tensor(
|
||||
np.asarray([prev_action]), self.device),
|
||||
SampleBatch.PREV_REWARDS: convert_to_torch_tensor(
|
||||
np.asarray([prev_reward]), self.device),
|
||||
"is_training": False,
|
||||
}, [
|
||||
convert_to_torch_tensor(np.asarray([s]), self.device)
|
||||
for s in state
|
||||
], convert_to_torch_tensor(np.asarray([1]), self.device))
|
||||
# [0] = remove the batch dim.
|
||||
return self.model.value_function()[0]
|
||||
|
||||
# When not doing GAE, we do not require the value function's output.
|
||||
else:
|
||||
|
|
|
@ -1,9 +1,6 @@
|
|||
from gym.spaces import Box
|
||||
|
||||
from ray.rllib.models.modelv2 import ModelV2
|
||||
from ray.rllib.models.preprocessors import get_preprocessor
|
||||
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
|
||||
from ray.rllib.policy.view_requirement import ViewRequirement
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils.framework import try_import_torch
|
||||
|
||||
|
@ -25,17 +22,13 @@ class RNNModel(TorchModelV2, nn.Module):
|
|||
self.fc2 = nn.Linear(self.rnn_hidden_dim, num_outputs)
|
||||
self.n_agents = model_config["n_agents"]
|
||||
|
||||
self.inference_view_requirements.update({
|
||||
"state_in_0": ViewRequirement(
|
||||
"state_out_0",
|
||||
data_rel_pos=-1,
|
||||
space=Box(-1.0, 1.0, (self.n_agents, self.rnn_hidden_dim)))
|
||||
})
|
||||
|
||||
@override(ModelV2)
|
||||
def get_initial_state(self):
|
||||
# Place hidden states on same device as model.
|
||||
return [self.fc1.weight.new(1, self.rnn_hidden_dim).zero_().squeeze(0)]
|
||||
return [
|
||||
self.fc1.weight.new(self.n_agents,
|
||||
self.rnn_hidden_dim).zero_().squeeze(0)
|
||||
]
|
||||
|
||||
@override(ModelV2)
|
||||
def forward(self, input_dict, hidden_state, seq_lens):
|
||||
|
|
|
@ -215,9 +215,6 @@ class QMixTorchPolicy(Policy):
|
|||
name="target_model",
|
||||
default_model=RNNModel).to(self.device)
|
||||
|
||||
# Combine view_requirements for Model and Policy.
|
||||
self.view_requirements.update(self.model.inference_view_requirements)
|
||||
|
||||
self.exploration = self._create_exploration()
|
||||
|
||||
# Setup the mixer network.
|
||||
|
|
|
@ -28,7 +28,7 @@ class MADDPGPostprocessing:
|
|||
other_agent_batches=None,
|
||||
episode=None):
|
||||
# FIXME: Get done from info is required since agentwise done is not
|
||||
# supported now.
|
||||
# supported now.
|
||||
sample_batch.data[SampleBatch.DONES] = self.get_done_from_info(
|
||||
sample_batch.data[SampleBatch.INFOS])
|
||||
|
||||
|
@ -251,6 +251,9 @@ class MADDPGTFPolicy(MADDPGPostprocessing, TFPolicy):
|
|||
loss_inputs=loss_inputs,
|
||||
dist_inputs=actor_feature)
|
||||
|
||||
del self.view_requirements["prev_actions"]
|
||||
del self.view_requirements["prev_rewards"]
|
||||
|
||||
self.sess.run(tf1.global_variables_initializer())
|
||||
|
||||
# Hard initial update
|
||||
|
|
|
@ -191,7 +191,7 @@ class _SampleCollector(metaclass=ABCMeta):
|
|||
postprocessor.
|
||||
This is usually called to collect samples for policy training.
|
||||
If not enough data has been collected yet (`rollout_fragment_length`),
|
||||
returns None.
|
||||
returns an empty list.
|
||||
|
||||
Returns:
|
||||
List[Union[MultiAgentBatch, SampleBatch]]: Returns a (possibly
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
import collections
|
||||
from gym.spaces import Space
|
||||
import logging
|
||||
import numpy as np
|
||||
from typing import Any, List, Dict, Tuple, TYPE_CHECKING, Union
|
||||
|
@ -8,12 +9,11 @@ from ray.rllib.evaluation.collectors.sample_collector import _SampleCollector
|
|||
from ray.rllib.evaluation.episode import MultiAgentEpisode
|
||||
from ray.rllib.policy.policy import Policy
|
||||
from ray.rllib.policy.sample_batch import SampleBatch, MultiAgentBatch
|
||||
from ray.rllib.policy.view_requirement import ViewRequirement
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils.debug import summarize
|
||||
from ray.rllib.utils.typing import AgentID, EpisodeID, EnvID, PolicyID, \
|
||||
TensorType
|
||||
from ray.rllib.utils.framework import try_import_tf, try_import_torch
|
||||
from ray.rllib.utils.typing import AgentID, EpisodeID, EnvID, PolicyID, \
|
||||
TensorType, ViewRequirementsDict
|
||||
from ray.util.debug import log_once
|
||||
|
||||
_, tf, _ = try_import_tf()
|
||||
|
@ -48,13 +48,13 @@ class _AgentCollector:
|
|||
def __init__(self, shift_before: int = 0):
|
||||
self.shift_before = max(shift_before, 1)
|
||||
self.buffers: Dict[str, List] = {}
|
||||
self.episode_id = None
|
||||
# The simple timestep count for this agent. Gets increased by one
|
||||
# each time a (non-initial!) observation is added.
|
||||
self.count = 0
|
||||
|
||||
def add_init_obs(self, episode_id: EpisodeID, agent_index: int,
|
||||
env_id: EnvID, t: int, init_obs: TensorType,
|
||||
view_requirements: Dict[str, ViewRequirement]) -> None:
|
||||
env_id: EnvID, t: int, init_obs: TensorType) -> None:
|
||||
"""Adds an initial observation (after reset) to the Agent's trajectory.
|
||||
|
||||
Args:
|
||||
|
@ -67,19 +67,17 @@ class _AgentCollector:
|
|||
ts=-1(!), then an action/reward/next-obs at t=0, etc..
|
||||
init_obs (TensorType): The initial observation tensor (after
|
||||
`env.reset()`).
|
||||
view_requirements (Dict[str, ViewRequirements])
|
||||
"""
|
||||
if SampleBatch.OBS not in self.buffers:
|
||||
self._build_buffers(
|
||||
single_row={
|
||||
SampleBatch.OBS: init_obs,
|
||||
SampleBatch.EPS_ID: episode_id,
|
||||
SampleBatch.AGENT_INDEX: agent_index,
|
||||
"env_id": env_id,
|
||||
"t": t,
|
||||
})
|
||||
self.buffers[SampleBatch.OBS].append(init_obs)
|
||||
self.buffers[SampleBatch.EPS_ID].append(episode_id)
|
||||
self.episode_id = episode_id
|
||||
self.buffers[SampleBatch.AGENT_INDEX].append(agent_index)
|
||||
self.buffers["env_id"].append(env_id)
|
||||
self.buffers["t"].append(t)
|
||||
|
@ -97,6 +95,11 @@ class _AgentCollector:
|
|||
assert SampleBatch.OBS not in values
|
||||
values[SampleBatch.OBS] = values[SampleBatch.NEXT_OBS]
|
||||
del values[SampleBatch.NEXT_OBS]
|
||||
# Make sure EPS_ID stays the same for this agent. Usually, it should
|
||||
# not be part of `values` anyways.
|
||||
if SampleBatch.EPS_ID in values:
|
||||
assert values[SampleBatch.EPS_ID] == self.episode_id
|
||||
del values[SampleBatch.EPS_ID]
|
||||
|
||||
for k, v in values.items():
|
||||
if k not in self.buffers:
|
||||
|
@ -104,8 +107,7 @@ class _AgentCollector:
|
|||
self.buffers[k].append(v)
|
||||
self.count += 1
|
||||
|
||||
def build(self, view_requirements: Dict[str, ViewRequirement]) -> \
|
||||
SampleBatch:
|
||||
def build(self, view_requirements: ViewRequirementsDict) -> SampleBatch:
|
||||
"""Builds a SampleBatch from the thus-far collected agent data.
|
||||
|
||||
If the episode/trajectory has no DONE=True at the end, will copy
|
||||
|
@ -115,32 +117,29 @@ class _AgentCollector:
|
|||
by a Policy.
|
||||
|
||||
Args:
|
||||
view_requirements (Dict[str, ViewRequirement]: The view
|
||||
view_requirements (ViewRequirementsDict): The view
|
||||
requirements dict needed to build the SampleBatch from the raw
|
||||
buffers (which may have data shifts as well as mappings from
|
||||
view-col to data-col in them).
|
||||
|
||||
Returns:
|
||||
SampleBatch: The built SampleBatch for this agent, ready to go into
|
||||
postprocessing.
|
||||
"""
|
||||
|
||||
# TODO: measure performance gains when using a UsageTrackingDict
|
||||
# instead of a SampleBatch for postprocessing (this would eliminate
|
||||
# copies (for creating this SampleBatch) of many unused columns for
|
||||
# no reason (not used by postprocessor)).
|
||||
|
||||
batch_data = {}
|
||||
np_data = {}
|
||||
for view_col, view_req in view_requirements.items():
|
||||
# Create the batch of data from the different buffers.
|
||||
data_col = view_req.data_col or view_col
|
||||
|
||||
# Some columns don't exist yet (get created during postprocessing).
|
||||
# -> skip.
|
||||
if data_col not in self.buffers:
|
||||
continue
|
||||
# OBS are already shifted by -1 (the initial obs starts one ts
|
||||
# before all other data columns).
|
||||
shift = view_req.data_rel_pos - \
|
||||
shift = view_req.shift - \
|
||||
(1 if data_col == SampleBatch.OBS else 0)
|
||||
if data_col not in np_data:
|
||||
np_data[data_col] = to_float_np_array(self.buffers[data_col])
|
||||
|
@ -161,8 +160,12 @@ class _AgentCollector:
|
|||
data = np_data[data_col][self.shift_before + shift:shift]
|
||||
if len(data) > 0:
|
||||
batch_data[view_col] = data
|
||||
|
||||
batch = SampleBatch(batch_data)
|
||||
|
||||
# Add EPS_ID and UNROLL_ID to batch.
|
||||
batch.data[SampleBatch.EPS_ID] = np.repeat(self.episode_id,
|
||||
batch.count)
|
||||
if SampleBatch.UNROLL_ID not in batch.data:
|
||||
# TODO: (sven) Once we have the additional
|
||||
# model.preprocess_train_batch in place (attention net PR), we
|
||||
|
@ -200,7 +203,7 @@ class _AgentCollector:
|
|||
] else 0)
|
||||
# Python primitive or dict (e.g. INFOs).
|
||||
if isinstance(data, (int, float, bool, str, dict)):
|
||||
self.buffers[col] = [0 for _ in range(shift)]
|
||||
self.buffers[col] = [data for _ in range(shift)]
|
||||
# np.ndarray, torch.Tensor, or tf.Tensor.
|
||||
else:
|
||||
shape = data.shape
|
||||
|
@ -239,25 +242,24 @@ class _PolicyCollector:
|
|||
|
||||
def add_postprocessed_batch_for_training(
|
||||
self, batch: SampleBatch,
|
||||
view_requirements: Dict[str, ViewRequirement]) -> None:
|
||||
view_requirements: ViewRequirementsDict) -> None:
|
||||
"""Adds a postprocessed SampleBatch (single agent) to our buffers.
|
||||
|
||||
Args:
|
||||
batch (SampleBatch): A single agent (one trajectory) SampleBatch
|
||||
to be added to the Policy's buffers.
|
||||
view_requirements (Dict[str, ViewRequirement]: The view
|
||||
view_requirements (DViewRequirementsDict): The view
|
||||
requirements for the policy. This is so we know, whether a
|
||||
view-column needs to be copied at all (not needed for
|
||||
training).
|
||||
"""
|
||||
for view_col, data in batch.items():
|
||||
# TODO(ekl) how do we handle this for policies that don't extend
|
||||
# Torch / TF Policy template (no inference of view reqs)?
|
||||
# Skip columns that are not used for training.
|
||||
# if view_col not in view_requirements or \
|
||||
# not view_requirements[view_col].used_for_training:
|
||||
# continue
|
||||
self.buffers[view_col].extend(data)
|
||||
# 1) If col is not in view_requirements, we must have a direct
|
||||
# child of the base Policy that doesn't do auto-view req creation.
|
||||
# 2) Col is in view-reqs and needed for training.
|
||||
if view_col not in view_requirements or \
|
||||
view_requirements[view_col].used_for_training:
|
||||
self.buffers[view_col].extend(data)
|
||||
# Add the agent's trajectory length to our count.
|
||||
self.count += batch.count
|
||||
|
||||
|
@ -380,9 +382,6 @@ class _SimpleListCollector(_SampleCollector):
|
|||
self.agent_key_to_policy_id[agent_key] = policy_id
|
||||
else:
|
||||
assert self.agent_key_to_policy_id[agent_key] == policy_id
|
||||
policy = self.policy_map[policy_id]
|
||||
view_reqs = policy.model.inference_view_requirements if \
|
||||
getattr(policy, "model", None) else policy.view_requirements
|
||||
|
||||
# Add initial obs to Trajectory.
|
||||
assert agent_key not in self.agent_collectors
|
||||
|
@ -393,8 +392,7 @@ class _SimpleListCollector(_SampleCollector):
|
|||
agent_index=episode._agent_index(agent_id),
|
||||
env_id=env_id,
|
||||
t=t,
|
||||
init_obs=init_obs,
|
||||
view_requirements=view_reqs)
|
||||
init_obs=init_obs)
|
||||
|
||||
self.episodes[episode.episode_id] = episode
|
||||
if episode.batch_builder is None:
|
||||
|
@ -442,17 +440,22 @@ class _SimpleListCollector(_SampleCollector):
|
|||
# Create the batch of data from the different buffers.
|
||||
data_col = view_req.data_col or view_col
|
||||
time_indices = \
|
||||
view_req.data_rel_pos - (
|
||||
view_req.shift - (
|
||||
1 if data_col in [SampleBatch.OBS, "t", "env_id",
|
||||
SampleBatch.EPS_ID,
|
||||
SampleBatch.AGENT_INDEX] else 0)
|
||||
data_list = []
|
||||
for k in keys:
|
||||
if data_col not in buffers[k]:
|
||||
self.agent_collectors[k]._build_buffers({
|
||||
data_col: view_req.space.sample()
|
||||
})
|
||||
data_list.append(buffers[k][data_col][time_indices])
|
||||
if data_col == SampleBatch.EPS_ID:
|
||||
data_list.append(self.agent_collectors[k].episode_id)
|
||||
else:
|
||||
if data_col not in buffers[k]:
|
||||
fill_value = np.zeros_like(view_req.space.sample()) \
|
||||
if isinstance(view_req.space, Space) else \
|
||||
view_req.space
|
||||
self.agent_collectors[k]._build_buffers({
|
||||
data_col: fill_value
|
||||
})
|
||||
data_list.append(buffers[k][data_col][time_indices])
|
||||
input_dict[view_col] = np.array(data_list)
|
||||
|
||||
self._reset_inference_calls(policy_id)
|
||||
|
@ -517,8 +520,8 @@ class _SimpleListCollector(_SampleCollector):
|
|||
del other_batches[agent_id]
|
||||
pid = self.agent_key_to_policy_id[(episode_id, agent_id)]
|
||||
policy = self.policy_map[pid]
|
||||
if any(pre_batch["dones"][:-1]) or len(set(
|
||||
pre_batch["eps_id"])) > 1:
|
||||
if any(pre_batch[SampleBatch.DONES][:-1]) or len(
|
||||
set(pre_batch[SampleBatch.EPS_ID])) > 1:
|
||||
raise ValueError(
|
||||
"Batches sent to postprocessing must only contain steps "
|
||||
"from a single trajectory.", pre_batch)
|
||||
|
|
|
@ -177,6 +177,7 @@ class RolloutWorker(ParallelIteratorWorker):
|
|||
fake_sampler: bool = False,
|
||||
spaces: Optional[Dict[PolicyID, Tuple[gym.spaces.Space,
|
||||
gym.spaces.Space]]] = None,
|
||||
_use_trajectory_view_api: bool = True,
|
||||
policy: Union[type, Dict[
|
||||
str, Tuple[Optional[type], gym.Space, gym.Space,
|
||||
PartialTrainerConfigDict]]] = None,
|
||||
|
@ -295,6 +296,8 @@ class RolloutWorker(ParallelIteratorWorker):
|
|||
gym.spaces.Space]]]): An optional space dict mapping policy IDs
|
||||
to (obs_space, action_space)-tuples. This is used in case no
|
||||
Env is created on this RolloutWorker.
|
||||
_use_trajectory_view_api (bool): Whether to collect samples through
|
||||
the experimental Trajectory View API.
|
||||
policy: Obsoleted arg. Use `policy_spec` instead.
|
||||
"""
|
||||
# Deprecated arg.
|
||||
|
@ -459,6 +462,14 @@ class RolloutWorker(ParallelIteratorWorker):
|
|||
self.policy_map, self.preprocessors = self._build_policy_map(
|
||||
policy_dict, policy_config)
|
||||
|
||||
# Update Policy's view requirements from Model, only if Policy directly
|
||||
# inherited from base `Policy` class. At this point here, the Policy
|
||||
# must have it's Model (if any) defined and ready to output an initial
|
||||
# state.
|
||||
for pol in self.policy_map.values():
|
||||
if not pol._model_init_state_automatically_added:
|
||||
pol._update_model_inference_view_requirements_from_init_state()
|
||||
|
||||
if (ray.is_initialized()
|
||||
and ray.worker._mode() != ray.worker.LOCAL_MODE):
|
||||
# Check available number of GPUs
|
||||
|
@ -568,8 +579,8 @@ class RolloutWorker(ParallelIteratorWorker):
|
|||
soft_horizon=soft_horizon,
|
||||
no_done_at_end=no_done_at_end,
|
||||
observation_fn=observation_fn,
|
||||
_use_trajectory_view_api=policy_config.get(
|
||||
"_use_trajectory_view_api", False))
|
||||
_use_trajectory_view_api=_use_trajectory_view_api,
|
||||
)
|
||||
# Start the Sampler thread.
|
||||
self.sampler.start()
|
||||
else:
|
||||
|
@ -590,8 +601,8 @@ class RolloutWorker(ParallelIteratorWorker):
|
|||
soft_horizon=soft_horizon,
|
||||
no_done_at_end=no_done_at_end,
|
||||
observation_fn=observation_fn,
|
||||
_use_trajectory_view_api=policy_config.get(
|
||||
"_use_trajectory_view_api", False))
|
||||
_use_trajectory_view_api=_use_trajectory_view_api,
|
||||
)
|
||||
|
||||
self.input_reader: InputReader = input_creator(self.io_context)
|
||||
self.output_writer: OutputWriter = output_creator(self.io_context)
|
||||
|
|
|
@ -1046,7 +1046,6 @@ def _process_observations_w_trajectory_view_api(
|
|||
# Add actions, rewards, next-obs to collectors.
|
||||
values_dict = {
|
||||
"t": episode.length - 1,
|
||||
"eps_id": episode.episode_id,
|
||||
"env_id": env_id,
|
||||
"agent_index": episode._agent_index(agent_id),
|
||||
# Action (slot 0) taken at timestep t.
|
||||
|
|
|
@ -59,7 +59,7 @@ class TestTrajectoryViewAPI(unittest.TestCase):
|
|||
assert view_req_policy[key].data_col is None
|
||||
else:
|
||||
assert view_req_policy[key].data_col == SampleBatch.OBS
|
||||
assert view_req_policy[key].data_rel_pos == 1
|
||||
assert view_req_policy[key].shift == 1
|
||||
rollout_worker = trainer.workers.local_worker()
|
||||
sample_batch = rollout_worker.sample()
|
||||
expected_count = \
|
||||
|
@ -99,10 +99,10 @@ class TestTrajectoryViewAPI(unittest.TestCase):
|
|||
|
||||
if key == SampleBatch.PREV_ACTIONS:
|
||||
assert view_req_policy[key].data_col == SampleBatch.ACTIONS
|
||||
assert view_req_policy[key].data_rel_pos == -1
|
||||
assert view_req_policy[key].shift == -1
|
||||
elif key == SampleBatch.PREV_REWARDS:
|
||||
assert view_req_policy[key].data_col == SampleBatch.REWARDS
|
||||
assert view_req_policy[key].data_rel_pos == -1
|
||||
assert view_req_policy[key].shift == -1
|
||||
elif key not in [
|
||||
SampleBatch.NEXT_OBS, SampleBatch.PREV_ACTIONS,
|
||||
SampleBatch.PREV_REWARDS
|
||||
|
@ -110,7 +110,7 @@ class TestTrajectoryViewAPI(unittest.TestCase):
|
|||
assert view_req_policy[key].data_col is None
|
||||
else:
|
||||
assert view_req_policy[key].data_col == SampleBatch.OBS
|
||||
assert view_req_policy[key].data_rel_pos == 1
|
||||
assert view_req_policy[key].shift == 1
|
||||
trainer.stop()
|
||||
|
||||
def test_traj_view_simple_performance(self):
|
||||
|
|
|
@ -352,6 +352,7 @@ class WorkerSet:
|
|||
fake_sampler=config["fake_sampler"],
|
||||
extra_python_environs=extra_python_environs,
|
||||
spaces=spaces,
|
||||
_use_trajectory_view_api=config["_use_trajectory_view_api"],
|
||||
)
|
||||
|
||||
return worker
|
||||
|
|
|
@ -28,15 +28,15 @@ class EpisodeEnvAwareLSTMPolicy(RandomPolicy):
|
|||
"t": ViewRequirement(),
|
||||
SampleBatch.OBS: ViewRequirement(),
|
||||
SampleBatch.PREV_ACTIONS: ViewRequirement(
|
||||
SampleBatch.ACTIONS, space=self.action_space, data_rel_pos=-1),
|
||||
SampleBatch.ACTIONS, space=self.action_space, shift=-1),
|
||||
SampleBatch.PREV_REWARDS: ViewRequirement(
|
||||
SampleBatch.REWARDS, data_rel_pos=-1),
|
||||
SampleBatch.REWARDS, shift=-1),
|
||||
}
|
||||
for i in range(2):
|
||||
self.model.inference_view_requirements["state_in_{}".format(i)] = \
|
||||
ViewRequirement(
|
||||
"state_out_{}".format(i),
|
||||
data_rel_pos=-1,
|
||||
shift=-1,
|
||||
space=self.state_space)
|
||||
self.model.inference_view_requirements[
|
||||
"state_out_{}".format(i)] = \
|
||||
|
@ -45,7 +45,7 @@ class EpisodeEnvAwareLSTMPolicy(RandomPolicy):
|
|||
self.view_requirements = dict(
|
||||
**{
|
||||
SampleBatch.NEXT_OBS: ViewRequirement(
|
||||
SampleBatch.OBS, data_rel_pos=1),
|
||||
SampleBatch.OBS, shift=1),
|
||||
SampleBatch.ACTIONS: ViewRequirement(space=self.action_space),
|
||||
SampleBatch.REWARDS: ViewRequirement(),
|
||||
SampleBatch.DONES: ViewRequirement(),
|
||||
|
@ -106,7 +106,7 @@ class EpisodeEnvAwareAttentionPolicy(RandomPolicy):
|
|||
"state_in_0": ViewRequirement(
|
||||
"state_out_0",
|
||||
# Provide state outs -50 to -1 as "state-in".
|
||||
data_rel_pos="-50:-1",
|
||||
shift="-50:-1",
|
||||
# Repeat the incoming state every n time steps (usually max seq
|
||||
# len).
|
||||
batch_repeat_value=self.config["model"]["max_seq_len"],
|
||||
|
|
|
@ -16,7 +16,7 @@ class AlwaysSameHeuristic(Policy):
|
|||
self.view_requirements.update({
|
||||
"state_in_0": ViewRequirement(
|
||||
"state_out_0",
|
||||
data_rel_pos=-1,
|
||||
shift=-1,
|
||||
space=gym.spaces.Box(0, 100, shape=(), dtype=np.int32))
|
||||
})
|
||||
|
||||
|
|
|
@ -143,6 +143,7 @@ if __name__ == "__main__":
|
|||
# Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0.
|
||||
"num_gpus": int(os.environ.get("RLLIB_NUM_GPUS", "0")),
|
||||
"framework": "torch" if args.torch else "tf",
|
||||
"_use_trajectory_view_api": True,
|
||||
}
|
||||
|
||||
stop = {
|
||||
|
|
|
@ -61,8 +61,7 @@ class ModelV2:
|
|||
self.time_major = self.model_config.get("_time_major")
|
||||
# Basic view requirement for all models: Use the observation as input.
|
||||
self.inference_view_requirements = {
|
||||
SampleBatch.OBS: ViewRequirement(
|
||||
data_rel_pos=0, space=self.obs_space),
|
||||
SampleBatch.OBS: ViewRequirement(shift=0, space=self.obs_space),
|
||||
}
|
||||
|
||||
# TODO: (sven): Get rid of `get_initial_state` once Trajectory
|
||||
|
@ -315,6 +314,29 @@ class ModelV2:
|
|||
"""
|
||||
return self.time_major is True
|
||||
|
||||
# TODO: (sven) Experimental method.
|
||||
def get_input_dict(self, sample_batch,
|
||||
index: int = -1) -> Dict[str, TensorType]:
|
||||
if index < 0:
|
||||
index = sample_batch.count - 1
|
||||
|
||||
input_dict = {}
|
||||
for view_col, view_req in self.inference_view_requirements.items():
|
||||
# Create batches of size 1 (single-agent input-dict).
|
||||
|
||||
# Index range.
|
||||
if isinstance(index, tuple):
|
||||
data = sample_batch[view_col][index[0]:index[1] + 1]
|
||||
input_dict[view_col] = np.array([data])
|
||||
# Single index.
|
||||
else:
|
||||
input_dict[view_col] = sample_batch[view_col][index:index + 1]
|
||||
|
||||
# Add valid `seq_lens`, just in case RNNs need it.
|
||||
input_dict["seq_lens"] = np.array([1], dtype=np.int32)
|
||||
|
||||
return input_dict
|
||||
|
||||
|
||||
class NullContextManager:
|
||||
"""No-op context manager"""
|
||||
|
|
|
@ -178,10 +178,10 @@ class LSTMWrapper(RecurrentNetwork):
|
|||
if model_config["lstm_use_prev_action"]:
|
||||
self.inference_view_requirements[SampleBatch.PREV_ACTIONS] = \
|
||||
ViewRequirement(SampleBatch.ACTIONS, space=self.action_space,
|
||||
data_rel_pos=-1)
|
||||
shift=-1)
|
||||
if model_config["lstm_use_prev_reward"]:
|
||||
self.inference_view_requirements[SampleBatch.PREV_REWARDS] = \
|
||||
ViewRequirement(SampleBatch.REWARDS, data_rel_pos=-1)
|
||||
ViewRequirement(SampleBatch.REWARDS, shift=-1)
|
||||
|
||||
@override(RecurrentNetwork)
|
||||
def forward(self, input_dict: Dict[str, TensorType],
|
||||
|
|
|
@ -159,10 +159,10 @@ class LSTMWrapper(RecurrentNetwork, nn.Module):
|
|||
if model_config["lstm_use_prev_action"]:
|
||||
self.inference_view_requirements[SampleBatch.PREV_ACTIONS] = \
|
||||
ViewRequirement(SampleBatch.ACTIONS, space=self.action_space,
|
||||
data_rel_pos=-1)
|
||||
shift=-1)
|
||||
if model_config["lstm_use_prev_reward"]:
|
||||
self.inference_view_requirements[SampleBatch.PREV_REWARDS] = \
|
||||
ViewRequirement(SampleBatch.REWARDS, data_rel_pos=-1)
|
||||
ViewRequirement(SampleBatch.REWARDS, shift=-1)
|
||||
|
||||
@override(RecurrentNetwork)
|
||||
def forward(self, input_dict: Dict[str, TensorType],
|
||||
|
|
|
@ -80,8 +80,6 @@ class DynamicTFPolicy(TFPolicy):
|
|||
], Tuple[TensorType, type, List[TensorType]]]] = None,
|
||||
existing_inputs: Optional[Dict[str, "tf1.placeholder"]] = None,
|
||||
existing_model: Optional[ModelV2] = None,
|
||||
view_requirements_fn: Optional[Callable[[Policy], Dict[
|
||||
str, ViewRequirement]]] = None,
|
||||
get_batch_divisibility_req: Optional[Callable[[Policy],
|
||||
int]] = None,
|
||||
obs_include_prev_action_reward: bool = True):
|
||||
|
@ -292,14 +290,6 @@ class DynamicTFPolicy(TFPolicy):
|
|||
action_distribution=action_dist,
|
||||
timestep=timestep,
|
||||
explore=explore)
|
||||
if self.config["_use_trajectory_view_api"]:
|
||||
self._dummy_batch[SampleBatch.ACTION_DIST_INPUTS] = \
|
||||
np.zeros(
|
||||
[1 if not s else s for s in
|
||||
dist_inputs.shape.as_list()])
|
||||
self._input_dict[SampleBatch.ACTION_DIST_INPUTS] = \
|
||||
tf1.placeholder(shape=dist_inputs.shape.as_list(),
|
||||
dtype=tf.float32)
|
||||
|
||||
# Phase 1 init.
|
||||
sess = tf1.get_default_session() or tf1.Session()
|
||||
|
@ -417,42 +407,37 @@ class DynamicTFPolicy(TFPolicy):
|
|||
input_dict/dummy_batch tuple.
|
||||
"""
|
||||
input_dict = {}
|
||||
dummy_batch = {}
|
||||
for view_col, view_req in view_requirements.items():
|
||||
# Point state_in to the already existing self._state_inputs.
|
||||
mo = re.match("state_in_(\d+)", view_col)
|
||||
if mo is not None:
|
||||
input_dict[view_col] = self._state_inputs[int(mo.group(1))]
|
||||
dummy_batch[view_col] = np.zeros_like(
|
||||
[view_req.space.sample()])
|
||||
# State-outs (no placeholders needed).
|
||||
elif view_col.startswith("state_out_"):
|
||||
dummy_batch[view_col] = np.zeros_like(
|
||||
[view_req.space.sample()])
|
||||
continue
|
||||
# Skip action dist inputs placeholder (do later).
|
||||
elif view_col == SampleBatch.ACTION_DIST_INPUTS:
|
||||
continue
|
||||
elif view_col in existing_inputs:
|
||||
input_dict[view_col] = existing_inputs[view_col]
|
||||
dummy_batch[view_col] = np.zeros(
|
||||
shape=[
|
||||
1 if s is None else s
|
||||
for s in existing_inputs[view_col].shape.as_list()
|
||||
],
|
||||
dtype=existing_inputs[view_col].dtype.as_numpy_dtype)
|
||||
# All others.
|
||||
else:
|
||||
if view_req.used_for_training:
|
||||
input_dict[view_col] = get_placeholder(
|
||||
space=view_req.space, name=view_col)
|
||||
dummy_batch[view_col] = np.zeros_like(
|
||||
[view_req.space.sample()])
|
||||
dummy_batch = self._get_dummy_batch_from_view_requirements(
|
||||
batch_size=32)
|
||||
|
||||
return input_dict, dummy_batch
|
||||
|
||||
def _initialize_loss_from_dummy_batch(
|
||||
self, auto_remove_unneeded_view_reqs: bool = True,
|
||||
stats_fn=None) -> None:
|
||||
|
||||
# Create the optimizer/exploration optimizer here. Some initialization
|
||||
# steps (e.g. exploration postprocessing) may need this.
|
||||
self._optimizer = self.optimizer()
|
||||
|
||||
# Test calls depend on variable init, so initialize model first.
|
||||
self._sess.run(tf1.global_variables_initializer())
|
||||
|
||||
|
@ -509,6 +494,8 @@ class DynamicTFPolicy(TFPolicy):
|
|||
batch_for_postproc = UsageTrackingDict(sb)
|
||||
batch_for_postproc.count = sb.count
|
||||
logger.info("Testing `postprocess_trajectory` w/ dummy batch.")
|
||||
self.exploration.postprocess_trajectory(self, batch_for_postproc,
|
||||
self._sess)
|
||||
postprocessed_batch = self.postprocess_trajectory(batch_for_postproc)
|
||||
# Add new columns automatically to (loss) input_dict.
|
||||
if self.config["_use_trajectory_view_api"]:
|
||||
|
@ -588,7 +575,8 @@ class DynamicTFPolicy(TFPolicy):
|
|||
batch_for_postproc.accessed_keys
|
||||
# Tag those only needed for post-processing.
|
||||
for key in batch_for_postproc.accessed_keys:
|
||||
if key not in train_batch.accessed_keys:
|
||||
if key not in train_batch.accessed_keys and \
|
||||
key not in self.model.inference_view_requirements:
|
||||
self.view_requirements[key].used_for_training = False
|
||||
if key in self._loss_input_dict:
|
||||
del self._loss_input_dict[key]
|
||||
|
|
|
@ -194,7 +194,6 @@ def build_eager_tf_policy(name,
|
|||
action_sampler_fn=None,
|
||||
action_distribution_fn=None,
|
||||
mixins=None,
|
||||
view_requirements_fn=None,
|
||||
obs_include_prev_action_reward=True,
|
||||
get_batch_divisibility_req=None):
|
||||
"""Build an eager TF policy.
|
||||
|
@ -265,9 +264,6 @@ def build_eager_tf_policy(name,
|
|||
for s in self.model.get_initial_state()
|
||||
]
|
||||
|
||||
# Update this Policy's ViewRequirements (if function given).
|
||||
if callable(view_requirements_fn):
|
||||
self.view_requirements.update(view_requirements_fn(self))
|
||||
# Combine view_requirements for Model and Policy.
|
||||
self.view_requirements.update(
|
||||
self.model.inference_view_requirements)
|
||||
|
@ -275,12 +271,6 @@ def build_eager_tf_policy(name,
|
|||
if before_loss_init:
|
||||
before_loss_init(self, observation_space, action_space, config)
|
||||
|
||||
self._initialize_loss_from_dummy_batch(
|
||||
auto_remove_unneeded_view_reqs=True,
|
||||
stats_fn=stats_fn,
|
||||
)
|
||||
self._loss_initialized = True
|
||||
|
||||
if optimizer_fn:
|
||||
optimizers = optimizer_fn(self, config)
|
||||
else:
|
||||
|
@ -293,10 +283,16 @@ def build_eager_tf_policy(name,
|
|||
# Just like torch Policy does.
|
||||
self._optimizer = optimizers[0] if optimizers else None
|
||||
|
||||
self._initialize_loss_from_dummy_batch(
|
||||
auto_remove_unneeded_view_reqs=True,
|
||||
stats_fn=stats_fn,
|
||||
)
|
||||
self._loss_initialized = True
|
||||
|
||||
if after_init:
|
||||
after_init(self, observation_space, action_space, config)
|
||||
|
||||
# Got to reset global_timestep again after this fake run-through.
|
||||
# Got to reset global_timestep again after fake run-throughs.
|
||||
self.global_timestep = 0
|
||||
|
||||
@override(Policy)
|
||||
|
@ -410,7 +406,7 @@ def build_eager_tf_policy(name,
|
|||
timestep=timestep, explore=explore)
|
||||
|
||||
if action_distribution_fn:
|
||||
dist_inputs, dist_class, state_out = \
|
||||
dist_inputs, self.dist_class, state_out = \
|
||||
action_distribution_fn(
|
||||
self, self.model,
|
||||
input_dict[SampleBatch.CUR_OBS],
|
||||
|
@ -418,11 +414,10 @@ def build_eager_tf_policy(name,
|
|||
timestep=timestep,
|
||||
is_training=False)
|
||||
else:
|
||||
dist_class = self.dist_class
|
||||
dist_inputs, state_out = self.model(
|
||||
input_dict, state_batches, seq_lens)
|
||||
|
||||
action_dist = dist_class(dist_inputs, self.model)
|
||||
action_dist = self.dist_class(dist_inputs, self.model)
|
||||
|
||||
# Get the exploration action from the forward results.
|
||||
actions, logp = self.exploration.get_exploration_action(
|
||||
|
@ -466,12 +461,12 @@ def build_eager_tf_policy(name,
|
|||
"is_training": tf.constant(False),
|
||||
}
|
||||
if obs_include_prev_action_reward:
|
||||
input_dict.update({
|
||||
SampleBatch.PREV_ACTIONS: tf.convert_to_tensor(
|
||||
prev_action_batch),
|
||||
SampleBatch.PREV_REWARDS: tf.convert_to_tensor(
|
||||
prev_reward_batch),
|
||||
})
|
||||
if prev_action_batch is not None:
|
||||
input_dict[SampleBatch.PREV_ACTIONS] = \
|
||||
tf.convert_to_tensor(prev_action_batch)
|
||||
if prev_reward_batch is not None:
|
||||
input_dict[SampleBatch.PREV_REWARDS] = \
|
||||
tf.convert_to_tensor(prev_reward_batch)
|
||||
|
||||
# Exploration hook before each forward pass.
|
||||
self.exploration.before_compute_actions(explore=False)
|
||||
|
@ -559,7 +554,9 @@ def build_eager_tf_policy(name,
|
|||
|
||||
@override(Policy)
|
||||
def get_initial_state(self):
|
||||
return self.model.get_initial_state()
|
||||
if hasattr(self, "model"):
|
||||
return self.model.get_initial_state()
|
||||
return []
|
||||
|
||||
def get_session(self):
|
||||
return None # None implies eager
|
||||
|
|
|
@ -92,6 +92,7 @@ class Policy(metaclass=ABCMeta):
|
|||
self.view_requirements = view_reqs
|
||||
else:
|
||||
self.view_requirements.update(view_reqs)
|
||||
self._model_init_state_automatically_added = False
|
||||
|
||||
@abstractmethod
|
||||
@DeveloperAPI
|
||||
|
@ -278,7 +279,8 @@ class Policy(metaclass=ABCMeta):
|
|||
# `self.compute_actions()`.
|
||||
state_batches = [
|
||||
# TODO: (sven) remove unsqueezing code here for non-traj.view API.
|
||||
s if self.config["_use_trajectory_view_api"] else s.unsqueeze(0)
|
||||
s if self.config.get("_use_trajectory_view_api", False) else
|
||||
s.unsqueeze(0)
|
||||
if torch and isinstance(s, torch.Tensor) else np.expand_dims(s, 0)
|
||||
for k, s in input_dict.items() if k[:9] == "state_in_"
|
||||
]
|
||||
|
@ -564,16 +566,25 @@ class Policy(metaclass=ABCMeta):
|
|||
SampleBatch.OBS: ViewRequirement(space=self.observation_space),
|
||||
SampleBatch.NEXT_OBS: ViewRequirement(
|
||||
data_col=SampleBatch.OBS,
|
||||
data_rel_pos=1,
|
||||
shift=1,
|
||||
space=self.observation_space),
|
||||
SampleBatch.ACTIONS: ViewRequirement(space=self.action_space),
|
||||
# For backward compatibility with custom Models that don't specify
|
||||
# these explicitly (will be removed by Policy if not used).
|
||||
SampleBatch.PREV_ACTIONS: ViewRequirement(
|
||||
data_col=SampleBatch.ACTIONS,
|
||||
shift=-1,
|
||||
space=self.action_space),
|
||||
SampleBatch.REWARDS: ViewRequirement(),
|
||||
# For backward compatibility with custom Models that don't specify
|
||||
# these explicitly (will be removed by Policy if not used).
|
||||
SampleBatch.PREV_REWARDS: ViewRequirement(
|
||||
data_col=SampleBatch.REWARDS, shift=-1),
|
||||
SampleBatch.DONES: ViewRequirement(),
|
||||
SampleBatch.INFOS: ViewRequirement(),
|
||||
SampleBatch.EPS_ID: ViewRequirement(),
|
||||
SampleBatch.UNROLL_ID: ViewRequirement(),
|
||||
SampleBatch.AGENT_INDEX: ViewRequirement(),
|
||||
SampleBatch.UNROLL_ID: ViewRequirement(),
|
||||
"t": ViewRequirement(),
|
||||
}
|
||||
|
||||
|
@ -616,6 +627,7 @@ class Policy(metaclass=ABCMeta):
|
|||
-1.0, 1.0, shape=value.shape[1:], dtype=value.dtype))
|
||||
batch_for_postproc = UsageTrackingDict(self._dummy_batch)
|
||||
batch_for_postproc.count = self._dummy_batch.count
|
||||
self.exploration.postprocess_trajectory(self, batch_for_postproc)
|
||||
postprocessed_batch = self.postprocess_trajectory(batch_for_postproc)
|
||||
if state_outs:
|
||||
B = 4 # For RNNs, have B=4, T=[depends on sample_batch_size]
|
||||
|
@ -700,27 +712,33 @@ class Policy(metaclass=ABCMeta):
|
|||
ret[view_col] = \
|
||||
np.zeros((batch_size, ) + shape[1:], np.float32)
|
||||
else:
|
||||
ret[view_col] = np.zeros_like(
|
||||
[view_req.space.sample() for _ in range(batch_size)])
|
||||
if isinstance(view_req.space, gym.spaces.Space):
|
||||
ret[view_col] = np.zeros_like(
|
||||
[view_req.space.sample() for _ in range(batch_size)])
|
||||
else:
|
||||
ret[view_col] = [view_req.space for _ in range(batch_size)]
|
||||
|
||||
return SampleBatch(ret)
|
||||
|
||||
def _update_model_inference_view_requirements_from_init_state(self):
|
||||
"""Uses this Model's initial state to auto-add necessary ViewReqs.
|
||||
"""Uses Model's (or this Policy's) init state to add needed ViewReqs.
|
||||
|
||||
Can be called from within a Policy to make sure RNNs automatically
|
||||
update their internal state-related view requirements.
|
||||
Changes the `self.inference_view_requirements` dict.
|
||||
"""
|
||||
model = self.model
|
||||
self._model_init_state_automatically_added = True
|
||||
model = getattr(self, "model", None)
|
||||
obj = model or self
|
||||
# Add state-ins to this model's view.
|
||||
for i, state in enumerate(model.get_initial_state()):
|
||||
model.inference_view_requirements["state_in_{}".format(i)] = \
|
||||
ViewRequirement(
|
||||
"state_out_{}".format(i),
|
||||
data_rel_pos=-1,
|
||||
space=Box(-1.0, 1.0, shape=state.shape))
|
||||
model.inference_view_requirements["state_out_{}".format(i)] = \
|
||||
ViewRequirement(space=Box(-1.0, 1.0, shape=state.shape))
|
||||
for i, state in enumerate(obj.get_initial_state()):
|
||||
space = Box(-1.0, 1.0, shape=state.shape) if \
|
||||
hasattr(state, "shape") else state
|
||||
view_reqs = model.inference_view_requirements if model else \
|
||||
self.view_requirements
|
||||
view_reqs["state_in_{}".format(i)] = ViewRequirement(
|
||||
"state_out_{}".format(i), shift=-1, space=space)
|
||||
view_reqs["state_out_{}".format(i)] = ViewRequirement(space=space)
|
||||
|
||||
|
||||
def clip_action(action, action_space):
|
||||
|
|
|
@ -115,7 +115,9 @@ class SampleBatch:
|
|||
[s[k] for s in concat_samples],
|
||||
time_major=concat_samples[0].time_major)
|
||||
return SampleBatch(
|
||||
out, _seq_lens=seq_lens, _time_major=concat_samples[0].time_major)
|
||||
out,
|
||||
_seq_lens=np.array(seq_lens, dtype=np.int32),
|
||||
_time_major=concat_samples[0].time_major)
|
||||
|
||||
@PublicAPI
|
||||
def concat(self, other: "SampleBatch") -> "SampleBatch":
|
||||
|
@ -154,7 +156,8 @@ class SampleBatch:
|
|||
"""
|
||||
return SampleBatch(
|
||||
{k: np.array(v, copy=True)
|
||||
for (k, v) in self.data.items()})
|
||||
for (k, v) in self.data.items()},
|
||||
_seq_lens=self.seq_lens)
|
||||
|
||||
@PublicAPI
|
||||
def rows(self) -> Dict[str, TensorType]:
|
||||
|
|
|
@ -2,6 +2,7 @@ import numpy as np
|
|||
from scipy.stats import norm
|
||||
import unittest
|
||||
|
||||
import ray
|
||||
import ray.rllib.agents.dqn as dqn
|
||||
import ray.rllib.agents.pg as pg
|
||||
import ray.rllib.agents.ppo as ppo
|
||||
|
@ -87,8 +88,8 @@ def do_test_log_likelihood(run,
|
|||
logp = policy.compute_log_likelihoods(
|
||||
np.array([a]),
|
||||
preprocessed_obs_batch,
|
||||
prev_action_batch=np.array([prev_a]),
|
||||
prev_reward_batch=np.array([prev_r]))
|
||||
prev_action_batch=np.array([prev_a]) if prev_a else None,
|
||||
prev_reward_batch=np.array([prev_r]) if prev_r else None)
|
||||
check(logp, expected_logp[0], rtol=0.2)
|
||||
# Test all available actions for their logp values.
|
||||
else:
|
||||
|
@ -98,12 +99,20 @@ def do_test_log_likelihood(run,
|
|||
logp = policy.compute_log_likelihoods(
|
||||
np.array([a]),
|
||||
preprocessed_obs_batch,
|
||||
prev_action_batch=np.array([prev_a]),
|
||||
prev_reward_batch=np.array([prev_r]))
|
||||
prev_action_batch=np.array([prev_a]) if prev_a else None,
|
||||
prev_reward_batch=np.array([prev_r]) if prev_r else None)
|
||||
check(np.exp(logp), expected_prob, atol=0.2)
|
||||
|
||||
|
||||
class TestComputeLogLikelihood(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls) -> None:
|
||||
ray.init()
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls) -> None:
|
||||
ray.shutdown()
|
||||
|
||||
def test_dqn(self):
|
||||
"""Tests, whether DQN correctly computes logp in soft-q mode."""
|
||||
config = dqn.DEFAULT_CONFIG.copy()
|
||||
|
|
|
@ -274,7 +274,8 @@ class TFPolicy(Policy):
|
|||
else:
|
||||
self._loss = loss
|
||||
|
||||
self._optimizer = self.optimizer()
|
||||
if self._optimizer is None:
|
||||
self._optimizer = self.optimizer()
|
||||
self._grads_and_vars = [
|
||||
(g, v) for (g, v) in self.gradients(self._optimizer, self._loss)
|
||||
if g is not None
|
||||
|
|
|
@ -8,7 +8,6 @@ from ray.rllib.policy import eager_tf_policy
|
|||
from ray.rllib.policy.policy import Policy, LEARNER_STATS_KEY
|
||||
from ray.rllib.policy.sample_batch import SampleBatch
|
||||
from ray.rllib.policy.tf_policy import TFPolicy
|
||||
from ray.rllib.policy.view_requirement import ViewRequirement
|
||||
from ray.rllib.utils import add_mixins, force_list
|
||||
from ray.rllib.utils.annotations import override, DeveloperAPI
|
||||
from ray.rllib.utils.framework import try_import_tf
|
||||
|
@ -66,8 +65,6 @@ def build_tf_policy(
|
|||
Policy, ModelV2, TensorType, TensorType, TensorType
|
||||
], Tuple[TensorType, type, List[TensorType]]]] = None,
|
||||
mixins: Optional[List[type]] = None,
|
||||
view_requirements_fn: Optional[Callable[[Policy], Dict[
|
||||
str, ViewRequirement]]] = None,
|
||||
get_batch_divisibility_req: Optional[Callable[[Policy], int]] = None,
|
||||
# TODO: (sven) deprecate once _use_trajectory_view_api is always True.
|
||||
obs_include_prev_action_reward: bool = True,
|
||||
|
@ -231,7 +228,6 @@ def build_tf_policy(
|
|||
action_distribution_fn=action_distribution_fn,
|
||||
existing_inputs=existing_inputs,
|
||||
existing_model=existing_model,
|
||||
view_requirements_fn=view_requirements_fn,
|
||||
get_batch_divisibility_req=get_batch_divisibility_req,
|
||||
obs_include_prev_action_reward=obs_include_prev_action_reward)
|
||||
|
||||
|
|
|
@ -8,7 +8,6 @@ from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
|
|||
from ray.rllib.policy.policy import Policy, LEARNER_STATS_KEY
|
||||
from ray.rllib.policy.sample_batch import SampleBatch
|
||||
from ray.rllib.policy.torch_policy import TorchPolicy
|
||||
from ray.rllib.policy.view_requirement import ViewRequirement
|
||||
from ray.rllib.utils import add_mixins, force_list
|
||||
from ray.rllib.utils.annotations import override, DeveloperAPI
|
||||
from ray.rllib.utils.framework import try_import_torch
|
||||
|
@ -70,8 +69,6 @@ def build_torch_policy(
|
|||
apply_gradients_fn: Optional[Callable[
|
||||
[Policy, "torch.optim.Optimizer"], None]] = None,
|
||||
mixins: Optional[List[type]] = None,
|
||||
view_requirements_fn: Optional[Callable[[Policy], Dict[
|
||||
str, ViewRequirement]]] = None,
|
||||
get_batch_divisibility_req: Optional[Callable[[Policy], int]] = None
|
||||
) -> Type[TorchPolicy]:
|
||||
"""Helper function for creating a torch policy class at runtime.
|
||||
|
@ -174,9 +171,6 @@ def build_torch_policy(
|
|||
mixins (Optional[List[type]]): Optional list of any class mixins for
|
||||
the returned policy class. These mixins will be applied in order
|
||||
and will have higher precedence than the TorchPolicy class.
|
||||
view_requirements_fn (Optional[Callable[[Policy],
|
||||
Dict[str, ViewRequirement]]]): An optional callable to retrieve
|
||||
additional train view requirements for this policy.
|
||||
get_batch_divisibility_req (Optional[Callable[[Policy], int]]):
|
||||
Optional callable that returns the divisibility requirement for
|
||||
sample batches. If None, will assume a value of 1.
|
||||
|
@ -242,9 +236,6 @@ def build_torch_policy(
|
|||
get_batch_divisibility_req=get_batch_divisibility_req,
|
||||
)
|
||||
|
||||
# Update this Policy's ViewRequirements (if function given).
|
||||
if callable(view_requirements_fn):
|
||||
self.view_requirements.update(view_requirements_fn(self))
|
||||
# Merge Model's view requirements into Policy's.
|
||||
self.view_requirements.update(
|
||||
self.model.inference_view_requirements)
|
||||
|
|
|
@ -29,31 +29,42 @@ class ViewRequirement:
|
|||
def __init__(self,
|
||||
data_col: Optional[str] = None,
|
||||
space: gym.Space = None,
|
||||
data_rel_pos: Union[int, List[int]] = 0,
|
||||
shift: Union[int, List[int]] = 0,
|
||||
index: Optional[int] = None,
|
||||
used_for_training: bool = True):
|
||||
"""Initializes a ViewRequirement object.
|
||||
|
||||
Args:
|
||||
data_col (): The data column name from the SampleBatch (str key).
|
||||
If None, use the dict key under which this ViewRequirement
|
||||
resides.
|
||||
data_col (Optional[str]): The data column name from the SampleBatch
|
||||
(str key). If None, use the dict key under which this
|
||||
ViewRequirement resides.
|
||||
space (gym.Space): The gym Space used in case we need to pad data
|
||||
in inaccessible areas of the trajectory (t<0 or t>H).
|
||||
Default: Simple box space, e.g. rewards.
|
||||
data_rel_pos (Union[int, str, List[int]]): Single shift value or
|
||||
shift (Union[int, str, List[int]]): Single shift value or
|
||||
list of relative positions to use (relative to the underlying
|
||||
`data_col`).
|
||||
Example: For a view column "prev_actions", you can set
|
||||
`data_col="actions"` and `data_rel_pos=-1`.
|
||||
`data_col="actions"` and `shift=-1`.
|
||||
Example: For a view column "obs" in an Atari framestacking
|
||||
fashion, you can set `data_col="obs"` and
|
||||
`data_rel_pos=[-3, -2, -1, 0]`.
|
||||
`shift=[-3, -2, -1, 0]`.
|
||||
Example: For the obs input to an attention net, you can specify
|
||||
a range via a str: `shift="-100:0"`, which will pass in
|
||||
the past 100 observations plus the current one.
|
||||
index (Optional[int]): An optional absolute position arg,
|
||||
used e.g. for the location of a requested inference dict within
|
||||
the trajectory. Negative values refer to counting from the end
|
||||
of a trajectory.
|
||||
used_for_training (bool): Whether the data will be used for
|
||||
training. If False, the column will not be copied into the
|
||||
final train batch.
|
||||
"""
|
||||
self.data_col = data_col
|
||||
self.space = space or gym.spaces.Box(
|
||||
self.space = space if space is not None else gym.spaces.Box(
|
||||
float("-inf"), float("inf"), shape=())
|
||||
self.data_rel_pos = data_rel_pos
|
||||
|
||||
self.index = index
|
||||
|
||||
self.shift = shift
|
||||
self.used_for_training = used_for_training
|
||||
|
|
|
@ -151,10 +151,10 @@ def test_concat_batches(ray_start_regular_shared):
|
|||
def test_standardize(ray_start_regular_shared):
|
||||
workers = make_workers(0)
|
||||
a = ParallelRollouts(workers, mode="async")
|
||||
b = a.for_each(StandardizeFields(["t"]))
|
||||
b = a.for_each(StandardizeFields([SampleBatch.EPS_ID]))
|
||||
batch = next(b)
|
||||
assert abs(np.mean(batch["t"])) < 0.001, batch
|
||||
assert abs(np.std(batch["t"]) - 1.0) < 0.001, batch
|
||||
assert abs(np.mean(batch[SampleBatch.EPS_ID])) < 0.001, batch
|
||||
assert abs(np.std(batch[SampleBatch.EPS_ID]) - 1.0) < 0.001, batch
|
||||
|
||||
|
||||
def test_async_grads(ray_start_regular_shared):
|
||||
|
|
|
@ -7,6 +7,8 @@ import ray
|
|||
from ray.tune.registry import register_env
|
||||
from ray.rllib.agents.dqn.dqn_tf_policy import DQNTFPolicy
|
||||
from ray.rllib.agents.pg import PGTrainer
|
||||
from ray.rllib.evaluation.episode import MultiAgentEpisode
|
||||
from ray.rllib.evaluation.rollout_worker import get_global_worker
|
||||
from ray.rllib.examples.policy.random_policy import RandomPolicy
|
||||
from ray.rllib.examples.env.multi_agent import MultiAgentCartPole, \
|
||||
BasicMultiAgent, EarlyDoneMultiAgent, RoundRobinMultiAgent
|
||||
|
@ -321,21 +323,31 @@ class TestMultiAgentEnv(unittest.TestCase):
|
|||
if episodes is not None:
|
||||
# Pretend we did a model-based rollout and want to return
|
||||
# the extra trajectory.
|
||||
builder = episodes[0].new_batch_builder()
|
||||
rollout_id = random.randint(0, 10000)
|
||||
for t in range(5):
|
||||
builder.add_values(
|
||||
agent_id="extra_0",
|
||||
policy_id="p1", # use p1 so we can easily check it
|
||||
t=t,
|
||||
eps_id=rollout_id, # new id for each rollout
|
||||
obs=obs_batch[0],
|
||||
actions=0,
|
||||
rewards=0,
|
||||
dones=t == 4,
|
||||
infos={},
|
||||
new_obs=obs_batch[0])
|
||||
batch = builder.build_and_reset(episode=None)
|
||||
env_id = episodes[0].env_id
|
||||
fake_eps = MultiAgentEpisode(
|
||||
episodes[0]._policies, episodes[0]._policy_mapping_fn,
|
||||
lambda: None, lambda x: None, env_id)
|
||||
builder = get_global_worker().sampler.sample_collector
|
||||
agent_id = "extra_0"
|
||||
policy_id = "p1" # use p1 so we can easily check it
|
||||
builder.add_init_obs(fake_eps, agent_id, env_id, policy_id,
|
||||
-1, obs_batch[0])
|
||||
for t in range(4):
|
||||
builder.add_action_reward_next_obs(
|
||||
episode_id=fake_eps.episode_id,
|
||||
agent_id=agent_id,
|
||||
env_id=env_id,
|
||||
policy_id=policy_id,
|
||||
agent_done=t == 3,
|
||||
values=dict(
|
||||
t=t,
|
||||
actions=0,
|
||||
rewards=0,
|
||||
dones=t == 3,
|
||||
infos={},
|
||||
new_obs=obs_batch[0]))
|
||||
batch = builder.postprocess_episode(
|
||||
episode=fake_eps, build=True)
|
||||
episodes[0].add_extra_batch(batch)
|
||||
|
||||
# Just return zeros for actions
|
||||
|
@ -350,12 +362,17 @@ class TestMultiAgentEnv(unittest.TestCase):
|
|||
"p0": (ModelBasedPolicy, obs_space, act_space, {}),
|
||||
"p1": (ModelBasedPolicy, obs_space, act_space, {}),
|
||||
},
|
||||
policy_config={"_use_trajectory_view_api": True},
|
||||
policy_mapping_fn=lambda agent_id: "p0",
|
||||
rollout_fragment_length=5)
|
||||
batch = ev.sample()
|
||||
# 5 environment steps (rollout_fragment_length).
|
||||
self.assertEqual(batch.count, 5)
|
||||
# 10 agent steps for p0: 2 agents, both using p0 as their policy.
|
||||
self.assertEqual(batch.policy_batches["p0"].count, 10)
|
||||
self.assertEqual(batch.policy_batches["p1"].count, 25)
|
||||
# 20 agent steps for p1: Each time both(!) agents takes 1 step,
|
||||
# p1 takes 4: 5 (rollout-fragment length) * 4 = 20
|
||||
self.assertEqual(batch.policy_batches["p1"].count, 20)
|
||||
|
||||
def test_train_multi_agent_cartpole_single_policy(self):
|
||||
n = 10
|
||||
|
|
|
@ -1,12 +1,21 @@
|
|||
import numpy as np
|
||||
import unittest
|
||||
|
||||
import ray
|
||||
import ray.rllib.agents.ddpg as ddpg
|
||||
import ray.rllib.agents.dqn as dqn
|
||||
from ray.rllib.utils.test_utils import check, framework_iterator
|
||||
|
||||
|
||||
class TestParameterNoise(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls) -> None:
|
||||
ray.init()
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls) -> None:
|
||||
ray.shutdown()
|
||||
|
||||
def test_ddpg_parameter_noise(self):
|
||||
self.do_test_parameter_noise_exploration(
|
||||
ddpg.DDPGTrainer, ddpg.DEFAULT_CONFIG, "Pendulum-v0", {},
|
||||
|
@ -37,6 +46,10 @@ class TestParameterNoise(unittest.TestCase):
|
|||
trainer = trainer_cls(config=config, env=env)
|
||||
policy = trainer.get_policy()
|
||||
pol_sess = getattr(policy, "_sess", None)
|
||||
# Remove noise that has been added during policy initialization
|
||||
# (exploration.postprocess_trajectory does add noise to measure
|
||||
# the delta).
|
||||
policy.exploration._remove_noise(tf_sess=pol_sess)
|
||||
|
||||
self.assertFalse(policy.exploration.weights_are_currently_noisy)
|
||||
noise_before = self._get_current_noise(policy, fw)
|
||||
|
@ -96,6 +109,12 @@ class TestParameterNoise(unittest.TestCase):
|
|||
config["explore"] = False
|
||||
trainer = trainer_cls(config=config, env=env)
|
||||
policy = trainer.get_policy()
|
||||
pol_sess = getattr(policy, "_sess", None)
|
||||
# Remove noise that has been added during policy initialization
|
||||
# (exploration.postprocess_trajectory does add noise to measure
|
||||
# the delta).
|
||||
policy.exploration._remove_noise(tf_sess=pol_sess)
|
||||
|
||||
self.assertFalse(policy.exploration.weights_are_currently_noisy)
|
||||
initial_weights = self._get_current_weight(policy, fw)
|
||||
|
||||
|
|
|
@ -67,6 +67,10 @@ EnvInfoDict = dict
|
|||
# Represents a File object
|
||||
FileType = Any
|
||||
|
||||
# Represents a ViewRequirements dict mapping column names (str) to
|
||||
# ViewRequirement objects.
|
||||
ViewRequirementsDict = Dict[str, "ViewRequirement"]
|
||||
|
||||
# Represents the result dict returned by Trainer.train().
|
||||
ResultDict = dict
|
||||
|
||||
|
|
Loading…
Add table
Reference in a new issue