[rllib] Fix atari reward calculations, add LR annealing, explained var stat for A2C / impala (#2700)

Changes needed to reproduce Atari plots in IMPALA / A2C: https://github.com/ray-project/rl-experiments
This commit is contained in:
Eric Liang 2018-08-23 17:49:10 -07:00 committed by GitHub
parent 1b3de31ff1
commit aa014af85b
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
35 changed files with 483 additions and 148 deletions

View file

@ -16,11 +16,14 @@ Tuned examples: `PongNoFrameskip-v4 <https://github.com/ray-project/ray/blob/mas
Advantage Actor-Critic (A2C, A3C)
---------------------------------
`[paper] <https://arxiv.org/abs/1602.01783>`__ `[implementation] <https://github.com/ray-project/ray/blob/master/python/ray/rllib/agents/a3c/a3c.py>`__
RLlib's A3C uses the AsyncGradientsOptimizer to apply gradients computed remotely on policy evaluation actors. It scales to up to 16-32 worker processes, depending on the environment. Both a TensorFlow (LSTM), and PyTorch version are available.
RLlib implements A2C and A3C using SyncSamplesOptimizer and AsyncGradientsOptimizer respectively for policy optimization. These algorithms scale to up to 16-32 worker processes depending on the environment. Both a TensorFlow (LSTM), and PyTorch version are available.
Note that if you have a GPU, `IMPALA <#importance-weighted-actor-learner-architecture>`__ probably will perform better than A3C. You can also use the synchronous variant of A3C, `A2C <https://github.com/ray-project/ray/blob/master/python/ray/rllib/agents/a3c/a2c.py>`__.
.. note::
In most cases, `IMPALA <#importance-weighted-actor-learner-architecture-impala>`__ will outperform A2C / A3C. In our `benchmarks <https://github.com/ray-project/rl-experiments>`__, IMPALA is almost 10x faster than A2C in wallclock time, with similar sample efficiency.
Tuned examples: `PongDeterministic-v4 <https://github.com/ray-project/ray/blob/master/python/ray/rllib/tuned_examples/pong-a3c.yaml>`__, `A2C variant <https://github.com/ray-project/ray/blob/master/python/ray/rllib/tuned_examples/pong-a2c.yaml>`__, `PyTorch version <https://github.com/ray-project/ray/blob/master/python/ray/rllib/tuned_examples/pong-a3c-pytorch.yaml>`__
Tuned examples: `PongDeterministic-v4 <https://github.com/ray-project/ray/blob/master/python/ray/rllib/tuned_examples/pong-a3c.yaml>`__, `PyTorch version <https://github.com/ray-project/ray/blob/master/python/ray/rllib/tuned_examples/pong-a3c-pytorch.yaml>`__, `{BeamRider,Breakout,Qbert,SpaceInvaders}NoFrameskip-v4 <https://github.com/ray-project/ray/blob/master/python/ray/rllib/tuned_examples/atari-a2c.yaml>`__
See also our `Atari results <https://github.com/ray-project/rl-experiments>`__.
Deep Deterministic Policy Gradients (DDPG)
------------------------------------------
@ -56,12 +59,14 @@ Importance Weighted Actor-Learner Architecture (IMPALA)
`[implementation] <https://github.com/ray-project/ray/blob/master/python/ray/rllib/agents/impala/impala.py>`__
In IMPALA, a central learner runs SGD in a tight loop while asynchronously pulling sample batches from many actor processes. RLlib's IMPALA implementation uses DeepMind's reference `V-trace code <https://github.com/deepmind/scalable_agent/blob/master/vtrace.py>`__. Note that we do not provide a deep residual network out of the box, but one can be plugged in as a `custom model <rllib-models.html#custom-models>`__.
Tuned examples: `PongNoFrameskip-v4 <https://github.com/ray-project/ray/blob/master/python/ray/rllib/tuned_examples/pong-impala.yaml>`__, `vectorized configuration <https://github.com/ray-project/ray/blob/master/python/ray/rllib/tuned_examples/pong-impala-vectorized.yaml>`__
Tuned examples: `PongNoFrameskip-v4 <https://github.com/ray-project/ray/blob/master/python/ray/rllib/tuned_examples/pong-impala.yaml>`__, `vectorized configuration <https://github.com/ray-project/ray/blob/master/python/ray/rllib/tuned_examples/pong-impala-vectorized.yaml>`__, `{BeamRider,Breakout,Qbert,SpaceInvaders}NoFrameskip-v4 <https://github.com/ray-project/ray/blob/master/python/ray/rllib/tuned_examples/atari-impala.yaml>`__
See also our `Atari results <https://github.com/ray-project/rl-experiments>`__.
.. figure:: impala.png
:align: center
RLlib's IMPALA implementation scales from 16 to 128 workers on PongNoFrameskip-v4. With vectorization, similar learning performance to 128 workers can be achieved with only 32 workers. This is about an order of magnitude faster than A3C (not shown here), with similar sample efficiency.
IMPALA solves Atari about 10x faster than A2C / A3C, with similar sample efficiency. Here IMPALA scales from 16 to 128 workers on PongNoFrameskip-v4.
Policy Gradients
----------------
@ -74,7 +79,9 @@ Proximal Policy Optimization (PPO)
`[paper] <https://arxiv.org/abs/1707.06347>`__ `[implementation] <https://github.com/ray-project/ray/blob/master/python/ray/rllib/agents/ppo/ppo.py>`__
PPO's clipped objective supports multiple SGD passes over the same batch of experiences. RLlib's multi-GPU optimizer pins that data in GPU memory to avoid unnecessary transfers from host memory, substantially improving performance over a naive implementation. RLlib's PPO scales out using multiple workers for experience collection, and also with multiple GPUs for SGD.
Tuned examples: `Humanoid-v1 <https://github.com/ray-project/ray/blob/master/python/ray/rllib/tuned_examples/humanoid-ppo-gae.yaml>`__, `Hopper-v1 <https://github.com/ray-project/ray/blob/master/python/ray/rllib/tuned_examples/hopper-ppo.yaml>`__, `Pendulum-v0 <https://github.com/ray-project/ray/blob/master/python/ray/rllib/tuned_examples/pendulum-ppo.yaml>`__, `PongDeterministic-v4 <https://github.com/ray-project/ray/blob/master/python/ray/rllib/tuned_examples/pong-ppo.yaml>`__, `Walker2d-v1 <https://github.com/ray-project/ray/blob/master/python/ray/rllib/tuned_examples/walker2d-ppo.yaml>`__
Tuned examples: `Humanoid-v1 <https://github.com/ray-project/ray/blob/master/python/ray/rllib/tuned_examples/humanoid-ppo-gae.yaml>`__, `Hopper-v1 <https://github.com/ray-project/ray/blob/master/python/ray/rllib/tuned_examples/hopper-ppo.yaml>`__, `Pendulum-v0 <https://github.com/ray-project/ray/blob/master/python/ray/rllib/tuned_examples/pendulum-ppo.yaml>`__, `PongDeterministic-v4 <https://github.com/ray-project/ray/blob/master/python/ray/rllib/tuned_examples/pong-ppo.yaml>`__, `Walker2d-v1 <https://github.com/ray-project/ray/blob/master/python/ray/rllib/tuned_examples/walker2d-ppo.yaml>`__, `{BeamRider,Breakout,Qbert,SpaceInvaders}NoFrameskip-v4 <https://github.com/ray-project/ray/blob/master/python/ray/rllib/tuned_examples/atari-ppo.yaml>`__
See also our `Atari results <https://github.com/ray-project/rl-experiments>`__.
.. figure:: ppo.png
:width: 500px

View file

@ -6,7 +6,8 @@ ray.rllib.agents
.. automodule:: ray.rllib.agents
:members:
.. autoclass:: ray.rllib.agents.a3c.A2CAgent
.. autoclass:: ray.rllib.agents.a3c.A3CAgent
.. autoclass:: ray.rllib.agents.ddpg.ApexDDPGAgent
.. autoclass:: ray.rllib.agents.ddpg.DDPGAgent

View file

@ -26,7 +26,6 @@ training process with TensorBoard by running
tensorboard --logdir=~/ray_results
The ``train.py`` script has a number of options you can show by running
.. code-block:: bash
@ -44,14 +43,12 @@ Specifying Parameters
Each algorithm has specific hyperparameters that can be set with ``--config``, in addition to a number of `common hyperparameters <https://github.com/ray-project/ray/blob/master/python/ray/rllib/agents/agent.py>`__. See the
`algorithms documentation <rllib-algorithms.html>`__ for more information.
In an example below, we train A3C by specifying 8 workers through the config flag.
function that creates the env to refer to it by name. The contents of the env_config agent config field will be passed to that function to allow the environment to be configured. The return type should be an OpenAI gym.Env. For example:
In an example below, we train A2C by specifying 8 workers through the config flag. We also set ``"monitor": true`` to save episode videos to the result dir:
.. code-block:: bash
python ray/python/ray/rllib/train.py --env=PongDeterministic-v4 \
--run=A3C --config '{"num_workers": 8}'
--run=A2C --config '{"num_workers": 8, "monitor": true}'
Evaluating Trained Agents
~~~~~~~~~~~~~~~~~~~~~~~~~

View file

@ -13,9 +13,7 @@ A2C_DEFAULT_CONFIG = merge_dicts(
"gpu": False,
"sample_batch_size": 20,
"min_iter_time_s": 10,
"optimizer": {
"timesteps_per_batch": 200,
},
"sample_async": False,
},
)

