mirror of
https://github.com/vale981/ray
synced 2025-03-05 18:11:42 -05:00
[RLlib] PyTorch version of ARS (Augmented Random Search). (#8106)
This PR implements a PyTorch version of RLlib's ARS algorithm using RLlib's functional algo builder API. It also adds a regression test for ARS (torch) on CartPole.
This commit is contained in:
parent
d66d12661b
commit
d15609ba2a
16 changed files with 199 additions and 292 deletions
|
@ -12,7 +12,7 @@ Feature Compatibility Matrix
|
|||
Algorithm Frameworks Discrete Actions Continuous Actions Multi-Agent Model Support
|
||||
=================== ========== ======================= ================== =========== =====================
|
||||
`A2C, A3C`_ tf + torch **Yes** `+parametric`_ **Yes** **Yes** `+RNN`_, `+autoreg`_
|
||||
`ARS`_ tf **Yes** **Yes** No
|
||||
`ARS`_ tf + torch **Yes** **Yes** No
|
||||
`ES`_ tf + torch **Yes** **Yes** No
|
||||
`DDPG`_, `TD3`_ tf + torch No **Yes** **Yes**
|
||||
`APEX-DDPG`_ tf No **Yes** **Yes**
|
||||
|
@ -405,7 +405,7 @@ Derivative-free
|
|||
|
||||
Augmented Random Search (ARS)
|
||||
-----------------------------
|
||||
|tensorflow|
|
||||
|pytorch| |tensorflow|
|
||||
`[paper] <https://arxiv.org/abs/1803.07055>`__ `[implementation] <https://github.com/ray-project/ray/blob/master/rllib/agents/ars/ars.py>`__
|
||||
ARS is a random search method for training linear policies for continuous control problems. Code here is adapted from https://github.com/modestyachts/ARS to integrate with RLlib APIs.
|
||||
|
||||
|
|
|
@ -113,7 +113,7 @@ Algorithms
|
|||
|
||||
* Derivative-free
|
||||
|
||||
- |tensorflow| :ref:`Augmented Random Search (ARS) <ars>`
|
||||
- |pytorch| |tensorflow| :ref:`Augmented Random Search (ARS) <ars>`
|
||||
|
||||
- |pytorch| |tensorflow| :ref:`Evolution Strategies <es>`
|
||||
|
||||
|
|
|
@ -1,6 +1,10 @@
|
|||
from ray.rllib.agents.ars.ars import (ARSTrainer, DEFAULT_CONFIG)
|
||||
from ray.rllib.utils import renamed_agent
|
||||
from ray.rllib.agents.ars.ars import ARSTrainer, DEFAULT_CONFIG
|
||||
from ray.rllib.agents.ars.ars_tf_policy import ARSTFPolicy
|
||||
from ray.rllib.agents.ars.ars_torch_policy import ARSTorchPolicy
|
||||
|
||||
ARSAgent = renamed_agent(ARSTrainer)
|
||||
|
||||
__all__ = ["ARSAgent", "ARSTrainer", "DEFAULT_CONFIG"]
|
||||
__all__ = [
|
||||
"ARSTFPolicy",
|
||||
"ARSTorchPolicy",
|
||||
"ARSTrainer",
|
||||
"DEFAULT_CONFIG",
|
||||
]
|
||||
|
|
|
@ -10,9 +10,11 @@ import time
|
|||
import ray
|
||||
from ray.rllib.agents import Trainer, with_common_config
|
||||
|
||||
from ray.rllib.agents.ars import optimizers
|
||||
from ray.rllib.agents.ars import policies
|
||||
from ray.rllib.agents.ars import utils
|
||||
from ray.rllib.agents.ars.ars_tf_policy import ARSTFPolicy
|
||||
from ray.rllib.agents.es import optimizers
|
||||
from ray.rllib.agents.es import utils
|
||||
from ray.rllib.agents.es.es_tf_policy import rollout
|
||||
from ray.rllib.env.env_context import EnvContext
|
||||
from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils.memory import ray_get_and_free
|
||||
|
@ -28,6 +30,7 @@ Result = namedtuple("Result", [
|
|||
# yapf: disable
|
||||
# __sphinx_doc_begin__
|
||||
DEFAULT_CONFIG = with_common_config({
|
||||
"action_noise_std": 0.0,
|
||||
"noise_stdev": 0.02, # std deviation of parameter noise
|
||||
"num_rollouts": 32, # number of perturbs to try
|
||||
"rollouts_used": 32, # number of perturbs to keep in gradient estimate
|
||||
|
@ -69,23 +72,29 @@ class SharedNoiseTable:
|
|||
|
||||
@ray.remote
|
||||
class Worker:
|
||||
def __init__(self, config, env_creator, noise, min_task_runtime=0.2):
|
||||
def __init__(self,
|
||||
config,
|
||||
env_creator,
|
||||
noise,
|
||||
worker_index,
|
||||
min_task_runtime=0.2):
|
||||
self.min_task_runtime = min_task_runtime
|
||||
self.config = config
|
||||
self.config["single_threaded"] = True
|
||||
self.noise = SharedNoiseTable(noise)
|
||||
|
||||
self.env = env_creator(config["env_config"])
|
||||
env_context = EnvContext(config["env_config"] or {}, worker_index)
|
||||
self.env = env_creator(env_context)
|
||||
from ray.rllib import models
|
||||
self.preprocessor = models.ModelCatalog.get_preprocessor(self.env)
|
||||
|
||||
self.sess = utils.make_session(single_threaded=True)
|
||||
self.policy = policies.GenericPolicy(
|
||||
self.sess, self.env.action_space, self.env.observation_space,
|
||||
self.preprocessor, config["observation_filter"], config["model"])
|
||||
policy_cls = get_policy_class(config)
|
||||
self.policy = policy_cls(self.env.observation_space,
|
||||
self.env.action_space, config)
|
||||
|
||||
@property
|
||||
def filters(self):
|
||||
return {DEFAULT_POLICY_ID: self.policy.get_filter()}
|
||||
return {DEFAULT_POLICY_ID: self.policy.observation_filter}
|
||||
|
||||
def sync_filters(self, new_filters):
|
||||
for k in self.filters:
|
||||
|
@ -100,7 +109,7 @@ class Worker:
|
|||
return return_filters
|
||||
|
||||
def rollout(self, timestep_limit, add_noise=False):
|
||||
rollout_rewards, rollout_fragment_length = policies.rollout(
|
||||
rollout_rewards, rollout_fragment_length = rollout(
|
||||
self.policy,
|
||||
self.env,
|
||||
timestep_limit=timestep_limit,
|
||||
|
@ -110,7 +119,7 @@ class Worker:
|
|||
|
||||
def do_rollouts(self, params, timestep_limit=None):
|
||||
# Set the network weights.
|
||||
self.policy.set_weights(params)
|
||||
self.policy.set_flat_weights(params)
|
||||
|
||||
noise_indices, returns, sign_returns, lengths = [], [], [], []
|
||||
eval_returns, eval_lengths = [], []
|
||||
|
@ -119,7 +128,7 @@ class Worker:
|
|||
while (len(noise_indices) == 0):
|
||||
if np.random.uniform() < self.config["eval_prob"]:
|
||||
# Do an evaluation run with no perturbation.
|
||||
self.policy.set_weights(params)
|
||||
self.policy.set_flat_weights(params)
|
||||
rewards, length = self.rollout(timestep_limit, add_noise=False)
|
||||
eval_returns.append(rewards.sum())
|
||||
eval_lengths.append(length)
|
||||
|
@ -132,10 +141,10 @@ class Worker:
|
|||
|
||||
# These two sampling steps could be done in parallel on
|
||||
# different actors letting us update twice as frequently.
|
||||
self.policy.set_weights(params + perturbation)
|
||||
self.policy.set_flat_weights(params + perturbation)
|
||||
rewards_pos, lengths_pos = self.rollout(timestep_limit)
|
||||
|
||||
self.policy.set_weights(params - perturbation)
|
||||
self.policy.set_flat_weights(params - perturbation)
|
||||
rewards_neg, lengths_neg = self.rollout(timestep_limit)
|
||||
|
||||
noise_indices.append(noise_index)
|
||||
|
@ -154,6 +163,15 @@ class Worker:
|
|||
eval_lengths=eval_lengths)
|
||||
|
||||
|
||||
def get_policy_class(config):
|
||||
if config["use_pytorch"]:
|
||||
from ray.rllib.agents.ars.ars_torch_policy import ARSTorchPolicy
|
||||
policy_cls = ARSTorchPolicy
|
||||
else:
|
||||
policy_cls = ARSTFPolicy
|
||||
return policy_cls
|
||||
|
||||
|
||||
class ARSTrainer(Trainer):
|
||||
"""Large-scale implementation of Augmented Random Search in Ray."""
|
||||
|
||||
|
@ -162,19 +180,12 @@ class ARSTrainer(Trainer):
|
|||
|
||||
@override(Trainer)
|
||||
def _init(self, config, env_creator):
|
||||
# PyTorch check.
|
||||
if config["use_pytorch"]:
|
||||
raise ValueError(
|
||||
"ARS does not support PyTorch yet! Use tf instead.")
|
||||
env_context = EnvContext(config["env_config"] or {}, worker_index=0)
|
||||
env = env_creator(env_context)
|
||||
|
||||
env = env_creator(config["env_config"])
|
||||
from ray.rllib import models
|
||||
preprocessor = models.ModelCatalog.get_preprocessor(env)
|
||||
|
||||
self.sess = utils.make_session(single_threaded=False)
|
||||
self.policy = policies.GenericPolicy(
|
||||
self.sess, env.action_space, env.observation_space, preprocessor,
|
||||
config["observation_filter"], config["model"])
|
||||
policy_cls = get_policy_class(config)
|
||||
self.policy = policy_cls(env.observation_space, env.action_space,
|
||||
config)
|
||||
self.optimizer = optimizers.SGD(self.policy, config["sgd_stepsize"])
|
||||
|
||||
self.rollouts_used = config["rollouts_used"]
|
||||
|
@ -189,8 +200,8 @@ class ARSTrainer(Trainer):
|
|||
# Create the actors.
|
||||
logger.info("Creating actors.")
|
||||
self.workers = [
|
||||
Worker.remote(config, env_creator, noise_id)
|
||||
for _ in range(config["num_workers"])
|
||||
Worker.remote(config, env_creator, noise_id, idx + 1)
|
||||
for idx in range(config["num_workers"])
|
||||
]
|
||||
|
||||
self.episodes_so_far = 0
|
||||
|
@ -201,8 +212,9 @@ class ARSTrainer(Trainer):
|
|||
def _train(self):
|
||||
config = self.config
|
||||
|
||||
theta = self.policy.get_weights()
|
||||
theta = self.policy.get_flat_weights()
|
||||
assert theta.dtype == np.float32
|
||||
assert len(theta.shape) == 1
|
||||
|
||||
# Put the current policy weights in the object store.
|
||||
theta_id = ray.put(theta)
|
||||
|
@ -266,14 +278,14 @@ class ARSTrainer(Trainer):
|
|||
# Compute the new weights theta.
|
||||
theta, update_ratio = self.optimizer.update(-g)
|
||||
# Set the new weights in the local copy of the policy.
|
||||
self.policy.set_weights(theta)
|
||||
self.policy.set_flat_weights(theta)
|
||||
# update the reward list
|
||||
if len(all_eval_returns) > 0:
|
||||
self.reward_list.append(eval_returns.mean())
|
||||
|
||||
# Now sync the filters
|
||||
FilterManager.synchronize({
|
||||
DEFAULT_POLICY_ID: self.policy.get_filter()
|
||||
DEFAULT_POLICY_ID: self.policy.observation_filter
|
||||
}, self.workers)
|
||||
|
||||
info = {
|
||||
|
@ -301,7 +313,7 @@ class ARSTrainer(Trainer):
|
|||
|
||||
@override(Trainer)
|
||||
def compute_action(self, observation, *args, **kwargs):
|
||||
return self.policy.compute(observation, update=True)[0]
|
||||
return self.policy.compute_actions(observation, update=True)[0]
|
||||
|
||||
def _collect_results(self, theta_id, min_episodes):
|
||||
num_episodes, num_timesteps = 0, 0
|
||||
|
@ -327,15 +339,15 @@ class ARSTrainer(Trainer):
|
|||
|
||||
def __getstate__(self):
|
||||
return {
|
||||
"weights": self.policy.get_weights(),
|
||||
"filter": self.policy.get_filter(),
|
||||
"weights": self.policy.get_flat_weights(),
|
||||
"filter": self.policy.observation_filter,
|
||||
"episodes_so_far": self.episodes_so_far,
|
||||
}
|
||||
|
||||
def __setstate__(self, state):
|
||||
self.episodes_so_far = state["episodes_so_far"]
|
||||
self.policy.set_weights(state["weights"])
|
||||
self.policy.set_filter(state["filter"])
|
||||
self.policy.set_flat_weights(state["weights"])
|
||||
self.policy.observation_filter = state["filter"]
|
||||
FilterManager.synchronize({
|
||||
DEFAULT_POLICY_ID: self.policy.get_filter()
|
||||
DEFAULT_POLICY_ID: self.policy.observation_filter
|
||||
}, self.workers)
|
||||
|
|
68
rllib/agents/ars/ars_tf_policy.py
Normal file
68
rllib/agents/ars/ars_tf_policy.py
Normal file
|
@ -0,0 +1,68 @@
|
|||
# Code in this file is copied and adapted from
|
||||
# https://github.com/openai/evolution-strategies-starter.
|
||||
|
||||
import gym
|
||||
import numpy as np
|
||||
|
||||
import ray
|
||||
import ray.experimental.tf_utils
|
||||
from ray.rllib.agents.es.es_tf_policy import make_session
|
||||
from ray.rllib.evaluation.sampler import _unbatch_tuple_actions
|
||||
from ray.rllib.models import ModelCatalog
|
||||
from ray.rllib.policy.sample_batch import SampleBatch
|
||||
from ray.rllib.utils.filter import get_filter
|
||||
from ray.rllib.utils.framework import try_import_tf
|
||||
|
||||
tf = try_import_tf()
|
||||
|
||||
|
||||
class ARSTFPolicy:
|
||||
def __init__(self, obs_space, action_space, config):
|
||||
self.observation_space = obs_space
|
||||
self.action_space = action_space
|
||||
self.action_noise_std = config["action_noise_std"]
|
||||
self.preprocessor = ModelCatalog.get_preprocessor_for_space(
|
||||
self.observation_space)
|
||||
self.observation_filter = get_filter(config["observation_filter"],
|
||||
self.preprocessor.shape)
|
||||
|
||||
self.single_threaded = config.get("single_threaded", False)
|
||||
self.sess = make_session(single_threaded=self.single_threaded)
|
||||
|
||||
self.inputs = tf.placeholder(tf.float32,
|
||||
[None] + list(self.preprocessor.shape))
|
||||
|
||||
# Policy network.
|
||||
dist_class, dist_dim = ModelCatalog.get_action_dist(
|
||||
self.action_space, config["model"], dist_type="deterministic")
|
||||
|
||||
model = ModelCatalog.get_model({
|
||||
SampleBatch.CUR_OBS: self.inputs
|
||||
}, self.observation_space, self.action_space, dist_dim,
|
||||
config["model"])
|
||||
dist = dist_class(model.outputs, model)
|
||||
self.sampler = dist.sample()
|
||||
|
||||
self.variables = ray.experimental.tf_utils.TensorFlowVariables(
|
||||
model.outputs, self.sess)
|
||||
|
||||
self.num_params = sum(
|
||||
np.prod(variable.shape.as_list())
|
||||
for _, variable in self.variables.variables.items())
|
||||
self.sess.run(tf.global_variables_initializer())
|
||||
|
||||
def compute_actions(self, observation, add_noise=False, update=True):
|
||||
observation = self.preprocessor.transform(observation)
|
||||
observation = self.observation_filter(observation[None], update=update)
|
||||
action = self.sess.run(
|
||||
self.sampler, feed_dict={self.inputs: observation})
|
||||
action = _unbatch_tuple_actions(action)
|
||||
if add_noise and isinstance(self.action_space, gym.spaces.Box):
|
||||
action += np.random.randn(*action.shape) * self.action_noise_std
|
||||
return action
|
||||
|
||||
def set_flat_weights(self, x):
|
||||
self.variables.set_flat(x)
|
||||
|
||||
def get_flat_weights(self):
|
||||
return self.variables.get_flat()
|
15
rllib/agents/ars/ars_torch_policy.py
Normal file
15
rllib/agents/ars/ars_torch_policy.py
Normal file
|
@ -0,0 +1,15 @@
|
|||
# Code in this file is adapted from:
|
||||
# https://github.com/openai/evolution-strategies-starter.
|
||||
|
||||
import ray
|
||||
from ray.rllib.agents.es.es_torch_policy import after_init, before_init, \
|
||||
make_model_and_action_dist
|
||||
from ray.rllib.policy.torch_policy_template import build_torch_policy
|
||||
|
||||
ARSTorchPolicy = build_torch_policy(
|
||||
name="ARSTorchPolicy",
|
||||
loss_fn=None,
|
||||
get_default_config=lambda: ray.rllib.agents.ars.ars.DEFAULT_CONFIG,
|
||||
before_init=before_init,
|
||||
after_init=after_init,
|
||||
make_model_and_action_dist=make_model_and_action_dist)
|
|
@ -1,53 +0,0 @@
|
|||
# Code in this file is copied and adapted from
|
||||
# https://github.com/openai/evolution-strategies-starter.
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
class Optimizer:
|
||||
def __init__(self, policy):
|
||||
self.policy = policy
|
||||
self.dim = policy.num_params
|
||||
self.t = 0
|
||||
|
||||
def update(self, globalg):
|
||||
self.t += 1
|
||||
step = self._compute_step(globalg)
|
||||
theta = self.policy.get_weights()
|
||||
ratio = np.linalg.norm(step) / np.linalg.norm(theta)
|
||||
return theta + step, ratio
|
||||
|
||||
def _compute_step(self, globalg):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class SGD(Optimizer):
|
||||
def __init__(self, policy, stepsize, momentum=0.0):
|
||||
Optimizer.__init__(self, policy)
|
||||
self.v = np.zeros(self.dim, dtype=np.float32)
|
||||
self.stepsize, self.momentum = stepsize, momentum
|
||||
|
||||
def _compute_step(self, globalg):
|
||||
self.v = self.momentum * self.v + (1. - self.momentum) * globalg
|
||||
step = -self.stepsize * self.v
|
||||
return step
|
||||
|
||||
|
||||
class Adam(Optimizer):
|
||||
def __init__(self, policy, stepsize, beta1=0.9, beta2=0.999,
|
||||
epsilon=1e-08):
|
||||
Optimizer.__init__(self, policy)
|
||||
self.stepsize = stepsize
|
||||
self.beta1 = beta1
|
||||
self.beta2 = beta2
|
||||
self.epsilon = epsilon
|
||||
self.m = np.zeros(self.dim, dtype=np.float32)
|
||||
self.v = np.zeros(self.dim, dtype=np.float32)
|
||||
|
||||
def _compute_step(self, globalg):
|
||||
a = self.stepsize * (np.sqrt(1 - self.beta2**self.t) /
|
||||
(1 - self.beta1**self.t))
|
||||
self.m = self.beta1 * self.m + (1 - self.beta1) * globalg
|
||||
self.v = self.beta2 * self.v + (1 - self.beta2) * (globalg * globalg)
|
||||
step = -a * self.m / (np.sqrt(self.v) + self.epsilon)
|
||||
return step
|
|
@ -1,111 +0,0 @@
|
|||
# Code in this file is copied and adapted from
|
||||
# https://github.com/openai/evolution-strategies-starter.
|
||||
|
||||
import gym
|
||||
import numpy as np
|
||||
|
||||
import ray
|
||||
import ray.experimental.tf_utils
|
||||
from ray.rllib.evaluation.sampler import _unbatch_tuple_actions
|
||||
from ray.rllib.utils.filter import get_filter
|
||||
from ray.rllib.models import ModelCatalog
|
||||
from ray.rllib.utils import try_import_tf
|
||||
|
||||
tf = try_import_tf()
|
||||
|
||||
|
||||
def rollout(policy, env, timestep_limit=None, add_noise=False, offset=0):
|
||||
"""Do a rollout.
|
||||
|
||||
If add_noise is True, the rollout will take noisy actions with
|
||||
noise drawn from that stream. Otherwise, no action noise will be added.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
policy: tf object
|
||||
policy from which to draw actions
|
||||
env: GymEnv
|
||||
environment from which to draw rewards, done, and next state
|
||||
timestep_limit: int, optional
|
||||
steps after which to end the rollout
|
||||
add_noise: bool, optional
|
||||
indicates whether exploratory action noise should be added
|
||||
offset: int, optional
|
||||
value to subtract from the reward. For example, survival bonus
|
||||
from humanoid
|
||||
"""
|
||||
env_timestep_limit = env.spec.max_episode_steps
|
||||
timestep_limit = (env_timestep_limit if timestep_limit is None else min(
|
||||
timestep_limit, env_timestep_limit))
|
||||
rews = []
|
||||
t = 0
|
||||
observation = env.reset()
|
||||
for _ in range(timestep_limit or 999999):
|
||||
ac = policy.compute(observation, add_noise=add_noise, update=True)[0]
|
||||
observation, rew, done, _ = env.step(ac)
|
||||
rew -= np.abs(offset)
|
||||
rews.append(rew)
|
||||
t += 1
|
||||
if done:
|
||||
break
|
||||
rews = np.array(rews, dtype=np.float32)
|
||||
return rews, t
|
||||
|
||||
|
||||
class GenericPolicy:
|
||||
def __init__(self,
|
||||
sess,
|
||||
action_space,
|
||||
obs_space,
|
||||
preprocessor,
|
||||
observation_filter,
|
||||
model_config,
|
||||
action_noise_std=0.0):
|
||||
self.sess = sess
|
||||
self.action_space = action_space
|
||||
self.action_noise_std = action_noise_std
|
||||
self.preprocessor = preprocessor
|
||||
self.observation_filter = get_filter(observation_filter,
|
||||
self.preprocessor.shape)
|
||||
self.inputs = tf.placeholder(tf.float32,
|
||||
[None] + list(self.preprocessor.shape))
|
||||
|
||||
# Policy network.
|
||||
dist_class, dist_dim = ModelCatalog.get_action_dist(
|
||||
action_space, model_config, dist_type="deterministic")
|
||||
|
||||
model = ModelCatalog.get_model({
|
||||
"obs": self.inputs
|
||||
}, obs_space, action_space, dist_dim, model_config)
|
||||
dist = dist_class(model.outputs, model)
|
||||
self.sampler = dist.sample()
|
||||
|
||||
self.variables = ray.experimental.tf_utils.TensorFlowVariables(
|
||||
model.outputs, self.sess)
|
||||
|
||||
self.num_params = sum(
|
||||
np.prod(variable.shape.as_list())
|
||||
for _, variable in self.variables.variables.items())
|
||||
self.sess.run(tf.global_variables_initializer())
|
||||
|
||||
def compute(self, observation, add_noise=False, update=True):
|
||||
observation = self.preprocessor.transform(observation)
|
||||
observation = self.observation_filter(observation[None], update=update)
|
||||
action = self.sess.run(
|
||||
self.sampler, feed_dict={self.inputs: observation})
|
||||
action = _unbatch_tuple_actions(action)
|
||||
if add_noise and isinstance(self.action_space, gym.spaces.Box):
|
||||
action += np.random.randn(*action.shape) * self.action_noise_std
|
||||
return action
|
||||
|
||||
def set_weights(self, x):
|
||||
self.variables.set_flat(x)
|
||||
|
||||
def set_filter(self, obs_filter):
|
||||
self.observation_filter = obs_filter
|
||||
|
||||
def get_filter(self):
|
||||
return self.observation_filter
|
||||
|
||||
def get_weights(self):
|
||||
return self.variables.get_flat()
|
|
@ -1,59 +0,0 @@
|
|||
# Code in this file is copied and adapted from
|
||||
# https://github.com/openai/evolution-strategies-starter.
|
||||
|
||||
import numpy as np
|
||||
from ray.rllib.utils import try_import_tf
|
||||
|
||||
tf = try_import_tf()
|
||||
|
||||
|
||||
def compute_ranks(x):
|
||||
"""Returns ranks in [0, len(x))
|
||||
|
||||
Note: This is different from scipy.stats.rankdata, which returns ranks in
|
||||
[1, len(x)].
|
||||
"""
|
||||
assert x.ndim == 1
|
||||
ranks = np.empty(len(x), dtype=int)
|
||||
ranks[x.argsort()] = np.arange(len(x))
|
||||
return ranks
|
||||
|
||||
|
||||
def compute_centered_ranks(x):
|
||||
y = compute_ranks(x.ravel()).reshape(x.shape).astype(np.float32)
|
||||
y /= (x.size - 1)
|
||||
y -= 0.5
|
||||
return y
|
||||
|
||||
|
||||
def make_session(single_threaded):
|
||||
if not single_threaded:
|
||||
return tf.Session()
|
||||
return tf.Session(
|
||||
config=tf.ConfigProto(
|
||||
inter_op_parallelism_threads=1, intra_op_parallelism_threads=1))
|
||||
|
||||
|
||||
def itergroups(items, group_size):
|
||||
assert group_size >= 1
|
||||
group = []
|
||||
for x in items:
|
||||
group.append(x)
|
||||
if len(group) == group_size:
|
||||
yield tuple(group)
|
||||
del group[:]
|
||||
if group:
|
||||
yield tuple(group)
|
||||
|
||||
|
||||
def batched_weighted_sum(weights, vecs, batch_size):
|
||||
total = 0
|
||||
num_items_summed = 0
|
||||
for batch_weights, batch_vecs in zip(
|
||||
itergroups(weights, batch_size), itergroups(vecs, batch_size)):
|
||||
assert len(batch_weights) == len(batch_vecs) <= batch_size
|
||||
total += np.dot(
|
||||
np.asarray(batch_weights, dtype=np.float32),
|
||||
np.asarray(batch_vecs, dtype=np.float32))
|
||||
num_items_summed += len(batch_weights)
|
||||
return total, num_items_summed
|
|
@ -26,6 +26,7 @@ Result = namedtuple("Result", [
|
|||
# yapf: disable
|
||||
# __sphinx_doc_begin__
|
||||
DEFAULT_CONFIG = with_common_config({
|
||||
"action_noise_std": 0.01,
|
||||
"l2_coeff": 0.005,
|
||||
"noise_stdev": 0.02,
|
||||
"episodes_per_batch": 1000,
|
||||
|
@ -177,8 +178,6 @@ class ESTrainer(Trainer):
|
|||
|
||||
@override(Trainer)
|
||||
def _init(self, config, env_creator):
|
||||
policy_params = {"action_noise_std": 0.01}
|
||||
config.update(policy_params)
|
||||
env_context = EnvContext(config["env_config"] or {}, worker_index=0)
|
||||
env = env_creator(env_context)
|
||||
policy_cls = get_policy_class(config)
|
||||
|
@ -197,8 +196,8 @@ class ESTrainer(Trainer):
|
|||
# Create the actors.
|
||||
logger.info("Creating actors.")
|
||||
self._workers = [
|
||||
Worker.remote(config, policy_params, env_creator, noise_id,
|
||||
idx + 1) for idx in range(config["num_workers"])
|
||||
Worker.remote(config, {}, env_creator, noise_id, idx + 1)
|
||||
for idx in range(config["num_workers"])
|
||||
]
|
||||
|
||||
self.episodes_so_far = 0
|
||||
|
|
|
@ -10,16 +10,27 @@ from ray.rllib.evaluation.sampler import _unbatch_tuple_actions
|
|||
from ray.rllib.models import ModelCatalog
|
||||
from ray.rllib.policy.sample_batch import SampleBatch
|
||||
from ray.rllib.utils.filter import get_filter
|
||||
from ray.rllib.utils import try_import_tf
|
||||
from ray.rllib.utils.framework import try_import_tf
|
||||
|
||||
tf = try_import_tf()
|
||||
|
||||
|
||||
def rollout(policy, env, timestep_limit=None, add_noise=False):
|
||||
def rollout(policy, env, timestep_limit=None, add_noise=False, offset=0.0):
|
||||
"""Do a rollout.
|
||||
|
||||
If add_noise is True, the rollout will take noisy actions with
|
||||
noise drawn from that stream. Otherwise, no action noise will be added.
|
||||
|
||||
Args:
|
||||
policy (Policy): Rllib Policy from which to draw actions.
|
||||
env (gym.Env): Environment from which to draw rewards, done, and
|
||||
next state.
|
||||
timestep_limit (Optional[int]): Steps after which to end the rollout.
|
||||
If None, use `env.spec.max_episode_steps` or 999999.
|
||||
add_noise (bool): Indicates whether exploratory action noise should be
|
||||
added.
|
||||
offset (float): Value to subtract from the reward (e.g. survival bonus
|
||||
from humanoid).
|
||||
"""
|
||||
max_timestep_limit = 999999
|
||||
env_timestep_limit = env.spec.max_episode_steps if (
|
||||
|
@ -27,18 +38,21 @@ def rollout(policy, env, timestep_limit=None, add_noise=False):
|
|||
else max_timestep_limit
|
||||
timestep_limit = (env_timestep_limit if timestep_limit is None else min(
|
||||
timestep_limit, env_timestep_limit))
|
||||
rews = []
|
||||
rewards = []
|
||||
t = 0
|
||||
observation = env.reset()
|
||||
for _ in range(timestep_limit or max_timestep_limit):
|
||||
ac = policy.compute_actions(observation, add_noise=add_noise)[0]
|
||||
observation, rew, done, _ = env.step(ac)
|
||||
rews.append(rew)
|
||||
ac = policy.compute_actions(
|
||||
observation, add_noise=add_noise, update=True)[0]
|
||||
observation, r, done, _ = env.step(ac)
|
||||
if offset != 0.0:
|
||||
r -= np.abs(offset)
|
||||
rewards.append(r)
|
||||
t += 1
|
||||
if done:
|
||||
break
|
||||
rews = np.array(rews, dtype=np.float32)
|
||||
return rews, t
|
||||
rewards = np.array(rewards, dtype=np.float32)
|
||||
return rewards, t
|
||||
|
||||
|
||||
def make_session(single_threaded):
|
||||
|
|
|
@ -5,7 +5,6 @@ import gym
|
|||
import numpy as np
|
||||
|
||||
import ray
|
||||
import ray.experimental.tf_utils
|
||||
from ray.rllib.evaluation.sampler import _unbatch_tuple_actions
|
||||
from ray.rllib.models import ModelCatalog
|
||||
from ray.rllib.policy.sample_batch import SampleBatch
|
||||
|
|
|
@ -5,15 +5,15 @@ import numpy as np
|
|||
|
||||
|
||||
class Optimizer:
|
||||
def __init__(self, pi):
|
||||
self.pi = pi
|
||||
self.dim = pi.num_params
|
||||
def __init__(self, policy):
|
||||
self.policy = policy
|
||||
self.dim = policy.num_params
|
||||
self.t = 0
|
||||
|
||||
def update(self, globalg):
|
||||
self.t += 1
|
||||
step = self._compute_step(globalg)
|
||||
theta = self.pi.get_flat_weights()
|
||||
theta = self.policy.get_flat_weights()
|
||||
ratio = np.linalg.norm(step) / np.linalg.norm(theta)
|
||||
return theta + step, ratio
|
||||
|
||||
|
@ -22,8 +22,8 @@ class Optimizer:
|
|||
|
||||
|
||||
class SGD(Optimizer):
|
||||
def __init__(self, pi, stepsize, momentum=0.9):
|
||||
Optimizer.__init__(self, pi)
|
||||
def __init__(self, policy, stepsize, momentum=0.0):
|
||||
Optimizer.__init__(self, policy)
|
||||
self.v = np.zeros(self.dim, dtype=np.float32)
|
||||
self.stepsize, self.momentum = stepsize, momentum
|
||||
|
||||
|
@ -34,8 +34,9 @@ class SGD(Optimizer):
|
|||
|
||||
|
||||
class Adam(Optimizer):
|
||||
def __init__(self, pi, stepsize, beta1=0.9, beta2=0.999, epsilon=1e-08):
|
||||
Optimizer.__init__(self, pi)
|
||||
def __init__(self, policy, stepsize, beta1=0.9, beta2=0.999,
|
||||
epsilon=1e-08):
|
||||
Optimizer.__init__(self, policy)
|
||||
self.stepsize = stepsize
|
||||
self.beta1 = beta1
|
||||
self.beta2 = beta2
|
||||
|
|
|
@ -67,7 +67,7 @@ class FullyConnectedNetwork(TorchModelV2, nn.Module):
|
|||
prev_layer_size = hiddens[-1]
|
||||
if self.num_outputs:
|
||||
self._logits = SlimFC(
|
||||
in_size=hiddens[-1],
|
||||
in_size=prev_layer_size,
|
||||
out_size=self.num_outputs,
|
||||
initializer=normc_initializer(0.01),
|
||||
activation_fn=None)
|
||||
|
|
|
@ -1,10 +1,11 @@
|
|||
cartpole-ars:
|
||||
cartpole-ars-tf:
|
||||
env: CartPole-v0
|
||||
run: ARS
|
||||
stop:
|
||||
episode_reward_mean: 50
|
||||
timesteps_total: 500000
|
||||
config:
|
||||
use_pytorch: false
|
||||
noise_stdev: 0.02
|
||||
num_rollouts: 50
|
||||
rollouts_used: 25
|
|
@ -0,0 +1,17 @@
|
|||
cartpole-ars-torch:
|
||||
env: CartPole-v0
|
||||
run: ARS
|
||||
stop:
|
||||
episode_reward_mean: 150
|
||||
timesteps_total: 500000
|
||||
config:
|
||||
use_pytorch: true
|
||||
noise_stdev: 0.02
|
||||
num_rollouts: 50
|
||||
rollouts_used: 25
|
||||
num_workers: 2
|
||||
sgd_stepsize: 0.01
|
||||
noise_size: 25000000
|
||||
eval_prob: 0.5
|
||||
model:
|
||||
fcnet_hiddens: [64, 64]
|
Loading…
Add table
Reference in a new issue