mirror of
https://github.com/vale981/ray
synced 2025-03-05 18:11:42 -05:00
[RLlib] Batch-size for truncate_episode batch_mode should be confgurable in agent-steps (rather than env-steps), if needed. (#12420)
This commit is contained in:
parent
fd4e025da6
commit
e40b14d255
28 changed files with 384 additions and 132 deletions
14
rllib/BUILD
14
rllib/BUILD
|
@ -1108,6 +1108,13 @@ py_test(
|
|||
# srcs = ["evaluation/tests/test_trajectory_view_api.py"]
|
||||
#)
|
||||
|
||||
py_test(
|
||||
name = "evaluation/tests/test_rollout_worker",
|
||||
tags = ["evaluation"],
|
||||
size = "medium",
|
||||
srcs = ["evaluation/tests/test_rollout_worker.py"]
|
||||
)
|
||||
|
||||
|
||||
# --------------------------------------------------------------------
|
||||
# Optimizers and Memories
|
||||
|
@ -1411,13 +1418,6 @@ py_test(
|
|||
args = ["TestRolloutLearntPolicy"]
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "tests/test_rollout_worker",
|
||||
tags = ["tests_dir", "tests_dir_R"],
|
||||
size = "medium",
|
||||
srcs = ["tests/test_rollout_worker.py"]
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "tests/test_supported_multi_agent_pg",
|
||||
main = "tests/test_supported_multi_agent.py",
|
||||
|
|
|
@ -38,17 +38,20 @@ def execution_plan(workers, config):
|
|||
# allowing for extremely large experience batches to be used.
|
||||
train_op = (
|
||||
rollouts.combine(
|
||||
ConcatBatches(min_batch_size=config["microbatch_size"]))
|
||||
ConcatBatches(
|
||||
min_batch_size=config["microbatch_size"],
|
||||
count_steps_by=config["multiagent"]["count_steps_by"]))
|
||||
.for_each(ComputeGradients(workers)) # (grads, info)
|
||||
.batch(num_microbatches) # List[(grads, info)]
|
||||
.for_each(AverageGradients()) # (avg_grads, info)
|
||||
.for_each(ApplyGradients(workers)))
|
||||
else:
|
||||
# In normal mode, we execute one SGD step per each train batch.
|
||||
train_op = rollouts \
|
||||
.combine(ConcatBatches(
|
||||
min_batch_size=config["train_batch_size"])) \
|
||||
.for_each(TrainOneStep(workers))
|
||||
train_op = rollouts.combine(
|
||||
ConcatBatches(
|
||||
min_batch_size=config["train_batch_size"],
|
||||
count_steps_by=config["multiagent"][
|
||||
"count_steps_by"])).for_each(TrainOneStep(workers))
|
||||
|
||||
return StandardMetricsReporting(train_op, workers, config)
|
||||
|
||||
|
|
|
@ -221,7 +221,10 @@ def gather_experiences_directly(workers, config):
|
|||
replay_proportion=config["replay_proportion"])) \
|
||||
.flatten() \
|
||||
.combine(
|
||||
ConcatBatches(min_batch_size=config["train_batch_size"]))
|
||||
ConcatBatches(
|
||||
min_batch_size=config["train_batch_size"],
|
||||
count_steps_by=config["multiagent"]["count_steps_by"],
|
||||
))
|
||||
|
||||
return train_batches
|
||||
|
||||
|
|
|
@ -56,7 +56,10 @@ def execution_plan(workers, config):
|
|||
|
||||
replay_op = Replay(local_buffer=replay_buffer) \
|
||||
.combine(
|
||||
ConcatBatches(min_batch_size=config["train_batch_size"])) \
|
||||
ConcatBatches(
|
||||
min_batch_size=config["train_batch_size"],
|
||||
count_steps_by=config["multiagent"]["count_steps_by"],
|
||||
)) \
|
||||
.for_each(TrainOneStep(workers))
|
||||
|
||||
train_op = Concurrently(
|
||||
|
|
|
@ -244,7 +244,10 @@ def execution_plan(workers: WorkerSet,
|
|||
SelectExperiences(workers.trainable_policies()))
|
||||
# Concatenate the SampleBatches into one.
|
||||
rollouts = rollouts.combine(
|
||||
ConcatBatches(min_batch_size=config["train_batch_size"]))
|
||||
ConcatBatches(
|
||||
min_batch_size=config["train_batch_size"],
|
||||
count_steps_by=config["multiagent"]["count_steps_by"],
|
||||
))
|
||||
# Standardize advantages.
|
||||
rollouts = rollouts.for_each(StandardizeFields(["advantages"]))
|
||||
|
||||
|
|
|
@ -73,7 +73,7 @@ class TestPPO(unittest.TestCase):
|
|||
def test_ppo_compilation_and_lr_schedule(self):
|
||||
"""Test whether a PPOTrainer can be built with all frameworks."""
|
||||
config = copy.deepcopy(ppo.DEFAULT_CONFIG)
|
||||
# for checking lr-schedule correctness
|
||||
# For checking lr-schedule correctness.
|
||||
config["callbacks"] = MyCallbacks
|
||||
|
||||
config["num_workers"] = 1
|
||||
|
|
|
@ -109,7 +109,10 @@ def execution_plan(workers, config):
|
|||
|
||||
train_op = Replay(local_buffer=replay_buffer) \
|
||||
.combine(
|
||||
ConcatBatches(min_batch_size=config["train_batch_size"])) \
|
||||
ConcatBatches(
|
||||
min_batch_size=config["train_batch_size"],
|
||||
count_steps_by=config["multiagent"]["count_steps_by"]
|
||||
)) \
|
||||
.for_each(TrainOneStep(workers)) \
|
||||
.for_each(UpdateTargetNetwork(
|
||||
workers, config["target_network_update_freq"]))
|
||||
|
|
|
@ -75,10 +75,18 @@ COMMON_CONFIG: TrainerConfigDict = {
|
|||
# The dataflow here can vary per algorithm. For example, PPO further
|
||||
# divides the train batch into minibatches for multi-epoch SGD.
|
||||
"rollout_fragment_length": 200,
|
||||
# Whether to rollout "complete_episodes" or "truncate_episodes" to
|
||||
# `rollout_fragment_length` length unrolls. Episode truncation guarantees
|
||||
# evenly sized batches, but increases variance as the reward-to-go will
|
||||
# need to be estimated at truncation boundaries.
|
||||
# How to build per-Sampler (RolloutWorker) batches, which are then
|
||||
# usually concat'd to form the train batch. Note that "steps" below can
|
||||
# mean different things (either env- or agent-steps) and depends on the
|
||||
# `count_steps_by` (multiagent) setting below.
|
||||
# truncate_episodes: Each produced batch (when calling
|
||||
# RolloutWorker.sample()) will contain exactly `rollout_fragment_length`
|
||||
# steps. This mode guarantees evenly sized batches, but increases
|
||||
# variance as the future return must now be estimated at truncation
|
||||
# boundaries.
|
||||
# complete_episodes: Each unroll happens exactly over one episode, from
|
||||
# beginning to end. Data collection will not stop unless the episode
|
||||
# terminates or a configured horizon (hard or soft) is hit.
|
||||
"batch_mode": "truncate_episodes",
|
||||
|
||||
# === Settings for the Trainer process ===
|
||||
|
@ -357,6 +365,13 @@ COMMON_CONFIG: TrainerConfigDict = {
|
|||
# agents it controls at that timestep. When replay_mode=independent,
|
||||
# transitions are replayed independently per policy.
|
||||
"replay_mode": "independent",
|
||||
# Which metric to use as the "batch size" when building a
|
||||
# MultiAgentBatch. The two supported values are:
|
||||
# env_steps: Count each time the env is "stepped" (no matter how many
|
||||
# multi-agent actions are passed/how many multi-agent observations
|
||||
# have been returned in the previous step).
|
||||
# agent_steps: Count each individual agent step as one step.
|
||||
"count_steps_by": "env_steps",
|
||||
},
|
||||
|
||||
# === Logger ===
|
||||
|
@ -1081,6 +1096,20 @@ class Trainer(Trainable):
|
|||
config["model"]["lstm_use_prev_action"] = prev_a_r
|
||||
config["model"]["lstm_use_prev_reward"] = prev_a_r
|
||||
|
||||
# Check batching/sample collection settings.
|
||||
if config["batch_mode"] not in [
|
||||
"truncate_episodes", "complete_episodes"
|
||||
]:
|
||||
raise ValueError("`batch_mode` must be one of [truncate_episodes|"
|
||||
"complete_episodes]! Got {}".format(
|
||||
config["batch_mode"]))
|
||||
|
||||
if config["multiagent"].get("count_steps_by", "env_steps") not in \
|
||||
["env_steps", "agent_steps"]:
|
||||
raise ValueError(
|
||||
"`count_steps_by` must be one of [env_steps|agent_steps]! "
|
||||
"Got {}".format(config["multiagent"]["count_steps_by"]))
|
||||
|
||||
def _try_recover(self):
|
||||
"""Try to identify and remove any unhealthy workers.
|
||||
|
||||
|
|
|
@ -22,10 +22,11 @@ def default_execution_plan(workers: WorkerSet, config: TrainerConfigDict):
|
|||
|
||||
# Combine experiences batches until we hit `train_batch_size` in size.
|
||||
# Then, train the policy on those experiences and update the workers.
|
||||
train_op = rollouts \
|
||||
.combine(ConcatBatches(
|
||||
min_batch_size=config["train_batch_size"])) \
|
||||
.for_each(TrainOneStep(workers))
|
||||
train_op = rollouts.combine(
|
||||
ConcatBatches(
|
||||
min_batch_size=config["train_batch_size"],
|
||||
count_steps_by=config["multiagent"]["count_steps_by"],
|
||||
)).for_each(TrainOneStep(workers))
|
||||
|
||||
# Add on the standard episode reward, etc. metrics reporting. This returns
|
||||
# a LocalIterator[metrics_dict] representing metrics for each train step.
|
||||
|
|
|
@ -164,11 +164,12 @@ def execution_plan(workers, config):
|
|||
rollouts = ParallelRollouts(workers, mode="bulk_sync")
|
||||
|
||||
if config["simple_optimizer"]:
|
||||
train_op = rollouts \
|
||||
.combine(ConcatBatches(
|
||||
min_batch_size=config["train_batch_size"])) \
|
||||
.for_each(TrainOneStep(
|
||||
workers, num_sgd_iter=config["num_sgd_iter"]))
|
||||
train_op = rollouts.combine(
|
||||
ConcatBatches(
|
||||
min_batch_size=config["train_batch_size"],
|
||||
count_steps_by=config["multiagent"]["count_steps_by"],
|
||||
)).for_each(
|
||||
TrainOneStep(workers, num_sgd_iter=config["num_sgd_iter"]))
|
||||
else:
|
||||
replay_buffer = SimpleReplayBuffer(config["buffer_size"])
|
||||
|
||||
|
@ -178,7 +179,10 @@ def execution_plan(workers, config):
|
|||
replay_op = Replay(local_buffer=replay_buffer) \
|
||||
.filter(WaitUntilTimestepsElapsed(config["learning_starts"])) \
|
||||
.combine(
|
||||
ConcatBatches(min_batch_size=config["train_batch_size"])) \
|
||||
ConcatBatches(
|
||||
min_batch_size=config["train_batch_size"],
|
||||
count_steps_by=config["multiagent"]["count_steps_by"],
|
||||
)) \
|
||||
.for_each(TrainOneStep(
|
||||
workers, num_sgd_iter=config["num_sgd_iter"]))
|
||||
|
||||
|
|
|
@ -110,11 +110,37 @@ class _SampleCollector(metaclass=ABCMeta):
|
|||
|
||||
@abstractmethod
|
||||
def total_env_steps(self) -> int:
|
||||
"""Returns total number of steps taken in the env (sum of all agents).
|
||||
"""Returns total number of env-steps taken so far.
|
||||
|
||||
Thereby, a step in an N-agent multi-agent environment counts as only 1
|
||||
for this metric. The returned count contains everything that has not
|
||||
been built yet (and returned as MultiAgentBatches by the
|
||||
`try_build_truncated_episode_multi_agent_batch` or
|
||||
`postprocess_episode(build=True)` methods). After such build, this
|
||||
counter is reset to 0.
|
||||
|
||||
Returns:
|
||||
int: The number of steps taken in total in the environment over all
|
||||
agents.
|
||||
int: The number of env-steps taken in total in the environment(s)
|
||||
so far.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def total_agent_steps(self) -> int:
|
||||
"""Returns total number of (individual) agent-steps taken so far.
|
||||
|
||||
Thereby, a step in an N-agent multi-agent environment counts as N.
|
||||
If less than N agents have stepped (because some agents were not
|
||||
required to send actions), the count will be increased by less than N.
|
||||
The returned count contains everything that has not been built yet
|
||||
(and returned as MultiAgentBatches by the
|
||||
`try_build_truncated_episode_multi_agent_batch` or
|
||||
`postprocess_episode(build=True)` methods). After such build, this
|
||||
counter is reset to 0.
|
||||
|
||||
Returns:
|
||||
int: The number of (individual) agent-steps taken in total in the
|
||||
environment(s) so far.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
|
|
@ -51,7 +51,7 @@ class _AgentCollector:
|
|||
self.episode_id = None
|
||||
# The simple timestep count for this agent. Gets increased by one
|
||||
# each time a (non-initial!) observation is added.
|
||||
self.count = 0
|
||||
self.agent_steps = 0
|
||||
|
||||
def add_init_obs(self, episode_id: EpisodeID, agent_index: int,
|
||||
env_id: EnvID, t: int, init_obs: TensorType) -> None:
|
||||
|
@ -105,7 +105,7 @@ class _AgentCollector:
|
|||
if k not in self.buffers:
|
||||
self._build_buffers(single_row=values)
|
||||
self.buffers[k].append(v)
|
||||
self.count += 1
|
||||
self.agent_steps += 1
|
||||
|
||||
def build(self, view_requirements: ViewRequirementsDict) -> SampleBatch:
|
||||
"""Builds a SampleBatch from the thus-far collected agent data.
|
||||
|
@ -183,7 +183,7 @@ class _AgentCollector:
|
|||
if self.shift_before > 0:
|
||||
for k, data in self.buffers.items():
|
||||
self.buffers[k] = data[-self.shift_before:]
|
||||
self.count = 0
|
||||
self.agent_steps = 0
|
||||
|
||||
return batch
|
||||
|
||||
|
@ -238,7 +238,7 @@ class _PolicyCollector:
|
|||
# NOTE: This is not an env-step count (across n agents). AgentA and
|
||||
# agentB, both using this policy, acting in the same episode and both
|
||||
# doing n steps would increase the count by 2*n.
|
||||
self.count = 0
|
||||
self.agent_steps = 0
|
||||
|
||||
def add_postprocessed_batch_for_training(
|
||||
self, batch: SampleBatch,
|
||||
|
@ -246,9 +246,9 @@ class _PolicyCollector:
|
|||
"""Adds a postprocessed SampleBatch (single agent) to our buffers.
|
||||
|
||||
Args:
|
||||
batch (SampleBatch): A single agent (one trajectory) SampleBatch
|
||||
to be added to the Policy's buffers.
|
||||
view_requirements (DViewRequirementsDict): The view
|
||||
batch (SampleBatch): An individual agent's (one trajectory)
|
||||
SampleBatch to be added to the Policy's buffers.
|
||||
view_requirements (ViewRequirementsDict): The view
|
||||
requirements for the policy. This is so we know, whether a
|
||||
view-column needs to be copied at all (not needed for
|
||||
training).
|
||||
|
@ -261,7 +261,7 @@ class _PolicyCollector:
|
|||
view_requirements[view_col].used_for_training:
|
||||
self.buffers[view_col].extend(data)
|
||||
# Add the agent's trajectory length to our count.
|
||||
self.count += batch.count
|
||||
self.agent_steps += batch.count
|
||||
|
||||
def build(self):
|
||||
"""Builds a SampleBatch for this policy from the collected data.
|
||||
|
@ -277,8 +277,8 @@ class _PolicyCollector:
|
|||
assert SampleBatch.UNROLL_ID in batch.data
|
||||
# Clear buffers for future samples.
|
||||
self.buffers.clear()
|
||||
# Reset count to 0.
|
||||
self.count = 0
|
||||
# Reset agent steps to 0.
|
||||
self.agent_steps = 0
|
||||
return batch
|
||||
|
||||
|
||||
|
@ -288,7 +288,11 @@ class _PolicyCollectorGroup:
|
|||
pid: _PolicyCollector()
|
||||
for pid in policy_map.keys()
|
||||
}
|
||||
self.count = 0
|
||||
# Total env-steps (1 env-step=up to N agents stepped).
|
||||
self.env_steps = 0
|
||||
# Total agent steps (1 agent-step=1 individual agent (out of N)
|
||||
# stepped).
|
||||
self.agent_steps = 0
|
||||
|
||||
|
||||
class _SimpleListCollector(_SampleCollector):
|
||||
|
@ -305,7 +309,8 @@ class _SimpleListCollector(_SampleCollector):
|
|||
clip_rewards: Union[bool, float],
|
||||
callbacks: "DefaultCallbacks",
|
||||
multiple_episodes_in_batch: bool = True,
|
||||
rollout_fragment_length: int = 200):
|
||||
rollout_fragment_length: int = 200,
|
||||
count_steps_by: str = "env_steps"):
|
||||
"""Initializes a _SimpleListCollector instance.
|
||||
|
||||
Args:
|
||||
|
@ -314,6 +319,10 @@ class _SimpleListCollector(_SampleCollector):
|
|||
clip_rewards (Union[bool, float]): Whether to clip rewards before
|
||||
postprocessing (at +/-1.0) or the actual value to +/- clip.
|
||||
callbacks (DefaultCallbacks): RLlib callbacks.
|
||||
multiple_episodes_in_batch (bool): Whether it's allowed to pack
|
||||
multiple episodes into the same built batch.
|
||||
rollout_fragment_length (int): The
|
||||
|
||||
"""
|
||||
|
||||
self.policy_map = policy_map
|
||||
|
@ -321,6 +330,7 @@ class _SimpleListCollector(_SampleCollector):
|
|||
self.callbacks = callbacks
|
||||
self.multiple_episodes_in_batch = multiple_episodes_in_batch
|
||||
self.rollout_fragment_length = rollout_fragment_length
|
||||
self.count_steps_by = count_steps_by
|
||||
self.large_batch_threshold: int = max(
|
||||
1000, rollout_fragment_length *
|
||||
10) if rollout_fragment_length != float("inf") else 5000
|
||||
|
@ -340,8 +350,10 @@ class _SimpleListCollector(_SampleCollector):
|
|||
self.forward_pass_size = {pid: 0 for pid in policy_map.keys()}
|
||||
|
||||
# Maps episode ID to the (non-built) env steps taken in this episode.
|
||||
self.episode_steps: Dict[EpisodeID, int] = \
|
||||
collections.defaultdict(int)
|
||||
self.episode_steps: Dict[EpisodeID, int] = collections.defaultdict(int)
|
||||
# Maps episode ID to the (non-built) individual agent steps in this
|
||||
# episode.
|
||||
self.agent_steps: Dict[EpisodeID, int] = collections.defaultdict(int)
|
||||
# Maps episode ID to MultiAgentEpisode.
|
||||
self.episodes: Dict[EpisodeID, MultiAgentEpisode] = {}
|
||||
|
||||
|
@ -351,15 +363,17 @@ class _SimpleListCollector(_SampleCollector):
|
|||
self.episode_steps[episode_id] += 1
|
||||
episode.length += 1
|
||||
assert episode.batch_builder is not None
|
||||
env_steps = episode.batch_builder.count
|
||||
num_observations = sum(
|
||||
c.count for c in episode.batch_builder.policy_collectors.values())
|
||||
env_steps = episode.batch_builder.env_steps
|
||||
num_individual_observations = sum(
|
||||
c.agent_steps
|
||||
for c in episode.batch_builder.policy_collectors.values())
|
||||
|
||||
if num_observations > self.large_batch_threshold and \
|
||||
if num_individual_observations > self.large_batch_threshold and \
|
||||
log_once("large_batch_warning"):
|
||||
logger.warning(
|
||||
"More than {} observations in {} env steps for "
|
||||
"episode {} ".format(num_observations, env_steps, episode_id) +
|
||||
"episode {} ".format(num_individual_observations, env_steps,
|
||||
episode_id) +
|
||||
"are buffered in the sampler. If this is more than you "
|
||||
"expected, check that that you set a horizon on your "
|
||||
"environment correctly and that it terminates at some point. "
|
||||
|
@ -412,6 +426,8 @@ class _SimpleListCollector(_SampleCollector):
|
|||
assert self.agent_key_to_policy_id[agent_key] == policy_id
|
||||
assert agent_key in self.agent_collectors
|
||||
|
||||
self.agent_steps[episode_id] += 1
|
||||
|
||||
# Include the current agent id for multi-agent algorithms.
|
||||
if agent_id != _DUMMY_AGENT_ID:
|
||||
values["agent_id"] = agent_id
|
||||
|
@ -424,7 +440,18 @@ class _SimpleListCollector(_SampleCollector):
|
|||
|
||||
@override(_SampleCollector)
|
||||
def total_env_steps(self) -> int:
|
||||
return sum(a.count for a in self.agent_collectors.values())
|
||||
# Add the non-built ongoing-episode env steps + the already built
|
||||
# env-steps.
|
||||
return sum(self.episode_steps.values()) + sum(
|
||||
pg.env_steps for pg in self.policy_collector_groups.values())
|
||||
|
||||
@override(_SampleCollector)
|
||||
def total_agent_steps(self) -> int:
|
||||
# Add the non-built ongoing-episode agent steps (still in the agent
|
||||
# collectors) + the already built agent steps.
|
||||
return sum(a.agent_steps for a in self.agent_collectors.values()) + \
|
||||
sum(pg.agent_steps for pg in
|
||||
self.policy_collector_groups.values())
|
||||
|
||||
@override(_SampleCollector)
|
||||
def get_inference_input_dict(self, policy_id: PolicyID) -> \
|
||||
|
@ -463,11 +490,12 @@ class _SimpleListCollector(_SampleCollector):
|
|||
return input_dict
|
||||
|
||||
@override(_SampleCollector)
|
||||
def postprocess_episode(self,
|
||||
episode: MultiAgentEpisode,
|
||||
is_done: bool = False,
|
||||
check_dones: bool = False,
|
||||
build: bool = False) -> None:
|
||||
def postprocess_episode(
|
||||
self,
|
||||
episode: MultiAgentEpisode,
|
||||
is_done: bool = False,
|
||||
check_dones: bool = False,
|
||||
build: bool = False) -> Union[None, SampleBatch, MultiAgentBatch]:
|
||||
episode_id = episode.episode_id
|
||||
policy_collector_group = episode.batch_builder
|
||||
|
||||
|
@ -478,7 +506,7 @@ class _SimpleListCollector(_SampleCollector):
|
|||
pre_batches = {}
|
||||
for (eps_id, agent_id), collector in self.agent_collectors.items():
|
||||
# Build only if there is data and agent is part of given episode.
|
||||
if collector.count == 0 or eps_id != episode_id:
|
||||
if collector.agent_steps == 0 or eps_id != episode_id:
|
||||
continue
|
||||
pid = self.agent_key_to_policy_id[(eps_id, agent_id)]
|
||||
policy = self.policy_map[pid]
|
||||
|
@ -559,16 +587,19 @@ class _SimpleListCollector(_SampleCollector):
|
|||
post_batch, policy.view_requirements)
|
||||
|
||||
env_steps = self.episode_steps[episode_id]
|
||||
policy_collector_group.count += env_steps
|
||||
policy_collector_group.env_steps += env_steps
|
||||
agent_steps = self.agent_steps[episode_id]
|
||||
policy_collector_group.agent_steps += agent_steps
|
||||
|
||||
if is_done:
|
||||
del self.episode_steps[episode_id]
|
||||
del self.agent_steps[episode_id]
|
||||
del self.episodes[episode_id]
|
||||
# Make PolicyCollectorGroup available for more agent batches in
|
||||
# other episodes. Do not reset count to 0.
|
||||
self.policy_collector_groups.append(policy_collector_group)
|
||||
else:
|
||||
self.episode_steps[episode_id] = 0
|
||||
self.episode_steps[episode_id] = self.agent_steps[episode_id] = 0
|
||||
|
||||
# Build a MultiAgentBatch from the episode and return.
|
||||
if build:
|
||||
|
@ -579,14 +610,15 @@ class _SimpleListCollector(_SampleCollector):
|
|||
|
||||
ma_batch = {}
|
||||
for pid, collector in episode.batch_builder.policy_collectors.items():
|
||||
if collector.count > 0:
|
||||
if collector.agent_steps > 0:
|
||||
ma_batch[pid] = collector.build()
|
||||
# Create the batch.
|
||||
ma_batch = MultiAgentBatch.wrap_as_needed(
|
||||
ma_batch, env_steps=episode.batch_builder.count)
|
||||
ma_batch, env_steps=episode.batch_builder.env_steps)
|
||||
|
||||
# PolicyCollectorGroup is empty.
|
||||
episode.batch_builder.count = 0
|
||||
episode.batch_builder.env_steps = 0
|
||||
episode.batch_builder.agent_steps = 0
|
||||
|
||||
return ma_batch
|
||||
|
||||
|
@ -595,16 +627,26 @@ class _SimpleListCollector(_SampleCollector):
|
|||
List[Union[MultiAgentBatch, SampleBatch]]:
|
||||
batches = []
|
||||
# Loop through ongoing episodes and see whether their length plus
|
||||
# what's already in the policy collectors reaches the fragment-len.
|
||||
# what's already in the policy collectors reaches the fragment-len
|
||||
# (abiding to the unit used: env-steps or agent-steps).
|
||||
for episode_id, episode in self.episodes.items():
|
||||
env_steps = episode.batch_builder.count + \
|
||||
self.episode_steps[episode_id]
|
||||
# Measure batch size in env-steps.
|
||||
if self.count_steps_by == "env_steps":
|
||||
built_steps = episode.batch_builder.env_steps
|
||||
ongoing_steps = self.episode_steps[episode_id]
|
||||
# Measure batch-size in agent-steps.
|
||||
else:
|
||||
built_steps = episode.batch_builder.agent_steps
|
||||
ongoing_steps = self.agent_steps[episode_id]
|
||||
|
||||
# Reached the fragment-len -> We should build an MA-Batch.
|
||||
if env_steps >= self.rollout_fragment_length:
|
||||
assert env_steps == self.rollout_fragment_length
|
||||
if built_steps + ongoing_steps >= self.rollout_fragment_length:
|
||||
if self.count_steps_by != "agent_steps":
|
||||
assert built_steps + ongoing_steps == \
|
||||
self.rollout_fragment_length
|
||||
# If we reached the fragment-len only because of `episode_id`
|
||||
# (still ongoing) -> postprocess `episode_id` first.
|
||||
if episode.batch_builder.count < self.rollout_fragment_length:
|
||||
if built_steps < self.rollout_fragment_length:
|
||||
self.postprocess_episode(episode, is_done=False)
|
||||
# Build the MA-batch and return.
|
||||
batch = self._build_multi_agent_batch(episode=episode)
|
||||
|
|
|
@ -143,6 +143,7 @@ class RolloutWorker(ParallelIteratorWorker):
|
|||
policies_to_train: Optional[List[PolicyID]] = None,
|
||||
tf_session_creator: Optional[Callable[[], "tf1.Session"]] = None,
|
||||
rollout_fragment_length: int = 100,
|
||||
count_steps_by: str = "env_steps",
|
||||
batch_mode: str = "truncate_episodes",
|
||||
episode_horizon: int = None,
|
||||
preprocessor_pref: str = "deepmind",
|
||||
|
@ -208,8 +209,11 @@ class RolloutWorker(ParallelIteratorWorker):
|
|||
tf_session_creator (Optional[Callable[[], tf1.Session]]): A
|
||||
function that returns a TF session. This is optional and only
|
||||
useful with TFPolicy.
|
||||
rollout_fragment_length (int): The target number of env transitions
|
||||
to include in each sample batch returned from this worker.
|
||||
rollout_fragment_length (int): The target number of steps
|
||||
(maesured in `count_steps_by`) to include in each sample
|
||||
batch returned from this worker.
|
||||
count_steps_by (str): The unit in which to count fragment
|
||||
lengths. One of env_steps or agent_steps.
|
||||
batch_mode (str): One of the following batch modes:
|
||||
"truncate_episodes": Each call to sample() will return a batch
|
||||
of at most `rollout_fragment_length * num_envs` in size.
|
||||
|
@ -356,6 +360,7 @@ class RolloutWorker(ParallelIteratorWorker):
|
|||
raise ValueError("Policy mapping function not callable?")
|
||||
self.env_creator: Callable[[EnvContext], EnvType] = env_creator
|
||||
self.rollout_fragment_length: int = rollout_fragment_length * num_envs
|
||||
self.count_steps_by: str = count_steps_by
|
||||
self.batch_mode: str = batch_mode
|
||||
self.compress_observations: bool = compress_observations
|
||||
self.preprocessing_enabled: bool = True
|
||||
|
@ -570,6 +575,7 @@ class RolloutWorker(ParallelIteratorWorker):
|
|||
obs_filters=self.filters,
|
||||
clip_rewards=clip_rewards,
|
||||
rollout_fragment_length=rollout_fragment_length,
|
||||
count_steps_by=count_steps_by,
|
||||
callbacks=self.callbacks,
|
||||
horizon=episode_horizon,
|
||||
multiple_episodes_in_batch=pack,
|
||||
|
@ -593,6 +599,7 @@ class RolloutWorker(ParallelIteratorWorker):
|
|||
obs_filters=self.filters,
|
||||
clip_rewards=clip_rewards,
|
||||
rollout_fragment_length=rollout_fragment_length,
|
||||
count_steps_by=count_steps_by,
|
||||
callbacks=self.callbacks,
|
||||
horizon=episode_horizon,
|
||||
multiple_episodes_in_batch=pack,
|
||||
|
@ -636,7 +643,9 @@ class RolloutWorker(ParallelIteratorWorker):
|
|||
self.rollout_fragment_length))
|
||||
|
||||
batches = [self.input_reader.next()]
|
||||
steps_so_far = batches[0].count
|
||||
steps_so_far = batches[0].count if \
|
||||
self.count_steps_by == "env_steps" else \
|
||||
batches[0].agent_steps()
|
||||
|
||||
# In truncate_episodes mode, never pull more than 1 batch per env.
|
||||
# This avoids over-running the target batch size.
|
||||
|
@ -648,7 +657,9 @@ class RolloutWorker(ParallelIteratorWorker):
|
|||
while (steps_so_far < self.rollout_fragment_length
|
||||
and len(batches) < max_batches):
|
||||
batch = self.input_reader.next()
|
||||
steps_so_far += batch.count
|
||||
steps_so_far += batch.count if \
|
||||
self.count_steps_by == "env_steps" else \
|
||||
batch.agent_steps()
|
||||
batches.append(batch)
|
||||
batch = batches[0].concat_samples(batches) if len(batches) > 1 else \
|
||||
batches[0]
|
||||
|
|
|
@ -129,6 +129,7 @@ class SyncSampler(SamplerInput):
|
|||
obs_filters: Dict[PolicyID, Filter],
|
||||
clip_rewards: bool,
|
||||
rollout_fragment_length: int,
|
||||
count_steps_by: str = "env_steps",
|
||||
callbacks: "DefaultCallbacks",
|
||||
horizon: int = None,
|
||||
multiple_episodes_in_batch: bool = False,
|
||||
|
@ -190,8 +191,12 @@ class SyncSampler(SamplerInput):
|
|||
self.perf_stats = _PerfStats()
|
||||
if _use_trajectory_view_api:
|
||||
self.sample_collector = _SimpleListCollector(
|
||||
policies, clip_rewards, callbacks, multiple_episodes_in_batch,
|
||||
rollout_fragment_length)
|
||||
policies,
|
||||
clip_rewards,
|
||||
callbacks,
|
||||
multiple_episodes_in_batch,
|
||||
rollout_fragment_length,
|
||||
count_steps_by=count_steps_by)
|
||||
else:
|
||||
self.sample_collector = None
|
||||
|
||||
|
@ -254,6 +259,7 @@ class AsyncSampler(threading.Thread, SamplerInput):
|
|||
obs_filters: Dict[PolicyID, Filter],
|
||||
clip_rewards: bool,
|
||||
rollout_fragment_length: int,
|
||||
count_steps_by: str = "env_steps",
|
||||
callbacks: "DefaultCallbacks",
|
||||
horizon: int = None,
|
||||
multiple_episodes_in_batch: bool = False,
|
||||
|
@ -282,6 +288,8 @@ class AsyncSampler(threading.Thread, SamplerInput):
|
|||
rollout_fragment_length (int): The length of a fragment to collect
|
||||
before building a SampleBatch from the data and resetting
|
||||
the SampleBatchBuilder object.
|
||||
count_steps_by (str): Either "env_steps" or "agent_steps".
|
||||
Refers to the unit of `rollout_fragment_length`.
|
||||
callbacks (Callbacks): The Callbacks object to use when episode
|
||||
events happen during rollout.
|
||||
horizon (Optional[int]): Hard-reset the Env
|
||||
|
@ -336,8 +344,12 @@ class AsyncSampler(threading.Thread, SamplerInput):
|
|||
self._use_trajectory_view_api = _use_trajectory_view_api
|
||||
if _use_trajectory_view_api:
|
||||
self.sample_collector = _SimpleListCollector(
|
||||
policies, clip_rewards, callbacks, multiple_episodes_in_batch,
|
||||
rollout_fragment_length)
|
||||
policies,
|
||||
clip_rewards,
|
||||
callbacks,
|
||||
multiple_episodes_in_batch,
|
||||
rollout_fragment_length,
|
||||
count_steps_by=count_steps_by)
|
||||
else:
|
||||
self.sample_collector = None
|
||||
|
||||
|
|
0
rllib/evaluation/tests/__init__.py
Normal file
0
rllib/evaluation/tests/__init__.py
Normal file
|
@ -1,5 +1,6 @@
|
|||
from collections import Counter
|
||||
import gym
|
||||
from gym.spaces import Box, Discrete
|
||||
import numpy as np
|
||||
import os
|
||||
import random
|
||||
|
@ -13,9 +14,12 @@ from ray.rllib.env.vector_env import VectorEnv
|
|||
from ray.rllib.evaluation.rollout_worker import RolloutWorker
|
||||
from ray.rllib.evaluation.metrics import collect_metrics
|
||||
from ray.rllib.evaluation.postprocessing import compute_advantages
|
||||
from ray.rllib.examples.env.mock_env import MockEnv, MockEnv2
|
||||
from ray.rllib.examples.env.multi_agent import MultiAgentCartPole
|
||||
from ray.rllib.examples.policy.random_policy import RandomPolicy
|
||||
from ray.rllib.policy.policy import Policy
|
||||
from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID, SampleBatch
|
||||
from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID, MultiAgentBatch, \
|
||||
SampleBatch
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils.test_utils import check, framework_iterator
|
||||
from ray.tune.registry import register_env
|
||||
|
@ -71,39 +75,6 @@ class FailOnStepEnv(gym.Env):
|
|||
raise ValueError("kaboom")
|
||||
|
||||
|
||||
class MockEnv(gym.Env):
|
||||
def __init__(self, episode_length, config=None):
|
||||
self.episode_length = episode_length
|
||||
self.config = config
|
||||
self.i = 0
|
||||
self.observation_space = gym.spaces.Discrete(1)
|
||||
self.action_space = gym.spaces.Discrete(2)
|
||||
|
||||
def reset(self):
|
||||
self.i = 0
|
||||
return self.i
|
||||
|
||||
def step(self, action):
|
||||
self.i += 1
|
||||
return 0, 1, self.i >= self.episode_length, {}
|
||||
|
||||
|
||||
class MockEnv2(gym.Env):
|
||||
def __init__(self, episode_length):
|
||||
self.episode_length = episode_length
|
||||
self.i = 0
|
||||
self.observation_space = gym.spaces.Discrete(100)
|
||||
self.action_space = gym.spaces.Discrete(2)
|
||||
|
||||
def reset(self):
|
||||
self.i = 0
|
||||
return self.i
|
||||
|
||||
def step(self, action):
|
||||
self.i += 1
|
||||
return self.i, 100, self.i >= self.episode_length, {}
|
||||
|
||||
|
||||
class MockVectorEnv(VectorEnv):
|
||||
def __init__(self, episode_length, num_envs):
|
||||
super().__init__(
|
||||
|
@ -523,14 +494,57 @@ class TestRolloutWorker(unittest.TestCase):
|
|||
ev.stop()
|
||||
|
||||
def test_truncate_episodes(self):
|
||||
ev = RolloutWorker(
|
||||
ev_env_steps = RolloutWorker(
|
||||
env_creator=lambda _: MockEnv(10),
|
||||
policy_spec=MockPolicy,
|
||||
policy_config={"_use_trajectory_view_api": True},
|
||||
rollout_fragment_length=15,
|
||||
batch_mode="truncate_episodes")
|
||||
batch = ev.sample()
|
||||
batch = ev_env_steps.sample()
|
||||
self.assertEqual(batch.count, 15)
|
||||
ev.stop()
|
||||
self.assertTrue(isinstance(batch, SampleBatch))
|
||||
ev_env_steps.stop()
|
||||
|
||||
action_space = Discrete(2)
|
||||
obs_space = Box(float("-inf"), float("inf"), (4, ), dtype=np.float32)
|
||||
ev_agent_steps = RolloutWorker(
|
||||
env_creator=lambda _: MultiAgentCartPole({"num_agents": 4}),
|
||||
policy_spec={
|
||||
"pol0": (MockPolicy, obs_space, action_space, {}),
|
||||
"pol1": (MockPolicy, obs_space, action_space, {}),
|
||||
},
|
||||
policy_config={"_use_trajectory_view_api": True},
|
||||
policy_mapping_fn=lambda ag: "pol0" if ag == 0 else "pol1",
|
||||
rollout_fragment_length=301,
|
||||
count_steps_by="env_steps",
|
||||
batch_mode="truncate_episodes",
|
||||
)
|
||||
batch = ev_agent_steps.sample()
|
||||
self.assertTrue(isinstance(batch, MultiAgentBatch))
|
||||
self.assertGreater(batch.agent_steps(), 301)
|
||||
self.assertEqual(batch.env_steps(), 301)
|
||||
ev_agent_steps.stop()
|
||||
|
||||
ev_agent_steps = RolloutWorker(
|
||||
env_creator=lambda _: MultiAgentCartPole({"num_agents": 4}),
|
||||
policy_spec={
|
||||
"pol0": (MockPolicy, obs_space, action_space, {}),
|
||||
"pol1": (MockPolicy, obs_space, action_space, {}),
|
||||
},
|
||||
policy_config={"_use_trajectory_view_api": True},
|
||||
policy_mapping_fn=lambda ag: "pol0" if ag == 0 else "pol1",
|
||||
rollout_fragment_length=301,
|
||||
count_steps_by="agent_steps",
|
||||
batch_mode="truncate_episodes")
|
||||
batch = ev_agent_steps.sample()
|
||||
self.assertTrue(isinstance(batch, MultiAgentBatch))
|
||||
self.assertLess(batch.env_steps(), 301)
|
||||
# When counting agent steps, the count may be slightly larger than
|
||||
# rollout_fragment_length, b/c we have up to N agents stepping in each
|
||||
# env step and we only check, whether we should build after each env
|
||||
# step.
|
||||
self.assertGreaterEqual(batch.agent_steps(), 301)
|
||||
ev_agent_steps.stop()
|
||||
|
||||
def test_complete_episodes(self):
|
||||
ev = RolloutWorker(
|
|
@ -1,13 +1,16 @@
|
|||
import copy
|
||||
import gym
|
||||
from gym.spaces import Box, Discrete
|
||||
import numpy as np
|
||||
import time
|
||||
import unittest
|
||||
|
||||
import ray
|
||||
from ray import tune
|
||||
import ray.rllib.agents.dqn as dqn
|
||||
import ray.rllib.agents.ppo as ppo
|
||||
from ray.rllib.examples.env.debug_counter_env import MultiAgentDebugCounterEnv
|
||||
from ray.rllib.examples.env.multi_agent import MultiAgentCartPole
|
||||
from ray.rllib.evaluation.rollout_worker import RolloutWorker
|
||||
from ray.rllib.examples.policy.episode_env_aware_policy import \
|
||||
EpisodeEnvAwareLSTMPolicy
|
||||
|
@ -295,6 +298,38 @@ class TestTrajectoryViewAPI(unittest.TestCase):
|
|||
pol_batch_wo = result.policy_batches["pol0"]
|
||||
check(pol_batch_w.data, pol_batch_wo.data)
|
||||
|
||||
def test_counting_by_agent_steps(self):
|
||||
"""Test whether a PPOTrainer can be built with all frameworks."""
|
||||
config = copy.deepcopy(ppo.DEFAULT_CONFIG)
|
||||
action_space = Discrete(2)
|
||||
obs_space = Box(float("-inf"), float("inf"), (4, ), dtype=np.float32)
|
||||
|
||||
config["num_workers"] = 2
|
||||
config["num_sgd_iter"] = 2
|
||||
config["framework"] = "torch"
|
||||
config["rollout_fragment_length"] = 21
|
||||
config["train_batch_size"] = 147
|
||||
config["multiagent"] = {
|
||||
"policies": {
|
||||
"p0": (None, obs_space, action_space, {}),
|
||||
"p1": (None, obs_space, action_space, {}),
|
||||
},
|
||||
"policy_mapping_fn": lambda aid: "p{}".format(aid),
|
||||
"count_steps_by": "agent_steps",
|
||||
}
|
||||
tune.register_env(
|
||||
"ma_cartpole", lambda _: MultiAgentCartPole({"num_agents": 2}))
|
||||
num_iterations = 2
|
||||
trainer = ppo.PPOTrainer(config=config, env="ma_cartpole")
|
||||
results = None
|
||||
for i in range(num_iterations):
|
||||
results = trainer.train()
|
||||
self.assertGreater(results["timesteps_total"],
|
||||
num_iterations * config["train_batch_size"])
|
||||
self.assertLess(results["timesteps_total"],
|
||||
(num_iterations + 1) * config["train_batch_size"])
|
||||
trainer.stop()
|
||||
|
||||
|
||||
def analyze_rnn_batch(batch, max_seq_len):
|
||||
count = batch.count
|
||||
|
|
|
@ -321,6 +321,7 @@ class WorkerSet:
|
|||
tf_session_creator=(session_creator
|
||||
if config["tf_session_args"] else None),
|
||||
rollout_fragment_length=config["rollout_fragment_length"],
|
||||
count_steps_by=config["multiagent"]["count_steps_by"],
|
||||
batch_mode=config["batch_mode"],
|
||||
episode_horizon=config["horizon"],
|
||||
preprocessor_pref=config["preprocessor_pref"],
|
||||
|
|
46
rllib/examples/env/mock_env.py
vendored
Normal file
46
rllib/examples/env/mock_env.py
vendored
Normal file
|
@ -0,0 +1,46 @@
|
|||
import gym
|
||||
|
||||
|
||||
class MockEnv(gym.Env):
|
||||
"""Mock environment for testing purposes.
|
||||
|
||||
Observation=0, reward=1.0, episode-len is configurable.
|
||||
Actions are ignored.
|
||||
"""
|
||||
|
||||
def __init__(self, episode_length, config=None):
|
||||
self.episode_length = episode_length
|
||||
self.config = config
|
||||
self.i = 0
|
||||
self.observation_space = gym.spaces.Discrete(1)
|
||||
self.action_space = gym.spaces.Discrete(2)
|
||||
|
||||
def reset(self):
|
||||
self.i = 0
|
||||
return 0
|
||||
|
||||
def step(self, action):
|
||||
self.i += 1
|
||||
return 0, 1.0, self.i >= self.episode_length, {}
|
||||
|
||||
|
||||
class MockEnv2(gym.Env):
|
||||
"""Mock environment for testing purposes.
|
||||
|
||||
Observation=ts (discrete space!), reward=100.0, episode-len is
|
||||
configurable. Actions are ignored.
|
||||
"""
|
||||
|
||||
def __init__(self, episode_length):
|
||||
self.episode_length = episode_length
|
||||
self.i = 0
|
||||
self.observation_space = gym.spaces.Discrete(100)
|
||||
self.action_space = gym.spaces.Discrete(2)
|
||||
|
||||
def reset(self):
|
||||
self.i = 0
|
||||
return self.i
|
||||
|
||||
def step(self, action):
|
||||
self.i += 1
|
||||
return self.i, 100.0, self.i >= self.episode_length, {}
|
2
rllib/examples/env/multi_agent.py
vendored
2
rllib/examples/env/multi_agent.py
vendored
|
@ -1,8 +1,8 @@
|
|||
import gym
|
||||
|
||||
from ray.rllib.env.multi_agent_env import MultiAgentEnv
|
||||
from ray.rllib.examples.env.mock_env import MockEnv, MockEnv2
|
||||
from ray.rllib.examples.env.stateless_cartpole import StatelessCartPole
|
||||
from ray.rllib.tests.test_rollout_worker import MockEnv, MockEnv2
|
||||
|
||||
|
||||
def make_multiagent(env_name_or_creator):
|
||||
|
|
|
@ -81,7 +81,8 @@ def custom_training_workflow(workers: WorkerSet, config: dict):
|
|||
|
||||
# PPO sub-flow.
|
||||
ppo_train_op = r2.for_each(SelectExperiences(["ppo_policy"])) \
|
||||
.combine(ConcatBatches(min_batch_size=200)) \
|
||||
.combine(ConcatBatches(
|
||||
min_batch_size=200, count_steps_by="env_steps")) \
|
||||
.for_each(add_ppo_metrics) \
|
||||
.for_each(StandardizeFields(["advantages"])) \
|
||||
.for_each(TrainOneStep(
|
||||
|
|
|
@ -141,13 +141,15 @@ class ConcatBatches:
|
|||
|
||||
Examples:
|
||||
>>> rollouts = ParallelRollouts(...)
|
||||
>>> rollouts = rollouts.combine(ConcatBatches(min_batch_size=10000))
|
||||
>>> rollouts = rollouts.combine(ConcatBatches(
|
||||
... min_batch_size=10000, count_steps_by="env_steps"))
|
||||
>>> print(next(rollouts).count)
|
||||
10000
|
||||
"""
|
||||
|
||||
def __init__(self, min_batch_size: int):
|
||||
def __init__(self, min_batch_size: int, count_steps_by: str = "env_steps"):
|
||||
self.min_batch_size = min_batch_size
|
||||
self.count_steps_by = count_steps_by
|
||||
self.buffer = []
|
||||
self.count = 0
|
||||
self.batch_start_time = None
|
||||
|
@ -159,7 +161,15 @@ class ConcatBatches:
|
|||
def __call__(self, batch: SampleBatchType) -> List[SampleBatchType]:
|
||||
_check_sample_batch_type(batch)
|
||||
self.buffer.append(batch)
|
||||
self.count += batch.count
|
||||
|
||||
if self.count_steps_by == "env_steps":
|
||||
self.count += batch.count
|
||||
else:
|
||||
assert isinstance(batch, MultiAgentBatch), \
|
||||
"`count_steps_by=agent_steps` only allowed in multi-agent " \
|
||||
"environments!"
|
||||
self.count += batch.agent_steps()
|
||||
|
||||
if self.count >= self.min_batch_size:
|
||||
if self.count > self.min_batch_size * 2:
|
||||
logger.info("Collected more training samples than expected "
|
||||
|
|
|
@ -51,7 +51,9 @@ class Aggregator(ParallelIteratorWorker):
|
|||
.flatten() \
|
||||
.combine(
|
||||
ConcatBatches(
|
||||
min_batch_size=config["train_batch_size"]))
|
||||
min_batch_size=config["train_batch_size"],
|
||||
count_steps_by=config["multiagent"]["count_steps_by"],
|
||||
))
|
||||
|
||||
for train_batch in it:
|
||||
yield train_batch
|
||||
|
|
|
@ -417,16 +417,17 @@ class MultiAgentBatch:
|
|||
Args:
|
||||
policy_batches (Dict[PolicyID, SampleBatch]): Mapping from policy
|
||||
ids to SampleBatches of experiences.
|
||||
env_steps (int): The number of timesteps in the environment this
|
||||
batch contains. This will be less than the number of
|
||||
env_steps (int): The number of environment steps in the environment
|
||||
this batch contains. This will be less than the number of
|
||||
transitions this batch contains across all policies in total.
|
||||
"""
|
||||
|
||||
for v in policy_batches.values():
|
||||
assert isinstance(v, SampleBatch)
|
||||
self.policy_batches = policy_batches
|
||||
# Called count for uniformity with SampleBatch. Prefer to access this
|
||||
# via the env_steps() method when possible for clarity.
|
||||
# Called "count" for uniformity with SampleBatch.
|
||||
# Prefer to access this via the `env_steps()` method when possible
|
||||
# for clarity.
|
||||
self.count = env_steps
|
||||
|
||||
@PublicAPI
|
||||
|
@ -526,7 +527,8 @@ class MultiAgentBatch:
|
|||
"""
|
||||
if len(policy_batches) == 1 and DEFAULT_POLICY_ID in policy_batches:
|
||||
return policy_batches[DEFAULT_POLICY_ID]
|
||||
return MultiAgentBatch(policy_batches, env_steps)
|
||||
return MultiAgentBatch(
|
||||
policy_batches=policy_batches, env_steps=env_steps)
|
||||
|
||||
@staticmethod
|
||||
@PublicAPI
|
||||
|
|
|
@ -9,8 +9,9 @@ from ray.rllib.agents.dqn import DQNTrainer
|
|||
from ray.rllib.agents.pg import PGTrainer
|
||||
from ray.rllib.evaluation.rollout_worker import RolloutWorker
|
||||
from ray.rllib.env.external_env import ExternalEnv
|
||||
from ray.rllib.tests.test_rollout_worker import (BadPolicy, MockPolicy,
|
||||
MockEnv)
|
||||
from ray.rllib.evaluation.tests.test_rollout_worker import (BadPolicy,
|
||||
MockPolicy)
|
||||
from ray.rllib.examples.env.mock_env import MockEnv
|
||||
from ray.rllib.utils.test_utils import framework_iterator
|
||||
from ray.tune.registry import register_env
|
||||
|
||||
|
|
|
@ -5,8 +5,8 @@ import unittest
|
|||
import ray
|
||||
from ray.rllib.env.external_multi_agent_env import ExternalMultiAgentEnv
|
||||
from ray.rllib.evaluation.rollout_worker import RolloutWorker
|
||||
from ray.rllib.evaluation.tests.test_rollout_worker import MockPolicy
|
||||
from ray.rllib.examples.env.multi_agent import BasicMultiAgent
|
||||
from ray.rllib.tests.test_rollout_worker import MockPolicy
|
||||
from ray.rllib.tests.test_external_env import make_simple_serving
|
||||
|
||||
SimpleMultiServing = make_simple_serving(True, ExternalMultiAgentEnv)
|
||||
|
|
|
@ -12,8 +12,8 @@ from ray.rllib.evaluation.rollout_worker import get_global_worker
|
|||
from ray.rllib.examples.policy.random_policy import RandomPolicy
|
||||
from ray.rllib.examples.env.multi_agent import MultiAgentCartPole, \
|
||||
BasicMultiAgent, EarlyDoneMultiAgent, RoundRobinMultiAgent
|
||||
from ray.rllib.tests.test_rollout_worker import MockPolicy
|
||||
from ray.rllib.evaluation.rollout_worker import RolloutWorker
|
||||
from ray.rllib.evaluation.tests.test_rollout_worker import MockPolicy
|
||||
from ray.rllib.env.base_env import _MultiAgentEnvToBaseEnv
|
||||
from ray.rllib.utils.numpy import one_hot
|
||||
from ray.rllib.utils.test_utils import check
|
||||
|
|
|
@ -4,7 +4,7 @@ import unittest
|
|||
|
||||
import ray
|
||||
from ray.rllib.evaluation.rollout_worker import RolloutWorker
|
||||
from ray.rllib.tests.test_rollout_worker import MockPolicy
|
||||
from ray.rllib.evaluation.tests.test_rollout_worker import MockPolicy
|
||||
|
||||
|
||||
class TestPerf(unittest.TestCase):
|
||||
|
|
Loading…
Add table
Reference in a new issue