mirror of
https://github.com/vale981/ray
synced 2025-03-05 18:11:42 -05:00
[RLlib] PyTorch version of ES (Evolution Strategies). (#8104)
PyTorch version of Evolution Strategies (ES) Algo.
This commit is contained in:
parent
9f3e9e7e9f
commit
3812bfedda
17 changed files with 276 additions and 121 deletions
|
@ -13,7 +13,7 @@ Algorithm Frameworks Discrete Actions Continuous Actions Multi-
|
|||
=================== ========== ======================= ================== =========== =====================
|
||||
`A2C, A3C`_ tf + torch **Yes** `+parametric`_ **Yes** **Yes** `+RNN`_, `+autoreg`_
|
||||
`ARS`_ tf **Yes** **Yes** No
|
||||
`ES`_ tf **Yes** **Yes** No
|
||||
`ES`_ tf + torch **Yes** **Yes** No
|
||||
`DDPG`_, `TD3`_ tf + torch No **Yes** **Yes**
|
||||
`APEX-DDPG`_ tf No **Yes** **Yes**
|
||||
`DQN`_, `Rainbow`_ tf + torch **Yes** `+parametric`_ No **Yes**
|
||||
|
@ -422,7 +422,7 @@ Tuned examples: `CartPole-v0 <https://github.com/ray-project/ray/blob/master/rll
|
|||
|
||||
Evolution Strategies
|
||||
--------------------
|
||||
|tensorflow|
|
||||
|pytorch| |tensorflow|
|
||||
`[paper] <https://arxiv.org/abs/1703.03864>`__ `[implementation] <https://github.com/ray-project/ray/blob/master/rllib/agents/es/es.py>`__
|
||||
Code here is adapted from https://github.com/openai/evolution-strategies-starter to execute in the distributed setting with Ray.
|
||||
|
||||
|
|
|
@ -101,7 +101,7 @@ Algorithms
|
|||
|
||||
- |pytorch| |tensorflow| :ref:`Advantage Actor-Critic (A2C, A3C) <a3c>`
|
||||
|
||||
- |tensorflow| :ref:`Deep Deterministic Policy Gradients (DDPG, TD3) <ddpg>`
|
||||
- |pytorch| |tensorflow| :ref:`Deep Deterministic Policy Gradients (DDPG, TD3) <ddpg>`
|
||||
|
||||
- |pytorch| |tensorflow| :ref:`Deep Q Networks (DQN, Rainbow, Parametric DQN) <dqn>`
|
||||
|
||||
|
@ -109,13 +109,13 @@ Algorithms
|
|||
|
||||
- |pytorch| |tensorflow| :ref:`Proximal Policy Optimization (PPO) <ppo>`
|
||||
|
||||
- |tensorflow| :ref:`Soft Actor Critic (SAC) <sac>`
|
||||
- |pytorch| |tensorflow| :ref:`Soft Actor Critic (SAC) <sac>`
|
||||
|
||||
* Derivative-free
|
||||
|
||||
- |tensorflow| :ref:`Augmented Random Search (ARS) <ars>`
|
||||
|
||||
- |tensorflow| :ref:`Evolution Strategies <es>`
|
||||
- |pytorch| |tensorflow| :ref:`Evolution Strategies <es>`
|
||||
|
||||
* Multi-agent specific
|
||||
|
||||
|
@ -124,7 +124,7 @@ Algorithms
|
|||
|
||||
* Offline
|
||||
|
||||
- |tensorflow| :ref:`Advantage Re-Weighted Imitation Learning (MARWIL) <marwil>`
|
||||
- |pytorch| |tensorflow| :ref:`Advantage Re-Weighted Imitation Learning (MARWIL) <marwil>`
|
||||
|
||||
* Contextual bandits
|
||||
|
||||
|
|
|
@ -103,7 +103,7 @@ py_test(
|
|||
py_test(
|
||||
name = "test_apex",
|
||||
tags = ["agents_dir"],
|
||||
size = "medium",
|
||||
size = "large",
|
||||
srcs = ["agents/dqn/tests/test_apex.py"]
|
||||
)
|
||||
|
||||
|
|
|
@ -1,6 +1,5 @@
|
|||
from ray.rllib.agents.es.es import (ESTrainer, DEFAULT_CONFIG)
|
||||
from ray.rllib.utils import renamed_agent
|
||||
from ray.rllib.agents.es.es import ESTrainer, DEFAULT_CONFIG
|
||||
from ray.rllib.agents.es.es_tf_policy import ESTFPolicy
|
||||
from ray.rllib.agents.es.es_torch_policy import ESTorchPolicy
|
||||
|
||||
ESAgent = renamed_agent(ESTrainer)
|
||||
|
||||
__all__ = ["ESAgent", "ESTrainer", "DEFAULT_CONFIG"]
|
||||
__all__ = ["ESTFPolicy", "ESTorchPolicy", "ESTrainer", "DEFAULT_CONFIG"]
|
||||
|
|
|
@ -8,7 +8,8 @@ import time
|
|||
|
||||
import ray
|
||||
from ray.rllib.agents import Trainer, with_common_config
|
||||
from ray.rllib.agents.es import optimizers, policies, utils
|
||||
from ray.rllib.agents.es import optimizers, utils
|
||||
from ray.rllib.agents.es.es_tf_policy import ESTFPolicy, rollout
|
||||
from ray.rllib.env.env_context import EnvContext
|
||||
from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID
|
||||
from ray.rllib.utils import FilterManager
|
||||
|
@ -72,7 +73,8 @@ class Worker:
|
|||
min_task_runtime=0.2):
|
||||
self.min_task_runtime = min_task_runtime
|
||||
self.config = config
|
||||
self.policy_params = policy_params
|
||||
self.config.update(policy_params)
|
||||
self.config["single_threaded"] = True
|
||||
self.noise = SharedNoiseTable(noise)
|
||||
|
||||
env_context = EnvContext(config["env_config"] or {}, worker_index)
|
||||
|
@ -81,15 +83,13 @@ class Worker:
|
|||
self.preprocessor = models.ModelCatalog.get_preprocessor(
|
||||
self.env, config["model"])
|
||||
|
||||
self.sess = utils.make_session(single_threaded=True)
|
||||
self.policy = policies.GenericPolicy(
|
||||
self.sess, self.env.action_space, self.env.observation_space,
|
||||
self.preprocessor, config["observation_filter"], config["model"],
|
||||
**policy_params)
|
||||
policy_cls = get_policy_class(config)
|
||||
self.policy = policy_cls(self.env.observation_space,
|
||||
self.env.action_space, config)
|
||||
|
||||
@property
|
||||
def filters(self):
|
||||
return {DEFAULT_POLICY_ID: self.policy.get_filter()}
|
||||
return {DEFAULT_POLICY_ID: self.policy.observation_filter}
|
||||
|
||||
def sync_filters(self, new_filters):
|
||||
for k in self.filters:
|
||||
|
@ -104,7 +104,7 @@ class Worker:
|
|||
return return_filters
|
||||
|
||||
def rollout(self, timestep_limit, add_noise=True):
|
||||
rollout_rewards, rollout_fragment_length = policies.rollout(
|
||||
rollout_rewards, rollout_fragment_length = rollout(
|
||||
self.policy,
|
||||
self.env,
|
||||
timestep_limit=timestep_limit,
|
||||
|
@ -113,7 +113,7 @@ class Worker:
|
|||
|
||||
def do_rollouts(self, params, timestep_limit=None):
|
||||
# Set the network weights.
|
||||
self.policy.set_weights(params)
|
||||
self.policy.set_flat_weights(params)
|
||||
|
||||
noise_indices, returns, sign_returns, lengths = [], [], [], []
|
||||
eval_returns, eval_lengths = [], []
|
||||
|
@ -125,7 +125,7 @@ class Worker:
|
|||
|
||||
if np.random.uniform() < self.config["eval_prob"]:
|
||||
# Do an evaluation run with no perturbation.
|
||||
self.policy.set_weights(params)
|
||||
self.policy.set_flat_weights(params)
|
||||
rewards, length = self.rollout(timestep_limit, add_noise=False)
|
||||
eval_returns.append(rewards.sum())
|
||||
eval_lengths.append(length)
|
||||
|
@ -138,10 +138,10 @@ class Worker:
|
|||
|
||||
# These two sampling steps could be done in parallel on
|
||||
# different actors letting us update twice as frequently.
|
||||
self.policy.set_weights(params + perturbation)
|
||||
self.policy.set_flat_weights(params + perturbation)
|
||||
rewards_pos, lengths_pos = self.rollout(timestep_limit)
|
||||
|
||||
self.policy.set_weights(params - perturbation)
|
||||
self.policy.set_flat_weights(params - perturbation)
|
||||
rewards_neg, lengths_neg = self.rollout(timestep_limit)
|
||||
|
||||
noise_indices.append(noise_index)
|
||||
|
@ -160,6 +160,15 @@ class Worker:
|
|||
eval_lengths=eval_lengths)
|
||||
|
||||
|
||||
def get_policy_class(config):
|
||||
if config["use_pytorch"]:
|
||||
from ray.rllib.agents.es.es_torch_policy import ESTorchPolicy
|
||||
policy_cls = ESTorchPolicy
|
||||
else:
|
||||
policy_cls = ESTFPolicy
|
||||
return policy_cls
|
||||
|
||||
|
||||
class ESTrainer(Trainer):
|
||||
"""Large-scale implementation of Evolution Strategies in Ray."""
|
||||
|
||||
|
@ -168,22 +177,15 @@ class ESTrainer(Trainer):
|
|||
|
||||
@override(Trainer)
|
||||
def _init(self, config, env_creator):
|
||||
# PyTorch check.
|
||||
if config["use_pytorch"]:
|
||||
raise ValueError(
|
||||
"ES does not support PyTorch yet! Use tf instead.")
|
||||
|
||||
policy_params = {"action_noise_std": 0.01}
|
||||
|
||||
config.update(policy_params)
|
||||
env_context = EnvContext(config["env_config"] or {}, worker_index=0)
|
||||
env = env_creator(env_context)
|
||||
from ray.rllib import models
|
||||
preprocessor = models.ModelCatalog.get_preprocessor(env)
|
||||
|
||||
self.sess = utils.make_session(single_threaded=False)
|
||||
self.policy = policies.GenericPolicy(
|
||||
self.sess, env.action_space, env.observation_space, preprocessor,
|
||||
config["observation_filter"], config["model"], **policy_params)
|
||||
policy_cls = get_policy_class(config)
|
||||
self.policy = policy_cls(
|
||||
obs_space=env.observation_space,
|
||||
action_space=env.action_space,
|
||||
config=config)
|
||||
self.optimizer = optimizers.Adam(self.policy, config["stepsize"])
|
||||
self.report_length = config["report_length"]
|
||||
|
||||
|
@ -207,8 +209,9 @@ class ESTrainer(Trainer):
|
|||
def _train(self):
|
||||
config = self.config
|
||||
|
||||
theta = self.policy.get_weights()
|
||||
theta = self.policy.get_flat_weights()
|
||||
assert theta.dtype == np.float32
|
||||
assert len(theta.shape) == 1
|
||||
|
||||
# Put the current policy weights in the object store.
|
||||
theta_id = ray.put(theta)
|
||||
|
@ -264,14 +267,14 @@ class ESTrainer(Trainer):
|
|||
theta, update_ratio = self.optimizer.update(-g +
|
||||
config["l2_coeff"] * theta)
|
||||
# Set the new weights in the local copy of the policy.
|
||||
self.policy.set_weights(theta)
|
||||
self.policy.set_flat_weights(theta)
|
||||
# Store the rewards
|
||||
if len(all_eval_returns) > 0:
|
||||
self.reward_list.append(np.mean(eval_returns))
|
||||
|
||||
# Now sync the filters
|
||||
FilterManager.synchronize({
|
||||
DEFAULT_POLICY_ID: self.policy.get_filter()
|
||||
DEFAULT_POLICY_ID: self.policy.observation_filter
|
||||
}, self._workers)
|
||||
|
||||
info = {
|
||||
|
@ -293,7 +296,7 @@ class ESTrainer(Trainer):
|
|||
|
||||
@override(Trainer)
|
||||
def compute_action(self, observation, *args, **kwargs):
|
||||
return self.policy.compute(observation, update=False)[0]
|
||||
return self.policy.compute_actions(observation, update=False)[0]
|
||||
|
||||
@override(Trainer)
|
||||
def _stop(self):
|
||||
|
@ -325,15 +328,15 @@ class ESTrainer(Trainer):
|
|||
|
||||
def __getstate__(self):
|
||||
return {
|
||||
"weights": self.policy.get_weights(),
|
||||
"filter": self.policy.get_filter(),
|
||||
"weights": self.policy.get_flat_weights(),
|
||||
"filter": self.policy.observation_filter,
|
||||
"episodes_so_far": self.episodes_so_far,
|
||||
}
|
||||
|
||||
def __setstate__(self, state):
|
||||
self.episodes_so_far = state["episodes_so_far"]
|
||||
self.policy.set_weights(state["weights"])
|
||||
self.policy.set_filter(state["filter"])
|
||||
self.policy.set_flat_weights(state["weights"])
|
||||
self.policy.observation_filter = state["filter"]
|
||||
FilterManager.synchronize({
|
||||
DEFAULT_POLICY_ID: self.policy.get_filter()
|
||||
DEFAULT_POLICY_ID: self.policy.observation_filter
|
||||
}, self._workers)
|
||||
|
|
|
@ -8,6 +8,7 @@ import ray
|
|||
import ray.experimental.tf_utils
|
||||
from ray.rllib.evaluation.sampler import _unbatch_tuple_actions
|
||||
from ray.rllib.models import ModelCatalog
|
||||
from ray.rllib.policy.sample_batch import SampleBatch
|
||||
from ray.rllib.utils.filter import get_filter
|
||||
from ray.rllib.utils import try_import_tf
|
||||
|
||||
|
@ -30,7 +31,7 @@ def rollout(policy, env, timestep_limit=None, add_noise=False):
|
|||
t = 0
|
||||
observation = env.reset()
|
||||
for _ in range(timestep_limit or max_timestep_limit):
|
||||
ac = policy.compute(observation, add_noise=add_noise)[0]
|
||||
ac = policy.compute_actions(observation, add_noise=add_noise)[0]
|
||||
observation, rew, done, _ = env.step(ac)
|
||||
rews.append(rew)
|
||||
t += 1
|
||||
|
@ -40,24 +41,32 @@ def rollout(policy, env, timestep_limit=None, add_noise=False):
|
|||
return rews, t
|
||||
|
||||
|
||||
class GenericPolicy:
|
||||
def __init__(self, sess, action_space, obs_space, preprocessor,
|
||||
observation_filter, model_options, action_noise_std):
|
||||
self.sess = sess
|
||||
def make_session(single_threaded):
|
||||
if not single_threaded:
|
||||
return tf.Session()
|
||||
return tf.Session(
|
||||
config=tf.ConfigProto(
|
||||
inter_op_parallelism_threads=1, intra_op_parallelism_threads=1))
|
||||
|
||||
|
||||
class ESTFPolicy:
|
||||
def __init__(self, obs_space, action_space, config):
|
||||
self.action_space = action_space
|
||||
self.action_noise_std = action_noise_std
|
||||
self.preprocessor = preprocessor
|
||||
self.observation_filter = get_filter(observation_filter,
|
||||
self.action_noise_std = config["action_noise_std"]
|
||||
self.preprocessor = ModelCatalog.get_preprocessor_for_space(obs_space)
|
||||
self.observation_filter = get_filter(config["observation_filter"],
|
||||
self.preprocessor.shape)
|
||||
self.single_threaded = config.get("single_threaded", False)
|
||||
self.sess = make_session(single_threaded=self.single_threaded)
|
||||
self.inputs = tf.placeholder(tf.float32,
|
||||
[None] + list(self.preprocessor.shape))
|
||||
|
||||
# Policy network.
|
||||
dist_class, dist_dim = ModelCatalog.get_action_dist(
|
||||
self.action_space, model_options, dist_type="deterministic")
|
||||
self.action_space, config["model"], dist_type="deterministic")
|
||||
model = ModelCatalog.get_model({
|
||||
"obs": self.inputs
|
||||
}, obs_space, action_space, dist_dim, model_options)
|
||||
SampleBatch.CUR_OBS: self.inputs
|
||||
}, obs_space, action_space, dist_dim, config["model"])
|
||||
dist = dist_class(model.outputs, model)
|
||||
self.sampler = dist.sample()
|
||||
|
||||
|
@ -69,7 +78,7 @@ class GenericPolicy:
|
|||
for _, variable in self.variables.variables.items())
|
||||
self.sess.run(tf.global_variables_initializer())
|
||||
|
||||
def compute(self, observation, add_noise=False, update=True):
|
||||
def compute_actions(self, observation, add_noise=False, update=True):
|
||||
observation = self.preprocessor.transform(observation)
|
||||
observation = self.observation_filter(observation[None], update=update)
|
||||
action = self.sess.run(
|
||||
|
@ -79,14 +88,8 @@ class GenericPolicy:
|
|||
action += np.random.randn(*action.shape) * self.action_noise_std
|
||||
return action
|
||||
|
||||
def set_weights(self, x):
|
||||
def set_flat_weights(self, x):
|
||||
self.variables.set_flat(x)
|
||||
|
||||
def get_weights(self):
|
||||
def get_flat_weights(self):
|
||||
return self.variables.get_flat()
|
||||
|
||||
def get_filter(self):
|
||||
return self.observation_filter
|
||||
|
||||
def set_filter(self, observation_filter):
|
||||
self.observation_filter = observation_filter
|
108
rllib/agents/es/es_torch_policy.py
Normal file
108
rllib/agents/es/es_torch_policy.py
Normal file
|
@ -0,0 +1,108 @@
|
|||
# Code in this file is adapted from:
|
||||
# https://github.com/openai/evolution-strategies-starter.
|
||||
|
||||
import gym
|
||||
import numpy as np
|
||||
|
||||
import ray
|
||||
import ray.experimental.tf_utils
|
||||
from ray.rllib.evaluation.sampler import _unbatch_tuple_actions
|
||||
from ray.rllib.models import ModelCatalog
|
||||
from ray.rllib.policy.sample_batch import SampleBatch
|
||||
from ray.rllib.policy.torch_policy_template import build_torch_policy
|
||||
from ray.rllib.utils.filter import get_filter
|
||||
from ray.rllib.utils.framework import try_import_torch
|
||||
from ray.rllib.utils.torch_ops import convert_to_torch_tensor
|
||||
|
||||
torch, _ = try_import_torch()
|
||||
|
||||
|
||||
def before_init(policy, observation_space, action_space, config):
|
||||
policy.action_noise_std = config["action_noise_std"]
|
||||
policy.preprocessor = ModelCatalog.get_preprocessor_for_space(
|
||||
observation_space)
|
||||
policy.observation_filter = get_filter(config["observation_filter"],
|
||||
policy.preprocessor.shape)
|
||||
policy.single_threaded = config.get("single_threaded", False)
|
||||
|
||||
def _set_flat_weights(policy, theta):
|
||||
pos = 0
|
||||
theta_dict = policy.model.state_dict()
|
||||
new_theta_dict = {}
|
||||
|
||||
for k in sorted(theta_dict.keys()):
|
||||
shape = policy.param_shapes[k]
|
||||
num_params = int(np.prod(shape))
|
||||
new_theta_dict[k] = torch.from_numpy(
|
||||
np.reshape(theta[pos:pos + num_params], shape))
|
||||
pos += num_params
|
||||
policy.model.load_state_dict(new_theta_dict)
|
||||
|
||||
def _get_flat_weights(policy):
|
||||
# Get the parameter tensors.
|
||||
theta_dict = policy.model.state_dict()
|
||||
# Flatten it into a single np.ndarray.
|
||||
theta_list = []
|
||||
for k in sorted(theta_dict.keys()):
|
||||
theta_list.append(torch.reshape(theta_dict[k], (-1, )))
|
||||
cat = torch.cat(theta_list, dim=0)
|
||||
return cat.numpy()
|
||||
|
||||
type(policy).set_flat_weights = _set_flat_weights
|
||||
type(policy).get_flat_weights = _get_flat_weights
|
||||
|
||||
def _compute_actions(policy, obs_batch, add_noise=False, update=True):
|
||||
observation = policy.preprocessor.transform(obs_batch)
|
||||
observation = policy.observation_filter(
|
||||
observation[None], update=update)
|
||||
|
||||
observation = convert_to_torch_tensor(observation)
|
||||
dist_inputs, _ = policy.model({
|
||||
SampleBatch.CUR_OBS: observation
|
||||
}, [], None)
|
||||
dist = policy.dist_class(dist_inputs, policy.model)
|
||||
action = dist.sample().detach().numpy()
|
||||
action = _unbatch_tuple_actions(action)
|
||||
if add_noise and isinstance(policy.action_space, gym.spaces.Box):
|
||||
action += np.random.randn(*action.shape) * policy.action_noise_std
|
||||
return action
|
||||
|
||||
type(policy).compute_actions = _compute_actions
|
||||
|
||||
|
||||
def after_init(policy, observation_space, action_space, config):
|
||||
state_dict = policy.model.state_dict()
|
||||
policy.param_shapes = {
|
||||
k: tuple(state_dict[k].size())
|
||||
for k in sorted(state_dict.keys())
|
||||
}
|
||||
policy.num_params = sum(np.prod(s) for s in policy.param_shapes.values())
|
||||
|
||||
|
||||
def make_model_and_action_dist(policy, observation_space, action_space,
|
||||
config):
|
||||
# Policy network.
|
||||
dist_class, dist_dim = ModelCatalog.get_action_dist(
|
||||
action_space,
|
||||
config["model"], # model_options
|
||||
dist_type="deterministic",
|
||||
framework="torch")
|
||||
model = ModelCatalog.get_model_v2(
|
||||
observation_space,
|
||||
action_space,
|
||||
num_outputs=dist_dim,
|
||||
model_config=config["model"],
|
||||
framework="torch")
|
||||
# Make all model params not require any gradients.
|
||||
for p in model.parameters():
|
||||
p.requires_grad = False
|
||||
return model, dist_class
|
||||
|
||||
|
||||
ESTorchPolicy = build_torch_policy(
|
||||
name="ESTorchPolicy",
|
||||
loss_fn=None,
|
||||
get_default_config=lambda: ray.rllib.agents.es.es.DEFAULT_CONFIG,
|
||||
before_init=before_init,
|
||||
after_init=after_init,
|
||||
make_model_and_action_dist=make_model_and_action_dist)
|
|
@ -13,7 +13,7 @@ class Optimizer:
|
|||
def update(self, globalg):
|
||||
self.t += 1
|
||||
step = self._compute_step(globalg)
|
||||
theta = self.pi.get_weights()
|
||||
theta = self.pi.get_flat_weights()
|
||||
ratio = np.linalg.norm(step) / np.linalg.norm(theta)
|
||||
return theta + step, ratio
|
||||
|
||||
|
|
33
rllib/agents/es/tests/test_es.py
Normal file
33
rllib/agents/es/tests/test_es.py
Normal file
|
@ -0,0 +1,33 @@
|
|||
import unittest
|
||||
|
||||
import ray
|
||||
import ray.rllib.agents.es as es
|
||||
from ray.rllib.utils.framework import try_import_tf
|
||||
from ray.rllib.utils.test_utils import framework_iterator
|
||||
|
||||
tf = try_import_tf()
|
||||
|
||||
|
||||
class TestES(unittest.TestCase):
|
||||
def test_es_compilation(self):
|
||||
"""Test whether an ESTrainer can be built on all frameworks."""
|
||||
ray.init()
|
||||
config = es.DEFAULT_CONFIG.copy()
|
||||
# Keep it simple.
|
||||
config["model"]["fcnet_hiddens"] = [10]
|
||||
config["model"]["fcnet_activation"] = None
|
||||
|
||||
num_iterations = 2
|
||||
|
||||
for _ in framework_iterator(config, ("torch", "tf")):
|
||||
plain_config = config.copy()
|
||||
trainer = es.ESTrainer(config=plain_config, env="CartPole-v0")
|
||||
for i in range(num_iterations):
|
||||
results = trainer.train()
|
||||
print(results)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import pytest
|
||||
import sys
|
||||
sys.exit(pytest.main(["-v", __file__]))
|
|
@ -2,9 +2,6 @@
|
|||
# https://github.com/openai/evolution-strategies-starter.
|
||||
|
||||
import numpy as np
|
||||
from ray.rllib.utils import try_import_tf
|
||||
|
||||
tf = try_import_tf()
|
||||
|
||||
|
||||
def compute_ranks(x):
|
||||
|
@ -26,14 +23,6 @@ def compute_centered_ranks(x):
|
|||
return y
|
||||
|
||||
|
||||
def make_session(single_threaded):
|
||||
if not single_threaded:
|
||||
return tf.Session()
|
||||
return tf.Session(
|
||||
config=tf.ConfigProto(
|
||||
inter_op_parallelism_threads=1, intra_op_parallelism_threads=1))
|
||||
|
||||
|
||||
def itergroups(items, group_size):
|
||||
assert group_size >= 1
|
||||
group = []
|
||||
|
|
|
@ -18,7 +18,7 @@ from ray.rllib.models.tf.tf_modelv2 import TFModelV2
|
|||
from ray.rllib.models.tf.visionnet_v1 import VisionNetwork
|
||||
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
|
||||
from ray.rllib.models.torch.torch_action_dist import TorchCategorical, \
|
||||
TorchMultiCategorical, TorchDiagGaussian
|
||||
TorchMultiCategorical, TorchDeterministic, TorchDiagGaussian
|
||||
from ray.rllib.utils import try_import_tf
|
||||
from ray.rllib.utils.annotations import DeveloperAPI, PublicAPI
|
||||
from ray.rllib.utils.error import UnsupportedSpaceException
|
||||
|
@ -149,7 +149,8 @@ class ModelCatalog:
|
|||
if dist_type is None:
|
||||
dist = DiagGaussian if framework == "tf" else TorchDiagGaussian
|
||||
elif dist_type == "deterministic":
|
||||
dist = Deterministic
|
||||
dist = Deterministic if framework == "tf" else \
|
||||
TorchDeterministic
|
||||
# Discrete Space -> Categorical.
|
||||
elif isinstance(action_space, gym.spaces.Discrete):
|
||||
dist = Categorical if framework == "tf" else TorchCategorical
|
||||
|
|
|
@ -31,7 +31,7 @@ class FullyConnectedNetwork(TorchModelV2, nn.Module):
|
|||
|
||||
logger.debug("Constructing fcnet {} {}".format(hiddens, activation))
|
||||
layers = []
|
||||
prev_layer_size = np.product(obs_space.shape)
|
||||
prev_layer_size = int(np.product(obs_space.shape))
|
||||
self._logits = None
|
||||
|
||||
# Create layers 0 to second-last.
|
||||
|
|
|
@ -28,48 +28,55 @@ def build_torch_policy(name,
|
|||
apply_gradients_fn=None,
|
||||
mixins=None,
|
||||
get_batch_divisibility_req=None):
|
||||
"""Helper function for creating a torch policy at runtime.
|
||||
"""Helper function for creating a torch policy class at runtime.
|
||||
|
||||
Arguments:
|
||||
name (str): name of the policy (e.g., "PPOTorchPolicy")
|
||||
loss_fn (func): function that returns a loss tensor as arguments
|
||||
(policy, model, dist_class, train_batch)
|
||||
get_default_config (func): optional function that returns the default
|
||||
config to merge with any overrides
|
||||
stats_fn (func): optional function that returns a dict of
|
||||
values given the policy and batch input tensors
|
||||
postprocess_fn (func): optional experience postprocessing function
|
||||
that takes the same args as Policy.postprocess_trajectory()
|
||||
extra_action_out_fn (func): optional function that returns
|
||||
a dict of extra values to include in experiences
|
||||
extra_grad_process_fn (func): optional function that is called after
|
||||
gradients are computed and returns processing info
|
||||
optimizer_fn (func): optional function that returns a torch optimizer
|
||||
given the policy and config
|
||||
before_init (func): optional function to run at the beginning of
|
||||
policy init that takes the same arguments as the policy constructor
|
||||
after_init (func): optional function to run at the end of policy init
|
||||
that takes the same arguments as the policy constructor
|
||||
action_sampler_fn (Optional[callable]): A callable returning a sampled
|
||||
action and its log-likelihood given some (obs and state) inputs.
|
||||
action_distribution_fn (Optional[callable]): A callable returning
|
||||
distribution inputs (parameters), a dist-class to generate an
|
||||
action distribution object from, and internal-state outputs (or an
|
||||
empty list if not applicable).
|
||||
make_model_and_action_dist (func): optional func that takes the same
|
||||
arguments as policy init and returns a tuple of model instance and
|
||||
torch action distribution class. If not specified, the default
|
||||
model and action dist from the catalog will be used
|
||||
apply_gradients_fn (Optional[callable]): An optional callable that
|
||||
loss_fn (callable): Callable that returns a loss tensor as arguments
|
||||
given (policy, model, dist_class, train_batch).
|
||||
get_default_config (Optional[callable]): Optional callable that returns
|
||||
the default config to merge with any overrides.
|
||||
stats_fn (Optional[callable]): Optional callable that returns a dict of
|
||||
values given the policy and batch input tensors.
|
||||
postprocess_fn (Optional[callable]): Optional experience postprocessing
|
||||
function that takes the same args as
|
||||
Policy.postprocess_trajectory().
|
||||
extra_action_out_fn (Optional[callable]): Optional callable that
|
||||
returns a dict of extra values to include in experiences.
|
||||
extra_grad_process_fn (Optional[callable]): Optional callable that is
|
||||
called after gradients are computed and returns processing info.
|
||||
optimizer_fn (Optional[callable]): Optional callable that returns a
|
||||
torch optimizer given the policy and config.
|
||||
before_init (Optional[callable]): Optional callable to run at the
|
||||
beginning of `Policy.__init__` that takes the same arguments as
|
||||
the Policy constructor.
|
||||
after_init (Optional[callable]): Optional callable to run at the end of
|
||||
policy init that takes the same arguments as the policy
|
||||
constructor.
|
||||
action_sampler_fn (Optional[callable]): Optional callable returning a
|
||||
sampled action and its log-likelihood given some (obs and state)
|
||||
inputs.
|
||||
action_distribution_fn (Optional[callable]): A callable that takes
|
||||
the Policy, Model, the observation batch, an explore-flag, a
|
||||
timestep, and an is_training flag and returns a tuple of
|
||||
a) distribution inputs (parameters), b) a dist-class to generate
|
||||
an action distribution object from, and c) internal-state outputs
|
||||
(empty list if not applicable).
|
||||
make_model_and_action_dist (Optional[callable]): Optional func that
|
||||
takes the same arguments as Policy.__init__ and returns a tuple
|
||||
of model instance and torch action distribution class. If not
|
||||
specified, the default model and action dist from the catalog will
|
||||
be used.
|
||||
apply_gradients_fn (Optional[callable]): Optional callable that
|
||||
takes a grads list and applies these to the Model's parameters.
|
||||
mixins (list): list of any class mixins for the returned policy class.
|
||||
These mixins will be applied in order and will have higher
|
||||
precedence than the TorchPolicy class
|
||||
precedence than the TorchPolicy class.
|
||||
get_batch_divisibility_req (Optional[callable]): Optional callable that
|
||||
returns the divisibility requirement for sample batches.
|
||||
|
||||
Returns:
|
||||
a TorchPolicy instance that uses the specified args
|
||||
type: TorchPolicy child class constructed from the specified args.
|
||||
"""
|
||||
|
||||
original_kwargs = locals().copy()
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
# Runs one or more regression tests. Retries tests up to 3 times.
|
||||
#
|
||||
# Example usage:
|
||||
# $ python run_regression_tests.py regression-tests/cartpole-es.yaml
|
||||
# $ python run_regression_tests.py regression-tests/cartpole-es-[tf|torch].yaml
|
||||
#
|
||||
# When using in BAZEL (with py_test), e.g. see in ray/rllib/BUILD:
|
||||
# py_test(
|
||||
|
|
|
@ -1,10 +1,11 @@
|
|||
cartpole-es:
|
||||
cartpole-es-tf:
|
||||
env: CartPole-v0
|
||||
run: ES
|
||||
stop:
|
||||
episode_reward_mean: 75
|
||||
episode_reward_mean: 150
|
||||
timesteps_total: 400000
|
||||
config:
|
||||
use_pytorch: false
|
||||
num_workers: 2
|
||||
noise_size: 25000000
|
||||
episodes_per_batch: 50
|
11
rllib/tuned_examples/regression_tests/cartpole-es-torch.yaml
Normal file
11
rllib/tuned_examples/regression_tests/cartpole-es-torch.yaml
Normal file
|
@ -0,0 +1,11 @@
|
|||
cartpole-es-torch:
|
||||
env: CartPole-v0
|
||||
run: ES
|
||||
stop:
|
||||
episode_reward_mean: 150
|
||||
timesteps_total: 400000
|
||||
config:
|
||||
use_pytorch: true
|
||||
num_workers: 2
|
||||
noise_size: 25000000
|
||||
episodes_per_batch: 50
|
|
@ -214,15 +214,15 @@ def get_activation_fn(name, framework="tf"):
|
|||
torch.nn.ReLU. Returns None for name="linear".
|
||||
"""
|
||||
if framework == "torch":
|
||||
_, nn = try_import_torch()
|
||||
if name == "linear":
|
||||
if name == "linear" or name is None:
|
||||
return None
|
||||
elif name == "relu":
|
||||
_, nn = try_import_torch()
|
||||
if name == "relu":
|
||||
return nn.ReLU
|
||||
elif name == "tanh":
|
||||
return nn.Tanh
|
||||
else:
|
||||
if name == "linear":
|
||||
if name == "linear" or name is None:
|
||||
return None
|
||||
tf = try_import_tf()
|
||||
fn = getattr(tf.nn, name, None)
|
||||
|
|
Loading…
Add table
Reference in a new issue