mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
* Exploration API (+EpsilonGreedy sub-class). * Exploration API (+EpsilonGreedy sub-class). * Cleanup/LINT. * Add `deterministic` to generic Trainer config (NOTE: this is still ignored by most Agents). * Add `error` option to deprecation_warning(). * WIP. * Bug fix: Get exploration-info for tf framework. Bug fix: Properly deprecate some DQN config keys. * WIP. * LINT. * WIP. * Split PerWorkerEpsilonGreedy out of EpsilonGreedy. Docstrings. * Fix bug in sampler.py in case Policy has self.exploration = None * Update rllib/agents/dqn/dqn.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * WIP. * Update rllib/agents/trainer.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * WIP. * Change requests. * LINT * In tune/utils/util.py::deep_update() Only keep deep_updat'ing if both original and value are dicts. If value is not a dict, set * Completely obsolete syn_replay_optimizer.py's parameters schedule_max_timesteps AND beta_annealing_fraction (replaced with prioritized_replay_beta_annealing_timesteps). * Update rllib/evaluation/worker_set.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Review fixes. * Fix default value for DQN's exploration spec. * LINT * Fix recursion bug (wrong parent c'tor). * Do not pass timestep to get_exploration_info. * Update tf_policy.py * Fix some remaining issues with test cases and remove more deprecated DQN/APEX exploration configs. * Bug fix tf-action-dist * DDPG incompatibility bug fix with new DQN exploration handling (which is imported by DDPG). * Switch off exploration when getting action probs from off-policy-estimator's policy. * LINT * Fix test_checkpoint_restore.py. * Deprecate all SAC exploration (unused) configs. * Properly use `model.last_output()` everywhere. Instead of `model._last_output`. * WIP. * Take out set_epsilon from multi-agent-env test (not needed, decays anyway). * WIP. * Trigger re-test (flaky checkpoint-restore test). * WIP. * WIP. * Add test case for deterministic action sampling in PPO. * bug fix. * Added deterministic test cases for different Agents. * Fix problem with TupleActions in dynamic-tf-policy. * Separate supported_spaces tests so they can be run separately for easier debugging. * LINT. * Fix autoregressive_action_dist.py test case. * Re-test. * Fix. * Remove duplicate py_test rule from bazel. * LINT. * WIP. * WIP. * SAC fix. * SAC fix. * WIP. * WIP. * WIP. * FIX 2 examples tests. * WIP. * WIP. * WIP. * WIP. * WIP. * Fix. * LINT. * Renamed test file. * WIP. * Add unittest.main. * Make action_dist_class mandatory. * fix * FIX. * WIP. * WIP. * Fix. * Fix. * Fix explorations test case (contextlib cannot find its own nullcontext??). * Force torch to be installed for QMIX. * LINT. * Fix determine_tests_to_run.py. * Fix determine_tests_to_run.py. * WIP * Add Random exploration component to tests (fixed issue with "static-graph randomness" via py_function). * Add Random exploration component to tests (fixed issue with "static-graph randomness" via py_function). * Rename some stuff. * Rename some stuff. * WIP. * WIP. * Fix SAC. * Fix SAC. * Fix strange tf-error in ray core tests. * Fix strange ray-core tf-error in test_memory_scheduling test case. * Fix test_io.py. * LINT. * Update SAC yaml files' config. Co-authored-by: Eric Liang <ekhliang@gmail.com>
This commit is contained in:
parent
4c2de7be54
commit
0db2046b0a
35 changed files with 768 additions and 232 deletions
|
@ -79,7 +79,7 @@ fi
|
||||||
if [[ "$RLLIB_TESTING" == "1" ]]; then
|
if [[ "$RLLIB_TESTING" == "1" ]]; then
|
||||||
pip install -q tensorflow-probability==$tfp_version gast==0.2.2 \
|
pip install -q tensorflow-probability==$tfp_version gast==0.2.2 \
|
||||||
torch==$torch_version torchvision \
|
torch==$torch_version torchvision \
|
||||||
gym[atari] atari_py smart_open
|
gym[atari] atari_py smart_open lz4
|
||||||
fi
|
fi
|
||||||
|
|
||||||
if [[ "$PYTHON" == "3.6" ]] || [[ "$MAC_WHEELS" == "1" ]]; then
|
if [[ "$PYTHON" == "3.6" ]] || [[ "$MAC_WHEELS" == "1" ]]; then
|
||||||
|
|
|
@ -4,7 +4,6 @@ import logging
|
||||||
import os
|
import os
|
||||||
import yaml
|
import yaml
|
||||||
import numbers
|
import numbers
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
import ray.cloudpickle as cloudpickle
|
import ray.cloudpickle as cloudpickle
|
||||||
|
|
25
rllib/BUILD
25
rllib/BUILD
|
@ -796,7 +796,6 @@ py_test(
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# --------------------------------------------------------------------
|
# --------------------------------------------------------------------
|
||||||
# Models and Distributions
|
# Models and Distributions
|
||||||
# rllib/models/
|
# rllib/models/
|
||||||
|
@ -811,6 +810,20 @@ py_test(
|
||||||
srcs = ["models/tests/test_distributions.py"]
|
srcs = ["models/tests/test_distributions.py"]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# --------------------------------------------------------------------
|
||||||
|
# Policies
|
||||||
|
# rllib/policy/
|
||||||
|
#
|
||||||
|
# Tag: policy
|
||||||
|
# --------------------------------------------------------------------
|
||||||
|
|
||||||
|
py_test(
|
||||||
|
name = "policy/tests/test_compute_log_likelihoods",
|
||||||
|
tags = ["policy"],
|
||||||
|
size = "small",
|
||||||
|
srcs = ["policy/tests/test_compute_log_likelihoods.py"]
|
||||||
|
)
|
||||||
|
|
||||||
# --------------------------------------------------------------------
|
# --------------------------------------------------------------------
|
||||||
# Utils:
|
# Utils:
|
||||||
# rllib/utils/
|
# rllib/utils/
|
||||||
|
@ -880,14 +893,6 @@ py_test(
|
||||||
srcs = ["tests/test_dependency.py"]
|
srcs = ["tests/test_dependency.py"]
|
||||||
)
|
)
|
||||||
|
|
||||||
# PR 7086
|
|
||||||
#py_test(
|
|
||||||
# name = "tests/test_deterministic_support",
|
|
||||||
# tags = ["tests_dir", "tests_dir_D"],
|
|
||||||
# size = "small",
|
|
||||||
# srcs = ["tests/test_deterministic_support.py"]
|
|
||||||
#)
|
|
||||||
|
|
||||||
py_test(
|
py_test(
|
||||||
name = "tests/test_eager_support",
|
name = "tests/test_eager_support",
|
||||||
tags = ["tests_dir", "tests_dir_E"],
|
tags = ["tests_dir", "tests_dir_E"],
|
||||||
|
@ -912,7 +917,7 @@ py_test(
|
||||||
|
|
||||||
py_test(
|
py_test(
|
||||||
name = "tests/test_explorations",
|
name = "tests/test_explorations",
|
||||||
tags = ["tests_dir", "tests_dir_E"],
|
tags = ["tests_dir", "tests_dir_E", "explorations"],
|
||||||
size = "medium",
|
size = "medium",
|
||||||
srcs = ["tests/test_explorations.py"]
|
srcs = ["tests/test_explorations.py"]
|
||||||
)
|
)
|
||||||
|
|
|
@ -291,7 +291,7 @@ class DDPGTFPolicy(DDPGPostprocessing, TFPolicy):
|
||||||
self.config,
|
self.config,
|
||||||
self.sess,
|
self.sess,
|
||||||
obs_input=self.cur_observations,
|
obs_input=self.cur_observations,
|
||||||
action_sampler=self.output_actions,
|
sampled_action=self.output_actions,
|
||||||
loss=self.actor_loss + self.critic_loss,
|
loss=self.actor_loss + self.critic_loss,
|
||||||
loss_inputs=self.loss_inputs,
|
loss_inputs=self.loss_inputs,
|
||||||
update_ops=q_batchnorm_update_ops + policy_batchnorm_update_ops)
|
update_ops=q_batchnorm_update_ops + policy_batchnorm_update_ops)
|
||||||
|
|
|
@ -202,20 +202,29 @@ def build_q_model(policy, obs_space, action_space, config):
|
||||||
return policy.q_model
|
return policy.q_model
|
||||||
|
|
||||||
|
|
||||||
def sample_action_from_q_network(policy, q_model, input_dict, obs_space,
|
def get_log_likelihood(policy, q_model, actions, input_dict, obs_space,
|
||||||
action_space, explore, config, timestep):
|
action_space, config):
|
||||||
|
|
||||||
# Action Q network.
|
# Action Q network.
|
||||||
q_vals = _compute_q_values(policy, q_model,
|
q_vals = _compute_q_values(policy, q_model,
|
||||||
input_dict[SampleBatch.CUR_OBS], obs_space,
|
input_dict[SampleBatch.CUR_OBS], obs_space,
|
||||||
action_space)
|
action_space)
|
||||||
|
q_vals = q_vals[0] if isinstance(q_vals, tuple) else q_vals
|
||||||
|
action_dist = Categorical(q_vals, q_model)
|
||||||
|
return action_dist.logp(actions)
|
||||||
|
|
||||||
|
|
||||||
|
def sample_action_from_q_network(policy, q_model, input_dict, obs_space,
|
||||||
|
action_space, explore, config, timestep):
|
||||||
|
# Action Q network.
|
||||||
|
q_vals = _compute_q_values(policy, q_model,
|
||||||
|
input_dict[SampleBatch.CUR_OBS], obs_space,
|
||||||
|
action_space)
|
||||||
policy.q_values = q_vals[0] if isinstance(q_vals, tuple) else q_vals
|
policy.q_values = q_vals[0] if isinstance(q_vals, tuple) else q_vals
|
||||||
policy.q_func_vars = q_model.variables()
|
policy.q_func_vars = q_model.variables()
|
||||||
|
|
||||||
policy.output_actions, policy.action_logp = \
|
policy.output_actions, policy.sampled_action_logp = \
|
||||||
policy.exploration.get_exploration_action(
|
policy.exploration.get_exploration_action(
|
||||||
policy.q_values, q_model, Categorical, explore, timestep)
|
policy.q_values, Categorical, q_model, explore, timestep)
|
||||||
|
|
||||||
# Noise vars for Q network except for layer normalization vars.
|
# Noise vars for Q network except for layer normalization vars.
|
||||||
if config["parameter_noise"]:
|
if config["parameter_noise"]:
|
||||||
|
@ -224,7 +233,7 @@ def sample_action_from_q_network(policy, q_model, input_dict, obs_space,
|
||||||
[var for var in policy.q_func_vars if "LayerNorm" not in var.name])
|
[var for var in policy.q_func_vars if "LayerNorm" not in var.name])
|
||||||
policy.action_probs = tf.nn.softmax(policy.q_values)
|
policy.action_probs = tf.nn.softmax(policy.q_values)
|
||||||
|
|
||||||
return policy.output_actions, policy.action_logp
|
return policy.output_actions, policy.sampled_action_logp
|
||||||
|
|
||||||
|
|
||||||
def _build_parameter_noise(policy, pnet_params):
|
def _build_parameter_noise(policy, pnet_params):
|
||||||
|
@ -448,6 +457,7 @@ DQNTFPolicy = build_tf_policy(
|
||||||
get_default_config=lambda: ray.rllib.agents.dqn.dqn.DEFAULT_CONFIG,
|
get_default_config=lambda: ray.rllib.agents.dqn.dqn.DEFAULT_CONFIG,
|
||||||
make_model=build_q_model,
|
make_model=build_q_model,
|
||||||
action_sampler_fn=sample_action_from_q_network,
|
action_sampler_fn=sample_action_from_q_network,
|
||||||
|
log_likelihood_fn=get_log_likelihood,
|
||||||
loss_fn=build_q_losses,
|
loss_fn=build_q_losses,
|
||||||
stats_fn=build_q_stats,
|
stats_fn=build_q_stats,
|
||||||
postprocess_fn=postprocess_nstep_and_prio,
|
postprocess_fn=postprocess_nstep_and_prio,
|
||||||
|
|
|
@ -88,6 +88,17 @@ def build_q_models(policy, obs_space, action_space, config):
|
||||||
return policy.q_model
|
return policy.q_model
|
||||||
|
|
||||||
|
|
||||||
|
def get_log_likelihood(policy, q_model, actions, input_dict, obs_space,
|
||||||
|
action_space, config):
|
||||||
|
# Action Q network.
|
||||||
|
q_vals = _compute_q_values(policy, q_model,
|
||||||
|
input_dict[SampleBatch.CUR_OBS], obs_space,
|
||||||
|
action_space)
|
||||||
|
q_vals = q_vals[0] if isinstance(q_vals, tuple) else q_vals
|
||||||
|
action_dist = Categorical(q_vals, q_model)
|
||||||
|
return action_dist.logp(actions)
|
||||||
|
|
||||||
|
|
||||||
def simple_sample_action_from_q_network(policy, q_model, input_dict, obs_space,
|
def simple_sample_action_from_q_network(policy, q_model, input_dict, obs_space,
|
||||||
action_space, explore, config,
|
action_space, explore, config,
|
||||||
timestep):
|
timestep):
|
||||||
|
@ -95,15 +106,14 @@ def simple_sample_action_from_q_network(policy, q_model, input_dict, obs_space,
|
||||||
q_vals = _compute_q_values(policy, q_model,
|
q_vals = _compute_q_values(policy, q_model,
|
||||||
input_dict[SampleBatch.CUR_OBS], obs_space,
|
input_dict[SampleBatch.CUR_OBS], obs_space,
|
||||||
action_space)
|
action_space)
|
||||||
|
|
||||||
policy.q_values = q_vals[0] if isinstance(q_vals, tuple) else q_vals
|
policy.q_values = q_vals[0] if isinstance(q_vals, tuple) else q_vals
|
||||||
policy.q_func_vars = q_model.variables()
|
policy.q_func_vars = q_model.variables()
|
||||||
|
|
||||||
policy.output_actions, policy.action_logp = \
|
policy.output_actions, policy.sampled_action_logp = \
|
||||||
policy.exploration.get_exploration_action(
|
policy.exploration.get_exploration_action(
|
||||||
policy.q_values, q_model, Categorical, explore, timestep)
|
policy.q_values, Categorical, q_model, explore, timestep)
|
||||||
|
|
||||||
return policy.output_actions, policy.action_logp
|
return policy.output_actions, policy.sampled_action_logp
|
||||||
|
|
||||||
|
|
||||||
def build_q_losses(policy, model, dist_class, train_batch):
|
def build_q_losses(policy, model, dist_class, train_batch):
|
||||||
|
@ -167,13 +177,11 @@ SimpleQPolicy = build_tf_policy(
|
||||||
get_default_config=lambda: ray.rllib.agents.dqn.dqn.DEFAULT_CONFIG,
|
get_default_config=lambda: ray.rllib.agents.dqn.dqn.DEFAULT_CONFIG,
|
||||||
make_model=build_q_models,
|
make_model=build_q_models,
|
||||||
action_sampler_fn=simple_sample_action_from_q_network,
|
action_sampler_fn=simple_sample_action_from_q_network,
|
||||||
|
log_likelihood_fn=get_log_likelihood,
|
||||||
loss_fn=build_q_losses,
|
loss_fn=build_q_losses,
|
||||||
extra_action_fetches_fn=lambda policy: {"q_values": policy.q_values},
|
extra_action_fetches_fn=lambda policy: {"q_values": policy.q_values},
|
||||||
extra_learn_fetches_fn=lambda policy: {"td_error": policy.td_error},
|
extra_learn_fetches_fn=lambda policy: {"td_error": policy.td_error},
|
||||||
before_init=setup_early_mixins,
|
before_init=setup_early_mixins,
|
||||||
after_init=setup_late_mixins,
|
after_init=setup_late_mixins,
|
||||||
obs_include_prev_action_reward=False,
|
obs_include_prev_action_reward=False,
|
||||||
mixins=[
|
mixins=[ParameterNoiseMixin, TargetNetworkMixin])
|
||||||
ParameterNoiseMixin,
|
|
||||||
TargetNetworkMixin,
|
|
||||||
])
|
|
||||||
|
|
|
@ -1,8 +1,6 @@
|
||||||
from gym.spaces import Tuple, Discrete, Dict
|
from gym.spaces import Tuple, Discrete, Dict
|
||||||
import logging
|
import logging
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch as th
|
|
||||||
import torch.nn as nn
|
|
||||||
from torch.optim import RMSprop
|
from torch.optim import RMSprop
|
||||||
from torch.distributions import Categorical
|
from torch.distributions import Categorical
|
||||||
|
|
||||||
|
@ -16,9 +14,13 @@ from ray.rllib.policy.sample_batch import SampleBatch
|
||||||
from ray.rllib.models.catalog import ModelCatalog
|
from ray.rllib.models.catalog import ModelCatalog
|
||||||
from ray.rllib.models.model import _unpack_obs
|
from ray.rllib.models.model import _unpack_obs
|
||||||
from ray.rllib.env.constants import GROUP_REWARDS
|
from ray.rllib.env.constants import GROUP_REWARDS
|
||||||
|
from ray.rllib.utils.framework import try_import_torch
|
||||||
from ray.rllib.utils.annotations import override
|
from ray.rllib.utils.annotations import override
|
||||||
from ray.rllib.utils.tuple_actions import TupleActions
|
from ray.rllib.utils.tuple_actions import TupleActions
|
||||||
|
|
||||||
|
# Torch must be installed.
|
||||||
|
torch, nn = try_import_torch(error=True)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
# if the obs space is Dict type, look for the global state under this key
|
# if the obs space is Dict type, look for the global state under this key
|
||||||
|
@ -85,7 +87,7 @@ class QMixLoss(nn.Module):
|
||||||
mac_out = _unroll_mac(self.model, obs)
|
mac_out = _unroll_mac(self.model, obs)
|
||||||
|
|
||||||
# Pick the Q-Values for the actions taken -> [B * n_agents, T]
|
# Pick the Q-Values for the actions taken -> [B * n_agents, T]
|
||||||
chosen_action_qvals = th.gather(
|
chosen_action_qvals = torch.gather(
|
||||||
mac_out, dim=3, index=actions.unsqueeze(3)).squeeze(3)
|
mac_out, dim=3, index=actions.unsqueeze(3)).squeeze(3)
|
||||||
|
|
||||||
# Calculate the Q-Values necessary for the target
|
# Calculate the Q-Values necessary for the target
|
||||||
|
@ -114,8 +116,8 @@ class QMixLoss(nn.Module):
|
||||||
|
|
||||||
# use the target network to estimate the Q-values of policy
|
# use the target network to estimate the Q-values of policy
|
||||||
# network's selected actions
|
# network's selected actions
|
||||||
target_max_qvals = th.gather(target_mac_out, 3,
|
target_max_qvals = torch.gather(target_mac_out, 3,
|
||||||
cur_max_actions).squeeze(3)
|
cur_max_actions).squeeze(3)
|
||||||
else:
|
else:
|
||||||
target_max_qvals = target_mac_out.max(dim=3)[0]
|
target_max_qvals = target_mac_out.max(dim=3)[0]
|
||||||
|
|
||||||
|
@ -167,8 +169,8 @@ class QMixTorchPolicy(Policy):
|
||||||
self.h_size = config["model"]["lstm_cell_size"]
|
self.h_size = config["model"]["lstm_cell_size"]
|
||||||
self.has_env_global_state = False
|
self.has_env_global_state = False
|
||||||
self.has_action_mask = False
|
self.has_action_mask = False
|
||||||
self.device = (th.device("cuda")
|
self.device = (torch.device("cuda")
|
||||||
if th.cuda.is_available() else th.device("cpu"))
|
if torch.cuda.is_available() else torch.device("cpu"))
|
||||||
|
|
||||||
agent_obs_space = obs_space.original_space.spaces[0]
|
agent_obs_space = obs_space.original_space.spaces[0]
|
||||||
if isinstance(agent_obs_space, Dict):
|
if isinstance(agent_obs_space, Dict):
|
||||||
|
@ -262,20 +264,21 @@ class QMixTorchPolicy(Policy):
|
||||||
# to compute actions
|
# to compute actions
|
||||||
|
|
||||||
# Compute actions
|
# Compute actions
|
||||||
with th.no_grad():
|
with torch.no_grad():
|
||||||
q_values, hiddens = _mac(
|
q_values, hiddens = _mac(
|
||||||
self.model,
|
self.model,
|
||||||
th.as_tensor(obs_batch, dtype=th.float, device=self.device), [
|
torch.as_tensor(
|
||||||
th.as_tensor(
|
obs_batch, dtype=torch.float, device=self.device), [
|
||||||
np.array(s), dtype=th.float, device=self.device)
|
torch.as_tensor(
|
||||||
for s in state_batches
|
np.array(s), dtype=torch.float, device=self.device)
|
||||||
])
|
for s in state_batches
|
||||||
avail = th.as_tensor(
|
])
|
||||||
action_mask, dtype=th.float, device=self.device)
|
avail = torch.as_tensor(
|
||||||
|
action_mask, dtype=torch.float, device=self.device)
|
||||||
masked_q_values = q_values.clone()
|
masked_q_values = q_values.clone()
|
||||||
masked_q_values[avail == 0.0] = -float("inf")
|
masked_q_values[avail == 0.0] = -float("inf")
|
||||||
# epsilon-greedy action selector
|
# epsilon-greedy action selector
|
||||||
random_numbers = th.rand_like(q_values[:, :, 0])
|
random_numbers = torch.rand_like(q_values[:, :, 0])
|
||||||
pick_random = (random_numbers < (self.cur_epsilon
|
pick_random = (random_numbers < (self.cur_epsilon
|
||||||
if explore else 0.0)).long()
|
if explore else 0.0)).long()
|
||||||
random_actions = Categorical(avail).sample().long()
|
random_actions = Categorical(avail).sample().long()
|
||||||
|
@ -286,6 +289,16 @@ class QMixTorchPolicy(Policy):
|
||||||
|
|
||||||
return TupleActions(list(actions.transpose([1, 0]))), hiddens, {}
|
return TupleActions(list(actions.transpose([1, 0]))), hiddens, {}
|
||||||
|
|
||||||
|
@override(Policy)
|
||||||
|
def compute_log_likelihoods(self,
|
||||||
|
actions,
|
||||||
|
obs_batch,
|
||||||
|
state_batches=None,
|
||||||
|
prev_action_batch=None,
|
||||||
|
prev_reward_batch=None):
|
||||||
|
obs_batch, action_mask, _ = self._unpack_observation(obs_batch)
|
||||||
|
return np.zeros(obs_batch.size()[0])
|
||||||
|
|
||||||
@override(Policy)
|
@override(Policy)
|
||||||
def learn_on_batch(self, samples):
|
def learn_on_batch(self, samples):
|
||||||
obs_batch, action_mask, env_global_state = self._unpack_observation(
|
obs_batch, action_mask, env_global_state = self._unpack_observation(
|
||||||
|
@ -323,31 +336,32 @@ class QMixTorchPolicy(Policy):
|
||||||
|
|
||||||
def to_batches(arr, dtype):
|
def to_batches(arr, dtype):
|
||||||
new_shape = [B, T] + list(arr.shape[1:])
|
new_shape = [B, T] + list(arr.shape[1:])
|
||||||
return th.as_tensor(
|
return torch.as_tensor(
|
||||||
np.reshape(arr, new_shape), dtype=dtype, device=self.device)
|
np.reshape(arr, new_shape), dtype=dtype, device=self.device)
|
||||||
|
|
||||||
rewards = to_batches(rew, th.float)
|
rewards = to_batches(rew, torch.float)
|
||||||
actions = to_batches(act, th.long)
|
actions = to_batches(act, torch.long)
|
||||||
obs = to_batches(obs, th.float).reshape(
|
obs = to_batches(obs, torch.float).reshape(
|
||||||
[B, T, self.n_agents, self.obs_size])
|
[B, T, self.n_agents, self.obs_size])
|
||||||
action_mask = to_batches(action_mask, th.float)
|
action_mask = to_batches(action_mask, torch.float)
|
||||||
next_obs = to_batches(next_obs, th.float).reshape(
|
next_obs = to_batches(next_obs, torch.float).reshape(
|
||||||
[B, T, self.n_agents, self.obs_size])
|
[B, T, self.n_agents, self.obs_size])
|
||||||
next_action_mask = to_batches(next_action_mask, th.float)
|
next_action_mask = to_batches(next_action_mask, torch.float)
|
||||||
if self.has_env_global_state:
|
if self.has_env_global_state:
|
||||||
env_global_state = to_batches(env_global_state, th.float)
|
env_global_state = to_batches(env_global_state, torch.float)
|
||||||
next_env_global_state = to_batches(next_env_global_state, th.float)
|
next_env_global_state = to_batches(next_env_global_state,
|
||||||
|
torch.float)
|
||||||
|
|
||||||
# TODO(ekl) this treats group termination as individual termination
|
# TODO(ekl) this treats group termination as individual termination
|
||||||
terminated = to_batches(dones, th.float).unsqueeze(2).expand(
|
terminated = to_batches(dones, torch.float).unsqueeze(2).expand(
|
||||||
B, T, self.n_agents)
|
B, T, self.n_agents)
|
||||||
|
|
||||||
# Create mask for where index is < unpadded sequence length
|
# Create mask for where index is < unpadded sequence length
|
||||||
filled = np.reshape(
|
filled = np.reshape(
|
||||||
np.tile(np.arange(T, dtype=np.float32), B),
|
np.tile(np.arange(T, dtype=np.float32), B),
|
||||||
[B, T]) < np.expand_dims(seq_lens, 1)
|
[B, T]) < np.expand_dims(seq_lens, 1)
|
||||||
mask = th.as_tensor(
|
mask = torch.as_tensor(
|
||||||
filled, dtype=th.float, device=self.device).unsqueeze(2).expand(
|
filled, dtype=torch.float, device=self.device).unsqueeze(2).expand(
|
||||||
B, T, self.n_agents)
|
B, T, self.n_agents)
|
||||||
|
|
||||||
# Compute loss
|
# Compute loss
|
||||||
|
@ -359,7 +373,7 @@ class QMixTorchPolicy(Policy):
|
||||||
# Optimise
|
# Optimise
|
||||||
self.optimiser.zero_grad()
|
self.optimiser.zero_grad()
|
||||||
loss_out.backward()
|
loss_out.backward()
|
||||||
grad_norm = th.nn.utils.clip_grad_norm_(
|
grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||||
self.params, self.config["grad_norm_clipping"])
|
self.params, self.config["grad_norm_clipping"])
|
||||||
self.optimiser.step()
|
self.optimiser.step()
|
||||||
|
|
||||||
|
@ -432,7 +446,7 @@ class QMixTorchPolicy(Policy):
|
||||||
|
|
||||||
def _device_dict(self, state_dict):
|
def _device_dict(self, state_dict):
|
||||||
return {
|
return {
|
||||||
k: th.as_tensor(v, device=self.device)
|
k: torch.as_tensor(v, device=self.device)
|
||||||
for k, v in state_dict.items()
|
for k, v in state_dict.items()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -539,7 +553,7 @@ def _unroll_mac(model, obs_tensor):
|
||||||
for t in range(T):
|
for t in range(T):
|
||||||
q, h = _mac(model, obs_tensor[:, t], h)
|
q, h = _mac(model, obs_tensor[:, t], h)
|
||||||
mac_out.append(q)
|
mac_out.append(q)
|
||||||
mac_out = th.stack(mac_out, dim=1) # Concat over time
|
mac_out = torch.stack(mac_out, dim=1) # Concat over time
|
||||||
|
|
||||||
return mac_out
|
return mac_out
|
||||||
|
|
||||||
|
|
|
@ -80,7 +80,7 @@ class SACModel(TFModelV2):
|
||||||
shift_and_log_scale_diag = tf.keras.Sequential([
|
shift_and_log_scale_diag = tf.keras.Sequential([
|
||||||
tf.keras.layers.Dense(
|
tf.keras.layers.Dense(
|
||||||
units=hidden,
|
units=hidden,
|
||||||
activation=getattr(tf.nn, actor_hidden_activation),
|
activation=getattr(tf.nn, actor_hidden_activation, None),
|
||||||
name="action_hidden_{}".format(i))
|
name="action_hidden_{}".format(i))
|
||||||
for i, hidden in enumerate(actor_hiddens)
|
for i, hidden in enumerate(actor_hiddens)
|
||||||
] + [
|
] + [
|
||||||
|
|
|
@ -86,6 +86,30 @@ def postprocess_trajectory(policy,
|
||||||
return postprocess_nstep_and_prio(policy, sample_batch)
|
return postprocess_nstep_and_prio(policy, sample_batch)
|
||||||
|
|
||||||
|
|
||||||
|
def unsquash_actions(actions, action_space):
|
||||||
|
# Use sigmoid to scale to [0,1], but also double magnitude of input to
|
||||||
|
# emulate behaviour of tanh activation used in SAC and TD3 papers.
|
||||||
|
sigmoid_out = tf.nn.sigmoid(2 * actions)
|
||||||
|
# Rescale to actual env policy scale
|
||||||
|
# (shape of sigmoid_out is [batch_size, dim_actions], so we reshape to
|
||||||
|
# get same dims)
|
||||||
|
action_range = (action_space.high - action_space.low)[None]
|
||||||
|
low_action = action_space.low[None]
|
||||||
|
unsquashed_actions = action_range * sigmoid_out + low_action
|
||||||
|
|
||||||
|
return unsquashed_actions
|
||||||
|
|
||||||
|
|
||||||
|
def get_log_likelihood(policy, model, actions, input_dict, obs_space,
|
||||||
|
action_space, config):
|
||||||
|
model_out, _ = model({
|
||||||
|
"obs": input_dict[SampleBatch.CUR_OBS],
|
||||||
|
"is_training": policy._get_is_training_placeholder(),
|
||||||
|
}, [], None)
|
||||||
|
log_pis = policy.model.log_pis_model((model_out, actions))
|
||||||
|
return log_pis
|
||||||
|
|
||||||
|
|
||||||
def build_action_output(policy, model, input_dict, obs_space, action_space,
|
def build_action_output(policy, model, input_dict, obs_space, action_space,
|
||||||
explore, config, timestep):
|
explore, config, timestep):
|
||||||
model_out, _ = model({
|
model_out, _ = model({
|
||||||
|
@ -93,28 +117,16 @@ def build_action_output(policy, model, input_dict, obs_space, action_space,
|
||||||
"is_training": policy._get_is_training_placeholder(),
|
"is_training": policy._get_is_training_placeholder(),
|
||||||
}, [], None)
|
}, [], None)
|
||||||
|
|
||||||
def unsquash_actions(actions):
|
|
||||||
# Use sigmoid to scale to [0,1], but also double magnitude of input to
|
|
||||||
# emulate behaviour of tanh activation used in SAC and TD3 papers.
|
|
||||||
sigmoid_out = tf.nn.sigmoid(2 * actions)
|
|
||||||
# Rescale to actual env policy scale
|
|
||||||
# (shape of sigmoid_out is [batch_size, dim_actions], so we reshape to
|
|
||||||
# get same dims)
|
|
||||||
action_range = (action_space.high - action_space.low)[None]
|
|
||||||
low_action = action_space.low[None]
|
|
||||||
unsquashed_actions = action_range * sigmoid_out + low_action
|
|
||||||
|
|
||||||
return unsquashed_actions
|
|
||||||
|
|
||||||
squashed_stochastic_actions, log_pis = policy.model.get_policy_output(
|
squashed_stochastic_actions, log_pis = policy.model.get_policy_output(
|
||||||
model_out, deterministic=False)
|
model_out, deterministic=False)
|
||||||
stochastic_actions = squashed_stochastic_actions if config[
|
stochastic_actions = squashed_stochastic_actions if config[
|
||||||
"normalize_actions"] else unsquash_actions(squashed_stochastic_actions)
|
"normalize_actions"] else unsquash_actions(squashed_stochastic_actions,
|
||||||
|
action_space)
|
||||||
squashed_deterministic_actions, _ = policy.model.get_policy_output(
|
squashed_deterministic_actions, _ = policy.model.get_policy_output(
|
||||||
model_out, deterministic=True)
|
model_out, deterministic=True)
|
||||||
deterministic_actions = squashed_deterministic_actions if config[
|
deterministic_actions = squashed_deterministic_actions if config[
|
||||||
"normalize_actions"] else unsquash_actions(
|
"normalize_actions"] else unsquash_actions(
|
||||||
squashed_deterministic_actions)
|
squashed_deterministic_actions, action_space)
|
||||||
|
|
||||||
actions = tf.cond(
|
actions = tf.cond(
|
||||||
tf.constant(explore) if isinstance(explore, bool) else explore,
|
tf.constant(explore) if isinstance(explore, bool) else explore,
|
||||||
|
@ -409,6 +421,7 @@ SACTFPolicy = build_tf_policy(
|
||||||
make_model=build_sac_model,
|
make_model=build_sac_model,
|
||||||
postprocess_fn=postprocess_trajectory,
|
postprocess_fn=postprocess_trajectory,
|
||||||
action_sampler_fn=build_action_output,
|
action_sampler_fn=build_action_output,
|
||||||
|
log_likelihood_fn=get_log_likelihood,
|
||||||
loss_fn=actor_critic_loss,
|
loss_fn=actor_critic_loss,
|
||||||
stats_fn=stats,
|
stats_fn=stats,
|
||||||
gradients_fn=gradients,
|
gradients_fn=gradients,
|
||||||
|
|
|
@ -245,7 +245,7 @@ class MADDPGTFPolicy(MADDPGPostprocessing, TFPolicy):
|
||||||
config=config,
|
config=config,
|
||||||
sess=self.sess,
|
sess=self.sess,
|
||||||
obs_input=obs_ph_n[agent_id],
|
obs_input=obs_ph_n[agent_id],
|
||||||
action_sampler=act_sampler,
|
sampled_action=act_sampler,
|
||||||
loss=actor_loss + critic_loss,
|
loss=actor_loss + critic_loss,
|
||||||
loss_inputs=loss_inputs)
|
loss_inputs=loss_inputs)
|
||||||
|
|
||||||
|
@ -339,9 +339,7 @@ class MADDPGTFPolicy(MADDPGPostprocessing, TFPolicy):
|
||||||
out = tf.concat(obs_n + act_n, axis=1)
|
out = tf.concat(obs_n + act_n, axis=1)
|
||||||
|
|
||||||
for hidden in hiddens:
|
for hidden in hiddens:
|
||||||
out = tf.layers.dense(
|
out = tf.layers.dense(out, units=hidden, activation=activation)
|
||||||
out, units=hidden, activation=activation
|
|
||||||
)
|
|
||||||
feature = out
|
feature = out
|
||||||
out = tf.layers.dense(feature, units=1, activation=None)
|
out = tf.layers.dense(feature, units=1, activation=None)
|
||||||
|
|
||||||
|
@ -367,9 +365,7 @@ class MADDPGTFPolicy(MADDPGPostprocessing, TFPolicy):
|
||||||
out = obs
|
out = obs
|
||||||
|
|
||||||
for hidden in hiddens:
|
for hidden in hiddens:
|
||||||
out = tf.layers.dense(
|
out = tf.layers.dense(out, units=hidden, activation=activation)
|
||||||
out, units=hidden, activation=activation
|
|
||||||
)
|
|
||||||
feature = tf.layers.dense(
|
feature = tf.layers.dense(
|
||||||
out, units=act_space.shape[0], activation=None)
|
out, units=act_space.shape[0], activation=None)
|
||||||
sampler = tfp.distributions.RelaxedOneHotCategorical(
|
sampler = tfp.distributions.RelaxedOneHotCategorical(
|
||||||
|
|
|
@ -10,10 +10,10 @@ import gym
|
||||||
|
|
||||||
import ray
|
import ray
|
||||||
from ray import tune
|
from ray import tune
|
||||||
from ray.rllib.policy import Policy
|
|
||||||
from ray.rllib.evaluation import RolloutWorker
|
from ray.rllib.evaluation import RolloutWorker
|
||||||
from ray.rllib.evaluation.metrics import collect_metrics
|
from ray.rllib.evaluation.metrics import collect_metrics
|
||||||
from ray.rllib.policy.sample_batch import SampleBatch
|
from ray.rllib.policy.sample_batch import SampleBatch
|
||||||
|
from ray.rllib.policy.tests.test_policy import TestPolicy
|
||||||
|
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument("--gpu", action="store_true")
|
parser.add_argument("--gpu", action="store_true")
|
||||||
|
@ -22,7 +22,7 @@ parser.add_argument("--num-workers", type=int, default=2)
|
||||||
parser.add_argument("--num-cpus", type=int, default=0)
|
parser.add_argument("--num-cpus", type=int, default=0)
|
||||||
|
|
||||||
|
|
||||||
class CustomPolicy(Policy):
|
class CustomPolicy(TestPolicy):
|
||||||
"""Example of a custom policy written from scratch.
|
"""Example of a custom policy written from scratch.
|
||||||
|
|
||||||
You might find it more convenient to extend TF/TorchPolicy instead
|
You might find it more convenient to extend TF/TorchPolicy instead
|
||||||
|
@ -30,7 +30,7 @@ class CustomPolicy(Policy):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, observation_space, action_space, config):
|
def __init__(self, observation_space, action_space, config):
|
||||||
Policy.__init__(self, observation_space, action_space, config)
|
super().__init__(observation_space, action_space, config)
|
||||||
# example parameter
|
# example parameter
|
||||||
self.w = 1.0
|
self.w = 1.0
|
||||||
|
|
||||||
|
|
|
@ -1,14 +1,22 @@
|
||||||
import unittest
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
from gym.spaces import Box
|
||||||
|
from scipy.stats import norm
|
||||||
|
from tensorflow.python.eager.context import eager_mode
|
||||||
|
import unittest
|
||||||
|
|
||||||
from ray.rllib.models.tf.tf_action_dist import Categorical
|
from ray.rllib.models.tf.tf_action_dist import Categorical, SquashedGaussian
|
||||||
from ray.rllib.utils import try_import_tf
|
from ray.rllib.utils import try_import_tf
|
||||||
|
from ray.rllib.utils.numpy import MIN_LOG_NN_OUTPUT, MAX_LOG_NN_OUTPUT
|
||||||
|
from ray.rllib.utils.test_utils import check
|
||||||
|
|
||||||
tf = try_import_tf()
|
tf = try_import_tf()
|
||||||
|
|
||||||
|
|
||||||
class TestDistributions(unittest.TestCase):
|
class TestDistributions(unittest.TestCase):
|
||||||
|
"""Tests ActionDistribution classes."""
|
||||||
|
|
||||||
def test_categorical(self):
|
def test_categorical(self):
|
||||||
|
"""Tests the Categorical ActionDistribution (tf only)."""
|
||||||
num_samples = 100000
|
num_samples = 100000
|
||||||
logits = tf.placeholder(tf.float32, shape=(None, 10))
|
logits = tf.placeholder(tf.float32, shape=(None, 10))
|
||||||
z = 8 * (np.random.rand(10) - 0.5)
|
z = 8 * (np.random.rand(10) - 0.5)
|
||||||
|
@ -24,6 +32,76 @@ class TestDistributions(unittest.TestCase):
|
||||||
probs = np.exp(z) / np.sum(np.exp(z))
|
probs = np.exp(z) / np.sum(np.exp(z))
|
||||||
self.assertTrue(np.sum(np.abs(probs - counts / num_samples)) <= 0.01)
|
self.assertTrue(np.sum(np.abs(probs - counts / num_samples)) <= 0.01)
|
||||||
|
|
||||||
|
def test_squashed_gaussian(self):
|
||||||
|
"""Tests the SquashedGaussia ActionDistribution (tf-eager only)."""
|
||||||
|
with eager_mode():
|
||||||
|
input_space = Box(-1.0, 1.0, shape=(200, 10))
|
||||||
|
low, high = -2.0, 1.0
|
||||||
|
|
||||||
|
# Batch of size=n and deterministic.
|
||||||
|
inputs = input_space.sample()
|
||||||
|
means, _ = np.split(inputs, 2, axis=-1)
|
||||||
|
squashed_distribution = SquashedGaussian(
|
||||||
|
inputs, {}, low=low, high=high)
|
||||||
|
expected = ((np.tanh(means) + 1.0) / 2.0) * (high - low) + low
|
||||||
|
# Sample n times, expect always mean value (deterministic draw).
|
||||||
|
out = squashed_distribution.deterministic_sample()
|
||||||
|
check(out, expected)
|
||||||
|
|
||||||
|
# Batch of size=n and non-deterministic -> expect roughly the mean.
|
||||||
|
inputs = input_space.sample()
|
||||||
|
means, log_stds = np.split(inputs, 2, axis=-1)
|
||||||
|
squashed_distribution = SquashedGaussian(
|
||||||
|
inputs, {}, low=low, high=high)
|
||||||
|
expected = ((np.tanh(means) + 1.0) / 2.0) * (high - low) + low
|
||||||
|
values = squashed_distribution.sample()
|
||||||
|
self.assertTrue(np.max(values) < high)
|
||||||
|
self.assertTrue(np.min(values) > low)
|
||||||
|
|
||||||
|
check(np.mean(values), expected.mean(), decimals=1)
|
||||||
|
|
||||||
|
# Test log-likelihood outputs.
|
||||||
|
sampled_action_logp = squashed_distribution.sampled_action_logp()
|
||||||
|
# Convert to parameters for distr.
|
||||||
|
stds = np.exp(
|
||||||
|
np.clip(log_stds, MIN_LOG_NN_OUTPUT, MAX_LOG_NN_OUTPUT))
|
||||||
|
# Unsquash values, then get log-llh from regular gaussian.
|
||||||
|
unsquashed_values = np.arctanh((values - low) /
|
||||||
|
(high - low) * 2.0 - 1.0)
|
||||||
|
log_prob_unsquashed = \
|
||||||
|
np.sum(np.log(norm.pdf(unsquashed_values, means, stds)), -1)
|
||||||
|
log_prob = log_prob_unsquashed - \
|
||||||
|
np.sum(np.log(1 - np.tanh(unsquashed_values) ** 2),
|
||||||
|
axis=-1)
|
||||||
|
check(np.mean(sampled_action_logp), np.mean(log_prob), rtol=0.01)
|
||||||
|
|
||||||
|
# NN output.
|
||||||
|
means = np.array([[0.1, 0.2, 0.3, 0.4, 50.0],
|
||||||
|
[-0.1, -0.2, -0.3, -0.4, -1.0]])
|
||||||
|
log_stds = np.array([[0.8, -0.2, 0.3, -1.0, 2.0],
|
||||||
|
[0.7, -0.3, 0.4, -0.9, 2.0]])
|
||||||
|
squashed_distribution = SquashedGaussian(
|
||||||
|
np.concatenate([means, log_stds], axis=-1), {},
|
||||||
|
low=low,
|
||||||
|
high=high)
|
||||||
|
# Convert to parameters for distr.
|
||||||
|
stds = np.exp(log_stds)
|
||||||
|
# Values to get log-likelihoods for.
|
||||||
|
values = np.array([[0.9, 0.2, 0.4, -0.1, -1.05],
|
||||||
|
[-0.9, -0.2, 0.4, -0.1, -1.05]])
|
||||||
|
|
||||||
|
# Unsquash values, then get log-llh from regular gaussian.
|
||||||
|
unsquashed_values = np.arctanh((values - low) /
|
||||||
|
(high - low) * 2.0 - 1.0)
|
||||||
|
log_prob_unsquashed = \
|
||||||
|
np.sum(np.log(norm.pdf(unsquashed_values, means, stds)), -1)
|
||||||
|
log_prob = log_prob_unsquashed - \
|
||||||
|
np.sum(np.log(1 - np.tanh(unsquashed_values) ** 2),
|
||||||
|
axis=-1)
|
||||||
|
|
||||||
|
out = squashed_distribution.logp(values)
|
||||||
|
check(out, log_prob)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
import unittest
|
import unittest
|
||||||
|
|
|
@ -3,10 +3,12 @@ import functools
|
||||||
|
|
||||||
from ray.rllib.models.action_dist import ActionDistribution
|
from ray.rllib.models.action_dist import ActionDistribution
|
||||||
from ray.rllib.utils.annotations import override, DeveloperAPI
|
from ray.rllib.utils.annotations import override, DeveloperAPI
|
||||||
from ray.rllib.utils import try_import_tf
|
from ray.rllib.utils import try_import_tf, try_import_tfp, SMALL_NUMBER, \
|
||||||
|
MIN_LOG_NN_OUTPUT, MAX_LOG_NN_OUTPUT
|
||||||
from ray.rllib.utils.tuple_actions import TupleActions
|
from ray.rllib.utils.tuple_actions import TupleActions
|
||||||
|
|
||||||
tf = try_import_tf()
|
tf = try_import_tf()
|
||||||
|
tfp = try_import_tfp()
|
||||||
|
|
||||||
|
|
||||||
@DeveloperAPI
|
@DeveloperAPI
|
||||||
|
@ -188,6 +190,79 @@ class DiagGaussian(TFActionDistribution):
|
||||||
return np.prod(action_space.shape) * 2
|
return np.prod(action_space.shape) * 2
|
||||||
|
|
||||||
|
|
||||||
|
class SquashedGaussian(TFActionDistribution):
|
||||||
|
"""A tanh-squashed Gaussian distribution defined by: mean, std, low, high.
|
||||||
|
|
||||||
|
The distribution will never return low or high exactly, but
|
||||||
|
`low`+SMALL_NUMBER or `high`-SMALL_NUMBER respectively.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, inputs, model, low=-1.0, high=1.0):
|
||||||
|
"""Parameterizes the distribution via `inputs`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
low (float): The lowest possible sampling value
|
||||||
|
(excluding this value).
|
||||||
|
high (float): The highest possible sampling value
|
||||||
|
(excluding this value).
|
||||||
|
"""
|
||||||
|
assert tfp is not None
|
||||||
|
loc, log_scale = tf.split(inputs, 2, axis=-1)
|
||||||
|
# Clip `scale` values (coming from NN) to reasonable values.
|
||||||
|
log_scale = tf.clip_by_value(log_scale, MIN_LOG_NN_OUTPUT,
|
||||||
|
MAX_LOG_NN_OUTPUT)
|
||||||
|
scale = tf.exp(log_scale)
|
||||||
|
self.distr = tfp.distributions.Normal(loc=loc, scale=scale)
|
||||||
|
assert np.all(np.less(low, high))
|
||||||
|
self.low = low
|
||||||
|
self.high = high
|
||||||
|
super().__init__(inputs, model)
|
||||||
|
|
||||||
|
@override(TFActionDistribution)
|
||||||
|
def sampled_action_logp(self):
|
||||||
|
unsquashed_values = self._unsquash(self.sample_op)
|
||||||
|
log_prob = tf.reduce_sum(
|
||||||
|
self.distr.log_prob(unsquashed_values), axis=-1)
|
||||||
|
unsquashed_values_tanhd = tf.math.tanh(unsquashed_values)
|
||||||
|
log_prob -= tf.math.reduce_sum(
|
||||||
|
tf.math.log(1 - unsquashed_values_tanhd**2 + SMALL_NUMBER),
|
||||||
|
axis=-1)
|
||||||
|
return log_prob
|
||||||
|
|
||||||
|
@override(ActionDistribution)
|
||||||
|
def deterministic_sample(self):
|
||||||
|
mean = self.distr.mean()
|
||||||
|
return self._squash(mean)
|
||||||
|
|
||||||
|
@override(TFActionDistribution)
|
||||||
|
def _build_sample_op(self):
|
||||||
|
return self._squash(self.distr.sample())
|
||||||
|
|
||||||
|
@override(ActionDistribution)
|
||||||
|
def logp(self, x):
|
||||||
|
unsquashed_values = self._unsquash(x)
|
||||||
|
log_prob = tf.reduce_sum(
|
||||||
|
self.distr.log_prob(value=unsquashed_values), axis=-1)
|
||||||
|
unsquashed_values_tanhd = tf.math.tanh(unsquashed_values)
|
||||||
|
log_prob -= tf.math.reduce_sum(
|
||||||
|
tf.math.log(1 - unsquashed_values_tanhd**2 + SMALL_NUMBER),
|
||||||
|
axis=-1)
|
||||||
|
return log_prob
|
||||||
|
|
||||||
|
def _squash(self, raw_values):
|
||||||
|
# Make sure raw_values are not too high/low (such that tanh would
|
||||||
|
# return exactly 1.0/-1.0, which would lead to +/-inf log-probs).
|
||||||
|
return (tf.clip_by_value(
|
||||||
|
tf.math.tanh(raw_values),
|
||||||
|
-1.0 + SMALL_NUMBER,
|
||||||
|
1.0 - SMALL_NUMBER) + 1.0) / 2.0 * (self.high - self.low) + \
|
||||||
|
self.low
|
||||||
|
|
||||||
|
def _unsquash(self, values):
|
||||||
|
return tf.math.atanh((values - self.low) /
|
||||||
|
(self.high - self.low) * 2.0 - 1.0)
|
||||||
|
|
||||||
|
|
||||||
class Deterministic(TFActionDistribution):
|
class Deterministic(TFActionDistribution):
|
||||||
"""Action distribution that returns the input values directly.
|
"""Action distribution that returns the input values directly.
|
||||||
|
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
from collections import namedtuple
|
from collections import namedtuple
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
from ray.rllib.policy.sample_batch import MultiAgentBatch
|
from ray.rllib.policy.sample_batch import MultiAgentBatch, SampleBatch
|
||||||
from ray.rllib.utils.annotations import DeveloperAPI
|
from ray.rllib.utils.annotations import DeveloperAPI
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
@ -57,24 +57,13 @@ class OffPolicyEstimator:
|
||||||
if k.startswith("state_in_"):
|
if k.startswith("state_in_"):
|
||||||
num_state_inputs += 1
|
num_state_inputs += 1
|
||||||
state_keys = ["state_in_{}".format(i) for i in range(num_state_inputs)]
|
state_keys = ["state_in_{}".format(i) for i in range(num_state_inputs)]
|
||||||
|
log_likelihoods = self.policy.compute_log_likelihoods(
|
||||||
# TODO(sven): This is wrong. The info["action_prob"] needs to refer
|
actions=batch[SampleBatch.ACTIONS],
|
||||||
# to the old action (from the batch). It might be the action-prob of
|
obs_batch=batch[SampleBatch.CUR_OBS],
|
||||||
# a different action (as the policy has changed).
|
|
||||||
# https://github.com/ray-project/ray/issues/7107
|
|
||||||
_, _, info = self.policy.compute_actions(
|
|
||||||
obs_batch=batch["obs"],
|
|
||||||
state_batches=[batch[k] for k in state_keys],
|
state_batches=[batch[k] for k in state_keys],
|
||||||
prev_action_batch=batch.data.get("prev_action"),
|
prev_action_batch=batch.data.get(SampleBatch.PREV_ACTIONS),
|
||||||
prev_reward_batch=batch.data.get("prev_reward"),
|
prev_reward_batch=batch.data.get(SampleBatch.PREV_REWARDS))
|
||||||
info_batch=batch.data.get("info"))
|
return log_likelihoods
|
||||||
if "action_prob" not in info:
|
|
||||||
raise ValueError(
|
|
||||||
"Off-policy estimation is not possible unless the policy "
|
|
||||||
"returns action probabilities when computing actions (i.e., "
|
|
||||||
"the 'action_prob' key is output by the policy). You "
|
|
||||||
"can set `input_evaluation: []` to resolve this.")
|
|
||||||
return info["action_prob"]
|
|
||||||
|
|
||||||
@DeveloperAPI
|
@DeveloperAPI
|
||||||
def process(self, batch):
|
def process(self, batch):
|
||||||
|
|
|
@ -47,6 +47,7 @@ class DynamicTFPolicy(TFPolicy):
|
||||||
before_loss_init=None,
|
before_loss_init=None,
|
||||||
make_model=None,
|
make_model=None,
|
||||||
action_sampler_fn=None,
|
action_sampler_fn=None,
|
||||||
|
log_likelihood_fn=None,
|
||||||
existing_inputs=None,
|
existing_inputs=None,
|
||||||
existing_model=None,
|
existing_model=None,
|
||||||
get_batch_divisibility_req=None,
|
get_batch_divisibility_req=None,
|
||||||
|
@ -69,10 +70,14 @@ class DynamicTFPolicy(TFPolicy):
|
||||||
given (policy, obs_space, action_space, config).
|
given (policy, obs_space, action_space, config).
|
||||||
All policy variables should be created in this function. If not
|
All policy variables should be created in this function. If not
|
||||||
specified, a default model will be created.
|
specified, a default model will be created.
|
||||||
action_sampler_fn (func): optional function that returns a
|
action_sampler_fn (Optional[callable]): An optional callable
|
||||||
tuple of action and action logp tensors given
|
returning a tuple of action and action prob tensors given
|
||||||
(policy, model, input_dict, obs_space, action_space, config).
|
(policy, model, input_dict, obs_space, action_space, config).
|
||||||
If not specified, a default action distribution will be used.
|
If None, a default action distribution will be used.
|
||||||
|
log_likelihood_fn (Optional[callable]): A callable,
|
||||||
|
returning a log-likelihood op.
|
||||||
|
If None, a default class is used and distribution inputs
|
||||||
|
(for parameterization) will be generated by a model call.
|
||||||
existing_inputs (OrderedDict): When copying a policy, this
|
existing_inputs (OrderedDict): When copying a policy, this
|
||||||
specifies an existing dict of placeholders to use instead of
|
specifies an existing dict of placeholders to use instead of
|
||||||
defining new ones
|
defining new ones
|
||||||
|
@ -98,11 +103,13 @@ class DynamicTFPolicy(TFPolicy):
|
||||||
if self._obs_include_prev_action_reward:
|
if self._obs_include_prev_action_reward:
|
||||||
prev_actions = existing_inputs[SampleBatch.PREV_ACTIONS]
|
prev_actions = existing_inputs[SampleBatch.PREV_ACTIONS]
|
||||||
prev_rewards = existing_inputs[SampleBatch.PREV_REWARDS]
|
prev_rewards = existing_inputs[SampleBatch.PREV_REWARDS]
|
||||||
|
action_input = existing_inputs[SampleBatch.ACTIONS]
|
||||||
else:
|
else:
|
||||||
obs = tf.placeholder(
|
obs = tf.placeholder(
|
||||||
tf.float32,
|
tf.float32,
|
||||||
shape=[None] + list(obs_space.shape),
|
shape=[None] + list(obs_space.shape),
|
||||||
name="observation")
|
name="observation")
|
||||||
|
action_input = ModelCatalog.get_action_placeholder(action_space)
|
||||||
if self._obs_include_prev_action_reward:
|
if self._obs_include_prev_action_reward:
|
||||||
prev_actions = ModelCatalog.get_action_placeholder(
|
prev_actions = ModelCatalog.get_action_placeholder(
|
||||||
action_space, "prev_action")
|
action_space, "prev_action")
|
||||||
|
@ -121,16 +128,16 @@ class DynamicTFPolicy(TFPolicy):
|
||||||
self._seq_lens = tf.placeholder(
|
self._seq_lens = tf.placeholder(
|
||||||
dtype=tf.int32, shape=[None], name="seq_lens")
|
dtype=tf.int32, shape=[None], name="seq_lens")
|
||||||
|
|
||||||
# Setup model
|
|
||||||
if action_sampler_fn:
|
if action_sampler_fn:
|
||||||
if not make_model:
|
if not make_model:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"make_model is required if action_sampler_fn is given")
|
"`make_model` is required if `action_sampler_fn` is given")
|
||||||
self.dist_class = None
|
self.dist_class = None
|
||||||
else:
|
else:
|
||||||
self.dist_class, logit_dim = ModelCatalog.get_action_dist(
|
self.dist_class, logit_dim = ModelCatalog.get_action_dist(
|
||||||
action_space, self.config["model"])
|
action_space, self.config["model"])
|
||||||
|
|
||||||
|
# Setup model
|
||||||
if existing_model:
|
if existing_model:
|
||||||
self.model = existing_model
|
self.model = existing_model
|
||||||
elif make_model:
|
elif make_model:
|
||||||
|
@ -165,35 +172,49 @@ class DynamicTFPolicy(TFPolicy):
|
||||||
|
|
||||||
# Setup custom action sampler.
|
# Setup custom action sampler.
|
||||||
if action_sampler_fn:
|
if action_sampler_fn:
|
||||||
action_sampler, action_logp = action_sampler_fn(
|
sampled_action, sampled_action_logp = action_sampler_fn(
|
||||||
self, self.model, self._input_dict, obs_space, action_space,
|
self, self.model, self._input_dict, obs_space, action_space,
|
||||||
explore, config, timestep)
|
explore, config, timestep)
|
||||||
# Create a default action sampler.
|
# Create a default action sampler.
|
||||||
else:
|
else:
|
||||||
# Using an exporation setup.
|
# Using an exploration setup.
|
||||||
action_sampler, action_logp = \
|
sampled_action, sampled_action_logp = \
|
||||||
self.exploration.get_exploration_action(
|
self.exploration.get_exploration_action(
|
||||||
model_out,
|
model_out,
|
||||||
|
self.dist_class,
|
||||||
self.model,
|
self.model,
|
||||||
action_dist_class=self.dist_class,
|
|
||||||
explore=explore,
|
explore=explore,
|
||||||
timestep=timestep)
|
timestep=timestep)
|
||||||
|
|
||||||
# Phase 1 init
|
# Phase 1 init.
|
||||||
sess = tf.get_default_session() or tf.Session()
|
sess = tf.get_default_session() or tf.Session()
|
||||||
if get_batch_divisibility_req:
|
if get_batch_divisibility_req:
|
||||||
batch_divisibility_req = get_batch_divisibility_req(self)
|
batch_divisibility_req = get_batch_divisibility_req(self)
|
||||||
else:
|
else:
|
||||||
batch_divisibility_req = 1
|
batch_divisibility_req = 1
|
||||||
|
|
||||||
|
# Generate the log-likelihood op.
|
||||||
|
log_likelihood = None
|
||||||
|
# From a given function.
|
||||||
|
if log_likelihood_fn:
|
||||||
|
log_likelihood = log_likelihood_fn(self, self.model, action_input,
|
||||||
|
self._input_dict, obs_space,
|
||||||
|
action_space, config)
|
||||||
|
# Create default, iff we have a distribution class.
|
||||||
|
elif self.dist_class is not None:
|
||||||
|
log_likelihood = self.dist_class(model_out,
|
||||||
|
self.model).logp(action_input)
|
||||||
|
|
||||||
super().__init__(
|
super().__init__(
|
||||||
obs_space,
|
obs_space,
|
||||||
action_space,
|
action_space,
|
||||||
config,
|
config,
|
||||||
sess,
|
sess,
|
||||||
obs_input=obs,
|
obs_input=obs,
|
||||||
action_sampler=action_sampler,
|
action_input=action_input, # for logp calculations
|
||||||
action_logp=action_logp,
|
sampled_action=sampled_action,
|
||||||
|
sampled_action_logp=sampled_action_logp,
|
||||||
|
log_likelihood=log_likelihood,
|
||||||
loss=None, # dynamically initialized on run
|
loss=None, # dynamically initialized on run
|
||||||
loss_inputs=[],
|
loss_inputs=[],
|
||||||
model=self.model,
|
model=self.model,
|
||||||
|
|
|
@ -14,7 +14,7 @@ from ray.rllib.policy.policy import ACTION_PROB, ACTION_LOGP
|
||||||
from ray.rllib.utils import add_mixins
|
from ray.rllib.utils import add_mixins
|
||||||
from ray.rllib.utils.annotations import override
|
from ray.rllib.utils.annotations import override
|
||||||
from ray.rllib.utils.debug import log_once
|
from ray.rllib.utils.debug import log_once
|
||||||
from ray.rllib.utils import try_import_tf
|
from ray.rllib.utils.framework import try_import_tf
|
||||||
|
|
||||||
tf = try_import_tf()
|
tf = try_import_tf()
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
@ -176,13 +176,14 @@ def build_eager_tf_policy(name,
|
||||||
after_init=None,
|
after_init=None,
|
||||||
make_model=None,
|
make_model=None,
|
||||||
action_sampler_fn=None,
|
action_sampler_fn=None,
|
||||||
|
log_likelihood_fn=None,
|
||||||
mixins=None,
|
mixins=None,
|
||||||
obs_include_prev_action_reward=True,
|
obs_include_prev_action_reward=True,
|
||||||
get_batch_divisibility_req=None):
|
get_batch_divisibility_req=None):
|
||||||
"""Build an eager TF policy.
|
"""Build an eager TF policy.
|
||||||
|
|
||||||
An eager policy runs all operations in eager mode, which makes debugging
|
An eager policy runs all operations in eager mode, which makes debugging
|
||||||
much simpler, but is lower performance.
|
much simpler, but has lower performance.
|
||||||
|
|
||||||
You shouldn't need to call this directly. Rather, prefer to build a TF
|
You shouldn't need to call this directly. Rather, prefer to build a TF
|
||||||
graph policy and use set {"eager": true} in the trainer config to have
|
graph policy and use set {"eager": true} in the trainer config to have
|
||||||
|
@ -208,12 +209,12 @@ def build_eager_tf_policy(name,
|
||||||
before_init(self, observation_space, action_space, config)
|
before_init(self, observation_space, action_space, config)
|
||||||
|
|
||||||
self.config = config
|
self.config = config
|
||||||
|
self.dist_class = None
|
||||||
|
|
||||||
if action_sampler_fn:
|
if action_sampler_fn:
|
||||||
if not make_model:
|
if not make_model:
|
||||||
raise ValueError(
|
raise ValueError("`make_model` is required if "
|
||||||
"make_model is required if action_sampler_fn is given")
|
"`action_sampler_fn` is given")
|
||||||
self.dist_class = None
|
|
||||||
else:
|
else:
|
||||||
self.dist_class, logit_dim = ModelCatalog.get_action_dist(
|
self.dist_class, logit_dim = ModelCatalog.get_action_dist(
|
||||||
action_space, self.config["model"])
|
action_space, self.config["model"])
|
||||||
|
@ -235,13 +236,14 @@ def build_eager_tf_policy(name,
|
||||||
for s in self.model.get_initial_state()
|
for s in self.model.get_initial_state()
|
||||||
]
|
]
|
||||||
|
|
||||||
self.model({
|
input_dict = {
|
||||||
SampleBatch.CUR_OBS: tf.convert_to_tensor(
|
SampleBatch.CUR_OBS: tf.convert_to_tensor(
|
||||||
np.array([observation_space.sample()])),
|
np.array([observation_space.sample()])),
|
||||||
SampleBatch.PREV_ACTIONS: tf.convert_to_tensor(
|
SampleBatch.PREV_ACTIONS: tf.convert_to_tensor(
|
||||||
[_flatten_action(action_space.sample())]),
|
[_flatten_action(action_space.sample())]),
|
||||||
SampleBatch.PREV_REWARDS: tf.convert_to_tensor([0.]),
|
SampleBatch.PREV_REWARDS: tf.convert_to_tensor([0.]),
|
||||||
}, self._state_in, tf.convert_to_tensor([1]))
|
}
|
||||||
|
self.model(input_dict, self._state_in, tf.convert_to_tensor([1]))
|
||||||
|
|
||||||
if before_loss_init:
|
if before_loss_init:
|
||||||
before_loss_init(self, observation_space, action_space, config)
|
before_loss_init(self, observation_space, action_space, config)
|
||||||
|
@ -312,8 +314,8 @@ def build_eager_tf_policy(name,
|
||||||
n = len(obs_batch)
|
n = len(obs_batch)
|
||||||
else:
|
else:
|
||||||
n = obs_batch.shape[0]
|
n = obs_batch.shape[0]
|
||||||
|
|
||||||
seq_lens = tf.ones(n, dtype=tf.int32)
|
seq_lens = tf.ones(n, dtype=tf.int32)
|
||||||
|
|
||||||
input_dict = {
|
input_dict = {
|
||||||
SampleBatch.CUR_OBS: tf.convert_to_tensor(obs_batch),
|
SampleBatch.CUR_OBS: tf.convert_to_tensor(obs_batch),
|
||||||
"is_training": tf.constant(False),
|
"is_training": tf.constant(False),
|
||||||
|
@ -326,24 +328,24 @@ def build_eager_tf_policy(name,
|
||||||
prev_reward_batch),
|
prev_reward_batch),
|
||||||
})
|
})
|
||||||
|
|
||||||
with tf.variable_creator_scope(_disallow_var_creation):
|
|
||||||
model_out, state_out = self.model(input_dict, state_batches,
|
|
||||||
seq_lens)
|
|
||||||
|
|
||||||
# Custom sampler fn given (which may handle self.exploration).
|
# Custom sampler fn given (which may handle self.exploration).
|
||||||
if action_sampler_fn is not None:
|
if action_sampler_fn is not None:
|
||||||
|
state_out = []
|
||||||
action, logp = action_sampler_fn(
|
action, logp = action_sampler_fn(
|
||||||
self, self.model, input_dict, self.observation_space,
|
self, self.model, input_dict, self.observation_space,
|
||||||
self.action_space, explore, self.config, timestep)
|
self.action_space, explore, self.config, timestep)
|
||||||
# Use Exploration object.
|
# Use Exploration object.
|
||||||
else:
|
else:
|
||||||
action, logp = self.exploration.get_exploration_action(
|
with tf.variable_creator_scope(_disallow_var_creation):
|
||||||
model_out,
|
model_out, state_out = self.model(input_dict,
|
||||||
self.model,
|
state_batches, seq_lens)
|
||||||
action_dist_class=self.dist_class,
|
action, logp = self.exploration.get_exploration_action(
|
||||||
explore=explore,
|
model_out,
|
||||||
timestep=timestep
|
self.dist_class,
|
||||||
if timestep is not None else self.global_timestep)
|
self.model,
|
||||||
|
explore=explore,
|
||||||
|
timestep=timestep
|
||||||
|
if timestep is not None else self.global_timestep)
|
||||||
|
|
||||||
extra_fetches = {}
|
extra_fetches = {}
|
||||||
if logp is not None:
|
if logp is not None:
|
||||||
|
@ -359,6 +361,41 @@ def build_eager_tf_policy(name,
|
||||||
|
|
||||||
return action, state_out, extra_fetches
|
return action, state_out, extra_fetches
|
||||||
|
|
||||||
|
@override(Policy)
|
||||||
|
def compute_log_likelihoods(self,
|
||||||
|
actions,
|
||||||
|
obs_batch,
|
||||||
|
state_batches=None,
|
||||||
|
prev_action_batch=None,
|
||||||
|
prev_reward_batch=None):
|
||||||
|
|
||||||
|
seq_lens = tf.ones(len(obs_batch), dtype=tf.int32)
|
||||||
|
input_dict = {
|
||||||
|
SampleBatch.CUR_OBS: tf.convert_to_tensor(obs_batch),
|
||||||
|
"is_training": tf.constant(False),
|
||||||
|
}
|
||||||
|
if obs_include_prev_action_reward:
|
||||||
|
input_dict.update({
|
||||||
|
SampleBatch.PREV_ACTIONS: tf.convert_to_tensor(
|
||||||
|
prev_action_batch),
|
||||||
|
SampleBatch.PREV_REWARDS: tf.convert_to_tensor(
|
||||||
|
prev_reward_batch),
|
||||||
|
})
|
||||||
|
|
||||||
|
# Custom log_likelihood function given.
|
||||||
|
if log_likelihood_fn:
|
||||||
|
log_likelihoods = log_likelihood_fn(
|
||||||
|
self, self.model, actions, input_dict,
|
||||||
|
self.observation_space, self.action_space, self.config)
|
||||||
|
# Default log-likelihood calculation.
|
||||||
|
else:
|
||||||
|
dist_inputs, _ = self.model(input_dict, state_batches,
|
||||||
|
seq_lens)
|
||||||
|
action_dist = self.dist_class(dist_inputs, self.model)
|
||||||
|
log_likelihoods = action_dist.logp(actions)
|
||||||
|
|
||||||
|
return log_likelihoods
|
||||||
|
|
||||||
@override(Policy)
|
@override(Policy)
|
||||||
def apply_gradients(self, gradients):
|
def apply_gradients(self, gradients):
|
||||||
self._apply_gradients(
|
self._apply_gradients(
|
||||||
|
|
|
@ -164,6 +164,34 @@ class Policy(metaclass=ABCMeta):
|
||||||
return action, [s[0] for s in state_out], \
|
return action, [s[0] for s in state_out], \
|
||||||
{k: v[0] for k, v in info.items()}
|
{k: v[0] for k, v in info.items()}
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
@DeveloperAPI
|
||||||
|
def compute_log_likelihoods(self,
|
||||||
|
actions,
|
||||||
|
obs_batch,
|
||||||
|
state_batches=None,
|
||||||
|
prev_action_batch=None,
|
||||||
|
prev_reward_batch=None):
|
||||||
|
"""Computes the log-prob/likelihood for a given action and observation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
actions (Union[List,np.ndarray]): Batch of actions, for which to
|
||||||
|
retrieve the log-probs/likelihoods (given all other inputs:
|
||||||
|
obs, states, ..).
|
||||||
|
obs_batch (Union[List,np.ndarray]): Batch of observations.
|
||||||
|
state_batches (Optional[list]): List of RNN state input batches,
|
||||||
|
if any.
|
||||||
|
prev_action_batch (Optional[List,np.ndarray]): Batch of previous
|
||||||
|
action values.
|
||||||
|
prev_reward_batch (Optional[List,np.ndarray]): Batch of previous
|
||||||
|
rewards.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
log-likelihoods (np.ndarray): Batch of log probs/likelihoods, with
|
||||||
|
shape: [BATCH_SIZE].
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
@DeveloperAPI
|
@DeveloperAPI
|
||||||
def postprocess_trajectory(self,
|
def postprocess_trajectory(self,
|
||||||
sample_batch,
|
sample_batch,
|
||||||
|
|
153
rllib/policy/tests/test_compute_log_likelihoods.py
Normal file
153
rllib/policy/tests/test_compute_log_likelihoods.py
Normal file
|
@ -0,0 +1,153 @@
|
||||||
|
import numpy as np
|
||||||
|
from scipy.stats import norm
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
import ray.rllib.agents.dqn as dqn
|
||||||
|
import ray.rllib.agents.ppo as ppo
|
||||||
|
import ray.rllib.agents.sac as sac
|
||||||
|
from ray.rllib.utils.framework import try_import_tf
|
||||||
|
from ray.rllib.utils.test_utils import check
|
||||||
|
from ray.rllib.utils.numpy import one_hot, fc, MIN_LOG_NN_OUTPUT, \
|
||||||
|
MAX_LOG_NN_OUTPUT
|
||||||
|
|
||||||
|
tf = try_import_tf()
|
||||||
|
|
||||||
|
|
||||||
|
def test_log_likelihood(run,
|
||||||
|
config,
|
||||||
|
prev_a=None,
|
||||||
|
continuous=False,
|
||||||
|
layer_key=("fc", (0, 4)),
|
||||||
|
logp_func=None):
|
||||||
|
config = config.copy()
|
||||||
|
# Run locally.
|
||||||
|
config["num_workers"] = 0
|
||||||
|
# Env setup.
|
||||||
|
if continuous:
|
||||||
|
env = "Pendulum-v0"
|
||||||
|
obs_batch = preprocessed_obs_batch = np.array([[0.0, 0.1, -0.1]])
|
||||||
|
else:
|
||||||
|
env = "FrozenLake-v0"
|
||||||
|
config["env_config"] = {"is_slippery": False, "map_name": "4x4"}
|
||||||
|
obs_batch = np.array([0])
|
||||||
|
preprocessed_obs_batch = one_hot(obs_batch, depth=16)
|
||||||
|
|
||||||
|
# Use Soft-Q for DQNs.
|
||||||
|
if run is dqn.DQNTrainer:
|
||||||
|
config["exploration_config"] = {"type": "SoftQ", "temperature": 0.5}
|
||||||
|
|
||||||
|
prev_r = None if prev_a is None else np.array(0.0)
|
||||||
|
|
||||||
|
# Test against all frameworks.
|
||||||
|
for fw in ["tf", "eager", "torch"]:
|
||||||
|
if run in [dqn.DQNTrainer, sac.SACTrainer] and fw == "torch":
|
||||||
|
continue
|
||||||
|
print("Testing {} with framework={}".format(run, fw))
|
||||||
|
config["eager"] = True if fw == "eager" else False
|
||||||
|
config["use_pytorch"] = True if fw == "torch" else False
|
||||||
|
|
||||||
|
trainer = run(config=config, env=env)
|
||||||
|
policy = trainer.get_policy()
|
||||||
|
vars = policy.get_weights()
|
||||||
|
# Sample n actions, then roughly check their logp against their
|
||||||
|
# counts.
|
||||||
|
num_actions = 500
|
||||||
|
actions = []
|
||||||
|
for _ in range(num_actions):
|
||||||
|
# Single action from single obs.
|
||||||
|
actions.append(
|
||||||
|
trainer.compute_action(
|
||||||
|
obs_batch[0],
|
||||||
|
prev_action=prev_a,
|
||||||
|
prev_reward=prev_r,
|
||||||
|
explore=True))
|
||||||
|
|
||||||
|
# Test 50 actions for their log-likelihoods vs expected values.
|
||||||
|
if continuous:
|
||||||
|
for idx in range(50):
|
||||||
|
a = actions[idx]
|
||||||
|
if fw == "tf" or fw == "eager":
|
||||||
|
if isinstance(vars, list):
|
||||||
|
expected_mean_logstd = fc(
|
||||||
|
fc(obs_batch, vars[layer_key[1][0]]),
|
||||||
|
vars[layer_key[1][1]])
|
||||||
|
else:
|
||||||
|
expected_mean_logstd = fc(
|
||||||
|
fc(
|
||||||
|
obs_batch,
|
||||||
|
vars["default_policy/{}_1/kernel".format(
|
||||||
|
layer_key[0])]),
|
||||||
|
vars["default_policy/{}_out/kernel".format(
|
||||||
|
layer_key[0])])
|
||||||
|
else:
|
||||||
|
expected_mean_logstd = fc(
|
||||||
|
fc(obs_batch,
|
||||||
|
vars["_hidden_layers.0._model.0.weight"]),
|
||||||
|
vars["_logits._model.0.weight"])
|
||||||
|
mean, log_std = np.split(expected_mean_logstd, 2, axis=-1)
|
||||||
|
if logp_func is None:
|
||||||
|
expected_logp = np.log(norm.pdf(a, mean, np.exp(log_std)))
|
||||||
|
else:
|
||||||
|
expected_logp = logp_func(mean, log_std, a)
|
||||||
|
logp = policy.compute_log_likelihoods(
|
||||||
|
np.array([a]),
|
||||||
|
preprocessed_obs_batch,
|
||||||
|
prev_action_batch=np.array([prev_a]),
|
||||||
|
prev_reward_batch=np.array([prev_r]))
|
||||||
|
check(logp, expected_logp[0], rtol=0.2)
|
||||||
|
# Test all available actions for their logp values.
|
||||||
|
else:
|
||||||
|
for a in [0, 1, 2, 3]:
|
||||||
|
count = actions.count(a)
|
||||||
|
expected_logp = np.log(count / num_actions)
|
||||||
|
logp = policy.compute_log_likelihoods(
|
||||||
|
np.array([a]),
|
||||||
|
preprocessed_obs_batch,
|
||||||
|
prev_action_batch=np.array([prev_a]),
|
||||||
|
prev_reward_batch=np.array([prev_r]))
|
||||||
|
check(logp, expected_logp, rtol=0.3)
|
||||||
|
|
||||||
|
|
||||||
|
class TestComputeLogLikelihood(unittest.TestCase):
|
||||||
|
def test_dqn(self):
|
||||||
|
"""Tests, whether DQN correctly computes logp in soft-q mode."""
|
||||||
|
test_log_likelihood(dqn.DQNTrainer, dqn.DEFAULT_CONFIG)
|
||||||
|
|
||||||
|
def test_ppo_cont(self):
|
||||||
|
"""Tests PPO's (cont. actions) compute_log_likelihoods method."""
|
||||||
|
config = ppo.DEFAULT_CONFIG.copy()
|
||||||
|
config["model"]["fcnet_hiddens"] = [10]
|
||||||
|
config["model"]["fcnet_activation"] = "linear"
|
||||||
|
prev_a = np.array([0.0])
|
||||||
|
test_log_likelihood(ppo.PPOTrainer, config, prev_a, continuous=True)
|
||||||
|
|
||||||
|
def test_ppo_discr(self):
|
||||||
|
"""Tests PPO's (discr. actions) compute_log_likelihoods method."""
|
||||||
|
prev_a = np.array(0)
|
||||||
|
test_log_likelihood(ppo.PPOTrainer, ppo.DEFAULT_CONFIG, prev_a)
|
||||||
|
|
||||||
|
def test_sac(self):
|
||||||
|
"""Tests SAC's compute_log_likelihoods method."""
|
||||||
|
config = sac.DEFAULT_CONFIG.copy()
|
||||||
|
config["policy_model"]["hidden_layer_sizes"] = [10]
|
||||||
|
config["policy_model"]["hidden_activation"] = "linear"
|
||||||
|
prev_a = np.array([0.0])
|
||||||
|
|
||||||
|
def logp_func(means, log_stds, values, low=-1.0, high=1.0):
|
||||||
|
stds = np.exp(
|
||||||
|
np.clip(log_stds, MIN_LOG_NN_OUTPUT, MAX_LOG_NN_OUTPUT))
|
||||||
|
unsquashed_values = np.arctanh((values - low) /
|
||||||
|
(high - low) * 2.0 - 1.0)
|
||||||
|
log_prob_unsquashed = \
|
||||||
|
np.sum(np.log(norm.pdf(unsquashed_values, means, stds)), -1)
|
||||||
|
return log_prob_unsquashed - \
|
||||||
|
np.sum(np.log(1 - np.tanh(unsquashed_values) ** 2),
|
||||||
|
axis=-1)
|
||||||
|
|
||||||
|
test_log_likelihood(
|
||||||
|
sac.SACTrainer,
|
||||||
|
config,
|
||||||
|
prev_a,
|
||||||
|
continuous=True,
|
||||||
|
layer_key=("sequential/action", (0, 2)),
|
||||||
|
logp_func=logp_func)
|
|
@ -1,6 +1,7 @@
|
||||||
import random
|
import random
|
||||||
|
|
||||||
from ray.rllib.policy.policy import Policy
|
from ray.rllib.policy.policy import Policy
|
||||||
|
from ray.rllib.utils.annotations import override
|
||||||
|
|
||||||
|
|
||||||
class TestPolicy(Policy):
|
class TestPolicy(Policy):
|
||||||
|
@ -9,6 +10,7 @@ class TestPolicy(Policy):
|
||||||
and implements all other abstract methods of Policy with "pass".
|
and implements all other abstract methods of Policy with "pass".
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@override(Policy)
|
||||||
def compute_actions(self,
|
def compute_actions(self,
|
||||||
obs_batch,
|
obs_batch,
|
||||||
state_batches=None,
|
state_batches=None,
|
||||||
|
@ -19,3 +21,12 @@ class TestPolicy(Policy):
|
||||||
timestep=None,
|
timestep=None,
|
||||||
**kwargs):
|
**kwargs):
|
||||||
return [random.choice([0, 1])] * len(obs_batch), [], {}
|
return [random.choice([0, 1])] * len(obs_batch), [], {}
|
||||||
|
|
||||||
|
@override(Policy)
|
||||||
|
def compute_log_likelihoods(self,
|
||||||
|
actions,
|
||||||
|
obs_batch,
|
||||||
|
state_batches=None,
|
||||||
|
prev_action_batch=None,
|
||||||
|
prev_reward_batch=None):
|
||||||
|
return [random.random()] * len(obs_batch)
|
||||||
|
|
|
@ -13,9 +13,9 @@ from ray.rllib.models.modelv2 import ModelV2
|
||||||
from ray.rllib.utils.annotations import override, DeveloperAPI
|
from ray.rllib.utils.annotations import override, DeveloperAPI
|
||||||
from ray.rllib.utils.debug import log_once, summarize
|
from ray.rllib.utils.debug import log_once, summarize
|
||||||
from ray.rllib.utils.exploration.exploration import Exploration
|
from ray.rllib.utils.exploration.exploration import Exploration
|
||||||
|
from ray.rllib.utils.framework import try_import_tf
|
||||||
from ray.rllib.utils.schedules import ConstantSchedule, PiecewiseSchedule
|
from ray.rllib.utils.schedules import ConstantSchedule, PiecewiseSchedule
|
||||||
from ray.rllib.utils.tf_run_builder import TFRunBuilder
|
from ray.rllib.utils.tf_run_builder import TFRunBuilder
|
||||||
from ray.rllib.utils import try_import_tf
|
|
||||||
|
|
||||||
tf = try_import_tf()
|
tf = try_import_tf()
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
@ -38,7 +38,7 @@ class TFPolicy(Policy):
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
>>> policy = TFPolicySubclass(
|
>>> policy = TFPolicySubclass(
|
||||||
sess, obs_input, action_sampler, loss, loss_inputs)
|
sess, obs_input, sampled_action, loss, loss_inputs)
|
||||||
|
|
||||||
>>> print(policy.compute_actions([1, 0, 2]))
|
>>> print(policy.compute_actions([1, 0, 2]))
|
||||||
(array([0, 1, 1]), [], {})
|
(array([0, 1, 1]), [], {})
|
||||||
|
@ -54,11 +54,13 @@ class TFPolicy(Policy):
|
||||||
config,
|
config,
|
||||||
sess,
|
sess,
|
||||||
obs_input,
|
obs_input,
|
||||||
action_sampler,
|
sampled_action,
|
||||||
loss,
|
loss,
|
||||||
loss_inputs,
|
loss_inputs,
|
||||||
model=None,
|
model=None,
|
||||||
action_logp=None,
|
sampled_action_logp=None,
|
||||||
|
action_input=None,
|
||||||
|
log_likelihood=None,
|
||||||
state_inputs=None,
|
state_inputs=None,
|
||||||
state_outputs=None,
|
state_outputs=None,
|
||||||
prev_action_input=None,
|
prev_action_input=None,
|
||||||
|
@ -78,7 +80,7 @@ class TFPolicy(Policy):
|
||||||
sess (Session): The TensorFlow session to use.
|
sess (Session): The TensorFlow session to use.
|
||||||
obs_input (Tensor): Input placeholder for observations, of shape
|
obs_input (Tensor): Input placeholder for observations, of shape
|
||||||
[BATCH_SIZE, obs...].
|
[BATCH_SIZE, obs...].
|
||||||
action_sampler (Tensor): Tensor for sampling an action, of shape
|
sampled_action (Tensor): Tensor for sampling an action, of shape
|
||||||
[BATCH_SIZE, action...]
|
[BATCH_SIZE, action...]
|
||||||
loss (Tensor): Scalar policy loss output tensor.
|
loss (Tensor): Scalar policy loss output tensor.
|
||||||
loss_inputs (list): A (name, placeholder) tuple for each loss
|
loss_inputs (list): A (name, placeholder) tuple for each loss
|
||||||
|
@ -89,7 +91,12 @@ class TFPolicy(Policy):
|
||||||
placeholders during loss computation.
|
placeholders during loss computation.
|
||||||
model (rllib.models.Model): used to integrate custom losses and
|
model (rllib.models.Model): used to integrate custom losses and
|
||||||
stats from user-defined RLlib models.
|
stats from user-defined RLlib models.
|
||||||
action_logp (Tensor): log probability of the sampled action.
|
sampled_action_logp (Tensor): log probability of the sampled
|
||||||
|
action.
|
||||||
|
action_input (Optional[Tensor]): Input placeholder for actions for
|
||||||
|
logp/log-likelihood calculations.
|
||||||
|
log_likelihood (Optional[Tensor]): Tensor to calculate the
|
||||||
|
log_likelihood (given action_input and obs_input).
|
||||||
state_inputs (list): list of RNN state input Tensors.
|
state_inputs (list): list of RNN state input Tensors.
|
||||||
state_outputs (list): list of RNN state output Tensors.
|
state_outputs (list): list of RNN state output Tensors.
|
||||||
prev_action_input (Tensor): placeholder for previous actions
|
prev_action_input (Tensor): placeholder for previous actions
|
||||||
|
@ -115,13 +122,16 @@ class TFPolicy(Policy):
|
||||||
self._obs_input = obs_input
|
self._obs_input = obs_input
|
||||||
self._prev_action_input = prev_action_input
|
self._prev_action_input = prev_action_input
|
||||||
self._prev_reward_input = prev_reward_input
|
self._prev_reward_input = prev_reward_input
|
||||||
self._action = action_sampler
|
self._sampled_action = sampled_action
|
||||||
self._is_training = self._get_is_training_placeholder()
|
self._is_training = self._get_is_training_placeholder()
|
||||||
self._is_exploring = explore if explore is not None else \
|
self._is_exploring = explore if explore is not None else \
|
||||||
tf.placeholder_with_default(True, (), name="is_exploring")
|
tf.placeholder_with_default(True, (), name="is_exploring")
|
||||||
self._action_logp = action_logp
|
self._sampled_action_logp = sampled_action_logp
|
||||||
self._action_prob = (tf.exp(self._action_logp)
|
self._sampled_action_prob = (tf.exp(self._sampled_action_logp)
|
||||||
if self._action_logp is not None else None)
|
if self._sampled_action_logp is not None
|
||||||
|
else None)
|
||||||
|
self._action_input = action_input # For logp calculations.
|
||||||
|
self._log_likelihood = log_likelihood
|
||||||
self._state_inputs = state_inputs or []
|
self._state_inputs = state_inputs or []
|
||||||
self._state_outputs = state_outputs or []
|
self._state_outputs = state_outputs or []
|
||||||
self._seq_lens = seq_lens
|
self._seq_lens = seq_lens
|
||||||
|
@ -152,6 +162,9 @@ class TFPolicy(Policy):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"seq_lens tensor must be given if state inputs are defined")
|
"seq_lens tensor must be given if state inputs are defined")
|
||||||
|
|
||||||
|
# Generate the log-likelihood calculator.
|
||||||
|
self._log_likelihood = log_likelihood
|
||||||
|
|
||||||
def variables(self):
|
def variables(self):
|
||||||
"""Return the list of all savable variables for this policy."""
|
"""Return the list of all savable variables for this policy."""
|
||||||
return self.model.variables()
|
return self.model.variables()
|
||||||
|
@ -255,6 +268,46 @@ class TFPolicy(Policy):
|
||||||
# Execute session run to get action (and other fetches).
|
# Execute session run to get action (and other fetches).
|
||||||
return builder.get(fetches)
|
return builder.get(fetches)
|
||||||
|
|
||||||
|
@override(Policy)
|
||||||
|
def compute_log_likelihoods(self,
|
||||||
|
actions,
|
||||||
|
obs_batch,
|
||||||
|
state_batches=None,
|
||||||
|
prev_action_batch=None,
|
||||||
|
prev_reward_batch=None):
|
||||||
|
if self._log_likelihood is None:
|
||||||
|
raise ValueError("Cannot compute log-prob/likelihood w/o a "
|
||||||
|
"self._log_likelihood op!")
|
||||||
|
|
||||||
|
# Do the forward pass through the model to capture the parameters
|
||||||
|
# for the action distribution, then do a logp on that distribution.
|
||||||
|
builder = TFRunBuilder(self._sess, "compute_log_likelihoods")
|
||||||
|
# Feed actions (for which we want logp values) into graph.
|
||||||
|
builder.add_feed_dict({self._action_input: actions})
|
||||||
|
# Feed observations.
|
||||||
|
builder.add_feed_dict({self._obs_input: obs_batch})
|
||||||
|
# Internal states.
|
||||||
|
state_batches = state_batches or []
|
||||||
|
if len(self._state_inputs) != len(state_batches):
|
||||||
|
raise ValueError(
|
||||||
|
"Must pass in RNN state batches for placeholders {}, got {}".
|
||||||
|
format(self._state_inputs, state_batches))
|
||||||
|
builder.add_feed_dict(
|
||||||
|
{k: v
|
||||||
|
for k, v in zip(self._state_inputs, state_batches)})
|
||||||
|
if state_batches:
|
||||||
|
builder.add_feed_dict({self._seq_lens: np.ones(len(obs_batch))})
|
||||||
|
# Prev-a and r.
|
||||||
|
if self._prev_action_input is not None and \
|
||||||
|
prev_action_batch is not None:
|
||||||
|
builder.add_feed_dict({self._prev_action_input: prev_action_batch})
|
||||||
|
if self._prev_reward_input is not None and \
|
||||||
|
prev_reward_batch is not None:
|
||||||
|
builder.add_feed_dict({self._prev_reward_input: prev_reward_batch})
|
||||||
|
# Fetch the log_likelihoods output and return.
|
||||||
|
fetches = builder.add_fetches([self._log_likelihood])
|
||||||
|
return builder.get(fetches)[0]
|
||||||
|
|
||||||
@override(Policy)
|
@override(Policy)
|
||||||
def compute_gradients(self, postprocessed_batch):
|
def compute_gradients(self, postprocessed_batch):
|
||||||
assert self.loss_initialized()
|
assert self.loss_initialized()
|
||||||
|
@ -341,9 +394,9 @@ class TFPolicy(Policy):
|
||||||
By default we only return action probability info (if present).
|
By default we only return action probability info (if present).
|
||||||
"""
|
"""
|
||||||
ret = {}
|
ret = {}
|
||||||
if self._action_logp is not None:
|
if self._sampled_action_logp is not None:
|
||||||
ret[ACTION_PROB] = self._action_prob
|
ret[ACTION_PROB] = self._sampled_action_prob
|
||||||
ret[ACTION_LOGP] = self._action_logp
|
ret[ACTION_LOGP] = self._sampled_action_logp
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
@DeveloperAPI
|
@DeveloperAPI
|
||||||
|
@ -441,7 +494,7 @@ class TFPolicy(Policy):
|
||||||
# build output signatures
|
# build output signatures
|
||||||
output_signature = self._extra_output_signature_def()
|
output_signature = self._extra_output_signature_def()
|
||||||
output_signature["actions"] = \
|
output_signature["actions"] = \
|
||||||
tf.saved_model.utils.build_tensor_info(self._action)
|
tf.saved_model.utils.build_tensor_info(self._sampled_action)
|
||||||
for state_output in self._state_outputs:
|
for state_output in self._state_outputs:
|
||||||
output_signature[state_output.name] = \
|
output_signature[state_output.name] = \
|
||||||
tf.saved_model.utils.build_tensor_info(state_output)
|
tf.saved_model.utils.build_tensor_info(state_output)
|
||||||
|
@ -463,6 +516,7 @@ class TFPolicy(Policy):
|
||||||
episodes=None,
|
episodes=None,
|
||||||
explore=None,
|
explore=None,
|
||||||
timestep=None):
|
timestep=None):
|
||||||
|
|
||||||
explore = explore if explore is not None else self.config["explore"]
|
explore = explore if explore is not None else self.config["explore"]
|
||||||
|
|
||||||
state_batches = state_batches or []
|
state_batches = state_batches or []
|
||||||
|
@ -485,7 +539,8 @@ class TFPolicy(Policy):
|
||||||
if timestep is not None:
|
if timestep is not None:
|
||||||
builder.add_feed_dict({self._timestep: timestep})
|
builder.add_feed_dict({self._timestep: timestep})
|
||||||
builder.add_feed_dict(dict(zip(self._state_inputs, state_batches)))
|
builder.add_feed_dict(dict(zip(self._state_inputs, state_batches)))
|
||||||
fetches = builder.add_fetches([self._action] + self._state_outputs +
|
fetches = builder.add_fetches([self._sampled_action] +
|
||||||
|
self._state_outputs +
|
||||||
[self.extra_compute_action_fetches()])
|
[self.extra_compute_action_fetches()])
|
||||||
return fetches[0], fetches[1:-1], fetches[-1]
|
return fetches[0], fetches[1:-1], fetches[-1]
|
||||||
|
|
||||||
|
|
|
@ -26,6 +26,7 @@ def build_tf_policy(name,
|
||||||
after_init=None,
|
after_init=None,
|
||||||
make_model=None,
|
make_model=None,
|
||||||
action_sampler_fn=None,
|
action_sampler_fn=None,
|
||||||
|
log_likelihood_fn=None,
|
||||||
mixins=None,
|
mixins=None,
|
||||||
get_batch_divisibility_req=None,
|
get_batch_divisibility_req=None,
|
||||||
obs_include_prev_action_reward=True):
|
obs_include_prev_action_reward=True):
|
||||||
|
@ -81,10 +82,14 @@ def build_tf_policy(name,
|
||||||
given (policy, obs_space, action_space, config).
|
given (policy, obs_space, action_space, config).
|
||||||
All policy variables should be created in this function. If not
|
All policy variables should be created in this function. If not
|
||||||
specified, a default model will be created.
|
specified, a default model will be created.
|
||||||
action_sampler_fn (func): optional function that returns a
|
action_sampler_fn (Optional[callable]): An optional callable returning
|
||||||
tuple of action and action prob tensors given
|
a tuple of action and action prob tensors given
|
||||||
(policy, model, input_dict, obs_space, action_space, config).
|
(policy, model, input_dict, obs_space, action_space, config).
|
||||||
If not specified, a default action distribution will be used.
|
If None, a default action distribution will be used.
|
||||||
|
log_likelihood_fn (Optional[callable]): A callable,
|
||||||
|
returning a log-likelihood op.
|
||||||
|
If None, a default class is used and distribution inputs
|
||||||
|
(for parameterization) will be generated by a model call.
|
||||||
mixins (list): list of any class mixins for the returned policy class.
|
mixins (list): list of any class mixins for the returned policy class.
|
||||||
These mixins will be applied in order and will have higher
|
These mixins will be applied in order and will have higher
|
||||||
precedence than the DynamicTFPolicy class
|
precedence than the DynamicTFPolicy class
|
||||||
|
@ -132,6 +137,7 @@ def build_tf_policy(name,
|
||||||
before_loss_init=before_loss_init_wrapper,
|
before_loss_init=before_loss_init_wrapper,
|
||||||
make_model=make_model,
|
make_model=make_model,
|
||||||
action_sampler_fn=action_sampler_fn,
|
action_sampler_fn=action_sampler_fn,
|
||||||
|
log_likelihood_fn=log_likelihood_fn,
|
||||||
existing_model=existing_model,
|
existing_model=existing_model,
|
||||||
existing_inputs=existing_inputs,
|
existing_inputs=existing_inputs,
|
||||||
get_batch_divisibility_req=get_batch_divisibility_req,
|
get_batch_divisibility_req=get_batch_divisibility_req,
|
||||||
|
|
|
@ -4,10 +4,10 @@ import time
|
||||||
from ray.rllib.policy.policy import Policy, LEARNER_STATS_KEY, ACTION_PROB, \
|
from ray.rllib.policy.policy import Policy, LEARNER_STATS_KEY, ACTION_PROB, \
|
||||||
ACTION_LOGP
|
ACTION_LOGP
|
||||||
from ray.rllib.policy.sample_batch import SampleBatch
|
from ray.rllib.policy.sample_batch import SampleBatch
|
||||||
from ray.rllib.utils import try_import_torch
|
|
||||||
from ray.rllib.utils.annotations import override, DeveloperAPI
|
from ray.rllib.utils.annotations import override, DeveloperAPI
|
||||||
from ray.rllib.utils.tracking_dict import UsageTrackingDict
|
from ray.rllib.utils.framework import try_import_torch
|
||||||
from ray.rllib.utils.schedules import ConstantSchedule, PiecewiseSchedule
|
from ray.rllib.utils.schedules import ConstantSchedule, PiecewiseSchedule
|
||||||
|
from ray.rllib.utils.tracking_dict import UsageTrackingDict
|
||||||
|
|
||||||
torch, _ = try_import_torch()
|
torch, _ = try_import_torch()
|
||||||
|
|
||||||
|
@ -86,7 +86,7 @@ class TorchPolicy(Policy):
|
||||||
action_dist = None
|
action_dist = None
|
||||||
actions, logp = \
|
actions, logp = \
|
||||||
self.exploration.get_exploration_action(
|
self.exploration.get_exploration_action(
|
||||||
logits, self.model, self.dist_class, explore,
|
logits, self.dist_class, self.model, explore,
|
||||||
timestep if timestep is not None else
|
timestep if timestep is not None else
|
||||||
self.global_timestep)
|
self.global_timestep)
|
||||||
input_dict[SampleBatch.ACTIONS] = actions
|
input_dict[SampleBatch.ACTIONS] = actions
|
||||||
|
@ -101,6 +101,28 @@ class TorchPolicy(Policy):
|
||||||
return (actions.cpu().numpy(), [h.cpu().numpy() for h in state],
|
return (actions.cpu().numpy(), [h.cpu().numpy() for h in state],
|
||||||
extra_action_out)
|
extra_action_out)
|
||||||
|
|
||||||
|
@override(Policy)
|
||||||
|
def compute_log_likelihoods(self,
|
||||||
|
actions,
|
||||||
|
obs_batch,
|
||||||
|
state_batches=None,
|
||||||
|
prev_action_batch=None,
|
||||||
|
prev_reward_batch=None):
|
||||||
|
with torch.no_grad():
|
||||||
|
input_dict = self._lazy_tensor_dict({
|
||||||
|
SampleBatch.CUR_OBS: obs_batch,
|
||||||
|
SampleBatch.ACTIONS: actions
|
||||||
|
})
|
||||||
|
if prev_action_batch:
|
||||||
|
input_dict[SampleBatch.PREV_ACTIONS] = prev_action_batch
|
||||||
|
if prev_reward_batch:
|
||||||
|
input_dict[SampleBatch.PREV_REWARDS] = prev_reward_batch
|
||||||
|
|
||||||
|
parameters, _ = self.model(input_dict, state_batches, [1])
|
||||||
|
action_dist = self.dist_class(parameters, self.model)
|
||||||
|
log_likelihoods = action_dist.logp(input_dict[SampleBatch.ACTIONS])
|
||||||
|
return log_likelihoods
|
||||||
|
|
||||||
@override(Policy)
|
@override(Policy)
|
||||||
def learn_on_batch(self, postprocessed_batch):
|
def learn_on_batch(self, postprocessed_batch):
|
||||||
train_batch = self._lazy_tensor_dict(postprocessed_batch)
|
train_batch = self._lazy_tensor_dict(postprocessed_batch)
|
||||||
|
|
|
@ -63,6 +63,9 @@ class TestEagerSupport(unittest.TestCase):
|
||||||
"timesteps_per_iteration": 100
|
"timesteps_per_iteration": 100
|
||||||
})
|
})
|
||||||
|
|
||||||
|
def testSAC(self):
|
||||||
|
check_support("SAC", {"num_workers": 0})
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
import pytest
|
import pytest
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
from tensorflow.python.eager.context import eager_mode
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
import ray
|
import ray
|
||||||
|
@ -30,14 +31,23 @@ def test_explorations(run,
|
||||||
impala.ImpalaTrainer, sac.SACTrainer]:
|
impala.ImpalaTrainer, sac.SACTrainer]:
|
||||||
continue
|
continue
|
||||||
print("Testing {} in framework={}".format(run, fw))
|
print("Testing {} in framework={}".format(run, fw))
|
||||||
config["eager"] = True if fw == "eager" else False
|
config["eager"] = (fw == "eager")
|
||||||
config["use_pytorch"] = True if fw == "torch" else False
|
config["use_pytorch"] = (fw == "torch")
|
||||||
|
|
||||||
# Test for both the default Agent's exploration AND the `Random`
|
# Test for both the default Agent's exploration AND the `Random`
|
||||||
# exploration class.
|
# exploration class.
|
||||||
for exploration in [None]: # , "Random"]:
|
for exploration in [None, "Random"]:
|
||||||
if exploration == "Random":
|
if exploration == "Random":
|
||||||
|
# TODO(sven): Random doesn't work for cont. action spaces
|
||||||
|
# or IMPALA yet.
|
||||||
|
if env == "Pendulum-v0" or run is impala.ImpalaTrainer:
|
||||||
|
continue
|
||||||
config["exploration_config"] = {"type": "Random"}
|
config["exploration_config"] = {"type": "Random"}
|
||||||
|
print("exploration={}".format(exploration or "default"))
|
||||||
|
|
||||||
|
eager_mode_ctx = eager_mode()
|
||||||
|
if fw == "eager":
|
||||||
|
eager_mode_ctx.__enter__()
|
||||||
|
|
||||||
trainer = run(config=config, env=env)
|
trainer = run(config=config, env=env)
|
||||||
|
|
||||||
|
@ -53,8 +63,8 @@ def test_explorations(run,
|
||||||
prev_reward=1.0 if prev_a is not None else None))
|
prev_reward=1.0 if prev_a is not None else None))
|
||||||
check(actions[-1], actions[0])
|
check(actions[-1], actions[0])
|
||||||
|
|
||||||
# Make sure actions drawn are different (around some mean value),
|
# Make sure actions drawn are different
|
||||||
# given constant observations.
|
# (around some mean value), given constant observations.
|
||||||
actions = []
|
actions = []
|
||||||
for _ in range(100):
|
for _ in range(100):
|
||||||
actions.append(
|
actions.append(
|
||||||
|
@ -71,6 +81,9 @@ def test_explorations(run,
|
||||||
# Check that the stddev is not 0.0 (values differ).
|
# Check that the stddev is not 0.0 (values differ).
|
||||||
check(np.std(actions), 0.0, false=True)
|
check(np.std(actions), 0.0, false=True)
|
||||||
|
|
||||||
|
if fw == "eager":
|
||||||
|
eager_mode_ctx.__exit__(None, None, None)
|
||||||
|
|
||||||
|
|
||||||
class TestExplorations(unittest.TestCase):
|
class TestExplorations(unittest.TestCase):
|
||||||
"""
|
"""
|
||||||
|
@ -109,7 +122,7 @@ class TestExplorations(unittest.TestCase):
|
||||||
"CartPole-v0",
|
"CartPole-v0",
|
||||||
impala.DEFAULT_CONFIG,
|
impala.DEFAULT_CONFIG,
|
||||||
np.array([0.0, 0.1, 0.0, 0.0]),
|
np.array([0.0, 0.1, 0.0, 0.0]),
|
||||||
prev_a=np.array([0]))
|
prev_a=np.array(0))
|
||||||
|
|
||||||
def test_pg(self):
|
def test_pg(self):
|
||||||
test_explorations(
|
test_explorations(
|
||||||
|
@ -117,7 +130,7 @@ class TestExplorations(unittest.TestCase):
|
||||||
"CartPole-v0",
|
"CartPole-v0",
|
||||||
pg.DEFAULT_CONFIG,
|
pg.DEFAULT_CONFIG,
|
||||||
np.array([0.0, 0.1, 0.0, 0.0]),
|
np.array([0.0, 0.1, 0.0, 0.0]),
|
||||||
prev_a=np.array([1]))
|
prev_a=np.array(1))
|
||||||
|
|
||||||
def test_ppo_discr(self):
|
def test_ppo_discr(self):
|
||||||
test_explorations(
|
test_explorations(
|
||||||
|
@ -125,7 +138,7 @@ class TestExplorations(unittest.TestCase):
|
||||||
"CartPole-v0",
|
"CartPole-v0",
|
||||||
ppo.DEFAULT_CONFIG,
|
ppo.DEFAULT_CONFIG,
|
||||||
np.array([0.0, 0.1, 0.0, 0.0]),
|
np.array([0.0, 0.1, 0.0, 0.0]),
|
||||||
prev_a=np.array([0]))
|
prev_a=np.array(0))
|
||||||
|
|
||||||
def test_ppo_cont(self):
|
def test_ppo_cont(self):
|
||||||
test_explorations(
|
test_explorations(
|
||||||
|
@ -133,7 +146,7 @@ class TestExplorations(unittest.TestCase):
|
||||||
"Pendulum-v0",
|
"Pendulum-v0",
|
||||||
ppo.DEFAULT_CONFIG,
|
ppo.DEFAULT_CONFIG,
|
||||||
np.array([0.0, 0.1, 0.0]),
|
np.array([0.0, 0.1, 0.0]),
|
||||||
prev_a=np.array([0]),
|
prev_a=np.array([0.0]),
|
||||||
expected_mean_action=0.0)
|
expected_mean_action=0.0)
|
||||||
|
|
||||||
def test_sac(self):
|
def test_sac(self):
|
||||||
|
|
|
@ -237,12 +237,13 @@ class JsonIOTest(unittest.TestCase):
|
||||||
for _ in range(100):
|
for _ in range(100):
|
||||||
writer.write(SAMPLES)
|
writer.write(SAMPLES)
|
||||||
num_files = len(os.listdir(self.test_dir))
|
num_files = len(os.listdir(self.test_dir))
|
||||||
# Magic numbers: 2: On travis, it seems to create only 2 files.
|
# Magic numbers: 2: On travis, it seems to create only 2 files,
|
||||||
|
# but sometimes also 7.
|
||||||
# 12 or 13: Mac locally.
|
# 12 or 13: Mac locally.
|
||||||
# Reasons: Different compressions, file-size interpretations,
|
# Reasons: Different compressions, file-size interpretations,
|
||||||
# json writers?
|
# json writers?
|
||||||
assert num_files in [2, 12, 13], \
|
assert num_files in [2, 7, 12, 13], \
|
||||||
"Expected 12|13 files, but found {} ({})". \
|
"Expected 2|7|12|13 files, but found {} ({})". \
|
||||||
format(num_files, os.listdir(self.test_dir))
|
format(num_files, os.listdir(self.test_dir))
|
||||||
|
|
||||||
def testReadWrite(self):
|
def testReadWrite(self):
|
||||||
|
|
|
@ -7,6 +7,7 @@ import unittest
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
import ray
|
import ray
|
||||||
|
from ray.rllib.utils.framework import try_import_tf
|
||||||
from ray.rllib.agents.registry import get_agent_class
|
from ray.rllib.agents.registry import get_agent_class
|
||||||
from ray.rllib.models.tf.fcnet_v2 import FullyConnectedNetwork as FCNetV2
|
from ray.rllib.models.tf.fcnet_v2 import FullyConnectedNetwork as FCNetV2
|
||||||
from ray.rllib.models.tf.visionnet_v2 import VisionNetwork as VisionNetV2
|
from ray.rllib.models.tf.visionnet_v2 import VisionNetwork as VisionNetV2
|
||||||
|
@ -14,6 +15,7 @@ from ray.rllib.tests.test_multi_agent_env import MultiCartpole, \
|
||||||
MultiMountainCar
|
MultiMountainCar
|
||||||
from ray.rllib.utils.error import UnsupportedSpaceException
|
from ray.rllib.utils.error import UnsupportedSpaceException
|
||||||
from ray.tune.registry import register_env
|
from ray.tune.registry import register_env
|
||||||
|
tf = try_import_tf()
|
||||||
|
|
||||||
ACTION_SPACES_TO_TEST = {
|
ACTION_SPACES_TO_TEST = {
|
||||||
"discrete": Discrete(5),
|
"discrete": Discrete(5),
|
||||||
|
@ -220,16 +222,6 @@ class ModelSupportedSpaces(unittest.TestCase):
|
||||||
def test_sac(self):
|
def test_sac(self):
|
||||||
check_support("SAC", {}, self.stats, check_bounds=True)
|
check_support("SAC", {}, self.stats, check_bounds=True)
|
||||||
|
|
||||||
# def testAll(self):
|
|
||||||
|
|
||||||
# num_unexpected_errors = 0
|
|
||||||
# for (alg, a_name, o_name), stat in sorted(self.stats.items()):
|
|
||||||
# if stat not in ["ok", "unsupported", "skip"]:
|
|
||||||
# num_unexpected_errors += 1
|
|
||||||
# print(alg, "action_space", a_name, "obs_space", o_name, "result",
|
|
||||||
# stat)
|
|
||||||
# self.assertEqual(num_unexpected_errors, 0)
|
|
||||||
|
|
||||||
def test_a3c_multiagent(self):
|
def test_a3c_multiagent(self):
|
||||||
check_support_multiagent("A3C", {
|
check_support_multiagent("A3C", {
|
||||||
"num_workers": 1,
|
"num_workers": 1,
|
||||||
|
|
|
@ -23,7 +23,7 @@ halfcheetah_sac:
|
||||||
target_network_update_freq: 1
|
target_network_update_freq: 1
|
||||||
timesteps_per_iteration: 1000
|
timesteps_per_iteration: 1000
|
||||||
learning_starts: 10000
|
learning_starts: 10000
|
||||||
exploration_enabled: True
|
explore: True
|
||||||
optimization:
|
optimization:
|
||||||
actor_learning_rate: 0.0003
|
actor_learning_rate: 0.0003
|
||||||
critic_learning_rate: 0.0003
|
critic_learning_rate: 0.0003
|
||||||
|
|
|
@ -24,7 +24,7 @@ pendulum_sac:
|
||||||
target_network_update_freq: 1
|
target_network_update_freq: 1
|
||||||
timesteps_per_iteration: 1000
|
timesteps_per_iteration: 1000
|
||||||
learning_starts: 256
|
learning_starts: 256
|
||||||
exploration_enabled: True
|
explore: True
|
||||||
optimization:
|
optimization:
|
||||||
actor_learning_rate: 0.0003
|
actor_learning_rate: 0.0003
|
||||||
critic_learning_rate: 0.0003
|
critic_learning_rate: 0.0003
|
||||||
|
|
|
@ -8,7 +8,7 @@ from ray.rllib.utils.deprecation import deprecation_warning, renamed_agent, \
|
||||||
from ray.rllib.utils.filter_manager import FilterManager
|
from ray.rllib.utils.filter_manager import FilterManager
|
||||||
from ray.rllib.utils.filter import Filter
|
from ray.rllib.utils.filter import Filter
|
||||||
from ray.rllib.utils.numpy import sigmoid, softmax, relu, one_hot, fc, lstm, \
|
from ray.rllib.utils.numpy import sigmoid, softmax, relu, one_hot, fc, lstm, \
|
||||||
SMALL_NUMBER, LARGE_INTEGER
|
SMALL_NUMBER, LARGE_INTEGER, MIN_LOG_NN_OUTPUT, MAX_LOG_NN_OUTPUT
|
||||||
from ray.rllib.utils.policy_client import PolicyClient
|
from ray.rllib.utils.policy_client import PolicyClient
|
||||||
from ray.rllib.utils.policy_server import PolicyServer
|
from ray.rllib.utils.policy_server import PolicyServer
|
||||||
from ray.rllib.utils.schedules import LinearSchedule, PiecewiseSchedule, \
|
from ray.rllib.utils.schedules import LinearSchedule, PiecewiseSchedule, \
|
||||||
|
@ -85,6 +85,8 @@ __all__ = [
|
||||||
"FilterManager",
|
"FilterManager",
|
||||||
"LARGE_INTEGER",
|
"LARGE_INTEGER",
|
||||||
"LinearSchedule",
|
"LinearSchedule",
|
||||||
|
"MAX_LOG_NN_OUTPUT",
|
||||||
|
"MIN_LOG_NN_OUTPUT",
|
||||||
"PiecewiseSchedule",
|
"PiecewiseSchedule",
|
||||||
"PolicyClient",
|
"PolicyClient",
|
||||||
"PolicyServer",
|
"PolicyServer",
|
||||||
|
|
|
@ -63,24 +63,24 @@ class EpsilonGreedy(Exploration):
|
||||||
|
|
||||||
@override(Exploration)
|
@override(Exploration)
|
||||||
def get_exploration_action(self,
|
def get_exploration_action(self,
|
||||||
model_output,
|
distribution_inputs,
|
||||||
model,
|
action_dist_class=None,
|
||||||
action_dist_class,
|
model=None,
|
||||||
explore=True,
|
explore=True,
|
||||||
timestep=None):
|
timestep=None):
|
||||||
|
|
||||||
if self.framework == "tf":
|
if self.framework == "tf":
|
||||||
return self._get_tf_exploration_action_op(model_output, explore,
|
return self._get_tf_exploration_action_op(distribution_inputs,
|
||||||
timestep)
|
explore, timestep)
|
||||||
else:
|
else:
|
||||||
return self._get_torch_exploration_action(model_output, explore,
|
return self._get_torch_exploration_action(distribution_inputs,
|
||||||
timestep)
|
explore, timestep)
|
||||||
|
|
||||||
def _get_tf_exploration_action_op(self, model_output, explore, timestep):
|
def _get_tf_exploration_action_op(self, q_values, explore, timestep):
|
||||||
"""Tf method to produce the tf op for an epsilon exploration action.
|
"""Tf method to produce the tf op for an epsilon exploration action.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model_output (tf.Tensor):
|
q_values (Tensor): The Q-values coming from some q-model.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
tf.Tensor: The tf exploration-action op.
|
tf.Tensor: The tf exploration-action op.
|
||||||
|
@ -90,15 +90,14 @@ class EpsilonGreedy(Exploration):
|
||||||
self.last_timestep))
|
self.last_timestep))
|
||||||
|
|
||||||
# Get the exploit action as the one with the highest logit value.
|
# Get the exploit action as the one with the highest logit value.
|
||||||
exploit_action = tf.argmax(model_output, axis=1)
|
exploit_action = tf.argmax(q_values, axis=1)
|
||||||
|
|
||||||
batch_size = tf.shape(model_output)[0]
|
batch_size = tf.shape(q_values)[0]
|
||||||
# Mask out actions with q-value=-inf so that we don't
|
# Mask out actions with q-value=-inf so that we don't
|
||||||
# even consider them for exploration.
|
# even consider them for exploration.
|
||||||
random_valid_action_logits = tf.where(
|
random_valid_action_logits = tf.where(
|
||||||
tf.equal(model_output, tf.float32.min),
|
tf.equal(q_values, tf.float32.min),
|
||||||
tf.ones_like(model_output) * tf.float32.min,
|
tf.ones_like(q_values) * tf.float32.min, tf.ones_like(q_values))
|
||||||
tf.ones_like(model_output))
|
|
||||||
random_actions = tf.squeeze(
|
random_actions = tf.squeeze(
|
||||||
tf.multinomial(random_valid_action_logits, 1), axis=1)
|
tf.multinomial(random_valid_action_logits, 1), axis=1)
|
||||||
|
|
||||||
|
@ -122,11 +121,11 @@ class EpsilonGreedy(Exploration):
|
||||||
with tf.control_dependencies([assign_op]):
|
with tf.control_dependencies([assign_op]):
|
||||||
return action, tf.zeros_like(action, dtype=tf.float32)
|
return action, tf.zeros_like(action, dtype=tf.float32)
|
||||||
|
|
||||||
def _get_torch_exploration_action(self, model_output, explore, timestep):
|
def _get_torch_exploration_action(self, q_values, explore, timestep):
|
||||||
"""Torch method to produce an epsilon exploration action.
|
"""Torch method to produce an epsilon exploration action.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model_output (torch.Tensor):
|
q_values (Tensor): The Q-values coming from some q-model.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
torch.Tensor: The exploration-action.
|
torch.Tensor: The exploration-action.
|
||||||
|
@ -135,20 +134,20 @@ class EpsilonGreedy(Exploration):
|
||||||
self.last_timestep = timestep if timestep is not None else \
|
self.last_timestep = timestep if timestep is not None else \
|
||||||
self.last_timestep + 1
|
self.last_timestep + 1
|
||||||
|
|
||||||
_, exploit_action = torch.max(model_output, 1)
|
_, exploit_action = torch.max(q_values, 1)
|
||||||
action_logp = torch.zeros_like(exploit_action)
|
action_logp = torch.zeros_like(exploit_action)
|
||||||
|
|
||||||
# Explore.
|
# Explore.
|
||||||
if explore:
|
if explore:
|
||||||
# Get the current epsilon.
|
# Get the current epsilon.
|
||||||
epsilon = self.epsilon_schedule(self.last_timestep)
|
epsilon = self.epsilon_schedule(self.last_timestep)
|
||||||
batch_size = model_output.size()[0]
|
batch_size = q_values.size()[0]
|
||||||
# Mask out actions, whose Q-values are -inf, so that we don't
|
# Mask out actions, whose Q-values are -inf, so that we don't
|
||||||
# even consider them for exploration.
|
# even consider them for exploration.
|
||||||
random_valid_action_logits = torch.where(
|
random_valid_action_logits = torch.where(
|
||||||
model_output == float("-inf"),
|
q_values == float("-inf"),
|
||||||
torch.ones_like(model_output) * float("-inf"),
|
torch.ones_like(q_values) * float("-inf"),
|
||||||
torch.ones_like(model_output))
|
torch.ones_like(q_values))
|
||||||
# A random action.
|
# A random action.
|
||||||
random_actions = torch.squeeze(
|
random_actions = torch.squeeze(
|
||||||
torch.multinomial(random_valid_action_logits, 1), axis=1)
|
torch.multinomial(random_valid_action_logits, 1), axis=1)
|
||||||
|
|
|
@ -31,9 +31,9 @@ class Exploration:
|
||||||
self.framework = check_framework(framework)
|
self.framework = check_framework(framework)
|
||||||
|
|
||||||
def get_exploration_action(self,
|
def get_exploration_action(self,
|
||||||
model_output,
|
distribution_inputs,
|
||||||
model,
|
|
||||||
action_dist_class,
|
action_dist_class,
|
||||||
|
model=None,
|
||||||
explore=True,
|
explore=True,
|
||||||
timestep=None):
|
timestep=None):
|
||||||
"""Returns a (possibly) exploratory action.
|
"""Returns a (possibly) exploratory action.
|
||||||
|
@ -42,10 +42,12 @@ class Exploration:
|
||||||
exploratory action.
|
exploratory action.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model_output (any): The raw output coming from the model
|
distribution_inputs (any): The output coming from the model,
|
||||||
|
ready for parameterizing a distribution
|
||||||
(e.g. q-values or PG-logits).
|
(e.g. q-values or PG-logits).
|
||||||
|
action_dist_class (class): The action distribution class
|
||||||
|
to use.
|
||||||
model (ModelV2): The Model object.
|
model (ModelV2): The Model object.
|
||||||
action_dist_class: The ActionDistribution class.
|
|
||||||
explore (bool): True: "Normal" exploration behavior.
|
explore (bool): True: "Normal" exploration behavior.
|
||||||
False: Suppress all exploratory behavior and return
|
False: Suppress all exploratory behavior and return
|
||||||
a deterministic action.
|
a deterministic action.
|
||||||
|
|
|
@ -31,13 +31,13 @@ class Random(Exploration):
|
||||||
|
|
||||||
@override(Exploration)
|
@override(Exploration)
|
||||||
def get_exploration_action(self,
|
def get_exploration_action(self,
|
||||||
model_output,
|
distribution_inputs,
|
||||||
model,
|
|
||||||
action_dist_class,
|
action_dist_class,
|
||||||
|
model=None,
|
||||||
explore=True,
|
explore=True,
|
||||||
timestep=None):
|
timestep=None):
|
||||||
# Instantiate the distribution object.
|
# Instantiate the distribution object.
|
||||||
action_dist = action_dist_class(model_output, model)
|
action_dist = action_dist_class(distribution_inputs, model)
|
||||||
|
|
||||||
if self.framework == "tf":
|
if self.framework == "tf":
|
||||||
return self._get_tf_exploration_action_op(action_dist, explore,
|
return self._get_tf_exploration_action_op(action_dist, explore,
|
||||||
|
@ -49,7 +49,7 @@ class Random(Exploration):
|
||||||
@tf_function(tf)
|
@tf_function(tf)
|
||||||
def _get_tf_exploration_action_op(self, action_dist, explore, timestep):
|
def _get_tf_exploration_action_op(self, action_dist, explore, timestep):
|
||||||
if explore:
|
if explore:
|
||||||
action = self.action_space.sample()
|
action = tf.py_function(self.action_space.sample, [], tf.int64)
|
||||||
# Will be unnecessary, once we support batch/time-aware Spaces.
|
# Will be unnecessary, once we support batch/time-aware Spaces.
|
||||||
action = tf.expand_dims(tf.cast(action, dtype=tf.int32), 0)
|
action = tf.expand_dims(tf.cast(action, dtype=tf.int32), 0)
|
||||||
else:
|
else:
|
||||||
|
@ -67,8 +67,8 @@ class Random(Exploration):
|
||||||
if explore:
|
if explore:
|
||||||
# Unsqueeze will be unnecessary, once we support batch/time-aware
|
# Unsqueeze will be unnecessary, once we support batch/time-aware
|
||||||
# Spaces.
|
# Spaces.
|
||||||
action = torch.IntTensor(self.action_space.sample()).unsqueeze(0)
|
action = torch.LongTensor(self.action_space.sample()).unsqueeze(0)
|
||||||
else:
|
else:
|
||||||
action = torch.IntTensor(action_dist.deterministic_sample())
|
action = torch.LongTensor(action_dist.deterministic_sample())
|
||||||
logp = torch.zeros((action.size()[0], ), dtype=torch.float32)
|
logp = torch.zeros((action.size()[0], ), dtype=torch.float32)
|
||||||
return action, logp
|
return action, logp
|
||||||
|
|
|
@ -1,4 +1,3 @@
|
||||||
from ray.rllib.models.catalog import ModelCatalog
|
|
||||||
from ray.rllib.utils.annotations import override
|
from ray.rllib.utils.annotations import override
|
||||||
from ray.rllib.utils.exploration.exploration import Exploration
|
from ray.rllib.utils.exploration.exploration import Exploration
|
||||||
from ray.rllib.utils.framework import try_import_tf, try_import_torch
|
from ray.rllib.utils.framework import try_import_tf, try_import_torch
|
||||||
|
@ -46,9 +45,9 @@ class StochasticSampling(Exploration):
|
||||||
|
|
||||||
@override(Exploration)
|
@override(Exploration)
|
||||||
def get_exploration_action(self,
|
def get_exploration_action(self,
|
||||||
model_output,
|
distribution_inputs,
|
||||||
model,
|
|
||||||
action_dist_class,
|
action_dist_class,
|
||||||
|
model=None,
|
||||||
explore=True,
|
explore=True,
|
||||||
timestep=None):
|
timestep=None):
|
||||||
kwargs = self.static_params.copy()
|
kwargs = self.static_params.copy()
|
||||||
|
@ -60,12 +59,7 @@ class StochasticSampling(Exploration):
|
||||||
# if self.time_dependent_params:
|
# if self.time_dependent_params:
|
||||||
# for k, v in self.time_dependent_params:
|
# for k, v in self.time_dependent_params:
|
||||||
# kwargs[k] = v(timestep)
|
# kwargs[k] = v(timestep)
|
||||||
constructor, _ = ModelCatalog.get_action_dist(
|
action_dist = action_dist_class(distribution_inputs, model, **kwargs)
|
||||||
self.action_space,
|
|
||||||
None,
|
|
||||||
action_dist_class,
|
|
||||||
framework=self.framework)
|
|
||||||
action_dist = constructor(model_output, model, **kwargs)
|
|
||||||
|
|
||||||
if self.framework == "torch":
|
if self.framework == "torch":
|
||||||
return self._get_torch_exploration_action(action_dist, explore)
|
return self._get_torch_exploration_action(action_dist, explore)
|
||||||
|
|
|
@ -65,7 +65,7 @@ def tf_function(tf_module):
|
||||||
# The actual decorator to use (pass in `tf` (which could be None)).
|
# The actual decorator to use (pass in `tf` (which could be None)).
|
||||||
def decorator(func):
|
def decorator(func):
|
||||||
# If tf not installed -> return function as is (won't be used anyways).
|
# If tf not installed -> return function as is (won't be used anyways).
|
||||||
if tf_module is None:
|
if tf_module is None or tf_module.executing_eagerly():
|
||||||
return func
|
return func
|
||||||
# If tf installed, return @tf.function-decorated function.
|
# If tf installed, return @tf.function-decorated function.
|
||||||
return tf_module.function(func)
|
return tf_module.function(func)
|
||||||
|
|
|
@ -1,5 +1,8 @@
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
from ray.rllib.utils.framework import try_import_torch
|
||||||
|
|
||||||
|
torch, _ = try_import_torch()
|
||||||
|
|
||||||
SMALL_NUMBER = 1e-6
|
SMALL_NUMBER = 1e-6
|
||||||
# Some large int number. May be increased here, if needed.
|
# Some large int number. May be increased here, if needed.
|
||||||
|
@ -58,7 +61,7 @@ def relu(x, alpha=0.0):
|
||||||
Returns:
|
Returns:
|
||||||
np.ndarray: The leaky ReLU output for x.
|
np.ndarray: The leaky ReLU output for x.
|
||||||
"""
|
"""
|
||||||
return np.maximum(x, x*alpha, x)
|
return np.maximum(x, x * alpha, x)
|
||||||
|
|
||||||
|
|
||||||
def one_hot(x, depth=0, on_value=1, off_value=0):
|
def one_hot(x, depth=0, on_value=1, off_value=0):
|
||||||
|
@ -89,7 +92,7 @@ def one_hot(x, depth=0, on_value=1, off_value=0):
|
||||||
shape = x.shape
|
shape = x.shape
|
||||||
|
|
||||||
# Python 2.7 compatibility, (*shape, depth) is not allowed.
|
# Python 2.7 compatibility, (*shape, depth) is not allowed.
|
||||||
shape_list = shape[:]
|
shape_list = list(shape[:])
|
||||||
shape_list.append(depth)
|
shape_list.append(depth)
|
||||||
out = np.ones(shape_list) * off_value
|
out = np.ones(shape_list) * off_value
|
||||||
indices = []
|
indices = []
|
||||||
|
@ -99,7 +102,7 @@ def one_hot(x, depth=0, on_value=1, off_value=0):
|
||||||
s[i] = -1
|
s[i] = -1
|
||||||
r = np.arange(shape[i]).reshape(s)
|
r = np.arange(shape[i]).reshape(s)
|
||||||
if i > 0:
|
if i > 0:
|
||||||
tiles[i-1] = shape[i-1]
|
tiles[i - 1] = shape[i - 1]
|
||||||
r = np.tile(r, tiles)
|
r = np.tile(r, tiles)
|
||||||
indices.append(r)
|
indices.append(r)
|
||||||
indices.append(x)
|
indices.append(x)
|
||||||
|
@ -120,11 +123,18 @@ def fc(x, weights, biases=None):
|
||||||
Returns:
|
Returns:
|
||||||
The dense layer's output.
|
The dense layer's output.
|
||||||
"""
|
"""
|
||||||
|
# Torch stores matrices in transpose (faster for backprop).
|
||||||
|
if torch and isinstance(weights, torch.Tensor):
|
||||||
|
weights = np.transpose(weights.numpy())
|
||||||
return np.matmul(x, weights) + (0.0 if biases is None else biases)
|
return np.matmul(x, weights) + (0.0 if biases is None else biases)
|
||||||
|
|
||||||
|
|
||||||
def lstm(x, weights, biases=None, initial_internal_states=None,
|
def lstm(x,
|
||||||
time_major=False, forget_bias=1.0):
|
weights,
|
||||||
|
biases=None,
|
||||||
|
initial_internal_states=None,
|
||||||
|
time_major=False,
|
||||||
|
forget_bias=1.0):
|
||||||
"""
|
"""
|
||||||
Calculates the outputs of an LSTM layer given weights/biases,
|
Calculates the outputs of an LSTM layer given weights/biases,
|
||||||
internal_states, and input.
|
internal_states, and input.
|
||||||
|
@ -174,15 +184,15 @@ def lstm(x, weights, biases=None, initial_internal_states=None,
|
||||||
input_matrix = np.concatenate((input_matrix, h_states), axis=1)
|
input_matrix = np.concatenate((input_matrix, h_states), axis=1)
|
||||||
input_matmul_matrix = np.matmul(input_matrix, weights) + biases
|
input_matmul_matrix = np.matmul(input_matrix, weights) + biases
|
||||||
# Forget gate (3rd slot in tf output matrix). Add static forget bias.
|
# Forget gate (3rd slot in tf output matrix). Add static forget bias.
|
||||||
sigmoid_1 = sigmoid(input_matmul_matrix[:, units*2:units*3] +
|
sigmoid_1 = sigmoid(input_matmul_matrix[:, units * 2:units * 3] +
|
||||||
forget_bias)
|
forget_bias)
|
||||||
c_states = np.multiply(c_states, sigmoid_1)
|
c_states = np.multiply(c_states, sigmoid_1)
|
||||||
# Add gate (1st and 2nd slots in tf output matrix).
|
# Add gate (1st and 2nd slots in tf output matrix).
|
||||||
sigmoid_2 = sigmoid(input_matmul_matrix[:, 0:units])
|
sigmoid_2 = sigmoid(input_matmul_matrix[:, 0:units])
|
||||||
tanh_3 = np.tanh(input_matmul_matrix[:, units:units*2])
|
tanh_3 = np.tanh(input_matmul_matrix[:, units:units * 2])
|
||||||
c_states = np.add(c_states, np.multiply(sigmoid_2, tanh_3))
|
c_states = np.add(c_states, np.multiply(sigmoid_2, tanh_3))
|
||||||
# Output gate (last slot in tf output matrix).
|
# Output gate (last slot in tf output matrix).
|
||||||
sigmoid_4 = sigmoid(input_matmul_matrix[:, units*3:units*4])
|
sigmoid_4 = sigmoid(input_matmul_matrix[:, units * 3:units * 4])
|
||||||
h_states = np.multiply(sigmoid_4, np.tanh(c_states))
|
h_states = np.multiply(sigmoid_4, np.tanh(c_states))
|
||||||
|
|
||||||
# Store this output time-slice.
|
# Store this output time-slice.
|
||||||
|
|
Loading…
Add table
Reference in a new issue