mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
[rllib] unify writing performance metrics and make it queryable (#708)
* write config to s3 * add train file * write performance to S3 * writing needs to be fixed, replacing result.json at the moment * update * add experiment_id * more logging and example queries * update * add info * fill in other algorithms * fix linting * convert readme to rst * fixes * simplejson -> json * make files executable * edit README.rst * unify storing logs in S3 and on local filesystem * use 'info' entry in TrainingResult for algorithm specific info * don't install smart_open with ray * fixes * linting fixes
This commit is contained in:
parent
8464d77c76
commit
c24c07613c
8 changed files with 298 additions and 58 deletions
80
python/ray/rllib/README.rst
Normal file
80
python/ray/rllib/README.rst
Normal file
|
@ -0,0 +1,80 @@
|
|||
RLLib: Ray's modular and scalable reinforcement learning library
|
||||
================================================================
|
||||
|
||||
Getting Started
|
||||
---------------
|
||||
|
||||
You can run training with
|
||||
|
||||
::
|
||||
|
||||
python train.py --env CartPole-v0 --alg PolicyGradient
|
||||
|
||||
The available algorithms are:
|
||||
|
||||
- ``PolicyGradient`` is a proximal variant of
|
||||
`TRPO <https://arxiv.org/abs/1502.05477>`__.
|
||||
|
||||
- ``EvolutionStrategies`` is decribed in `this
|
||||
paper <https://arxiv.org/abs/1703.03864>`__. Our implementation
|
||||
borrows code from
|
||||
`here <https://github.com/openai/evolution-strategies-starter>`__.
|
||||
|
||||
- ``DQN`` is an implementation of `Deep Q
|
||||
Networks <https://www.cs.toronto.edu/~vmnih/docs/dqn.pdf>`__ based on
|
||||
`OpenAI baselines <https://github.com/openai/baselines>`__.
|
||||
|
||||
- ``A3C`` is an implementation of
|
||||
`A3C <https://arxiv.org/abs/1602.01783>`__ based on `the OpenAI
|
||||
starter agent <https://github.com/openai/universe-starter-agent>`__.
|
||||
|
||||
Storing logs
|
||||
------------
|
||||
|
||||
You can store the algorithm configuration (including hyperparameters) and
|
||||
training results on a filesystem with the ``--upload-dir`` flag. Two protocols
|
||||
are supported at the moment:
|
||||
|
||||
- ``--upload-dir file:///tmp/ray/`` will store the logs on the local filesystem
|
||||
in a subdirectory of /tmp/ray which is named after the algorithm name, the
|
||||
environment and the current date. This is the default.
|
||||
|
||||
- ``--upload-dir s3://bucketname/`` will store the logs in S3. Not that if you
|
||||
store the logs in S3, TensorFlow files will not currently be stored because
|
||||
TensorFlow doesn't support directly uploading files to S3 at the moment.
|
||||
|
||||
Querying logs with Athena
|
||||
-------------------------
|
||||
|
||||
If you stored the logs in S3 or uploaded them there from the local file system,
|
||||
they can be queried with Athena. First create tables containing the
|
||||
experimental results with
|
||||
|
||||
.. code:: sql
|
||||
|
||||
CREATE EXTERNAL TABLE IF NOT EXISTS experiments (
|
||||
experiment_id STRING,
|
||||
env_name STRING,
|
||||
alg STRING,
|
||||
-- result.json
|
||||
training_iteration INT,
|
||||
episode_reward_mean FLOAT,
|
||||
episode_len_mean FLOAT
|
||||
) ROW FORMAT serde 'org.apache.hive.hcatalog.data.JsonSerDe'
|
||||
LOCATION 's3://bucketname/'
|
||||
|
||||
and then you can for example visualize the results with
|
||||
|
||||
.. code:: sql
|
||||
|
||||
SELECT c.experiment_id, c.env_name, c.alg, a.episode_reward_mean, a.episode_len_mean
|
||||
FROM experiments a
|
||||
LEFT OUTER JOIN experiments b
|
||||
ON a.experiment_id = b.experiment_id AND a.training_iteration < b.training_iteration
|
||||
INNER JOIN experiments c
|
||||
ON a.experiment_id = c.experiment_id
|
||||
WHERE b.experiment_id IS NULL AND a.training_iteration IS NOT NULL AND c.alg is NOT NULL;
|
||||
|
||||
This query selects last iteration from each experiment (see `this
|
||||
stackoverflow
|
||||
post <https://stackoverflow.com/questions/7745609/sql-select-only-rows-with-max-value-on-a-column>`__).
|
|
@ -83,8 +83,9 @@ class Runner(object):
|
|||
|
||||
|
||||
class A3C(Algorithm):
|
||||
def __init__(self, env_name, config):
|
||||
Algorithm.__init__(self, env_name, config)
|
||||
def __init__(self, env_name, config, upload_dir=None):
|
||||
config.update({"alg": "A3C"})
|
||||
Algorithm.__init__(self, env_name, config, upload_dir=upload_dir)
|
||||
self.env = create_env(env_name)
|
||||
self.policy = LSTMPolicy(
|
||||
self.env.observation_space.shape, self.env.action_space.n, 0)
|
||||
|
@ -123,5 +124,6 @@ class A3C(Algorithm):
|
|||
episode_lengths.append(episode.episode_length)
|
||||
episode_rewards.append(episode.episode_reward)
|
||||
res = TrainingResult(
|
||||
self.iteration, np.mean(episode_rewards), np.mean(episode_lengths))
|
||||
self.experiment_id.hex, self.iteration,
|
||||
np.mean(episode_rewards), np.mean(episode_lengths), dict())
|
||||
return res
|
||||
|
|
|
@ -2,17 +2,54 @@ from collections import namedtuple
|
|||
from datetime import datetime
|
||||
import json
|
||||
import logging
|
||||
import numpy as np
|
||||
import os
|
||||
import sys
|
||||
import tempfile
|
||||
import uuid
|
||||
import smart_open
|
||||
if sys.version_info[0] == 2:
|
||||
import cStringIO as StringIO
|
||||
elif sys.version_info[0] == 3:
|
||||
import io as StringIO
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
|
||||
class RLLibEncoder(json.JSONEncoder):
|
||||
def default(self, value):
|
||||
if isinstance(value, np.float32) or isinstance(value, np.float64):
|
||||
if np.isnan(value):
|
||||
return None
|
||||
else:
|
||||
return float(value)
|
||||
|
||||
|
||||
class RLLibLogger(object):
|
||||
"""Writing small amounts of data to S3 with real-time updates.
|
||||
"""
|
||||
|
||||
def __init__(self, uri):
|
||||
self.result_buffer = StringIO.StringIO()
|
||||
self.uri = uri
|
||||
|
||||
def write(self, b):
|
||||
# TODO(pcm): At the moment we are writing the whole results output from
|
||||
# the beginning in each iteration. This will write O(n^2) bytes where n
|
||||
# is the number of bytes printed so far. Fix this! This should at least
|
||||
# only write the last 5MBs (S3 chunksize).
|
||||
with smart_open.smart_open(self.uri, "w") as f:
|
||||
self.result_buffer.write(b)
|
||||
f.write(self.result_buffer.getvalue())
|
||||
|
||||
|
||||
TrainingResult = namedtuple("TrainingResult", [
|
||||
"experiment_id",
|
||||
"training_iteration",
|
||||
"episode_reward_mean",
|
||||
"episode_len_mean",
|
||||
"info"
|
||||
])
|
||||
|
||||
|
||||
|
@ -30,18 +67,32 @@ class Algorithm(object):
|
|||
TODO(ekl): support checkpoint / restore of training state.
|
||||
"""
|
||||
|
||||
def __init__(self, env_name, config):
|
||||
def __init__(self, env_name, config, upload_dir="file:///tmp/ray"):
|
||||
"""Initialize an RLLib algorithm.
|
||||
|
||||
Args:
|
||||
env_name (str): The name of the OpenAI gym environment to use.
|
||||
config (obj): Algorithm-specific configuration data.
|
||||
upload_dir (str): Root directory into which the output directory
|
||||
should be placed. Can be local like file:///tmp/ray/ or on S3
|
||||
like s3://bucketname/.
|
||||
"""
|
||||
self.experiment_id = uuid.uuid4()
|
||||
self.env_name = env_name
|
||||
self.config = config
|
||||
self.logdir = tempfile.mkdtemp(
|
||||
prefix="{}_{}_{}".format(
|
||||
env_name,
|
||||
self.__class__.__name__,
|
||||
datetime.today().strftime("%Y-%m-%d_%H-%M-%S")),
|
||||
dir="/tmp/ray")
|
||||
json.dump(
|
||||
self.config, open(os.path.join(self.logdir, "config.json"), "w"),
|
||||
sort_keys=True, indent=4)
|
||||
self.config.update({"experiment_id": self.experiment_id.hex})
|
||||
self.config.update({"env_name": env_name})
|
||||
prefix = "{}_{}_{}".format(
|
||||
env_name,
|
||||
self.__class__.__name__,
|
||||
datetime.today().strftime("%Y-%m-%d_%H-%M-%S"))
|
||||
if upload_dir.startswith("file"):
|
||||
self.logdir = "file://" + tempfile.mkdtemp(prefix=prefix, dir="/tmp/ray")
|
||||
else:
|
||||
self.logdir = os.path.join(upload_dir, prefix)
|
||||
log_path = os.path.join(self.logdir, "config.json")
|
||||
with smart_open.smart_open(log_path, "w") as f:
|
||||
json.dump(self.config, f, sort_keys=True, cls=RLLibEncoder)
|
||||
logger.info(
|
||||
"%s algorithm created with logdir '%s'",
|
||||
self.__class__.__name__, self.logdir)
|
||||
|
|
|
@ -88,8 +88,9 @@ DEFAULT_CONFIG = dict(
|
|||
|
||||
|
||||
class DQN(Algorithm):
|
||||
def __init__(self, env_name, config):
|
||||
Algorithm.__init__(self, env_name, config)
|
||||
def __init__(self, env_name, config, upload_dir=None):
|
||||
config.update({"alg": "DQN"})
|
||||
Algorithm.__init__(self, env_name, config, upload_dir=upload_dir)
|
||||
env = gym.make(env_name)
|
||||
env = ScaledFloatFrame(wrap_dqn(env))
|
||||
self.env = env
|
||||
|
@ -193,6 +194,15 @@ class DQN(Algorithm):
|
|||
mean_100ep_reward = round(np.mean(self.episode_rewards[-101:-1]), 1)
|
||||
mean_100ep_length = round(np.mean(self.episode_lengths[-101:-1]), 1)
|
||||
num_episodes = len(self.episode_rewards)
|
||||
|
||||
info = {
|
||||
"sample_time": sample_time,
|
||||
"learn_time": learn_time,
|
||||
"steps": self.num_timesteps,
|
||||
"episodes": num_episodes,
|
||||
"exploration": int(100 * self.exploration.value(t))
|
||||
}
|
||||
|
||||
logger.record_tabular("sample_time", sample_time)
|
||||
logger.record_tabular("learn_time", learn_time)
|
||||
logger.record_tabular("steps", self.num_timesteps)
|
||||
|
@ -203,6 +213,7 @@ class DQN(Algorithm):
|
|||
logger.dump_tabular()
|
||||
|
||||
res = TrainingResult(
|
||||
self.num_iterations, mean_100ep_reward, mean_100ep_length)
|
||||
self.experiment_id.hex, self.num_iterations, mean_100ep_reward,
|
||||
mean_100ep_length, info)
|
||||
self.num_iterations += 1
|
||||
return res
|
||||
|
|
|
@ -21,19 +21,13 @@ from ray.rllib.evolution_strategies import tf_util
|
|||
from ray.rllib.evolution_strategies import utils
|
||||
|
||||
|
||||
Config = namedtuple("Config", [
|
||||
"l2coeff", "noise_stdev", "episodes_per_batch", "timesteps_per_batch",
|
||||
"calc_obstat_prob", "eval_prob", "snapshot_freq", "return_proc_mode",
|
||||
"episode_cutoff_mode", "num_workers", "stepsize"
|
||||
])
|
||||
|
||||
Result = namedtuple("Result", [
|
||||
"noise_inds_n", "returns_n2", "sign_returns_n2", "lengths_n2",
|
||||
"eval_return", "eval_length", "ob_sum", "ob_sumsq", "ob_count"
|
||||
])
|
||||
|
||||
|
||||
DEFAULT_CONFIG = Config(
|
||||
DEFAULT_CONFIG = dict(
|
||||
l2coeff=0.005,
|
||||
noise_stdev=0.02,
|
||||
episodes_per_batch=10000,
|
||||
|
@ -86,11 +80,11 @@ class Worker(object):
|
|||
|
||||
self.rs = np.random.RandomState()
|
||||
|
||||
assert self.policy.needs_ob_stat == (self.config.calc_obstat_prob != 0)
|
||||
assert self.policy.needs_ob_stat == (self.config["calc_obstat_prob"] != 0)
|
||||
|
||||
def rollout_and_update_ob_stat(self, timestep_limit, task_ob_stat):
|
||||
if (self.policy.needs_ob_stat and self.config.calc_obstat_prob != 0 and
|
||||
self.rs.rand() < self.config.calc_obstat_prob):
|
||||
if (self.policy.needs_ob_stat and self.config["calc_obstat_prob"] != 0 and
|
||||
self.rs.rand() < self.config["calc_obstat_prob"]):
|
||||
rollout_rews, rollout_len, obs = self.policy.rollout(
|
||||
self.env, timestep_limit=timestep_limit, save_obs=True,
|
||||
random_stream=self.rs)
|
||||
|
@ -108,7 +102,7 @@ class Worker(object):
|
|||
if self.policy.needs_ob_stat:
|
||||
self.policy.set_ob_stat(ob_mean, ob_std)
|
||||
|
||||
if self.config.eval_prob != 0:
|
||||
if self.config["eval_prob"] != 0:
|
||||
raise NotImplementedError("Eval rollouts are not implemented.")
|
||||
|
||||
noise_inds, returns, sign_returns, lengths = [], [], [], []
|
||||
|
@ -120,7 +114,7 @@ class Worker(object):
|
|||
while (len(noise_inds) == 0 or
|
||||
time.time() - task_tstart < self.min_task_runtime):
|
||||
noise_idx = self.noise.sample_index(self.rs, self.policy.num_params)
|
||||
perturbation = self.config.noise_stdev * self.noise.get(
|
||||
perturbation = self.config["noise_stdev"] * self.noise.get(
|
||||
noise_idx, self.policy.num_params)
|
||||
|
||||
# These two sampling steps could be done in parallel on different actors
|
||||
|
@ -151,8 +145,10 @@ class Worker(object):
|
|||
|
||||
|
||||
class EvolutionStrategies(Algorithm):
|
||||
def __init__(self, env_name, config):
|
||||
Algorithm.__init__(self, env_name, config)
|
||||
def __init__(self, env_name, config, upload_dir=None):
|
||||
config.update({"alg": "EvolutionStrategies"})
|
||||
|
||||
Algorithm.__init__(self, env_name, config, upload_dir=upload_dir)
|
||||
|
||||
policy_params = {
|
||||
"ac_bins": "continuous:",
|
||||
|
@ -170,14 +166,14 @@ class EvolutionStrategies(Algorithm):
|
|||
# Create the actors.
|
||||
print("Creating actors.")
|
||||
self.workers = [Worker.remote(config, policy_params, env_name, noise_id)
|
||||
for _ in range(config.num_workers)]
|
||||
for _ in range(config["num_workers"])]
|
||||
|
||||
env = gym.make(env_name)
|
||||
utils.make_session(single_threaded=False)
|
||||
self.policy = policies.MujocoPolicy(
|
||||
env.observation_space, env.action_space, **policy_params)
|
||||
tf_util.initialize()
|
||||
self.optimizer = optimizers.Adam(self.policy, config.stepsize)
|
||||
self.optimizer = optimizers.Adam(self.policy, config["stepsize"])
|
||||
self.ob_stat = utils.RunningStat(env.observation_space.shape, eps=1e-2)
|
||||
|
||||
self.episodes_so_far = 0
|
||||
|
@ -233,10 +229,10 @@ class EvolutionStrategies(Algorithm):
|
|||
lengths_n2 = np.concatenate([r.lengths_n2 for r in curr_task_results])
|
||||
assert noise_inds_n.shape[0] == returns_n2.shape[0] == lengths_n2.shape[0]
|
||||
# Process the returns.
|
||||
if config.return_proc_mode == "centered_rank":
|
||||
if config["return_proc_mode"] == "centered_rank":
|
||||
proc_returns_n2 = utils.compute_centered_ranks(returns_n2)
|
||||
else:
|
||||
raise NotImplementedError(config.return_proc_mode)
|
||||
raise NotImplementedError(config["return_proc_mode"])
|
||||
|
||||
# Compute and take a step.
|
||||
g, count = utils.batched_weighted_sum(
|
||||
|
@ -246,7 +242,7 @@ class EvolutionStrategies(Algorithm):
|
|||
g /= returns_n2.size
|
||||
assert (g.shape == (self.policy.num_params,) and g.dtype == np.float32 and
|
||||
count == len(noise_inds_n))
|
||||
update_ratio = self.optimizer.update(-g + config.l2coeff * theta)
|
||||
update_ratio = self.optimizer.update(-g + config["l2coeff"] * theta)
|
||||
|
||||
# Update ob stat (we're never running the policy in the master, but we
|
||||
# might be snapshotting the policy).
|
||||
|
@ -274,14 +270,29 @@ class EvolutionStrategies(Algorithm):
|
|||
tlogger.record_tabular("TimeElapsed", step_tend - self.tstart)
|
||||
tlogger.dump_tabular()
|
||||
|
||||
if (config.snapshot_freq != 0 and
|
||||
self.iteration % config.snapshot_freq == 0):
|
||||
if (config["snapshot_freq"] != 0 and
|
||||
self.iteration % config["snapshot_freq"] == 0):
|
||||
filename = os.path.join(
|
||||
self.logdir, "snapshot_iter{:05d}.h5".format(self.iteration))
|
||||
assert not os.path.exists(filename)
|
||||
self.policy.save(filename)
|
||||
tlogger.log("Saved snapshot {}".format(filename))
|
||||
|
||||
res = TrainingResult(self.iteration, returns_n2.mean(), lengths_n2.mean())
|
||||
info = {
|
||||
"weights_norm": np.square(self.policy.get_trainable_flat()).sum(),
|
||||
"grad_norm": np.square(g).sum(),
|
||||
"update_ratio": update_ratio,
|
||||
"episodes_this_iter": lengths_n2.size,
|
||||
"episodes_so_far": self.episodes_so_far,
|
||||
"timesteps_this_iter": lengths_n2.sum(),
|
||||
"timesteps_so_far": self.timesteps_so_far,
|
||||
"ob_count": ob_count_this_batch,
|
||||
"time_elapsed_this_iter": step_tend - step_tstart,
|
||||
"time_elapsed": step_tend - self.tstart
|
||||
}
|
||||
res = TrainingResult(self.experiment_id.hex, self.iteration,
|
||||
returns_n2.mean(), lengths_n2.mean(), info)
|
||||
|
||||
self.iteration += 1
|
||||
|
||||
return res
|
||||
|
|
|
@ -43,8 +43,10 @@ DEFAULT_CONFIG = {
|
|||
|
||||
|
||||
class PolicyGradient(Algorithm):
|
||||
def __init__(self, env_name, config):
|
||||
Algorithm.__init__(self, env_name, config)
|
||||
def __init__(self, env_name, config, upload_dir=None):
|
||||
config.update({"alg": "PolicyGradient"})
|
||||
|
||||
Algorithm.__init__(self, env_name, config, upload_dir=upload_dir)
|
||||
|
||||
# TODO(ekl) the preprocessor should be associated with the env elsewhere
|
||||
if self.env_name == "Pong-v0":
|
||||
|
@ -81,13 +83,16 @@ class PolicyGradient(Algorithm):
|
|||
if "load_checkpoint" in config:
|
||||
saver.restore(model.sess, config["load_checkpoint"])
|
||||
|
||||
file_writer = tf.summary.FileWriter(self.logdir, model.sess.graph)
|
||||
# TF does not support to write logs to S3 at the moment
|
||||
write_tf_logs = self.logdir.startswith("file")
|
||||
iter_start = time.time()
|
||||
if config["model_checkpoint_file"]:
|
||||
checkpoint_path = saver.save(
|
||||
model.sess,
|
||||
os.path.join(self.logdir, config["model_checkpoint_file"] % j))
|
||||
print("Checkpoint saved in file: %s" % checkpoint_path)
|
||||
if write_tf_logs:
|
||||
file_writer = tf.summary.FileWriter(self.logdir, model.sess.graph)
|
||||
if config["model_checkpoint_file"]:
|
||||
checkpoint_path = saver.save(
|
||||
model.sess,
|
||||
os.path.join(self.logdir, config["model_checkpoint_file"] % j))
|
||||
print("Checkpoint saved in file: %s" % checkpoint_path)
|
||||
checkpointing_end = time.time()
|
||||
weights = ray.put(model.get_weights())
|
||||
[a.load_weights.remote(weights) for a in agents]
|
||||
|
@ -96,14 +101,15 @@ class PolicyGradient(Algorithm):
|
|||
print("total reward is ", total_reward)
|
||||
print("trajectory length mean is ", traj_len_mean)
|
||||
print("timesteps:", trajectory["dones"].shape[0])
|
||||
traj_stats = tf.Summary(value=[
|
||||
tf.Summary.Value(
|
||||
tag="policy_gradient/rollouts/mean_reward",
|
||||
simple_value=total_reward),
|
||||
tf.Summary.Value(
|
||||
tag="policy_gradient/rollouts/traj_len_mean",
|
||||
simple_value=traj_len_mean)])
|
||||
file_writer.add_summary(traj_stats, self.global_step)
|
||||
if write_tf_logs:
|
||||
traj_stats = tf.Summary(value=[
|
||||
tf.Summary.Value(
|
||||
tag="policy_gradient/rollouts/mean_reward",
|
||||
simple_value=total_reward),
|
||||
tf.Summary.Value(
|
||||
tag="policy_gradient/rollouts/traj_len_mean",
|
||||
simple_value=traj_len_mean)])
|
||||
file_writer.add_summary(traj_stats, self.global_step)
|
||||
self.global_step += 1
|
||||
trajectory["advantages"] = ((trajectory["advantages"] -
|
||||
trajectory["advantages"].mean()) /
|
||||
|
@ -135,7 +141,8 @@ class PolicyGradient(Algorithm):
|
|||
batch_index == config["full_trace_nth_sgd_batch"])
|
||||
batch_loss, batch_kl, batch_entropy = model.run_sgd_minibatch(
|
||||
permutation[batch_index] * model.per_device_batch_size,
|
||||
self.kl_coeff, full_trace, file_writer)
|
||||
self.kl_coeff, full_trace,
|
||||
file_writer if write_tf_logs else None)
|
||||
loss.append(batch_loss)
|
||||
kl.append(batch_kl)
|
||||
entropy.append(batch_entropy)
|
||||
|
@ -164,8 +171,9 @@ class PolicyGradient(Algorithm):
|
|||
tf.Summary.Value(
|
||||
tag=metric_prefix + "mean_kl",
|
||||
simple_value=kl)])
|
||||
sgd_stats = tf.Summary(value=values)
|
||||
file_writer.add_summary(sgd_stats, self.global_step)
|
||||
if write_tf_logs:
|
||||
sgd_stats = tf.Summary(value=values)
|
||||
file_writer.add_summary(sgd_stats, self.global_step)
|
||||
self.global_step += 1
|
||||
sgd_time += sgd_end - sgd_start
|
||||
if kl > 2.0 * config["kl_target"]:
|
||||
|
@ -173,6 +181,17 @@ class PolicyGradient(Algorithm):
|
|||
elif kl < 0.5 * config["kl_target"]:
|
||||
self.kl_coeff *= 0.5
|
||||
|
||||
info = {
|
||||
"kl_divergence": kl,
|
||||
"kl_coefficient": self.kl_coeff,
|
||||
"checkpointing_time": checkpointing_time,
|
||||
"rollouts_time": rollouts_time,
|
||||
"shuffle_time": shuffle_time,
|
||||
"load_time": load_time,
|
||||
"sgd_time": sgd_time,
|
||||
"sample_throughput": len(trajectory["observations"]) / sgd_time
|
||||
}
|
||||
|
||||
print("kl div:", kl)
|
||||
print("kl coeff:", self.kl_coeff)
|
||||
print("checkpointing time:", checkpointing_time)
|
||||
|
@ -182,4 +201,7 @@ class PolicyGradient(Algorithm):
|
|||
print("sgd time:", sgd_time)
|
||||
print("sgd examples/s:", len(trajectory["observations"]) / sgd_time)
|
||||
|
||||
return TrainingResult(j, total_reward, traj_len_mean)
|
||||
result = TrainingResult(
|
||||
self.experiment_id.hex, j, total_reward, traj_len_mean, info)
|
||||
|
||||
return result
|
||||
|
|
6
python/ray/rllib/test.sh
Executable file
6
python/ray/rllib/test.sh
Executable file
|
@ -0,0 +1,6 @@
|
|||
#!/bin/bash
|
||||
|
||||
python train.py --env Walker2d-v1 --alg PolicyGradient --upload-dir s3://bucketname/
|
||||
python train.py --env PongNoFrameskip-v0 --alg DQN --upload-dir s3://bucketname/
|
||||
python train.py --env PongDeterministic-v0 --alg A3C --upload-dir s3://bucketname/
|
||||
python train.py --env Humanoid-v1 --alg EvolutionStrategies --upload-dir s3://bucketname/
|
57
python/ray/rllib/train.py
Executable file
57
python/ray/rllib/train.py
Executable file
|
@ -0,0 +1,57 @@
|
|||
#!/usr/bin/env python
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
|
||||
import ray
|
||||
import ray.rllib.policy_gradient as pg
|
||||
import ray.rllib.evolution_strategies as es
|
||||
import ray.rllib.dqn as dqn
|
||||
import ray.rllib.a3c as a3c
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
description=("Train a reinforcement learning agent."))
|
||||
parser.add_argument("--env", required=True, type=str)
|
||||
parser.add_argument("--alg", required=True, type=str)
|
||||
parser.add_argument("--upload-dir", default="file:///tmp/ray", type=str)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parser.parse_args()
|
||||
|
||||
ray.init()
|
||||
|
||||
env_name = args.env
|
||||
if args.alg == "PolicyGradient":
|
||||
alg = pg.PolicyGradient(
|
||||
env_name, pg.DEFAULT_CONFIG, upload_dir=args.upload_dir)
|
||||
elif args.alg == "EvolutionStrategies":
|
||||
alg = es.EvolutionStrategies(
|
||||
env_name, es.DEFAULT_CONFIG, upload_dir=args.upload_dir)
|
||||
elif args.alg == "DQN":
|
||||
alg = dqn.DQN(
|
||||
env_name, dqn.DEFAULT_CONFIG, upload_dir=args.upload_dir)
|
||||
elif args.alg == "A3C":
|
||||
alg = a3c.A3C(
|
||||
env_name, a3c.DEFAULT_CONFIG, upload_dir=args.upload_dir)
|
||||
else:
|
||||
assert False, ("Unknown algorithm, check --alg argument. Valid choices "
|
||||
"are PolicyGradientPolicyGradient, EvolutionStrategies, "
|
||||
"DQN and A3C.")
|
||||
|
||||
result_logger = ray.rllib.common.RLLibLogger(
|
||||
os.path.join(alg.logdir, "result.json"))
|
||||
|
||||
while True:
|
||||
result = alg.train()
|
||||
|
||||
# We need to use a custom json serializer class so that NaNs get encoded
|
||||
# as null as required by Athena.
|
||||
json.dump(result._asdict(), result_logger,
|
||||
cls=ray.rllib.common.RLLibEncoder)
|
||||
result_logger.write("\n")
|
Loading…
Add table
Reference in a new issue