mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
[RLlib] CQL iteration count fixes: Remove dummy buffer and unnecessary store op from exec_plan. (#16332)
This commit is contained in:
parent
c8a5d7ba85
commit
3d4dc60e2e
6 changed files with 68 additions and 53 deletions
|
@ -2271,6 +2271,15 @@ py_test(
|
|||
args = ["--as-test", "--stop-reward=50.0", "--num-cpus=6"]
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "examples/parallel_evaluation_and_training_tf2",
|
||||
main = "examples/parallel_evaluation_and_training.py",
|
||||
tags = ["examples", "examples_P"],
|
||||
size = "medium",
|
||||
srcs = ["examples/parallel_evaluation_and_training.py"],
|
||||
args = ["--as-test", "--framework=tf2", "--stop-reward=30.0", "--num-cpus=6"]
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "examples/parametric_actions_cartpole_pg_tf",
|
||||
main = "examples/parametric_actions_cartpole.py",
|
||||
|
|
|
@ -1,19 +1,15 @@
|
|||
"""CQL (derived from SAC).
|
||||
"""
|
||||
import numpy as np
|
||||
from typing import Optional, Type, List
|
||||
from typing import Optional, Type
|
||||
|
||||
from ray.actor import ActorHandle
|
||||
from ray.rllib.agents.cql.cql_tf_policy import CQLTFPolicy
|
||||
from ray.rllib.agents.cql.cql_torch_policy import CQLTorchPolicy
|
||||
from ray.rllib.agents.dqn.dqn import calculate_rr_weights
|
||||
from ray.rllib.agents.sac.sac import SACTrainer, \
|
||||
DEFAULT_CONFIG as SAC_CONFIG
|
||||
from ray.rllib.execution.concurrency_ops import Concurrently
|
||||
from ray.rllib.execution.metric_ops import StandardMetricsReporting
|
||||
from ray.rllib.execution.replay_buffer import LocalReplayBuffer
|
||||
from ray.rllib.execution.replay_ops import Replay
|
||||
from ray.rllib.execution.rollout_ops import ParallelRollouts
|
||||
from ray.rllib.execution.train_ops import TrainTFMultiGPU, TrainOneStep, \
|
||||
UpdateTargetNetwork
|
||||
from ray.rllib.offline.shuffled_input import ShuffledInput
|
||||
|
@ -57,17 +53,6 @@ def validate_config(config: TrainerConfigDict):
|
|||
replay_buffer = None
|
||||
|
||||
|
||||
class NoOpReplayBuffer:
|
||||
def __init__(self,
|
||||
*,
|
||||
local_buffer: LocalReplayBuffer = None,
|
||||
actors: List[ActorHandle] = None):
|
||||
return
|
||||
|
||||
def __call__(self, batch):
|
||||
return batch
|
||||
|
||||
|
||||
def execution_plan(workers, config):
|
||||
if config.get("prioritized_replay"):
|
||||
prio_args = {
|
||||
|
@ -92,14 +77,6 @@ def execution_plan(workers, config):
|
|||
global replay_buffer
|
||||
replay_buffer = local_replay_buffer
|
||||
|
||||
rollouts = ParallelRollouts(workers, mode="bulk_sync")
|
||||
|
||||
# NoReplayBuffer ensures that no online data is added
|
||||
# The Dataset is added to the Replay Buffer in after_init()
|
||||
# method below the execution plan.
|
||||
store_op = rollouts.for_each(
|
||||
NoOpReplayBuffer(local_buffer=local_replay_buffer))
|
||||
|
||||
def update_prio(item):
|
||||
samples, info_dict = item
|
||||
if config.get("prioritized_replay"):
|
||||
|
@ -141,16 +118,8 @@ def execution_plan(workers, config):
|
|||
.for_each(UpdateTargetNetwork(
|
||||
workers, config["target_network_update_freq"]))
|
||||
|
||||
# Alternate deterministically between (1) and (2).
|
||||
train_op = Concurrently(
|
||||
[store_op, replay_op],
|
||||
mode="round_robin",
|
||||
# Only return the output
|
||||
# of (2) since training metrics are not available until (2) runs.
|
||||
output_indexes=[1],
|
||||
round_robin_weights=calculate_rr_weights(config))
|
||||
|
||||
return StandardMetricsReporting(train_op, workers, config)
|
||||
return StandardMetricsReporting(
|
||||
replay_op, workers, config, by_steps_trained=True)
|
||||
|
||||
|
||||
def get_policy_class(config: TrainerConfigDict) -> Optional[Type[Policy]]:
|
||||
|
@ -185,7 +154,7 @@ def after_init(trainer):
|
|||
batch[SampleBatch.DONES][-1] = True
|
||||
replay_buffer.add_batch(batch)
|
||||
print(
|
||||
f"Loaded {num_batches} batches ({total_timesteps} ts) into "
|
||||
f"Loaded {num_batches} batches ({total_timesteps} ts) into the "
|
||||
f"replay buffer, which has capacity {replay_buffer.buffer_size}.")
|
||||
else:
|
||||
raise ValueError(
|
||||
|
|
|
@ -42,22 +42,30 @@ class TestCQL(unittest.TestCase):
|
|||
|
||||
config["num_workers"] = 0 # Run locally.
|
||||
config["twin_q"] = True
|
||||
config["clip_actions"] = False
|
||||
config["clip_actions"] = True
|
||||
config["normalize_actions"] = True
|
||||
config["learning_starts"] = 0
|
||||
config["rollout_fragment_length"] = 1
|
||||
config["train_batch_size"] = 10
|
||||
|
||||
# Switch on off-policy evaluation.
|
||||
config["input_evaluation"] = ["is"]
|
||||
|
||||
num_iterations = 2
|
||||
config["evaluation_interval"] = 2
|
||||
config["evaluation_num_episodes"] = 10
|
||||
config["evaluation_config"]["input"] = "sampler"
|
||||
config["evaluation_parallel_to_training"] = True
|
||||
config["evaluation_num_workers"] = 2
|
||||
|
||||
num_iterations = 3
|
||||
|
||||
# Test for tf/torch frameworks.
|
||||
for fw in framework_iterator(config):
|
||||
trainer = cql.CQLTrainer(config=config)
|
||||
for i in range(num_iterations):
|
||||
print(trainer.train())
|
||||
results = trainer.train().get("evaluation")
|
||||
if results:
|
||||
print(f"iter={trainer.iteration} "
|
||||
f"R={results['episode_reward_mean']}")
|
||||
|
||||
check_compute_single_action(trainer)
|
||||
|
||||
|
|
|
@ -816,6 +816,13 @@ class Trainer(Trainable):
|
|||
Note that this default implementation does not do anything beyond
|
||||
merging evaluation_config with the normal trainer config.
|
||||
"""
|
||||
# In case we are evaluating (in a thread) parallel to training,
|
||||
# we may have to re-enable eager mode here (gets disabled in the
|
||||
# thread).
|
||||
if self.config.get("framework") in ["tf2", "tfe"] and \
|
||||
not tf.executing_eagerly():
|
||||
tf1.enable_eager_execution()
|
||||
|
||||
# Call the `_before_evaluate` hook.
|
||||
self._before_evaluate()
|
||||
|
||||
|
|
|
@ -176,17 +176,22 @@ def build_trainer(
|
|||
# No parallelism.
|
||||
if not self.config["evaluation_parallel_to_training"]:
|
||||
res = next(self.train_exec_impl)
|
||||
|
||||
# Kick off evaluation-loop (and parallel train() call,
|
||||
# if requested).
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
eval_future = executor.submit(self.evaluate)
|
||||
# Parallelism.
|
||||
if self.config["evaluation_parallel_to_training"]:
|
||||
# Parallel eval + training.
|
||||
if self.config["evaluation_parallel_to_training"]:
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
eval_future = executor.submit(self.evaluate)
|
||||
res = next(self.train_exec_impl)
|
||||
evaluation_metrics = eval_future.result()
|
||||
assert isinstance(evaluation_metrics, dict), \
|
||||
"_evaluate() needs to return a dict."
|
||||
res.update(evaluation_metrics)
|
||||
evaluation_metrics = eval_future.result()
|
||||
# Sequential: train (already done above), then eval.
|
||||
else:
|
||||
evaluation_metrics = self.evaluate()
|
||||
|
||||
assert isinstance(evaluation_metrics, dict), \
|
||||
"_evaluate() needs to return a dict."
|
||||
res.update(evaluation_metrics)
|
||||
|
||||
# Check `env_task_fn` for possible update of the env's task.
|
||||
if self.config["env_task_fn"] is not None:
|
||||
|
|
|
@ -5,7 +5,7 @@ from ray.actor import ActorHandle
|
|||
from ray.util.iter import LocalIterator
|
||||
from ray.rllib.evaluation.metrics import collect_episodes, summarize_episodes
|
||||
from ray.rllib.execution.common import AGENT_STEPS_SAMPLED_COUNTER, \
|
||||
STEPS_SAMPLED_COUNTER, _get_shared_metrics
|
||||
STEPS_SAMPLED_COUNTER, STEPS_TRAINED_COUNTER, _get_shared_metrics
|
||||
from ray.rllib.evaluation.worker_set import WorkerSet
|
||||
|
||||
|
||||
|
@ -13,7 +13,9 @@ def StandardMetricsReporting(
|
|||
train_op: LocalIterator[Any],
|
||||
workers: WorkerSet,
|
||||
config: dict,
|
||||
selected_workers: List[ActorHandle] = None) -> LocalIterator[dict]:
|
||||
selected_workers: List[ActorHandle] = None,
|
||||
by_steps_trained: bool = False,
|
||||
) -> LocalIterator[dict]:
|
||||
"""Operator to periodically collect and report metrics.
|
||||
|
||||
Args:
|
||||
|
@ -24,6 +26,8 @@ def StandardMetricsReporting(
|
|||
of stats reporting.
|
||||
selected_workers (list): Override the list of remote workers
|
||||
to collect metrics from.
|
||||
by_steps_trained (bool): If True, uses the `STEPS_TRAINED_COUNTER`
|
||||
instead of the `STEPS_SAMPLED_COUNTER` in metrics.
|
||||
|
||||
Returns:
|
||||
LocalIterator[dict]: A local iterator over training results.
|
||||
|
@ -36,10 +40,12 @@ def StandardMetricsReporting(
|
|||
"""
|
||||
|
||||
output_op = train_op \
|
||||
.filter(OncePerTimestepsElapsed(config["timesteps_per_iteration"])) \
|
||||
.filter(OncePerTimestepsElapsed(config["timesteps_per_iteration"],
|
||||
by_steps_trained=by_steps_trained)) \
|
||||
.filter(OncePerTimeInterval(config["min_iter_time_s"])) \
|
||||
.for_each(CollectMetrics(
|
||||
workers, min_history=config["metrics_smoothing_episodes"],
|
||||
workers,
|
||||
min_history=config["metrics_smoothing_episodes"],
|
||||
timeout_seconds=config["collect_metrics_timeout"],
|
||||
selected_workers=selected_workers))
|
||||
return output_op
|
||||
|
@ -158,15 +164,26 @@ class OncePerTimestepsElapsed:
|
|||
# will only return after 1000 steps have elapsed
|
||||
"""
|
||||
|
||||
def __init__(self, delay_steps: int):
|
||||
def __init__(self, delay_steps: int, by_steps_trained: bool = False):
|
||||
"""
|
||||
Args:
|
||||
delay_steps (int): The number of steps (sampled or trained) every
|
||||
which this op returns True.
|
||||
by_steps_trained (bool): If True, uses the `STEPS_TRAINED_COUNTER`
|
||||
instead of the `STEPS_SAMPLED_COUNTER` in metrics.
|
||||
"""
|
||||
self.delay_steps = delay_steps
|
||||
self.by_steps_trained = by_steps_trained
|
||||
self.last_called = 0
|
||||
|
||||
def __call__(self, item: Any) -> bool:
|
||||
if self.delay_steps <= 0:
|
||||
return True
|
||||
metrics = _get_shared_metrics()
|
||||
now = metrics.counters[STEPS_SAMPLED_COUNTER]
|
||||
if self.by_steps_trained:
|
||||
now = metrics.counters[STEPS_TRAINED_COUNTER]
|
||||
else:
|
||||
now = metrics.counters[STEPS_SAMPLED_COUNTER]
|
||||
if now - self.last_called >= self.delay_steps:
|
||||
self.last_called = now
|
||||
return True
|
||||
|
|
Loading…
Add table
Reference in a new issue