mirror of
https://github.com/vale981/ray
synced 2025-03-05 18:11:42 -05:00
This commit is contained in:
parent
1d033fb552
commit
c3a15ecc0f
10 changed files with 66 additions and 48 deletions
|
@ -22,35 +22,28 @@ class TestIMPALA(unittest.TestCase):
|
|||
def test_impala_compilation(self):
|
||||
"""Test whether an ImpalaTrainer can be built with both frameworks."""
|
||||
config = impala.DEFAULT_CONFIG.copy()
|
||||
config["model"]["lstm_use_prev_action"] = True
|
||||
config["model"]["lstm_use_prev_reward"] = True
|
||||
num_iterations = 1
|
||||
env = "CartPole-v0"
|
||||
|
||||
for _ in framework_iterator(config):
|
||||
local_cfg = config.copy()
|
||||
for env in ["Pendulum-v0", "CartPole-v0"]:
|
||||
print("Env={}".format(env))
|
||||
print("w/o LSTM")
|
||||
# Test w/o LSTM.
|
||||
local_cfg["model"]["use_lstm"] = False
|
||||
local_cfg["num_aggregation_workers"] = 0
|
||||
trainer = impala.ImpalaTrainer(config=local_cfg, env=env)
|
||||
for i in range(num_iterations):
|
||||
print(trainer.train())
|
||||
check_compute_single_action(trainer)
|
||||
trainer.stop()
|
||||
|
||||
# Test w/ LSTM.
|
||||
print("w/ LSTM")
|
||||
local_cfg["model"]["use_lstm"] = True
|
||||
local_cfg["model"]["lstm_use_prev_action"] = True
|
||||
local_cfg["model"]["lstm_use_prev_reward"] = True
|
||||
local_cfg["num_aggregation_workers"] = 1
|
||||
for lstm in [False, True]:
|
||||
local_cfg["num_aggregation_workers"] = 0 if not lstm else 1
|
||||
local_cfg["model"]["use_lstm"] = lstm
|
||||
print("lstm={} aggregation-worker={}".format(
|
||||
lstm, local_cfg["num_aggregation_workers"]))
|
||||
# Test with and w/o aggregation workers (this has nothing
|
||||
# to do with LSTMs, though).
|
||||
trainer = impala.ImpalaTrainer(config=local_cfg, env=env)
|
||||
for i in range(num_iterations):
|
||||
print(trainer.train())
|
||||
check_compute_single_action(
|
||||
trainer,
|
||||
include_state=True,
|
||||
include_prev_action_reward=True)
|
||||
include_state=lstm,
|
||||
include_prev_action_reward=lstm,
|
||||
)
|
||||
trainer.stop()
|
||||
|
||||
def test_impala_lr_schedule(self):
|
||||
|
|
|
@ -17,6 +17,8 @@ from ray.rllib.evaluation.postprocessing import compute_advantages
|
|||
from ray.rllib.examples.env.mock_env import MockEnv, MockEnv2
|
||||
from ray.rllib.examples.env.multi_agent import MultiAgentCartPole
|
||||
from ray.rllib.examples.policy.random_policy import RandomPolicy
|
||||
from ray.rllib.execution.common import STEPS_SAMPLED_COUNTER, \
|
||||
STEPS_TRAINED_COUNTER
|
||||
from ray.rllib.policy.policy import Policy
|
||||
from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID, MultiAgentBatch, \
|
||||
SampleBatch
|
||||
|
@ -170,10 +172,10 @@ class TestRolloutWorker(unittest.TestCase):
|
|||
policy = agent.get_policy()
|
||||
for i in range(3):
|
||||
result = agent.train()
|
||||
print("num_steps_trained={}".format(
|
||||
result["info"]["num_steps_trained"]))
|
||||
print("num_steps_sampled={}".format(
|
||||
result["info"]["num_steps_sampled"]))
|
||||
print("{}={}".format(STEPS_TRAINED_COUNTER,
|
||||
result["info"][STEPS_TRAINED_COUNTER]))
|
||||
print("{}={}".format(STEPS_SAMPLED_COUNTER,
|
||||
result["info"][STEPS_SAMPLED_COUNTER]))
|
||||
global_timesteps = policy.global_timestep
|
||||
print("global_timesteps={}".format(global_timesteps))
|
||||
expected_lr = \
|
||||
|
|
|
@ -419,9 +419,9 @@ class TestTrajectoryViewAPI(unittest.TestCase):
|
|||
results = None
|
||||
for i in range(num_iterations):
|
||||
results = trainer.train()
|
||||
self.assertGreater(results["timesteps_total"],
|
||||
self.assertGreater(results["agent_timesteps_total"],
|
||||
num_iterations * config["train_batch_size"])
|
||||
self.assertLess(results["timesteps_total"],
|
||||
self.assertLess(results["agent_timesteps_total"],
|
||||
(num_iterations + 1) * config["train_batch_size"])
|
||||
trainer.stop()
|
||||
|
||||
|
|
|
@ -5,7 +5,9 @@ from ray.util.iter_metrics import MetricsContext
|
|||
|
||||
# Counters for training progress (keys for metrics.counters).
|
||||
STEPS_SAMPLED_COUNTER = "num_steps_sampled"
|
||||
AGENT_STEPS_SAMPLED_COUNTER = "num_agent_steps_sampled"
|
||||
STEPS_TRAINED_COUNTER = "num_steps_trained"
|
||||
AGENT_STEPS_TRAINED_COUNTER = "num_agent_steps_trained"
|
||||
|
||||
# Counters to track target network updates.
|
||||
LAST_TARGET_UPDATE_TS = "last_target_update_ts"
|
||||
|
|
|
@ -3,8 +3,8 @@ import time
|
|||
|
||||
from ray.util.iter import LocalIterator
|
||||
from ray.rllib.evaluation.metrics import collect_episodes, summarize_episodes
|
||||
from ray.rllib.execution.common import STEPS_SAMPLED_COUNTER, \
|
||||
_get_shared_metrics
|
||||
from ray.rllib.execution.common import AGENT_STEPS_SAMPLED_COUNTER, \
|
||||
STEPS_SAMPLED_COUNTER, _get_shared_metrics
|
||||
from ray.rllib.evaluation.worker_set import WorkerSet
|
||||
|
||||
|
||||
|
@ -103,6 +103,8 @@ class CollectMetrics:
|
|||
res.update({
|
||||
"num_healthy_workers": len(self.workers.remote_workers()),
|
||||
"timesteps_total": metrics.counters[STEPS_SAMPLED_COUNTER],
|
||||
"agent_timesteps_total": metrics.counters.get(
|
||||
AGENT_STEPS_SAMPLED_COUNTER, 0),
|
||||
})
|
||||
res["timers"] = timers
|
||||
res["info"] = info
|
||||
|
|
|
@ -7,9 +7,9 @@ from ray.util.iter_metrics import SharedMetrics
|
|||
from ray.rllib.evaluation.metrics import get_learner_stats
|
||||
from ray.rllib.evaluation.rollout_worker import get_global_worker
|
||||
from ray.rllib.evaluation.worker_set import WorkerSet
|
||||
from ray.rllib.execution.common import STEPS_SAMPLED_COUNTER, LEARNER_INFO, \
|
||||
SAMPLE_TIMER, GRAD_WAIT_TIMER, _check_sample_batch_type, \
|
||||
_get_shared_metrics
|
||||
from ray.rllib.execution.common import AGENT_STEPS_SAMPLED_COUNTER, \
|
||||
STEPS_SAMPLED_COUNTER, LEARNER_INFO, SAMPLE_TIMER, GRAD_WAIT_TIMER, \
|
||||
_check_sample_batch_type, _get_shared_metrics
|
||||
from ray.rllib.policy.sample_batch import SampleBatch, DEFAULT_POLICY_ID, \
|
||||
MultiAgentBatch
|
||||
from ray.rllib.utils.sgd import standardized
|
||||
|
@ -60,6 +60,11 @@ def ParallelRollouts(workers: WorkerSet, *, mode="bulk_sync",
|
|||
def report_timesteps(batch):
|
||||
metrics = _get_shared_metrics()
|
||||
metrics.counters[STEPS_SAMPLED_COUNTER] += batch.count
|
||||
if isinstance(batch, MultiAgentBatch):
|
||||
metrics.counters[AGENT_STEPS_SAMPLED_COUNTER] += \
|
||||
batch.agent_steps()
|
||||
else:
|
||||
metrics.counters[AGENT_STEPS_SAMPLED_COUNTER] += batch.count
|
||||
return batch
|
||||
|
||||
if not workers.remote_workers():
|
||||
|
|
|
@ -9,11 +9,11 @@ from ray.rllib.evaluation.metrics import extract_stats, get_learner_stats, \
|
|||
LEARNER_STATS_KEY
|
||||
from ray.rllib.evaluation.worker_set import WorkerSet
|
||||
from ray.rllib.execution.common import \
|
||||
STEPS_SAMPLED_COUNTER, STEPS_TRAINED_COUNTER, LEARNER_INFO, \
|
||||
APPLY_GRADS_TIMER, COMPUTE_GRADS_TIMER, WORKER_UPDATE_TIMER, \
|
||||
LEARN_ON_BATCH_TIMER, LOAD_BATCH_TIMER, LAST_TARGET_UPDATE_TS, \
|
||||
NUM_TARGET_UPDATES, _get_global_vars, _check_sample_batch_type, \
|
||||
_get_shared_metrics
|
||||
AGENT_STEPS_TRAINED_COUNTER, APPLY_GRADS_TIMER, COMPUTE_GRADS_TIMER, \
|
||||
LAST_TARGET_UPDATE_TS, LEARNER_INFO, LEARN_ON_BATCH_TIMER, \
|
||||
LOAD_BATCH_TIMER, NUM_TARGET_UPDATES, STEPS_SAMPLED_COUNTER, \
|
||||
STEPS_TRAINED_COUNTER, WORKER_UPDATE_TIMER, _check_sample_batch_type, \
|
||||
_get_global_vars, _get_shared_metrics
|
||||
from ray.rllib.execution.multi_gpu_impl import LocalSyncParallelOptimizer
|
||||
from ray.rllib.policy.sample_batch import SampleBatch, DEFAULT_POLICY_ID, \
|
||||
MultiAgentBatch
|
||||
|
@ -76,6 +76,9 @@ class TrainOneStep:
|
|||
info, "custom_metrics")
|
||||
learn_timer.push_units_processed(batch.count)
|
||||
metrics.counters[STEPS_TRAINED_COUNTER] += batch.count
|
||||
if isinstance(batch, MultiAgentBatch):
|
||||
metrics.counters[
|
||||
AGENT_STEPS_TRAINED_COUNTER] += batch.agent_steps()
|
||||
# Update weights - after learning on the local worker - on all remote
|
||||
# workers.
|
||||
if self.workers.remote_workers():
|
||||
|
@ -236,6 +239,7 @@ class TrainTFMultiGPU:
|
|||
learn_timer.push_units_processed(samples.count)
|
||||
|
||||
metrics.counters[STEPS_TRAINED_COUNTER] += samples.count
|
||||
metrics.counters[AGENT_STEPS_TRAINED_COUNTER] += samples.agent_steps()
|
||||
metrics.info[LEARNER_INFO] = fetches
|
||||
if self.workers.remote_workers():
|
||||
with metrics.timers[WORKER_UPDATE_TIMER]:
|
||||
|
|
|
@ -3,15 +3,16 @@ import platform
|
|||
from typing import List, Dict, Any
|
||||
|
||||
import ray
|
||||
from ray.rllib.execution.common import STEPS_SAMPLED_COUNTER, \
|
||||
_get_shared_metrics
|
||||
from ray.rllib.evaluation.worker_set import WorkerSet
|
||||
from ray.rllib.execution.common import AGENT_STEPS_SAMPLED_COUNTER, \
|
||||
STEPS_SAMPLED_COUNTER, _get_shared_metrics
|
||||
from ray.rllib.execution.replay_ops import MixInReplay
|
||||
from ray.rllib.execution.rollout_ops import ParallelRollouts, ConcatBatches
|
||||
from ray.rllib.policy.sample_batch import MultiAgentBatch
|
||||
from ray.rllib.utils.actors import create_colocated
|
||||
from ray.rllib.utils.typing import SampleBatchType, ModelWeights
|
||||
from ray.util.iter import ParallelIterator, ParallelIteratorWorker, \
|
||||
from_actors, LocalIterator
|
||||
from ray.rllib.utils.typing import SampleBatchType, ModelWeights
|
||||
from ray.rllib.evaluation.worker_set import WorkerSet
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -100,6 +101,11 @@ def gather_experiences_tree_aggregation(workers: WorkerSet,
|
|||
def record_steps_sampled(batch):
|
||||
metrics = _get_shared_metrics()
|
||||
metrics.counters[STEPS_SAMPLED_COUNTER] += batch.count
|
||||
if isinstance(batch, MultiAgentBatch):
|
||||
metrics.counters[AGENT_STEPS_SAMPLED_COUNTER] += \
|
||||
batch.agent_steps()
|
||||
else:
|
||||
metrics.counters[AGENT_STEPS_SAMPLED_COUNTER] += batch.count
|
||||
return batch
|
||||
|
||||
return train_batches.gather_async().for_each(record_steps_sampled)
|
||||
|
|
|
@ -2,6 +2,8 @@ import unittest
|
|||
|
||||
import ray
|
||||
from ray.rllib.agents.a3c import A2CTrainer
|
||||
from ray.rllib.execution.common import STEPS_SAMPLED_COUNTER, \
|
||||
STEPS_TRAINED_COUNTER
|
||||
from ray.rllib.utils.test_utils import framework_iterator
|
||||
|
||||
|
||||
|
@ -28,8 +30,8 @@ class TestDistributedExecution(unittest.TestCase):
|
|||
assert isinstance(result, dict)
|
||||
assert "info" in result
|
||||
assert "learner" in result["info"]
|
||||
assert "num_steps_sampled" in result["info"]
|
||||
assert "num_steps_trained" in result["info"]
|
||||
assert STEPS_SAMPLED_COUNTER in result["info"]
|
||||
assert STEPS_TRAINED_COUNTER in result["info"]
|
||||
assert "timers" in result
|
||||
assert "learn_time_ms" in result["timers"]
|
||||
assert "learn_throughput" in result["timers"]
|
||||
|
|
|
@ -7,6 +7,8 @@ import ray
|
|||
from ray.rllib.agents.ppo.ppo_tf_policy import PPOTFPolicy
|
||||
from ray.rllib.evaluation.worker_set import WorkerSet
|
||||
from ray.rllib.evaluation.rollout_worker import RolloutWorker
|
||||
from ray.rllib.execution.common import STEPS_SAMPLED_COUNTER, \
|
||||
STEPS_TRAINED_COUNTER
|
||||
from ray.rllib.execution.concurrency_ops import Concurrently, Enqueue, Dequeue
|
||||
from ray.rllib.execution.metric_ops import StandardMetricsReporting
|
||||
from ray.rllib.execution.replay_ops import StoreToReplayBuffer, Replay
|
||||
|
@ -122,11 +124,11 @@ def test_rollouts(ray_start_regular_shared):
|
|||
a = ParallelRollouts(workers, mode="bulk_sync")
|
||||
assert next(a).count == 200
|
||||
counters = a.shared_metrics.get().counters
|
||||
assert counters["num_steps_sampled"] == 200, counters
|
||||
assert counters[STEPS_SAMPLED_COUNTER] == 200, counters
|
||||
a = ParallelRollouts(workers, mode="async")
|
||||
assert next(a).count == 100
|
||||
counters = a.shared_metrics.get().counters
|
||||
assert counters["num_steps_sampled"] == 100, counters
|
||||
assert counters[STEPS_SAMPLED_COUNTER] == 100, counters
|
||||
workers.stop()
|
||||
|
||||
|
||||
|
@ -135,7 +137,7 @@ def test_rollouts_local(ray_start_regular_shared):
|
|||
a = ParallelRollouts(workers, mode="bulk_sync")
|
||||
assert next(a).count == 100
|
||||
counters = a.shared_metrics.get().counters
|
||||
assert counters["num_steps_sampled"] == 100, counters
|
||||
assert counters[STEPS_SAMPLED_COUNTER] == 100, counters
|
||||
workers.stop()
|
||||
|
||||
|
||||
|
@ -163,7 +165,7 @@ def test_async_grads(ray_start_regular_shared):
|
|||
res1 = next(a)
|
||||
assert isinstance(res1, tuple) and len(res1) == 2, res1
|
||||
counters = a.shared_metrics.get().counters
|
||||
assert counters["num_steps_sampled"] == 100, counters
|
||||
assert counters[STEPS_SAMPLED_COUNTER] == 100, counters
|
||||
workers.stop()
|
||||
|
||||
|
||||
|
@ -176,8 +178,8 @@ def test_train_one_step(ray_start_regular_shared):
|
|||
assert DEFAULT_POLICY_ID in stats
|
||||
assert "learner_stats" in stats[DEFAULT_POLICY_ID]
|
||||
counters = a.shared_metrics.get().counters
|
||||
assert counters["num_steps_sampled"] == 100, counters
|
||||
assert counters["num_steps_trained"] == 100, counters
|
||||
assert counters[STEPS_SAMPLED_COUNTER] == 100, counters
|
||||
assert counters[STEPS_TRAINED_COUNTER] == 100, counters
|
||||
timers = a.shared_metrics.get().timers
|
||||
assert "learn" in timers
|
||||
workers.stop()
|
||||
|
|
Loading…
Add table
Reference in a new issue