mirror of
https://github.com/vale981/ray
synced 2025-04-23 06:25:52 -04:00
[rllib] Behavior Cloning (#1400)
* Behavior Cloning * episode_reward_mean -> mean_loss * removing vestigial code * punctuation * unnecessary * Behavior Cloning * Behavior Cloning * Update __init__.py
This commit is contained in:
parent
ee36effd8e
commit
4b0ef5eb2c
11 changed files with 390 additions and 83 deletions
|
@ -8,7 +8,8 @@ from ray.rllib.agent import get_agent_class
|
|||
|
||||
|
||||
def _register_all():
|
||||
for key in ["PPO", "ES", "DQN", "A3C", "__fake", "__sigmoid_fake_data"]:
|
||||
for key in [
|
||||
"PPO", "ES", "DQN", "A3C", "BC", "__fake", "__sigmoid_fake_data"]:
|
||||
try:
|
||||
register_trainable(key, get_agent_class(key))
|
||||
except ImportError as e:
|
||||
|
|
|
@ -28,7 +28,7 @@ class TFPolicy(Policy):
|
|||
self.setup_gradients()
|
||||
self.initialize()
|
||||
|
||||
def _setup_graph(self):
|
||||
def _setup_graph(self, ob_space, ac_space):
|
||||
raise NotImplementedError
|
||||
|
||||
def setup_loss(self, action_space):
|
||||
|
|
|
@ -128,7 +128,7 @@ class Agent(Trainable):
|
|||
|
||||
self._initialize_ok = True
|
||||
|
||||
def _init(self, config, env_creator):
|
||||
def _init(self):
|
||||
"""Subclasses should override this for custom initialization."""
|
||||
|
||||
raise NotImplementedError
|
||||
|
@ -293,7 +293,7 @@ class Agent(Trainable):
|
|||
|
||||
raise NotImplementedError
|
||||
|
||||
def _restore(self):
|
||||
def _restore(self, checkpoint_path):
|
||||
"""Subclasses should override this to implement restore()."""
|
||||
|
||||
raise NotImplementedError
|
||||
|
@ -371,6 +371,9 @@ def get_agent_class(alg):
|
|||
elif alg == "A3C":
|
||||
from ray.rllib import a3c
|
||||
return a3c.A3CAgent
|
||||
elif alg == "BC":
|
||||
from ray.rllib import bc
|
||||
return bc.BCAgent
|
||||
elif alg == "script":
|
||||
from ray.tune import script_runner
|
||||
return script_runner.ScriptRunner
|
||||
|
|
3
python/ray/rllib/bc/__init__.py
Normal file
3
python/ray/rllib/bc/__init__.py
Normal file
|
@ -0,0 +1,3 @@
|
|||
from ray.rllib.bc.bc import BCAgent, DEFAULT_CONFIG
|
||||
|
||||
__all__ = ["BCAgent", "DEFAULT_CONFIG"]
|
82
python/ray/rllib/bc/bc.py
Normal file
82
python/ray/rllib/bc/bc.py
Normal file
|
@ -0,0 +1,82 @@
|
|||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import ray
|
||||
from ray.rllib.agent import Agent
|
||||
from ray.rllib.bc.bc_evaluator import BCEvaluator, GPURemoteBCEvaluator, \
|
||||
RemoteBCEvaluator
|
||||
from ray.rllib.optimizers import AsyncOptimizer
|
||||
from ray.tune.result import TrainingResult
|
||||
|
||||
DEFAULT_CONFIG = {
|
||||
# Number of workers (excluding master)
|
||||
"num_workers": 4,
|
||||
# Size of rollout batch
|
||||
"batch_size": 100,
|
||||
# Max global norm for each gradient calculated by worker
|
||||
"grad_clip": 40.0,
|
||||
# Learning rate
|
||||
"lr": 0.0001,
|
||||
# Whether to place workers on GPUs
|
||||
"use_gpu_for_workers": False,
|
||||
# Model and preprocessor options
|
||||
"model": {
|
||||
# (Image statespace) - Converts image to Channels = 1
|
||||
"grayscale": True,
|
||||
# (Image statespace) - Each pixel
|
||||
"zero_mean": False,
|
||||
# (Image statespace) - Converts image to (dim, dim, C)
|
||||
"dim": 80,
|
||||
# (Image statespace) - Converts image shape to (C, dim, dim)
|
||||
"channel_major": False
|
||||
},
|
||||
# Arguments to pass to the rllib optimizer
|
||||
"optimizer": {
|
||||
# Number of gradients applied for each `train` step
|
||||
"grads_per_step": 100,
|
||||
},
|
||||
# Arguments to pass to the env creator
|
||||
"env_config": {},
|
||||
}
|
||||
|
||||
|
||||
class BCAgent(Agent):
|
||||
_agent_name = "BC"
|
||||
_default_config = DEFAULT_CONFIG
|
||||
_allow_unknown_configs = True
|
||||
|
||||
def _init(self):
|
||||
self.local_evaluator = BCEvaluator(
|
||||
self.registry, self.env_creator, self.config, self.logdir)
|
||||
if self.config["use_gpu_for_workers"]:
|
||||
remote_cls = GPURemoteBCEvaluator
|
||||
else:
|
||||
remote_cls = RemoteBCEvaluator
|
||||
self.remote_evaluators = [
|
||||
remote_cls.remote(
|
||||
self.registry, self.env_creator, self.config, self.logdir)
|
||||
for _ in range(self.config["num_workers"])]
|
||||
self.optimizer = AsyncOptimizer(
|
||||
self.config["optimizer"], self.local_evaluator,
|
||||
self.remote_evaluators)
|
||||
|
||||
def _train(self):
|
||||
self.optimizer.step()
|
||||
metric_lists = [re.get_metrics.remote() for re in
|
||||
self.remote_evaluators]
|
||||
total_samples = 0
|
||||
total_loss = 0
|
||||
for metrics in metric_lists:
|
||||
for m in ray.get(metrics):
|
||||
total_samples += m["num_samples"]
|
||||
total_loss += m["loss"]
|
||||
result = TrainingResult(
|
||||
mean_loss=total_loss / total_samples,
|
||||
timesteps_this_iter=total_samples,
|
||||
)
|
||||
return result
|
||||
|
||||
def compute_action(self, observation):
|
||||
action, info = self.local_evaluator.policy.compute(observation)
|
||||
return action
|
65
python/ray/rllib/bc/bc_evaluator.py
Normal file
65
python/ray/rllib/bc/bc_evaluator.py
Normal file
|
@ -0,0 +1,65 @@
|
|||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import pickle
|
||||
import queue
|
||||
|
||||
import ray
|
||||
from ray.rllib.bc.experience_dataset import ExperienceDataset
|
||||
from ray.rllib.bc.policy import BCPolicy
|
||||
from ray.rllib.models import ModelCatalog
|
||||
from ray.rllib.optimizers import Evaluator
|
||||
|
||||
|
||||
class BCEvaluator(Evaluator):
|
||||
def __init__(self, registry, env_creator, config, logdir):
|
||||
env = ModelCatalog.get_preprocessor_as_wrapper(registry, env_creator(
|
||||
config["env_config"]), config["model"])
|
||||
self.dataset = ExperienceDataset(config["dataset_path"])
|
||||
# TODO(rliaw): should change this to be just env.observation_space
|
||||
self.policy = BCPolicy(registry, env.observation_space.shape,
|
||||
env.action_space, config)
|
||||
self.config = config
|
||||
self.logdir = logdir
|
||||
self.metrics_queue = queue.Queue()
|
||||
|
||||
def sample(self):
|
||||
return self.dataset.sample(self.config["batch_size"])
|
||||
|
||||
def compute_gradients(self, samples):
|
||||
gradient, info = self.policy.compute_gradients(samples)
|
||||
self.metrics_queue.put(
|
||||
{"num_samples": info["num_samples"], "loss": info["loss"]})
|
||||
return gradient
|
||||
|
||||
def apply_gradients(self, grads):
|
||||
self.policy.apply_gradients(grads)
|
||||
|
||||
def get_weights(self):
|
||||
return self.policy.get_weights()
|
||||
|
||||
def set_weights(self, params):
|
||||
self.policy.set_weights(params)
|
||||
|
||||
def save(self):
|
||||
weights = self.get_weights()
|
||||
return pickle.dumps({
|
||||
"weights": weights})
|
||||
|
||||
def restore(self, objs):
|
||||
objs = pickle.loads(objs)
|
||||
self.set_weights(objs["weights"])
|
||||
|
||||
def get_metrics(self):
|
||||
completed = []
|
||||
while True:
|
||||
try:
|
||||
completed.append(self.metrics_queue.get_nowait())
|
||||
except queue.Empty:
|
||||
break
|
||||
return completed
|
||||
|
||||
|
||||
RemoteBCEvaluator = ray.remote(BCEvaluator)
|
||||
GPURemoteBCEvaluator = ray.remote(num_gpus=1)(BCEvaluator)
|
33
python/ray/rllib/bc/experience_dataset.py
Normal file
33
python/ray/rllib/bc/experience_dataset.py
Normal file
|
@ -0,0 +1,33 @@
|
|||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import itertools
|
||||
import pickle
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
class ExperienceDataset(object):
|
||||
def __init__(self, dataset_path):
|
||||
"""Create dataset of experience to imitate.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
dataset_path:
|
||||
Path of file containing the database as pickled list of trajectories,
|
||||
each trajectory being a list of steps,
|
||||
each step containing the observation and action as its first two
|
||||
elements.
|
||||
The file must be available on each machine used by a BCEvaluator.
|
||||
"""
|
||||
self._dataset = list(itertools.chain.from_iterable(
|
||||
pickle.load(open(dataset_path, "rb"))))
|
||||
|
||||
def sample(self, batch_size):
|
||||
indexes = np.random.choice(len(self._dataset), batch_size)
|
||||
samples = {
|
||||
'observations': [self._dataset[i][0] for i in indexes],
|
||||
'actions': [self._dataset[i][1] for i in indexes]
|
||||
}
|
||||
return samples
|
102
python/ray/rllib/bc/policy.py
Normal file
102
python/ray/rllib/bc/policy.py
Normal file
|
@ -0,0 +1,102 @@
|
|||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import ray
|
||||
import tensorflow as tf
|
||||
from ray.rllib.a3c.policy import Policy
|
||||
from ray.rllib.models.catalog import ModelCatalog
|
||||
|
||||
|
||||
class BCPolicy(Policy):
|
||||
def __init__(self, registry, ob_space, action_space, config, name="local",
|
||||
summarize=True):
|
||||
super(BCPolicy, self).__init__(ob_space, action_space, name, summarize)
|
||||
self.registry = registry
|
||||
self.local_steps = 0
|
||||
self.config = config
|
||||
self.summarize = summarize
|
||||
worker_device = "/job:localhost/replica:0/task:0/cpu:0"
|
||||
self.g = tf.Graph()
|
||||
with self.g.as_default(), tf.device(worker_device):
|
||||
with tf.variable_scope(name):
|
||||
self._setup_graph(ob_space, action_space)
|
||||
print("Setting up loss")
|
||||
self.setup_loss(action_space)
|
||||
self.setup_gradients()
|
||||
self.initialize()
|
||||
|
||||
def _setup_graph(self, ob_space, ac_space):
|
||||
self.x = tf.placeholder(tf.float32, [None] + list(ob_space))
|
||||
dist_class, self.logit_dim = ModelCatalog.get_action_dist(ac_space)
|
||||
self._model = ModelCatalog.get_model(
|
||||
self.registry, self.x, self.logit_dim, self.config["model"])
|
||||
self.logits = self._model.outputs
|
||||
self.curr_dist = dist_class(self.logits)
|
||||
self.sample = self.curr_dist.sample()
|
||||
self.var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
|
||||
tf.get_variable_scope().name)
|
||||
|
||||
def setup_loss(self, action_space):
|
||||
self.ac = tf.placeholder(tf.int64, [None], name="ac")
|
||||
log_prob = self.curr_dist.logp(self.ac)
|
||||
self.pi_loss = - tf.reduce_sum(log_prob)
|
||||
self.loss = self.pi_loss
|
||||
|
||||
def setup_gradients(self):
|
||||
grads = tf.gradients(self.loss, self.var_list)
|
||||
self.grads, _ = tf.clip_by_global_norm(grads, self.config["grad_clip"])
|
||||
grads_and_vars = list(zip(self.grads, self.var_list))
|
||||
opt = tf.train.AdamOptimizer(self.config["lr"])
|
||||
self._apply_gradients = opt.apply_gradients(grads_and_vars)
|
||||
|
||||
def initialize(self):
|
||||
if self.summarize:
|
||||
bs = tf.to_float(tf.shape(self.x)[0])
|
||||
tf.summary.scalar("model/policy_loss", self.pi_loss / bs)
|
||||
tf.summary.scalar("model/grad_gnorm", tf.global_norm(self.grads))
|
||||
tf.summary.scalar("model/var_gnorm", tf.global_norm(self.var_list))
|
||||
self.summary_op = tf.summary.merge_all()
|
||||
|
||||
# TODO(rliaw): Can consider exposing these parameters
|
||||
self.sess = tf.Session(graph=self.g, config=tf.ConfigProto(
|
||||
intra_op_parallelism_threads=1, inter_op_parallelism_threads=2,
|
||||
gpu_options=tf.GPUOptions(allow_growth=True)))
|
||||
self.variables = ray.experimental.TensorFlowVariables(self.loss,
|
||||
self.sess)
|
||||
self.sess.run(tf.global_variables_initializer())
|
||||
|
||||
def compute_gradients(self, samples):
|
||||
info = {}
|
||||
feed_dict = {
|
||||
self.x: samples["observations"],
|
||||
self.ac: samples["actions"]
|
||||
}
|
||||
self.grads = [g for g in self.grads if g is not None]
|
||||
self.local_steps += 1
|
||||
if self.summarize:
|
||||
loss, grad, summ = self.sess.run(
|
||||
[self.loss, self.grads, self.summary_op], feed_dict=feed_dict)
|
||||
info["summary"] = summ
|
||||
else:
|
||||
loss, grad = self.sess.run([self.loss, self.grads],
|
||||
feed_dict=feed_dict)
|
||||
info["num_samples"] = len(samples)
|
||||
info["loss"] = loss
|
||||
return grad, info
|
||||
|
||||
def apply_gradients(self, grads):
|
||||
feed_dict = {self.grads[i]: grads[i]
|
||||
for i in range(len(grads))}
|
||||
self.sess.run(self._apply_gradients, feed_dict=feed_dict)
|
||||
|
||||
def get_weights(self):
|
||||
weights = self.variables.get_weights()
|
||||
return weights
|
||||
|
||||
def set_weights(self, weights):
|
||||
self.variables.set_weights(weights)
|
||||
|
||||
def compute(self, ob, *args):
|
||||
action = self.sess.run(self.sample, {self.x: [ob]})
|
||||
return action, None
|
|
@ -1,77 +0,0 @@
|
|||
#!/usr/bin/env python
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import argparse
|
||||
import gym
|
||||
import json
|
||||
import os
|
||||
import ray
|
||||
|
||||
from ray.rllib.agent import get_agent_class
|
||||
|
||||
|
||||
EXAMPLE_USAGE = """
|
||||
example usage:
|
||||
./eval.py /tmp/ray/checkpoint_dir/checkpoint-0 --run DQN --env CartPole-v0
|
||||
"""
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
description="Evaluates a reinforcement learning agent "
|
||||
"given a checkpoint.", epilog=EXAMPLE_USAGE)
|
||||
|
||||
parser.add_argument(
|
||||
"checkpoint", type=str, help="Checkpoint from which to evaluate.")
|
||||
required_named = parser.add_argument_group("required named arguments")
|
||||
required_named.add_argument(
|
||||
"--run", type=str, required=True,
|
||||
help="The algorithm or model to train. This may refer to the name "
|
||||
"of a built-on algorithm (e.g. RLLib's DQN or PPO), or a "
|
||||
"user-defined trainable function or class registered in the "
|
||||
"tune registry.")
|
||||
required_named.add_argument(
|
||||
"--env", type=str, help="The gym environment to use.")
|
||||
parser.add_argument(
|
||||
"--no-render", default=False, action="store_const", const=True,
|
||||
help="Surpress rendering of the environment.")
|
||||
parser.add_argument(
|
||||
"--loop-forever", default=False, action="store_const", const=True,
|
||||
help="Run evaluation of the agent forever.")
|
||||
parser.add_argument(
|
||||
"--config", default="{}", type=json.loads,
|
||||
help="Algorithm-specific configuration (e.g. env, hyperparams). "
|
||||
"Surpresses loading of configuration from checkpoint.")
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parser.parse_args()
|
||||
|
||||
if not args.config:
|
||||
# Load configuration from file
|
||||
config_dir = os.path.dirname(args.checkpoint)
|
||||
config_path = os.path.join(config_dir, "params.json")
|
||||
with open(config_path) as f:
|
||||
args.config = json.load(f)
|
||||
|
||||
if not args.env:
|
||||
if not args.config.get("env"):
|
||||
parser.error("the following arguments are required: --env")
|
||||
args.env = args.config.get("env")
|
||||
|
||||
ray.init()
|
||||
|
||||
cls = get_agent_class(args.run)
|
||||
agent = cls(env=args.env)
|
||||
agent.restore(args.checkpoint)
|
||||
|
||||
env = gym.make(args.env)
|
||||
state = env.reset()
|
||||
done = False
|
||||
while args.loop_forever or not done:
|
||||
action = agent.compute_action(state)
|
||||
state, reward, done, _ = env.step(action)
|
||||
if not args.no_render:
|
||||
env.render()
|
|
@ -7,8 +7,8 @@ class Optimizer(object):
|
|||
"""RLlib optimizers encapsulate distributed RL optimization strategies.
|
||||
|
||||
For example, AsyncOptimizer is used for A3C, and LocalMultiGPUOptimizer is
|
||||
used for PPO. These optimizers are all pluggable however, it is possible
|
||||
to mix as match as needed.
|
||||
used for PPO. These optimizers are all pluggable, and it is possible
|
||||
to mix and match as needed.
|
||||
|
||||
In order for an algorithm to use an RLlib optimizer, it must implement
|
||||
the Evaluator interface and pass a number of Evaluators to its Optimizer
|
||||
|
|
95
python/ray/rllib/rollout.py
Normal file
95
python/ray/rllib/rollout.py
Normal file
|
@ -0,0 +1,95 @@
|
|||
#!/usr/bin/env python
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import pickle
|
||||
|
||||
import gym
|
||||
import ray
|
||||
from ray.rllib.agent import get_agent_class
|
||||
from ray.rllib.models import ModelCatalog
|
||||
from ray.tune.registry import get_registry
|
||||
|
||||
EXAMPLE_USAGE = """
|
||||
example usage:
|
||||
./rollout.py /tmp/ray/checkpoint_dir/checkpoint-0 --run DQN """
|
||||
"""--env CartPole-v0 --steps 1000000 --out rollouts.pkl
|
||||
"""
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
description="Roll out a reinforcement learning agent "
|
||||
"given a checkpoint.", epilog=EXAMPLE_USAGE)
|
||||
|
||||
parser.add_argument(
|
||||
"checkpoint", type=str, help="Checkpoint from which to roll out.")
|
||||
required_named = parser.add_argument_group("required named arguments")
|
||||
required_named.add_argument(
|
||||
"--run", type=str, required=True,
|
||||
help="The algorithm or model to train. This may refer to the name "
|
||||
"of a built-on algorithm (e.g. RLLib's DQN or PPO), or a "
|
||||
"user-defined trainable function or class registered in the "
|
||||
"tune registry.")
|
||||
required_named.add_argument(
|
||||
"--env", type=str, help="The gym environment to use.")
|
||||
parser.add_argument(
|
||||
"--no-render", default=False, action="store_const", const=True,
|
||||
help="Surpress rendering of the environment.")
|
||||
parser.add_argument(
|
||||
"--steps", default=None, help="Number of steps to roll out.")
|
||||
parser.add_argument(
|
||||
"--out", default=None, help="Output filename.")
|
||||
parser.add_argument(
|
||||
"--config", default="{}", type=json.loads,
|
||||
help="Algorithm-specific configuration (e.g. env, hyperparams). "
|
||||
"Surpresses loading of configuration from checkpoint.")
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parser.parse_args()
|
||||
|
||||
if not args.config:
|
||||
# Load configuration from file
|
||||
config_dir = os.path.dirname(args.checkpoint)
|
||||
config_path = os.path.join(config_dir, "params.json")
|
||||
with open(config_path) as f:
|
||||
args.config = json.load(f)
|
||||
|
||||
if not args.env:
|
||||
if not args.config.get("env"):
|
||||
parser.error("the following arguments are required: --env")
|
||||
args.env = args.config.get("env")
|
||||
|
||||
ray.init()
|
||||
|
||||
cls = get_agent_class(args.run)
|
||||
agent = cls(env=args.env)
|
||||
agent.restore(args.checkpoint)
|
||||
num_steps = int(args.steps)
|
||||
|
||||
env = ModelCatalog.get_preprocessor_as_wrapper(get_registry(),
|
||||
gym.make(args.env))
|
||||
if args.out is not None:
|
||||
rollouts = []
|
||||
steps = 0
|
||||
while steps < (num_steps or steps + 1):
|
||||
if args.out is not None:
|
||||
rollout = []
|
||||
state = env.reset()
|
||||
done = False
|
||||
while not done and steps < (num_steps or steps + 1):
|
||||
action = agent.compute_action(state)
|
||||
next_state, reward, done, _ = env.step(action)
|
||||
if not args.no_render:
|
||||
env.render()
|
||||
if args.out is not None:
|
||||
rollout.append([state, action, next_state, reward, done])
|
||||
steps += 1
|
||||
if args.out is not None:
|
||||
rollouts.append(rollout)
|
||||
if args.out is not None:
|
||||
pickle.dump(rollouts, open(args.out, "wb"))
|
Loading…
Add table
Reference in a new issue