mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
[rllib] Added vanilla policy gradient (#1497)
This commit is contained in:
parent
1ab2e63dbd
commit
81a4be8f65
9 changed files with 258 additions and 2 deletions
|
@ -10,7 +10,7 @@ Recipe for an RLlib algorithm
|
|||
|
||||
Here are the steps for implementing a new algorithm in RLlib:
|
||||
|
||||
1. Define an algorithm-specific `Policy evaluator class <#policy-evaluators-and-optimizers>`__ (the core of the algorithm). Evaluators encapsulate framework-specific components such as the policy and loss functions. For an example, see the `A3C Evaluator implementation <https://github.com/ray-project/ray/blob/master/python/ray/rllib/a3c/a3c_evaluator.py>`__.
|
||||
1. Define an algorithm-specific `Policy evaluator class <#policy-evaluators-and-optimizers>`__ (the core of the algorithm). Evaluators encapsulate framework-specific components such as the policy and loss functions. For an example, see the `simple policy gradient evaluator example <https://github.com/ray-project/ray/blob/master/python/ray/rllib/pg/pg_evaluator.py>`__.
|
||||
|
||||
|
||||
2. Pick an appropriate `Policy optimizer class <#policy-evaluators-and-optimizers>`__. Optimizers manage the parallel execution of the algorithm. RLlib provides several built-in optimizers for gradient-based algorithms. Advanced algorithms may find it beneficial to implement their own optimizers.
|
||||
|
|
|
@ -8,7 +8,7 @@ from ray.tune.registry import register_trainable
|
|||
|
||||
|
||||
def _register_all():
|
||||
for key in ["PPO", "ES", "DQN", "A3C", "BC", "__fake",
|
||||
for key in ["PPO", "ES", "DQN", "A3C", "BC", "PG", "__fake",
|
||||
"__sigmoid_fake_data", "__parameter_tuning"]:
|
||||
try:
|
||||
from ray.rllib.agent import get_agent_class
|
||||
|
|
|
@ -229,6 +229,9 @@ def get_agent_class(alg):
|
|||
elif alg == "BC":
|
||||
from ray.rllib import bc
|
||||
return bc.BCAgent
|
||||
elif alg == "PG":
|
||||
from ray.rllib import pg
|
||||
return pg.PGAgent
|
||||
elif alg == "script":
|
||||
from ray.tune import script_runner
|
||||
return script_runner.ScriptRunner
|
||||
|
|
3
python/ray/rllib/pg/__init__.py
Normal file
3
python/ray/rllib/pg/__init__.py
Normal file
|
@ -0,0 +1,3 @@
|
|||
from ray.rllib.pg.pg import PGAgent, DEFAULT_CONFIG
|
||||
|
||||
__all__ = ["PGAgent", "DEFAULT_CONFIG"]
|
83
python/ray/rllib/pg/pg.py
Normal file
83
python/ray/rllib/pg/pg.py
Normal file
|
@ -0,0 +1,83 @@
|
|||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
import ray
|
||||
from ray.rllib.optimizers import LocalSyncOptimizer
|
||||
from ray.rllib.pg.pg_evaluator import PGEvaluator, RemotePGEvaluator
|
||||
from ray.rllib.agent import Agent
|
||||
from ray.tune.result import TrainingResult
|
||||
|
||||
DEFAULT_CONFIG = {
|
||||
# Number of workers (excluding master)
|
||||
"num_workers": 4,
|
||||
# Size of rollout batch
|
||||
"batch_size": 512,
|
||||
# Discount factor of MDP
|
||||
"gamma": 0.99,
|
||||
# Number of steps after which the rollout gets cut
|
||||
"horizon": 500,
|
||||
# Learning rate
|
||||
"lr": 0.0004,
|
||||
# Arguments to pass to the rllib optimizer
|
||||
"optimizer": {
|
||||
# Number of gradients applied for each `train` step
|
||||
"grads_per_step": 1,
|
||||
},
|
||||
# Model parameters
|
||||
"model": {"fcnet_hiddens": [128, 128]},
|
||||
# Arguments to pass to the env creator
|
||||
"env_config": {},
|
||||
}
|
||||
|
||||
|
||||
class PGAgent(Agent):
|
||||
|
||||
"""Simple policy gradient agent.
|
||||
|
||||
This is an example agent to show how to implement algorithms in RLlib.
|
||||
In most cases, you will probably want to use the PPO agent instead.
|
||||
"""
|
||||
|
||||
_agent_name = "PG"
|
||||
_default_config = DEFAULT_CONFIG
|
||||
|
||||
def _init(self):
|
||||
self.local_evaluator = PGEvaluator(
|
||||
self.registry, self.env_creator, self.config)
|
||||
self.remote_evaluators = [
|
||||
RemotePGEvaluator.remote(
|
||||
self.registry, self.env_creator, self.config)
|
||||
for _ in range(self.config["num_workers"])]
|
||||
self.optimizer = LocalSyncOptimizer(
|
||||
self.config["optimizer"], self.local_evaluator,
|
||||
self.remote_evaluators)
|
||||
|
||||
def _train(self):
|
||||
self.optimizer.step()
|
||||
|
||||
episode_rewards = []
|
||||
episode_lengths = []
|
||||
metric_lists = [a.get_completed_rollout_metrics.remote()
|
||||
for a in self.remote_evaluators]
|
||||
for metrics in metric_lists:
|
||||
for episode in ray.get(metrics):
|
||||
episode_lengths.append(episode.episode_length)
|
||||
episode_rewards.append(episode.episode_reward)
|
||||
avg_reward = np.mean(episode_rewards)
|
||||
avg_length = np.mean(episode_lengths)
|
||||
timesteps = np.sum(episode_lengths)
|
||||
|
||||
result = TrainingResult(
|
||||
episode_reward_mean=avg_reward,
|
||||
episode_len_mean=avg_length,
|
||||
timesteps_this_iter=timesteps,
|
||||
info={})
|
||||
|
||||
return result
|
||||
|
||||
def compute_action(self, obs):
|
||||
action, info = self.local_evaluator.policy.compute(obs)
|
||||
return action
|
60
python/ray/rllib/pg/pg_evaluator.py
Normal file
60
python/ray/rllib/pg/pg_evaluator.py
Normal file
|
@ -0,0 +1,60 @@
|
|||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import ray
|
||||
from ray.rllib.models.catalog import ModelCatalog
|
||||
from ray.rllib.optimizers import Evaluator
|
||||
from ray.rllib.pg.policy import PGPolicy
|
||||
from ray.rllib.utils.filter import NoFilter
|
||||
from ray.rllib.utils.process_rollout import process_rollout
|
||||
from ray.rllib.utils.sampler import SyncSampler
|
||||
|
||||
|
||||
class PGEvaluator(Evaluator):
|
||||
"""Evaluator for simple policy gradient."""
|
||||
|
||||
def __init__(self, registry, env_creator, config):
|
||||
self.env = ModelCatalog.get_preprocessor_as_wrapper(
|
||||
registry, env_creator(config["env_config"]), config["model"])
|
||||
self.config = config
|
||||
|
||||
self.policy = PGPolicy(registry, self.env.observation_space,
|
||||
self.env.action_space, config)
|
||||
self.sampler = SyncSampler(
|
||||
self.env, self.policy, NoFilter(),
|
||||
config["batch_size"], horizon=config["horizon"])
|
||||
|
||||
def sample(self):
|
||||
rollout = self.sampler.get_data()
|
||||
samples = process_rollout(
|
||||
rollout, NoFilter(),
|
||||
gamma=self.config["gamma"], use_gae=False)
|
||||
return samples
|
||||
|
||||
def get_completed_rollout_metrics(self):
|
||||
"""Returns metrics on previously completed rollouts.
|
||||
|
||||
Calling this clears the queue of completed rollout metrics.
|
||||
"""
|
||||
return self.sampler.get_metrics()
|
||||
|
||||
def compute_gradients(self, samples):
|
||||
""" Returns gradient w.r.t. samples."""
|
||||
gradient, info = self.policy.compute_gradients(samples)
|
||||
return gradient
|
||||
|
||||
def apply_gradients(self, grads):
|
||||
"""Applies gradients to evaluator weights."""
|
||||
self.policy.apply_gradients(grads)
|
||||
|
||||
def get_weights(self):
|
||||
"""Returns model weights."""
|
||||
return self.policy.get_weights()
|
||||
|
||||
def set_weights(self, weights):
|
||||
"""Sets model weights."""
|
||||
return self.policy.set_weights(weights)
|
||||
|
||||
|
||||
RemotePGEvaluator = ray.remote(PGEvaluator)
|
82
python/ray/rllib/pg/policy.py
Normal file
82
python/ray/rllib/pg/policy.py
Normal file
|
@ -0,0 +1,82 @@
|
|||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
import ray
|
||||
from ray.rllib.models.catalog import ModelCatalog
|
||||
|
||||
|
||||
class PGPolicy():
|
||||
|
||||
other_output = []
|
||||
is_recurrent = False
|
||||
|
||||
def __init__(self, registry, ob_space, ac_space, config):
|
||||
self.config = config
|
||||
self.registry = registry
|
||||
with tf.variable_scope("local"):
|
||||
self._setup_graph(ob_space, ac_space)
|
||||
print("Setting up loss")
|
||||
self._setup_loss(ac_space)
|
||||
self._setup_gradients()
|
||||
self.initialize()
|
||||
|
||||
def _setup_graph(self, ob_space, ac_space):
|
||||
self.x = tf.placeholder(tf.float32, shape=[None]+list(ob_space.shape))
|
||||
dist_class, self.logit_dim = ModelCatalog.get_action_dist(ac_space)
|
||||
self.model = ModelCatalog.get_model(
|
||||
self.registry, self.x, self.logit_dim,
|
||||
options=self.config["model"])
|
||||
self.action_logits = self.model.outputs # logit for each action
|
||||
self.dist = dist_class(self.action_logits)
|
||||
self.sample = self.dist.sample()
|
||||
self.var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
|
||||
tf.get_variable_scope().name)
|
||||
|
||||
def _setup_loss(self, action_space):
|
||||
self.ac = ModelCatalog.get_action_placeholder(action_space)
|
||||
self.adv = tf.placeholder(tf.float32, [None], name="adv")
|
||||
|
||||
log_prob = self.dist.logp(self.ac)
|
||||
|
||||
# policy loss
|
||||
self.loss = -tf.reduce_mean(log_prob * self.adv)
|
||||
|
||||
def _setup_gradients(self):
|
||||
self.grads = tf.gradients(self.loss, self.var_list)
|
||||
grads_and_vars = list(zip(self.grads, self.var_list))
|
||||
opt = tf.train.AdamOptimizer(self.config["lr"])
|
||||
self._apply_gradients = opt.apply_gradients(grads_and_vars)
|
||||
|
||||
def initialize(self):
|
||||
self.sess = tf.Session()
|
||||
self.variables = ray.experimental.TensorFlowVariables(
|
||||
self.loss, self.sess)
|
||||
self.sess.run(tf.global_variables_initializer())
|
||||
|
||||
def compute_gradients(self, samples):
|
||||
info = {}
|
||||
feed_dict = {
|
||||
self.x: samples["observations"],
|
||||
self.ac: samples["actions"],
|
||||
self.adv: samples["advantages"],
|
||||
}
|
||||
self.grads = [g for g in self.grads if g is not None]
|
||||
grad = self.sess.run(self.grads, feed_dict=feed_dict)
|
||||
return grad, info
|
||||
|
||||
def apply_gradients(self, grads):
|
||||
feed_dict = dict(zip(self.grads, grads))
|
||||
self.sess.run(self._apply_gradients, feed_dict=feed_dict)
|
||||
|
||||
def get_weights(self):
|
||||
return self.variables.get_weights()
|
||||
|
||||
def set_weights(self, weights):
|
||||
self.variables.set_weights(weights)
|
||||
|
||||
def compute(self, ob, *args):
|
||||
action = self.sess.run(self.sample, {self.x: [ob]})
|
||||
return action[0], {}
|
|
@ -120,6 +120,10 @@ class ModelSupportedSpaces(unittest.TestCase):
|
|||
{"num_workers": 1, "noise_size": 10000000,
|
||||
"episodes_per_batch": 1, "timesteps_per_batch": 1},
|
||||
stats)
|
||||
check_support(
|
||||
"PG",
|
||||
{"num_workers": 1, "optimizer": {"grads_per_step": 1}},
|
||||
stats)
|
||||
num_unexpected_errors = 0
|
||||
num_unexpected_success = 0
|
||||
for (alg, a_name, o_name), stat in sorted(stats.items()):
|
||||
|
|
|
@ -153,6 +153,27 @@ docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \
|
|||
# --stop '{"training_iteration": 2}' \
|
||||
# --config '{"num_workers": 2, "use_lstm": false, "use_pytorch": true, "model": {"grayscale": true, "zero_mean": false, "dim": 80, "channel_major": true}}'
|
||||
|
||||
docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \
|
||||
python /ray/python/ray/rllib/train.py \
|
||||
--env CartPole-v0 \
|
||||
--run PG \
|
||||
--stop '{"training_iteration": 2}' \
|
||||
--config '{"batch_size": 500, "num_workers": 1}'
|
||||
|
||||
docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \
|
||||
python /ray/python/ray/rllib/train.py \
|
||||
--env Pong-v0 \
|
||||
--run PG \
|
||||
--stop '{"training_iteration": 2}' \
|
||||
--config '{"batch_size": 500, "num_workers": 1}'
|
||||
|
||||
docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \
|
||||
python /ray/python/ray/rllib/train.py \
|
||||
--env FrozenLake-v0 \
|
||||
--run PG \
|
||||
--stop '{"training_iteration": 2}' \
|
||||
--config '{"batch_size": 500, "num_workers": 1}'
|
||||
|
||||
docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \
|
||||
sh /ray/test/jenkins_tests/multi_node_tests/test_rllib_eval.sh
|
||||
|
||||
|
|
Loading…
Add table
Reference in a new issue