[RLlib] CQL iteration count fixes: Remove dummy buffer and unnecessary store op from exec_plan. (#16332)

This commit is contained in:
Sven Mika 2021-06-10 07:49:17 +02:00 committed by GitHub
parent c8a5d7ba85
commit 3d4dc60e2e
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 68 additions and 53 deletions

View file

@ -2271,6 +2271,15 @@ py_test(
args = ["--as-test", "--stop-reward=50.0", "--num-cpus=6"]
)
py_test(
name = "examples/parallel_evaluation_and_training_tf2",
main = "examples/parallel_evaluation_and_training.py",
tags = ["examples", "examples_P"],
size = "medium",
srcs = ["examples/parallel_evaluation_and_training.py"],
args = ["--as-test", "--framework=tf2", "--stop-reward=30.0", "--num-cpus=6"]
)
py_test(
name = "examples/parametric_actions_cartpole_pg_tf",
main = "examples/parametric_actions_cartpole.py",

View file

@ -1,19 +1,15 @@
"""CQL (derived from SAC).
"""
import numpy as np
from typing import Optional, Type, List
from typing import Optional, Type
from ray.actor import ActorHandle
from ray.rllib.agents.cql.cql_tf_policy import CQLTFPolicy
from ray.rllib.agents.cql.cql_torch_policy import CQLTorchPolicy
from ray.rllib.agents.dqn.dqn import calculate_rr_weights
from ray.rllib.agents.sac.sac import SACTrainer, \
DEFAULT_CONFIG as SAC_CONFIG
from ray.rllib.execution.concurrency_ops import Concurrently
from ray.rllib.execution.metric_ops import StandardMetricsReporting
from ray.rllib.execution.replay_buffer import LocalReplayBuffer
from ray.rllib.execution.replay_ops import Replay
from ray.rllib.execution.rollout_ops import ParallelRollouts
from ray.rllib.execution.train_ops import TrainTFMultiGPU, TrainOneStep, \
UpdateTargetNetwork
from ray.rllib.offline.shuffled_input import ShuffledInput
@ -57,17 +53,6 @@ def validate_config(config: TrainerConfigDict):
replay_buffer = None
class NoOpReplayBuffer:
def __init__(self,
*,
local_buffer: LocalReplayBuffer = None,
actors: List[ActorHandle] = None):
return
def __call__(self, batch):
return batch
def execution_plan(workers, config):
if config.get("prioritized_replay"):
prio_args = {
@ -92,14 +77,6 @@ def execution_plan(workers, config):
global replay_buffer
replay_buffer = local_replay_buffer
rollouts = ParallelRollouts(workers, mode="bulk_sync")
# NoReplayBuffer ensures that no online data is added
# The Dataset is added to the Replay Buffer in after_init()
# method below the execution plan.
store_op = rollouts.for_each(
NoOpReplayBuffer(local_buffer=local_replay_buffer))
def update_prio(item):
samples, info_dict = item
if config.get("prioritized_replay"):
@ -141,16 +118,8 @@ def execution_plan(workers, config):
.for_each(UpdateTargetNetwork(
workers, config["target_network_update_freq"]))
# Alternate deterministically between (1) and (2).
train_op = Concurrently(
[store_op, replay_op],
mode="round_robin",
# Only return the output
# of (2) since training metrics are not available until (2) runs.
output_indexes=[1],
round_robin_weights=calculate_rr_weights(config))
return StandardMetricsReporting(train_op, workers, config)
return StandardMetricsReporting(
replay_op, workers, config, by_steps_trained=True)
def get_policy_class(config: TrainerConfigDict) -> Optional[Type[Policy]]:
@ -185,7 +154,7 @@ def after_init(trainer):
batch[SampleBatch.DONES][-1] = True
replay_buffer.add_batch(batch)
print(
f"Loaded {num_batches} batches ({total_timesteps} ts) into "
f"Loaded {num_batches} batches ({total_timesteps} ts) into the "
f"replay buffer, which has capacity {replay_buffer.buffer_size}.")
else:
raise ValueError(

View file

@ -42,22 +42,30 @@ class TestCQL(unittest.TestCase):
config["num_workers"] = 0 # Run locally.
config["twin_q"] = True
config["clip_actions"] = False
config["clip_actions"] = True
config["normalize_actions"] = True
config["learning_starts"] = 0
config["rollout_fragment_length"] = 1
config["train_batch_size"] = 10
# Switch on off-policy evaluation.
config["input_evaluation"] = ["is"]
num_iterations = 2
config["evaluation_interval"] = 2
config["evaluation_num_episodes"] = 10
config["evaluation_config"]["input"] = "sampler"
config["evaluation_parallel_to_training"] = True
config["evaluation_num_workers"] = 2
num_iterations = 3
# Test for tf/torch frameworks.
for fw in framework_iterator(config):
trainer = cql.CQLTrainer(config=config)
for i in range(num_iterations):
print(trainer.train())
results = trainer.train().get("evaluation")
if results:
print(f"iter={trainer.iteration} "
f"R={results['episode_reward_mean']}")
check_compute_single_action(trainer)

View file

@ -816,6 +816,13 @@ class Trainer(Trainable):
Note that this default implementation does not do anything beyond
merging evaluation_config with the normal trainer config.
"""
# In case we are evaluating (in a thread) parallel to training,
# we may have to re-enable eager mode here (gets disabled in the
# thread).
if self.config.get("framework") in ["tf2", "tfe"] and \
not tf.executing_eagerly():
tf1.enable_eager_execution()
# Call the `_before_evaluate` hook.
self._before_evaluate()

View file

@ -176,17 +176,22 @@ def build_trainer(
# No parallelism.
if not self.config["evaluation_parallel_to_training"]:
res = next(self.train_exec_impl)
# Kick off evaluation-loop (and parallel train() call,
# if requested).
with concurrent.futures.ThreadPoolExecutor() as executor:
eval_future = executor.submit(self.evaluate)
# Parallelism.
if self.config["evaluation_parallel_to_training"]:
# Parallel eval + training.
if self.config["evaluation_parallel_to_training"]:
with concurrent.futures.ThreadPoolExecutor() as executor:
eval_future = executor.submit(self.evaluate)
res = next(self.train_exec_impl)
evaluation_metrics = eval_future.result()
assert isinstance(evaluation_metrics, dict), \
"_evaluate() needs to return a dict."
res.update(evaluation_metrics)
evaluation_metrics = eval_future.result()
# Sequential: train (already done above), then eval.
else:
evaluation_metrics = self.evaluate()
assert isinstance(evaluation_metrics, dict), \
"_evaluate() needs to return a dict."
res.update(evaluation_metrics)
# Check `env_task_fn` for possible update of the env's task.
if self.config["env_task_fn"] is not None:

View file

@ -5,7 +5,7 @@ from ray.actor import ActorHandle
from ray.util.iter import LocalIterator
from ray.rllib.evaluation.metrics import collect_episodes, summarize_episodes
from ray.rllib.execution.common import AGENT_STEPS_SAMPLED_COUNTER, \
STEPS_SAMPLED_COUNTER, _get_shared_metrics
STEPS_SAMPLED_COUNTER, STEPS_TRAINED_COUNTER, _get_shared_metrics
from ray.rllib.evaluation.worker_set import WorkerSet
@ -13,7 +13,9 @@ def StandardMetricsReporting(
train_op: LocalIterator[Any],
workers: WorkerSet,
config: dict,
selected_workers: List[ActorHandle] = None) -> LocalIterator[dict]:
selected_workers: List[ActorHandle] = None,
by_steps_trained: bool = False,
) -> LocalIterator[dict]:
"""Operator to periodically collect and report metrics.
Args:
@ -24,6 +26,8 @@ def StandardMetricsReporting(
of stats reporting.
selected_workers (list): Override the list of remote workers
to collect metrics from.
by_steps_trained (bool): If True, uses the `STEPS_TRAINED_COUNTER`
instead of the `STEPS_SAMPLED_COUNTER` in metrics.
Returns:
LocalIterator[dict]: A local iterator over training results.
@ -36,10 +40,12 @@ def StandardMetricsReporting(
"""
output_op = train_op \
.filter(OncePerTimestepsElapsed(config["timesteps_per_iteration"])) \
.filter(OncePerTimestepsElapsed(config["timesteps_per_iteration"],
by_steps_trained=by_steps_trained)) \
.filter(OncePerTimeInterval(config["min_iter_time_s"])) \
.for_each(CollectMetrics(
workers, min_history=config["metrics_smoothing_episodes"],
workers,
min_history=config["metrics_smoothing_episodes"],
timeout_seconds=config["collect_metrics_timeout"],
selected_workers=selected_workers))
return output_op
@ -158,15 +164,26 @@ class OncePerTimestepsElapsed:
# will only return after 1000 steps have elapsed
"""
def __init__(self, delay_steps: int):
def __init__(self, delay_steps: int, by_steps_trained: bool = False):
"""
Args:
delay_steps (int): The number of steps (sampled or trained) every
which this op returns True.
by_steps_trained (bool): If True, uses the `STEPS_TRAINED_COUNTER`
instead of the `STEPS_SAMPLED_COUNTER` in metrics.
"""
self.delay_steps = delay_steps
self.by_steps_trained = by_steps_trained
self.last_called = 0
def __call__(self, item: Any) -> bool:
if self.delay_steps <= 0:
return True
metrics = _get_shared_metrics()
now = metrics.counters[STEPS_SAMPLED_COUNTER]
if self.by_steps_trained:
now = metrics.counters[STEPS_TRAINED_COUNTER]
else:
now = metrics.counters[STEPS_SAMPLED_COUNTER]
if now - self.last_called >= self.delay_steps:
self.last_called = now
return True