ray/rllib/utils/numpy.py
Sven Mika 0db2046b0a
[RLlib] Policy.compute_log_likelihoods() and SAC refactor. (issue #7107) (#7124)
* Exploration API (+EpsilonGreedy sub-class).

* Exploration API (+EpsilonGreedy sub-class).

* Cleanup/LINT.

* Add `deterministic` to generic Trainer config (NOTE: this is still ignored by most Agents).

* Add `error` option to deprecation_warning().

* WIP.

* Bug fix: Get exploration-info for tf framework.
Bug fix: Properly deprecate some DQN config keys.

* WIP.

* LINT.

* WIP.

* Split PerWorkerEpsilonGreedy out of EpsilonGreedy.
Docstrings.

* Fix bug in sampler.py in case Policy has self.exploration = None

* Update rllib/agents/dqn/dqn.py

Co-Authored-By: Eric Liang <ekhliang@gmail.com>

* WIP.

* Update rllib/agents/trainer.py

Co-Authored-By: Eric Liang <ekhliang@gmail.com>

* WIP.

* Change requests.

* LINT

* In tune/utils/util.py::deep_update() Only keep deep_updat'ing if both original and value are dicts. If value is not a dict, set

* Completely obsolete syn_replay_optimizer.py's parameters schedule_max_timesteps AND beta_annealing_fraction (replaced with prioritized_replay_beta_annealing_timesteps).

* Update rllib/evaluation/worker_set.py

Co-Authored-By: Eric Liang <ekhliang@gmail.com>

* Review fixes.

* Fix default value for DQN's exploration spec.

* LINT

* Fix recursion bug (wrong parent c'tor).

* Do not pass timestep to get_exploration_info.

* Update tf_policy.py

* Fix some remaining issues with test cases and remove more deprecated DQN/APEX exploration configs.

* Bug fix tf-action-dist

* DDPG incompatibility bug fix with new DQN exploration handling (which is imported by DDPG).

* Switch off exploration when getting action probs from off-policy-estimator's policy.

* LINT

* Fix test_checkpoint_restore.py.

* Deprecate all SAC exploration (unused) configs.

* Properly use `model.last_output()` everywhere. Instead of `model._last_output`.

* WIP.

* Take out set_epsilon from multi-agent-env test (not needed, decays anyway).

* WIP.

* Trigger re-test (flaky checkpoint-restore test).

* WIP.

* WIP.

* Add test case for deterministic action sampling in PPO.

* bug fix.

* Added deterministic test cases for different Agents.

* Fix problem with TupleActions in dynamic-tf-policy.

* Separate supported_spaces tests so they can be run separately for easier debugging.

* LINT.

* Fix autoregressive_action_dist.py test case.

* Re-test.

* Fix.

* Remove duplicate py_test rule from bazel.

* LINT.

* WIP.

* WIP.

* SAC fix.

* SAC fix.

* WIP.

* WIP.

* WIP.

* FIX 2 examples tests.

* WIP.

* WIP.

* WIP.

* WIP.

* WIP.

* Fix.

* LINT.

* Renamed test file.

* WIP.

* Add unittest.main.

* Make action_dist_class mandatory.

* fix

* FIX.

* WIP.

* WIP.

* Fix.

* Fix.

* Fix explorations test case (contextlib cannot find its own nullcontext??).

* Force torch to be installed for QMIX.

* LINT.

* Fix determine_tests_to_run.py.

* Fix determine_tests_to_run.py.

* WIP

* Add Random exploration component to tests (fixed issue with "static-graph randomness" via py_function).

* Add Random exploration component to tests (fixed issue with "static-graph randomness" via py_function).

* Rename some stuff.

* Rename some stuff.

* WIP.

* WIP.

* Fix SAC.

* Fix SAC.

* Fix strange tf-error in ray core tests.

* Fix strange ray-core tf-error in test_memory_scheduling test case.

* Fix test_io.py.

* LINT.

* Update SAC yaml files' config.

Co-authored-by: Eric Liang <ekhliang@gmail.com>
2020-02-22 14:19:49 -08:00

204 lines
6.5 KiB
Python

import numpy as np
from ray.rllib.utils.framework import try_import_torch
torch, _ = try_import_torch()
SMALL_NUMBER = 1e-6
# Some large int number. May be increased here, if needed.
LARGE_INTEGER = 100000000
# Min and Max outputs (clipped) from an NN-output layer interpreted as the
# log(x) of some x (e.g. a stddev of a normal
# distribution).
MIN_LOG_NN_OUTPUT = -20
MAX_LOG_NN_OUTPUT = 2
def sigmoid(x, derivative=False):
"""
Returns the sigmoid function applied to x.
Alternatively, can return the derivative or the sigmoid function.
Args:
x (np.ndarray): The input to the sigmoid function.
derivative (bool): Whether to return the derivative or not.
Default: False.
Returns:
np.ndarray: The sigmoid function (or its derivative) applied to x.
"""
if derivative:
return x * (1 - x)
else:
return 1 / (1 + np.exp(-x))
def softmax(x, axis=-1):
"""
Returns the softmax values for x as:
S(xi) = e^xi / SUMj(e^xj), where j goes over all elements in x.
Args:
x (np.ndarray): The input to the softmax function.
axis (int): The axis along which to softmax.
Returns:
np.ndarray: The softmax over x.
"""
x_exp = np.exp(x)
return np.maximum(x_exp / np.sum(x_exp, axis, keepdims=True), SMALL_NUMBER)
def relu(x, alpha=0.0):
"""
Implementation of the leaky ReLU function:
y = x * alpha if x < 0 else x
Args:
x (np.ndarray): The input values.
alpha (float): A scaling ("leak") factor to use for negative x.
Returns:
np.ndarray: The leaky ReLU output for x.
"""
return np.maximum(x, x * alpha, x)
def one_hot(x, depth=0, on_value=1, off_value=0):
"""
One-hot utility function for numpy.
Thanks to qianyizhang:
https://gist.github.com/qianyizhang/07ee1c15cad08afb03f5de69349efc30.
Args:
x (np.ndarray): The input to be one-hot encoded.
depth (int): The max. number to be one-hot encoded (size of last rank).
on_value (float): The value to use for on. Default: 1.0.
off_value (float): The value to use for off. Default: 0.0.
Returns:
np.ndarray: The one-hot encoded equivalent of the input array.
"""
# Handle bool arrays correctly.
if x.dtype == np.bool_:
x = x.astype(np.int)
depth = 2
if depth == 0:
depth = np.max(x) + 1
assert np.max(x) < depth, \
"ERROR: The max. index of `x` ({}) is larger than depth ({})!".\
format(np.max(x), depth)
shape = x.shape
# Python 2.7 compatibility, (*shape, depth) is not allowed.
shape_list = list(shape[:])
shape_list.append(depth)
out = np.ones(shape_list) * off_value
indices = []
for i in range(x.ndim):
tiles = [1] * x.ndim
s = [1] * x.ndim
s[i] = -1
r = np.arange(shape[i]).reshape(s)
if i > 0:
tiles[i - 1] = shape[i - 1]
r = np.tile(r, tiles)
indices.append(r)
indices.append(x)
out[tuple(indices)] = on_value
return out
def fc(x, weights, biases=None):
"""
Calculates the outputs of a fully-connected (dense) layer given
weights/biases and an input.
Args:
x (np.ndarray): The input to the dense layer.
weights (np.ndarray): The weights matrix.
biases (Optional[np.ndarray]): The biases vector. All 0s if None.
Returns:
The dense layer's output.
"""
# Torch stores matrices in transpose (faster for backprop).
if torch and isinstance(weights, torch.Tensor):
weights = np.transpose(weights.numpy())
return np.matmul(x, weights) + (0.0 if biases is None else biases)
def lstm(x,
weights,
biases=None,
initial_internal_states=None,
time_major=False,
forget_bias=1.0):
"""
Calculates the outputs of an LSTM layer given weights/biases,
internal_states, and input.
Args:
x (np.ndarray): The inputs to the LSTM layer including time-rank
(0th if time-major, else 1st) and the batch-rank
(1st if time-major, else 0th).
weights (np.ndarray): The weights matrix.
biases (Optional[np.ndarray]): The biases vector. All 0s if None.
initial_internal_states (Optional[np.ndarray]): The initial internal
states to pass into the layer. All 0s if None.
time_major (bool): Whether to use time-major or not. Default: False.
forget_bias (float): Gets added to first sigmoid (forget gate) output.
Default: 1.0.
Returns:
Tuple:
- The LSTM layer's output.
- Tuple: Last (c-state, h-state).
"""
sequence_length = x.shape[0 if time_major else 1]
batch_size = x.shape[1 if time_major else 0]
units = weights.shape[1] // 4 # 4 internal layers (3x sigmoid, 1x tanh)
if initial_internal_states is None:
c_states = np.zeros(shape=(batch_size, units))
h_states = np.zeros(shape=(batch_size, units))
else:
c_states = initial_internal_states[0]
h_states = initial_internal_states[1]
# Create a placeholder for all n-time step outputs.
if time_major:
unrolled_outputs = np.zeros(shape=(sequence_length, batch_size, units))
else:
unrolled_outputs = np.zeros(shape=(batch_size, sequence_length, units))
# Push the batch 4 times through the LSTM cell and capture the outputs plus
# the final h- and c-states.
for t in range(sequence_length):
input_matrix = x[t, :, :] if time_major else x[:, t, :]
input_matrix = np.concatenate((input_matrix, h_states), axis=1)
input_matmul_matrix = np.matmul(input_matrix, weights) + biases
# Forget gate (3rd slot in tf output matrix). Add static forget bias.
sigmoid_1 = sigmoid(input_matmul_matrix[:, units * 2:units * 3] +
forget_bias)
c_states = np.multiply(c_states, sigmoid_1)
# Add gate (1st and 2nd slots in tf output matrix).
sigmoid_2 = sigmoid(input_matmul_matrix[:, 0:units])
tanh_3 = np.tanh(input_matmul_matrix[:, units:units * 2])
c_states = np.add(c_states, np.multiply(sigmoid_2, tanh_3))
# Output gate (last slot in tf output matrix).
sigmoid_4 = sigmoid(input_matmul_matrix[:, units * 3:units * 4])
h_states = np.multiply(sigmoid_4, np.tanh(c_states))
# Store this output time-slice.
if time_major:
unrolled_outputs[t, :, :] = h_states
else:
unrolled_outputs[:, t, :] = h_states
return unrolled_outputs, (c_states, h_states)