[RLlib] PyTorch version of ARS (Augmented Random Search). (#8106)

This PR implements a PyTorch version of RLlib's ARS algorithm using RLlib's functional algo builder API. It also adds a regression test for ARS (torch) on CartPole.
This commit is contained in:
Sven Mika 2020-04-21 09:47:52 +02:00 committed by GitHub
parent d66d12661b
commit d15609ba2a
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
16 changed files with 199 additions and 292 deletions

View file

@ -12,7 +12,7 @@ Feature Compatibility Matrix
Algorithm Frameworks Discrete Actions Continuous Actions Multi-Agent Model Support
=================== ========== ======================= ================== =========== =====================
`A2C, A3C`_ tf + torch **Yes** `+parametric`_ **Yes** **Yes** `+RNN`_, `+autoreg`_
`ARS`_ tf **Yes** **Yes** No
`ARS`_ tf + torch **Yes** **Yes** No
`ES`_ tf + torch **Yes** **Yes** No
`DDPG`_, `TD3`_ tf + torch No **Yes** **Yes**
`APEX-DDPG`_ tf No **Yes** **Yes**
@ -405,7 +405,7 @@ Derivative-free
Augmented Random Search (ARS)
-----------------------------
|tensorflow|
|pytorch| |tensorflow|
`[paper] <https://arxiv.org/abs/1803.07055>`__ `[implementation] <https://github.com/ray-project/ray/blob/master/rllib/agents/ars/ars.py>`__
ARS is a random search method for training linear policies for continuous control problems. Code here is adapted from https://github.com/modestyachts/ARS to integrate with RLlib APIs.

View file

@ -113,7 +113,7 @@ Algorithms
* Derivative-free
- |tensorflow| :ref:`Augmented Random Search (ARS) <ars>`
- |pytorch| |tensorflow| :ref:`Augmented Random Search (ARS) <ars>`
- |pytorch| |tensorflow| :ref:`Evolution Strategies <es>`

View file

@ -1,6 +1,10 @@
from ray.rllib.agents.ars.ars import (ARSTrainer, DEFAULT_CONFIG)
from ray.rllib.utils import renamed_agent
from ray.rllib.agents.ars.ars import ARSTrainer, DEFAULT_CONFIG
from ray.rllib.agents.ars.ars_tf_policy import ARSTFPolicy
from ray.rllib.agents.ars.ars_torch_policy import ARSTorchPolicy
ARSAgent = renamed_agent(ARSTrainer)
__all__ = ["ARSAgent", "ARSTrainer", "DEFAULT_CONFIG"]
__all__ = [
"ARSTFPolicy",
"ARSTorchPolicy",
"ARSTrainer",
"DEFAULT_CONFIG",
]

View file

@ -10,9 +10,11 @@ import time
import ray
from ray.rllib.agents import Trainer, with_common_config
from ray.rllib.agents.ars import optimizers
from ray.rllib.agents.ars import policies
from ray.rllib.agents.ars import utils
from ray.rllib.agents.ars.ars_tf_policy import ARSTFPolicy
from ray.rllib.agents.es import optimizers
from ray.rllib.agents.es import utils
from ray.rllib.agents.es.es_tf_policy import rollout
from ray.rllib.env.env_context import EnvContext
from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID
from ray.rllib.utils.annotations import override
from ray.rllib.utils.memory import ray_get_and_free
@ -28,6 +30,7 @@ Result = namedtuple("Result", [
# yapf: disable
# __sphinx_doc_begin__
DEFAULT_CONFIG = with_common_config({
"action_noise_std": 0.0,
"noise_stdev": 0.02, # std deviation of parameter noise
"num_rollouts": 32, # number of perturbs to try
"rollouts_used": 32, # number of perturbs to keep in gradient estimate
@ -69,23 +72,29 @@ class SharedNoiseTable:
@ray.remote
class Worker:
def __init__(self, config, env_creator, noise, min_task_runtime=0.2):
def __init__(self,
config,
env_creator,
noise,
worker_index,
min_task_runtime=0.2):
self.min_task_runtime = min_task_runtime
self.config = config
self.config["single_threaded"] = True
self.noise = SharedNoiseTable(noise)
self.env = env_creator(config["env_config"])
env_context = EnvContext(config["env_config"] or {}, worker_index)
self.env = env_creator(env_context)
from ray.rllib import models
self.preprocessor = models.ModelCatalog.get_preprocessor(self.env)
self.sess = utils.make_session(single_threaded=True)
self.policy = policies.GenericPolicy(
self.sess, self.env.action_space, self.env.observation_space,
self.preprocessor, config["observation_filter"], config["model"])
policy_cls = get_policy_class(config)
self.policy = policy_cls(self.env.observation_space,
self.env.action_space, config)
@property
def filters(self):
return {DEFAULT_POLICY_ID: self.policy.get_filter()}
return {DEFAULT_POLICY_ID: self.policy.observation_filter}
def sync_filters(self, new_filters):
for k in self.filters:
@ -100,7 +109,7 @@ class Worker:
return return_filters
def rollout(self, timestep_limit, add_noise=False):
rollout_rewards, rollout_fragment_length = policies.rollout(
rollout_rewards, rollout_fragment_length = rollout(
self.policy,
self.env,
timestep_limit=timestep_limit,
@ -110,7 +119,7 @@ class Worker:
def do_rollouts(self, params, timestep_limit=None):
# Set the network weights.
self.policy.set_weights(params)
self.policy.set_flat_weights(params)
noise_indices, returns, sign_returns, lengths = [], [], [], []
eval_returns, eval_lengths = [], []
@ -119,7 +128,7 @@ class Worker:
while (len(noise_indices) == 0):
if np.random.uniform() < self.config["eval_prob"]:
# Do an evaluation run with no perturbation.
self.policy.set_weights(params)
self.policy.set_flat_weights(params)
rewards, length = self.rollout(timestep_limit, add_noise=False)
eval_returns.append(rewards.sum())
eval_lengths.append(length)
@ -132,10 +141,10 @@ class Worker:
# These two sampling steps could be done in parallel on
# different actors letting us update twice as frequently.
self.policy.set_weights(params + perturbation)
self.policy.set_flat_weights(params + perturbation)
rewards_pos, lengths_pos = self.rollout(timestep_limit)
self.policy.set_weights(params - perturbation)
self.policy.set_flat_weights(params - perturbation)
rewards_neg, lengths_neg = self.rollout(timestep_limit)
noise_indices.append(noise_index)
@ -154,6 +163,15 @@ class Worker:
eval_lengths=eval_lengths)
def get_policy_class(config):
if config["use_pytorch"]:
from ray.rllib.agents.ars.ars_torch_policy import ARSTorchPolicy
policy_cls = ARSTorchPolicy
else:
policy_cls = ARSTFPolicy
return policy_cls
class ARSTrainer(Trainer):
"""Large-scale implementation of Augmented Random Search in Ray."""
@ -162,19 +180,12 @@ class ARSTrainer(Trainer):
@override(Trainer)
def _init(self, config, env_creator):
# PyTorch check.
if config["use_pytorch"]:
raise ValueError(
"ARS does not support PyTorch yet! Use tf instead.")
env_context = EnvContext(config["env_config"] or {}, worker_index=0)
env = env_creator(env_context)
env = env_creator(config["env_config"])
from ray.rllib import models
preprocessor = models.ModelCatalog.get_preprocessor(env)
self.sess = utils.make_session(single_threaded=False)
self.policy = policies.GenericPolicy(
self.sess, env.action_space, env.observation_space, preprocessor,
config["observation_filter"], config["model"])
policy_cls = get_policy_class(config)
self.policy = policy_cls(env.observation_space, env.action_space,
config)
self.optimizer = optimizers.SGD(self.policy, config["sgd_stepsize"])
self.rollouts_used = config["rollouts_used"]
@ -189,8 +200,8 @@ class ARSTrainer(Trainer):
# Create the actors.
logger.info("Creating actors.")
self.workers = [
Worker.remote(config, env_creator, noise_id)
for _ in range(config["num_workers"])
Worker.remote(config, env_creator, noise_id, idx + 1)
for idx in range(config["num_workers"])
]
self.episodes_so_far = 0
@ -201,8 +212,9 @@ class ARSTrainer(Trainer):
def _train(self):
config = self.config
theta = self.policy.get_weights()
theta = self.policy.get_flat_weights()
assert theta.dtype == np.float32
assert len(theta.shape) == 1
# Put the current policy weights in the object store.
theta_id = ray.put(theta)
@ -266,14 +278,14 @@ class ARSTrainer(Trainer):
# Compute the new weights theta.
theta, update_ratio = self.optimizer.update(-g)
# Set the new weights in the local copy of the policy.
self.policy.set_weights(theta)
self.policy.set_flat_weights(theta)
# update the reward list
if len(all_eval_returns) > 0:
self.reward_list.append(eval_returns.mean())
# Now sync the filters
FilterManager.synchronize({
DEFAULT_POLICY_ID: self.policy.get_filter()
DEFAULT_POLICY_ID: self.policy.observation_filter
}, self.workers)
info = {
@ -301,7 +313,7 @@ class ARSTrainer(Trainer):
@override(Trainer)
def compute_action(self, observation, *args, **kwargs):
return self.policy.compute(observation, update=True)[0]
return self.policy.compute_actions(observation, update=True)[0]
def _collect_results(self, theta_id, min_episodes):
num_episodes, num_timesteps = 0, 0
@ -327,15 +339,15 @@ class ARSTrainer(Trainer):
def __getstate__(self):
return {
"weights": self.policy.get_weights(),
"filter": self.policy.get_filter(),
"weights": self.policy.get_flat_weights(),
"filter": self.policy.observation_filter,
"episodes_so_far": self.episodes_so_far,
}
def __setstate__(self, state):
self.episodes_so_far = state["episodes_so_far"]
self.policy.set_weights(state["weights"])
self.policy.set_filter(state["filter"])
self.policy.set_flat_weights(state["weights"])
self.policy.observation_filter = state["filter"]
FilterManager.synchronize({
DEFAULT_POLICY_ID: self.policy.get_filter()
DEFAULT_POLICY_ID: self.policy.observation_filter
}, self.workers)

View file

@ -0,0 +1,68 @@
# Code in this file is copied and adapted from
# https://github.com/openai/evolution-strategies-starter.
import gym
import numpy as np
import ray
import ray.experimental.tf_utils
from ray.rllib.agents.es.es_tf_policy import make_session
from ray.rllib.evaluation.sampler import _unbatch_tuple_actions
from ray.rllib.models import ModelCatalog
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.utils.filter import get_filter
from ray.rllib.utils.framework import try_import_tf
tf = try_import_tf()
class ARSTFPolicy:
def __init__(self, obs_space, action_space, config):
self.observation_space = obs_space
self.action_space = action_space
self.action_noise_std = config["action_noise_std"]
self.preprocessor = ModelCatalog.get_preprocessor_for_space(
self.observation_space)
self.observation_filter = get_filter(config["observation_filter"],
self.preprocessor.shape)
self.single_threaded = config.get("single_threaded", False)
self.sess = make_session(single_threaded=self.single_threaded)
self.inputs = tf.placeholder(tf.float32,
[None] + list(self.preprocessor.shape))
# Policy network.
dist_class, dist_dim = ModelCatalog.get_action_dist(
self.action_space, config["model"], dist_type="deterministic")
model = ModelCatalog.get_model({
SampleBatch.CUR_OBS: self.inputs
}, self.observation_space, self.action_space, dist_dim,
config["model"])
dist = dist_class(model.outputs, model)
self.sampler = dist.sample()
self.variables = ray.experimental.tf_utils.TensorFlowVariables(
model.outputs, self.sess)
self.num_params = sum(
np.prod(variable.shape.as_list())
for _, variable in self.variables.variables.items())
self.sess.run(tf.global_variables_initializer())
def compute_actions(self, observation, add_noise=False, update=True):
observation = self.preprocessor.transform(observation)
observation = self.observation_filter(observation[None], update=update)
action = self.sess.run(
self.sampler, feed_dict={self.inputs: observation})
action = _unbatch_tuple_actions(action)
if add_noise and isinstance(self.action_space, gym.spaces.Box):
action += np.random.randn(*action.shape) * self.action_noise_std
return action
def set_flat_weights(self, x):
self.variables.set_flat(x)
def get_flat_weights(self):
return self.variables.get_flat()

View file

@ -0,0 +1,15 @@
# Code in this file is adapted from:
# https://github.com/openai/evolution-strategies-starter.
import ray
from ray.rllib.agents.es.es_torch_policy import after_init, before_init, \
make_model_and_action_dist
from ray.rllib.policy.torch_policy_template import build_torch_policy
ARSTorchPolicy = build_torch_policy(
name="ARSTorchPolicy",
loss_fn=None,
get_default_config=lambda: ray.rllib.agents.ars.ars.DEFAULT_CONFIG,
before_init=before_init,
after_init=after_init,
make_model_and_action_dist=make_model_and_action_dist)

View file

@ -1,53 +0,0 @@
# Code in this file is copied and adapted from
# https://github.com/openai/evolution-strategies-starter.
import numpy as np
class Optimizer:
def __init__(self, policy):
self.policy = policy
self.dim = policy.num_params
self.t = 0
def update(self, globalg):
self.t += 1
step = self._compute_step(globalg)
theta = self.policy.get_weights()
ratio = np.linalg.norm(step) / np.linalg.norm(theta)
return theta + step, ratio
def _compute_step(self, globalg):
raise NotImplementedError
class SGD(Optimizer):
def __init__(self, policy, stepsize, momentum=0.0):
Optimizer.__init__(self, policy)
self.v = np.zeros(self.dim, dtype=np.float32)
self.stepsize, self.momentum = stepsize, momentum
def _compute_step(self, globalg):
self.v = self.momentum * self.v + (1. - self.momentum) * globalg
step = -self.stepsize * self.v
return step
class Adam(Optimizer):
def __init__(self, policy, stepsize, beta1=0.9, beta2=0.999,
epsilon=1e-08):
Optimizer.__init__(self, policy)
self.stepsize = stepsize
self.beta1 = beta1
self.beta2 = beta2
self.epsilon = epsilon
self.m = np.zeros(self.dim, dtype=np.float32)
self.v = np.zeros(self.dim, dtype=np.float32)
def _compute_step(self, globalg):
a = self.stepsize * (np.sqrt(1 - self.beta2**self.t) /
(1 - self.beta1**self.t))
self.m = self.beta1 * self.m + (1 - self.beta1) * globalg
self.v = self.beta2 * self.v + (1 - self.beta2) * (globalg * globalg)
step = -a * self.m / (np.sqrt(self.v) + self.epsilon)
return step

View file

@ -1,111 +0,0 @@
# Code in this file is copied and adapted from
# https://github.com/openai/evolution-strategies-starter.
import gym
import numpy as np
import ray
import ray.experimental.tf_utils
from ray.rllib.evaluation.sampler import _unbatch_tuple_actions
from ray.rllib.utils.filter import get_filter
from ray.rllib.models import ModelCatalog
from ray.rllib.utils import try_import_tf
tf = try_import_tf()
def rollout(policy, env, timestep_limit=None, add_noise=False, offset=0):
"""Do a rollout.
If add_noise is True, the rollout will take noisy actions with
noise drawn from that stream. Otherwise, no action noise will be added.
Parameters
----------
policy: tf object
policy from which to draw actions
env: GymEnv
environment from which to draw rewards, done, and next state
timestep_limit: int, optional
steps after which to end the rollout
add_noise: bool, optional
indicates whether exploratory action noise should be added
offset: int, optional
value to subtract from the reward. For example, survival bonus
from humanoid
"""
env_timestep_limit = env.spec.max_episode_steps
timestep_limit = (env_timestep_limit if timestep_limit is None else min(
timestep_limit, env_timestep_limit))
rews = []
t = 0
observation = env.reset()
for _ in range(timestep_limit or 999999):
ac = policy.compute(observation, add_noise=add_noise, update=True)[0]
observation, rew, done, _ = env.step(ac)
rew -= np.abs(offset)
rews.append(rew)
t += 1
if done:
break
rews = np.array(rews, dtype=np.float32)
return rews, t
class GenericPolicy:
def __init__(self,
sess,
action_space,
obs_space,
preprocessor,
observation_filter,
model_config,
action_noise_std=0.0):
self.sess = sess
self.action_space = action_space
self.action_noise_std = action_noise_std
self.preprocessor = preprocessor
self.observation_filter = get_filter(observation_filter,
self.preprocessor.shape)
self.inputs = tf.placeholder(tf.float32,
[None] + list(self.preprocessor.shape))
# Policy network.
dist_class, dist_dim = ModelCatalog.get_action_dist(
action_space, model_config, dist_type="deterministic")
model = ModelCatalog.get_model({
"obs": self.inputs
}, obs_space, action_space, dist_dim, model_config)
dist = dist_class(model.outputs, model)
self.sampler = dist.sample()
self.variables = ray.experimental.tf_utils.TensorFlowVariables(
model.outputs, self.sess)
self.num_params = sum(
np.prod(variable.shape.as_list())
for _, variable in self.variables.variables.items())
self.sess.run(tf.global_variables_initializer())
def compute(self, observation, add_noise=False, update=True):
observation = self.preprocessor.transform(observation)
observation = self.observation_filter(observation[None], update=update)
action = self.sess.run(
self.sampler, feed_dict={self.inputs: observation})
action = _unbatch_tuple_actions(action)
if add_noise and isinstance(self.action_space, gym.spaces.Box):
action += np.random.randn(*action.shape) * self.action_noise_std
return action
def set_weights(self, x):
self.variables.set_flat(x)
def set_filter(self, obs_filter):
self.observation_filter = obs_filter
def get_filter(self):
return self.observation_filter
def get_weights(self):
return self.variables.get_flat()

View file

@ -1,59 +0,0 @@
# Code in this file is copied and adapted from
# https://github.com/openai/evolution-strategies-starter.
import numpy as np
from ray.rllib.utils import try_import_tf
tf = try_import_tf()
def compute_ranks(x):
"""Returns ranks in [0, len(x))
Note: This is different from scipy.stats.rankdata, which returns ranks in
[1, len(x)].
"""
assert x.ndim == 1
ranks = np.empty(len(x), dtype=int)
ranks[x.argsort()] = np.arange(len(x))
return ranks
def compute_centered_ranks(x):
y = compute_ranks(x.ravel()).reshape(x.shape).astype(np.float32)
y /= (x.size - 1)
y -= 0.5
return y
def make_session(single_threaded):
if not single_threaded:
return tf.Session()
return tf.Session(
config=tf.ConfigProto(
inter_op_parallelism_threads=1, intra_op_parallelism_threads=1))
def itergroups(items, group_size):
assert group_size >= 1
group = []
for x in items:
group.append(x)
if len(group) == group_size:
yield tuple(group)
del group[:]
if group:
yield tuple(group)
def batched_weighted_sum(weights, vecs, batch_size):
total = 0
num_items_summed = 0
for batch_weights, batch_vecs in zip(
itergroups(weights, batch_size), itergroups(vecs, batch_size)):
assert len(batch_weights) == len(batch_vecs) <= batch_size
total += np.dot(
np.asarray(batch_weights, dtype=np.float32),
np.asarray(batch_vecs, dtype=np.float32))
num_items_summed += len(batch_weights)
return total, num_items_summed

View file

@ -26,6 +26,7 @@ Result = namedtuple("Result", [
# yapf: disable
# __sphinx_doc_begin__
DEFAULT_CONFIG = with_common_config({
"action_noise_std": 0.01,
"l2_coeff": 0.005,
"noise_stdev": 0.02,
"episodes_per_batch": 1000,
@ -177,8 +178,6 @@ class ESTrainer(Trainer):
@override(Trainer)
def _init(self, config, env_creator):
policy_params = {"action_noise_std": 0.01}
config.update(policy_params)
env_context = EnvContext(config["env_config"] or {}, worker_index=0)
env = env_creator(env_context)
policy_cls = get_policy_class(config)
@ -197,8 +196,8 @@ class ESTrainer(Trainer):
# Create the actors.
logger.info("Creating actors.")
self._workers = [
Worker.remote(config, policy_params, env_creator, noise_id,
idx + 1) for idx in range(config["num_workers"])
Worker.remote(config, {}, env_creator, noise_id, idx + 1)
for idx in range(config["num_workers"])
]
self.episodes_so_far = 0

View file

@ -10,16 +10,27 @@ from ray.rllib.evaluation.sampler import _unbatch_tuple_actions
from ray.rllib.models import ModelCatalog
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.utils.filter import get_filter
from ray.rllib.utils import try_import_tf
from ray.rllib.utils.framework import try_import_tf
tf = try_import_tf()
def rollout(policy, env, timestep_limit=None, add_noise=False):
def rollout(policy, env, timestep_limit=None, add_noise=False, offset=0.0):
"""Do a rollout.
If add_noise is True, the rollout will take noisy actions with
noise drawn from that stream. Otherwise, no action noise will be added.
Args:
policy (Policy): Rllib Policy from which to draw actions.
env (gym.Env): Environment from which to draw rewards, done, and
next state.
timestep_limit (Optional[int]): Steps after which to end the rollout.
If None, use `env.spec.max_episode_steps` or 999999.
add_noise (bool): Indicates whether exploratory action noise should be
added.
offset (float): Value to subtract from the reward (e.g. survival bonus
from humanoid).
"""
max_timestep_limit = 999999
env_timestep_limit = env.spec.max_episode_steps if (
@ -27,18 +38,21 @@ def rollout(policy, env, timestep_limit=None, add_noise=False):
else max_timestep_limit
timestep_limit = (env_timestep_limit if timestep_limit is None else min(
timestep_limit, env_timestep_limit))
rews = []
rewards = []
t = 0
observation = env.reset()
for _ in range(timestep_limit or max_timestep_limit):
ac = policy.compute_actions(observation, add_noise=add_noise)[0]
observation, rew, done, _ = env.step(ac)
rews.append(rew)
ac = policy.compute_actions(
observation, add_noise=add_noise, update=True)[0]
observation, r, done, _ = env.step(ac)
if offset != 0.0:
r -= np.abs(offset)
rewards.append(r)
t += 1
if done:
break
rews = np.array(rews, dtype=np.float32)
return rews, t
rewards = np.array(rewards, dtype=np.float32)
return rewards, t
def make_session(single_threaded):

View file

@ -5,7 +5,6 @@ import gym
import numpy as np
import ray
import ray.experimental.tf_utils
from ray.rllib.evaluation.sampler import _unbatch_tuple_actions
from ray.rllib.models import ModelCatalog
from ray.rllib.policy.sample_batch import SampleBatch

View file

@ -5,15 +5,15 @@ import numpy as np
class Optimizer:
def __init__(self, pi):
self.pi = pi
self.dim = pi.num_params
def __init__(self, policy):
self.policy = policy
self.dim = policy.num_params
self.t = 0
def update(self, globalg):
self.t += 1
step = self._compute_step(globalg)
theta = self.pi.get_flat_weights()
theta = self.policy.get_flat_weights()
ratio = np.linalg.norm(step) / np.linalg.norm(theta)
return theta + step, ratio
@ -22,8 +22,8 @@ class Optimizer:
class SGD(Optimizer):
def __init__(self, pi, stepsize, momentum=0.9):
Optimizer.__init__(self, pi)
def __init__(self, policy, stepsize, momentum=0.0):
Optimizer.__init__(self, policy)
self.v = np.zeros(self.dim, dtype=np.float32)
self.stepsize, self.momentum = stepsize, momentum
@ -34,8 +34,9 @@ class SGD(Optimizer):
class Adam(Optimizer):
def __init__(self, pi, stepsize, beta1=0.9, beta2=0.999, epsilon=1e-08):
Optimizer.__init__(self, pi)
def __init__(self, policy, stepsize, beta1=0.9, beta2=0.999,
epsilon=1e-08):
Optimizer.__init__(self, policy)
self.stepsize = stepsize
self.beta1 = beta1
self.beta2 = beta2

View file

@ -67,7 +67,7 @@ class FullyConnectedNetwork(TorchModelV2, nn.Module):
prev_layer_size = hiddens[-1]
if self.num_outputs:
self._logits = SlimFC(
in_size=hiddens[-1],
in_size=prev_layer_size,
out_size=self.num_outputs,
initializer=normc_initializer(0.01),
activation_fn=None)

View file

@ -1,10 +1,11 @@
cartpole-ars:
cartpole-ars-tf:
env: CartPole-v0
run: ARS
stop:
episode_reward_mean: 50
timesteps_total: 500000
config:
use_pytorch: false
noise_stdev: 0.02
num_rollouts: 50
rollouts_used: 25

View file

@ -0,0 +1,17 @@
cartpole-ars-torch:
env: CartPole-v0
run: ARS
stop:
episode_reward_mean: 150
timesteps_total: 500000
config:
use_pytorch: true
noise_stdev: 0.02
num_rollouts: 50
rollouts_used: 25
num_workers: 2
sgd_stepsize: 0.01
noise_size: 25000000
eval_prob: 0.5
model:
fcnet_hiddens: [64, 64]