View file

@ -7,6 +7,7 @@ import os
import time
import ray
from ray.rllib.agents.a3c.a3c_tf_policy_graph import A3CPolicyGraph
from ray.rllib.agents.agent import Agent, with_common_config
from ray.rllib.optimizers import AsyncGradientsOptimizer
from ray.rllib.utils import FilterManager, merge_dicts
@ -23,14 +24,14 @@ DEFAULT_CONFIG = with_common_config({
"grad_clip": 40.0,
# Learning rate
"lr": 0.0001,
# Learning rate schedule
"lr_schedule": None,
# Value Function Loss coefficient
"vf_loss_coeff": 0.5,
# Entropy coefficient
"entropy_coeff": -0.01,
# Whether to place workers on GPUs
"use_gpu_for_workers": False,
# Whether to emit extra summary stats
"summarize": False,
# Min time per iteration
"min_iter_time_s": 5,
# Workers sample async. Note that this increases the effective
@ -67,6 +68,7 @@ class A3CAgent(Agent):
_agent_name = "A3C"
_default_config = DEFAULT_CONFIG
_policy_graph = A3CPolicyGraph
@classmethod
def default_resource_request(cls, config):
@ -83,8 +85,7 @@ class A3CAgent(Agent):
A3CTorchPolicyGraph
policy_cls = A3CTorchPolicyGraph
else:
from ray.rllib.agents.a3c.a3c_tf_policy_graph import A3CPolicyGraph
policy_cls = A3CPolicyGraph
policy_cls = self._policy_graph
self.local_evaluator = self.make_local_evaluator(
self.env_creator, policy_cls)

View file

@ -9,8 +9,10 @@ import gym
import ray
from ray.rllib.utils.error import UnsupportedSpaceException
from ray.rllib.utils.explained_variance import explained_variance
from ray.rllib.evaluation.postprocessing import compute_advantages
from ray.rllib.evaluation.tf_policy_graph import TFPolicyGraph
from ray.rllib.evaluation.tf_policy_graph import TFPolicyGraph, \
LearningRateSchedule
from ray.rllib.models.misc import linear, normc_initializer
from ray.rllib.models.catalog import ModelCatalog
@ -36,7 +38,7 @@ class A3CLoss(object):
self.entropy * entropy_coeff)
class A3CPolicyGraph(TFPolicyGraph):
class A3CPolicyGraph(LearningRateSchedule, TFPolicyGraph):
def __init__(self, observation_space, action_space, config):
config = dict(ray.rllib.agents.a3c.a3c.DEFAULT_CONFIG, **config)
self.config = config
@ -67,8 +69,8 @@ class A3CPolicyGraph(TFPolicyGraph):
"Action space {} is not supported for A3C.".format(
action_space))
advantages = tf.placeholder(tf.float32, [None], name="advantages")
v_target = tf.placeholder(tf.float32, [None], name="v_target")
self.loss = A3CLoss(action_dist, actions, advantages, v_target,
self.v_target = tf.placeholder(tf.float32, [None], name="v_target")
self.loss = A3CLoss(action_dist, actions, advantages, self.v_target,
self.vf, self.config["vf_loss_coeff"],
self.config["entropy_coeff"])
@ -77,8 +79,10 @@ class A3CPolicyGraph(TFPolicyGraph):
("obs", self.observations),
("actions", actions),
("advantages", advantages),
("value_targets", v_target),
("value_targets", self.v_target),
]
LearningRateSchedule.__init__(self, self.config["lr"],
self.config["lr_schedule"])
TFPolicyGraph.__init__(
self,
observation_space,
@ -93,6 +97,18 @@ class A3CPolicyGraph(TFPolicyGraph):
seq_lens=self.model.seq_lens,
max_seq_len=self.config["model"]["max_seq_len"])
self.stats_fetches = {
"stats": {
"cur_lr": tf.cast(self.cur_lr, tf.float64),
"policy_loss": self.loss.pi_loss,
"policy_entropy": self.loss.entropy,
"grad_gnorm": tf.global_norm(self._grads),
"var_gnorm": tf.global_norm(self.var_list),
"vf_loss": self.loss.vf_loss,
"vf_explained_var": explained_variance(self.v_target, self.vf),
},
}
self.sess.run(tf.global_variables_initializer())
def extra_compute_action_fetches(self):
@ -107,9 +123,6 @@ class A3CPolicyGraph(TFPolicyGraph):
vf = self.sess.run(self.vf, feed_dict)
return vf[0]
def optimizer(self):
return tf.train.AdamOptimizer(self.config["lr"])
def gradients(self, optimizer):
grads = tf.gradients(self.loss.total_loss, self.var_list)
self.grads, _ = tf.clip_by_global_norm(grads, self.config["grad_clip"])
@ -117,18 +130,7 @@ class A3CPolicyGraph(TFPolicyGraph):
return clipped_grads
def extra_compute_grad_fetches(self):
if self.config.get("summarize"):
return {
"stats": {
"policy_loss": self.loss.pi_loss,
"value_loss": self.loss.vf_loss,
"entropy": self.loss.entropy,
"grad_gnorm": tf.global_norm(self._grads),
"var_gnorm": tf.global_norm(self.var_list),
},
}
else:
return {}
return self.stats_fetches
def get_initial_state(self):
return self.model.state_init

