[RLlib] PyTorch version of ES (Evolution Strategies). (#8104)

PyTorch version of Evolution Strategies (ES) Algo.
This commit is contained in:
Sven Mika 2020-04-20 21:47:28 +02:00 committed by GitHub
parent 9f3e9e7e9f
commit 3812bfedda
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
17 changed files with 276 additions and 121 deletions

View file

@ -13,7 +13,7 @@ Algorithm Frameworks Discrete Actions Continuous Actions Multi-
=================== ========== ======================= ================== =========== =====================
`A2C, A3C`_ tf + torch **Yes** `+parametric`_ **Yes** **Yes** `+RNN`_, `+autoreg`_
`ARS`_ tf **Yes** **Yes** No
`ES`_ tf **Yes** **Yes** No
`ES`_ tf + torch **Yes** **Yes** No
`DDPG`_, `TD3`_ tf + torch No **Yes** **Yes**
`APEX-DDPG`_ tf No **Yes** **Yes**
`DQN`_, `Rainbow`_ tf + torch **Yes** `+parametric`_ No **Yes**
@ -422,7 +422,7 @@ Tuned examples: `CartPole-v0 <https://github.com/ray-project/ray/blob/master/rll
Evolution Strategies
--------------------
|tensorflow|
|pytorch| |tensorflow|
`[paper] <https://arxiv.org/abs/1703.03864>`__ `[implementation] <https://github.com/ray-project/ray/blob/master/rllib/agents/es/es.py>`__
Code here is adapted from https://github.com/openai/evolution-strategies-starter to execute in the distributed setting with Ray.

View file

@ -101,7 +101,7 @@ Algorithms
- |pytorch| |tensorflow| :ref:`Advantage Actor-Critic (A2C, A3C) <a3c>`
- |tensorflow| :ref:`Deep Deterministic Policy Gradients (DDPG, TD3) <ddpg>`
- |pytorch| |tensorflow| :ref:`Deep Deterministic Policy Gradients (DDPG, TD3) <ddpg>`
- |pytorch| |tensorflow| :ref:`Deep Q Networks (DQN, Rainbow, Parametric DQN) <dqn>`
@ -109,13 +109,13 @@ Algorithms
- |pytorch| |tensorflow| :ref:`Proximal Policy Optimization (PPO) <ppo>`
- |tensorflow| :ref:`Soft Actor Critic (SAC) <sac>`
- |pytorch| |tensorflow| :ref:`Soft Actor Critic (SAC) <sac>`
* Derivative-free
- |tensorflow| :ref:`Augmented Random Search (ARS) <ars>`
- |tensorflow| :ref:`Evolution Strategies <es>`
- |pytorch| |tensorflow| :ref:`Evolution Strategies <es>`
* Multi-agent specific
@ -124,7 +124,7 @@ Algorithms
* Offline
- |tensorflow| :ref:`Advantage Re-Weighted Imitation Learning (MARWIL) <marwil>`
- |pytorch| |tensorflow| :ref:`Advantage Re-Weighted Imitation Learning (MARWIL) <marwil>`
* Contextual bandits

View file

@ -103,7 +103,7 @@ py_test(
py_test(
name = "test_apex",
tags = ["agents_dir"],
size = "medium",
size = "large",
srcs = ["agents/dqn/tests/test_apex.py"]
)

View file

@ -1,6 +1,5 @@
from ray.rllib.agents.es.es import (ESTrainer, DEFAULT_CONFIG)
from ray.rllib.utils import renamed_agent
from ray.rllib.agents.es.es import ESTrainer, DEFAULT_CONFIG
from ray.rllib.agents.es.es_tf_policy import ESTFPolicy
from ray.rllib.agents.es.es_torch_policy import ESTorchPolicy
ESAgent = renamed_agent(ESTrainer)
__all__ = ["ESAgent", "ESTrainer", "DEFAULT_CONFIG"]
__all__ = ["ESTFPolicy", "ESTorchPolicy", "ESTrainer", "DEFAULT_CONFIG"]

View file

@ -8,7 +8,8 @@ import time
import ray
from ray.rllib.agents import Trainer, with_common_config
from ray.rllib.agents.es import optimizers, policies, utils
from ray.rllib.agents.es import optimizers, utils
from ray.rllib.agents.es.es_tf_policy import ESTFPolicy, rollout
from ray.rllib.env.env_context import EnvContext
from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID
from ray.rllib.utils import FilterManager
@ -72,7 +73,8 @@ class Worker:
min_task_runtime=0.2):
self.min_task_runtime = min_task_runtime
self.config = config
self.policy_params = policy_params
self.config.update(policy_params)
self.config["single_threaded"] = True
self.noise = SharedNoiseTable(noise)
env_context = EnvContext(config["env_config"] or {}, worker_index)
@ -81,15 +83,13 @@ class Worker:
self.preprocessor = models.ModelCatalog.get_preprocessor(
self.env, config["model"])
self.sess = utils.make_session(single_threaded=True)
self.policy = policies.GenericPolicy(
self.sess, self.env.action_space, self.env.observation_space,
self.preprocessor, config["observation_filter"], config["model"],
**policy_params)
policy_cls = get_policy_class(config)
self.policy = policy_cls(self.env.observation_space,
self.env.action_space, config)
@property
def filters(self):
return {DEFAULT_POLICY_ID: self.policy.get_filter()}
return {DEFAULT_POLICY_ID: self.policy.observation_filter}
def sync_filters(self, new_filters):
for k in self.filters:
@ -104,7 +104,7 @@ class Worker:
return return_filters
def rollout(self, timestep_limit, add_noise=True):
rollout_rewards, rollout_fragment_length = policies.rollout(
rollout_rewards, rollout_fragment_length = rollout(
self.policy,
self.env,
timestep_limit=timestep_limit,
@ -113,7 +113,7 @@ class Worker:
def do_rollouts(self, params, timestep_limit=None):
# Set the network weights.
self.policy.set_weights(params)
self.policy.set_flat_weights(params)
noise_indices, returns, sign_returns, lengths = [], [], [], []
eval_returns, eval_lengths = [], []
@ -125,7 +125,7 @@ class Worker:
if np.random.uniform() < self.config["eval_prob"]:
# Do an evaluation run with no perturbation.
self.policy.set_weights(params)
self.policy.set_flat_weights(params)
rewards, length = self.rollout(timestep_limit, add_noise=False)
eval_returns.append(rewards.sum())
eval_lengths.append(length)
@ -138,10 +138,10 @@ class Worker:
# These two sampling steps could be done in parallel on
# different actors letting us update twice as frequently.
self.policy.set_weights(params + perturbation)
self.policy.set_flat_weights(params + perturbation)
rewards_pos, lengths_pos = self.rollout(timestep_limit)
self.policy.set_weights(params - perturbation)
self.policy.set_flat_weights(params - perturbation)
rewards_neg, lengths_neg = self.rollout(timestep_limit)
noise_indices.append(noise_index)
@ -160,6 +160,15 @@ class Worker:
eval_lengths=eval_lengths)
def get_policy_class(config):
if config["use_pytorch"]:
from ray.rllib.agents.es.es_torch_policy import ESTorchPolicy
policy_cls = ESTorchPolicy
else:
policy_cls = ESTFPolicy
return policy_cls
class ESTrainer(Trainer):
"""Large-scale implementation of Evolution Strategies in Ray."""
@ -168,22 +177,15 @@ class ESTrainer(Trainer):
@override(Trainer)
def _init(self, config, env_creator):
# PyTorch check.
if config["use_pytorch"]:
raise ValueError(
"ES does not support PyTorch yet! Use tf instead.")
policy_params = {"action_noise_std": 0.01}
config.update(policy_params)
env_context = EnvContext(config["env_config"] or {}, worker_index=0)
env = env_creator(env_context)
from ray.rllib import models
preprocessor = models.ModelCatalog.get_preprocessor(env)
self.sess = utils.make_session(single_threaded=False)
self.policy = policies.GenericPolicy(
self.sess, env.action_space, env.observation_space, preprocessor,
config["observation_filter"], config["model"], **policy_params)
policy_cls = get_policy_class(config)
self.policy = policy_cls(
obs_space=env.observation_space,
action_space=env.action_space,
config=config)
self.optimizer = optimizers.Adam(self.policy, config["stepsize"])
self.report_length = config["report_length"]
@ -207,8 +209,9 @@ class ESTrainer(Trainer):
def _train(self):
config = self.config
theta = self.policy.get_weights()
theta = self.policy.get_flat_weights()
assert theta.dtype == np.float32
assert len(theta.shape) == 1
# Put the current policy weights in the object store.
theta_id = ray.put(theta)
@ -264,14 +267,14 @@ class ESTrainer(Trainer):
theta, update_ratio = self.optimizer.update(-g +
config["l2_coeff"] * theta)
# Set the new weights in the local copy of the policy.
self.policy.set_weights(theta)
self.policy.set_flat_weights(theta)
# Store the rewards
if len(all_eval_returns) > 0:
self.reward_list.append(np.mean(eval_returns))
# Now sync the filters
FilterManager.synchronize({
DEFAULT_POLICY_ID: self.policy.get_filter()
DEFAULT_POLICY_ID: self.policy.observation_filter
}, self._workers)
info = {
@ -293,7 +296,7 @@ class ESTrainer(Trainer):
@override(Trainer)
def compute_action(self, observation, *args, **kwargs):
return self.policy.compute(observation, update=False)[0]
return self.policy.compute_actions(observation, update=False)[0]
@override(Trainer)
def _stop(self):
@ -325,15 +328,15 @@ class ESTrainer(Trainer):
def __getstate__(self):
return {
"weights": self.policy.get_weights(),
"filter": self.policy.get_filter(),
"weights": self.policy.get_flat_weights(),
"filter": self.policy.observation_filter,
"episodes_so_far": self.episodes_so_far,
}
def __setstate__(self, state):
self.episodes_so_far = state["episodes_so_far"]
self.policy.set_weights(state["weights"])
self.policy.set_filter(state["filter"])
self.policy.set_flat_weights(state["weights"])
self.policy.observation_filter = state["filter"]
FilterManager.synchronize({
DEFAULT_POLICY_ID: self.policy.get_filter()
DEFAULT_POLICY_ID: self.policy.observation_filter
}, self._workers)

View file

@ -8,6 +8,7 @@ import ray
import ray.experimental.tf_utils
from ray.rllib.evaluation.sampler import _unbatch_tuple_actions
from ray.rllib.models import ModelCatalog
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.utils.filter import get_filter
from ray.rllib.utils import try_import_tf
@ -30,7 +31,7 @@ def rollout(policy, env, timestep_limit=None, add_noise=False):
t = 0
observation = env.reset()
for _ in range(timestep_limit or max_timestep_limit):
ac = policy.compute(observation, add_noise=add_noise)[0]
ac = policy.compute_actions(observation, add_noise=add_noise)[0]
observation, rew, done, _ = env.step(ac)
rews.append(rew)
t += 1
@ -40,24 +41,32 @@ def rollout(policy, env, timestep_limit=None, add_noise=False):
return rews, t
class GenericPolicy:
def __init__(self, sess, action_space, obs_space, preprocessor,
observation_filter, model_options, action_noise_std):
self.sess = sess
def make_session(single_threaded):
if not single_threaded:
return tf.Session()
return tf.Session(
config=tf.ConfigProto(
inter_op_parallelism_threads=1, intra_op_parallelism_threads=1))
class ESTFPolicy:
def __init__(self, obs_space, action_space, config):
self.action_space = action_space
self.action_noise_std = action_noise_std
self.preprocessor = preprocessor
self.observation_filter = get_filter(observation_filter,
self.action_noise_std = config["action_noise_std"]
self.preprocessor = ModelCatalog.get_preprocessor_for_space(obs_space)
self.observation_filter = get_filter(config["observation_filter"],
self.preprocessor.shape)
self.single_threaded = config.get("single_threaded", False)
self.sess = make_session(single_threaded=self.single_threaded)
self.inputs = tf.placeholder(tf.float32,
[None] + list(self.preprocessor.shape))
# Policy network.
dist_class, dist_dim = ModelCatalog.get_action_dist(
self.action_space, model_options, dist_type="deterministic")
self.action_space, config["model"], dist_type="deterministic")
model = ModelCatalog.get_model({
"obs": self.inputs
}, obs_space, action_space, dist_dim, model_options)
SampleBatch.CUR_OBS: self.inputs
}, obs_space, action_space, dist_dim, config["model"])
dist = dist_class(model.outputs, model)
self.sampler = dist.sample()
@ -69,7 +78,7 @@ class GenericPolicy:
for _, variable in self.variables.variables.items())
self.sess.run(tf.global_variables_initializer())
def compute(self, observation, add_noise=False, update=True):
def compute_actions(self, observation, add_noise=False, update=True):
observation = self.preprocessor.transform(observation)
observation = self.observation_filter(observation[None], update=update)
action = self.sess.run(
@ -79,14 +88,8 @@ class GenericPolicy:
action += np.random.randn(*action.shape) * self.action_noise_std
return action
def set_weights(self, x):
def set_flat_weights(self, x):
self.variables.set_flat(x)
def get_weights(self):
def get_flat_weights(self):
return self.variables.get_flat()
def get_filter(self):
return self.observation_filter
def set_filter(self, observation_filter):
self.observation_filter = observation_filter

View file

@ -0,0 +1,108 @@
# Code in this file is adapted from:
# https://github.com/openai/evolution-strategies-starter.
import gym
import numpy as np
import ray
import ray.experimental.tf_utils
from ray.rllib.evaluation.sampler import _unbatch_tuple_actions
from ray.rllib.models import ModelCatalog
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.policy.torch_policy_template import build_torch_policy
from ray.rllib.utils.filter import get_filter
from ray.rllib.utils.framework import try_import_torch
from ray.rllib.utils.torch_ops import convert_to_torch_tensor
torch, _ = try_import_torch()
def before_init(policy, observation_space, action_space, config):
policy.action_noise_std = config["action_noise_std"]
policy.preprocessor = ModelCatalog.get_preprocessor_for_space(
observation_space)
policy.observation_filter = get_filter(config["observation_filter"],
policy.preprocessor.shape)
policy.single_threaded = config.get("single_threaded", False)
def _set_flat_weights(policy, theta):
pos = 0
theta_dict = policy.model.state_dict()
new_theta_dict = {}
for k in sorted(theta_dict.keys()):
shape = policy.param_shapes[k]
num_params = int(np.prod(shape))
new_theta_dict[k] = torch.from_numpy(
np.reshape(theta[pos:pos + num_params], shape))
pos += num_params
policy.model.load_state_dict(new_theta_dict)
def _get_flat_weights(policy):
# Get the parameter tensors.
theta_dict = policy.model.state_dict()
# Flatten it into a single np.ndarray.
theta_list = []
for k in sorted(theta_dict.keys()):
theta_list.append(torch.reshape(theta_dict[k], (-1, )))
cat = torch.cat(theta_list, dim=0)
return cat.numpy()
type(policy).set_flat_weights = _set_flat_weights
type(policy).get_flat_weights = _get_flat_weights
def _compute_actions(policy, obs_batch, add_noise=False, update=True):
observation = policy.preprocessor.transform(obs_batch)
observation = policy.observation_filter(
observation[None], update=update)
observation = convert_to_torch_tensor(observation)
dist_inputs, _ = policy.model({
SampleBatch.CUR_OBS: observation
}, [], None)
dist = policy.dist_class(dist_inputs, policy.model)
action = dist.sample().detach().numpy()
action = _unbatch_tuple_actions(action)
if add_noise and isinstance(policy.action_space, gym.spaces.Box):
action += np.random.randn(*action.shape) * policy.action_noise_std
return action
type(policy).compute_actions = _compute_actions
def after_init(policy, observation_space, action_space, config):
state_dict = policy.model.state_dict()
policy.param_shapes = {
k: tuple(state_dict[k].size())
for k in sorted(state_dict.keys())
}
policy.num_params = sum(np.prod(s) for s in policy.param_shapes.values())
def make_model_and_action_dist(policy, observation_space, action_space,
config):
# Policy network.
dist_class, dist_dim = ModelCatalog.get_action_dist(
action_space,
config["model"], # model_options
dist_type="deterministic",
framework="torch")
model = ModelCatalog.get_model_v2(
observation_space,
action_space,
num_outputs=dist_dim,
model_config=config["model"],
framework="torch")
# Make all model params not require any gradients.
for p in model.parameters():
p.requires_grad = False
return model, dist_class
ESTorchPolicy = build_torch_policy(
name="ESTorchPolicy",
loss_fn=None,
get_default_config=lambda: ray.rllib.agents.es.es.DEFAULT_CONFIG,
before_init=before_init,
after_init=after_init,
make_model_and_action_dist=make_model_and_action_dist)

View file

@ -13,7 +13,7 @@ class Optimizer:
def update(self, globalg):
self.t += 1
step = self._compute_step(globalg)
theta = self.pi.get_weights()
theta = self.pi.get_flat_weights()
ratio = np.linalg.norm(step) / np.linalg.norm(theta)
return theta + step, ratio

View file

@ -0,0 +1,33 @@
import unittest
import ray
import ray.rllib.agents.es as es
from ray.rllib.utils.framework import try_import_tf
from ray.rllib.utils.test_utils import framework_iterator
tf = try_import_tf()
class TestES(unittest.TestCase):
def test_es_compilation(self):
"""Test whether an ESTrainer can be built on all frameworks."""
ray.init()
config = es.DEFAULT_CONFIG.copy()
# Keep it simple.
config["model"]["fcnet_hiddens"] = [10]
config["model"]["fcnet_activation"] = None
num_iterations = 2
for _ in framework_iterator(config, ("torch", "tf")):
plain_config = config.copy()
trainer = es.ESTrainer(config=plain_config, env="CartPole-v0")
for i in range(num_iterations):
results = trainer.train()
print(results)
if __name__ == "__main__":
import pytest
import sys
sys.exit(pytest.main(["-v", __file__]))

View file

@ -2,9 +2,6 @@
# https://github.com/openai/evolution-strategies-starter.
import numpy as np
from ray.rllib.utils import try_import_tf
tf = try_import_tf()
def compute_ranks(x):
@ -26,14 +23,6 @@ def compute_centered_ranks(x):
return y
def make_session(single_threaded):
if not single_threaded:
return tf.Session()
return tf.Session(
config=tf.ConfigProto(
inter_op_parallelism_threads=1, intra_op_parallelism_threads=1))
def itergroups(items, group_size):
assert group_size >= 1
group = []

View file

@ -18,7 +18,7 @@ from ray.rllib.models.tf.tf_modelv2 import TFModelV2
from ray.rllib.models.tf.visionnet_v1 import VisionNetwork
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
from ray.rllib.models.torch.torch_action_dist import TorchCategorical, \
TorchMultiCategorical, TorchDiagGaussian
TorchMultiCategorical, TorchDeterministic, TorchDiagGaussian
from ray.rllib.utils import try_import_tf
from ray.rllib.utils.annotations import DeveloperAPI, PublicAPI
from ray.rllib.utils.error import UnsupportedSpaceException
@ -149,7 +149,8 @@ class ModelCatalog:
if dist_type is None:
dist = DiagGaussian if framework == "tf" else TorchDiagGaussian
elif dist_type == "deterministic":
dist = Deterministic
dist = Deterministic if framework == "tf" else \
TorchDeterministic
# Discrete Space -> Categorical.
elif isinstance(action_space, gym.spaces.Discrete):
dist = Categorical if framework == "tf" else TorchCategorical

View file

@ -31,7 +31,7 @@ class FullyConnectedNetwork(TorchModelV2, nn.Module):
logger.debug("Constructing fcnet {} {}".format(hiddens, activation))
layers = []
prev_layer_size = np.product(obs_space.shape)
prev_layer_size = int(np.product(obs_space.shape))
self._logits = None
# Create layers 0 to second-last.

View file

@ -28,48 +28,55 @@ def build_torch_policy(name,
apply_gradients_fn=None,
mixins=None,
get_batch_divisibility_req=None):
"""Helper function for creating a torch policy at runtime.
"""Helper function for creating a torch policy class at runtime.
Arguments:
name (str): name of the policy (e.g., "PPOTorchPolicy")
loss_fn (func): function that returns a loss tensor as arguments
(policy, model, dist_class, train_batch)
get_default_config (func): optional function that returns the default
config to merge with any overrides
stats_fn (func): optional function that returns a dict of
values given the policy and batch input tensors
postprocess_fn (func): optional experience postprocessing function
that takes the same args as Policy.postprocess_trajectory()
extra_action_out_fn (func): optional function that returns
a dict of extra values to include in experiences
extra_grad_process_fn (func): optional function that is called after
gradients are computed and returns processing info
optimizer_fn (func): optional function that returns a torch optimizer
given the policy and config
before_init (func): optional function to run at the beginning of
policy init that takes the same arguments as the policy constructor
after_init (func): optional function to run at the end of policy init
that takes the same arguments as the policy constructor
action_sampler_fn (Optional[callable]): A callable returning a sampled
action and its log-likelihood given some (obs and state) inputs.
action_distribution_fn (Optional[callable]): A callable returning
distribution inputs (parameters), a dist-class to generate an
action distribution object from, and internal-state outputs (or an
empty list if not applicable).
make_model_and_action_dist (func): optional func that takes the same
arguments as policy init and returns a tuple of model instance and
torch action distribution class. If not specified, the default
model and action dist from the catalog will be used
apply_gradients_fn (Optional[callable]): An optional callable that
loss_fn (callable): Callable that returns a loss tensor as arguments
given (policy, model, dist_class, train_batch).
get_default_config (Optional[callable]): Optional callable that returns
the default config to merge with any overrides.
stats_fn (Optional[callable]): Optional callable that returns a dict of
values given the policy and batch input tensors.
postprocess_fn (Optional[callable]): Optional experience postprocessing
function that takes the same args as
Policy.postprocess_trajectory().
extra_action_out_fn (Optional[callable]): Optional callable that
returns a dict of extra values to include in experiences.
extra_grad_process_fn (Optional[callable]): Optional callable that is
called after gradients are computed and returns processing info.
optimizer_fn (Optional[callable]): Optional callable that returns a
torch optimizer given the policy and config.
before_init (Optional[callable]): Optional callable to run at the
beginning of `Policy.__init__` that takes the same arguments as
the Policy constructor.
after_init (Optional[callable]): Optional callable to run at the end of
policy init that takes the same arguments as the policy
constructor.
action_sampler_fn (Optional[callable]): Optional callable returning a
sampled action and its log-likelihood given some (obs and state)
inputs.
action_distribution_fn (Optional[callable]): A callable that takes
the Policy, Model, the observation batch, an explore-flag, a
timestep, and an is_training flag and returns a tuple of
a) distribution inputs (parameters), b) a dist-class to generate
an action distribution object from, and c) internal-state outputs
(empty list if not applicable).
make_model_and_action_dist (Optional[callable]): Optional func that
takes the same arguments as Policy.__init__ and returns a tuple
of model instance and torch action distribution class. If not
specified, the default model and action dist from the catalog will
be used.
apply_gradients_fn (Optional[callable]): Optional callable that
takes a grads list and applies these to the Model's parameters.
mixins (list): list of any class mixins for the returned policy class.
These mixins will be applied in order and will have higher
precedence than the TorchPolicy class
precedence than the TorchPolicy class.
get_batch_divisibility_req (Optional[callable]): Optional callable that
returns the divisibility requirement for sample batches.
Returns:
a TorchPolicy instance that uses the specified args
type: TorchPolicy child class constructed from the specified args.
"""
original_kwargs = locals().copy()

View file

@ -2,7 +2,7 @@
# Runs one or more regression tests. Retries tests up to 3 times.
#
# Example usage:
# $ python run_regression_tests.py regression-tests/cartpole-es.yaml
# $ python run_regression_tests.py regression-tests/cartpole-es-[tf|torch].yaml
#
# When using in BAZEL (with py_test), e.g. see in ray/rllib/BUILD:
# py_test(

View file

@ -1,10 +1,11 @@
cartpole-es:
cartpole-es-tf:
env: CartPole-v0
run: ES
stop:
episode_reward_mean: 75
episode_reward_mean: 150
timesteps_total: 400000
config:
use_pytorch: false
num_workers: 2
noise_size: 25000000
episodes_per_batch: 50

View file

@ -0,0 +1,11 @@
cartpole-es-torch:
env: CartPole-v0
run: ES
stop:
episode_reward_mean: 150
timesteps_total: 400000
config:
use_pytorch: true
num_workers: 2
noise_size: 25000000
episodes_per_batch: 50

View file

@ -214,15 +214,15 @@ def get_activation_fn(name, framework="tf"):
torch.nn.ReLU. Returns None for name="linear".
"""
if framework == "torch":
_, nn = try_import_torch()
if name == "linear":
if name == "linear" or name is None:
return None
elif name == "relu":
_, nn = try_import_torch()
if name == "relu":
return nn.ReLU
elif name == "tanh":
return nn.Tanh
else:
if name == "linear":
if name == "linear" or name is None:
return None
tf = try_import_tf()
fn = getattr(tf.nn, name, None)