diff --git a/ci/travis/install-dependencies.sh b/ci/travis/install-dependencies.sh index 74f3cfd95..d0e6c5fff 100755 --- a/ci/travis/install-dependencies.sh +++ b/ci/travis/install-dependencies.sh @@ -79,7 +79,7 @@ fi if [[ "$RLLIB_TESTING" == "1" ]]; then pip install -q tensorflow-probability==$tfp_version gast==0.2.2 \ torch==$torch_version torchvision \ - gym[atari] atari_py smart_open + gym[atari] atari_py smart_open lz4 fi if [[ "$PYTHON" == "3.6" ]] || [[ "$MAC_WHEELS" == "1" ]]; then diff --git a/python/ray/tune/logger.py b/python/ray/tune/logger.py index 5c9ddd1a4..2080ed30b 100644 --- a/python/ray/tune/logger.py +++ b/python/ray/tune/logger.py @@ -4,7 +4,6 @@ import logging import os import yaml import numbers - import numpy as np import ray.cloudpickle as cloudpickle diff --git a/rllib/BUILD b/rllib/BUILD index a066c3010..99ab84d21 100644 --- a/rllib/BUILD +++ b/rllib/BUILD @@ -796,7 +796,6 @@ py_test( ] ) - # -------------------------------------------------------------------- # Models and Distributions # rllib/models/ @@ -811,6 +810,20 @@ py_test( srcs = ["models/tests/test_distributions.py"] ) +# -------------------------------------------------------------------- +# Policies +# rllib/policy/ +# +# Tag: policy +# -------------------------------------------------------------------- + +py_test( + name = "policy/tests/test_compute_log_likelihoods", + tags = ["policy"], + size = "small", + srcs = ["policy/tests/test_compute_log_likelihoods.py"] +) + # -------------------------------------------------------------------- # Utils: # rllib/utils/ @@ -880,14 +893,6 @@ py_test( srcs = ["tests/test_dependency.py"] ) -# PR 7086 -#py_test( -# name = "tests/test_deterministic_support", -# tags = ["tests_dir", "tests_dir_D"], -# size = "small", -# srcs = ["tests/test_deterministic_support.py"] -#) - py_test( name = "tests/test_eager_support", tags = ["tests_dir", "tests_dir_E"], @@ -912,7 +917,7 @@ py_test( py_test( name = "tests/test_explorations", - tags = ["tests_dir", "tests_dir_E"], + tags = ["tests_dir", "tests_dir_E", "explorations"], size = "medium", srcs = ["tests/test_explorations.py"] ) diff --git a/rllib/agents/ddpg/ddpg_policy.py b/rllib/agents/ddpg/ddpg_policy.py index f226e013e..32ca36dd0 100644 --- a/rllib/agents/ddpg/ddpg_policy.py +++ b/rllib/agents/ddpg/ddpg_policy.py @@ -291,7 +291,7 @@ class DDPGTFPolicy(DDPGPostprocessing, TFPolicy): self.config, self.sess, obs_input=self.cur_observations, - action_sampler=self.output_actions, + sampled_action=self.output_actions, loss=self.actor_loss + self.critic_loss, loss_inputs=self.loss_inputs, update_ops=q_batchnorm_update_ops + policy_batchnorm_update_ops) diff --git a/rllib/agents/dqn/dqn_policy.py b/rllib/agents/dqn/dqn_policy.py index 11034faaa..50520b72e 100644 --- a/rllib/agents/dqn/dqn_policy.py +++ b/rllib/agents/dqn/dqn_policy.py @@ -202,20 +202,29 @@ def build_q_model(policy, obs_space, action_space, config): return policy.q_model -def sample_action_from_q_network(policy, q_model, input_dict, obs_space, - action_space, explore, config, timestep): - +def get_log_likelihood(policy, q_model, actions, input_dict, obs_space, + action_space, config): # Action Q network. q_vals = _compute_q_values(policy, q_model, input_dict[SampleBatch.CUR_OBS], obs_space, action_space) + q_vals = q_vals[0] if isinstance(q_vals, tuple) else q_vals + action_dist = Categorical(q_vals, q_model) + return action_dist.logp(actions) + +def sample_action_from_q_network(policy, q_model, input_dict, obs_space, + action_space, explore, config, timestep): + # Action Q network. + q_vals = _compute_q_values(policy, q_model, + input_dict[SampleBatch.CUR_OBS], obs_space, + action_space) policy.q_values = q_vals[0] if isinstance(q_vals, tuple) else q_vals policy.q_func_vars = q_model.variables() - policy.output_actions, policy.action_logp = \ + policy.output_actions, policy.sampled_action_logp = \ policy.exploration.get_exploration_action( - policy.q_values, q_model, Categorical, explore, timestep) + policy.q_values, Categorical, q_model, explore, timestep) # Noise vars for Q network except for layer normalization vars. if config["parameter_noise"]: @@ -224,7 +233,7 @@ def sample_action_from_q_network(policy, q_model, input_dict, obs_space, [var for var in policy.q_func_vars if "LayerNorm" not in var.name]) policy.action_probs = tf.nn.softmax(policy.q_values) - return policy.output_actions, policy.action_logp + return policy.output_actions, policy.sampled_action_logp def _build_parameter_noise(policy, pnet_params): @@ -448,6 +457,7 @@ DQNTFPolicy = build_tf_policy( get_default_config=lambda: ray.rllib.agents.dqn.dqn.DEFAULT_CONFIG, make_model=build_q_model, action_sampler_fn=sample_action_from_q_network, + log_likelihood_fn=get_log_likelihood, loss_fn=build_q_losses, stats_fn=build_q_stats, postprocess_fn=postprocess_nstep_and_prio, diff --git a/rllib/agents/dqn/simple_q_policy.py b/rllib/agents/dqn/simple_q_policy.py index 903212c0d..fa68b1cdf 100644 --- a/rllib/agents/dqn/simple_q_policy.py +++ b/rllib/agents/dqn/simple_q_policy.py @@ -88,6 +88,17 @@ def build_q_models(policy, obs_space, action_space, config): return policy.q_model +def get_log_likelihood(policy, q_model, actions, input_dict, obs_space, + action_space, config): + # Action Q network. + q_vals = _compute_q_values(policy, q_model, + input_dict[SampleBatch.CUR_OBS], obs_space, + action_space) + q_vals = q_vals[0] if isinstance(q_vals, tuple) else q_vals + action_dist = Categorical(q_vals, q_model) + return action_dist.logp(actions) + + def simple_sample_action_from_q_network(policy, q_model, input_dict, obs_space, action_space, explore, config, timestep): @@ -95,15 +106,14 @@ def simple_sample_action_from_q_network(policy, q_model, input_dict, obs_space, q_vals = _compute_q_values(policy, q_model, input_dict[SampleBatch.CUR_OBS], obs_space, action_space) - policy.q_values = q_vals[0] if isinstance(q_vals, tuple) else q_vals policy.q_func_vars = q_model.variables() - policy.output_actions, policy.action_logp = \ + policy.output_actions, policy.sampled_action_logp = \ policy.exploration.get_exploration_action( - policy.q_values, q_model, Categorical, explore, timestep) + policy.q_values, Categorical, q_model, explore, timestep) - return policy.output_actions, policy.action_logp + return policy.output_actions, policy.sampled_action_logp def build_q_losses(policy, model, dist_class, train_batch): @@ -167,13 +177,11 @@ SimpleQPolicy = build_tf_policy( get_default_config=lambda: ray.rllib.agents.dqn.dqn.DEFAULT_CONFIG, make_model=build_q_models, action_sampler_fn=simple_sample_action_from_q_network, + log_likelihood_fn=get_log_likelihood, loss_fn=build_q_losses, extra_action_fetches_fn=lambda policy: {"q_values": policy.q_values}, extra_learn_fetches_fn=lambda policy: {"td_error": policy.td_error}, before_init=setup_early_mixins, after_init=setup_late_mixins, obs_include_prev_action_reward=False, - mixins=[ - ParameterNoiseMixin, - TargetNetworkMixin, - ]) + mixins=[ParameterNoiseMixin, TargetNetworkMixin]) diff --git a/rllib/agents/qmix/qmix_policy.py b/rllib/agents/qmix/qmix_policy.py index 463840716..3fdd6a9e4 100644 --- a/rllib/agents/qmix/qmix_policy.py +++ b/rllib/agents/qmix/qmix_policy.py @@ -1,8 +1,6 @@ from gym.spaces import Tuple, Discrete, Dict import logging import numpy as np -import torch as th -import torch.nn as nn from torch.optim import RMSprop from torch.distributions import Categorical @@ -16,9 +14,13 @@ from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.models.catalog import ModelCatalog from ray.rllib.models.model import _unpack_obs from ray.rllib.env.constants import GROUP_REWARDS +from ray.rllib.utils.framework import try_import_torch from ray.rllib.utils.annotations import override from ray.rllib.utils.tuple_actions import TupleActions +# Torch must be installed. +torch, nn = try_import_torch(error=True) + logger = logging.getLogger(__name__) # if the obs space is Dict type, look for the global state under this key @@ -85,7 +87,7 @@ class QMixLoss(nn.Module): mac_out = _unroll_mac(self.model, obs) # Pick the Q-Values for the actions taken -> [B * n_agents, T] - chosen_action_qvals = th.gather( + chosen_action_qvals = torch.gather( mac_out, dim=3, index=actions.unsqueeze(3)).squeeze(3) # Calculate the Q-Values necessary for the target @@ -114,8 +116,8 @@ class QMixLoss(nn.Module): # use the target network to estimate the Q-values of policy # network's selected actions - target_max_qvals = th.gather(target_mac_out, 3, - cur_max_actions).squeeze(3) + target_max_qvals = torch.gather(target_mac_out, 3, + cur_max_actions).squeeze(3) else: target_max_qvals = target_mac_out.max(dim=3)[0] @@ -167,8 +169,8 @@ class QMixTorchPolicy(Policy): self.h_size = config["model"]["lstm_cell_size"] self.has_env_global_state = False self.has_action_mask = False - self.device = (th.device("cuda") - if th.cuda.is_available() else th.device("cpu")) + self.device = (torch.device("cuda") + if torch.cuda.is_available() else torch.device("cpu")) agent_obs_space = obs_space.original_space.spaces[0] if isinstance(agent_obs_space, Dict): @@ -262,20 +264,21 @@ class QMixTorchPolicy(Policy): # to compute actions # Compute actions - with th.no_grad(): + with torch.no_grad(): q_values, hiddens = _mac( self.model, - th.as_tensor(obs_batch, dtype=th.float, device=self.device), [ - th.as_tensor( - np.array(s), dtype=th.float, device=self.device) - for s in state_batches - ]) - avail = th.as_tensor( - action_mask, dtype=th.float, device=self.device) + torch.as_tensor( + obs_batch, dtype=torch.float, device=self.device), [ + torch.as_tensor( + np.array(s), dtype=torch.float, device=self.device) + for s in state_batches + ]) + avail = torch.as_tensor( + action_mask, dtype=torch.float, device=self.device) masked_q_values = q_values.clone() masked_q_values[avail == 0.0] = -float("inf") # epsilon-greedy action selector - random_numbers = th.rand_like(q_values[:, :, 0]) + random_numbers = torch.rand_like(q_values[:, :, 0]) pick_random = (random_numbers < (self.cur_epsilon if explore else 0.0)).long() random_actions = Categorical(avail).sample().long() @@ -286,6 +289,16 @@ class QMixTorchPolicy(Policy): return TupleActions(list(actions.transpose([1, 0]))), hiddens, {} + @override(Policy) + def compute_log_likelihoods(self, + actions, + obs_batch, + state_batches=None, + prev_action_batch=None, + prev_reward_batch=None): + obs_batch, action_mask, _ = self._unpack_observation(obs_batch) + return np.zeros(obs_batch.size()[0]) + @override(Policy) def learn_on_batch(self, samples): obs_batch, action_mask, env_global_state = self._unpack_observation( @@ -323,31 +336,32 @@ class QMixTorchPolicy(Policy): def to_batches(arr, dtype): new_shape = [B, T] + list(arr.shape[1:]) - return th.as_tensor( + return torch.as_tensor( np.reshape(arr, new_shape), dtype=dtype, device=self.device) - rewards = to_batches(rew, th.float) - actions = to_batches(act, th.long) - obs = to_batches(obs, th.float).reshape( + rewards = to_batches(rew, torch.float) + actions = to_batches(act, torch.long) + obs = to_batches(obs, torch.float).reshape( [B, T, self.n_agents, self.obs_size]) - action_mask = to_batches(action_mask, th.float) - next_obs = to_batches(next_obs, th.float).reshape( + action_mask = to_batches(action_mask, torch.float) + next_obs = to_batches(next_obs, torch.float).reshape( [B, T, self.n_agents, self.obs_size]) - next_action_mask = to_batches(next_action_mask, th.float) + next_action_mask = to_batches(next_action_mask, torch.float) if self.has_env_global_state: - env_global_state = to_batches(env_global_state, th.float) - next_env_global_state = to_batches(next_env_global_state, th.float) + env_global_state = to_batches(env_global_state, torch.float) + next_env_global_state = to_batches(next_env_global_state, + torch.float) # TODO(ekl) this treats group termination as individual termination - terminated = to_batches(dones, th.float).unsqueeze(2).expand( + terminated = to_batches(dones, torch.float).unsqueeze(2).expand( B, T, self.n_agents) # Create mask for where index is < unpadded sequence length filled = np.reshape( np.tile(np.arange(T, dtype=np.float32), B), [B, T]) < np.expand_dims(seq_lens, 1) - mask = th.as_tensor( - filled, dtype=th.float, device=self.device).unsqueeze(2).expand( + mask = torch.as_tensor( + filled, dtype=torch.float, device=self.device).unsqueeze(2).expand( B, T, self.n_agents) # Compute loss @@ -359,7 +373,7 @@ class QMixTorchPolicy(Policy): # Optimise self.optimiser.zero_grad() loss_out.backward() - grad_norm = th.nn.utils.clip_grad_norm_( + grad_norm = torch.nn.utils.clip_grad_norm_( self.params, self.config["grad_norm_clipping"]) self.optimiser.step() @@ -432,7 +446,7 @@ class QMixTorchPolicy(Policy): def _device_dict(self, state_dict): return { - k: th.as_tensor(v, device=self.device) + k: torch.as_tensor(v, device=self.device) for k, v in state_dict.items() } @@ -539,7 +553,7 @@ def _unroll_mac(model, obs_tensor): for t in range(T): q, h = _mac(model, obs_tensor[:, t], h) mac_out.append(q) - mac_out = th.stack(mac_out, dim=1) # Concat over time + mac_out = torch.stack(mac_out, dim=1) # Concat over time return mac_out diff --git a/rllib/agents/sac/sac_model.py b/rllib/agents/sac/sac_model.py index 78e268ee0..a7919f1f8 100644 --- a/rllib/agents/sac/sac_model.py +++ b/rllib/agents/sac/sac_model.py @@ -80,7 +80,7 @@ class SACModel(TFModelV2): shift_and_log_scale_diag = tf.keras.Sequential([ tf.keras.layers.Dense( units=hidden, - activation=getattr(tf.nn, actor_hidden_activation), + activation=getattr(tf.nn, actor_hidden_activation, None), name="action_hidden_{}".format(i)) for i, hidden in enumerate(actor_hiddens) ] + [ diff --git a/rllib/agents/sac/sac_policy.py b/rllib/agents/sac/sac_policy.py index 4ccb5cfc6..1b3eccb2b 100644 --- a/rllib/agents/sac/sac_policy.py +++ b/rllib/agents/sac/sac_policy.py @@ -86,6 +86,30 @@ def postprocess_trajectory(policy, return postprocess_nstep_and_prio(policy, sample_batch) +def unsquash_actions(actions, action_space): + # Use sigmoid to scale to [0,1], but also double magnitude of input to + # emulate behaviour of tanh activation used in SAC and TD3 papers. + sigmoid_out = tf.nn.sigmoid(2 * actions) + # Rescale to actual env policy scale + # (shape of sigmoid_out is [batch_size, dim_actions], so we reshape to + # get same dims) + action_range = (action_space.high - action_space.low)[None] + low_action = action_space.low[None] + unsquashed_actions = action_range * sigmoid_out + low_action + + return unsquashed_actions + + +def get_log_likelihood(policy, model, actions, input_dict, obs_space, + action_space, config): + model_out, _ = model({ + "obs": input_dict[SampleBatch.CUR_OBS], + "is_training": policy._get_is_training_placeholder(), + }, [], None) + log_pis = policy.model.log_pis_model((model_out, actions)) + return log_pis + + def build_action_output(policy, model, input_dict, obs_space, action_space, explore, config, timestep): model_out, _ = model({ @@ -93,28 +117,16 @@ def build_action_output(policy, model, input_dict, obs_space, action_space, "is_training": policy._get_is_training_placeholder(), }, [], None) - def unsquash_actions(actions): - # Use sigmoid to scale to [0,1], but also double magnitude of input to - # emulate behaviour of tanh activation used in SAC and TD3 papers. - sigmoid_out = tf.nn.sigmoid(2 * actions) - # Rescale to actual env policy scale - # (shape of sigmoid_out is [batch_size, dim_actions], so we reshape to - # get same dims) - action_range = (action_space.high - action_space.low)[None] - low_action = action_space.low[None] - unsquashed_actions = action_range * sigmoid_out + low_action - - return unsquashed_actions - squashed_stochastic_actions, log_pis = policy.model.get_policy_output( model_out, deterministic=False) stochastic_actions = squashed_stochastic_actions if config[ - "normalize_actions"] else unsquash_actions(squashed_stochastic_actions) + "normalize_actions"] else unsquash_actions(squashed_stochastic_actions, + action_space) squashed_deterministic_actions, _ = policy.model.get_policy_output( model_out, deterministic=True) deterministic_actions = squashed_deterministic_actions if config[ "normalize_actions"] else unsquash_actions( - squashed_deterministic_actions) + squashed_deterministic_actions, action_space) actions = tf.cond( tf.constant(explore) if isinstance(explore, bool) else explore, @@ -409,6 +421,7 @@ SACTFPolicy = build_tf_policy( make_model=build_sac_model, postprocess_fn=postprocess_trajectory, action_sampler_fn=build_action_output, + log_likelihood_fn=get_log_likelihood, loss_fn=actor_critic_loss, stats_fn=stats, gradients_fn=gradients, diff --git a/rllib/contrib/maddpg/maddpg_policy.py b/rllib/contrib/maddpg/maddpg_policy.py index 50aa88246..fe18fad42 100644 --- a/rllib/contrib/maddpg/maddpg_policy.py +++ b/rllib/contrib/maddpg/maddpg_policy.py @@ -245,7 +245,7 @@ class MADDPGTFPolicy(MADDPGPostprocessing, TFPolicy): config=config, sess=self.sess, obs_input=obs_ph_n[agent_id], - action_sampler=act_sampler, + sampled_action=act_sampler, loss=actor_loss + critic_loss, loss_inputs=loss_inputs) @@ -339,9 +339,7 @@ class MADDPGTFPolicy(MADDPGPostprocessing, TFPolicy): out = tf.concat(obs_n + act_n, axis=1) for hidden in hiddens: - out = tf.layers.dense( - out, units=hidden, activation=activation - ) + out = tf.layers.dense(out, units=hidden, activation=activation) feature = out out = tf.layers.dense(feature, units=1, activation=None) @@ -367,9 +365,7 @@ class MADDPGTFPolicy(MADDPGPostprocessing, TFPolicy): out = obs for hidden in hiddens: - out = tf.layers.dense( - out, units=hidden, activation=activation - ) + out = tf.layers.dense(out, units=hidden, activation=activation) feature = tf.layers.dense( out, units=act_space.shape[0], activation=None) sampler = tfp.distributions.RelaxedOneHotCategorical( diff --git a/rllib/examples/rollout_worker_custom_workflow.py b/rllib/examples/rollout_worker_custom_workflow.py index f3745f819..167722cef 100644 --- a/rllib/examples/rollout_worker_custom_workflow.py +++ b/rllib/examples/rollout_worker_custom_workflow.py @@ -10,10 +10,10 @@ import gym import ray from ray import tune -from ray.rllib.policy import Policy from ray.rllib.evaluation import RolloutWorker from ray.rllib.evaluation.metrics import collect_metrics from ray.rllib.policy.sample_batch import SampleBatch +from ray.rllib.policy.tests.test_policy import TestPolicy parser = argparse.ArgumentParser() parser.add_argument("--gpu", action="store_true") @@ -22,7 +22,7 @@ parser.add_argument("--num-workers", type=int, default=2) parser.add_argument("--num-cpus", type=int, default=0) -class CustomPolicy(Policy): +class CustomPolicy(TestPolicy): """Example of a custom policy written from scratch. You might find it more convenient to extend TF/TorchPolicy instead @@ -30,7 +30,7 @@ class CustomPolicy(Policy): """ def __init__(self, observation_space, action_space, config): - Policy.__init__(self, observation_space, action_space, config) + super().__init__(observation_space, action_space, config) # example parameter self.w = 1.0 diff --git a/rllib/models/tests/test_distributions.py b/rllib/models/tests/test_distributions.py index f6bdf818e..531111aff 100644 --- a/rllib/models/tests/test_distributions.py +++ b/rllib/models/tests/test_distributions.py @@ -1,14 +1,22 @@ -import unittest import numpy as np +from gym.spaces import Box +from scipy.stats import norm +from tensorflow.python.eager.context import eager_mode +import unittest -from ray.rllib.models.tf.tf_action_dist import Categorical +from ray.rllib.models.tf.tf_action_dist import Categorical, SquashedGaussian from ray.rllib.utils import try_import_tf +from ray.rllib.utils.numpy import MIN_LOG_NN_OUTPUT, MAX_LOG_NN_OUTPUT +from ray.rllib.utils.test_utils import check tf = try_import_tf() class TestDistributions(unittest.TestCase): + """Tests ActionDistribution classes.""" + def test_categorical(self): + """Tests the Categorical ActionDistribution (tf only).""" num_samples = 100000 logits = tf.placeholder(tf.float32, shape=(None, 10)) z = 8 * (np.random.rand(10) - 0.5) @@ -24,6 +32,76 @@ class TestDistributions(unittest.TestCase): probs = np.exp(z) / np.sum(np.exp(z)) self.assertTrue(np.sum(np.abs(probs - counts / num_samples)) <= 0.01) + def test_squashed_gaussian(self): + """Tests the SquashedGaussia ActionDistribution (tf-eager only).""" + with eager_mode(): + input_space = Box(-1.0, 1.0, shape=(200, 10)) + low, high = -2.0, 1.0 + + # Batch of size=n and deterministic. + inputs = input_space.sample() + means, _ = np.split(inputs, 2, axis=-1) + squashed_distribution = SquashedGaussian( + inputs, {}, low=low, high=high) + expected = ((np.tanh(means) + 1.0) / 2.0) * (high - low) + low + # Sample n times, expect always mean value (deterministic draw). + out = squashed_distribution.deterministic_sample() + check(out, expected) + + # Batch of size=n and non-deterministic -> expect roughly the mean. + inputs = input_space.sample() + means, log_stds = np.split(inputs, 2, axis=-1) + squashed_distribution = SquashedGaussian( + inputs, {}, low=low, high=high) + expected = ((np.tanh(means) + 1.0) / 2.0) * (high - low) + low + values = squashed_distribution.sample() + self.assertTrue(np.max(values) < high) + self.assertTrue(np.min(values) > low) + + check(np.mean(values), expected.mean(), decimals=1) + + # Test log-likelihood outputs. + sampled_action_logp = squashed_distribution.sampled_action_logp() + # Convert to parameters for distr. + stds = np.exp( + np.clip(log_stds, MIN_LOG_NN_OUTPUT, MAX_LOG_NN_OUTPUT)) + # Unsquash values, then get log-llh from regular gaussian. + unsquashed_values = np.arctanh((values - low) / + (high - low) * 2.0 - 1.0) + log_prob_unsquashed = \ + np.sum(np.log(norm.pdf(unsquashed_values, means, stds)), -1) + log_prob = log_prob_unsquashed - \ + np.sum(np.log(1 - np.tanh(unsquashed_values) ** 2), + axis=-1) + check(np.mean(sampled_action_logp), np.mean(log_prob), rtol=0.01) + + # NN output. + means = np.array([[0.1, 0.2, 0.3, 0.4, 50.0], + [-0.1, -0.2, -0.3, -0.4, -1.0]]) + log_stds = np.array([[0.8, -0.2, 0.3, -1.0, 2.0], + [0.7, -0.3, 0.4, -0.9, 2.0]]) + squashed_distribution = SquashedGaussian( + np.concatenate([means, log_stds], axis=-1), {}, + low=low, + high=high) + # Convert to parameters for distr. + stds = np.exp(log_stds) + # Values to get log-likelihoods for. + values = np.array([[0.9, 0.2, 0.4, -0.1, -1.05], + [-0.9, -0.2, 0.4, -0.1, -1.05]]) + + # Unsquash values, then get log-llh from regular gaussian. + unsquashed_values = np.arctanh((values - low) / + (high - low) * 2.0 - 1.0) + log_prob_unsquashed = \ + np.sum(np.log(norm.pdf(unsquashed_values, means, stds)), -1) + log_prob = log_prob_unsquashed - \ + np.sum(np.log(1 - np.tanh(unsquashed_values) ** 2), + axis=-1) + + out = squashed_distribution.logp(values) + check(out, log_prob) + if __name__ == "__main__": import unittest diff --git a/rllib/models/tf/tf_action_dist.py b/rllib/models/tf/tf_action_dist.py index 8f21fd708..2bac4f4bc 100644 --- a/rllib/models/tf/tf_action_dist.py +++ b/rllib/models/tf/tf_action_dist.py @@ -3,10 +3,12 @@ import functools from ray.rllib.models.action_dist import ActionDistribution from ray.rllib.utils.annotations import override, DeveloperAPI -from ray.rllib.utils import try_import_tf +from ray.rllib.utils import try_import_tf, try_import_tfp, SMALL_NUMBER, \ + MIN_LOG_NN_OUTPUT, MAX_LOG_NN_OUTPUT from ray.rllib.utils.tuple_actions import TupleActions tf = try_import_tf() +tfp = try_import_tfp() @DeveloperAPI @@ -188,6 +190,79 @@ class DiagGaussian(TFActionDistribution): return np.prod(action_space.shape) * 2 +class SquashedGaussian(TFActionDistribution): + """A tanh-squashed Gaussian distribution defined by: mean, std, low, high. + + The distribution will never return low or high exactly, but + `low`+SMALL_NUMBER or `high`-SMALL_NUMBER respectively. + """ + + def __init__(self, inputs, model, low=-1.0, high=1.0): + """Parameterizes the distribution via `inputs`. + + Args: + low (float): The lowest possible sampling value + (excluding this value). + high (float): The highest possible sampling value + (excluding this value). + """ + assert tfp is not None + loc, log_scale = tf.split(inputs, 2, axis=-1) + # Clip `scale` values (coming from NN) to reasonable values. + log_scale = tf.clip_by_value(log_scale, MIN_LOG_NN_OUTPUT, + MAX_LOG_NN_OUTPUT) + scale = tf.exp(log_scale) + self.distr = tfp.distributions.Normal(loc=loc, scale=scale) + assert np.all(np.less(low, high)) + self.low = low + self.high = high + super().__init__(inputs, model) + + @override(TFActionDistribution) + def sampled_action_logp(self): + unsquashed_values = self._unsquash(self.sample_op) + log_prob = tf.reduce_sum( + self.distr.log_prob(unsquashed_values), axis=-1) + unsquashed_values_tanhd = tf.math.tanh(unsquashed_values) + log_prob -= tf.math.reduce_sum( + tf.math.log(1 - unsquashed_values_tanhd**2 + SMALL_NUMBER), + axis=-1) + return log_prob + + @override(ActionDistribution) + def deterministic_sample(self): + mean = self.distr.mean() + return self._squash(mean) + + @override(TFActionDistribution) + def _build_sample_op(self): + return self._squash(self.distr.sample()) + + @override(ActionDistribution) + def logp(self, x): + unsquashed_values = self._unsquash(x) + log_prob = tf.reduce_sum( + self.distr.log_prob(value=unsquashed_values), axis=-1) + unsquashed_values_tanhd = tf.math.tanh(unsquashed_values) + log_prob -= tf.math.reduce_sum( + tf.math.log(1 - unsquashed_values_tanhd**2 + SMALL_NUMBER), + axis=-1) + return log_prob + + def _squash(self, raw_values): + # Make sure raw_values are not too high/low (such that tanh would + # return exactly 1.0/-1.0, which would lead to +/-inf log-probs). + return (tf.clip_by_value( + tf.math.tanh(raw_values), + -1.0 + SMALL_NUMBER, + 1.0 - SMALL_NUMBER) + 1.0) / 2.0 * (self.high - self.low) + \ + self.low + + def _unsquash(self, values): + return tf.math.atanh((values - self.low) / + (self.high - self.low) * 2.0 - 1.0) + + class Deterministic(TFActionDistribution): """Action distribution that returns the input values directly. diff --git a/rllib/offline/off_policy_estimator.py b/rllib/offline/off_policy_estimator.py index 82c49b898..048e5d933 100644 --- a/rllib/offline/off_policy_estimator.py +++ b/rllib/offline/off_policy_estimator.py @@ -1,7 +1,7 @@ from collections import namedtuple import logging -from ray.rllib.policy.sample_batch import MultiAgentBatch +from ray.rllib.policy.sample_batch import MultiAgentBatch, SampleBatch from ray.rllib.utils.annotations import DeveloperAPI logger = logging.getLogger(__name__) @@ -57,24 +57,13 @@ class OffPolicyEstimator: if k.startswith("state_in_"): num_state_inputs += 1 state_keys = ["state_in_{}".format(i) for i in range(num_state_inputs)] - - # TODO(sven): This is wrong. The info["action_prob"] needs to refer - # to the old action (from the batch). It might be the action-prob of - # a different action (as the policy has changed). - # https://github.com/ray-project/ray/issues/7107 - _, _, info = self.policy.compute_actions( - obs_batch=batch["obs"], + log_likelihoods = self.policy.compute_log_likelihoods( + actions=batch[SampleBatch.ACTIONS], + obs_batch=batch[SampleBatch.CUR_OBS], state_batches=[batch[k] for k in state_keys], - prev_action_batch=batch.data.get("prev_action"), - prev_reward_batch=batch.data.get("prev_reward"), - info_batch=batch.data.get("info")) - if "action_prob" not in info: - raise ValueError( - "Off-policy estimation is not possible unless the policy " - "returns action probabilities when computing actions (i.e., " - "the 'action_prob' key is output by the policy). You " - "can set `input_evaluation: []` to resolve this.") - return info["action_prob"] + prev_action_batch=batch.data.get(SampleBatch.PREV_ACTIONS), + prev_reward_batch=batch.data.get(SampleBatch.PREV_REWARDS)) + return log_likelihoods @DeveloperAPI def process(self, batch): diff --git a/rllib/policy/dynamic_tf_policy.py b/rllib/policy/dynamic_tf_policy.py index 37c3814f8..52c89e797 100644 --- a/rllib/policy/dynamic_tf_policy.py +++ b/rllib/policy/dynamic_tf_policy.py @@ -47,6 +47,7 @@ class DynamicTFPolicy(TFPolicy): before_loss_init=None, make_model=None, action_sampler_fn=None, + log_likelihood_fn=None, existing_inputs=None, existing_model=None, get_batch_divisibility_req=None, @@ -69,10 +70,14 @@ class DynamicTFPolicy(TFPolicy): given (policy, obs_space, action_space, config). All policy variables should be created in this function. If not specified, a default model will be created. - action_sampler_fn (func): optional function that returns a - tuple of action and action logp tensors given + action_sampler_fn (Optional[callable]): An optional callable + returning a tuple of action and action prob tensors given (policy, model, input_dict, obs_space, action_space, config). - If not specified, a default action distribution will be used. + If None, a default action distribution will be used. + log_likelihood_fn (Optional[callable]): A callable, + returning a log-likelihood op. + If None, a default class is used and distribution inputs + (for parameterization) will be generated by a model call. existing_inputs (OrderedDict): When copying a policy, this specifies an existing dict of placeholders to use instead of defining new ones @@ -98,11 +103,13 @@ class DynamicTFPolicy(TFPolicy): if self._obs_include_prev_action_reward: prev_actions = existing_inputs[SampleBatch.PREV_ACTIONS] prev_rewards = existing_inputs[SampleBatch.PREV_REWARDS] + action_input = existing_inputs[SampleBatch.ACTIONS] else: obs = tf.placeholder( tf.float32, shape=[None] + list(obs_space.shape), name="observation") + action_input = ModelCatalog.get_action_placeholder(action_space) if self._obs_include_prev_action_reward: prev_actions = ModelCatalog.get_action_placeholder( action_space, "prev_action") @@ -121,16 +128,16 @@ class DynamicTFPolicy(TFPolicy): self._seq_lens = tf.placeholder( dtype=tf.int32, shape=[None], name="seq_lens") - # Setup model if action_sampler_fn: if not make_model: raise ValueError( - "make_model is required if action_sampler_fn is given") + "`make_model` is required if `action_sampler_fn` is given") self.dist_class = None else: self.dist_class, logit_dim = ModelCatalog.get_action_dist( action_space, self.config["model"]) + # Setup model if existing_model: self.model = existing_model elif make_model: @@ -165,35 +172,49 @@ class DynamicTFPolicy(TFPolicy): # Setup custom action sampler. if action_sampler_fn: - action_sampler, action_logp = action_sampler_fn( + sampled_action, sampled_action_logp = action_sampler_fn( self, self.model, self._input_dict, obs_space, action_space, explore, config, timestep) # Create a default action sampler. else: - # Using an exporation setup. - action_sampler, action_logp = \ + # Using an exploration setup. + sampled_action, sampled_action_logp = \ self.exploration.get_exploration_action( model_out, + self.dist_class, self.model, - action_dist_class=self.dist_class, explore=explore, timestep=timestep) - # Phase 1 init + # Phase 1 init. sess = tf.get_default_session() or tf.Session() if get_batch_divisibility_req: batch_divisibility_req = get_batch_divisibility_req(self) else: batch_divisibility_req = 1 + # Generate the log-likelihood op. + log_likelihood = None + # From a given function. + if log_likelihood_fn: + log_likelihood = log_likelihood_fn(self, self.model, action_input, + self._input_dict, obs_space, + action_space, config) + # Create default, iff we have a distribution class. + elif self.dist_class is not None: + log_likelihood = self.dist_class(model_out, + self.model).logp(action_input) + super().__init__( obs_space, action_space, config, sess, obs_input=obs, - action_sampler=action_sampler, - action_logp=action_logp, + action_input=action_input, # for logp calculations + sampled_action=sampled_action, + sampled_action_logp=sampled_action_logp, + log_likelihood=log_likelihood, loss=None, # dynamically initialized on run loss_inputs=[], model=self.model, diff --git a/rllib/policy/eager_tf_policy.py b/rllib/policy/eager_tf_policy.py index 8ade4a83f..a01aa0d65 100644 --- a/rllib/policy/eager_tf_policy.py +++ b/rllib/policy/eager_tf_policy.py @@ -14,7 +14,7 @@ from ray.rllib.policy.policy import ACTION_PROB, ACTION_LOGP from ray.rllib.utils import add_mixins from ray.rllib.utils.annotations import override from ray.rllib.utils.debug import log_once -from ray.rllib.utils import try_import_tf +from ray.rllib.utils.framework import try_import_tf tf = try_import_tf() logger = logging.getLogger(__name__) @@ -176,13 +176,14 @@ def build_eager_tf_policy(name, after_init=None, make_model=None, action_sampler_fn=None, + log_likelihood_fn=None, mixins=None, obs_include_prev_action_reward=True, get_batch_divisibility_req=None): """Build an eager TF policy. An eager policy runs all operations in eager mode, which makes debugging - much simpler, but is lower performance. + much simpler, but has lower performance. You shouldn't need to call this directly. Rather, prefer to build a TF graph policy and use set {"eager": true} in the trainer config to have @@ -208,12 +209,12 @@ def build_eager_tf_policy(name, before_init(self, observation_space, action_space, config) self.config = config + self.dist_class = None if action_sampler_fn: if not make_model: - raise ValueError( - "make_model is required if action_sampler_fn is given") - self.dist_class = None + raise ValueError("`make_model` is required if " + "`action_sampler_fn` is given") else: self.dist_class, logit_dim = ModelCatalog.get_action_dist( action_space, self.config["model"]) @@ -235,13 +236,14 @@ def build_eager_tf_policy(name, for s in self.model.get_initial_state() ] - self.model({ + input_dict = { SampleBatch.CUR_OBS: tf.convert_to_tensor( np.array([observation_space.sample()])), SampleBatch.PREV_ACTIONS: tf.convert_to_tensor( [_flatten_action(action_space.sample())]), SampleBatch.PREV_REWARDS: tf.convert_to_tensor([0.]), - }, self._state_in, tf.convert_to_tensor([1])) + } + self.model(input_dict, self._state_in, tf.convert_to_tensor([1])) if before_loss_init: before_loss_init(self, observation_space, action_space, config) @@ -312,8 +314,8 @@ def build_eager_tf_policy(name, n = len(obs_batch) else: n = obs_batch.shape[0] - seq_lens = tf.ones(n, dtype=tf.int32) + input_dict = { SampleBatch.CUR_OBS: tf.convert_to_tensor(obs_batch), "is_training": tf.constant(False), @@ -326,24 +328,24 @@ def build_eager_tf_policy(name, prev_reward_batch), }) - with tf.variable_creator_scope(_disallow_var_creation): - model_out, state_out = self.model(input_dict, state_batches, - seq_lens) - # Custom sampler fn given (which may handle self.exploration). if action_sampler_fn is not None: + state_out = [] action, logp = action_sampler_fn( self, self.model, input_dict, self.observation_space, self.action_space, explore, self.config, timestep) # Use Exploration object. else: - action, logp = self.exploration.get_exploration_action( - model_out, - self.model, - action_dist_class=self.dist_class, - explore=explore, - timestep=timestep - if timestep is not None else self.global_timestep) + with tf.variable_creator_scope(_disallow_var_creation): + model_out, state_out = self.model(input_dict, + state_batches, seq_lens) + action, logp = self.exploration.get_exploration_action( + model_out, + self.dist_class, + self.model, + explore=explore, + timestep=timestep + if timestep is not None else self.global_timestep) extra_fetches = {} if logp is not None: @@ -359,6 +361,41 @@ def build_eager_tf_policy(name, return action, state_out, extra_fetches + @override(Policy) + def compute_log_likelihoods(self, + actions, + obs_batch, + state_batches=None, + prev_action_batch=None, + prev_reward_batch=None): + + seq_lens = tf.ones(len(obs_batch), dtype=tf.int32) + input_dict = { + SampleBatch.CUR_OBS: tf.convert_to_tensor(obs_batch), + "is_training": tf.constant(False), + } + if obs_include_prev_action_reward: + input_dict.update({ + SampleBatch.PREV_ACTIONS: tf.convert_to_tensor( + prev_action_batch), + SampleBatch.PREV_REWARDS: tf.convert_to_tensor( + prev_reward_batch), + }) + + # Custom log_likelihood function given. + if log_likelihood_fn: + log_likelihoods = log_likelihood_fn( + self, self.model, actions, input_dict, + self.observation_space, self.action_space, self.config) + # Default log-likelihood calculation. + else: + dist_inputs, _ = self.model(input_dict, state_batches, + seq_lens) + action_dist = self.dist_class(dist_inputs, self.model) + log_likelihoods = action_dist.logp(actions) + + return log_likelihoods + @override(Policy) def apply_gradients(self, gradients): self._apply_gradients( diff --git a/rllib/policy/policy.py b/rllib/policy/policy.py index 7d5fe0722..ba8ed404a 100644 --- a/rllib/policy/policy.py +++ b/rllib/policy/policy.py @@ -164,6 +164,34 @@ class Policy(metaclass=ABCMeta): return action, [s[0] for s in state_out], \ {k: v[0] for k, v in info.items()} + @abstractmethod + @DeveloperAPI + def compute_log_likelihoods(self, + actions, + obs_batch, + state_batches=None, + prev_action_batch=None, + prev_reward_batch=None): + """Computes the log-prob/likelihood for a given action and observation. + + Args: + actions (Union[List,np.ndarray]): Batch of actions, for which to + retrieve the log-probs/likelihoods (given all other inputs: + obs, states, ..). + obs_batch (Union[List,np.ndarray]): Batch of observations. + state_batches (Optional[list]): List of RNN state input batches, + if any. + prev_action_batch (Optional[List,np.ndarray]): Batch of previous + action values. + prev_reward_batch (Optional[List,np.ndarray]): Batch of previous + rewards. + + Returns: + log-likelihoods (np.ndarray): Batch of log probs/likelihoods, with + shape: [BATCH_SIZE]. + """ + raise NotImplementedError + @DeveloperAPI def postprocess_trajectory(self, sample_batch, diff --git a/rllib/policy/tests/test_compute_log_likelihoods.py b/rllib/policy/tests/test_compute_log_likelihoods.py new file mode 100644 index 000000000..780b902ca --- /dev/null +++ b/rllib/policy/tests/test_compute_log_likelihoods.py @@ -0,0 +1,153 @@ +import numpy as np +from scipy.stats import norm +import unittest + +import ray.rllib.agents.dqn as dqn +import ray.rllib.agents.ppo as ppo +import ray.rllib.agents.sac as sac +from ray.rllib.utils.framework import try_import_tf +from ray.rllib.utils.test_utils import check +from ray.rllib.utils.numpy import one_hot, fc, MIN_LOG_NN_OUTPUT, \ + MAX_LOG_NN_OUTPUT + +tf = try_import_tf() + + +def test_log_likelihood(run, + config, + prev_a=None, + continuous=False, + layer_key=("fc", (0, 4)), + logp_func=None): + config = config.copy() + # Run locally. + config["num_workers"] = 0 + # Env setup. + if continuous: + env = "Pendulum-v0" + obs_batch = preprocessed_obs_batch = np.array([[0.0, 0.1, -0.1]]) + else: + env = "FrozenLake-v0" + config["env_config"] = {"is_slippery": False, "map_name": "4x4"} + obs_batch = np.array([0]) + preprocessed_obs_batch = one_hot(obs_batch, depth=16) + + # Use Soft-Q for DQNs. + if run is dqn.DQNTrainer: + config["exploration_config"] = {"type": "SoftQ", "temperature": 0.5} + + prev_r = None if prev_a is None else np.array(0.0) + + # Test against all frameworks. + for fw in ["tf", "eager", "torch"]: + if run in [dqn.DQNTrainer, sac.SACTrainer] and fw == "torch": + continue + print("Testing {} with framework={}".format(run, fw)) + config["eager"] = True if fw == "eager" else False + config["use_pytorch"] = True if fw == "torch" else False + + trainer = run(config=config, env=env) + policy = trainer.get_policy() + vars = policy.get_weights() + # Sample n actions, then roughly check their logp against their + # counts. + num_actions = 500 + actions = [] + for _ in range(num_actions): + # Single action from single obs. + actions.append( + trainer.compute_action( + obs_batch[0], + prev_action=prev_a, + prev_reward=prev_r, + explore=True)) + + # Test 50 actions for their log-likelihoods vs expected values. + if continuous: + for idx in range(50): + a = actions[idx] + if fw == "tf" or fw == "eager": + if isinstance(vars, list): + expected_mean_logstd = fc( + fc(obs_batch, vars[layer_key[1][0]]), + vars[layer_key[1][1]]) + else: + expected_mean_logstd = fc( + fc( + obs_batch, + vars["default_policy/{}_1/kernel".format( + layer_key[0])]), + vars["default_policy/{}_out/kernel".format( + layer_key[0])]) + else: + expected_mean_logstd = fc( + fc(obs_batch, + vars["_hidden_layers.0._model.0.weight"]), + vars["_logits._model.0.weight"]) + mean, log_std = np.split(expected_mean_logstd, 2, axis=-1) + if logp_func is None: + expected_logp = np.log(norm.pdf(a, mean, np.exp(log_std))) + else: + expected_logp = logp_func(mean, log_std, a) + logp = policy.compute_log_likelihoods( + np.array([a]), + preprocessed_obs_batch, + prev_action_batch=np.array([prev_a]), + prev_reward_batch=np.array([prev_r])) + check(logp, expected_logp[0], rtol=0.2) + # Test all available actions for their logp values. + else: + for a in [0, 1, 2, 3]: + count = actions.count(a) + expected_logp = np.log(count / num_actions) + logp = policy.compute_log_likelihoods( + np.array([a]), + preprocessed_obs_batch, + prev_action_batch=np.array([prev_a]), + prev_reward_batch=np.array([prev_r])) + check(logp, expected_logp, rtol=0.3) + + +class TestComputeLogLikelihood(unittest.TestCase): + def test_dqn(self): + """Tests, whether DQN correctly computes logp in soft-q mode.""" + test_log_likelihood(dqn.DQNTrainer, dqn.DEFAULT_CONFIG) + + def test_ppo_cont(self): + """Tests PPO's (cont. actions) compute_log_likelihoods method.""" + config = ppo.DEFAULT_CONFIG.copy() + config["model"]["fcnet_hiddens"] = [10] + config["model"]["fcnet_activation"] = "linear" + prev_a = np.array([0.0]) + test_log_likelihood(ppo.PPOTrainer, config, prev_a, continuous=True) + + def test_ppo_discr(self): + """Tests PPO's (discr. actions) compute_log_likelihoods method.""" + prev_a = np.array(0) + test_log_likelihood(ppo.PPOTrainer, ppo.DEFAULT_CONFIG, prev_a) + + def test_sac(self): + """Tests SAC's compute_log_likelihoods method.""" + config = sac.DEFAULT_CONFIG.copy() + config["policy_model"]["hidden_layer_sizes"] = [10] + config["policy_model"]["hidden_activation"] = "linear" + prev_a = np.array([0.0]) + + def logp_func(means, log_stds, values, low=-1.0, high=1.0): + stds = np.exp( + np.clip(log_stds, MIN_LOG_NN_OUTPUT, MAX_LOG_NN_OUTPUT)) + unsquashed_values = np.arctanh((values - low) / + (high - low) * 2.0 - 1.0) + log_prob_unsquashed = \ + np.sum(np.log(norm.pdf(unsquashed_values, means, stds)), -1) + return log_prob_unsquashed - \ + np.sum(np.log(1 - np.tanh(unsquashed_values) ** 2), + axis=-1) + + test_log_likelihood( + sac.SACTrainer, + config, + prev_a, + continuous=True, + layer_key=("sequential/action", (0, 2)), + logp_func=logp_func) diff --git a/rllib/policy/tests/test_policy.py b/rllib/policy/tests/test_policy.py index 481410137..198f9ea67 100644 --- a/rllib/policy/tests/test_policy.py +++ b/rllib/policy/tests/test_policy.py @@ -1,6 +1,7 @@ import random from ray.rllib.policy.policy import Policy +from ray.rllib.utils.annotations import override class TestPolicy(Policy): @@ -9,6 +10,7 @@ class TestPolicy(Policy): and implements all other abstract methods of Policy with "pass". """ + @override(Policy) def compute_actions(self, obs_batch, state_batches=None, @@ -19,3 +21,12 @@ class TestPolicy(Policy): timestep=None, **kwargs): return [random.choice([0, 1])] * len(obs_batch), [], {} + + @override(Policy) + def compute_log_likelihoods(self, + actions, + obs_batch, + state_batches=None, + prev_action_batch=None, + prev_reward_batch=None): + return [random.random()] * len(obs_batch) diff --git a/rllib/policy/tf_policy.py b/rllib/policy/tf_policy.py index cc4b6dbf6..516b7288a 100644 --- a/rllib/policy/tf_policy.py +++ b/rllib/policy/tf_policy.py @@ -13,9 +13,9 @@ from ray.rllib.models.modelv2 import ModelV2 from ray.rllib.utils.annotations import override, DeveloperAPI from ray.rllib.utils.debug import log_once, summarize from ray.rllib.utils.exploration.exploration import Exploration +from ray.rllib.utils.framework import try_import_tf from ray.rllib.utils.schedules import ConstantSchedule, PiecewiseSchedule from ray.rllib.utils.tf_run_builder import TFRunBuilder -from ray.rllib.utils import try_import_tf tf = try_import_tf() logger = logging.getLogger(__name__) @@ -38,7 +38,7 @@ class TFPolicy(Policy): Examples: >>> policy = TFPolicySubclass( - sess, obs_input, action_sampler, loss, loss_inputs) + sess, obs_input, sampled_action, loss, loss_inputs) >>> print(policy.compute_actions([1, 0, 2])) (array([0, 1, 1]), [], {}) @@ -54,11 +54,13 @@ class TFPolicy(Policy): config, sess, obs_input, - action_sampler, + sampled_action, loss, loss_inputs, model=None, - action_logp=None, + sampled_action_logp=None, + action_input=None, + log_likelihood=None, state_inputs=None, state_outputs=None, prev_action_input=None, @@ -78,7 +80,7 @@ class TFPolicy(Policy): sess (Session): The TensorFlow session to use. obs_input (Tensor): Input placeholder for observations, of shape [BATCH_SIZE, obs...]. - action_sampler (Tensor): Tensor for sampling an action, of shape + sampled_action (Tensor): Tensor for sampling an action, of shape [BATCH_SIZE, action...] loss (Tensor): Scalar policy loss output tensor. loss_inputs (list): A (name, placeholder) tuple for each loss @@ -89,7 +91,12 @@ class TFPolicy(Policy): placeholders during loss computation. model (rllib.models.Model): used to integrate custom losses and stats from user-defined RLlib models. - action_logp (Tensor): log probability of the sampled action. + sampled_action_logp (Tensor): log probability of the sampled + action. + action_input (Optional[Tensor]): Input placeholder for actions for + logp/log-likelihood calculations. + log_likelihood (Optional[Tensor]): Tensor to calculate the + log_likelihood (given action_input and obs_input). state_inputs (list): list of RNN state input Tensors. state_outputs (list): list of RNN state output Tensors. prev_action_input (Tensor): placeholder for previous actions @@ -115,13 +122,16 @@ class TFPolicy(Policy): self._obs_input = obs_input self._prev_action_input = prev_action_input self._prev_reward_input = prev_reward_input - self._action = action_sampler + self._sampled_action = sampled_action self._is_training = self._get_is_training_placeholder() self._is_exploring = explore if explore is not None else \ tf.placeholder_with_default(True, (), name="is_exploring") - self._action_logp = action_logp - self._action_prob = (tf.exp(self._action_logp) - if self._action_logp is not None else None) + self._sampled_action_logp = sampled_action_logp + self._sampled_action_prob = (tf.exp(self._sampled_action_logp) + if self._sampled_action_logp is not None + else None) + self._action_input = action_input # For logp calculations. + self._log_likelihood = log_likelihood self._state_inputs = state_inputs or [] self._state_outputs = state_outputs or [] self._seq_lens = seq_lens @@ -152,6 +162,9 @@ class TFPolicy(Policy): raise ValueError( "seq_lens tensor must be given if state inputs are defined") + # Generate the log-likelihood calculator. + self._log_likelihood = log_likelihood + def variables(self): """Return the list of all savable variables for this policy.""" return self.model.variables() @@ -255,6 +268,46 @@ class TFPolicy(Policy): # Execute session run to get action (and other fetches). return builder.get(fetches) + @override(Policy) + def compute_log_likelihoods(self, + actions, + obs_batch, + state_batches=None, + prev_action_batch=None, + prev_reward_batch=None): + if self._log_likelihood is None: + raise ValueError("Cannot compute log-prob/likelihood w/o a " + "self._log_likelihood op!") + + # Do the forward pass through the model to capture the parameters + # for the action distribution, then do a logp on that distribution. + builder = TFRunBuilder(self._sess, "compute_log_likelihoods") + # Feed actions (for which we want logp values) into graph. + builder.add_feed_dict({self._action_input: actions}) + # Feed observations. + builder.add_feed_dict({self._obs_input: obs_batch}) + # Internal states. + state_batches = state_batches or [] + if len(self._state_inputs) != len(state_batches): + raise ValueError( + "Must pass in RNN state batches for placeholders {}, got {}". + format(self._state_inputs, state_batches)) + builder.add_feed_dict( + {k: v + for k, v in zip(self._state_inputs, state_batches)}) + if state_batches: + builder.add_feed_dict({self._seq_lens: np.ones(len(obs_batch))}) + # Prev-a and r. + if self._prev_action_input is not None and \ + prev_action_batch is not None: + builder.add_feed_dict({self._prev_action_input: prev_action_batch}) + if self._prev_reward_input is not None and \ + prev_reward_batch is not None: + builder.add_feed_dict({self._prev_reward_input: prev_reward_batch}) + # Fetch the log_likelihoods output and return. + fetches = builder.add_fetches([self._log_likelihood]) + return builder.get(fetches)[0] + @override(Policy) def compute_gradients(self, postprocessed_batch): assert self.loss_initialized() @@ -341,9 +394,9 @@ class TFPolicy(Policy): By default we only return action probability info (if present). """ ret = {} - if self._action_logp is not None: - ret[ACTION_PROB] = self._action_prob - ret[ACTION_LOGP] = self._action_logp + if self._sampled_action_logp is not None: + ret[ACTION_PROB] = self._sampled_action_prob + ret[ACTION_LOGP] = self._sampled_action_logp return ret @DeveloperAPI @@ -441,7 +494,7 @@ class TFPolicy(Policy): # build output signatures output_signature = self._extra_output_signature_def() output_signature["actions"] = \ - tf.saved_model.utils.build_tensor_info(self._action) + tf.saved_model.utils.build_tensor_info(self._sampled_action) for state_output in self._state_outputs: output_signature[state_output.name] = \ tf.saved_model.utils.build_tensor_info(state_output) @@ -463,6 +516,7 @@ class TFPolicy(Policy): episodes=None, explore=None, timestep=None): + explore = explore if explore is not None else self.config["explore"] state_batches = state_batches or [] @@ -485,7 +539,8 @@ class TFPolicy(Policy): if timestep is not None: builder.add_feed_dict({self._timestep: timestep}) builder.add_feed_dict(dict(zip(self._state_inputs, state_batches))) - fetches = builder.add_fetches([self._action] + self._state_outputs + + fetches = builder.add_fetches([self._sampled_action] + + self._state_outputs + [self.extra_compute_action_fetches()]) return fetches[0], fetches[1:-1], fetches[-1] diff --git a/rllib/policy/tf_policy_template.py b/rllib/policy/tf_policy_template.py index d5387ba20..e704410fe 100644 --- a/rllib/policy/tf_policy_template.py +++ b/rllib/policy/tf_policy_template.py @@ -26,6 +26,7 @@ def build_tf_policy(name, after_init=None, make_model=None, action_sampler_fn=None, + log_likelihood_fn=None, mixins=None, get_batch_divisibility_req=None, obs_include_prev_action_reward=True): @@ -81,10 +82,14 @@ def build_tf_policy(name, given (policy, obs_space, action_space, config). All policy variables should be created in this function. If not specified, a default model will be created. - action_sampler_fn (func): optional function that returns a - tuple of action and action prob tensors given + action_sampler_fn (Optional[callable]): An optional callable returning + a tuple of action and action prob tensors given (policy, model, input_dict, obs_space, action_space, config). - If not specified, a default action distribution will be used. + If None, a default action distribution will be used. + log_likelihood_fn (Optional[callable]): A callable, + returning a log-likelihood op. + If None, a default class is used and distribution inputs + (for parameterization) will be generated by a model call. mixins (list): list of any class mixins for the returned policy class. These mixins will be applied in order and will have higher precedence than the DynamicTFPolicy class @@ -132,6 +137,7 @@ def build_tf_policy(name, before_loss_init=before_loss_init_wrapper, make_model=make_model, action_sampler_fn=action_sampler_fn, + log_likelihood_fn=log_likelihood_fn, existing_model=existing_model, existing_inputs=existing_inputs, get_batch_divisibility_req=get_batch_divisibility_req, diff --git a/rllib/policy/torch_policy.py b/rllib/policy/torch_policy.py index e0782aab1..e2291a431 100644 --- a/rllib/policy/torch_policy.py +++ b/rllib/policy/torch_policy.py @@ -4,10 +4,10 @@ import time from ray.rllib.policy.policy import Policy, LEARNER_STATS_KEY, ACTION_PROB, \ ACTION_LOGP from ray.rllib.policy.sample_batch import SampleBatch -from ray.rllib.utils import try_import_torch from ray.rllib.utils.annotations import override, DeveloperAPI -from ray.rllib.utils.tracking_dict import UsageTrackingDict +from ray.rllib.utils.framework import try_import_torch from ray.rllib.utils.schedules import ConstantSchedule, PiecewiseSchedule +from ray.rllib.utils.tracking_dict import UsageTrackingDict torch, _ = try_import_torch() @@ -86,7 +86,7 @@ class TorchPolicy(Policy): action_dist = None actions, logp = \ self.exploration.get_exploration_action( - logits, self.model, self.dist_class, explore, + logits, self.dist_class, self.model, explore, timestep if timestep is not None else self.global_timestep) input_dict[SampleBatch.ACTIONS] = actions @@ -101,6 +101,28 @@ class TorchPolicy(Policy): return (actions.cpu().numpy(), [h.cpu().numpy() for h in state], extra_action_out) + @override(Policy) + def compute_log_likelihoods(self, + actions, + obs_batch, + state_batches=None, + prev_action_batch=None, + prev_reward_batch=None): + with torch.no_grad(): + input_dict = self._lazy_tensor_dict({ + SampleBatch.CUR_OBS: obs_batch, + SampleBatch.ACTIONS: actions + }) + if prev_action_batch: + input_dict[SampleBatch.PREV_ACTIONS] = prev_action_batch + if prev_reward_batch: + input_dict[SampleBatch.PREV_REWARDS] = prev_reward_batch + + parameters, _ = self.model(input_dict, state_batches, [1]) + action_dist = self.dist_class(parameters, self.model) + log_likelihoods = action_dist.logp(input_dict[SampleBatch.ACTIONS]) + return log_likelihoods + @override(Policy) def learn_on_batch(self, postprocessed_batch): train_batch = self._lazy_tensor_dict(postprocessed_batch) diff --git a/rllib/tests/test_eager_support.py b/rllib/tests/test_eager_support.py index 348b0461c..0e36557fe 100644 --- a/rllib/tests/test_eager_support.py +++ b/rllib/tests/test_eager_support.py @@ -63,6 +63,9 @@ class TestEagerSupport(unittest.TestCase): "timesteps_per_iteration": 100 }) + def testSAC(self): + check_support("SAC", {"num_workers": 0}) + if __name__ == "__main__": import pytest diff --git a/rllib/tests/test_explorations.py b/rllib/tests/test_explorations.py index f92fdec5b..54cd2d3eb 100644 --- a/rllib/tests/test_explorations.py +++ b/rllib/tests/test_explorations.py @@ -1,4 +1,5 @@ import numpy as np +from tensorflow.python.eager.context import eager_mode import unittest import ray @@ -30,14 +31,23 @@ def test_explorations(run, impala.ImpalaTrainer, sac.SACTrainer]: continue print("Testing {} in framework={}".format(run, fw)) - config["eager"] = True if fw == "eager" else False - config["use_pytorch"] = True if fw == "torch" else False + config["eager"] = (fw == "eager") + config["use_pytorch"] = (fw == "torch") # Test for both the default Agent's exploration AND the `Random` # exploration class. - for exploration in [None]: # , "Random"]: + for exploration in [None, "Random"]: if exploration == "Random": + # TODO(sven): Random doesn't work for cont. action spaces + # or IMPALA yet. + if env == "Pendulum-v0" or run is impala.ImpalaTrainer: + continue config["exploration_config"] = {"type": "Random"} + print("exploration={}".format(exploration or "default")) + + eager_mode_ctx = eager_mode() + if fw == "eager": + eager_mode_ctx.__enter__() trainer = run(config=config, env=env) @@ -53,8 +63,8 @@ def test_explorations(run, prev_reward=1.0 if prev_a is not None else None)) check(actions[-1], actions[0]) - # Make sure actions drawn are different (around some mean value), - # given constant observations. + # Make sure actions drawn are different + # (around some mean value), given constant observations. actions = [] for _ in range(100): actions.append( @@ -71,6 +81,9 @@ def test_explorations(run, # Check that the stddev is not 0.0 (values differ). check(np.std(actions), 0.0, false=True) + if fw == "eager": + eager_mode_ctx.__exit__(None, None, None) + class TestExplorations(unittest.TestCase): """ @@ -109,7 +122,7 @@ class TestExplorations(unittest.TestCase): "CartPole-v0", impala.DEFAULT_CONFIG, np.array([0.0, 0.1, 0.0, 0.0]), - prev_a=np.array([0])) + prev_a=np.array(0)) def test_pg(self): test_explorations( @@ -117,7 +130,7 @@ class TestExplorations(unittest.TestCase): "CartPole-v0", pg.DEFAULT_CONFIG, np.array([0.0, 0.1, 0.0, 0.0]), - prev_a=np.array([1])) + prev_a=np.array(1)) def test_ppo_discr(self): test_explorations( @@ -125,7 +138,7 @@ class TestExplorations(unittest.TestCase): "CartPole-v0", ppo.DEFAULT_CONFIG, np.array([0.0, 0.1, 0.0, 0.0]), - prev_a=np.array([0])) + prev_a=np.array(0)) def test_ppo_cont(self): test_explorations( @@ -133,7 +146,7 @@ class TestExplorations(unittest.TestCase): "Pendulum-v0", ppo.DEFAULT_CONFIG, np.array([0.0, 0.1, 0.0]), - prev_a=np.array([0]), + prev_a=np.array([0.0]), expected_mean_action=0.0) def test_sac(self): diff --git a/rllib/tests/test_io.py b/rllib/tests/test_io.py index 7485a21f6..a30ed447e 100644 --- a/rllib/tests/test_io.py +++ b/rllib/tests/test_io.py @@ -237,12 +237,13 @@ class JsonIOTest(unittest.TestCase): for _ in range(100): writer.write(SAMPLES) num_files = len(os.listdir(self.test_dir)) - # Magic numbers: 2: On travis, it seems to create only 2 files. + # Magic numbers: 2: On travis, it seems to create only 2 files, + # but sometimes also 7. # 12 or 13: Mac locally. # Reasons: Different compressions, file-size interpretations, # json writers? - assert num_files in [2, 12, 13], \ - "Expected 12|13 files, but found {} ({})". \ + assert num_files in [2, 7, 12, 13], \ + "Expected 2|7|12|13 files, but found {} ({})". \ format(num_files, os.listdir(self.test_dir)) def testReadWrite(self): diff --git a/rllib/tests/test_supported_spaces.py b/rllib/tests/test_supported_spaces.py index d6f44d409..f8a1ba51e 100644 --- a/rllib/tests/test_supported_spaces.py +++ b/rllib/tests/test_supported_spaces.py @@ -7,6 +7,7 @@ import unittest import traceback import ray +from ray.rllib.utils.framework import try_import_tf from ray.rllib.agents.registry import get_agent_class from ray.rllib.models.tf.fcnet_v2 import FullyConnectedNetwork as FCNetV2 from ray.rllib.models.tf.visionnet_v2 import VisionNetwork as VisionNetV2 @@ -14,6 +15,7 @@ from ray.rllib.tests.test_multi_agent_env import MultiCartpole, \ MultiMountainCar from ray.rllib.utils.error import UnsupportedSpaceException from ray.tune.registry import register_env +tf = try_import_tf() ACTION_SPACES_TO_TEST = { "discrete": Discrete(5), @@ -220,16 +222,6 @@ class ModelSupportedSpaces(unittest.TestCase): def test_sac(self): check_support("SAC", {}, self.stats, check_bounds=True) - # def testAll(self): - - # num_unexpected_errors = 0 - # for (alg, a_name, o_name), stat in sorted(self.stats.items()): - # if stat not in ["ok", "unsupported", "skip"]: - # num_unexpected_errors += 1 - # print(alg, "action_space", a_name, "obs_space", o_name, "result", - # stat) - # self.assertEqual(num_unexpected_errors, 0) - def test_a3c_multiagent(self): check_support_multiagent("A3C", { "num_workers": 1, diff --git a/rllib/tuned_examples/halfcheetah-sac.yaml b/rllib/tuned_examples/halfcheetah-sac.yaml index 4669d51bc..81aaddda9 100644 --- a/rllib/tuned_examples/halfcheetah-sac.yaml +++ b/rllib/tuned_examples/halfcheetah-sac.yaml @@ -23,7 +23,7 @@ halfcheetah_sac: target_network_update_freq: 1 timesteps_per_iteration: 1000 learning_starts: 10000 - exploration_enabled: True + explore: True optimization: actor_learning_rate: 0.0003 critic_learning_rate: 0.0003 diff --git a/rllib/tuned_examples/pendulum-sac.yaml b/rllib/tuned_examples/pendulum-sac.yaml index f0c28c839..9b320fb78 100644 --- a/rllib/tuned_examples/pendulum-sac.yaml +++ b/rllib/tuned_examples/pendulum-sac.yaml @@ -3,7 +3,7 @@ pendulum_sac: env: Pendulum-v0 run: SAC - stop: + stop: episode_reward_mean: -150 config: horizon: 200 @@ -24,7 +24,7 @@ pendulum_sac: target_network_update_freq: 1 timesteps_per_iteration: 1000 learning_starts: 256 - exploration_enabled: True + explore: True optimization: actor_learning_rate: 0.0003 critic_learning_rate: 0.0003 diff --git a/rllib/utils/__init__.py b/rllib/utils/__init__.py index cd16f8c81..d8733a599 100644 --- a/rllib/utils/__init__.py +++ b/rllib/utils/__init__.py @@ -8,7 +8,7 @@ from ray.rllib.utils.deprecation import deprecation_warning, renamed_agent, \ from ray.rllib.utils.filter_manager import FilterManager from ray.rllib.utils.filter import Filter from ray.rllib.utils.numpy import sigmoid, softmax, relu, one_hot, fc, lstm, \ - SMALL_NUMBER, LARGE_INTEGER + SMALL_NUMBER, LARGE_INTEGER, MIN_LOG_NN_OUTPUT, MAX_LOG_NN_OUTPUT from ray.rllib.utils.policy_client import PolicyClient from ray.rllib.utils.policy_server import PolicyServer from ray.rllib.utils.schedules import LinearSchedule, PiecewiseSchedule, \ @@ -85,6 +85,8 @@ __all__ = [ "FilterManager", "LARGE_INTEGER", "LinearSchedule", + "MAX_LOG_NN_OUTPUT", + "MIN_LOG_NN_OUTPUT", "PiecewiseSchedule", "PolicyClient", "PolicyServer", diff --git a/rllib/utils/exploration/epsilon_greedy.py b/rllib/utils/exploration/epsilon_greedy.py index d6f088c95..1bd6226d7 100644 --- a/rllib/utils/exploration/epsilon_greedy.py +++ b/rllib/utils/exploration/epsilon_greedy.py @@ -63,24 +63,24 @@ class EpsilonGreedy(Exploration): @override(Exploration) def get_exploration_action(self, - model_output, - model, - action_dist_class, + distribution_inputs, + action_dist_class=None, + model=None, explore=True, timestep=None): if self.framework == "tf": - return self._get_tf_exploration_action_op(model_output, explore, - timestep) + return self._get_tf_exploration_action_op(distribution_inputs, + explore, timestep) else: - return self._get_torch_exploration_action(model_output, explore, - timestep) + return self._get_torch_exploration_action(distribution_inputs, + explore, timestep) - def _get_tf_exploration_action_op(self, model_output, explore, timestep): + def _get_tf_exploration_action_op(self, q_values, explore, timestep): """Tf method to produce the tf op for an epsilon exploration action. Args: - model_output (tf.Tensor): + q_values (Tensor): The Q-values coming from some q-model. Returns: tf.Tensor: The tf exploration-action op. @@ -90,15 +90,14 @@ class EpsilonGreedy(Exploration): self.last_timestep)) # Get the exploit action as the one with the highest logit value. - exploit_action = tf.argmax(model_output, axis=1) + exploit_action = tf.argmax(q_values, axis=1) - batch_size = tf.shape(model_output)[0] + batch_size = tf.shape(q_values)[0] # Mask out actions with q-value=-inf so that we don't # even consider them for exploration. random_valid_action_logits = tf.where( - tf.equal(model_output, tf.float32.min), - tf.ones_like(model_output) * tf.float32.min, - tf.ones_like(model_output)) + tf.equal(q_values, tf.float32.min), + tf.ones_like(q_values) * tf.float32.min, tf.ones_like(q_values)) random_actions = tf.squeeze( tf.multinomial(random_valid_action_logits, 1), axis=1) @@ -122,11 +121,11 @@ class EpsilonGreedy(Exploration): with tf.control_dependencies([assign_op]): return action, tf.zeros_like(action, dtype=tf.float32) - def _get_torch_exploration_action(self, model_output, explore, timestep): + def _get_torch_exploration_action(self, q_values, explore, timestep): """Torch method to produce an epsilon exploration action. Args: - model_output (torch.Tensor): + q_values (Tensor): The Q-values coming from some q-model. Returns: torch.Tensor: The exploration-action. @@ -135,20 +134,20 @@ class EpsilonGreedy(Exploration): self.last_timestep = timestep if timestep is not None else \ self.last_timestep + 1 - _, exploit_action = torch.max(model_output, 1) + _, exploit_action = torch.max(q_values, 1) action_logp = torch.zeros_like(exploit_action) # Explore. if explore: # Get the current epsilon. epsilon = self.epsilon_schedule(self.last_timestep) - batch_size = model_output.size()[0] + batch_size = q_values.size()[0] # Mask out actions, whose Q-values are -inf, so that we don't # even consider them for exploration. random_valid_action_logits = torch.where( - model_output == float("-inf"), - torch.ones_like(model_output) * float("-inf"), - torch.ones_like(model_output)) + q_values == float("-inf"), + torch.ones_like(q_values) * float("-inf"), + torch.ones_like(q_values)) # A random action. random_actions = torch.squeeze( torch.multinomial(random_valid_action_logits, 1), axis=1) diff --git a/rllib/utils/exploration/exploration.py b/rllib/utils/exploration/exploration.py index ed163ee7e..67c2f992e 100644 --- a/rllib/utils/exploration/exploration.py +++ b/rllib/utils/exploration/exploration.py @@ -31,9 +31,9 @@ class Exploration: self.framework = check_framework(framework) def get_exploration_action(self, - model_output, - model, + distribution_inputs, action_dist_class, + model=None, explore=True, timestep=None): """Returns a (possibly) exploratory action. @@ -42,10 +42,12 @@ class Exploration: exploratory action. Args: - model_output (any): The raw output coming from the model + distribution_inputs (any): The output coming from the model, + ready for parameterizing a distribution (e.g. q-values or PG-logits). + action_dist_class (class): The action distribution class + to use. model (ModelV2): The Model object. - action_dist_class: The ActionDistribution class. explore (bool): True: "Normal" exploration behavior. False: Suppress all exploratory behavior and return a deterministic action. diff --git a/rllib/utils/exploration/random.py b/rllib/utils/exploration/random.py index c4342df78..59e112635 100644 --- a/rllib/utils/exploration/random.py +++ b/rllib/utils/exploration/random.py @@ -31,13 +31,13 @@ class Random(Exploration): @override(Exploration) def get_exploration_action(self, - model_output, - model, + distribution_inputs, action_dist_class, + model=None, explore=True, timestep=None): # Instantiate the distribution object. - action_dist = action_dist_class(model_output, model) + action_dist = action_dist_class(distribution_inputs, model) if self.framework == "tf": return self._get_tf_exploration_action_op(action_dist, explore, @@ -49,7 +49,7 @@ class Random(Exploration): @tf_function(tf) def _get_tf_exploration_action_op(self, action_dist, explore, timestep): if explore: - action = self.action_space.sample() + action = tf.py_function(self.action_space.sample, [], tf.int64) # Will be unnecessary, once we support batch/time-aware Spaces. action = tf.expand_dims(tf.cast(action, dtype=tf.int32), 0) else: @@ -67,8 +67,8 @@ class Random(Exploration): if explore: # Unsqueeze will be unnecessary, once we support batch/time-aware # Spaces. - action = torch.IntTensor(self.action_space.sample()).unsqueeze(0) + action = torch.LongTensor(self.action_space.sample()).unsqueeze(0) else: - action = torch.IntTensor(action_dist.deterministic_sample()) + action = torch.LongTensor(action_dist.deterministic_sample()) logp = torch.zeros((action.size()[0], ), dtype=torch.float32) return action, logp diff --git a/rllib/utils/exploration/stochastic_sampling.py b/rllib/utils/exploration/stochastic_sampling.py index 880784d4d..ae325584e 100644 --- a/rllib/utils/exploration/stochastic_sampling.py +++ b/rllib/utils/exploration/stochastic_sampling.py @@ -1,4 +1,3 @@ -from ray.rllib.models.catalog import ModelCatalog from ray.rllib.utils.annotations import override from ray.rllib.utils.exploration.exploration import Exploration from ray.rllib.utils.framework import try_import_tf, try_import_torch @@ -46,9 +45,9 @@ class StochasticSampling(Exploration): @override(Exploration) def get_exploration_action(self, - model_output, - model, + distribution_inputs, action_dist_class, + model=None, explore=True, timestep=None): kwargs = self.static_params.copy() @@ -60,12 +59,7 @@ class StochasticSampling(Exploration): # if self.time_dependent_params: # for k, v in self.time_dependent_params: # kwargs[k] = v(timestep) - constructor, _ = ModelCatalog.get_action_dist( - self.action_space, - None, - action_dist_class, - framework=self.framework) - action_dist = constructor(model_output, model, **kwargs) + action_dist = action_dist_class(distribution_inputs, model, **kwargs) if self.framework == "torch": return self._get_torch_exploration_action(action_dist, explore) diff --git a/rllib/utils/framework.py b/rllib/utils/framework.py index bd94f0fc5..77b0e35d4 100644 --- a/rllib/utils/framework.py +++ b/rllib/utils/framework.py @@ -65,7 +65,7 @@ def tf_function(tf_module): # The actual decorator to use (pass in `tf` (which could be None)). def decorator(func): # If tf not installed -> return function as is (won't be used anyways). - if tf_module is None: + if tf_module is None or tf_module.executing_eagerly(): return func # If tf installed, return @tf.function-decorated function. return tf_module.function(func) diff --git a/rllib/utils/numpy.py b/rllib/utils/numpy.py index 625b31e80..c32cdcc18 100644 --- a/rllib/utils/numpy.py +++ b/rllib/utils/numpy.py @@ -1,5 +1,8 @@ import numpy as np +from ray.rllib.utils.framework import try_import_torch + +torch, _ = try_import_torch() SMALL_NUMBER = 1e-6 # Some large int number. May be increased here, if needed. @@ -58,7 +61,7 @@ def relu(x, alpha=0.0): Returns: np.ndarray: The leaky ReLU output for x. """ - return np.maximum(x, x*alpha, x) + return np.maximum(x, x * alpha, x) def one_hot(x, depth=0, on_value=1, off_value=0): @@ -89,7 +92,7 @@ def one_hot(x, depth=0, on_value=1, off_value=0): shape = x.shape # Python 2.7 compatibility, (*shape, depth) is not allowed. - shape_list = shape[:] + shape_list = list(shape[:]) shape_list.append(depth) out = np.ones(shape_list) * off_value indices = [] @@ -99,7 +102,7 @@ def one_hot(x, depth=0, on_value=1, off_value=0): s[i] = -1 r = np.arange(shape[i]).reshape(s) if i > 0: - tiles[i-1] = shape[i-1] + tiles[i - 1] = shape[i - 1] r = np.tile(r, tiles) indices.append(r) indices.append(x) @@ -120,11 +123,18 @@ def fc(x, weights, biases=None): Returns: The dense layer's output. """ + # Torch stores matrices in transpose (faster for backprop). + if torch and isinstance(weights, torch.Tensor): + weights = np.transpose(weights.numpy()) return np.matmul(x, weights) + (0.0 if biases is None else biases) -def lstm(x, weights, biases=None, initial_internal_states=None, - time_major=False, forget_bias=1.0): +def lstm(x, + weights, + biases=None, + initial_internal_states=None, + time_major=False, + forget_bias=1.0): """ Calculates the outputs of an LSTM layer given weights/biases, internal_states, and input. @@ -174,15 +184,15 @@ def lstm(x, weights, biases=None, initial_internal_states=None, input_matrix = np.concatenate((input_matrix, h_states), axis=1) input_matmul_matrix = np.matmul(input_matrix, weights) + biases # Forget gate (3rd slot in tf output matrix). Add static forget bias. - sigmoid_1 = sigmoid(input_matmul_matrix[:, units*2:units*3] + + sigmoid_1 = sigmoid(input_matmul_matrix[:, units * 2:units * 3] + forget_bias) c_states = np.multiply(c_states, sigmoid_1) # Add gate (1st and 2nd slots in tf output matrix). sigmoid_2 = sigmoid(input_matmul_matrix[:, 0:units]) - tanh_3 = np.tanh(input_matmul_matrix[:, units:units*2]) + tanh_3 = np.tanh(input_matmul_matrix[:, units:units * 2]) c_states = np.add(c_states, np.multiply(sigmoid_2, tanh_3)) # Output gate (last slot in tf output matrix). - sigmoid_4 = sigmoid(input_matmul_matrix[:, units*3:units*4]) + sigmoid_4 = sigmoid(input_matmul_matrix[:, units * 3:units * 4]) h_states = np.multiply(sigmoid_4, np.tanh(c_states)) # Store this output time-slice.