View file

@ -10,7 +10,8 @@ import pickle
import tensorflow as tf
from ray.rllib.evaluation.policy_evaluator import PolicyEvaluator
from ray.rllib.utils import deep_update
from ray.rllib.optimizers.policy_optimizer import PolicyOptimizer
from ray.rllib.utils import deep_update, merge_dicts
from ray.tune.registry import ENV_CREATOR, _global_registry
from ray.tune.trainable import Trainable
@ -61,6 +62,8 @@ COMMON_CONFIG = {
},
# Whether to LZ4 compress observations
"compress_observations": False,
# Whether to write episode stats and videos to the agent log dir
"monitor": False,
# === Multiagent ===
"multiagent": {
@ -103,8 +106,19 @@ class Agent(Trainable):
def make_local_evaluator(self, env_creator, policy_graph):
"""Convenience method to return configured local evaluator."""
return self._make_evaluator(PolicyEvaluator, env_creator, policy_graph,
0)
return self._make_evaluator(
PolicyEvaluator,
env_creator,
policy_graph,
0,
# important: allow local tf to use multiple CPUs for optimization
merge_dicts(
self.config, {
"tf_session_args": {
"intra_op_parallelism_threads": None,
"inter_op_parallelism_threads": None,
}
}))
def make_remote_evaluators(self, env_creator, policy_graph, count,
remote_args):
@ -112,13 +126,12 @@ class Agent(Trainable):
cls = PolicyEvaluator.as_remote(**remote_args).remote
return [
self._make_evaluator(cls, env_creator, policy_graph, i + 1)
for i in range(count)
self._make_evaluator(cls, env_creator, policy_graph, i + 1,
self.config) for i in range(count)
]
def _make_evaluator(self, cls, env_creator, policy_graph, worker_index):
config = self.config
def _make_evaluator(self, cls, env_creator, policy_graph, worker_index,
config):
def session_creator():
return tf.Session(
config=tf.ConfigProto(**config["tf_session_args"]))
@ -142,7 +155,8 @@ class Agent(Trainable):
env_config=config["env_config"],
model_config=config["model"],
policy_config=config,
worker_index=worker_index)
worker_index=worker_index,
monitor_path=self.logdir if config["monitor"] else None)
@classmethod
def resource_help(cls, config):
@ -164,10 +178,25 @@ class Agent(Trainable):
config = config or {}
# Vars to synchronize to evaluators on each train call
self.global_vars = {"timestep": 0}
# Agents allow env ids to be passed directly to the constructor.
self._env_id = env or config.get("env")
Trainable.__init__(self, config, logger_creator)
def train(self):
"""Overrides super.train to synchronize global vars."""
if hasattr(self, "optimizer") and isinstance(self.optimizer,
PolicyOptimizer):
self.global_vars["timestep"] = self.optimizer.num_steps_sampled
self.optimizer.local_evaluator.set_global_vars(self.global_vars)
for ev in self.optimizer.remote_evaluators:
ev.set_global_vars.remote(self.global_vars)
return Trainable.train(self)
def _setup(self):
env = self._env_id
if env:

View file

