# Copyright 2018 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Tests for V-trace. For details and theory see: "IMPALA: Scalable Distributed Deep-RL with Importance Weighted Actor-Learner Architectures" by Espeholt, Soyer, Munos et al. """ from gym.spaces import Box import numpy as np import unittest from ray.rllib.algorithms.impala import vtrace_tf as vtrace_tf from ray.rllib.algorithms.impala import vtrace_torch as vtrace_torch from ray.rllib.utils.framework import try_import_tf, try_import_torch from ray.rllib.utils.numpy import softmax from ray.rllib.utils.test_utils import check, framework_iterator tf1, tf, tfv = try_import_tf() torch, nn = try_import_torch() def _ground_truth_calculation( vtrace, discounts, log_rhos, rewards, values, bootstrap_value, clip_rho_threshold, clip_pg_rho_threshold, ): """Calculates the ground truth for V-trace in Python/Numpy.""" vs = [] seq_len = len(discounts) rhos = np.exp(log_rhos) cs = np.minimum(rhos, 1.0) clipped_rhos = rhos if clip_rho_threshold: clipped_rhos = np.minimum(rhos, clip_rho_threshold) clipped_pg_rhos = rhos if clip_pg_rho_threshold: clipped_pg_rhos = np.minimum(rhos, clip_pg_rho_threshold) # This is a very inefficient way to calculate the V-trace ground truth. # We calculate it this way because it is close to the mathematical notation # of # V-trace. # v_s = V(x_s) # + \sum^{T-1}_{t=s} \gamma^{t-s} # * \prod_{i=s}^{t-1} c_i # * \rho_t (r_t + \gamma V(x_{t+1}) - V(x_t)) # Note that when we take the product over c_i, we write `s:t` as the # notation # of the paper is inclusive of the `t-1`, but Python is exclusive. # Also note that np.prod([]) == 1. values_t_plus_1 = np.concatenate([values[1:], bootstrap_value[None, :]], axis=0) for s in range(seq_len): v_s = np.copy(values[s]) # Very important copy. for t in range(s, seq_len): v_s += ( np.prod(discounts[s:t], axis=0) * np.prod(cs[s:t], axis=0) * clipped_rhos[t] * (rewards[t] + discounts[t] * values_t_plus_1[t] - values[t]) ) vs.append(v_s) vs = np.stack(vs, axis=0) pg_advantages = clipped_pg_rhos * ( rewards + discounts * np.concatenate([vs[1:], bootstrap_value[None, :]], axis=0) - values ) return vtrace.VTraceReturns(vs=vs, pg_advantages=pg_advantages) class LogProbsFromLogitsAndActionsTest(unittest.TestCase): def test_log_probs_from_logits_and_actions(self): """Tests log_probs_from_logits_and_actions.""" seq_len = 7 num_actions = 3 batch_size = 4 for fw, sess in framework_iterator(frameworks=("torch", "tf"), session=True): vtrace = vtrace_tf if fw != "torch" else vtrace_torch policy_logits = Box( -1.0, 1.0, (seq_len, batch_size, num_actions), np.float32 ).sample() actions = np.random.randint( 0, num_actions - 1, size=(seq_len, batch_size), dtype=np.int32 ) if fw == "torch": action_log_probs_tensor = vtrace.log_probs_from_logits_and_actions( torch.from_numpy(policy_logits), torch.from_numpy(actions) ) else: action_log_probs_tensor = vtrace.log_probs_from_logits_and_actions( policy_logits, actions ) # Ground Truth # Using broadcasting to create a mask that indexes action logits action_index_mask = actions[..., None] == np.arange(num_actions) def index_with_mask(array, mask): return array[mask].reshape(*array.shape[:-1]) # Note: Normally log(softmax) is not a good idea because it's not # numerically stable. However, in this test we have well-behaved # values. ground_truth_v = index_with_mask( np.log(softmax(policy_logits)), action_index_mask ) if sess: action_log_probs_tensor = sess.run(action_log_probs_tensor) check(action_log_probs_tensor, ground_truth_v) class VtraceTest(unittest.TestCase): def test_vtrace(self): """Tests V-trace against ground truth data calculated in python.""" seq_len = 5 batch_size = 10 # Create log_rhos such that rho will span from near-zero to above the # clipping thresholds. In particular, calculate log_rhos in # [-2.5, 2.5), # so that rho is in approx [0.08, 12.2). space_w_time = Box(-1.0, 1.0, (seq_len, batch_size), np.float32) space_only_batch = Box(-1.0, 1.0, (batch_size,), np.float32) log_rhos = space_w_time.sample() / (batch_size * seq_len) log_rhos = 5 * (log_rhos - 0.5) # [0.0, 1.0) -> [-2.5, 2.5). values = { "log_rhos": log_rhos, # T, B where B_i: [0.9 / (i+1)] * T "discounts": np.array( [[0.9 / (b + 1) for b in range(batch_size)] for _ in range(seq_len)] ), "rewards": space_w_time.sample(), "values": space_w_time.sample() / batch_size, "bootstrap_value": space_only_batch.sample() + 1.0, "clip_rho_threshold": 3.7, "clip_pg_rho_threshold": 2.2, } for fw, sess in framework_iterator(frameworks=("torch", "tf"), session=True): vtrace = vtrace_tf if fw != "torch" else vtrace_torch output = vtrace.from_importance_weights(**values) if sess: output = sess.run(output) ground_truth_v = _ground_truth_calculation(vtrace, **values) check(output, ground_truth_v) def test_vtrace_from_logits(self): """Tests V-trace calculated from logits.""" seq_len = 5 batch_size = 15 num_actions = 3 clip_rho_threshold = None # No clipping. clip_pg_rho_threshold = None # No clipping. space = Box(-1.0, 1.0, (seq_len, batch_size, num_actions)) action_space = Box( 0, num_actions - 1, ( seq_len, batch_size, ), dtype=np.int32, ) space_w_time = Box( -1.0, 1.0, ( seq_len, batch_size, ), ) space_only_batch = Box(-1.0, 1.0, (batch_size,)) for fw, sess in framework_iterator(frameworks=("torch", "tf"), session=True): vtrace = vtrace_tf if fw != "torch" else vtrace_torch if fw == "tf": # Intentionally leaving shapes unspecified to test if V-trace # can deal with that. inputs_ = { # T, B, NUM_ACTIONS "behaviour_policy_logits": tf1.placeholder( dtype=tf.float32, shape=[None, None, None] ), # T, B, NUM_ACTIONS "target_policy_logits": tf1.placeholder( dtype=tf.float32, shape=[None, None, None] ), "actions": tf1.placeholder(dtype=tf.int32, shape=[None, None]), "discounts": tf1.placeholder(dtype=tf.float32, shape=[None, None]), "rewards": tf1.placeholder(dtype=tf.float32, shape=[None, None]), "values": tf1.placeholder(dtype=tf.float32, shape=[None, None]), "bootstrap_value": tf1.placeholder(dtype=tf.float32, shape=[None]), } else: inputs_ = { # T, B, NUM_ACTIONS "behaviour_policy_logits": space.sample(), # T, B, NUM_ACTIONS "target_policy_logits": space.sample(), "actions": action_space.sample(), "discounts": space_w_time.sample(), "rewards": space_w_time.sample(), "values": space_w_time.sample(), "bootstrap_value": space_only_batch.sample(), } from_logits_output = vtrace.from_logits( clip_rho_threshold=clip_rho_threshold, clip_pg_rho_threshold=clip_pg_rho_threshold, **inputs_ ) if fw != "torch": target_log_probs = vtrace.log_probs_from_logits_and_actions( inputs_["target_policy_logits"], inputs_["actions"] ) behaviour_log_probs = vtrace.log_probs_from_logits_and_actions( inputs_["behaviour_policy_logits"], inputs_["actions"] ) else: target_log_probs = vtrace.log_probs_from_logits_and_actions( torch.from_numpy(inputs_["target_policy_logits"]), torch.from_numpy(inputs_["actions"]), ) behaviour_log_probs = vtrace.log_probs_from_logits_and_actions( torch.from_numpy(inputs_["behaviour_policy_logits"]), torch.from_numpy(inputs_["actions"]), ) log_rhos = target_log_probs - behaviour_log_probs ground_truth = (log_rhos, behaviour_log_probs, target_log_probs) if sess: values = { "behaviour_policy_logits": space.sample(), "target_policy_logits": space.sample(), "actions": action_space.sample(), "discounts": space_w_time.sample(), "rewards": space_w_time.sample(), "values": space_w_time.sample() / batch_size, "bootstrap_value": space_only_batch.sample() + 1.0, } feed_dict = {inputs_[k]: v for k, v in values.items()} from_logits_output = sess.run(from_logits_output, feed_dict=feed_dict) log_rhos, behaviour_log_probs, target_log_probs = sess.run( ground_truth, feed_dict=feed_dict ) # Calculate V-trace using the ground truth logits. from_iw = vtrace.from_importance_weights( log_rhos=log_rhos, discounts=values["discounts"], rewards=values["rewards"], values=values["values"], bootstrap_value=values["bootstrap_value"], clip_rho_threshold=clip_rho_threshold, clip_pg_rho_threshold=clip_pg_rho_threshold, ) from_iw = sess.run(from_iw) else: from_iw = vtrace.from_importance_weights( log_rhos=log_rhos, discounts=inputs_["discounts"], rewards=inputs_["rewards"], values=inputs_["values"], bootstrap_value=inputs_["bootstrap_value"], clip_rho_threshold=clip_rho_threshold, clip_pg_rho_threshold=clip_pg_rho_threshold, ) check(from_iw.vs, from_logits_output.vs) check(from_iw.pg_advantages, from_logits_output.pg_advantages) check(behaviour_log_probs, from_logits_output.behaviour_action_log_probs) check(target_log_probs, from_logits_output.target_action_log_probs) check(log_rhos, from_logits_output.log_rhos) def test_higher_rank_inputs_for_importance_weights(self): """Checks support for additional dimensions in inputs.""" for fw in framework_iterator(frameworks=("torch", "tf"), session=True): vtrace = vtrace_tf if fw != "torch" else vtrace_torch if fw == "tf": inputs_ = { "log_rhos": tf1.placeholder( dtype=tf.float32, shape=[None, None, 1] ), "discounts": tf1.placeholder( dtype=tf.float32, shape=[None, None, 1] ), "rewards": tf1.placeholder( dtype=tf.float32, shape=[None, None, 42] ), "values": tf1.placeholder(dtype=tf.float32, shape=[None, None, 42]), "bootstrap_value": tf1.placeholder( dtype=tf.float32, shape=[None, 42] ), } else: inputs_ = { "log_rhos": Box(-1.0, 1.0, (8, 10, 1)).sample(), "discounts": Box(-1.0, 1.0, (8, 10, 1)).sample(), "rewards": Box(-1.0, 1.0, (8, 10, 42)).sample(), "values": Box(-1.0, 1.0, (8, 10, 42)).sample(), "bootstrap_value": Box(-1.0, 1.0, (10, 42)).sample(), } output = vtrace.from_importance_weights(**inputs_) check(int(output.vs.shape[-1]), 42) def test_inconsistent_rank_inputs_for_importance_weights(self): """Test one of many possible errors in shape of inputs.""" for fw in framework_iterator(frameworks=("torch", "tf"), session=True): vtrace = vtrace_tf if fw != "torch" else vtrace_torch if fw == "tf": inputs_ = { "log_rhos": tf1.placeholder( dtype=tf.float32, shape=[None, None, 1] ), "discounts": tf1.placeholder( dtype=tf.float32, shape=[None, None, 1] ), "rewards": tf1.placeholder( dtype=tf.float32, shape=[None, None, 42] ), "values": tf1.placeholder(dtype=tf.float32, shape=[None, None, 42]), # Should be [None, 42]. "bootstrap_value": tf1.placeholder(dtype=tf.float32, shape=[None]), } else: inputs_ = { "log_rhos": Box(-1.0, 1.0, (7, 15, 1)).sample(), "discounts": Box(-1.0, 1.0, (7, 15, 1)).sample(), "rewards": Box(-1.0, 1.0, (7, 15, 42)).sample(), "values": Box(-1.0, 1.0, (7, 15, 42)).sample(), # Should be [15, 42]. "bootstrap_value": Box(-1.0, 1.0, (7,)).sample(), } with self.assertRaisesRegex( (ValueError, AssertionError), "must have rank 2" ): vtrace.from_importance_weights(**inputs_) if __name__ == "__main__": tf.test.main()