From 81a4be8f651842eef60902bf87e9dcd825a70438 Mon Sep 17 00:00:00 2001 From: alvkao58 Date: Sat, 10 Feb 2018 13:54:51 -0800 Subject: [PATCH] [rllib] Added vanilla policy gradient (#1497) --- doc/source/rllib-dev.rst | 2 +- python/ray/rllib/__init__.py | 2 +- python/ray/rllib/agent.py | 3 + python/ray/rllib/pg/__init__.py | 3 + python/ray/rllib/pg/pg.py | 83 +++++++++++++++++++ python/ray/rllib/pg/pg_evaluator.py | 60 ++++++++++++++ python/ray/rllib/pg/policy.py | 82 ++++++++++++++++++ .../ray/rllib/test/test_supported_spaces.py | 4 + test/jenkins_tests/run_multi_node_tests.sh | 21 +++++ 9 files changed, 258 insertions(+), 2 deletions(-) create mode 100644 python/ray/rllib/pg/__init__.py create mode 100644 python/ray/rllib/pg/pg.py create mode 100644 python/ray/rllib/pg/pg_evaluator.py create mode 100644 python/ray/rllib/pg/policy.py diff --git a/doc/source/rllib-dev.rst b/doc/source/rllib-dev.rst index b12bf9062..2b438ca60 100644 --- a/doc/source/rllib-dev.rst +++ b/doc/source/rllib-dev.rst @@ -10,7 +10,7 @@ Recipe for an RLlib algorithm Here are the steps for implementing a new algorithm in RLlib: -1. Define an algorithm-specific `Policy evaluator class <#policy-evaluators-and-optimizers>`__ (the core of the algorithm). Evaluators encapsulate framework-specific components such as the policy and loss functions. For an example, see the `A3C Evaluator implementation `__. +1. Define an algorithm-specific `Policy evaluator class <#policy-evaluators-and-optimizers>`__ (the core of the algorithm). Evaluators encapsulate framework-specific components such as the policy and loss functions. For an example, see the `simple policy gradient evaluator example `__. 2. Pick an appropriate `Policy optimizer class <#policy-evaluators-and-optimizers>`__. Optimizers manage the parallel execution of the algorithm. RLlib provides several built-in optimizers for gradient-based algorithms. Advanced algorithms may find it beneficial to implement their own optimizers. diff --git a/python/ray/rllib/__init__.py b/python/ray/rllib/__init__.py index 1c66d4d69..6a45ff50a 100644 --- a/python/ray/rllib/__init__.py +++ b/python/ray/rllib/__init__.py @@ -8,7 +8,7 @@ from ray.tune.registry import register_trainable def _register_all(): - for key in ["PPO", "ES", "DQN", "A3C", "BC", "__fake", + for key in ["PPO", "ES", "DQN", "A3C", "BC", "PG", "__fake", "__sigmoid_fake_data", "__parameter_tuning"]: try: from ray.rllib.agent import get_agent_class diff --git a/python/ray/rllib/agent.py b/python/ray/rllib/agent.py index 3cb850410..47c270299 100644 --- a/python/ray/rllib/agent.py +++ b/python/ray/rllib/agent.py @@ -229,6 +229,9 @@ def get_agent_class(alg): elif alg == "BC": from ray.rllib import bc return bc.BCAgent + elif alg == "PG": + from ray.rllib import pg + return pg.PGAgent elif alg == "script": from ray.tune import script_runner return script_runner.ScriptRunner diff --git a/python/ray/rllib/pg/__init__.py b/python/ray/rllib/pg/__init__.py new file mode 100644 index 000000000..fa566536d --- /dev/null +++ b/python/ray/rllib/pg/__init__.py @@ -0,0 +1,3 @@ +from ray.rllib.pg.pg import PGAgent, DEFAULT_CONFIG + +__all__ = ["PGAgent", "DEFAULT_CONFIG"] diff --git a/python/ray/rllib/pg/pg.py b/python/ray/rllib/pg/pg.py new file mode 100644 index 000000000..aa82681cc --- /dev/null +++ b/python/ray/rllib/pg/pg.py @@ -0,0 +1,83 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +import ray +from ray.rllib.optimizers import LocalSyncOptimizer +from ray.rllib.pg.pg_evaluator import PGEvaluator, RemotePGEvaluator +from ray.rllib.agent import Agent +from ray.tune.result import TrainingResult + +DEFAULT_CONFIG = { + # Number of workers (excluding master) + "num_workers": 4, + # Size of rollout batch + "batch_size": 512, + # Discount factor of MDP + "gamma": 0.99, + # Number of steps after which the rollout gets cut + "horizon": 500, + # Learning rate + "lr": 0.0004, + # Arguments to pass to the rllib optimizer + "optimizer": { + # Number of gradients applied for each `train` step + "grads_per_step": 1, + }, + # Model parameters + "model": {"fcnet_hiddens": [128, 128]}, + # Arguments to pass to the env creator + "env_config": {}, +} + + +class PGAgent(Agent): + + """Simple policy gradient agent. + + This is an example agent to show how to implement algorithms in RLlib. + In most cases, you will probably want to use the PPO agent instead. + """ + + _agent_name = "PG" + _default_config = DEFAULT_CONFIG + + def _init(self): + self.local_evaluator = PGEvaluator( + self.registry, self.env_creator, self.config) + self.remote_evaluators = [ + RemotePGEvaluator.remote( + self.registry, self.env_creator, self.config) + for _ in range(self.config["num_workers"])] + self.optimizer = LocalSyncOptimizer( + self.config["optimizer"], self.local_evaluator, + self.remote_evaluators) + + def _train(self): + self.optimizer.step() + + episode_rewards = [] + episode_lengths = [] + metric_lists = [a.get_completed_rollout_metrics.remote() + for a in self.remote_evaluators] + for metrics in metric_lists: + for episode in ray.get(metrics): + episode_lengths.append(episode.episode_length) + episode_rewards.append(episode.episode_reward) + avg_reward = np.mean(episode_rewards) + avg_length = np.mean(episode_lengths) + timesteps = np.sum(episode_lengths) + + result = TrainingResult( + episode_reward_mean=avg_reward, + episode_len_mean=avg_length, + timesteps_this_iter=timesteps, + info={}) + + return result + + def compute_action(self, obs): + action, info = self.local_evaluator.policy.compute(obs) + return action diff --git a/python/ray/rllib/pg/pg_evaluator.py b/python/ray/rllib/pg/pg_evaluator.py new file mode 100644 index 000000000..6ab6b0b67 --- /dev/null +++ b/python/ray/rllib/pg/pg_evaluator.py @@ -0,0 +1,60 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import ray +from ray.rllib.models.catalog import ModelCatalog +from ray.rllib.optimizers import Evaluator +from ray.rllib.pg.policy import PGPolicy +from ray.rllib.utils.filter import NoFilter +from ray.rllib.utils.process_rollout import process_rollout +from ray.rllib.utils.sampler import SyncSampler + + +class PGEvaluator(Evaluator): + """Evaluator for simple policy gradient.""" + + def __init__(self, registry, env_creator, config): + self.env = ModelCatalog.get_preprocessor_as_wrapper( + registry, env_creator(config["env_config"]), config["model"]) + self.config = config + + self.policy = PGPolicy(registry, self.env.observation_space, + self.env.action_space, config) + self.sampler = SyncSampler( + self.env, self.policy, NoFilter(), + config["batch_size"], horizon=config["horizon"]) + + def sample(self): + rollout = self.sampler.get_data() + samples = process_rollout( + rollout, NoFilter(), + gamma=self.config["gamma"], use_gae=False) + return samples + + def get_completed_rollout_metrics(self): + """Returns metrics on previously completed rollouts. + + Calling this clears the queue of completed rollout metrics. + """ + return self.sampler.get_metrics() + + def compute_gradients(self, samples): + """ Returns gradient w.r.t. samples.""" + gradient, info = self.policy.compute_gradients(samples) + return gradient + + def apply_gradients(self, grads): + """Applies gradients to evaluator weights.""" + self.policy.apply_gradients(grads) + + def get_weights(self): + """Returns model weights.""" + return self.policy.get_weights() + + def set_weights(self, weights): + """Sets model weights.""" + return self.policy.set_weights(weights) + + +RemotePGEvaluator = ray.remote(PGEvaluator) diff --git a/python/ray/rllib/pg/policy.py b/python/ray/rllib/pg/policy.py new file mode 100644 index 000000000..18b2cd71c --- /dev/null +++ b/python/ray/rllib/pg/policy.py @@ -0,0 +1,82 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tensorflow as tf + +import ray +from ray.rllib.models.catalog import ModelCatalog + + +class PGPolicy(): + + other_output = [] + is_recurrent = False + + def __init__(self, registry, ob_space, ac_space, config): + self.config = config + self.registry = registry + with tf.variable_scope("local"): + self._setup_graph(ob_space, ac_space) + print("Setting up loss") + self._setup_loss(ac_space) + self._setup_gradients() + self.initialize() + + def _setup_graph(self, ob_space, ac_space): + self.x = tf.placeholder(tf.float32, shape=[None]+list(ob_space.shape)) + dist_class, self.logit_dim = ModelCatalog.get_action_dist(ac_space) + self.model = ModelCatalog.get_model( + self.registry, self.x, self.logit_dim, + options=self.config["model"]) + self.action_logits = self.model.outputs # logit for each action + self.dist = dist_class(self.action_logits) + self.sample = self.dist.sample() + self.var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, + tf.get_variable_scope().name) + + def _setup_loss(self, action_space): + self.ac = ModelCatalog.get_action_placeholder(action_space) + self.adv = tf.placeholder(tf.float32, [None], name="adv") + + log_prob = self.dist.logp(self.ac) + + # policy loss + self.loss = -tf.reduce_mean(log_prob * self.adv) + + def _setup_gradients(self): + self.grads = tf.gradients(self.loss, self.var_list) + grads_and_vars = list(zip(self.grads, self.var_list)) + opt = tf.train.AdamOptimizer(self.config["lr"]) + self._apply_gradients = opt.apply_gradients(grads_and_vars) + + def initialize(self): + self.sess = tf.Session() + self.variables = ray.experimental.TensorFlowVariables( + self.loss, self.sess) + self.sess.run(tf.global_variables_initializer()) + + def compute_gradients(self, samples): + info = {} + feed_dict = { + self.x: samples["observations"], + self.ac: samples["actions"], + self.adv: samples["advantages"], + } + self.grads = [g for g in self.grads if g is not None] + grad = self.sess.run(self.grads, feed_dict=feed_dict) + return grad, info + + def apply_gradients(self, grads): + feed_dict = dict(zip(self.grads, grads)) + self.sess.run(self._apply_gradients, feed_dict=feed_dict) + + def get_weights(self): + return self.variables.get_weights() + + def set_weights(self, weights): + self.variables.set_weights(weights) + + def compute(self, ob, *args): + action = self.sess.run(self.sample, {self.x: [ob]}) + return action[0], {} diff --git a/python/ray/rllib/test/test_supported_spaces.py b/python/ray/rllib/test/test_supported_spaces.py index e68aa8242..b9f9ff933 100644 --- a/python/ray/rllib/test/test_supported_spaces.py +++ b/python/ray/rllib/test/test_supported_spaces.py @@ -120,6 +120,10 @@ class ModelSupportedSpaces(unittest.TestCase): {"num_workers": 1, "noise_size": 10000000, "episodes_per_batch": 1, "timesteps_per_batch": 1}, stats) + check_support( + "PG", + {"num_workers": 1, "optimizer": {"grads_per_step": 1}}, + stats) num_unexpected_errors = 0 num_unexpected_success = 0 for (alg, a_name, o_name), stat in sorted(stats.items()): diff --git a/test/jenkins_tests/run_multi_node_tests.sh b/test/jenkins_tests/run_multi_node_tests.sh index d7b036ce4..0165215a3 100755 --- a/test/jenkins_tests/run_multi_node_tests.sh +++ b/test/jenkins_tests/run_multi_node_tests.sh @@ -153,6 +153,27 @@ docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \ # --stop '{"training_iteration": 2}' \ # --config '{"num_workers": 2, "use_lstm": false, "use_pytorch": true, "model": {"grayscale": true, "zero_mean": false, "dim": 80, "channel_major": true}}' +docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \ + python /ray/python/ray/rllib/train.py \ + --env CartPole-v0 \ + --run PG \ + --stop '{"training_iteration": 2}' \ + --config '{"batch_size": 500, "num_workers": 1}' + +docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \ + python /ray/python/ray/rllib/train.py \ + --env Pong-v0 \ + --run PG \ + --stop '{"training_iteration": 2}' \ + --config '{"batch_size": 500, "num_workers": 1}' + +docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \ + python /ray/python/ray/rllib/train.py \ + --env FrozenLake-v0 \ + --run PG \ + --stop '{"training_iteration": 2}' \ + --config '{"batch_size": 500, "num_workers": 1}' + docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \ sh /ray/test/jenkins_tests/multi_node_tests/test_rllib_eval.sh