@ -29,7 +29,6 @@ DEFAULT_CONFIG = with_common_config({
"sample_batch_size": 50,
"train_batch_size": 500,
"min_iter_time_s": 10,
"summarize": False,
"gpu": True,
"num_workers": 2,
"num_cpus_per_worker": 1,
@ -40,6 +39,7 @@ DEFAULT_CONFIG = with_common_config({
# either "adam" or "rmsprop"
"opt_type": "adam",
"lr": 0.0005,
"lr_schedule": None,
# rmsprop considered
"decay": 0.99,
"momentum": 0.0,
@ -62,6 +62,7 @@ class ImpalaAgent(Agent):
_agent_name = "IMPALA"
_default_config = DEFAULT_CONFIG
_policy_graph = VTracePolicyGraph
@classmethod
def default_resource_request(cls, config):
@ -77,7 +78,7 @@ class ImpalaAgent(Agent):
if k not in self.config["optimizer"]:
self.config["optimizer"][k] = self.config[k]
if self.config["vtrace"]:
policy_cls = VTracePolicyGraph
policy_cls = self._policy_graph
else:
policy_cls = A3CPolicyGraph
self.local_evaluator = self.make_local_evaluator(

View file

@ -11,10 +11,12 @@ import gym
import ray
from ray.rllib.agents.impala import vtrace
from ray.rllib.evaluation.tf_policy_graph import TFPolicyGraph
from ray.rllib.evaluation.tf_policy_graph import TFPolicyGraph, \
LearningRateSchedule
from ray.rllib.models.catalog import ModelCatalog
from ray.rllib.models.misc import linear, normc_initializer
from ray.rllib.utils.error import UnsupportedSpaceException
from ray.rllib.utils.explained_variance import explained_variance
class VTraceLoss(object):
@ -54,7 +56,7 @@ class VTraceLoss(object):
# Compute vtrace on the CPU for better perf.
with tf.device("/cpu:0"):
vtrace_returns = vtrace.from_logits(
self.vtrace_returns = vtrace.from_logits(
behaviour_policy_logits=behaviour_logits,
target_policy_logits=target_logits,
actions=tf.cast(actions, tf.int32),
@ -68,10 +70,10 @@ class VTraceLoss(object):
# The policy gradients loss
self.pi_loss = -tf.reduce_sum(
actions_logp * vtrace_returns.pg_advantages)
actions_logp * self.vtrace_returns.pg_advantages)
# The baseline loss
delta = values - vtrace_returns.vs
delta = values - self.vtrace_returns.vs
self.vf_loss = 0.5 * tf.reduce_sum(tf.square(delta))
# The entropy loss
@ -82,9 +84,9 @@ class VTraceLoss(object):
self.entropy * entropy_coeff)
class VTracePolicyGraph(TFPolicyGraph):
class VTracePolicyGraph(LearningRateSchedule, TFPolicyGraph):
def __init__(self, observation_space, action_space, config):
config = dict(ray.rllib.agents.a3c.a3c.DEFAULT_CONFIG, **config)
config = dict(ray.rllib.agents.impala.impala.DEFAULT_CONFIG, **config)
assert config["batch_mode"] == "truncate_episodes", \
"Must use `truncate_episodes` batch mode with V-trace."
self.config = config
@ -162,6 +164,8 @@ class VTracePolicyGraph(TFPolicyGraph):
("rewards", rewards),
("obs", self.observations),
]
LearningRateSchedule.__init__(self, self.config["lr"],
self.config["lr_schedule"])
TFPolicyGraph.__init__(
self,
observation_space,
@ -178,13 +182,27 @@ class VTracePolicyGraph(TFPolicyGraph):
self.sess.run(tf.global_variables_initializer())
self.stats_fetches = {
"stats": {
"cur_lr": tf.cast(self.cur_lr, tf.float64),
"policy_loss": self.loss.pi_loss,
"entropy": self.loss.entropy,
"grad_gnorm": tf.global_norm(self._grads),
"var_gnorm": tf.global_norm(self.var_list),
"vf_loss": self.loss.vf_loss,
"vf_explained_var": explained_variance(
tf.reshape(self.loss.vtrace_returns.vs, [-1]),
tf.reshape(to_batches(values)[:-1], [-1])),
},
}
def optimizer(self):
if self.config["opt_type"] == "adam":
return tf.train.AdamOptimizer(self.config["lr"])
return tf.train.AdamOptimizer(self.cur_lr)
else:
return tf.train.RMSPropOptimizer(
self.config["lr"], self.config["decay"],
self.config["momentum"], self.config["epsilon"])
return tf.train.RMSPropOptimizer(self.cur_lr, self.config["decay"],
self.config["momentum"],
self.config["epsilon"])
def gradients(self, optimizer):
grads = tf.gradients(self.loss.total_loss, self.var_list)
@ -196,18 +214,7 @@ class VTracePolicyGraph(TFPolicyGraph):
return {"behaviour_logits": self.model.outputs}
def extra_compute_grad_fetches(self):
if self.config.get("summarize"):
return {
"stats": {
"policy_loss": self.loss.pi_loss,
"value_loss": self.loss.vf_loss,
"entropy": self.loss.entropy,
"grad_gnorm": tf.global_norm(self._grads),
"var_gnorm": tf.global_norm(self.var_list),
},
}
else:
return {}
return self.stats_fetches
def postprocess_trajectory(self, sample_batch, other_agent_batches=None):
del sample_batch.data["new_obs"] # not used, so save some bandwidth

View file

@ -32,6 +32,7 @@ class PGAgent(Agent):
_agent_name = "PG"
_default_config = DEFAULT_CONFIG
_policy_graph = PGPolicyGraph
@classmethod
def default_resource_request(cls, config):
@ -40,9 +41,10 @@ class PGAgent(Agent):
def _init(self):
self.local_evaluator = self.make_local_evaluator(
self.env_creator, PGPolicyGraph)
self.env_creator, self._policy_graph)
self.remote_evaluators = self.make_remote_evaluators(
self.env_creator, PGPolicyGraph, self.config["num_workers"], {})
self.env_creator, self._policy_graph, self.config["num_workers"],
{})
self.optimizer = SyncSamplesOptimizer(self.local_evaluator,
self.remote_evaluators,
self.config["optimizer"])

View file

@ -26,6 +26,10 @@ DEFAULT_CONFIG = with_common_config({
"num_sgd_iter": 30,
# Stepsize of SGD
"sgd_stepsize": 5e-5,
# Learning rate schedule
"lr_schedule": None,
# Share layers for value function
"vf_share_layers": False,
# Total SGD batch size across all devices for SGD (multi-gpu only)
"sgd_batchsize": 128,
# Coefficient of the value function loss
@ -63,6 +67,7 @@ class PPOAgent(Agent):
_agent_name = "PPO"
_default_config = DEFAULT_CONFIG
_policy_graph = PPOPolicyGraph
@classmethod
def default_resource_request(cls, config):
@ -75,9 +80,9 @@ class PPOAgent(Agent):
def _init(self):
self.local_evaluator = self.make_local_evaluator(
self.env_creator, PPOPolicyGraph)
self.env_creator, self._policy_graph)
self.remote_evaluators = self.make_remote_evaluators(
self.env_creator, PPOPolicyGraph, self.config["num_workers"], {
self.env_creator, self._policy_graph, self.config["num_workers"], {
"num_cpus": self.config["num_cpus_per_worker"],
"num_gpus": self.config["num_gpus_per_worker"]
})
@ -91,7 +96,6 @@ class PPOAgent(Agent):
self.optimizer = LocalMultiGPUOptimizer(
self.local_evaluator, self.remote_evaluators, {
"sgd_batch_size": self.config["sgd_batchsize"],
"sgd_stepsize": self.config["sgd_stepsize"],
"num_sgd_iter": self.config["num_sgd_iter"],
"num_gpus": self.config["num_gpus"],
"timesteps_per_batch": self.config["timesteps_per_batch"],

View file

@ -6,8 +6,11 @@ import tensorflow as tf
import ray
from ray.rllib.evaluation.postprocessing import compute_advantages
from ray.rllib.evaluation.tf_policy_graph import TFPolicyGraph
from ray.rllib.evaluation.tf_policy_graph import TFPolicyGraph, \
LearningRateSchedule
from ray.rllib.models.catalog import ModelCatalog
from ray.rllib.models.misc import linear, normc_initializer
from ray.rllib.utils.explained_variance import explained_variance
class PPOLoss(object):
@ -83,7 +86,7 @@ class PPOLoss(object):
self.loss = loss
class PPOPolicyGraph(TFPolicyGraph):
class PPOPolicyGraph(LearningRateSchedule, TFPolicyGraph):
def __init__(self,
observation_space,
action_space,
@ -126,6 +129,7 @@ class PPOPolicyGraph(TFPolicyGraph):
tf.float32, name="value_targets", shape=(None, ))
existing_state_in = None
existing_seq_lens = None
self.observations = obs_ph
self.loss_in = [
("obs", obs_ph),
@ -154,16 +158,21 @@ class PPOPolicyGraph(TFPolicyGraph):
curr_action_dist = dist_cls(self.logits)
self.sampler = curr_action_dist.sample()
if self.config["use_gae"]:
vf_config = self.config["model"].copy()
# Do not split the last layer of the value function into
# mean parameters and standard deviation parameters and
# do not make the standard deviations free variables.
vf_config["free_log_std"] = False
vf_config["use_lstm"] = False
with tf.variable_scope("value_function"):
self.value_function = ModelCatalog.get_model(
obs_ph, 1, vf_config).outputs
self.value_function = tf.reshape(self.value_function, [-1])
if self.config["vf_share_layers"]:
self.value_function = tf.reshape(
linear(self.model.last_layer, 1, "value",
normc_initializer(1.0)), [-1])
else:
vf_config = self.config["model"].copy()
# Do not split the last layer of the value function into
# mean parameters and standard deviation parameters and
# do not make the standard deviations free variables.
vf_config["free_log_std"] = False
vf_config["use_lstm"] = False
with tf.variable_scope("value_function"):
self.value_function = ModelCatalog.get_model(
obs_ph, 1, vf_config).outputs
self.value_function = tf.reshape(self.value_function, [-1])
else:
self.value_function = tf.zeros(shape=tf.shape(obs_ph)[:1])
@ -179,9 +188,11 @@ class PPOPolicyGraph(TFPolicyGraph):
self.kl_coeff,
entropy_coeff=self.config["entropy_coeff"],
clip_param=self.config["clip_param"],
vf_loss_coeff=self.config["kl_target"],
vf_loss_coeff=self.config["vf_loss_coeff"],
use_gae=self.config["use_gae"])
LearningRateSchedule.__init__(self, self.config["sgd_stepsize"],
self.config["lr_schedule"])
TFPolicyGraph.__init__(
self,
observation_space,
@ -197,6 +208,17 @@ class PPOPolicyGraph(TFPolicyGraph):
max_seq_len=config["model"]["max_seq_len"])
self.sess.run(tf.global_variables_initializer())
self.explained_variance = explained_variance(value_targets_ph,
self.value_function)
self.stats_fetches = {
"cur_lr": tf.cast(self.cur_lr, tf.float64),
"total_loss": self.loss_obj.loss,
"policy_loss": self.loss_obj.mean_policy_loss,
"vf_loss": self.loss_obj.mean_vf_loss,
"vf_explained_var": self.explained_variance,
"kl": self.loss_obj.mean_kl,
"entropy": self.loss_obj.mean_entropy
}
def copy(self, existing_inputs):
"""Creates a copy of self using existing input placeholders."""
@ -210,13 +232,7 @@ class PPOPolicyGraph(TFPolicyGraph):
return {"vf_preds": self.value_function, "logits": self.logits}
def extra_compute_grad_fetches(self):
return {
"total_loss": self.loss_obj.loss,
"policy_loss": self.loss_obj.mean_policy_loss,
"vf_loss": self.loss_obj.mean_vf_loss,
"kl": self.loss_obj.mean_kl,
"entropy": self.loss_obj.mean_entropy
}
return self.stats_fetches
def update_kl(self, sampled_kl):
if sampled_kl > 2.0 * self.kl_target:
@ -226,8 +242,24 @@ class PPOPolicyGraph(TFPolicyGraph):
self.kl_coeff.load(self.kl_coeff_val, session=self.sess)
return self.kl_coeff_val
def value(self, ob, *args):
feed_dict = {self.observations: [ob], self.model.seq_lens: [1]}
assert len(args) == len(self.model.state_in), \
(args, self.model.state_in)
for k, v in zip(self.model.state_in, args):
feed_dict[k] = v
vf = self.sess.run(self.value_function, feed_dict)
return vf[0]
def postprocess_trajectory(self, sample_batch, other_agent_batches=None):
last_r = 0.0
completed = sample_batch["dones"][-1]
if completed:
last_r = 0.0
else:
next_state = []
for i in range(len(self.model.state_in)):
next_state.append([sample_batch["state_out_{}".format(i)][-1]])
last_r = self.value(sample_batch["new_obs"][-1], *next_state)
batch = compute_advantages(
sample_batch,
last_r,
@ -236,9 +268,6 @@ class PPOPolicyGraph(TFPolicyGraph):
use_gae=self.config["use_gae"])
return batch
def optimizer(self):
return tf.train.AdamOptimizer(self.config["sgd_stepsize"])
def gradients(self, optimizer):
return optimizer.compute_gradients(
self._loss, colocate_gradients_with_ops=True)

View file

@ -123,12 +123,12 @@ class AsyncVectorEnv(object):
return None
def get_unwrapped(self):
"""Return a reference to some underlying gym env, if any.
"""Return a reference to the underlying gym envs, if any.
Returns:
env (gym.Env|None): Underlying gym env or None.
envs (list): Underlying gym envs or [].
"""
return None
return []
# Fixed agent identifier when there is only the single agent in the env

View file

@ -10,6 +10,68 @@ def is_atari(env):
return hasattr(env, "unwrapped") and hasattr(env.unwrapped, "ale")
def get_wrapper_by_cls(env, cls):
"""Returns the gym env wrapper of the given class, or None."""
currentenv = env
while True:
if isinstance(currentenv, cls):
return currentenv
elif isinstance(currentenv, gym.Wrapper):
currentenv = currentenv.env
else:
return None
class MonitorEnv(gym.Wrapper):
def __init__(self, env=None):
"""Record episodes stats prior to EpisodicLifeEnv, etc."""
gym.Wrapper.__init__(self, env)
self._current_reward = None
self._num_steps = None
self._total_steps = None
self._episode_rewards = []
self._episode_lengths = []
self._num_episodes = 0
self._num_returned = 0
def reset(self, **kwargs):
obs = self.env.reset(**kwargs)
if self._total_steps is None:
self._total_steps = sum(self._episode_lengths)
if self._current_reward is not None:
self._episode_rewards.append(self._current_reward)
self._episode_lengths.append(self._num_steps)
self._num_episodes += 1
self._current_reward = 0
self._num_steps = 0
return obs
def step(self, action):
obs, rew, done, info = self.env.step(action)
self._current_reward += rew
self._num_steps += 1
self._total_steps += 1
return (obs, rew, done, info)
def get_episode_rewards(self):
return self._episode_rewards
def get_episode_lengths(self):
return self._episode_lengths
def get_total_steps(self):
return self._total_steps
def next_episode_results(self):
for i in range(self._num_returned, len(self._episode_rewards)):
yield (self._episode_rewards[i], self._episode_lengths[i])
self._num_returned = len(self._episode_rewards)
class NoopResetEnv(gym.Wrapper):
def __init__(self, env, noop_max=30):
"""Sample initial states by taking random number of no-ops on reset.
@ -201,14 +263,16 @@ class ScaledFloatFrame(gym.ObservationWrapper):
return np.array(observation).astype(np.float32) / 255.0
def wrap_deepmind(env, dim=84):
def wrap_deepmind(env, dim=84, framestack=True):
"""Configure environment for DeepMind-style Atari.
Note that we assume reward clipping is done outside the wrapper.
Args:
dim (int): Dimension to resize observations to (dim x dim).
framestack (bool): Whether to framestack observations.
"""
env = MonitorEnv(env)
env = NoopResetEnv(env, noop_max=30)
if 'NoFrameskip' in env.spec.id:
env = MaxAndSkipEnv(env, skip=4)
@ -218,5 +282,6 @@ def wrap_deepmind(env, dim=84):
env = WarpFrame(env, dim)
# env = ScaledFloatFrame(env) # TODO: use for dqn?
# env = ClipRewardEnv(env) # reward clipping is handled by policy eval
env = FrameStack(env, 4)
if framestack:
env = FrameStack(env, 4)
return env

View file

@ -49,7 +49,7 @@ class VectorEnv(object):
raise NotImplementedError
def get_unwrapped(self):
"""Returns a single instance of the underlying env."""
"""Returns the underlying env instances."""
raise NotImplementedError
@ -87,4 +87,4 @@ class _VectorizedGymEnv(VectorEnv):
return obs_batch, rew_batch, done_batch, info_batch
def get_unwrapped(self):
return self.envs[0]
return self.envs

View file

@ -12,21 +12,36 @@ from ray.rllib.evaluation.sample_batch import DEFAULT_POLICY_ID
def collect_metrics(local_evaluator, remote_evaluators=[]):
"""Gathers episode metrics from PolicyEvaluator instances."""
episode_rewards = []
episode_lengths = []
policy_rewards = collections.defaultdict(list)
episodes = collect_episodes(local_evaluator, remote_evaluators)
return summarize_episodes(episodes)
def collect_episodes(local_evaluator, remote_evaluators=[]):
"""Gathers new episodes metrics tuples from the given evaluators."""
metric_lists = ray.get([
a.apply.remote(lambda ev: ev.sampler.get_metrics())
for a in remote_evaluators
])
metric_lists.append(local_evaluator.sampler.get_metrics())
episodes = []
for metrics in metric_lists:
for episode in metrics:
episode_lengths.append(episode.episode_length)
episode_rewards.append(episode.episode_reward)
for (_, policy_id), reward in episode.agent_rewards.items():
if policy_id != DEFAULT_POLICY_ID:
policy_rewards[policy_id].append(reward)
episodes.extend(metrics)
return episodes
def summarize_episodes(episodes):
"""Summarizes a set of episode metrics tuples."""
episode_rewards = []
episode_lengths = []
policy_rewards = collections.defaultdict(list)
for episode in episodes:
episode_lengths.append(episode.episode_length)
episode_rewards.append(episode.episode_reward)
for (_, policy_id), reward in episode.agent_rewards.items():
if policy_id != DEFAULT_POLICY_ID:
policy_rewards[policy_id].append(reward)
if episode_rewards:
min_reward = min(episode_rewards)
max_reward = max(episode_rewards)

View file

@ -100,7 +100,8 @@ class PolicyEvaluator(EvaluatorInterface):
env_config=None,
model_config=None,
policy_config=None,
worker_index=0):
worker_index=0,
monitor_path=None):
"""Initialize a policy evaluator.
Arguments:
@ -158,6 +159,8 @@ class PolicyEvaluator(EvaluatorInterface):
worker_index (int): For remote evaluators, this should be set to a
non-zero and unique value. This index is passed to created envs
through EnvContext so that envs can be configured per worker.
monitor_path (str): Write out episode stats and videos to this
directory if specified.
"""
env_context = EnvContext(env_config or {}, worker_index)
@ -184,12 +187,22 @@ class PolicyEvaluator(EvaluatorInterface):
preprocessor_pref == "deepmind":
def wrap(env):
return wrap_deepmind(env, dim=model_config.get("dim", 84))
env = wrap_deepmind(
env,
dim=model_config.get("dim", 84),
framestack=not model_config.get("use_lstm")
and not model_config.get("no_framestack"))
if monitor_path:
env = _monitor(env, monitor_path)
return env
else:
def wrap(env):
return ModelCatalog.get_preprocessor_as_wrapper(
env = ModelCatalog.get_preprocessor_as_wrapper(
env, model_config)
if monitor_path:
env = _monitor(env, monitor_path)
return env
self.env = wrap(self.env)
@ -452,6 +465,9 @@ class PolicyEvaluator(EvaluatorInterface):
for pid, state in objs["state"].items():
self.policy_map[pid].set_state(state)
def set_global_vars(self, global_vars):
self.foreach_policy(lambda p, _: p.on_global_var_update(global_vars))
def _validate_and_canonicalize(policy_graph, env):
if isinstance(policy_graph, dict):
@ -489,6 +505,10 @@ def _validate_and_canonicalize(policy_graph, env):
}
def _monitor(env, path):
return gym.wrappers.Monitor(env, path, resume=True)
def _has_tensorflow_graph(policy_dict):
for policy, _, _, _ in policy_dict.values():
if issubclass(policy, TFPolicyGraph):

View file

@ -174,3 +174,11 @@ class PolicyGraph(object):
state (obj): Serialized local state.
"""
self.set_weights(state)
def on_global_var_update(self, global_vars):
"""Called on an update to global vars.
Arguments:
global_vars (dict): Global variables broadcast from the driver.
"""
pass

View file

@ -11,6 +11,7 @@ from ray.rllib.evaluation.sample_batch import MultiAgentSampleBatchBuilder, \
MultiAgentBatch
from ray.rllib.evaluation.tf_policy_graph import TFPolicyGraph
from ray.rllib.env.async_vector_env import AsyncVectorEnv
from ray.rllib.env.atari_wrappers import get_wrapper_by_cls, MonitorEnv
from ray.rllib.utils.tf_run_builder import TFRunBuilder
RolloutMetrics = namedtuple(
@ -214,7 +215,8 @@ def _env_runner(async_vector_env,
try:
if not horizon:
horizon = async_vector_env.get_unwrapped().spec.max_episode_steps
horizon = (
async_vector_env.get_unwrapped()[0].spec.max_episode_steps)
except Exception:
print("Warning, no horizon specified, assuming infinite")
if not horizon:
@ -259,8 +261,13 @@ def _env_runner(async_vector_env,
# Check episode termination conditions
if dones[env_id]["__all__"] or episode.length >= horizon:
all_done = True
yield RolloutMetrics(episode.length, episode.total_reward,
dict(episode.agent_rewards))
atari_metrics = _fetch_atari_metrics(async_vector_env)
if atari_metrics is not None:
for m in atari_metrics:
yield m
else:
yield RolloutMetrics(episode.length, episode.total_reward,
dict(episode.agent_rewards))
else:
all_done = False
# At least send an empty dict if not done
@ -384,6 +391,24 @@ def _env_runner(async_vector_env,
async_vector_env.send_actions(dict(actions_to_send))
def _fetch_atari_metrics(async_vector_env):
"""Atari games have multiple logical episodes, one per life.
However for metrics reporting we count full episodes all lives included.
"""
unwrapped = async_vector_env.get_unwrapped()
if not unwrapped:
return None
atari_out = []
for u in unwrapped:
monitor = get_wrapper_by_cls(u, MonitorEnv)
if not monitor:
return None
for eps_rew, eps_len in monitor.next_episode_results():
atari_out.append(RolloutMetrics(eps_len, eps_rew, {}))
return atari_out
def _to_column_format(rnn_state_rows):
num_cols = len(rnn_state_rows[0])
return [[row[i] for row in rnn_state_rows] for i in range(num_cols)]

View file

@ -9,6 +9,7 @@ import ray
from ray.rllib.evaluation.policy_graph import PolicyGraph
from ray.rllib.models.lstm import chop_into_sequences
from ray.rllib.utils.tf_run_builder import TFRunBuilder
from ray.rllib.utils.schedules import ConstantSchedule, PiecewiseSchedule
class TFPolicyGraph(PolicyGraph):
@ -229,3 +230,24 @@ class TFPolicyGraph(PolicyGraph):
def loss_inputs(self):
return self._loss_inputs
class LearningRateSchedule(object):
"""Mixin for TFPolicyGraph that adds a learning rate schedule."""
def __init__(self, lr, lr_schedule):
self.cur_lr = tf.get_variable("lr", initializer=lr)
if lr_schedule is None:
self.lr_schedule = ConstantSchedule(lr)
else:
self.lr_schedule = PiecewiseSchedule(
lr_schedule, outside_value=lr_schedule[-1][-1])
def on_global_var_update(self, global_vars):
super(LearningRateSchedule, self).on_global_var_update(global_vars)
self.cur_lr.load(
self.lr_schedule.value(global_vars["timestep"]),
session=self._sess)
def optimizer(self):
return tf.train.AdamOptimizer(self.cur_lr)

View file

@ -16,7 +16,6 @@ more info.
import numpy as np
import tensorflow as tf
import tensorflow.contrib.rnn as rnn
import distutils.version
from ray.rllib.models.misc import linear, normc_initializer
from ray.rllib.models.model import Model
@ -137,15 +136,10 @@ class LSTM(Model):
def _build_layers(self, inputs, num_outputs, options):
cell_size = options.get("lstm_cell_size", 256)
use_tf100_api = (distutils.version.LooseVersion(tf.VERSION) >=
distutils.version.LooseVersion("1.0.0"))
last_layer = add_time_dimension(inputs, self.seq_lens)
# Setup the LSTM cell
if use_tf100_api:
lstm = rnn.BasicLSTMCell(cell_size, state_is_tuple=True)
else:
lstm = rnn.rnn_cell.BasicLSTMCell(cell_size, state_is_tuple=True)
lstm = rnn.BasicLSTMCell(cell_size, state_is_tuple=True)
self.state_init = [
np.zeros(lstm.state_size.c, np.float32),
np.zeros(lstm.state_size.h, np.float32)
@ -162,16 +156,13 @@ class LSTM(Model):
self.state_in = [c_in, h_in]
# Setup LSTM outputs
if use_tf100_api:
state_in = rnn.LSTMStateTuple(c_in, h_in)
else:
state_in = rnn.rnn_cell.LSTMStateTuple(c_in, h_in)
lstm_out, lstm_state = tf.nn.dynamic_rnn(
lstm,
last_layer,
initial_state=state_in,
sequence_length=self.seq_lens,
time_major=False)
time_major=False,
dtype=tf.float32)
self.state_out = list(lstm_state)
# Compute outputs

View file

@ -32,13 +32,11 @@ class LocalMultiGPUOptimizer(PolicyOptimizer):
def _init(self,
sgd_batch_size=128,
sgd_stepsize=5e-5,
num_sgd_iter=10,
timesteps_per_batch=1024,
num_gpus=0,
standardize_fields=[]):
self.batch_size = sgd_batch_size
self.sgd_stepsize = sgd_stepsize
self.num_sgd_iter = num_sgd_iter
self.timesteps_per_batch = timesteps_per_batch
if not num_gpus:
@ -81,8 +79,7 @@ class LocalMultiGPUOptimizer(PolicyOptimizer):
else:
rnn_inputs = []
self.par_opt = LocalSyncParallelOptimizer(
tf.train.AdamOptimizer(
self.sgd_stepsize), self.devices,
self.policy.optimizer(), self.devices,
[v for _, v in self.policy.loss_inputs()], rnn_inputs,
self.per_device_batch_size, self.policy.copy,
os.getcwd())

View file

@ -4,7 +4,7 @@ from __future__ import print_function
import ray
from ray.rllib.evaluation.policy_evaluator import PolicyEvaluator
from ray.rllib.evaluation.metrics import collect_metrics
from ray.rllib.evaluation.metrics import collect_episodes, summarize_episodes
from ray.rllib.evaluation.sample_batch import MultiAgentBatch
@ -45,6 +45,7 @@ class PolicyOptimizer(object):
"""
self.local_evaluator = local_evaluator
self.remote_evaluators = remote_evaluators or []
self.episode_history = []
self.config = config or {}
self._init(**self.config)
@ -78,14 +79,26 @@ class PolicyOptimizer(object):
"num_steps_sampled": self.num_steps_sampled,
}
def collect_metrics(self):
def collect_metrics(self, min_history=100):
"""Returns evaluator and optimizer stats.
Arguments:
min_history (int): Min history length to smooth results over.
Returns:
res (dict): A training result dict from evaluator metrics with
`info` replaced with stats from self.
"""
res = collect_metrics(self.local_evaluator, self.remote_evaluators)
episodes = collect_episodes(self.local_evaluator,
self.remote_evaluators)
orig_episodes = list(episodes)
missing = min_history - len(episodes)
if missing > 0:
episodes.extend(self.episode_history[-missing:])
assert len(episodes) <= min_history
self.episode_history.extend(orig_episodes)
self.episode_history = self.episode_history[-min_history:]
res = summarize_episodes(episodes)
res.update(info=self.stats())
return res

View file

@ -24,6 +24,7 @@ class SyncSamplesOptimizer(PolicyOptimizer):
self.throughput = RunningStat()
self.num_sgd_iter = num_sgd_iter
self.timesteps_per_batch = timesteps_per_batch
self.learner_stats = {}
def step(self):
with self.update_weights_timer:
@ -48,6 +49,8 @@ class SyncSamplesOptimizer(PolicyOptimizer):
with self.grad_timer:
for i in range(self.num_sgd_iter):
fetches = self.local_evaluator.compute_apply(samples)
if "stats" in fetches:
self.learner_stats = fetches["stats"]
if self.num_sgd_iter > 1:
print(i, fetches)
self.grad_timer.push_units_processed(samples.count)
@ -68,4 +71,5 @@ class SyncSamplesOptimizer(PolicyOptimizer):
"sample_peak_throughput": round(
self.sample_timer.mean_throughput, 3),
"opt_samples": round(self.grad_timer.mean_units_processed, 3),
"learner": self.learner_stats,
})

View file

@ -8,6 +8,7 @@ import unittest
import ray
from ray.rllib.agents.pg import PGAgent
from ray.rllib.agents.a3c import A2CAgent
from ray.rllib.evaluation.policy_evaluator import PolicyEvaluator
from ray.rllib.evaluation.metrics import collect_metrics
from ray.rllib.evaluation.policy_graph import PolicyGraph
@ -96,6 +97,9 @@ class MockVectorEnv(VectorEnv):
info_batch.append(info)
return obs_batch, rew_batch, done_batch, info_batch
def get_unwrapped(self):
return self.envs
class TestPolicyEvaluator(unittest.TestCase):
def testBasic(self):
@ -107,6 +111,17 @@ class TestPolicyEvaluator(unittest.TestCase):
self.assertIn(key, batch)
self.assertGreater(batch["advantages"][0], 1)
def testGlobalVarsUpdate(self):
agent = A2CAgent(
env="CartPole-v0",
config={
"lr_schedule": [[0, 0.1], [400, 0.000001]],
})
result = agent.train()
self.assertGreater(result["info"]["learner"]["cur_lr"], 0.01)
result2 = agent.train()
self.assertLess(result2["info"]["learner"]["cur_lr"], 0.0001)
def testQueryEvaluators(self):
register_env("test", lambda _: gym.make("CartPole-v0"))
pg = PGAgent(

View file

@ -0,0 +1,19 @@
# Runs on a single g3.16xl node
# See https://github.com/ray-project/rl-experiments for results
atari-a2c:
env:
grid_search:
- BreakoutNoFrameskip-v4
- BeamRiderNoFrameskip-v4
- QbertNoFrameskip-v4
- SpaceInvadersNoFrameskip-v4
run: A2C
config:
sample_batch_size: 100
num_workers: 5
num_envs_per_worker: 5
gpu: true
lr_schedule: [
[0, 0.0007],
[20000000, 0.000000000001],
]

View file

@ -0,0 +1,19 @@
# Runs on a g3.16xl node with 3 m4.16xl workers
# See https://github.com/ray-project/rl-experiments for results
atari-impala:
env:
grid_search:
- BreakoutNoFrameskip-v4
- BeamRiderNoFrameskip-v4
- QbertNoFrameskip-v4
- SpaceInvadersNoFrameskip-v4
run: IMPALA
config:
sample_batch_size: 250 # 50 * num_envs_per_worker
train_batch_size: 500
num_workers: 32
num_envs_per_worker: 5
lr_schedule: [
[0, 0.0005],
[20000000, 0.000000000001],
]

View file

@ -0,0 +1,29 @@
# Runs on a single g3.16xl node
# See https://github.com/ray-project/rl-experiments for results
atari-ppo:
env:
grid_search:
- BreakoutNoFrameskip-v4
- BeamRiderNoFrameskip-v4
- QbertNoFrameskip-v4
- SpaceInvadersNoFrameskip-v4
run: PPO
config:
lambda: 0.95
kl_coeff: 0.5
clip_param: 0.1
entropy_coeff: 0.01
timesteps_per_batch: 5000
sample_batch_size: 500
sgd_batchsize: 500
num_sgd_iter: 10
num_workers: 10
num_envs_per_worker: 5
batch_mode: truncate_episodes
observation_filter: NoFilter
vf_share_layers: true
num_gpus: 1
lr_schedule: [
[0, 0.0007],
[20000000, 0.000000000001],
]

View file

@ -47,6 +47,5 @@ halfcheetah-ddpg:
num_workers: 0
num_gpus_per_worker: 0
optimizer_class: "SyncReplayOptimizer"
optimizer_config: {}
per_worker_exploration: False
worker_side_prioritization: False

View file

@ -47,6 +47,5 @@ mountaincarcontinuous-ddpg:
num_workers: 0
num_gpus_per_worker: 0
optimizer_class: "SyncReplayOptimizer"
optimizer_config: {}
per_worker_exploration: False
worker_side_prioritization: False

View file

@ -47,6 +47,5 @@ pendulum-ddpg:
num_workers: 0
num_gpus_per_worker: 0
optimizer_class: "SyncReplayOptimizer"
optimizer_config: {}
per_worker_exploration: False
worker_side_prioritization: False

View file

@ -14,6 +14,7 @@ pong-a3c:
lambda: 1.0
lr: 0.0001
observation_filter: NoFilter
preprocessor_pref: rllib
model:
use_lstm: true
conv_activation: elu
@ -27,5 +28,3 @@ pong-a3c:
[32, [3, 3], 2],
[32, [3, 3], 2],
]
optimizer:
grads_per_step: 1000

View file

@ -6,6 +6,7 @@ pong-deterministic-dqn:
episode_reward_mean: 20
time_total_s: 7200
config:
gpu: True
gamma: 0.99
lr: .0001
learning_starts: 10000

View file

@ -0,0 +1,11 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
def explained_variance(y, pred):
_, y_var = tf.nn.moments(y, axes=[0])
_, diff_var = tf.nn.moments(y - pred, axes=[0])
return tf.maximum(-1.0, 1 - (diff_var / y_var))

View file

@ -390,6 +390,7 @@ def stop():
help=("Override the configured max worker node count for the cluster."))
@click.option(
"--cluster-name",
"-n",
required=False,
type=str,
help=("Override the configured cluster name."))
@ -423,6 +424,7 @@ def create_or_update(cluster_config_file, min_workers, max_workers, no_restart,
help=("Don't ask for confirmation."))
@click.option(
"--cluster-name",
"-n",
required=False,
type=str,
help=("Override the configured cluster name."))
@ -439,6 +441,7 @@ def teardown(cluster_config_file, yes, workers_only, cluster_name):
help=("Start the cluster if needed."))
@click.option(
"--cluster-name",
"-n",
required=False,
type=str,
help=("Override the configured cluster name."))
@ -452,6 +455,7 @@ def attach(cluster_config_file, start, cluster_name):
@click.argument("target", required=True, type=str)
@click.option(
"--cluster-name",
"-n",
required=False,
type=str,
help=("Override the configured cluster name."))
@ -465,6 +469,7 @@ def rsync_down(cluster_config_file, source, target, cluster_name):
@click.argument("target", required=True, type=str)
@click.option(
"--cluster-name",
"-n",
required=False,
type=str,
help=("Override the configured cluster name."))
@ -492,6 +497,7 @@ def rsync_up(cluster_config_file, source, target, cluster_name):
help=("Run the command in a screen."))
@click.option(
"--cluster-name",
"-n",
required=False,
type=str,
help=("Override the configured cluster name."))
@ -507,6 +513,7 @@ def exec_cmd(cluster_config_file, cmd, screen, stop, start, cluster_name,
@click.argument("cluster_config_file", required=True, type=str)
@click.option(
"--cluster-name",
"-n",
required=False,
type=str,
help=("Override the configured cluster name."))