ray/rllib/agents/es/es.py

340 lines
12 KiB
Python
Raw Normal View History

# Code in this file is copied and adapted from
# https://github.com/openai/evolution-strategies-starter.
from collections import namedtuple
import logging
import numpy as np
import time
import ray
from ray.rllib.agents import Trainer, with_common_config
[rllib] Document "v2" APIs (#2316) * re * wip * wip * a3c working * torch support * pg works * lint * rm v2 * consumer id * clean up pg * clean up more * fix python 2.7 * tf session management * docs * dqn wip * fix compile * dqn * apex runs * up * impotrs * ddpg * quotes * fix tests * fix last r * fix tests * lint * pass checkpoint restore * kwar * nits * policy graph * fix yapf * com * class * pyt * vectorization * update * test cpe * unit test * fix ddpg2 * changes * wip * args * faster test * common * fix * add alg option * batch mode and policy serving * multi serving test * todo * wip * serving test * doc async env * num envs * comments * thread * remove init hook * update * fix ppo * comments1 * fix * updates * add jenkins tests * fix * fix pytorch * fix * fixes * fix a3c policy * fix squeeze * fix trunc on apex * fix squeezing for real * update * remove horizon test for now * multiagent wip * update * fix race condition * fix ma * t * doc * st * wip * example * wip * working * cartpole * wip * batch wip * fix bug * make other_batches None default * working * debug * nit * warn * comments * fix ppo * fix obs filter * update * wip * tf * update * fix * cleanup * cleanup * spacing * model * fix * dqn * fix ddpg * doc * keep names * update * fix * com * docs * clarify model outputs * Update torch_policy_graph.py * fix obs filter * pass thru worker index * fix * rename * vlad torch comments * fix log action * debug name * fix lstm * remove unused ddpg net * remove conv net * revert lstm * wip * wip * cast * wip * works * fix a3c * works * lstm util test * doc * clean up * update * fix lstm check * move to end * fix sphinx * fix cmd * remove bad doc * envs * vec * doc prep * models * rl * alg * up * clarify * copy * async sa * fix * comments * fix a3c conf * tune lstm * fix reshape * fix * back to 16 * tuned a3c update * update * tuned * optional * merge * wip * fix up * move pg class * rename env * wip * update * tip * alg * readme * fix catalog * readme * doc * context * remove prep * comma * add env * link to paper * paper * update * rnn * update * wip * clean up ev creation * fix * fix * fix * fix lint * up * no comma * ma * Update run_multi_node_tests.sh * fix * sphinx is stupid * sphinx is stupid * clarify torch graph * no horizon * fix config * sb * Update test_optimizers.py
2018-07-01 00:05:08 -07:00
from ray.rllib.agents.es import optimizers
from ray.rllib.agents.es import policies
from ray.rllib.agents.es import utils
from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID
from ray.rllib.utils.annotations import override
from ray.rllib.utils.memory import ray_get_and_free
from ray.rllib.utils import FilterManager
logger = logging.getLogger(__name__)
Result = namedtuple("Result", [
"noise_indices", "noisy_returns", "sign_noisy_returns", "noisy_lengths",
"eval_returns", "eval_lengths"
])
# yapf: disable
# __sphinx_doc_begin__
DEFAULT_CONFIG = with_common_config({
"l2_coeff": 0.005,
"noise_stdev": 0.02,
"episodes_per_batch": 1000,
"train_batch_size": 10000,
"eval_prob": 0.003,
"return_proc_mode": "centered_rank",
"num_workers": 10,
"stepsize": 0.01,
"observation_filter": "MeanStdFilter",
"noise_size": 250000000,
"report_length": 10,
})
# __sphinx_doc_end__
# yapf: enable
@ray.remote
def create_shared_noise(count):
"""Create a large array of noise to be shared by all workers."""
seed = 123
noise = np.random.RandomState(seed).randn(count).astype(np.float32)
return noise
class SharedNoiseTable:
def __init__(self, noise):
self.noise = noise
assert self.noise.dtype == np.float32
def get(self, i, dim):
return self.noise[i:i + dim]
def sample_index(self, dim):
return np.random.randint(0, len(self.noise) - dim + 1)
@ray.remote
class Worker:
def __init__(self,
config,
policy_params,
env_creator,
noise,
min_task_runtime=0.2):
self.min_task_runtime = min_task_runtime
self.config = config
self.policy_params = policy_params
self.noise = SharedNoiseTable(noise)
[carla] [rllib] Add support for carla nav planner and scenarios from paper (#1382) * wip * Sat Dec 30 15:07:28 PST 2017 * log video * video doesn't work well * scenario integration * Sat Dec 30 17:30:22 PST 2017 * Sat Dec 30 17:31:05 PST 2017 * Sat Dec 30 17:31:32 PST 2017 * Sat Dec 30 17:32:16 PST 2017 * Sat Dec 30 17:34:11 PST 2017 * Sat Dec 30 17:34:50 PST 2017 * Sat Dec 30 17:35:34 PST 2017 * Sat Dec 30 17:38:49 PST 2017 * Sat Dec 30 17:40:39 PST 2017 * Sat Dec 30 17:43:00 PST 2017 * Sat Dec 30 17:43:04 PST 2017 * Sat Dec 30 17:45:56 PST 2017 * Sat Dec 30 17:46:26 PST 2017 * Sat Dec 30 17:47:02 PST 2017 * Sat Dec 30 17:51:53 PST 2017 * Sat Dec 30 17:52:54 PST 2017 * Sat Dec 30 17:56:43 PST 2017 * Sat Dec 30 18:27:07 PST 2017 * Sat Dec 30 18:27:52 PST 2017 * fix train * Sat Dec 30 18:41:51 PST 2017 * Sat Dec 30 18:54:11 PST 2017 * Sat Dec 30 18:56:22 PST 2017 * Sat Dec 30 19:05:04 PST 2017 * Sat Dec 30 19:05:23 PST 2017 * Sat Dec 30 19:11:53 PST 2017 * Sat Dec 30 19:14:31 PST 2017 * Sat Dec 30 19:16:20 PST 2017 * Sat Dec 30 19:18:05 PST 2017 * Sat Dec 30 19:18:45 PST 2017 * Sat Dec 30 19:22:44 PST 2017 * Sat Dec 30 19:24:41 PST 2017 * Sat Dec 30 19:26:57 PST 2017 * Sat Dec 30 19:40:37 PST 2017 * wip models * reward bonus * test prep * Sun Dec 31 18:45:25 PST 2017 * Sun Dec 31 18:58:28 PST 2017 * Sun Dec 31 18:59:34 PST 2017 * Sun Dec 31 19:03:33 PST 2017 * Sun Dec 31 19:05:05 PST 2017 * Sun Dec 31 19:09:25 PST 2017 * fix train * kill * add tuple preprocessor * Sun Dec 31 20:38:33 PST 2017 * Sun Dec 31 22:51:24 PST 2017 * Sun Dec 31 23:14:13 PST 2017 * Sun Dec 31 23:16:04 PST 2017 * Mon Jan 1 00:08:35 PST 2018 * Mon Jan 1 00:10:48 PST 2018 * Mon Jan 1 01:08:31 PST 2018 * Mon Jan 1 14:45:44 PST 2018 * Mon Jan 1 14:54:56 PST 2018 * Mon Jan 1 17:29:29 PST 2018 * switch to euclidean dists * Mon Jan 1 17:39:27 PST 2018 * Mon Jan 1 17:41:47 PST 2018 * Mon Jan 1 17:44:18 PST 2018 * Mon Jan 1 17:47:09 PST 2018 * Mon Jan 1 20:31:02 PST 2018 * Mon Jan 1 20:39:33 PST 2018 * Mon Jan 1 20:40:55 PST 2018 * Mon Jan 1 20:55:06 PST 2018 * Mon Jan 1 21:05:52 PST 2018 * fix env path * merge richards fix * fix hash * Mon Jan 1 22:04:00 PST 2018 * Mon Jan 1 22:25:29 PST 2018 * Mon Jan 1 22:30:42 PST 2018 * simplified reward function * add framestack * add env configs * simplify speed reward * Tue Jan 2 17:36:15 PST 2018 * Tue Jan 2 17:49:16 PST 2018 * Tue Jan 2 18:10:38 PST 2018 * add lane keeping simple mode * Tue Jan 2 20:25:26 PST 2018 * Tue Jan 2 20:30:30 PST 2018 * Tue Jan 2 20:33:26 PST 2018 * Tue Jan 2 20:41:42 PST 2018 * ppo lane keep * simplify discrete actions * Tue Jan 2 21:41:05 PST 2018 * Tue Jan 2 21:49:03 PST 2018 * Tue Jan 2 22:12:23 PST 2018 * Tue Jan 2 22:14:42 PST 2018 * Tue Jan 2 22:20:59 PST 2018 * Tue Jan 2 22:23:43 PST 2018 * Tue Jan 2 22:26:27 PST 2018 * Tue Jan 2 22:27:20 PST 2018 * Tue Jan 2 22:44:00 PST 2018 * Tue Jan 2 22:57:58 PST 2018 * Tue Jan 2 23:08:51 PST 2018 * Tue Jan 2 23:11:32 PST 2018 * update dqn reward * Thu Jan 4 12:29:40 PST 2018 * Thu Jan 4 12:30:26 PST 2018 * Update train_dqn.py * fix
2018-01-05 21:32:41 -08:00
self.env = env_creator(config["env_config"])
from ray.rllib import models
self.preprocessor = models.ModelCatalog.get_preprocessor(
self.env, config["model"])
self.sess = utils.make_session(single_threaded=True)
self.policy = policies.GenericPolicy(
self.sess, self.env.action_space, self.env.observation_space,
self.preprocessor, config["observation_filter"], config["model"],
**policy_params)
@property
def filters(self):
return {DEFAULT_POLICY_ID: self.policy.get_filter()}
def sync_filters(self, new_filters):
for k in self.filters:
self.filters[k].sync(new_filters[k])
def get_filters(self, flush_after=False):
return_filters = {}
for k, f in self.filters.items():
return_filters[k] = f.as_serializable()
if flush_after:
f.clear_buffer()
return return_filters
def rollout(self, timestep_limit, add_noise=True):
rollout_rewards, rollout_length = policies.rollout(
self.policy,
self.env,
timestep_limit=timestep_limit,
add_noise=add_noise)
return rollout_rewards, rollout_length
def do_rollouts(self, params, timestep_limit=None):
# Set the network weights.
self.policy.set_weights(params)
noise_indices, returns, sign_returns, lengths = [], [], [], []
eval_returns, eval_lengths = [], []
# Perform some rollouts with noise.
task_tstart = time.time()
while (len(noise_indices) == 0
or time.time() - task_tstart < self.min_task_runtime):
if np.random.uniform() < self.config["eval_prob"]:
# Do an evaluation run with no perturbation.
self.policy.set_weights(params)
rewards, length = self.rollout(timestep_limit, add_noise=False)
eval_returns.append(rewards.sum())
eval_lengths.append(length)
else:
# Do a regular run with parameter perturbations.
noise_index = self.noise.sample_index(self.policy.num_params)
perturbation = self.config["noise_stdev"] * self.noise.get(
noise_index, self.policy.num_params)
# These two sampling steps could be done in parallel on
# different actors letting us update twice as frequently.
self.policy.set_weights(params + perturbation)
rewards_pos, lengths_pos = self.rollout(timestep_limit)
self.policy.set_weights(params - perturbation)
rewards_neg, lengths_neg = self.rollout(timestep_limit)
noise_indices.append(noise_index)
returns.append([rewards_pos.sum(), rewards_neg.sum()])
sign_returns.append(
[np.sign(rewards_pos).sum(),
np.sign(rewards_neg).sum()])
lengths.append([lengths_pos, lengths_neg])
2018-01-31 17:22:39 -08:00
return Result(
noise_indices=noise_indices,
noisy_returns=returns,
sign_noisy_returns=sign_returns,
noisy_lengths=lengths,
eval_returns=eval_returns,
eval_lengths=eval_lengths)
class ESTrainer(Trainer):
[rllib] Document "v2" APIs (#2316) * re * wip * wip * a3c working * torch support * pg works * lint * rm v2 * consumer id * clean up pg * clean up more * fix python 2.7 * tf session management * docs * dqn wip * fix compile * dqn * apex runs * up * impotrs * ddpg * quotes * fix tests * fix last r * fix tests * lint * pass checkpoint restore * kwar * nits * policy graph * fix yapf * com * class * pyt * vectorization * update * test cpe * unit test * fix ddpg2 * changes * wip * args * faster test * common * fix * add alg option * batch mode and policy serving * multi serving test * todo * wip * serving test * doc async env * num envs * comments * thread * remove init hook * update * fix ppo * comments1 * fix * updates * add jenkins tests * fix * fix pytorch * fix * fixes * fix a3c policy * fix squeeze * fix trunc on apex * fix squeezing for real * update * remove horizon test for now * multiagent wip * update * fix race condition * fix ma * t * doc * st * wip * example * wip * working * cartpole * wip * batch wip * fix bug * make other_batches None default * working * debug * nit * warn * comments * fix ppo * fix obs filter * update * wip * tf * update * fix * cleanup * cleanup * spacing * model * fix * dqn * fix ddpg * doc * keep names * update * fix * com * docs * clarify model outputs * Update torch_policy_graph.py * fix obs filter * pass thru worker index * fix * rename * vlad torch comments * fix log action * debug name * fix lstm * remove unused ddpg net * remove conv net * revert lstm * wip * wip * cast * wip * works * fix a3c * works * lstm util test * doc * clean up * update * fix lstm check * move to end * fix sphinx * fix cmd * remove bad doc * envs * vec * doc prep * models * rl * alg * up * clarify * copy * async sa * fix * comments * fix a3c conf * tune lstm * fix reshape * fix * back to 16 * tuned a3c update * update * tuned * optional * merge * wip * fix up * move pg class * rename env * wip * update * tip * alg * readme * fix catalog * readme * doc * context * remove prep * comma * add env * link to paper * paper * update * rnn * update * wip * clean up ev creation * fix * fix * fix * fix lint * up * no comma * ma * Update run_multi_node_tests.sh * fix * sphinx is stupid * sphinx is stupid * clarify torch graph * no horizon * fix config * sb * Update test_optimizers.py
2018-07-01 00:05:08 -07:00
"""Large-scale implementation of Evolution Strategies in Ray."""
_name = "ES"
_default_config = DEFAULT_CONFIG
@override(Trainer)
def _init(self, config, env_creator):
# PyTorch check.
if config["use_pytorch"]:
raise ValueError(
"ES does not support PyTorch yet! Use tf instead."
)
policy_params = {"action_noise_std": 0.01}
env = env_creator(config["env_config"])
from ray.rllib import models
preprocessor = models.ModelCatalog.get_preprocessor(env)
self.sess = utils.make_session(single_threaded=False)
self.policy = policies.GenericPolicy(
self.sess, env.action_space, env.observation_space, preprocessor,
config["observation_filter"], config["model"], **policy_params)
self.optimizer = optimizers.Adam(self.policy, config["stepsize"])
self.report_length = config["report_length"]
# Create the shared noise table.
logger.info("Creating shared noise table.")
noise_id = create_shared_noise.remote(config["noise_size"])
self.noise = SharedNoiseTable(ray.get(noise_id))
# Create the actors.
logger.info("Creating actors.")
self._workers = [
Worker.remote(config, policy_params, env_creator, noise_id)
for _ in range(config["num_workers"])
]
self.episodes_so_far = 0
self.reward_list = []
self.tstart = time.time()
@override(Trainer)
def _train(self):
config = self.config
theta = self.policy.get_weights()
assert theta.dtype == np.float32
# Put the current policy weights in the object store.
theta_id = ray.put(theta)
# Use the actors to do rollouts, note that we pass in the ID of the
# policy weights.
results, num_episodes, num_timesteps = self._collect_results(
theta_id, config["episodes_per_batch"], config["train_batch_size"])
all_noise_indices = []
all_training_returns = []
all_training_lengths = []
all_eval_returns = []
all_eval_lengths = []
# Loop over the results.
for result in results:
all_eval_returns += result.eval_returns
all_eval_lengths += result.eval_lengths
all_noise_indices += result.noise_indices
all_training_returns += result.noisy_returns
all_training_lengths += result.noisy_lengths
assert len(all_eval_returns) == len(all_eval_lengths)
assert (len(all_noise_indices) == len(all_training_returns) ==
len(all_training_lengths))
self.episodes_so_far += num_episodes
# Assemble the results.
eval_returns = np.array(all_eval_returns)
eval_lengths = np.array(all_eval_lengths)
noise_indices = np.array(all_noise_indices)
noisy_returns = np.array(all_training_returns)
noisy_lengths = np.array(all_training_lengths)
# Process the returns.
if config["return_proc_mode"] == "centered_rank":
proc_noisy_returns = utils.compute_centered_ranks(noisy_returns)
else:
raise NotImplementedError(config["return_proc_mode"])
# Compute and take a step.
g, count = utils.batched_weighted_sum(
proc_noisy_returns[:, 0] - proc_noisy_returns[:, 1],
(self.noise.get(index, self.policy.num_params)
for index in noise_indices),
batch_size=500)
g /= noisy_returns.size
assert (g.shape == (self.policy.num_params, ) and g.dtype == np.float32
and count == len(noise_indices))
# Compute the new weights theta.
theta, update_ratio = self.optimizer.update(-g +
config["l2_coeff"] * theta)
# Set the new weights in the local copy of the policy.
self.policy.set_weights(theta)
# Store the rewards
if len(all_eval_returns) > 0:
self.reward_list.append(np.mean(eval_returns))
# Now sync the filters
FilterManager.synchronize({
DEFAULT_POLICY_ID: self.policy.get_filter()
}, self._workers)
info = {
"weights_norm": np.square(theta).sum(),
"grad_norm": np.square(g).sum(),
"update_ratio": update_ratio,
"episodes_this_iter": noisy_lengths.size,
"episodes_so_far": self.episodes_so_far,
}
reward_mean = np.mean(self.reward_list[-self.report_length:])
result = dict(
episode_reward_mean=reward_mean,
episode_len_mean=eval_lengths.mean(),
timesteps_this_iter=noisy_lengths.sum(),
info=info)
return result
@override(Trainer)
def compute_action(self, observation):
return self.policy.compute(observation, update=False)[0]
@override(Trainer)
def _stop(self):
# workaround for https://github.com/ray-project/ray/issues/1516
for w in self._workers:
w.__ray_terminate__.remote()
def _collect_results(self, theta_id, min_episodes, min_timesteps):
num_episodes, num_timesteps = 0, 0
results = []
while num_episodes < min_episodes or num_timesteps < min_timesteps:
logger.info(
"Collected {} episodes {} timesteps so far this iter".format(
num_episodes, num_timesteps))
rollout_ids = [
worker.do_rollouts.remote(theta_id) for worker in self._workers
]
# Get the results of the rollouts.
for result in ray_get_and_free(rollout_ids):
results.append(result)
# Update the number of episodes and the number of timesteps
# keeping in mind that result.noisy_lengths is a list of lists,
# where the inner lists have length 2.
num_episodes += sum(len(pair) for pair in result.noisy_lengths)
num_timesteps += sum(
sum(pair) for pair in result.noisy_lengths)
return results, num_episodes, num_timesteps
def __getstate__(self):
return {
"weights": self.policy.get_weights(),
"filter": self.policy.get_filter(),
"episodes_so_far": self.episodes_so_far,
}
def __setstate__(self, state):
self.episodes_so_far = state["episodes_so_far"]
self.policy.set_weights(state["weights"])
self.policy.set_filter(state["filter"])
FilterManager.synchronize({
DEFAULT_POLICY_ID: self.policy.get_filter()
}, self._workers)