mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
[rllib] Fix stats collection and some docs bugs since the refactoring (#2361)
* fix * fix pbt example * fix * fix * single thread by default * vec * fix * fix
This commit is contained in:
parent
9a6e329325
commit
d24f19fd1e
13 changed files with 60 additions and 34 deletions
|
@ -51,7 +51,7 @@ For a full example of a custom model in code, see the `Carla RLlib model <https:
|
|||
Custom Preprocessors
|
||||
--------------------
|
||||
|
||||
Similarly, custom preprocessors should subclass the RLlib `preprocessor class <https://github.com/ray-project/ray/blob/master/python/ray/rllib/models/preprocessors.py>`__ and registered in the model catalog:
|
||||
Similarly, custom preprocessors should subclass the RLlib `preprocessor class <https://github.com/ray-project/ray/blob/master/python/ray/rllib/models/preprocessors.py>`__ and be registered in the model catalog:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
|
|
|
@ -3,9 +3,9 @@ RLlib: Scalable Reinforcement Learning
|
|||
|
||||
RLlib is an open-source library for reinforcement learning that offers both a collection of reference algorithms and scalable primitives for composing new ones.
|
||||
|
||||
For an overview of RLlib, see the `documentation <http://ray.readthedocs.io/en/latest/rllib.html>`__.
|
||||
For an overview of RLlib, see the [documentation](http://ray.readthedocs.io/en/latest/rllib.html).
|
||||
|
||||
If you've found RLlib useful for your research, you can cite the `paper <https://arxiv.org/abs/1712.09381>`__ as follows:
|
||||
If you've found RLlib useful for your research, you can cite the [paper](https://arxiv.org/abs/1712.09381) as follows:
|
||||
|
||||
```
|
||||
@inproceedings{liang2018rllib,
|
|
@ -9,7 +9,6 @@ import ray
|
|||
from ray.rllib.agents.agent import Agent, with_common_config
|
||||
from ray.rllib.optimizers import AsyncGradientsOptimizer
|
||||
from ray.rllib.utils import FilterManager
|
||||
from ray.rllib.evaluation.metrics import collect_metrics
|
||||
from ray.tune.trial import Resources
|
||||
|
||||
DEFAULT_CONFIG = with_common_config({
|
||||
|
@ -98,12 +97,13 @@ class A3CAgent(Agent):
|
|||
self.config["optimizer"])
|
||||
|
||||
def _train(self):
|
||||
prev_steps = self.optimizer.num_steps_sampled
|
||||
self.optimizer.step()
|
||||
FilterManager.synchronize(
|
||||
self.local_evaluator.filters, self.remote_evaluators)
|
||||
result = collect_metrics(self.local_evaluator, self.remote_evaluators)
|
||||
result = self.optimizer.collect_metrics()
|
||||
result = result._replace(
|
||||
info=self.optimizer.stats())
|
||||
timesteps_this_iter=self.optimizer.num_steps_sampled - prev_steps)
|
||||
return result
|
||||
|
||||
def _stop(self):
|
||||
|
|
|
@ -39,8 +39,14 @@ COMMON_CONFIG = {
|
|||
"model": {},
|
||||
# Arguments to pass to the rllib optimizer
|
||||
"optimizer": {},
|
||||
# Override default TF session args if non-empty
|
||||
"tf_session_args": {},
|
||||
# Configure TF for single-process operation by default
|
||||
"tf_session_args": {
|
||||
"intra_op_parallelism_threads": 1,
|
||||
"inter_op_parallelism_threads": 1,
|
||||
"gpu_options": {
|
||||
"allow_growth": True,
|
||||
},
|
||||
},
|
||||
# Whether to LZ4 compress observations
|
||||
"compress_observations": False,
|
||||
|
||||
|
|
|
@ -185,9 +185,17 @@ class DQNAgent(Agent):
|
|||
e.foreach_policy.remote(lambda p, _: p.set_epsilon(exp_val))
|
||||
exp_vals.append(exp_val)
|
||||
|
||||
result = collect_metrics(
|
||||
self.local_evaluator, self.remote_evaluators)
|
||||
if self.config["per_worker_exploration"]:
|
||||
# Only collect metrics from the third of workers with lowest eps
|
||||
result = collect_metrics(
|
||||
self.local_evaluator,
|
||||
self.remote_evaluators[-len(self.remote_evaluators) // 3:])
|
||||
else:
|
||||
result = collect_metrics(
|
||||
self.local_evaluator, self.remote_evaluators)
|
||||
|
||||
return result._replace(
|
||||
timesteps_this_iter=self.global_timestep - start_timestep,
|
||||
info=dict({
|
||||
"min_exploration": min(exp_vals),
|
||||
"max_exploration": max(exp_vals),
|
||||
|
|
|
@ -4,7 +4,6 @@ from __future__ import print_function
|
|||
|
||||
from ray.rllib.agents.agent import Agent, with_common_config
|
||||
from ray.rllib.agents.pg.pg_policy_graph import PGPolicyGraph
|
||||
from ray.rllib.evaluation.metrics import collect_metrics
|
||||
from ray.rllib.optimizers import SyncSamplesOptimizer
|
||||
from ray.tune.trial import Resources
|
||||
|
||||
|
@ -49,6 +48,7 @@ class PGAgent(Agent):
|
|||
self.config["optimizer"])
|
||||
|
||||
def _train(self):
|
||||
prev_steps = self.optimizer.num_steps_sampled
|
||||
self.optimizer.step()
|
||||
return collect_metrics(
|
||||
self.optimizer.local_evaluator, self.optimizer.remote_evaluators)
|
||||
return self.optimizer.collect_metrics()._replace(
|
||||
timesteps_this_iter=self.optimizer.num_steps_sampled - prev_steps)
|
||||
|
|
|
@ -9,7 +9,6 @@ import pickle
|
|||
import ray
|
||||
from ray.rllib.agents import Agent, with_common_config
|
||||
from ray.rllib.agents.ppo.ppo_tf_policy import PPOTFPolicyGraph
|
||||
from ray.rllib.evaluation.metrics import collect_metrics
|
||||
from ray.rllib.utils import FilterManager
|
||||
from ray.rllib.optimizers.multi_gpu_optimizer import LocalMultiGPUOptimizer
|
||||
from ray.tune.trial import Resources
|
||||
|
@ -81,6 +80,8 @@ class PPOAgent(Agent):
|
|||
"timesteps_per_batch": self.config["timesteps_per_batch"]})
|
||||
|
||||
def _train(self):
|
||||
prev_steps = self.optimizer.num_steps_sampled
|
||||
|
||||
def postprocess_samples(batch):
|
||||
# Divide by the maximum of value.std() and 1e-4
|
||||
# to guard against the case where all values are equal
|
||||
|
@ -92,6 +93,7 @@ class PPOAgent(Agent):
|
|||
if not self.config["use_gae"]:
|
||||
batch.data["value_targets"] = dummy
|
||||
batch.data["vf_preds"] = dummy
|
||||
|
||||
extra_fetches = self.optimizer.step(postprocess_fn=postprocess_samples)
|
||||
kl = np.array(extra_fetches["kl"]).mean(axis=1)[-1]
|
||||
total_loss = np.array(extra_fetches["total_loss"]).mean(axis=1)[-1]
|
||||
|
@ -112,8 +114,10 @@ class PPOAgent(Agent):
|
|||
|
||||
FilterManager.synchronize(
|
||||
self.local_evaluator.filters, self.remote_evaluators)
|
||||
res = collect_metrics(self.local_evaluator, self.remote_evaluators)
|
||||
res = res._replace(info=info)
|
||||
res = self.optimizer.collect_metrics()
|
||||
res = res._replace(
|
||||
timesteps_this_iter=self.optimizer.num_steps_sampled - prev_steps,
|
||||
info=dict(info, **res.info))
|
||||
return res
|
||||
|
||||
def _stop(self):
|
||||
|
|
|
@ -33,7 +33,6 @@ def collect_metrics(local_evaluator, remote_evaluators=[]):
|
|||
max_reward = float('nan')
|
||||
avg_reward = np.mean(episode_rewards)
|
||||
avg_length = np.mean(episode_lengths)
|
||||
timesteps = np.sum(episode_lengths)
|
||||
|
||||
for policy_id, rewards in policy_rewards.copy().items():
|
||||
policy_rewards[policy_id] = np.mean(rewards)
|
||||
|
@ -44,5 +43,4 @@ def collect_metrics(local_evaluator, remote_evaluators=[]):
|
|||
episode_reward_mean=avg_reward,
|
||||
episode_len_mean=avg_length,
|
||||
episodes_total=len(episode_lengths),
|
||||
timesteps_this_iter=timesteps,
|
||||
policy_reward_mean=dict(policy_rewards))
|
||||
|
|
|
@ -3,6 +3,7 @@ from __future__ import division
|
|||
from __future__ import print_function
|
||||
|
||||
import ray
|
||||
from ray.rllib.evaluation.metrics import collect_metrics
|
||||
from ray.rllib.evaluation.sample_batch import MultiAgentBatch
|
||||
|
||||
|
||||
|
@ -104,6 +105,10 @@ class PolicyOptimizer(object):
|
|||
for i, ev in enumerate(self.remote_evaluators)])
|
||||
return local_result + remote_results
|
||||
|
||||
def collect_metrics(self):
|
||||
res = collect_metrics(self.local_evaluator, self.remote_evaluators)
|
||||
return res._replace(info=self.stats())
|
||||
|
||||
def _check_not_multiagent(self, sample_batch):
|
||||
if isinstance(sample_batch, MultiAgentBatch):
|
||||
raise NotImplementedError(
|
||||
|
|
|
@ -17,12 +17,11 @@ class SyncSamplesOptimizer(PolicyOptimizer):
|
|||
model weights are then broadcast to all remote evaluators.
|
||||
"""
|
||||
|
||||
def _init(self, batch_size=32):
|
||||
def _init(self):
|
||||
self.update_weights_timer = TimerStat()
|
||||
self.sample_timer = TimerStat()
|
||||
self.grad_timer = TimerStat()
|
||||
self.throughput = RunningStat()
|
||||
self.batch_size = batch_size
|
||||
|
||||
def step(self):
|
||||
with self.update_weights_timer:
|
||||
|
|
|
@ -5,6 +5,7 @@ import gym
|
|||
from gym.spaces import Box, Discrete, Tuple
|
||||
from gym.envs.registration import EnvSpec
|
||||
import numpy as np
|
||||
import sys
|
||||
|
||||
import ray
|
||||
from ray.rllib.agents.agent import get_agent_class
|
||||
|
@ -117,4 +118,12 @@ class ModelSupportedSpaces(unittest.TestCase):
|
|||
|
||||
|
||||
if __name__ == "__main__":
|
||||
if len(sys.argv) > 1 and sys.argv[1] == "--smoke":
|
||||
ACTION_SPACES_TO_TEST = {
|
||||
"discrete": Discrete(5),
|
||||
}
|
||||
OBSERVATION_SPACES_TO_TEST = {
|
||||
"vector": Box(0.0, 1.0, (5,), dtype=np.float32),
|
||||
"atari": Box(0.0, 1.0, (210, 160, 3), dtype=np.float32),
|
||||
}
|
||||
unittest.main(verbosity=2)
|
||||
|
|
|
@ -7,5 +7,7 @@ pong-apex:
|
|||
config:
|
||||
target_network_update_freq: 50000
|
||||
num_workers: 32
|
||||
## can also enable vectorization within processes
|
||||
# num_envs: 4
|
||||
lr: .0001
|
||||
gamma: 0.99
|
||||
|
|
|
@ -52,28 +52,23 @@ if __name__ == "__main__":
|
|||
"env": "Humanoid-v1",
|
||||
"repeat": 8,
|
||||
"config": {
|
||||
"kl_coeff":
|
||||
1.0,
|
||||
"num_workers":
|
||||
8,
|
||||
"devices": ["/gpu:0"],
|
||||
"kl_coeff": 1.0,
|
||||
"num_workers": 8,
|
||||
"num_gpus": 1,
|
||||
"model": {
|
||||
"free_log_std": True
|
||||
},
|
||||
# These params are tuned from a fixed starting value.
|
||||
"lambda":
|
||||
0.95,
|
||||
"clip_param":
|
||||
0.2,
|
||||
"sgd_stepsize":
|
||||
1e-4,
|
||||
"lambda": 0.95,
|
||||
"clip_param": 0.2,
|
||||
"sgd_stepsize": 1e-4,
|
||||
# These params start off randomly drawn from a set.
|
||||
"num_sgd_iter":
|
||||
lambda spec: random.choice([10, 20, 30]),
|
||||
lambda spec: random.choice([10, 20, 30]),
|
||||
"sgd_batchsize":
|
||||
lambda spec: random.choice([128, 512, 2048]),
|
||||
lambda spec: random.choice([128, 512, 2048]),
|
||||
"timesteps_per_batch":
|
||||
lambda spec: random.choice([10000, 20000, 40000])
|
||||
lambda spec: random.choice([10000, 20000, 40000])
|
||||
},
|
||||
},
|
||||
},
|
||||
|
|
Loading…
Add table
Reference in a new issue