mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
[rllib] Fix atari reward calculations, add LR annealing, explained var stat for A2C / impala (#2700)
Changes needed to reproduce Atari plots in IMPALA / A2C: https://github.com/ray-project/rl-experiments
This commit is contained in:
parent
1b3de31ff1
commit
aa014af85b
35 changed files with 483 additions and 148 deletions
|
@ -16,11 +16,14 @@ Tuned examples: `PongNoFrameskip-v4 <https://github.com/ray-project/ray/blob/mas
|
||||||
Advantage Actor-Critic (A2C, A3C)
|
Advantage Actor-Critic (A2C, A3C)
|
||||||
---------------------------------
|
---------------------------------
|
||||||
`[paper] <https://arxiv.org/abs/1602.01783>`__ `[implementation] <https://github.com/ray-project/ray/blob/master/python/ray/rllib/agents/a3c/a3c.py>`__
|
`[paper] <https://arxiv.org/abs/1602.01783>`__ `[implementation] <https://github.com/ray-project/ray/blob/master/python/ray/rllib/agents/a3c/a3c.py>`__
|
||||||
RLlib's A3C uses the AsyncGradientsOptimizer to apply gradients computed remotely on policy evaluation actors. It scales to up to 16-32 worker processes, depending on the environment. Both a TensorFlow (LSTM), and PyTorch version are available.
|
RLlib implements A2C and A3C using SyncSamplesOptimizer and AsyncGradientsOptimizer respectively for policy optimization. These algorithms scale to up to 16-32 worker processes depending on the environment. Both a TensorFlow (LSTM), and PyTorch version are available.
|
||||||
|
|
||||||
Note that if you have a GPU, `IMPALA <#importance-weighted-actor-learner-architecture>`__ probably will perform better than A3C. You can also use the synchronous variant of A3C, `A2C <https://github.com/ray-project/ray/blob/master/python/ray/rllib/agents/a3c/a2c.py>`__.
|
.. note::
|
||||||
|
In most cases, `IMPALA <#importance-weighted-actor-learner-architecture-impala>`__ will outperform A2C / A3C. In our `benchmarks <https://github.com/ray-project/rl-experiments>`__, IMPALA is almost 10x faster than A2C in wallclock time, with similar sample efficiency.
|
||||||
|
|
||||||
Tuned examples: `PongDeterministic-v4 <https://github.com/ray-project/ray/blob/master/python/ray/rllib/tuned_examples/pong-a3c.yaml>`__, `A2C variant <https://github.com/ray-project/ray/blob/master/python/ray/rllib/tuned_examples/pong-a2c.yaml>`__, `PyTorch version <https://github.com/ray-project/ray/blob/master/python/ray/rllib/tuned_examples/pong-a3c-pytorch.yaml>`__
|
Tuned examples: `PongDeterministic-v4 <https://github.com/ray-project/ray/blob/master/python/ray/rllib/tuned_examples/pong-a3c.yaml>`__, `PyTorch version <https://github.com/ray-project/ray/blob/master/python/ray/rllib/tuned_examples/pong-a3c-pytorch.yaml>`__, `{BeamRider,Breakout,Qbert,SpaceInvaders}NoFrameskip-v4 <https://github.com/ray-project/ray/blob/master/python/ray/rllib/tuned_examples/atari-a2c.yaml>`__
|
||||||
|
|
||||||
|
See also our `Atari results <https://github.com/ray-project/rl-experiments>`__.
|
||||||
|
|
||||||
Deep Deterministic Policy Gradients (DDPG)
|
Deep Deterministic Policy Gradients (DDPG)
|
||||||
------------------------------------------
|
------------------------------------------
|
||||||
|
@ -56,12 +59,14 @@ Importance Weighted Actor-Learner Architecture (IMPALA)
|
||||||
`[implementation] <https://github.com/ray-project/ray/blob/master/python/ray/rllib/agents/impala/impala.py>`__
|
`[implementation] <https://github.com/ray-project/ray/blob/master/python/ray/rllib/agents/impala/impala.py>`__
|
||||||
In IMPALA, a central learner runs SGD in a tight loop while asynchronously pulling sample batches from many actor processes. RLlib's IMPALA implementation uses DeepMind's reference `V-trace code <https://github.com/deepmind/scalable_agent/blob/master/vtrace.py>`__. Note that we do not provide a deep residual network out of the box, but one can be plugged in as a `custom model <rllib-models.html#custom-models>`__.
|
In IMPALA, a central learner runs SGD in a tight loop while asynchronously pulling sample batches from many actor processes. RLlib's IMPALA implementation uses DeepMind's reference `V-trace code <https://github.com/deepmind/scalable_agent/blob/master/vtrace.py>`__. Note that we do not provide a deep residual network out of the box, but one can be plugged in as a `custom model <rllib-models.html#custom-models>`__.
|
||||||
|
|
||||||
Tuned examples: `PongNoFrameskip-v4 <https://github.com/ray-project/ray/blob/master/python/ray/rllib/tuned_examples/pong-impala.yaml>`__, `vectorized configuration <https://github.com/ray-project/ray/blob/master/python/ray/rllib/tuned_examples/pong-impala-vectorized.yaml>`__
|
Tuned examples: `PongNoFrameskip-v4 <https://github.com/ray-project/ray/blob/master/python/ray/rllib/tuned_examples/pong-impala.yaml>`__, `vectorized configuration <https://github.com/ray-project/ray/blob/master/python/ray/rllib/tuned_examples/pong-impala-vectorized.yaml>`__, `{BeamRider,Breakout,Qbert,SpaceInvaders}NoFrameskip-v4 <https://github.com/ray-project/ray/blob/master/python/ray/rllib/tuned_examples/atari-impala.yaml>`__
|
||||||
|
|
||||||
|
See also our `Atari results <https://github.com/ray-project/rl-experiments>`__.
|
||||||
|
|
||||||
.. figure:: impala.png
|
.. figure:: impala.png
|
||||||
:align: center
|
:align: center
|
||||||
|
|
||||||
RLlib's IMPALA implementation scales from 16 to 128 workers on PongNoFrameskip-v4. With vectorization, similar learning performance to 128 workers can be achieved with only 32 workers. This is about an order of magnitude faster than A3C (not shown here), with similar sample efficiency.
|
IMPALA solves Atari about 10x faster than A2C / A3C, with similar sample efficiency. Here IMPALA scales from 16 to 128 workers on PongNoFrameskip-v4.
|
||||||
|
|
||||||
Policy Gradients
|
Policy Gradients
|
||||||
----------------
|
----------------
|
||||||
|
@ -74,7 +79,9 @@ Proximal Policy Optimization (PPO)
|
||||||
`[paper] <https://arxiv.org/abs/1707.06347>`__ `[implementation] <https://github.com/ray-project/ray/blob/master/python/ray/rllib/agents/ppo/ppo.py>`__
|
`[paper] <https://arxiv.org/abs/1707.06347>`__ `[implementation] <https://github.com/ray-project/ray/blob/master/python/ray/rllib/agents/ppo/ppo.py>`__
|
||||||
PPO's clipped objective supports multiple SGD passes over the same batch of experiences. RLlib's multi-GPU optimizer pins that data in GPU memory to avoid unnecessary transfers from host memory, substantially improving performance over a naive implementation. RLlib's PPO scales out using multiple workers for experience collection, and also with multiple GPUs for SGD.
|
PPO's clipped objective supports multiple SGD passes over the same batch of experiences. RLlib's multi-GPU optimizer pins that data in GPU memory to avoid unnecessary transfers from host memory, substantially improving performance over a naive implementation. RLlib's PPO scales out using multiple workers for experience collection, and also with multiple GPUs for SGD.
|
||||||
|
|
||||||
Tuned examples: `Humanoid-v1 <https://github.com/ray-project/ray/blob/master/python/ray/rllib/tuned_examples/humanoid-ppo-gae.yaml>`__, `Hopper-v1 <https://github.com/ray-project/ray/blob/master/python/ray/rllib/tuned_examples/hopper-ppo.yaml>`__, `Pendulum-v0 <https://github.com/ray-project/ray/blob/master/python/ray/rllib/tuned_examples/pendulum-ppo.yaml>`__, `PongDeterministic-v4 <https://github.com/ray-project/ray/blob/master/python/ray/rllib/tuned_examples/pong-ppo.yaml>`__, `Walker2d-v1 <https://github.com/ray-project/ray/blob/master/python/ray/rllib/tuned_examples/walker2d-ppo.yaml>`__
|
Tuned examples: `Humanoid-v1 <https://github.com/ray-project/ray/blob/master/python/ray/rllib/tuned_examples/humanoid-ppo-gae.yaml>`__, `Hopper-v1 <https://github.com/ray-project/ray/blob/master/python/ray/rllib/tuned_examples/hopper-ppo.yaml>`__, `Pendulum-v0 <https://github.com/ray-project/ray/blob/master/python/ray/rllib/tuned_examples/pendulum-ppo.yaml>`__, `PongDeterministic-v4 <https://github.com/ray-project/ray/blob/master/python/ray/rllib/tuned_examples/pong-ppo.yaml>`__, `Walker2d-v1 <https://github.com/ray-project/ray/blob/master/python/ray/rllib/tuned_examples/walker2d-ppo.yaml>`__, `{BeamRider,Breakout,Qbert,SpaceInvaders}NoFrameskip-v4 <https://github.com/ray-project/ray/blob/master/python/ray/rllib/tuned_examples/atari-ppo.yaml>`__
|
||||||
|
|
||||||
|
See also our `Atari results <https://github.com/ray-project/rl-experiments>`__.
|
||||||
|
|
||||||
.. figure:: ppo.png
|
.. figure:: ppo.png
|
||||||
:width: 500px
|
:width: 500px
|
||||||
|
|
|
@ -7,6 +7,7 @@ ray.rllib.agents
|
||||||
.. automodule:: ray.rllib.agents
|
.. automodule:: ray.rllib.agents
|
||||||
:members:
|
:members:
|
||||||
|
|
||||||
|
.. autoclass:: ray.rllib.agents.a3c.A2CAgent
|
||||||
.. autoclass:: ray.rllib.agents.a3c.A3CAgent
|
.. autoclass:: ray.rllib.agents.a3c.A3CAgent
|
||||||
.. autoclass:: ray.rllib.agents.ddpg.ApexDDPGAgent
|
.. autoclass:: ray.rllib.agents.ddpg.ApexDDPGAgent
|
||||||
.. autoclass:: ray.rllib.agents.ddpg.DDPGAgent
|
.. autoclass:: ray.rllib.agents.ddpg.DDPGAgent
|
||||||
|
|
|
@ -26,7 +26,6 @@ training process with TensorBoard by running
|
||||||
|
|
||||||
tensorboard --logdir=~/ray_results
|
tensorboard --logdir=~/ray_results
|
||||||
|
|
||||||
|
|
||||||
The ``train.py`` script has a number of options you can show by running
|
The ``train.py`` script has a number of options you can show by running
|
||||||
|
|
||||||
.. code-block:: bash
|
.. code-block:: bash
|
||||||
|
@ -44,14 +43,12 @@ Specifying Parameters
|
||||||
Each algorithm has specific hyperparameters that can be set with ``--config``, in addition to a number of `common hyperparameters <https://github.com/ray-project/ray/blob/master/python/ray/rllib/agents/agent.py>`__. See the
|
Each algorithm has specific hyperparameters that can be set with ``--config``, in addition to a number of `common hyperparameters <https://github.com/ray-project/ray/blob/master/python/ray/rllib/agents/agent.py>`__. See the
|
||||||
`algorithms documentation <rllib-algorithms.html>`__ for more information.
|
`algorithms documentation <rllib-algorithms.html>`__ for more information.
|
||||||
|
|
||||||
In an example below, we train A3C by specifying 8 workers through the config flag.
|
In an example below, we train A2C by specifying 8 workers through the config flag. We also set ``"monitor": true`` to save episode videos to the result dir:
|
||||||
function that creates the env to refer to it by name. The contents of the env_config agent config field will be passed to that function to allow the environment to be configured. The return type should be an OpenAI gym.Env. For example:
|
|
||||||
|
|
||||||
|
|
||||||
.. code-block:: bash
|
.. code-block:: bash
|
||||||
|
|
||||||
python ray/python/ray/rllib/train.py --env=PongDeterministic-v4 \
|
python ray/python/ray/rllib/train.py --env=PongDeterministic-v4 \
|
||||||
--run=A3C --config '{"num_workers": 8}'
|
--run=A2C --config '{"num_workers": 8, "monitor": true}'
|
||||||
|
|
||||||
Evaluating Trained Agents
|
Evaluating Trained Agents
|
||||||
~~~~~~~~~~~~~~~~~~~~~~~~~
|
~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
|
@ -13,9 +13,7 @@ A2C_DEFAULT_CONFIG = merge_dicts(
|
||||||
"gpu": False,
|
"gpu": False,
|
||||||
"sample_batch_size": 20,
|
"sample_batch_size": 20,
|
||||||
"min_iter_time_s": 10,
|
"min_iter_time_s": 10,
|
||||||
"optimizer": {
|
"sample_async": False,
|
||||||
"timesteps_per_batch": 200,
|
|
||||||
},
|
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -7,6 +7,7 @@ import os
|
||||||
import time
|
import time
|
||||||
|
|
||||||
import ray
|
import ray
|
||||||
|
from ray.rllib.agents.a3c.a3c_tf_policy_graph import A3CPolicyGraph
|
||||||
from ray.rllib.agents.agent import Agent, with_common_config
|
from ray.rllib.agents.agent import Agent, with_common_config
|
||||||
from ray.rllib.optimizers import AsyncGradientsOptimizer
|
from ray.rllib.optimizers import AsyncGradientsOptimizer
|
||||||
from ray.rllib.utils import FilterManager, merge_dicts
|
from ray.rllib.utils import FilterManager, merge_dicts
|
||||||
|
@ -23,14 +24,14 @@ DEFAULT_CONFIG = with_common_config({
|
||||||
"grad_clip": 40.0,
|
"grad_clip": 40.0,
|
||||||
# Learning rate
|
# Learning rate
|
||||||
"lr": 0.0001,
|
"lr": 0.0001,
|
||||||
|
# Learning rate schedule
|
||||||
|
"lr_schedule": None,
|
||||||
# Value Function Loss coefficient
|
# Value Function Loss coefficient
|
||||||
"vf_loss_coeff": 0.5,
|
"vf_loss_coeff": 0.5,
|
||||||
# Entropy coefficient
|
# Entropy coefficient
|
||||||
"entropy_coeff": -0.01,
|
"entropy_coeff": -0.01,
|
||||||
# Whether to place workers on GPUs
|
# Whether to place workers on GPUs
|
||||||
"use_gpu_for_workers": False,
|
"use_gpu_for_workers": False,
|
||||||
# Whether to emit extra summary stats
|
|
||||||
"summarize": False,
|
|
||||||
# Min time per iteration
|
# Min time per iteration
|
||||||
"min_iter_time_s": 5,
|
"min_iter_time_s": 5,
|
||||||
# Workers sample async. Note that this increases the effective
|
# Workers sample async. Note that this increases the effective
|
||||||
|
@ -67,6 +68,7 @@ class A3CAgent(Agent):
|
||||||
|
|
||||||
_agent_name = "A3C"
|
_agent_name = "A3C"
|
||||||
_default_config = DEFAULT_CONFIG
|
_default_config = DEFAULT_CONFIG
|
||||||
|
_policy_graph = A3CPolicyGraph
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def default_resource_request(cls, config):
|
def default_resource_request(cls, config):
|
||||||
|
@ -83,8 +85,7 @@ class A3CAgent(Agent):
|
||||||
A3CTorchPolicyGraph
|
A3CTorchPolicyGraph
|
||||||
policy_cls = A3CTorchPolicyGraph
|
policy_cls = A3CTorchPolicyGraph
|
||||||
else:
|
else:
|
||||||
from ray.rllib.agents.a3c.a3c_tf_policy_graph import A3CPolicyGraph
|
policy_cls = self._policy_graph
|
||||||
policy_cls = A3CPolicyGraph
|
|
||||||
|
|
||||||
self.local_evaluator = self.make_local_evaluator(
|
self.local_evaluator = self.make_local_evaluator(
|
||||||
self.env_creator, policy_cls)
|
self.env_creator, policy_cls)
|
||||||
|
|
|
@ -9,8 +9,10 @@ import gym
|
||||||
|
|
||||||
import ray
|
import ray
|
||||||
from ray.rllib.utils.error import UnsupportedSpaceException
|
from ray.rllib.utils.error import UnsupportedSpaceException
|
||||||
|
from ray.rllib.utils.explained_variance import explained_variance
|
||||||
from ray.rllib.evaluation.postprocessing import compute_advantages
|
from ray.rllib.evaluation.postprocessing import compute_advantages
|
||||||
from ray.rllib.evaluation.tf_policy_graph import TFPolicyGraph
|
from ray.rllib.evaluation.tf_policy_graph import TFPolicyGraph, \
|
||||||
|
LearningRateSchedule
|
||||||
from ray.rllib.models.misc import linear, normc_initializer
|
from ray.rllib.models.misc import linear, normc_initializer
|
||||||
from ray.rllib.models.catalog import ModelCatalog
|
from ray.rllib.models.catalog import ModelCatalog
|
||||||
|
|
||||||
|
@ -36,7 +38,7 @@ class A3CLoss(object):
|
||||||
self.entropy * entropy_coeff)
|
self.entropy * entropy_coeff)
|
||||||
|
|
||||||
|
|
||||||
class A3CPolicyGraph(TFPolicyGraph):
|
class A3CPolicyGraph(LearningRateSchedule, TFPolicyGraph):
|
||||||
def __init__(self, observation_space, action_space, config):
|
def __init__(self, observation_space, action_space, config):
|
||||||
config = dict(ray.rllib.agents.a3c.a3c.DEFAULT_CONFIG, **config)
|
config = dict(ray.rllib.agents.a3c.a3c.DEFAULT_CONFIG, **config)
|
||||||
self.config = config
|
self.config = config
|
||||||
|
@ -67,8 +69,8 @@ class A3CPolicyGraph(TFPolicyGraph):
|
||||||
"Action space {} is not supported for A3C.".format(
|
"Action space {} is not supported for A3C.".format(
|
||||||
action_space))
|
action_space))
|
||||||
advantages = tf.placeholder(tf.float32, [None], name="advantages")
|
advantages = tf.placeholder(tf.float32, [None], name="advantages")
|
||||||
v_target = tf.placeholder(tf.float32, [None], name="v_target")
|
self.v_target = tf.placeholder(tf.float32, [None], name="v_target")
|
||||||
self.loss = A3CLoss(action_dist, actions, advantages, v_target,
|
self.loss = A3CLoss(action_dist, actions, advantages, self.v_target,
|
||||||
self.vf, self.config["vf_loss_coeff"],
|
self.vf, self.config["vf_loss_coeff"],
|
||||||
self.config["entropy_coeff"])
|
self.config["entropy_coeff"])
|
||||||
|
|
||||||
|
@ -77,8 +79,10 @@ class A3CPolicyGraph(TFPolicyGraph):
|
||||||
("obs", self.observations),
|
("obs", self.observations),
|
||||||
("actions", actions),
|
("actions", actions),
|
||||||
("advantages", advantages),
|
("advantages", advantages),
|
||||||
("value_targets", v_target),
|
("value_targets", self.v_target),
|
||||||
]
|
]
|
||||||
|
LearningRateSchedule.__init__(self, self.config["lr"],
|
||||||
|
self.config["lr_schedule"])
|
||||||
TFPolicyGraph.__init__(
|
TFPolicyGraph.__init__(
|
||||||
self,
|
self,
|
||||||
observation_space,
|
observation_space,
|
||||||
|
@ -93,6 +97,18 @@ class A3CPolicyGraph(TFPolicyGraph):
|
||||||
seq_lens=self.model.seq_lens,
|
seq_lens=self.model.seq_lens,
|
||||||
max_seq_len=self.config["model"]["max_seq_len"])
|
max_seq_len=self.config["model"]["max_seq_len"])
|
||||||
|
|
||||||
|
self.stats_fetches = {
|
||||||
|
"stats": {
|
||||||
|
"cur_lr": tf.cast(self.cur_lr, tf.float64),
|
||||||
|
"policy_loss": self.loss.pi_loss,
|
||||||
|
"policy_entropy": self.loss.entropy,
|
||||||
|
"grad_gnorm": tf.global_norm(self._grads),
|
||||||
|
"var_gnorm": tf.global_norm(self.var_list),
|
||||||
|
"vf_loss": self.loss.vf_loss,
|
||||||
|
"vf_explained_var": explained_variance(self.v_target, self.vf),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
self.sess.run(tf.global_variables_initializer())
|
self.sess.run(tf.global_variables_initializer())
|
||||||
|
|
||||||
def extra_compute_action_fetches(self):
|
def extra_compute_action_fetches(self):
|
||||||
|
@ -107,9 +123,6 @@ class A3CPolicyGraph(TFPolicyGraph):
|
||||||
vf = self.sess.run(self.vf, feed_dict)
|
vf = self.sess.run(self.vf, feed_dict)
|
||||||
return vf[0]
|
return vf[0]
|
||||||
|
|
||||||
def optimizer(self):
|
|
||||||
return tf.train.AdamOptimizer(self.config["lr"])
|
|
||||||
|
|
||||||
def gradients(self, optimizer):
|
def gradients(self, optimizer):
|
||||||
grads = tf.gradients(self.loss.total_loss, self.var_list)
|
grads = tf.gradients(self.loss.total_loss, self.var_list)
|
||||||
self.grads, _ = tf.clip_by_global_norm(grads, self.config["grad_clip"])
|
self.grads, _ = tf.clip_by_global_norm(grads, self.config["grad_clip"])
|
||||||
|
@ -117,18 +130,7 @@ class A3CPolicyGraph(TFPolicyGraph):
|
||||||
return clipped_grads
|
return clipped_grads
|
||||||
|
|
||||||
def extra_compute_grad_fetches(self):
|
def extra_compute_grad_fetches(self):
|
||||||
if self.config.get("summarize"):
|
return self.stats_fetches
|
||||||
return {
|
|
||||||
"stats": {
|
|
||||||
"policy_loss": self.loss.pi_loss,
|
|
||||||
"value_loss": self.loss.vf_loss,
|
|
||||||
"entropy": self.loss.entropy,
|
|
||||||
"grad_gnorm": tf.global_norm(self._grads),
|
|
||||||
"var_gnorm": tf.global_norm(self.var_list),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
else:
|
|
||||||
return {}
|
|
||||||
|
|
||||||
def get_initial_state(self):
|
def get_initial_state(self):
|
||||||
return self.model.state_init
|
return self.model.state_init
|
||||||
|
|
|
@ -10,7 +10,8 @@ import pickle
|
||||||
|
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
from ray.rllib.evaluation.policy_evaluator import PolicyEvaluator
|
from ray.rllib.evaluation.policy_evaluator import PolicyEvaluator
|
||||||
from ray.rllib.utils import deep_update
|
from ray.rllib.optimizers.policy_optimizer import PolicyOptimizer
|
||||||
|
from ray.rllib.utils import deep_update, merge_dicts
|
||||||
from ray.tune.registry import ENV_CREATOR, _global_registry
|
from ray.tune.registry import ENV_CREATOR, _global_registry
|
||||||
from ray.tune.trainable import Trainable
|
from ray.tune.trainable import Trainable
|
||||||
|
|
||||||
|
@ -61,6 +62,8 @@ COMMON_CONFIG = {
|
||||||
},
|
},
|
||||||
# Whether to LZ4 compress observations
|
# Whether to LZ4 compress observations
|
||||||
"compress_observations": False,
|
"compress_observations": False,
|
||||||
|
# Whether to write episode stats and videos to the agent log dir
|
||||||
|
"monitor": False,
|
||||||
|
|
||||||
# === Multiagent ===
|
# === Multiagent ===
|
||||||
"multiagent": {
|
"multiagent": {
|
||||||
|
@ -103,8 +106,19 @@ class Agent(Trainable):
|
||||||
def make_local_evaluator(self, env_creator, policy_graph):
|
def make_local_evaluator(self, env_creator, policy_graph):
|
||||||
"""Convenience method to return configured local evaluator."""
|
"""Convenience method to return configured local evaluator."""
|
||||||
|
|
||||||
return self._make_evaluator(PolicyEvaluator, env_creator, policy_graph,
|
return self._make_evaluator(
|
||||||
0)
|
PolicyEvaluator,
|
||||||
|
env_creator,
|
||||||
|
policy_graph,
|
||||||
|
0,
|
||||||
|
# important: allow local tf to use multiple CPUs for optimization
|
||||||
|
merge_dicts(
|
||||||
|
self.config, {
|
||||||
|
"tf_session_args": {
|
||||||
|
"intra_op_parallelism_threads": None,
|
||||||
|
"inter_op_parallelism_threads": None,
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
|
||||||
def make_remote_evaluators(self, env_creator, policy_graph, count,
|
def make_remote_evaluators(self, env_creator, policy_graph, count,
|
||||||
remote_args):
|
remote_args):
|
||||||
|
@ -112,13 +126,12 @@ class Agent(Trainable):
|
||||||
|
|
||||||
cls = PolicyEvaluator.as_remote(**remote_args).remote
|
cls = PolicyEvaluator.as_remote(**remote_args).remote
|
||||||
return [
|
return [
|
||||||
self._make_evaluator(cls, env_creator, policy_graph, i + 1)
|
self._make_evaluator(cls, env_creator, policy_graph, i + 1,
|
||||||
for i in range(count)
|
self.config) for i in range(count)
|
||||||
]
|
]
|
||||||
|
|
||||||
def _make_evaluator(self, cls, env_creator, policy_graph, worker_index):
|
def _make_evaluator(self, cls, env_creator, policy_graph, worker_index,
|
||||||
config = self.config
|
config):
|
||||||
|
|
||||||
def session_creator():
|
def session_creator():
|
||||||
return tf.Session(
|
return tf.Session(
|
||||||
config=tf.ConfigProto(**config["tf_session_args"]))
|
config=tf.ConfigProto(**config["tf_session_args"]))
|
||||||
|
@ -142,7 +155,8 @@ class Agent(Trainable):
|
||||||
env_config=config["env_config"],
|
env_config=config["env_config"],
|
||||||
model_config=config["model"],
|
model_config=config["model"],
|
||||||
policy_config=config,
|
policy_config=config,
|
||||||
worker_index=worker_index)
|
worker_index=worker_index,
|
||||||
|
monitor_path=self.logdir if config["monitor"] else None)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def resource_help(cls, config):
|
def resource_help(cls, config):
|
||||||
|
@ -164,10 +178,25 @@ class Agent(Trainable):
|
||||||
|
|
||||||
config = config or {}
|
config = config or {}
|
||||||
|
|
||||||
|
# Vars to synchronize to evaluators on each train call
|
||||||
|
self.global_vars = {"timestep": 0}
|
||||||
|
|
||||||
# Agents allow env ids to be passed directly to the constructor.
|
# Agents allow env ids to be passed directly to the constructor.
|
||||||
self._env_id = env or config.get("env")
|
self._env_id = env or config.get("env")
|
||||||
Trainable.__init__(self, config, logger_creator)
|
Trainable.__init__(self, config, logger_creator)
|
||||||
|
|
||||||
|
def train(self):
|
||||||
|
"""Overrides super.train to synchronize global vars."""
|
||||||
|
|
||||||
|
if hasattr(self, "optimizer") and isinstance(self.optimizer,
|
||||||
|
PolicyOptimizer):
|
||||||
|
self.global_vars["timestep"] = self.optimizer.num_steps_sampled
|
||||||
|
self.optimizer.local_evaluator.set_global_vars(self.global_vars)
|
||||||
|
for ev in self.optimizer.remote_evaluators:
|
||||||
|
ev.set_global_vars.remote(self.global_vars)
|
||||||
|
|
||||||
|
return Trainable.train(self)
|
||||||
|
|
||||||
def _setup(self):
|
def _setup(self):
|
||||||
env = self._env_id
|
env = self._env_id
|
||||||
if env:
|
if env:
|
||||||
|
|
|
@ -29,7 +29,6 @@ DEFAULT_CONFIG = with_common_config({
|
||||||
"sample_batch_size": 50,
|
"sample_batch_size": 50,
|
||||||
"train_batch_size": 500,
|
"train_batch_size": 500,
|
||||||
"min_iter_time_s": 10,
|
"min_iter_time_s": 10,
|
||||||
"summarize": False,
|
|
||||||
"gpu": True,
|
"gpu": True,
|
||||||
"num_workers": 2,
|
"num_workers": 2,
|
||||||
"num_cpus_per_worker": 1,
|
"num_cpus_per_worker": 1,
|
||||||
|
@ -40,6 +39,7 @@ DEFAULT_CONFIG = with_common_config({
|
||||||
# either "adam" or "rmsprop"
|
# either "adam" or "rmsprop"
|
||||||
"opt_type": "adam",
|
"opt_type": "adam",
|
||||||
"lr": 0.0005,
|
"lr": 0.0005,
|
||||||
|
"lr_schedule": None,
|
||||||
# rmsprop considered
|
# rmsprop considered
|
||||||
"decay": 0.99,
|
"decay": 0.99,
|
||||||
"momentum": 0.0,
|
"momentum": 0.0,
|
||||||
|
@ -62,6 +62,7 @@ class ImpalaAgent(Agent):
|
||||||
|
|
||||||
_agent_name = "IMPALA"
|
_agent_name = "IMPALA"
|
||||||
_default_config = DEFAULT_CONFIG
|
_default_config = DEFAULT_CONFIG
|
||||||
|
_policy_graph = VTracePolicyGraph
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def default_resource_request(cls, config):
|
def default_resource_request(cls, config):
|
||||||
|
@ -77,7 +78,7 @@ class ImpalaAgent(Agent):
|
||||||
if k not in self.config["optimizer"]:
|
if k not in self.config["optimizer"]:
|
||||||
self.config["optimizer"][k] = self.config[k]
|
self.config["optimizer"][k] = self.config[k]
|
||||||
if self.config["vtrace"]:
|
if self.config["vtrace"]:
|
||||||
policy_cls = VTracePolicyGraph
|
policy_cls = self._policy_graph
|
||||||
else:
|
else:
|
||||||
policy_cls = A3CPolicyGraph
|
policy_cls = A3CPolicyGraph
|
||||||
self.local_evaluator = self.make_local_evaluator(
|
self.local_evaluator = self.make_local_evaluator(
|
||||||
|
|
|
@ -11,10 +11,12 @@ import gym
|
||||||
|
|
||||||
import ray
|
import ray
|
||||||
from ray.rllib.agents.impala import vtrace
|
from ray.rllib.agents.impala import vtrace
|
||||||
from ray.rllib.evaluation.tf_policy_graph import TFPolicyGraph
|
from ray.rllib.evaluation.tf_policy_graph import TFPolicyGraph, \
|
||||||
|
LearningRateSchedule
|
||||||
from ray.rllib.models.catalog import ModelCatalog
|
from ray.rllib.models.catalog import ModelCatalog
|
||||||
from ray.rllib.models.misc import linear, normc_initializer
|
from ray.rllib.models.misc import linear, normc_initializer
|
||||||
from ray.rllib.utils.error import UnsupportedSpaceException
|
from ray.rllib.utils.error import UnsupportedSpaceException
|
||||||
|
from ray.rllib.utils.explained_variance import explained_variance
|
||||||
|
|
||||||
|
|
||||||
class VTraceLoss(object):
|
class VTraceLoss(object):
|
||||||
|
@ -54,7 +56,7 @@ class VTraceLoss(object):
|
||||||
|
|
||||||
# Compute vtrace on the CPU for better perf.
|
# Compute vtrace on the CPU for better perf.
|
||||||
with tf.device("/cpu:0"):
|
with tf.device("/cpu:0"):
|
||||||
vtrace_returns = vtrace.from_logits(
|
self.vtrace_returns = vtrace.from_logits(
|
||||||
behaviour_policy_logits=behaviour_logits,
|
behaviour_policy_logits=behaviour_logits,
|
||||||
target_policy_logits=target_logits,
|
target_policy_logits=target_logits,
|
||||||
actions=tf.cast(actions, tf.int32),
|
actions=tf.cast(actions, tf.int32),
|
||||||
|
@ -68,10 +70,10 @@ class VTraceLoss(object):
|
||||||
|
|
||||||
# The policy gradients loss
|
# The policy gradients loss
|
||||||
self.pi_loss = -tf.reduce_sum(
|
self.pi_loss = -tf.reduce_sum(
|
||||||
actions_logp * vtrace_returns.pg_advantages)
|
actions_logp * self.vtrace_returns.pg_advantages)
|
||||||
|
|
||||||
# The baseline loss
|
# The baseline loss
|
||||||
delta = values - vtrace_returns.vs
|
delta = values - self.vtrace_returns.vs
|
||||||
self.vf_loss = 0.5 * tf.reduce_sum(tf.square(delta))
|
self.vf_loss = 0.5 * tf.reduce_sum(tf.square(delta))
|
||||||
|
|
||||||
# The entropy loss
|
# The entropy loss
|
||||||
|
@ -82,9 +84,9 @@ class VTraceLoss(object):
|
||||||
self.entropy * entropy_coeff)
|
self.entropy * entropy_coeff)
|
||||||
|
|
||||||
|
|
||||||
class VTracePolicyGraph(TFPolicyGraph):
|
class VTracePolicyGraph(LearningRateSchedule, TFPolicyGraph):
|
||||||
def __init__(self, observation_space, action_space, config):
|
def __init__(self, observation_space, action_space, config):
|
||||||
config = dict(ray.rllib.agents.a3c.a3c.DEFAULT_CONFIG, **config)
|
config = dict(ray.rllib.agents.impala.impala.DEFAULT_CONFIG, **config)
|
||||||
assert config["batch_mode"] == "truncate_episodes", \
|
assert config["batch_mode"] == "truncate_episodes", \
|
||||||
"Must use `truncate_episodes` batch mode with V-trace."
|
"Must use `truncate_episodes` batch mode with V-trace."
|
||||||
self.config = config
|
self.config = config
|
||||||
|
@ -162,6 +164,8 @@ class VTracePolicyGraph(TFPolicyGraph):
|
||||||
("rewards", rewards),
|
("rewards", rewards),
|
||||||
("obs", self.observations),
|
("obs", self.observations),
|
||||||
]
|
]
|
||||||
|
LearningRateSchedule.__init__(self, self.config["lr"],
|
||||||
|
self.config["lr_schedule"])
|
||||||
TFPolicyGraph.__init__(
|
TFPolicyGraph.__init__(
|
||||||
self,
|
self,
|
||||||
observation_space,
|
observation_space,
|
||||||
|
@ -178,13 +182,27 @@ class VTracePolicyGraph(TFPolicyGraph):
|
||||||
|
|
||||||
self.sess.run(tf.global_variables_initializer())
|
self.sess.run(tf.global_variables_initializer())
|
||||||
|
|
||||||
|
self.stats_fetches = {
|
||||||
|
"stats": {
|
||||||
|
"cur_lr": tf.cast(self.cur_lr, tf.float64),
|
||||||
|
"policy_loss": self.loss.pi_loss,
|
||||||
|
"entropy": self.loss.entropy,
|
||||||
|
"grad_gnorm": tf.global_norm(self._grads),
|
||||||
|
"var_gnorm": tf.global_norm(self.var_list),
|
||||||
|
"vf_loss": self.loss.vf_loss,
|
||||||
|
"vf_explained_var": explained_variance(
|
||||||
|
tf.reshape(self.loss.vtrace_returns.vs, [-1]),
|
||||||
|
tf.reshape(to_batches(values)[:-1], [-1])),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
def optimizer(self):
|
def optimizer(self):
|
||||||
if self.config["opt_type"] == "adam":
|
if self.config["opt_type"] == "adam":
|
||||||
return tf.train.AdamOptimizer(self.config["lr"])
|
return tf.train.AdamOptimizer(self.cur_lr)
|
||||||
else:
|
else:
|
||||||
return tf.train.RMSPropOptimizer(
|
return tf.train.RMSPropOptimizer(self.cur_lr, self.config["decay"],
|
||||||
self.config["lr"], self.config["decay"],
|
self.config["momentum"],
|
||||||
self.config["momentum"], self.config["epsilon"])
|
self.config["epsilon"])
|
||||||
|
|
||||||
def gradients(self, optimizer):
|
def gradients(self, optimizer):
|
||||||
grads = tf.gradients(self.loss.total_loss, self.var_list)
|
grads = tf.gradients(self.loss.total_loss, self.var_list)
|
||||||
|
@ -196,18 +214,7 @@ class VTracePolicyGraph(TFPolicyGraph):
|
||||||
return {"behaviour_logits": self.model.outputs}
|
return {"behaviour_logits": self.model.outputs}
|
||||||
|
|
||||||
def extra_compute_grad_fetches(self):
|
def extra_compute_grad_fetches(self):
|
||||||
if self.config.get("summarize"):
|
return self.stats_fetches
|
||||||
return {
|
|
||||||
"stats": {
|
|
||||||
"policy_loss": self.loss.pi_loss,
|
|
||||||
"value_loss": self.loss.vf_loss,
|
|
||||||
"entropy": self.loss.entropy,
|
|
||||||
"grad_gnorm": tf.global_norm(self._grads),
|
|
||||||
"var_gnorm": tf.global_norm(self.var_list),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
else:
|
|
||||||
return {}
|
|
||||||
|
|
||||||
def postprocess_trajectory(self, sample_batch, other_agent_batches=None):
|
def postprocess_trajectory(self, sample_batch, other_agent_batches=None):
|
||||||
del sample_batch.data["new_obs"] # not used, so save some bandwidth
|
del sample_batch.data["new_obs"] # not used, so save some bandwidth
|
||||||
|
|
|
@ -32,6 +32,7 @@ class PGAgent(Agent):
|
||||||
|
|
||||||
_agent_name = "PG"
|
_agent_name = "PG"
|
||||||
_default_config = DEFAULT_CONFIG
|
_default_config = DEFAULT_CONFIG
|
||||||
|
_policy_graph = PGPolicyGraph
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def default_resource_request(cls, config):
|
def default_resource_request(cls, config):
|
||||||
|
@ -40,9 +41,10 @@ class PGAgent(Agent):
|
||||||
|
|
||||||
def _init(self):
|
def _init(self):
|
||||||
self.local_evaluator = self.make_local_evaluator(
|
self.local_evaluator = self.make_local_evaluator(
|
||||||
self.env_creator, PGPolicyGraph)
|
self.env_creator, self._policy_graph)
|
||||||
self.remote_evaluators = self.make_remote_evaluators(
|
self.remote_evaluators = self.make_remote_evaluators(
|
||||||
self.env_creator, PGPolicyGraph, self.config["num_workers"], {})
|
self.env_creator, self._policy_graph, self.config["num_workers"],
|
||||||
|
{})
|
||||||
self.optimizer = SyncSamplesOptimizer(self.local_evaluator,
|
self.optimizer = SyncSamplesOptimizer(self.local_evaluator,
|
||||||
self.remote_evaluators,
|
self.remote_evaluators,
|
||||||
self.config["optimizer"])
|
self.config["optimizer"])
|
||||||
|
|
|
@ -26,6 +26,10 @@ DEFAULT_CONFIG = with_common_config({
|
||||||
"num_sgd_iter": 30,
|
"num_sgd_iter": 30,
|
||||||
# Stepsize of SGD
|
# Stepsize of SGD
|
||||||
"sgd_stepsize": 5e-5,
|
"sgd_stepsize": 5e-5,
|
||||||
|
# Learning rate schedule
|
||||||
|
"lr_schedule": None,
|
||||||
|
# Share layers for value function
|
||||||
|
"vf_share_layers": False,
|
||||||
# Total SGD batch size across all devices for SGD (multi-gpu only)
|
# Total SGD batch size across all devices for SGD (multi-gpu only)
|
||||||
"sgd_batchsize": 128,
|
"sgd_batchsize": 128,
|
||||||
# Coefficient of the value function loss
|
# Coefficient of the value function loss
|
||||||
|
@ -63,6 +67,7 @@ class PPOAgent(Agent):
|
||||||
|
|
||||||
_agent_name = "PPO"
|
_agent_name = "PPO"
|
||||||
_default_config = DEFAULT_CONFIG
|
_default_config = DEFAULT_CONFIG
|
||||||
|
_policy_graph = PPOPolicyGraph
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def default_resource_request(cls, config):
|
def default_resource_request(cls, config):
|
||||||
|
@ -75,9 +80,9 @@ class PPOAgent(Agent):
|
||||||
|
|
||||||
def _init(self):
|
def _init(self):
|
||||||
self.local_evaluator = self.make_local_evaluator(
|
self.local_evaluator = self.make_local_evaluator(
|
||||||
self.env_creator, PPOPolicyGraph)
|
self.env_creator, self._policy_graph)
|
||||||
self.remote_evaluators = self.make_remote_evaluators(
|
self.remote_evaluators = self.make_remote_evaluators(
|
||||||
self.env_creator, PPOPolicyGraph, self.config["num_workers"], {
|
self.env_creator, self._policy_graph, self.config["num_workers"], {
|
||||||
"num_cpus": self.config["num_cpus_per_worker"],
|
"num_cpus": self.config["num_cpus_per_worker"],
|
||||||
"num_gpus": self.config["num_gpus_per_worker"]
|
"num_gpus": self.config["num_gpus_per_worker"]
|
||||||
})
|
})
|
||||||
|
@ -91,7 +96,6 @@ class PPOAgent(Agent):
|
||||||
self.optimizer = LocalMultiGPUOptimizer(
|
self.optimizer = LocalMultiGPUOptimizer(
|
||||||
self.local_evaluator, self.remote_evaluators, {
|
self.local_evaluator, self.remote_evaluators, {
|
||||||
"sgd_batch_size": self.config["sgd_batchsize"],
|
"sgd_batch_size": self.config["sgd_batchsize"],
|
||||||
"sgd_stepsize": self.config["sgd_stepsize"],
|
|
||||||
"num_sgd_iter": self.config["num_sgd_iter"],
|
"num_sgd_iter": self.config["num_sgd_iter"],
|
||||||
"num_gpus": self.config["num_gpus"],
|
"num_gpus": self.config["num_gpus"],
|
||||||
"timesteps_per_batch": self.config["timesteps_per_batch"],
|
"timesteps_per_batch": self.config["timesteps_per_batch"],
|
||||||
|
|
|
@ -6,8 +6,11 @@ import tensorflow as tf
|
||||||
|
|
||||||
import ray
|
import ray
|
||||||
from ray.rllib.evaluation.postprocessing import compute_advantages
|
from ray.rllib.evaluation.postprocessing import compute_advantages
|
||||||
from ray.rllib.evaluation.tf_policy_graph import TFPolicyGraph
|
from ray.rllib.evaluation.tf_policy_graph import TFPolicyGraph, \
|
||||||
|
LearningRateSchedule
|
||||||
from ray.rllib.models.catalog import ModelCatalog
|
from ray.rllib.models.catalog import ModelCatalog
|
||||||
|
from ray.rllib.models.misc import linear, normc_initializer
|
||||||
|
from ray.rllib.utils.explained_variance import explained_variance
|
||||||
|
|
||||||
|
|
||||||
class PPOLoss(object):
|
class PPOLoss(object):
|
||||||
|
@ -83,7 +86,7 @@ class PPOLoss(object):
|
||||||
self.loss = loss
|
self.loss = loss
|
||||||
|
|
||||||
|
|
||||||
class PPOPolicyGraph(TFPolicyGraph):
|
class PPOPolicyGraph(LearningRateSchedule, TFPolicyGraph):
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
observation_space,
|
observation_space,
|
||||||
action_space,
|
action_space,
|
||||||
|
@ -126,6 +129,7 @@ class PPOPolicyGraph(TFPolicyGraph):
|
||||||
tf.float32, name="value_targets", shape=(None, ))
|
tf.float32, name="value_targets", shape=(None, ))
|
||||||
existing_state_in = None
|
existing_state_in = None
|
||||||
existing_seq_lens = None
|
existing_seq_lens = None
|
||||||
|
self.observations = obs_ph
|
||||||
|
|
||||||
self.loss_in = [
|
self.loss_in = [
|
||||||
("obs", obs_ph),
|
("obs", obs_ph),
|
||||||
|
@ -154,16 +158,21 @@ class PPOPolicyGraph(TFPolicyGraph):
|
||||||
curr_action_dist = dist_cls(self.logits)
|
curr_action_dist = dist_cls(self.logits)
|
||||||
self.sampler = curr_action_dist.sample()
|
self.sampler = curr_action_dist.sample()
|
||||||
if self.config["use_gae"]:
|
if self.config["use_gae"]:
|
||||||
vf_config = self.config["model"].copy()
|
if self.config["vf_share_layers"]:
|
||||||
# Do not split the last layer of the value function into
|
self.value_function = tf.reshape(
|
||||||
# mean parameters and standard deviation parameters and
|
linear(self.model.last_layer, 1, "value",
|
||||||
# do not make the standard deviations free variables.
|
normc_initializer(1.0)), [-1])
|
||||||
vf_config["free_log_std"] = False
|
else:
|
||||||
vf_config["use_lstm"] = False
|
vf_config = self.config["model"].copy()
|
||||||
with tf.variable_scope("value_function"):
|
# Do not split the last layer of the value function into
|
||||||
self.value_function = ModelCatalog.get_model(
|
# mean parameters and standard deviation parameters and
|
||||||
obs_ph, 1, vf_config).outputs
|
# do not make the standard deviations free variables.
|
||||||
self.value_function = tf.reshape(self.value_function, [-1])
|
vf_config["free_log_std"] = False
|
||||||
|
vf_config["use_lstm"] = False
|
||||||
|
with tf.variable_scope("value_function"):
|
||||||
|
self.value_function = ModelCatalog.get_model(
|
||||||
|
obs_ph, 1, vf_config).outputs
|
||||||
|
self.value_function = tf.reshape(self.value_function, [-1])
|
||||||
else:
|
else:
|
||||||
self.value_function = tf.zeros(shape=tf.shape(obs_ph)[:1])
|
self.value_function = tf.zeros(shape=tf.shape(obs_ph)[:1])
|
||||||
|
|
||||||
|
@ -179,9 +188,11 @@ class PPOPolicyGraph(TFPolicyGraph):
|
||||||
self.kl_coeff,
|
self.kl_coeff,
|
||||||
entropy_coeff=self.config["entropy_coeff"],
|
entropy_coeff=self.config["entropy_coeff"],
|
||||||
clip_param=self.config["clip_param"],
|
clip_param=self.config["clip_param"],
|
||||||
vf_loss_coeff=self.config["kl_target"],
|
vf_loss_coeff=self.config["vf_loss_coeff"],
|
||||||
use_gae=self.config["use_gae"])
|
use_gae=self.config["use_gae"])
|
||||||
|
|
||||||
|
LearningRateSchedule.__init__(self, self.config["sgd_stepsize"],
|
||||||
|
self.config["lr_schedule"])
|
||||||
TFPolicyGraph.__init__(
|
TFPolicyGraph.__init__(
|
||||||
self,
|
self,
|
||||||
observation_space,
|
observation_space,
|
||||||
|
@ -197,6 +208,17 @@ class PPOPolicyGraph(TFPolicyGraph):
|
||||||
max_seq_len=config["model"]["max_seq_len"])
|
max_seq_len=config["model"]["max_seq_len"])
|
||||||
|
|
||||||
self.sess.run(tf.global_variables_initializer())
|
self.sess.run(tf.global_variables_initializer())
|
||||||
|
self.explained_variance = explained_variance(value_targets_ph,
|
||||||
|
self.value_function)
|
||||||
|
self.stats_fetches = {
|
||||||
|
"cur_lr": tf.cast(self.cur_lr, tf.float64),
|
||||||
|
"total_loss": self.loss_obj.loss,
|
||||||
|
"policy_loss": self.loss_obj.mean_policy_loss,
|
||||||
|
"vf_loss": self.loss_obj.mean_vf_loss,
|
||||||
|
"vf_explained_var": self.explained_variance,
|
||||||
|
"kl": self.loss_obj.mean_kl,
|
||||||
|
"entropy": self.loss_obj.mean_entropy
|
||||||
|
}
|
||||||
|
|
||||||
def copy(self, existing_inputs):
|
def copy(self, existing_inputs):
|
||||||
"""Creates a copy of self using existing input placeholders."""
|
"""Creates a copy of self using existing input placeholders."""
|
||||||
|
@ -210,13 +232,7 @@ class PPOPolicyGraph(TFPolicyGraph):
|
||||||
return {"vf_preds": self.value_function, "logits": self.logits}
|
return {"vf_preds": self.value_function, "logits": self.logits}
|
||||||
|
|
||||||
def extra_compute_grad_fetches(self):
|
def extra_compute_grad_fetches(self):
|
||||||
return {
|
return self.stats_fetches
|
||||||
"total_loss": self.loss_obj.loss,
|
|
||||||
"policy_loss": self.loss_obj.mean_policy_loss,
|
|
||||||
"vf_loss": self.loss_obj.mean_vf_loss,
|
|
||||||
"kl": self.loss_obj.mean_kl,
|
|
||||||
"entropy": self.loss_obj.mean_entropy
|
|
||||||
}
|
|
||||||
|
|
||||||
def update_kl(self, sampled_kl):
|
def update_kl(self, sampled_kl):
|
||||||
if sampled_kl > 2.0 * self.kl_target:
|
if sampled_kl > 2.0 * self.kl_target:
|
||||||
|
@ -226,8 +242,24 @@ class PPOPolicyGraph(TFPolicyGraph):
|
||||||
self.kl_coeff.load(self.kl_coeff_val, session=self.sess)
|
self.kl_coeff.load(self.kl_coeff_val, session=self.sess)
|
||||||
return self.kl_coeff_val
|
return self.kl_coeff_val
|
||||||
|
|
||||||
|
def value(self, ob, *args):
|
||||||
|
feed_dict = {self.observations: [ob], self.model.seq_lens: [1]}
|
||||||
|
assert len(args) == len(self.model.state_in), \
|
||||||
|
(args, self.model.state_in)
|
||||||
|
for k, v in zip(self.model.state_in, args):
|
||||||
|
feed_dict[k] = v
|
||||||
|
vf = self.sess.run(self.value_function, feed_dict)
|
||||||
|
return vf[0]
|
||||||
|
|
||||||
def postprocess_trajectory(self, sample_batch, other_agent_batches=None):
|
def postprocess_trajectory(self, sample_batch, other_agent_batches=None):
|
||||||
last_r = 0.0
|
completed = sample_batch["dones"][-1]
|
||||||
|
if completed:
|
||||||
|
last_r = 0.0
|
||||||
|
else:
|
||||||
|
next_state = []
|
||||||
|
for i in range(len(self.model.state_in)):
|
||||||
|
next_state.append([sample_batch["state_out_{}".format(i)][-1]])
|
||||||
|
last_r = self.value(sample_batch["new_obs"][-1], *next_state)
|
||||||
batch = compute_advantages(
|
batch = compute_advantages(
|
||||||
sample_batch,
|
sample_batch,
|
||||||
last_r,
|
last_r,
|
||||||
|
@ -236,9 +268,6 @@ class PPOPolicyGraph(TFPolicyGraph):
|
||||||
use_gae=self.config["use_gae"])
|
use_gae=self.config["use_gae"])
|
||||||
return batch
|
return batch
|
||||||
|
|
||||||
def optimizer(self):
|
|
||||||
return tf.train.AdamOptimizer(self.config["sgd_stepsize"])
|
|
||||||
|
|
||||||
def gradients(self, optimizer):
|
def gradients(self, optimizer):
|
||||||
return optimizer.compute_gradients(
|
return optimizer.compute_gradients(
|
||||||
self._loss, colocate_gradients_with_ops=True)
|
self._loss, colocate_gradients_with_ops=True)
|
||||||
|
|
6
python/ray/rllib/env/async_vector_env.py
vendored
6
python/ray/rllib/env/async_vector_env.py
vendored
|
@ -123,12 +123,12 @@ class AsyncVectorEnv(object):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def get_unwrapped(self):
|
def get_unwrapped(self):
|
||||||
"""Return a reference to some underlying gym env, if any.
|
"""Return a reference to the underlying gym envs, if any.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
env (gym.Env|None): Underlying gym env or None.
|
envs (list): Underlying gym envs or [].
|
||||||
"""
|
"""
|
||||||
return None
|
return []
|
||||||
|
|
||||||
|
|
||||||
# Fixed agent identifier when there is only the single agent in the env
|
# Fixed agent identifier when there is only the single agent in the env
|
||||||
|
|
69
python/ray/rllib/env/atari_wrappers.py
vendored
69
python/ray/rllib/env/atari_wrappers.py
vendored
|
@ -10,6 +10,68 @@ def is_atari(env):
|
||||||
return hasattr(env, "unwrapped") and hasattr(env.unwrapped, "ale")
|
return hasattr(env, "unwrapped") and hasattr(env.unwrapped, "ale")
|
||||||
|
|
||||||
|
|
||||||
|
def get_wrapper_by_cls(env, cls):
|
||||||
|
"""Returns the gym env wrapper of the given class, or None."""
|
||||||
|
currentenv = env
|
||||||
|
while True:
|
||||||
|
if isinstance(currentenv, cls):
|
||||||
|
return currentenv
|
||||||
|
elif isinstance(currentenv, gym.Wrapper):
|
||||||
|
currentenv = currentenv.env
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
class MonitorEnv(gym.Wrapper):
|
||||||
|
def __init__(self, env=None):
|
||||||
|
"""Record episodes stats prior to EpisodicLifeEnv, etc."""
|
||||||
|
gym.Wrapper.__init__(self, env)
|
||||||
|
self._current_reward = None
|
||||||
|
self._num_steps = None
|
||||||
|
self._total_steps = None
|
||||||
|
self._episode_rewards = []
|
||||||
|
self._episode_lengths = []
|
||||||
|
self._num_episodes = 0
|
||||||
|
self._num_returned = 0
|
||||||
|
|
||||||
|
def reset(self, **kwargs):
|
||||||
|
obs = self.env.reset(**kwargs)
|
||||||
|
|
||||||
|
if self._total_steps is None:
|
||||||
|
self._total_steps = sum(self._episode_lengths)
|
||||||
|
|
||||||
|
if self._current_reward is not None:
|
||||||
|
self._episode_rewards.append(self._current_reward)
|
||||||
|
self._episode_lengths.append(self._num_steps)
|
||||||
|
self._num_episodes += 1
|
||||||
|
|
||||||
|
self._current_reward = 0
|
||||||
|
self._num_steps = 0
|
||||||
|
|
||||||
|
return obs
|
||||||
|
|
||||||
|
def step(self, action):
|
||||||
|
obs, rew, done, info = self.env.step(action)
|
||||||
|
self._current_reward += rew
|
||||||
|
self._num_steps += 1
|
||||||
|
self._total_steps += 1
|
||||||
|
return (obs, rew, done, info)
|
||||||
|
|
||||||
|
def get_episode_rewards(self):
|
||||||
|
return self._episode_rewards
|
||||||
|
|
||||||
|
def get_episode_lengths(self):
|
||||||
|
return self._episode_lengths
|
||||||
|
|
||||||
|
def get_total_steps(self):
|
||||||
|
return self._total_steps
|
||||||
|
|
||||||
|
def next_episode_results(self):
|
||||||
|
for i in range(self._num_returned, len(self._episode_rewards)):
|
||||||
|
yield (self._episode_rewards[i], self._episode_lengths[i])
|
||||||
|
self._num_returned = len(self._episode_rewards)
|
||||||
|
|
||||||
|
|
||||||
class NoopResetEnv(gym.Wrapper):
|
class NoopResetEnv(gym.Wrapper):
|
||||||
def __init__(self, env, noop_max=30):
|
def __init__(self, env, noop_max=30):
|
||||||
"""Sample initial states by taking random number of no-ops on reset.
|
"""Sample initial states by taking random number of no-ops on reset.
|
||||||
|
@ -201,14 +263,16 @@ class ScaledFloatFrame(gym.ObservationWrapper):
|
||||||
return np.array(observation).astype(np.float32) / 255.0
|
return np.array(observation).astype(np.float32) / 255.0
|
||||||
|
|
||||||
|
|
||||||
def wrap_deepmind(env, dim=84):
|
def wrap_deepmind(env, dim=84, framestack=True):
|
||||||
"""Configure environment for DeepMind-style Atari.
|
"""Configure environment for DeepMind-style Atari.
|
||||||
|
|
||||||
Note that we assume reward clipping is done outside the wrapper.
|
Note that we assume reward clipping is done outside the wrapper.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
dim (int): Dimension to resize observations to (dim x dim).
|
dim (int): Dimension to resize observations to (dim x dim).
|
||||||
|
framestack (bool): Whether to framestack observations.
|
||||||
"""
|
"""
|
||||||
|
env = MonitorEnv(env)
|
||||||
env = NoopResetEnv(env, noop_max=30)
|
env = NoopResetEnv(env, noop_max=30)
|
||||||
if 'NoFrameskip' in env.spec.id:
|
if 'NoFrameskip' in env.spec.id:
|
||||||
env = MaxAndSkipEnv(env, skip=4)
|
env = MaxAndSkipEnv(env, skip=4)
|
||||||
|
@ -218,5 +282,6 @@ def wrap_deepmind(env, dim=84):
|
||||||
env = WarpFrame(env, dim)
|
env = WarpFrame(env, dim)
|
||||||
# env = ScaledFloatFrame(env) # TODO: use for dqn?
|
# env = ScaledFloatFrame(env) # TODO: use for dqn?
|
||||||
# env = ClipRewardEnv(env) # reward clipping is handled by policy eval
|
# env = ClipRewardEnv(env) # reward clipping is handled by policy eval
|
||||||
env = FrameStack(env, 4)
|
if framestack:
|
||||||
|
env = FrameStack(env, 4)
|
||||||
return env
|
return env
|
||||||
|
|
4
python/ray/rllib/env/vector_env.py
vendored
4
python/ray/rllib/env/vector_env.py
vendored
|
@ -49,7 +49,7 @@ class VectorEnv(object):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def get_unwrapped(self):
|
def get_unwrapped(self):
|
||||||
"""Returns a single instance of the underlying env."""
|
"""Returns the underlying env instances."""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
@ -87,4 +87,4 @@ class _VectorizedGymEnv(VectorEnv):
|
||||||
return obs_batch, rew_batch, done_batch, info_batch
|
return obs_batch, rew_batch, done_batch, info_batch
|
||||||
|
|
||||||
def get_unwrapped(self):
|
def get_unwrapped(self):
|
||||||
return self.envs[0]
|
return self.envs
|
||||||
|
|
|
@ -12,21 +12,36 @@ from ray.rllib.evaluation.sample_batch import DEFAULT_POLICY_ID
|
||||||
def collect_metrics(local_evaluator, remote_evaluators=[]):
|
def collect_metrics(local_evaluator, remote_evaluators=[]):
|
||||||
"""Gathers episode metrics from PolicyEvaluator instances."""
|
"""Gathers episode metrics from PolicyEvaluator instances."""
|
||||||
|
|
||||||
episode_rewards = []
|
episodes = collect_episodes(local_evaluator, remote_evaluators)
|
||||||
episode_lengths = []
|
return summarize_episodes(episodes)
|
||||||
policy_rewards = collections.defaultdict(list)
|
|
||||||
|
|
||||||
|
def collect_episodes(local_evaluator, remote_evaluators=[]):
|
||||||
|
"""Gathers new episodes metrics tuples from the given evaluators."""
|
||||||
|
|
||||||
metric_lists = ray.get([
|
metric_lists = ray.get([
|
||||||
a.apply.remote(lambda ev: ev.sampler.get_metrics())
|
a.apply.remote(lambda ev: ev.sampler.get_metrics())
|
||||||
for a in remote_evaluators
|
for a in remote_evaluators
|
||||||
])
|
])
|
||||||
metric_lists.append(local_evaluator.sampler.get_metrics())
|
metric_lists.append(local_evaluator.sampler.get_metrics())
|
||||||
|
episodes = []
|
||||||
for metrics in metric_lists:
|
for metrics in metric_lists:
|
||||||
for episode in metrics:
|
episodes.extend(metrics)
|
||||||
episode_lengths.append(episode.episode_length)
|
return episodes
|
||||||
episode_rewards.append(episode.episode_reward)
|
|
||||||
for (_, policy_id), reward in episode.agent_rewards.items():
|
|
||||||
if policy_id != DEFAULT_POLICY_ID:
|
def summarize_episodes(episodes):
|
||||||
policy_rewards[policy_id].append(reward)
|
"""Summarizes a set of episode metrics tuples."""
|
||||||
|
|
||||||
|
episode_rewards = []
|
||||||
|
episode_lengths = []
|
||||||
|
policy_rewards = collections.defaultdict(list)
|
||||||
|
for episode in episodes:
|
||||||
|
episode_lengths.append(episode.episode_length)
|
||||||
|
episode_rewards.append(episode.episode_reward)
|
||||||
|
for (_, policy_id), reward in episode.agent_rewards.items():
|
||||||
|
if policy_id != DEFAULT_POLICY_ID:
|
||||||
|
policy_rewards[policy_id].append(reward)
|
||||||
if episode_rewards:
|
if episode_rewards:
|
||||||
min_reward = min(episode_rewards)
|
min_reward = min(episode_rewards)
|
||||||
max_reward = max(episode_rewards)
|
max_reward = max(episode_rewards)
|
||||||
|
|
|
@ -100,7 +100,8 @@ class PolicyEvaluator(EvaluatorInterface):
|
||||||
env_config=None,
|
env_config=None,
|
||||||
model_config=None,
|
model_config=None,
|
||||||
policy_config=None,
|
policy_config=None,
|
||||||
worker_index=0):
|
worker_index=0,
|
||||||
|
monitor_path=None):
|
||||||
"""Initialize a policy evaluator.
|
"""Initialize a policy evaluator.
|
||||||
|
|
||||||
Arguments:
|
Arguments:
|
||||||
|
@ -158,6 +159,8 @@ class PolicyEvaluator(EvaluatorInterface):
|
||||||
worker_index (int): For remote evaluators, this should be set to a
|
worker_index (int): For remote evaluators, this should be set to a
|
||||||
non-zero and unique value. This index is passed to created envs
|
non-zero and unique value. This index is passed to created envs
|
||||||
through EnvContext so that envs can be configured per worker.
|
through EnvContext so that envs can be configured per worker.
|
||||||
|
monitor_path (str): Write out episode stats and videos to this
|
||||||
|
directory if specified.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
env_context = EnvContext(env_config or {}, worker_index)
|
env_context = EnvContext(env_config or {}, worker_index)
|
||||||
|
@ -184,12 +187,22 @@ class PolicyEvaluator(EvaluatorInterface):
|
||||||
preprocessor_pref == "deepmind":
|
preprocessor_pref == "deepmind":
|
||||||
|
|
||||||
def wrap(env):
|
def wrap(env):
|
||||||
return wrap_deepmind(env, dim=model_config.get("dim", 84))
|
env = wrap_deepmind(
|
||||||
|
env,
|
||||||
|
dim=model_config.get("dim", 84),
|
||||||
|
framestack=not model_config.get("use_lstm")
|
||||||
|
and not model_config.get("no_framestack"))
|
||||||
|
if monitor_path:
|
||||||
|
env = _monitor(env, monitor_path)
|
||||||
|
return env
|
||||||
else:
|
else:
|
||||||
|
|
||||||
def wrap(env):
|
def wrap(env):
|
||||||
return ModelCatalog.get_preprocessor_as_wrapper(
|
env = ModelCatalog.get_preprocessor_as_wrapper(
|
||||||
env, model_config)
|
env, model_config)
|
||||||
|
if monitor_path:
|
||||||
|
env = _monitor(env, monitor_path)
|
||||||
|
return env
|
||||||
|
|
||||||
self.env = wrap(self.env)
|
self.env = wrap(self.env)
|
||||||
|
|
||||||
|
@ -452,6 +465,9 @@ class PolicyEvaluator(EvaluatorInterface):
|
||||||
for pid, state in objs["state"].items():
|
for pid, state in objs["state"].items():
|
||||||
self.policy_map[pid].set_state(state)
|
self.policy_map[pid].set_state(state)
|
||||||
|
|
||||||
|
def set_global_vars(self, global_vars):
|
||||||
|
self.foreach_policy(lambda p, _: p.on_global_var_update(global_vars))
|
||||||
|
|
||||||
|
|
||||||
def _validate_and_canonicalize(policy_graph, env):
|
def _validate_and_canonicalize(policy_graph, env):
|
||||||
if isinstance(policy_graph, dict):
|
if isinstance(policy_graph, dict):
|
||||||
|
@ -489,6 +505,10 @@ def _validate_and_canonicalize(policy_graph, env):
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _monitor(env, path):
|
||||||
|
return gym.wrappers.Monitor(env, path, resume=True)
|
||||||
|
|
||||||
|
|
||||||
def _has_tensorflow_graph(policy_dict):
|
def _has_tensorflow_graph(policy_dict):
|
||||||
for policy, _, _, _ in policy_dict.values():
|
for policy, _, _, _ in policy_dict.values():
|
||||||
if issubclass(policy, TFPolicyGraph):
|
if issubclass(policy, TFPolicyGraph):
|
||||||
|
|
|
@ -174,3 +174,11 @@ class PolicyGraph(object):
|
||||||
state (obj): Serialized local state.
|
state (obj): Serialized local state.
|
||||||
"""
|
"""
|
||||||
self.set_weights(state)
|
self.set_weights(state)
|
||||||
|
|
||||||
|
def on_global_var_update(self, global_vars):
|
||||||
|
"""Called on an update to global vars.
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
global_vars (dict): Global variables broadcast from the driver.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
|
@ -11,6 +11,7 @@ from ray.rllib.evaluation.sample_batch import MultiAgentSampleBatchBuilder, \
|
||||||
MultiAgentBatch
|
MultiAgentBatch
|
||||||
from ray.rllib.evaluation.tf_policy_graph import TFPolicyGraph
|
from ray.rllib.evaluation.tf_policy_graph import TFPolicyGraph
|
||||||
from ray.rllib.env.async_vector_env import AsyncVectorEnv
|
from ray.rllib.env.async_vector_env import AsyncVectorEnv
|
||||||
|
from ray.rllib.env.atari_wrappers import get_wrapper_by_cls, MonitorEnv
|
||||||
from ray.rllib.utils.tf_run_builder import TFRunBuilder
|
from ray.rllib.utils.tf_run_builder import TFRunBuilder
|
||||||
|
|
||||||
RolloutMetrics = namedtuple(
|
RolloutMetrics = namedtuple(
|
||||||
|
@ -214,7 +215,8 @@ def _env_runner(async_vector_env,
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if not horizon:
|
if not horizon:
|
||||||
horizon = async_vector_env.get_unwrapped().spec.max_episode_steps
|
horizon = (
|
||||||
|
async_vector_env.get_unwrapped()[0].spec.max_episode_steps)
|
||||||
except Exception:
|
except Exception:
|
||||||
print("Warning, no horizon specified, assuming infinite")
|
print("Warning, no horizon specified, assuming infinite")
|
||||||
if not horizon:
|
if not horizon:
|
||||||
|
@ -259,8 +261,13 @@ def _env_runner(async_vector_env,
|
||||||
# Check episode termination conditions
|
# Check episode termination conditions
|
||||||
if dones[env_id]["__all__"] or episode.length >= horizon:
|
if dones[env_id]["__all__"] or episode.length >= horizon:
|
||||||
all_done = True
|
all_done = True
|
||||||
yield RolloutMetrics(episode.length, episode.total_reward,
|
atari_metrics = _fetch_atari_metrics(async_vector_env)
|
||||||
dict(episode.agent_rewards))
|
if atari_metrics is not None:
|
||||||
|
for m in atari_metrics:
|
||||||
|
yield m
|
||||||
|
else:
|
||||||
|
yield RolloutMetrics(episode.length, episode.total_reward,
|
||||||
|
dict(episode.agent_rewards))
|
||||||
else:
|
else:
|
||||||
all_done = False
|
all_done = False
|
||||||
# At least send an empty dict if not done
|
# At least send an empty dict if not done
|
||||||
|
@ -384,6 +391,24 @@ def _env_runner(async_vector_env,
|
||||||
async_vector_env.send_actions(dict(actions_to_send))
|
async_vector_env.send_actions(dict(actions_to_send))
|
||||||
|
|
||||||
|
|
||||||
|
def _fetch_atari_metrics(async_vector_env):
|
||||||
|
"""Atari games have multiple logical episodes, one per life.
|
||||||
|
|
||||||
|
However for metrics reporting we count full episodes all lives included.
|
||||||
|
"""
|
||||||
|
unwrapped = async_vector_env.get_unwrapped()
|
||||||
|
if not unwrapped:
|
||||||
|
return None
|
||||||
|
atari_out = []
|
||||||
|
for u in unwrapped:
|
||||||
|
monitor = get_wrapper_by_cls(u, MonitorEnv)
|
||||||
|
if not monitor:
|
||||||
|
return None
|
||||||
|
for eps_rew, eps_len in monitor.next_episode_results():
|
||||||
|
atari_out.append(RolloutMetrics(eps_len, eps_rew, {}))
|
||||||
|
return atari_out
|
||||||
|
|
||||||
|
|
||||||
def _to_column_format(rnn_state_rows):
|
def _to_column_format(rnn_state_rows):
|
||||||
num_cols = len(rnn_state_rows[0])
|
num_cols = len(rnn_state_rows[0])
|
||||||
return [[row[i] for row in rnn_state_rows] for i in range(num_cols)]
|
return [[row[i] for row in rnn_state_rows] for i in range(num_cols)]
|
||||||
|
|
|
@ -9,6 +9,7 @@ import ray
|
||||||
from ray.rllib.evaluation.policy_graph import PolicyGraph
|
from ray.rllib.evaluation.policy_graph import PolicyGraph
|
||||||
from ray.rllib.models.lstm import chop_into_sequences
|
from ray.rllib.models.lstm import chop_into_sequences
|
||||||
from ray.rllib.utils.tf_run_builder import TFRunBuilder
|
from ray.rllib.utils.tf_run_builder import TFRunBuilder
|
||||||
|
from ray.rllib.utils.schedules import ConstantSchedule, PiecewiseSchedule
|
||||||
|
|
||||||
|
|
||||||
class TFPolicyGraph(PolicyGraph):
|
class TFPolicyGraph(PolicyGraph):
|
||||||
|
@ -229,3 +230,24 @@ class TFPolicyGraph(PolicyGraph):
|
||||||
|
|
||||||
def loss_inputs(self):
|
def loss_inputs(self):
|
||||||
return self._loss_inputs
|
return self._loss_inputs
|
||||||
|
|
||||||
|
|
||||||
|
class LearningRateSchedule(object):
|
||||||
|
"""Mixin for TFPolicyGraph that adds a learning rate schedule."""
|
||||||
|
|
||||||
|
def __init__(self, lr, lr_schedule):
|
||||||
|
self.cur_lr = tf.get_variable("lr", initializer=lr)
|
||||||
|
if lr_schedule is None:
|
||||||
|
self.lr_schedule = ConstantSchedule(lr)
|
||||||
|
else:
|
||||||
|
self.lr_schedule = PiecewiseSchedule(
|
||||||
|
lr_schedule, outside_value=lr_schedule[-1][-1])
|
||||||
|
|
||||||
|
def on_global_var_update(self, global_vars):
|
||||||
|
super(LearningRateSchedule, self).on_global_var_update(global_vars)
|
||||||
|
self.cur_lr.load(
|
||||||
|
self.lr_schedule.value(global_vars["timestep"]),
|
||||||
|
session=self._sess)
|
||||||
|
|
||||||
|
def optimizer(self):
|
||||||
|
return tf.train.AdamOptimizer(self.cur_lr)
|
||||||
|
|
|
@ -16,7 +16,6 @@ more info.
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
import tensorflow.contrib.rnn as rnn
|
import tensorflow.contrib.rnn as rnn
|
||||||
import distutils.version
|
|
||||||
|
|
||||||
from ray.rllib.models.misc import linear, normc_initializer
|
from ray.rllib.models.misc import linear, normc_initializer
|
||||||
from ray.rllib.models.model import Model
|
from ray.rllib.models.model import Model
|
||||||
|
@ -137,15 +136,10 @@ class LSTM(Model):
|
||||||
|
|
||||||
def _build_layers(self, inputs, num_outputs, options):
|
def _build_layers(self, inputs, num_outputs, options):
|
||||||
cell_size = options.get("lstm_cell_size", 256)
|
cell_size = options.get("lstm_cell_size", 256)
|
||||||
use_tf100_api = (distutils.version.LooseVersion(tf.VERSION) >=
|
|
||||||
distutils.version.LooseVersion("1.0.0"))
|
|
||||||
last_layer = add_time_dimension(inputs, self.seq_lens)
|
last_layer = add_time_dimension(inputs, self.seq_lens)
|
||||||
|
|
||||||
# Setup the LSTM cell
|
# Setup the LSTM cell
|
||||||
if use_tf100_api:
|
lstm = rnn.BasicLSTMCell(cell_size, state_is_tuple=True)
|
||||||
lstm = rnn.BasicLSTMCell(cell_size, state_is_tuple=True)
|
|
||||||
else:
|
|
||||||
lstm = rnn.rnn_cell.BasicLSTMCell(cell_size, state_is_tuple=True)
|
|
||||||
self.state_init = [
|
self.state_init = [
|
||||||
np.zeros(lstm.state_size.c, np.float32),
|
np.zeros(lstm.state_size.c, np.float32),
|
||||||
np.zeros(lstm.state_size.h, np.float32)
|
np.zeros(lstm.state_size.h, np.float32)
|
||||||
|
@ -162,16 +156,13 @@ class LSTM(Model):
|
||||||
self.state_in = [c_in, h_in]
|
self.state_in = [c_in, h_in]
|
||||||
|
|
||||||
# Setup LSTM outputs
|
# Setup LSTM outputs
|
||||||
if use_tf100_api:
|
|
||||||
state_in = rnn.LSTMStateTuple(c_in, h_in)
|
|
||||||
else:
|
|
||||||
state_in = rnn.rnn_cell.LSTMStateTuple(c_in, h_in)
|
|
||||||
lstm_out, lstm_state = tf.nn.dynamic_rnn(
|
lstm_out, lstm_state = tf.nn.dynamic_rnn(
|
||||||
lstm,
|
lstm,
|
||||||
last_layer,
|
last_layer,
|
||||||
initial_state=state_in,
|
|
||||||
sequence_length=self.seq_lens,
|
sequence_length=self.seq_lens,
|
||||||
time_major=False)
|
time_major=False,
|
||||||
|
dtype=tf.float32)
|
||||||
|
|
||||||
self.state_out = list(lstm_state)
|
self.state_out = list(lstm_state)
|
||||||
|
|
||||||
# Compute outputs
|
# Compute outputs
|
||||||
|
|
|
@ -32,13 +32,11 @@ class LocalMultiGPUOptimizer(PolicyOptimizer):
|
||||||
|
|
||||||
def _init(self,
|
def _init(self,
|
||||||
sgd_batch_size=128,
|
sgd_batch_size=128,
|
||||||
sgd_stepsize=5e-5,
|
|
||||||
num_sgd_iter=10,
|
num_sgd_iter=10,
|
||||||
timesteps_per_batch=1024,
|
timesteps_per_batch=1024,
|
||||||
num_gpus=0,
|
num_gpus=0,
|
||||||
standardize_fields=[]):
|
standardize_fields=[]):
|
||||||
self.batch_size = sgd_batch_size
|
self.batch_size = sgd_batch_size
|
||||||
self.sgd_stepsize = sgd_stepsize
|
|
||||||
self.num_sgd_iter = num_sgd_iter
|
self.num_sgd_iter = num_sgd_iter
|
||||||
self.timesteps_per_batch = timesteps_per_batch
|
self.timesteps_per_batch = timesteps_per_batch
|
||||||
if not num_gpus:
|
if not num_gpus:
|
||||||
|
@ -81,8 +79,7 @@ class LocalMultiGPUOptimizer(PolicyOptimizer):
|
||||||
else:
|
else:
|
||||||
rnn_inputs = []
|
rnn_inputs = []
|
||||||
self.par_opt = LocalSyncParallelOptimizer(
|
self.par_opt = LocalSyncParallelOptimizer(
|
||||||
tf.train.AdamOptimizer(
|
self.policy.optimizer(), self.devices,
|
||||||
self.sgd_stepsize), self.devices,
|
|
||||||
[v for _, v in self.policy.loss_inputs()], rnn_inputs,
|
[v for _, v in self.policy.loss_inputs()], rnn_inputs,
|
||||||
self.per_device_batch_size, self.policy.copy,
|
self.per_device_batch_size, self.policy.copy,
|
||||||
os.getcwd())
|
os.getcwd())
|
||||||
|
|
|
@ -4,7 +4,7 @@ from __future__ import print_function
|
||||||
|
|
||||||
import ray
|
import ray
|
||||||
from ray.rllib.evaluation.policy_evaluator import PolicyEvaluator
|
from ray.rllib.evaluation.policy_evaluator import PolicyEvaluator
|
||||||
from ray.rllib.evaluation.metrics import collect_metrics
|
from ray.rllib.evaluation.metrics import collect_episodes, summarize_episodes
|
||||||
from ray.rllib.evaluation.sample_batch import MultiAgentBatch
|
from ray.rllib.evaluation.sample_batch import MultiAgentBatch
|
||||||
|
|
||||||
|
|
||||||
|
@ -45,6 +45,7 @@ class PolicyOptimizer(object):
|
||||||
"""
|
"""
|
||||||
self.local_evaluator = local_evaluator
|
self.local_evaluator = local_evaluator
|
||||||
self.remote_evaluators = remote_evaluators or []
|
self.remote_evaluators = remote_evaluators or []
|
||||||
|
self.episode_history = []
|
||||||
self.config = config or {}
|
self.config = config or {}
|
||||||
self._init(**self.config)
|
self._init(**self.config)
|
||||||
|
|
||||||
|
@ -78,14 +79,26 @@ class PolicyOptimizer(object):
|
||||||
"num_steps_sampled": self.num_steps_sampled,
|
"num_steps_sampled": self.num_steps_sampled,
|
||||||
}
|
}
|
||||||
|
|
||||||
def collect_metrics(self):
|
def collect_metrics(self, min_history=100):
|
||||||
"""Returns evaluator and optimizer stats.
|
"""Returns evaluator and optimizer stats.
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
min_history (int): Min history length to smooth results over.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
res (dict): A training result dict from evaluator metrics with
|
res (dict): A training result dict from evaluator metrics with
|
||||||
`info` replaced with stats from self.
|
`info` replaced with stats from self.
|
||||||
"""
|
"""
|
||||||
res = collect_metrics(self.local_evaluator, self.remote_evaluators)
|
episodes = collect_episodes(self.local_evaluator,
|
||||||
|
self.remote_evaluators)
|
||||||
|
orig_episodes = list(episodes)
|
||||||
|
missing = min_history - len(episodes)
|
||||||
|
if missing > 0:
|
||||||
|
episodes.extend(self.episode_history[-missing:])
|
||||||
|
assert len(episodes) <= min_history
|
||||||
|
self.episode_history.extend(orig_episodes)
|
||||||
|
self.episode_history = self.episode_history[-min_history:]
|
||||||
|
res = summarize_episodes(episodes)
|
||||||
res.update(info=self.stats())
|
res.update(info=self.stats())
|
||||||
return res
|
return res
|
||||||
|
|
||||||
|
|
|
@ -24,6 +24,7 @@ class SyncSamplesOptimizer(PolicyOptimizer):
|
||||||
self.throughput = RunningStat()
|
self.throughput = RunningStat()
|
||||||
self.num_sgd_iter = num_sgd_iter
|
self.num_sgd_iter = num_sgd_iter
|
||||||
self.timesteps_per_batch = timesteps_per_batch
|
self.timesteps_per_batch = timesteps_per_batch
|
||||||
|
self.learner_stats = {}
|
||||||
|
|
||||||
def step(self):
|
def step(self):
|
||||||
with self.update_weights_timer:
|
with self.update_weights_timer:
|
||||||
|
@ -48,6 +49,8 @@ class SyncSamplesOptimizer(PolicyOptimizer):
|
||||||
with self.grad_timer:
|
with self.grad_timer:
|
||||||
for i in range(self.num_sgd_iter):
|
for i in range(self.num_sgd_iter):
|
||||||
fetches = self.local_evaluator.compute_apply(samples)
|
fetches = self.local_evaluator.compute_apply(samples)
|
||||||
|
if "stats" in fetches:
|
||||||
|
self.learner_stats = fetches["stats"]
|
||||||
if self.num_sgd_iter > 1:
|
if self.num_sgd_iter > 1:
|
||||||
print(i, fetches)
|
print(i, fetches)
|
||||||
self.grad_timer.push_units_processed(samples.count)
|
self.grad_timer.push_units_processed(samples.count)
|
||||||
|
@ -68,4 +71,5 @@ class SyncSamplesOptimizer(PolicyOptimizer):
|
||||||
"sample_peak_throughput": round(
|
"sample_peak_throughput": round(
|
||||||
self.sample_timer.mean_throughput, 3),
|
self.sample_timer.mean_throughput, 3),
|
||||||
"opt_samples": round(self.grad_timer.mean_units_processed, 3),
|
"opt_samples": round(self.grad_timer.mean_units_processed, 3),
|
||||||
|
"learner": self.learner_stats,
|
||||||
})
|
})
|
||||||
|
|
|
@ -8,6 +8,7 @@ import unittest
|
||||||
|
|
||||||
import ray
|
import ray
|
||||||
from ray.rllib.agents.pg import PGAgent
|
from ray.rllib.agents.pg import PGAgent
|
||||||
|
from ray.rllib.agents.a3c import A2CAgent
|
||||||
from ray.rllib.evaluation.policy_evaluator import PolicyEvaluator
|
from ray.rllib.evaluation.policy_evaluator import PolicyEvaluator
|
||||||
from ray.rllib.evaluation.metrics import collect_metrics
|
from ray.rllib.evaluation.metrics import collect_metrics
|
||||||
from ray.rllib.evaluation.policy_graph import PolicyGraph
|
from ray.rllib.evaluation.policy_graph import PolicyGraph
|
||||||
|
@ -96,6 +97,9 @@ class MockVectorEnv(VectorEnv):
|
||||||
info_batch.append(info)
|
info_batch.append(info)
|
||||||
return obs_batch, rew_batch, done_batch, info_batch
|
return obs_batch, rew_batch, done_batch, info_batch
|
||||||
|
|
||||||
|
def get_unwrapped(self):
|
||||||
|
return self.envs
|
||||||
|
|
||||||
|
|
||||||
class TestPolicyEvaluator(unittest.TestCase):
|
class TestPolicyEvaluator(unittest.TestCase):
|
||||||
def testBasic(self):
|
def testBasic(self):
|
||||||
|
@ -107,6 +111,17 @@ class TestPolicyEvaluator(unittest.TestCase):
|
||||||
self.assertIn(key, batch)
|
self.assertIn(key, batch)
|
||||||
self.assertGreater(batch["advantages"][0], 1)
|
self.assertGreater(batch["advantages"][0], 1)
|
||||||
|
|
||||||
|
def testGlobalVarsUpdate(self):
|
||||||
|
agent = A2CAgent(
|
||||||
|
env="CartPole-v0",
|
||||||
|
config={
|
||||||
|
"lr_schedule": [[0, 0.1], [400, 0.000001]],
|
||||||
|
})
|
||||||
|
result = agent.train()
|
||||||
|
self.assertGreater(result["info"]["learner"]["cur_lr"], 0.01)
|
||||||
|
result2 = agent.train()
|
||||||
|
self.assertLess(result2["info"]["learner"]["cur_lr"], 0.0001)
|
||||||
|
|
||||||
def testQueryEvaluators(self):
|
def testQueryEvaluators(self):
|
||||||
register_env("test", lambda _: gym.make("CartPole-v0"))
|
register_env("test", lambda _: gym.make("CartPole-v0"))
|
||||||
pg = PGAgent(
|
pg = PGAgent(
|
||||||
|
|
19
python/ray/rllib/tuned_examples/atari-a2c.yaml
Normal file
19
python/ray/rllib/tuned_examples/atari-a2c.yaml
Normal file
|
@ -0,0 +1,19 @@
|
||||||
|
# Runs on a single g3.16xl node
|
||||||
|
# See https://github.com/ray-project/rl-experiments for results
|
||||||
|
atari-a2c:
|
||||||
|
env:
|
||||||
|
grid_search:
|
||||||
|
- BreakoutNoFrameskip-v4
|
||||||
|
- BeamRiderNoFrameskip-v4
|
||||||
|
- QbertNoFrameskip-v4
|
||||||
|
- SpaceInvadersNoFrameskip-v4
|
||||||
|
run: A2C
|
||||||
|
config:
|
||||||
|
sample_batch_size: 100
|
||||||
|
num_workers: 5
|
||||||
|
num_envs_per_worker: 5
|
||||||
|
gpu: true
|
||||||
|
lr_schedule: [
|
||||||
|
[0, 0.0007],
|
||||||
|
[20000000, 0.000000000001],
|
||||||
|
]
|
19
python/ray/rllib/tuned_examples/atari-impala.yaml
Normal file
19
python/ray/rllib/tuned_examples/atari-impala.yaml
Normal file
|
@ -0,0 +1,19 @@
|
||||||
|
# Runs on a g3.16xl node with 3 m4.16xl workers
|
||||||
|
# See https://github.com/ray-project/rl-experiments for results
|
||||||
|
atari-impala:
|
||||||
|
env:
|
||||||
|
grid_search:
|
||||||
|
- BreakoutNoFrameskip-v4
|
||||||
|
- BeamRiderNoFrameskip-v4
|
||||||
|
- QbertNoFrameskip-v4
|
||||||
|
- SpaceInvadersNoFrameskip-v4
|
||||||
|
run: IMPALA
|
||||||
|
config:
|
||||||
|
sample_batch_size: 250 # 50 * num_envs_per_worker
|
||||||
|
train_batch_size: 500
|
||||||
|
num_workers: 32
|
||||||
|
num_envs_per_worker: 5
|
||||||
|
lr_schedule: [
|
||||||
|
[0, 0.0005],
|
||||||
|
[20000000, 0.000000000001],
|
||||||
|
]
|
29
python/ray/rllib/tuned_examples/atari-ppo.yaml
Normal file
29
python/ray/rllib/tuned_examples/atari-ppo.yaml
Normal file
|
@ -0,0 +1,29 @@
|
||||||
|
# Runs on a single g3.16xl node
|
||||||
|
# See https://github.com/ray-project/rl-experiments for results
|
||||||
|
atari-ppo:
|
||||||
|
env:
|
||||||
|
grid_search:
|
||||||
|
- BreakoutNoFrameskip-v4
|
||||||
|
- BeamRiderNoFrameskip-v4
|
||||||
|
- QbertNoFrameskip-v4
|
||||||
|
- SpaceInvadersNoFrameskip-v4
|
||||||
|
run: PPO
|
||||||
|
config:
|
||||||
|
lambda: 0.95
|
||||||
|
kl_coeff: 0.5
|
||||||
|
clip_param: 0.1
|
||||||
|
entropy_coeff: 0.01
|
||||||
|
timesteps_per_batch: 5000
|
||||||
|
sample_batch_size: 500
|
||||||
|
sgd_batchsize: 500
|
||||||
|
num_sgd_iter: 10
|
||||||
|
num_workers: 10
|
||||||
|
num_envs_per_worker: 5
|
||||||
|
batch_mode: truncate_episodes
|
||||||
|
observation_filter: NoFilter
|
||||||
|
vf_share_layers: true
|
||||||
|
num_gpus: 1
|
||||||
|
lr_schedule: [
|
||||||
|
[0, 0.0007],
|
||||||
|
[20000000, 0.000000000001],
|
||||||
|
]
|
|
@ -47,6 +47,5 @@ halfcheetah-ddpg:
|
||||||
num_workers: 0
|
num_workers: 0
|
||||||
num_gpus_per_worker: 0
|
num_gpus_per_worker: 0
|
||||||
optimizer_class: "SyncReplayOptimizer"
|
optimizer_class: "SyncReplayOptimizer"
|
||||||
optimizer_config: {}
|
|
||||||
per_worker_exploration: False
|
per_worker_exploration: False
|
||||||
worker_side_prioritization: False
|
worker_side_prioritization: False
|
||||||
|
|
|
@ -47,6 +47,5 @@ mountaincarcontinuous-ddpg:
|
||||||
num_workers: 0
|
num_workers: 0
|
||||||
num_gpus_per_worker: 0
|
num_gpus_per_worker: 0
|
||||||
optimizer_class: "SyncReplayOptimizer"
|
optimizer_class: "SyncReplayOptimizer"
|
||||||
optimizer_config: {}
|
|
||||||
per_worker_exploration: False
|
per_worker_exploration: False
|
||||||
worker_side_prioritization: False
|
worker_side_prioritization: False
|
||||||
|
|
|
@ -47,6 +47,5 @@ pendulum-ddpg:
|
||||||
num_workers: 0
|
num_workers: 0
|
||||||
num_gpus_per_worker: 0
|
num_gpus_per_worker: 0
|
||||||
optimizer_class: "SyncReplayOptimizer"
|
optimizer_class: "SyncReplayOptimizer"
|
||||||
optimizer_config: {}
|
|
||||||
per_worker_exploration: False
|
per_worker_exploration: False
|
||||||
worker_side_prioritization: False
|
worker_side_prioritization: False
|
||||||
|
|
|
@ -14,6 +14,7 @@ pong-a3c:
|
||||||
lambda: 1.0
|
lambda: 1.0
|
||||||
lr: 0.0001
|
lr: 0.0001
|
||||||
observation_filter: NoFilter
|
observation_filter: NoFilter
|
||||||
|
preprocessor_pref: rllib
|
||||||
model:
|
model:
|
||||||
use_lstm: true
|
use_lstm: true
|
||||||
conv_activation: elu
|
conv_activation: elu
|
||||||
|
@ -27,5 +28,3 @@ pong-a3c:
|
||||||
[32, [3, 3], 2],
|
[32, [3, 3], 2],
|
||||||
[32, [3, 3], 2],
|
[32, [3, 3], 2],
|
||||||
]
|
]
|
||||||
optimizer:
|
|
||||||
grads_per_step: 1000
|
|
||||||
|
|
|
@ -6,6 +6,7 @@ pong-deterministic-dqn:
|
||||||
episode_reward_mean: 20
|
episode_reward_mean: 20
|
||||||
time_total_s: 7200
|
time_total_s: 7200
|
||||||
config:
|
config:
|
||||||
|
gpu: True
|
||||||
gamma: 0.99
|
gamma: 0.99
|
||||||
lr: .0001
|
lr: .0001
|
||||||
learning_starts: 10000
|
learning_starts: 10000
|
||||||
|
|
11
python/ray/rllib/utils/explained_variance.py
Normal file
11
python/ray/rllib/utils/explained_variance.py
Normal file
|
@ -0,0 +1,11 @@
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
|
|
||||||
|
def explained_variance(y, pred):
|
||||||
|
_, y_var = tf.nn.moments(y, axes=[0])
|
||||||
|
_, diff_var = tf.nn.moments(y - pred, axes=[0])
|
||||||
|
return tf.maximum(-1.0, 1 - (diff_var / y_var))
|
|
@ -390,6 +390,7 @@ def stop():
|
||||||
help=("Override the configured max worker node count for the cluster."))
|
help=("Override the configured max worker node count for the cluster."))
|
||||||
@click.option(
|
@click.option(
|
||||||
"--cluster-name",
|
"--cluster-name",
|
||||||
|
"-n",
|
||||||
required=False,
|
required=False,
|
||||||
type=str,
|
type=str,
|
||||||
help=("Override the configured cluster name."))
|
help=("Override the configured cluster name."))
|
||||||
|
@ -423,6 +424,7 @@ def create_or_update(cluster_config_file, min_workers, max_workers, no_restart,
|
||||||
help=("Don't ask for confirmation."))
|
help=("Don't ask for confirmation."))
|
||||||
@click.option(
|
@click.option(
|
||||||
"--cluster-name",
|
"--cluster-name",
|
||||||
|
"-n",
|
||||||
required=False,
|
required=False,
|
||||||
type=str,
|
type=str,
|
||||||
help=("Override the configured cluster name."))
|
help=("Override the configured cluster name."))
|
||||||
|
@ -439,6 +441,7 @@ def teardown(cluster_config_file, yes, workers_only, cluster_name):
|
||||||
help=("Start the cluster if needed."))
|
help=("Start the cluster if needed."))
|
||||||
@click.option(
|
@click.option(
|
||||||
"--cluster-name",
|
"--cluster-name",
|
||||||
|
"-n",
|
||||||
required=False,
|
required=False,
|
||||||
type=str,
|
type=str,
|
||||||
help=("Override the configured cluster name."))
|
help=("Override the configured cluster name."))
|
||||||
|
@ -452,6 +455,7 @@ def attach(cluster_config_file, start, cluster_name):
|
||||||
@click.argument("target", required=True, type=str)
|
@click.argument("target", required=True, type=str)
|
||||||
@click.option(
|
@click.option(
|
||||||
"--cluster-name",
|
"--cluster-name",
|
||||||
|
"-n",
|
||||||
required=False,
|
required=False,
|
||||||
type=str,
|
type=str,
|
||||||
help=("Override the configured cluster name."))
|
help=("Override the configured cluster name."))
|
||||||
|
@ -465,6 +469,7 @@ def rsync_down(cluster_config_file, source, target, cluster_name):
|
||||||
@click.argument("target", required=True, type=str)
|
@click.argument("target", required=True, type=str)
|
||||||
@click.option(
|
@click.option(
|
||||||
"--cluster-name",
|
"--cluster-name",
|
||||||
|
"-n",
|
||||||
required=False,
|
required=False,
|
||||||
type=str,
|
type=str,
|
||||||
help=("Override the configured cluster name."))
|
help=("Override the configured cluster name."))
|
||||||
|
@ -492,6 +497,7 @@ def rsync_up(cluster_config_file, source, target, cluster_name):
|
||||||
help=("Run the command in a screen."))
|
help=("Run the command in a screen."))
|
||||||
@click.option(
|
@click.option(
|
||||||
"--cluster-name",
|
"--cluster-name",
|
||||||
|
"-n",
|
||||||
required=False,
|
required=False,
|
||||||
type=str,
|
type=str,
|
||||||
help=("Override the configured cluster name."))
|
help=("Override the configured cluster name."))
|
||||||
|
@ -507,6 +513,7 @@ def exec_cmd(cluster_config_file, cmd, screen, stop, start, cluster_name,
|
||||||
@click.argument("cluster_config_file", required=True, type=str)
|
@click.argument("cluster_config_file", required=True, type=str)
|
||||||
@click.option(
|
@click.option(
|
||||||
"--cluster-name",
|
"--cluster-name",
|
||||||
|
"-n",
|
||||||
required=False,
|
required=False,
|
||||||
type=str,
|
type=str,
|
||||||
help=("Override the configured cluster name."))
|
help=("Override the configured cluster name."))
|
||||||
|
|
Loading…
Add table
Reference in a new issue