mirror of
https://github.com/vale981/ray
synced 2025-03-05 10:01:43 -05:00
[rllib] add augmented random search (#2714)
* added ars * functioning ars with regression test * added regression tests for ARs * fixed default config for ARS * ARS code runs, now time to test * ARS working and tested, changed std deviation of meanstd filter to initialize to 1 * ARS working and tested, changed std deviation of meanstd filter to initialize to 1 * pep8 fixes * removed unused linear model * address comments * more fixing comments * post yapf * fixed support failure * Update LICENSE * Update policies.py * Update test_supported_spaces.py * Update policies.py * Update LICENSE * Update test_supported_spaces.py * Update policies.py * Update policies.py * Update filter.py
This commit is contained in:
parent
5fd44afb8a
commit
6201a6d1c7
11 changed files with 698 additions and 1 deletions
27
LICENSE
27
LICENSE
|
@ -243,3 +243,30 @@ distributed under the License is distributed on an "AS IS" BASIS,
|
|||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
|
||||
--------------------------------------------------------------------------------
|
||||
Code in python/ray/rllib/ars is adapted from https://github.com/modestyachts/ARS
|
||||
|
||||
Copyright (c) 2018, ARS contributors (Horia Mania, Aurelia Guy, Benjamin Recht)
|
||||
All rights reserved.
|
||||
|
||||
Redistribution and use of ARS in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are met:
|
||||
|
||||
1. Redistributions of source code must retain the above copyright notice, this
|
||||
list of conditions and the following disclaimer.
|
||||
|
||||
2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
this list of conditions and the following disclaimer in the documentation and/or
|
||||
other materials provided with the distribution.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
|
||||
ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
|
||||
WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR
|
||||
ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
|
||||
(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
|
||||
LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
|
||||
ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
|
|
@ -17,9 +17,10 @@ from ray.rllib.evaluation.sample_batch import SampleBatch
|
|||
|
||||
|
||||
def _register_all():
|
||||
|
||||
for key in [
|
||||
"PPO", "ES", "DQN", "APEX", "A3C", "BC", "PG", "DDPG", "APEX_DDPG",
|
||||
"IMPALA", "A2C", "__fake", "__sigmoid_fake_data",
|
||||
"IMPALA", "ARS", "A2C", "__fake", "__sigmoid_fake_data",
|
||||
"__parameter_tuning"
|
||||
]:
|
||||
from ray.rllib.agents.agent import get_agent_class
|
||||
|
|
|
@ -393,6 +393,9 @@ def get_agent_class(alg):
|
|||
elif alg == "ES":
|
||||
from ray.rllib.agents import es
|
||||
return es.ESAgent
|
||||
elif alg == "ARS":
|
||||
from ray.rllib.agents import ars
|
||||
return ars.ARSAgent
|
||||
elif alg == "DQN":
|
||||
from ray.rllib.agents import dqn
|
||||
return dqn.DQNAgent
|
||||
|
|
3
python/ray/rllib/agents/ars/__init__.py
Normal file
3
python/ray/rllib/agents/ars/__init__.py
Normal file
|
@ -0,0 +1,3 @@
|
|||
from ray.rllib.agents.ars.ars import (ARSAgent, DEFAULT_CONFIG)
|
||||
|
||||
__all__ = ["ARSAgent", "DEFAULT_CONFIG"]
|
351
python/ray/rllib/agents/ars/ars.py
Normal file
351
python/ray/rllib/agents/ars/ars.py
Normal file
|
@ -0,0 +1,351 @@
|
|||
# Code in this file is copied and adapted from
|
||||
# https://github.com/openai/evolution-strategies-starter and from
|
||||
# https://github.com/modestyachts/ARS
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from collections import namedtuple
|
||||
import numpy as np
|
||||
import os
|
||||
import pickle
|
||||
import time
|
||||
|
||||
import ray
|
||||
from ray.rllib.agents import Agent, with_common_config
|
||||
from ray.tune.trial import Resources
|
||||
|
||||
from ray.rllib.agents.ars import optimizers
|
||||
from ray.rllib.agents.ars import policies
|
||||
from ray.rllib.agents.es import tabular_logger as tlogger
|
||||
from ray.rllib.agents.ars import utils
|
||||
|
||||
Result = namedtuple("Result", [
|
||||
"noise_indices", "noisy_returns", "sign_noisy_returns", "noisy_lengths",
|
||||
"eval_returns", "eval_lengths"
|
||||
])
|
||||
|
||||
DEFAULT_CONFIG = with_common_config({
|
||||
'noise_stdev': 0.02, # std deviation of parameter noise
|
||||
'num_deltas': 4, # number of perturbations to try
|
||||
'deltas_used': 4, # number of perturbations to keep in gradient estimate
|
||||
'num_workers': 2,
|
||||
'stepsize': 0.01, # sgd step-size
|
||||
'observation_filter': "MeanStdFilter",
|
||||
'noise_size': 250000000,
|
||||
'eval_prob': 0.03, # probability of evaluating the parameter rewards
|
||||
'env_config': {},
|
||||
'offset': 0,
|
||||
'policy_type': "LinearPolicy", # ["LinearPolicy", "MLPPolicy"]
|
||||
"fcnet_hiddens": [32, 32], # fcnet structure of MLPPolicy
|
||||
})
|
||||
|
||||
|
||||
@ray.remote
|
||||
def create_shared_noise(count):
|
||||
"""Create a large array of noise to be shared by all workers."""
|
||||
seed = 123
|
||||
noise = np.random.RandomState(seed).randn(count).astype(np.float32)
|
||||
return noise
|
||||
|
||||
|
||||
class SharedNoiseTable(object):
|
||||
def __init__(self, noise):
|
||||
self.noise = noise
|
||||
assert self.noise.dtype == np.float32
|
||||
|
||||
def get(self, i, dim):
|
||||
return self.noise[i:i + dim]
|
||||
|
||||
def sample_index(self, dim):
|
||||
return np.random.randint(0, len(self.noise) - dim + 1)
|
||||
|
||||
def get_delta(self, dim):
|
||||
idx = self.sample_index(dim)
|
||||
return idx, self.get(idx, dim)
|
||||
|
||||
|
||||
@ray.remote
|
||||
class Worker(object):
|
||||
def __init__(self,
|
||||
config,
|
||||
policy_params,
|
||||
env_creator,
|
||||
noise,
|
||||
min_task_runtime=0.2):
|
||||
self.min_task_runtime = min_task_runtime
|
||||
self.config = config
|
||||
self.policy_params = policy_params
|
||||
self.noise = SharedNoiseTable(noise)
|
||||
|
||||
self.env = env_creator(config["env_config"])
|
||||
from ray.rllib import models
|
||||
self.preprocessor = models.ModelCatalog.get_preprocessor(self.env)
|
||||
|
||||
self.sess = utils.make_session(single_threaded=True)
|
||||
if config["policy_type"] == "LinearPolicy":
|
||||
self.policy = policies.LinearPolicy(
|
||||
self.sess, self.env.action_space, self.preprocessor,
|
||||
config["observation_filter"], **policy_params)
|
||||
else:
|
||||
self.policy = policies.MLPPolicy(
|
||||
self.sess, self.env.action_space, self.preprocessor,
|
||||
config["observation_filter"], config["fcnet_hiddens"],
|
||||
**policy_params)
|
||||
|
||||
def rollout(self, timestep_limit, add_noise=False):
|
||||
rollout_rewards, rollout_length = policies.rollout(
|
||||
self.policy,
|
||||
self.env,
|
||||
timestep_limit=timestep_limit,
|
||||
add_noise=add_noise,
|
||||
offset=self.config['offset'])
|
||||
return rollout_rewards, rollout_length
|
||||
|
||||
def do_rollouts(self, params, timestep_limit=None):
|
||||
# Set the network weights.
|
||||
self.policy.set_weights(params)
|
||||
|
||||
noise_indices, returns, sign_returns, lengths = [], [], [], []
|
||||
eval_returns, eval_lengths = [], []
|
||||
|
||||
# Perform some rollouts with noise.
|
||||
while (len(noise_indices) == 0):
|
||||
if np.random.uniform() < self.config["eval_prob"]:
|
||||
# Do an evaluation run with no perturbation.
|
||||
self.policy.set_weights(params)
|
||||
rewards, length = self.rollout(timestep_limit, add_noise=False)
|
||||
eval_returns.append(rewards.sum())
|
||||
eval_lengths.append(length)
|
||||
else:
|
||||
# Do a regular run with parameter perturbations.
|
||||
noise_index = self.noise.sample_index(self.policy.num_params)
|
||||
|
||||
perturbation = self.config["noise_stdev"] * self.noise.get(
|
||||
noise_index, self.policy.num_params)
|
||||
|
||||
# These two sampling steps could be done in parallel on
|
||||
# different actors letting us update twice as frequently.
|
||||
self.policy.set_weights(params + perturbation)
|
||||
rewards_pos, lengths_pos = self.rollout(timestep_limit)
|
||||
|
||||
self.policy.set_weights(params - perturbation)
|
||||
rewards_neg, lengths_neg = self.rollout(timestep_limit)
|
||||
|
||||
noise_indices.append(noise_index)
|
||||
returns.append([rewards_pos.sum(), rewards_neg.sum()])
|
||||
sign_returns.append(
|
||||
[np.sign(rewards_pos).sum(),
|
||||
np.sign(rewards_neg).sum()])
|
||||
lengths.append([lengths_pos, lengths_neg])
|
||||
|
||||
return Result(
|
||||
noise_indices=noise_indices,
|
||||
noisy_returns=returns,
|
||||
sign_noisy_returns=sign_returns,
|
||||
noisy_lengths=lengths,
|
||||
eval_returns=eval_returns,
|
||||
eval_lengths=eval_lengths)
|
||||
|
||||
|
||||
class ARSAgent(Agent):
|
||||
"""Large-scale implementation of Augmented Random Search in Ray."""
|
||||
|
||||
_agent_name = "ARS"
|
||||
_default_config = DEFAULT_CONFIG
|
||||
|
||||
@classmethod
|
||||
def default_resource_request(cls, config):
|
||||
cf = dict(cls._default_config, **config)
|
||||
return Resources(cpu=1, gpu=0, extra_cpu=cf["num_workers"])
|
||||
|
||||
def _init(self):
|
||||
policy_params = {"action_noise_std": 0.0}
|
||||
|
||||
# register the linear network
|
||||
utils.register_linear_network()
|
||||
|
||||
env = self.env_creator(self.config["env_config"])
|
||||
from ray.rllib import models
|
||||
preprocessor = models.ModelCatalog.get_preprocessor(env)
|
||||
|
||||
self.sess = utils.make_session(single_threaded=False)
|
||||
if self.config["policy_type"] == "LinearPolicy":
|
||||
self.policy = policies.LinearPolicy(
|
||||
self.sess, env.action_space, preprocessor,
|
||||
self.config["observation_filter"], **policy_params)
|
||||
else:
|
||||
self.policy = policies.MLPPolicy(
|
||||
self.sess, env.action_space, preprocessor,
|
||||
self.config["observation_filter"],
|
||||
self.config["fcnet_hiddens"], **policy_params)
|
||||
self.optimizer = optimizers.Adam(self.policy, self.config["stepsize"])
|
||||
|
||||
self.deltas_used = self.config["deltas_used"]
|
||||
self.num_deltas = self.config["num_deltas"]
|
||||
|
||||
# Create the shared noise table.
|
||||
print("Creating shared noise table.")
|
||||
noise_id = create_shared_noise.remote(self.config["noise_size"])
|
||||
self.noise = SharedNoiseTable(ray.get(noise_id))
|
||||
|
||||
# Create the actors.
|
||||
print("Creating actors.")
|
||||
self.workers = [
|
||||
Worker.remote(self.config, policy_params, self.env_creator,
|
||||
noise_id) for _ in range(self.config["num_workers"])
|
||||
]
|
||||
|
||||
self.episodes_so_far = 0
|
||||
self.timesteps_so_far = 0
|
||||
self.tstart = time.time()
|
||||
|
||||
def _collect_results(self, theta_id, min_episodes):
|
||||
num_episodes, num_timesteps = 0, 0
|
||||
results = []
|
||||
while num_episodes < min_episodes:
|
||||
print("Collected {} episodes {} timesteps so far this iter".format(
|
||||
num_episodes, num_timesteps))
|
||||
rollout_ids = [
|
||||
worker.do_rollouts.remote(theta_id) for worker in self.workers
|
||||
]
|
||||
# Get the results of the rollouts.
|
||||
for result in ray.get(rollout_ids):
|
||||
results.append(result)
|
||||
# Update the number of episodes and the number of timesteps
|
||||
# keeping in mind that result.noisy_lengths is a list of lists,
|
||||
# where the inner lists have length 2.
|
||||
num_episodes += sum(len(pair) for pair in result.noisy_lengths)
|
||||
num_timesteps += sum(
|
||||
sum(pair) for pair in result.noisy_lengths)
|
||||
return results, num_episodes, num_timesteps
|
||||
|
||||
def _train(self):
|
||||
config = self.config
|
||||
|
||||
step_tstart = time.time()
|
||||
theta = self.policy.get_weights()
|
||||
assert theta.dtype == np.float32
|
||||
|
||||
# Put the current policy weights in the object store.
|
||||
theta_id = ray.put(theta)
|
||||
# Use the actors to do rollouts, note that we pass in the ID of the
|
||||
# policy weights.
|
||||
results, num_episodes, num_timesteps = self._collect_results(
|
||||
theta_id, config["num_deltas"])
|
||||
|
||||
all_noise_indices = []
|
||||
all_training_returns = []
|
||||
all_training_lengths = []
|
||||
all_eval_returns = []
|
||||
all_eval_lengths = []
|
||||
|
||||
# Loop over the results.
|
||||
for result in results:
|
||||
all_eval_returns += result.eval_returns
|
||||
all_eval_lengths += result.eval_lengths
|
||||
|
||||
all_noise_indices += result.noise_indices
|
||||
all_training_returns += result.noisy_returns
|
||||
all_training_lengths += result.noisy_lengths
|
||||
|
||||
assert len(all_eval_returns) == len(all_eval_lengths)
|
||||
assert (len(all_noise_indices) == len(all_training_returns) ==
|
||||
len(all_training_lengths))
|
||||
|
||||
self.episodes_so_far += num_episodes
|
||||
self.timesteps_so_far += num_timesteps
|
||||
|
||||
# Assemble the results.
|
||||
eval_returns = np.array(all_eval_returns)
|
||||
eval_lengths = np.array(all_eval_lengths)
|
||||
noise_indices = np.array(all_noise_indices)
|
||||
noisy_returns = np.array(all_training_returns)
|
||||
noisy_lengths = np.array(all_training_lengths)
|
||||
|
||||
# keep only the best returns
|
||||
# select top performing directions if deltas_used < num_deltas
|
||||
max_rewards = np.max(noisy_returns, axis=1)
|
||||
if self.deltas_used > self.num_deltas:
|
||||
self.deltas_used = self.num_deltas
|
||||
|
||||
percentile = 100 * (1 - (self.deltas_used / self.num_deltas))
|
||||
idx = np.arange(max_rewards.size)[
|
||||
max_rewards >= np.percentile(max_rewards, percentile)]
|
||||
noise_idx = noise_indices[idx]
|
||||
noisy_returns = noisy_returns[idx, :]
|
||||
|
||||
# Compute and take a step.
|
||||
g, count = utils.batched_weighted_sum(
|
||||
noisy_returns[:, 0] - noisy_returns[:, 1],
|
||||
(self.noise.get(index, self.policy.num_params)
|
||||
for index in noise_idx),
|
||||
batch_size=min(500, noisy_returns[:, 0].size))
|
||||
g /= noise_idx.size
|
||||
# scale the returns by their standard deviation
|
||||
if not np.isclose(np.std(noisy_returns), 0.0):
|
||||
g /= np.std(noisy_returns)
|
||||
assert (g.shape == (self.policy.num_params, )
|
||||
and g.dtype == np.float32)
|
||||
print('the number of policy params is, ', self.policy.num_params)
|
||||
# Compute the new weights theta.
|
||||
theta, update_ratio = self.optimizer.update(-g)
|
||||
# Set the new weights in the local copy of the policy.
|
||||
self.policy.set_weights(theta)
|
||||
|
||||
step_tend = time.time()
|
||||
tlogger.record_tabular("EvalEpRewMean", eval_returns.mean())
|
||||
tlogger.record_tabular("EvalEpRewStd", eval_returns.std())
|
||||
tlogger.record_tabular("EvalEpLenMean", eval_lengths.mean())
|
||||
|
||||
tlogger.record_tabular("NoisyEpRewMean", noisy_returns.mean())
|
||||
tlogger.record_tabular("NoisyEpRewStd", noisy_returns.std())
|
||||
tlogger.record_tabular("NoisyEpLenMean", noisy_lengths.mean())
|
||||
|
||||
tlogger.record_tabular("WeightsNorm", float(np.square(theta).sum()))
|
||||
tlogger.record_tabular("WeightsStd", float(np.std(theta)))
|
||||
tlogger.record_tabular("Grad2Norm", float(np.sqrt(np.square(g).sum())))
|
||||
tlogger.record_tabular("UpdateRatio", float(update_ratio))
|
||||
tlogger.dump_tabular()
|
||||
|
||||
info = {
|
||||
"weights_norm": np.square(theta).sum(),
|
||||
"grad_norm": np.square(g).sum(),
|
||||
"update_ratio": update_ratio,
|
||||
"episodes_this_iter": noisy_lengths.size,
|
||||
"episodes_so_far": self.episodes_so_far,
|
||||
"timesteps_so_far": self.timesteps_so_far,
|
||||
"time_elapsed_this_iter": step_tend - step_tstart,
|
||||
"time_elapsed": step_tend - self.tstart
|
||||
}
|
||||
|
||||
result = dict(
|
||||
episode_reward_mean=eval_returns.mean(),
|
||||
episode_len_mean=eval_lengths.mean(),
|
||||
timesteps_this_iter=noisy_lengths.sum(),
|
||||
info=info)
|
||||
|
||||
return result
|
||||
|
||||
def _stop(self):
|
||||
# workaround for https://github.com/ray-project/ray/issues/1516
|
||||
for w in self.workers:
|
||||
w.__ray_terminate__.remote()
|
||||
|
||||
def _save(self, checkpoint_dir):
|
||||
checkpoint_path = os.path.join(checkpoint_dir,
|
||||
"checkpoint-{}".format(self.iteration))
|
||||
weights = self.policy.get_weights()
|
||||
objects = [weights, self.episodes_so_far, self.timesteps_so_far]
|
||||
pickle.dump(objects, open(checkpoint_path, "wb"))
|
||||
return checkpoint_path
|
||||
|
||||
def _restore(self, checkpoint_path):
|
||||
objects = pickle.load(open(checkpoint_path, "rb"))
|
||||
self.policy.set_weights(objects[0])
|
||||
self.episodes_so_far = objects[1]
|
||||
self.timesteps_so_far = objects[2]
|
||||
|
||||
def compute_action(self, observation):
|
||||
return self.policy.compute(observation, update=True)[0]
|
56
python/ray/rllib/agents/ars/optimizers.py
Normal file
56
python/ray/rllib/agents/ars/optimizers.py
Normal file
|
@ -0,0 +1,56 @@
|
|||
# Code in this file is copied and adapted from
|
||||
# https://github.com/openai/evolution-strategies-starter.
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
class Optimizer(object):
|
||||
def __init__(self, pi):
|
||||
self.pi = pi
|
||||
self.dim = pi.num_params
|
||||
self.t = 0
|
||||
|
||||
def update(self, globalg):
|
||||
self.t += 1
|
||||
step = self._compute_step(globalg)
|
||||
theta = self.pi.get_weights()
|
||||
ratio = np.linalg.norm(step) / np.linalg.norm(theta)
|
||||
return theta + step, ratio
|
||||
|
||||
def _compute_step(self, globalg):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class SGD(Optimizer):
|
||||
def __init__(self, pi, stepsize, momentum=0.9):
|
||||
Optimizer.__init__(self, pi)
|
||||
self.v = np.zeros(self.dim, dtype=np.float32)
|
||||
self.stepsize, self.momentum = stepsize, momentum
|
||||
|
||||
def _compute_step(self, globalg):
|
||||
self.v = self.momentum * self.v + (1. - self.momentum) * globalg
|
||||
step = -self.stepsize * self.v
|
||||
return step
|
||||
|
||||
|
||||
class Adam(Optimizer):
|
||||
def __init__(self, pi, stepsize, beta1=0.9, beta2=0.999, epsilon=1e-08):
|
||||
Optimizer.__init__(self, pi)
|
||||
self.stepsize = stepsize
|
||||
self.beta1 = beta1
|
||||
self.beta2 = beta2
|
||||
self.epsilon = epsilon
|
||||
self.m = np.zeros(self.dim, dtype=np.float32)
|
||||
self.v = np.zeros(self.dim, dtype=np.float32)
|
||||
|
||||
def _compute_step(self, globalg):
|
||||
a = self.stepsize * (np.sqrt(1 - self.beta2**self.t) /
|
||||
(1 - self.beta1**self.t))
|
||||
self.m = self.beta1 * self.m + (1 - self.beta1) * globalg
|
||||
self.v = self.beta2 * self.v + (1 - self.beta2) * (globalg * globalg)
|
||||
step = -a * self.m / (np.sqrt(self.v) + self.epsilon)
|
||||
return step
|
136
python/ray/rllib/agents/ars/policies.py
Normal file
136
python/ray/rllib/agents/ars/policies.py
Normal file
|
@ -0,0 +1,136 @@
|
|||
# Code in this file is copied and adapted from
|
||||
# https://github.com/openai/evolution-strategies-starter.
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import gym
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
import ray
|
||||
from ray.rllib.utils.filter import get_filter
|
||||
from ray.rllib.utils.error import UnsupportedSpaceException
|
||||
from ray.rllib.models import ModelCatalog
|
||||
|
||||
|
||||
def rollout(policy, env, timestep_limit=None, add_noise=False, offset=0):
|
||||
"""Do a rollout.
|
||||
|
||||
If add_noise is True, the rollout will take noisy actions with
|
||||
noise drawn from that stream. Otherwise, no action noise will be added.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
policy: tf object
|
||||
policy from which to draw actions
|
||||
env: GymEnv
|
||||
environment from which to draw rewards, done, and next state
|
||||
timestep_limit: int, optional
|
||||
steps after which to end the rollout
|
||||
add_noise: bool, optional
|
||||
indicates whether exploratory action noise should be added
|
||||
offset: int, optional
|
||||
value to subtract from the reward. For example, survival bonus
|
||||
from humanoid
|
||||
"""
|
||||
env_timestep_limit = env.spec.max_episode_steps
|
||||
timestep_limit = (env_timestep_limit if timestep_limit is None else min(
|
||||
timestep_limit, env_timestep_limit))
|
||||
rews = []
|
||||
t = 0
|
||||
observation = env.reset()
|
||||
for _ in range(timestep_limit or 999999):
|
||||
ac = policy.compute(observation, add_noise=add_noise, update=True)[0]
|
||||
observation, rew, done, _ = env.step(ac)
|
||||
rew -= np.abs(offset)
|
||||
rews.append(rew)
|
||||
t += 1
|
||||
if done:
|
||||
break
|
||||
rews = np.array(rews, dtype=np.float32)
|
||||
return rews, t
|
||||
|
||||
|
||||
class GenericPolicy(object):
|
||||
def __init__(self,
|
||||
sess,
|
||||
action_space,
|
||||
preprocessor,
|
||||
observation_filter,
|
||||
action_noise_std,
|
||||
options={}):
|
||||
|
||||
if len(preprocessor.shape) > 1:
|
||||
raise UnsupportedSpaceException(
|
||||
"Observation space {} is not supported with ARS.".format(
|
||||
preprocessor.shape))
|
||||
|
||||
self.sess = sess
|
||||
self.action_space = action_space
|
||||
self.action_noise_std = action_noise_std
|
||||
self.preprocessor = preprocessor
|
||||
self.observation_filter = get_filter(observation_filter,
|
||||
self.preprocessor.shape)
|
||||
self.inputs = tf.placeholder(tf.float32,
|
||||
[None] + list(self.preprocessor.shape))
|
||||
|
||||
# Policy network.
|
||||
dist_class, dist_dim = ModelCatalog.get_action_dist(
|
||||
action_space, dist_type="deterministic")
|
||||
|
||||
model = ModelCatalog.get_model(self.inputs, dist_dim, options=options)
|
||||
dist = dist_class(model.outputs)
|
||||
self.sampler = dist.sample()
|
||||
|
||||
self.variables = ray.experimental.TensorFlowVariables(
|
||||
model.outputs, self.sess)
|
||||
|
||||
self.num_params = sum(
|
||||
np.prod(variable.shape.as_list())
|
||||
for _, variable in self.variables.variables.items())
|
||||
self.sess.run(tf.global_variables_initializer())
|
||||
|
||||
def compute(self, observation, add_noise=False, update=True):
|
||||
observation = self.preprocessor.transform(observation)
|
||||
observation = self.observation_filter(observation[None], update=update)
|
||||
action = self.sess.run(
|
||||
self.sampler, feed_dict={self.inputs: observation})
|
||||
if add_noise and isinstance(self.action_space, gym.spaces.Box):
|
||||
action += np.random.randn(*action.shape) * self.action_noise_std
|
||||
return action
|
||||
|
||||
def set_weights(self, x):
|
||||
self.variables.set_flat(x)
|
||||
|
||||
def get_weights(self):
|
||||
return self.variables.get_flat()
|
||||
|
||||
|
||||
class LinearPolicy(GenericPolicy):
|
||||
def __init__(self, sess, action_space, preprocessor, observation_filter,
|
||||
action_noise_std):
|
||||
options = {"custom_model": "LinearNetwork"}
|
||||
GenericPolicy.__init__(
|
||||
self,
|
||||
sess,
|
||||
action_space,
|
||||
preprocessor,
|
||||
observation_filter,
|
||||
action_noise_std,
|
||||
options=options)
|
||||
|
||||
|
||||
class MLPPolicy(GenericPolicy):
|
||||
def __init__(self, sess, action_space, preprocessor, observation_filter,
|
||||
fcnet_hiddens, action_noise_std):
|
||||
options = {"fcnet_hiddens": fcnet_hiddens}
|
||||
GenericPolicy.__init__(
|
||||
self,
|
||||
sess,
|
||||
action_space,
|
||||
preprocessor,
|
||||
observation_filter,
|
||||
action_noise_std,
|
||||
options=options)
|
82
python/ray/rllib/agents/ars/utils.py
Normal file
82
python/ray/rllib/agents/ars/utils.py
Normal file
|
@ -0,0 +1,82 @@
|
|||
# Code in this file is copied and adapted from
|
||||
# https://github.com/openai/evolution-strategies-starter.
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
from ray.rllib.models import ModelCatalog, Model
|
||||
import tensorflow.contrib.slim as slim
|
||||
from ray.rllib.models.misc import normc_initializer
|
||||
|
||||
|
||||
def compute_ranks(x):
|
||||
"""Returns ranks in [0, len(x))
|
||||
|
||||
Note: This is different from scipy.stats.rankdata, which returns ranks in
|
||||
[1, len(x)].
|
||||
"""
|
||||
assert x.ndim == 1
|
||||
ranks = np.empty(len(x), dtype=int)
|
||||
ranks[x.argsort()] = np.arange(len(x))
|
||||
return ranks
|
||||
|
||||
|
||||
def compute_centered_ranks(x):
|
||||
y = compute_ranks(x.ravel()).reshape(x.shape).astype(np.float32)
|
||||
y /= (x.size - 1)
|
||||
y -= 0.5
|
||||
return y
|
||||
|
||||
|
||||
def make_session(single_threaded):
|
||||
if not single_threaded:
|
||||
return tf.Session()
|
||||
return tf.Session(
|
||||
config=tf.ConfigProto(
|
||||
inter_op_parallelism_threads=1, intra_op_parallelism_threads=1))
|
||||
|
||||
|
||||
def itergroups(items, group_size):
|
||||
assert group_size >= 1
|
||||
group = []
|
||||
for x in items:
|
||||
group.append(x)
|
||||
if len(group) == group_size:
|
||||
yield tuple(group)
|
||||
del group[:]
|
||||
if group:
|
||||
yield tuple(group)
|
||||
|
||||
|
||||
def batched_weighted_sum(weights, vecs, batch_size):
|
||||
total = 0
|
||||
num_items_summed = 0
|
||||
for batch_weights, batch_vecs in zip(
|
||||
itergroups(weights, batch_size), itergroups(vecs, batch_size)):
|
||||
assert len(batch_weights) == len(batch_vecs) <= batch_size
|
||||
total += np.dot(
|
||||
np.asarray(batch_weights, dtype=np.float32),
|
||||
np.asarray(batch_vecs, dtype=np.float32))
|
||||
num_items_summed += len(batch_weights)
|
||||
return total, num_items_summed
|
||||
|
||||
|
||||
class LinearNetwork(Model):
|
||||
"""Generic linear network."""
|
||||
|
||||
def _build_layers(self, inputs, num_outputs, _):
|
||||
with tf.name_scope("linear"):
|
||||
output = slim.fully_connected(
|
||||
inputs,
|
||||
num_outputs,
|
||||
weights_initializer=normc_initializer(0.01),
|
||||
activation_fn=None,
|
||||
)
|
||||
return output, inputs
|
||||
|
||||
|
||||
def register_linear_network():
|
||||
ModelCatalog.register_custom_model("LinearNetwork", LinearNetwork)
|
|
@ -116,6 +116,13 @@ class ModelSupportedSpaces(unittest.TestCase):
|
|||
"episodes_per_batch": 1,
|
||||
"timesteps_per_batch": 1
|
||||
}, stats)
|
||||
check_support(
|
||||
"ARS", {
|
||||
"num_workers": 1,
|
||||
"noise_size": 10000000,
|
||||
"num_deltas": 1,
|
||||
"deltas_used": 1
|
||||
}, stats)
|
||||
check_support("PG", {"num_workers": 1, "optimizer": {}}, stats)
|
||||
num_unexpected_errors = 0
|
||||
for (alg, a_name, o_name), stat in sorted(stats.items()):
|
||||
|
|
|
@ -0,0 +1,16 @@
|
|||
cartpole-ars:
|
||||
env: CartPole-v0
|
||||
run: ARS
|
||||
stop:
|
||||
episode_reward_mean: 200
|
||||
time_total_s: 600
|
||||
config:
|
||||
noise_stdev: 0.02
|
||||
num_deltas: 50
|
||||
deltas_used: 25
|
||||
num_workers: 2
|
||||
stepsize: 0.01
|
||||
noise_size: 250000000
|
||||
eval_prob: 0.5
|
||||
policy_type: MLPPolicy
|
||||
fcnet_hiddens: [16, 16]
|
15
python/ray/rllib/tuned_examples/swimmer-ars.yaml
Normal file
15
python/ray/rllib/tuned_examples/swimmer-ars.yaml
Normal file
|
@ -0,0 +1,15 @@
|
|||
# can expect improvement to -140 reward in ~300-500k timesteps
|
||||
pendulum-ars:
|
||||
env: Swimmer-v2
|
||||
run: ARS
|
||||
config:
|
||||
noise_stdev: 0.01
|
||||
num_deltas: 2
|
||||
deltas_used: 1
|
||||
num_workers: 1
|
||||
stepsize: 0.02
|
||||
noise_size: 250000000
|
||||
fcnet_hiddens: [32,32]
|
||||
policy_type: LinearPolicy
|
||||
eval_prob: 0.2
|
||||
offset: 0
|
Loading…
Add table
Reference in a new issue