2017-08-03 19:29:01 -07:00
|
|
|
import numpy as np
|
2020-11-12 03:18:50 -08:00
|
|
|
from typing import Tuple, Any, Optional
|
|
|
|
|
2022-05-24 22:14:25 -07:00
|
|
|
from ray.rllib.utils.annotations import DeveloperAPI
|
2020-06-16 08:52:20 +02:00
|
|
|
from ray.rllib.utils.framework import try_import_tf
|
2020-11-12 03:18:50 -08:00
|
|
|
from ray.rllib.utils.typing import TensorType
|
2019-05-10 20:36:18 -07:00
|
|
|
|
2020-06-30 10:13:20 +02:00
|
|
|
tf1, tf, tfv = try_import_tf()
|
2017-08-03 19:29:01 -07:00
|
|
|
|
|
|
|
|
2022-05-24 22:14:25 -07:00
|
|
|
@DeveloperAPI
|
2020-11-12 03:18:50 -08:00
|
|
|
def normc_initializer(std: float = 1.0) -> Any:
|
2017-08-03 19:29:01 -07:00
|
|
|
def _initializer(shape, dtype=None, partition_info=None):
|
2022-04-07 21:35:02 +02:00
|
|
|
out = np.random.randn(*shape).astype(
|
|
|
|
dtype.name if hasattr(dtype, "name") else dtype or np.float32
|
|
|
|
)
|
2017-08-03 19:29:01 -07:00
|
|
|
out *= std / np.sqrt(np.square(out).sum(axis=0, keepdims=True))
|
|
|
|
return tf.constant(out)
|
2018-07-19 15:30:36 -07:00
|
|
|
|
2017-08-03 19:29:01 -07:00
|
|
|
return _initializer
|
2017-08-28 12:23:14 -07:00
|
|
|
|
|
|
|
|
2022-05-24 22:14:25 -07:00
|
|
|
@DeveloperAPI
|
2020-11-12 03:18:50 -08:00
|
|
|
def conv2d(
|
|
|
|
x: TensorType,
|
|
|
|
num_filters: int,
|
|
|
|
name: str,
|
|
|
|
filter_size: Tuple[int, int] = (3, 3),
|
|
|
|
stride: Tuple[int, int] = (1, 1),
|
|
|
|
pad: str = "SAME",
|
|
|
|
dtype: Optional[Any] = None,
|
|
|
|
collections: Optional[Any] = None,
|
|
|
|
) -> TensorType:
|
2019-05-10 20:36:18 -07:00
|
|
|
if dtype is None:
|
|
|
|
dtype = tf.float32
|
|
|
|
|
2020-06-30 10:13:20 +02:00
|
|
|
with tf1.variable_scope(name):
|
2017-08-28 12:23:14 -07:00
|
|
|
stride_shape = [1, stride[0], stride[1], 1]
|
2018-07-19 15:30:36 -07:00
|
|
|
filter_shape = [
|
|
|
|
filter_size[0],
|
|
|
|
filter_size[1],
|
|
|
|
int(x.get_shape()[3]),
|
|
|
|
num_filters,
|
|
|
|
]
|
2017-08-28 12:23:14 -07:00
|
|
|
|
|
|
|
# There are "num input feature maps * filter height * filter width"
|
|
|
|
# inputs to each hidden unit.
|
|
|
|
fan_in = np.prod(filter_shape[:3])
|
|
|
|
# Each unit in the lower layer receives a gradient from: "num output
|
|
|
|
# feature maps * filter height * filter width" / pooling size.
|
|
|
|
fan_out = np.prod(filter_shape[:2]) * num_filters
|
|
|
|
# Initialize weights with random weights.
|
|
|
|
w_bound = np.sqrt(6 / (fan_in + fan_out))
|
|
|
|
|
2020-06-30 10:13:20 +02:00
|
|
|
w = tf1.get_variable(
|
2018-07-19 15:30:36 -07:00
|
|
|
"W",
|
|
|
|
filter_shape,
|
|
|
|
dtype,
|
2020-06-30 10:13:20 +02:00
|
|
|
tf1.random_uniform_initializer(-w_bound, w_bound),
|
2018-07-19 15:30:36 -07:00
|
|
|
collections=collections,
|
|
|
|
)
|
2020-06-30 10:13:20 +02:00
|
|
|
b = tf1.get_variable(
|
2018-07-19 15:30:36 -07:00
|
|
|
"b",
|
|
|
|
[1, 1, 1, num_filters],
|
2020-06-30 10:13:20 +02:00
|
|
|
initializer=tf1.constant_initializer(0.0),
|
2018-07-19 15:30:36 -07:00
|
|
|
collections=collections,
|
|
|
|
)
|
2020-06-30 10:13:20 +02:00
|
|
|
return tf1.nn.conv2d(x, w, stride_shape, pad) + b
|
2017-08-28 12:23:14 -07:00
|
|
|
|
|
|
|
|
2022-05-24 22:14:25 -07:00
|
|
|
@DeveloperAPI
|
2020-11-12 03:18:50 -08:00
|
|
|
def linear(
|
|
|
|
x: TensorType,
|
|
|
|
size: int,
|
|
|
|
name: str,
|
|
|
|
initializer: Optional[Any] = None,
|
|
|
|
bias_init: float = 0.0,
|
|
|
|
) -> TensorType:
|
2018-07-19 15:30:36 -07:00
|
|
|
w = tf1.get_variable(name + "/w", [x.get_shape()[1], size], initializer=initializer)
|
2020-06-30 10:13:20 +02:00
|
|
|
b = tf1.get_variable(
|
|
|
|
name + "/b", [size], initializer=tf1.constant_initializer(bias_init)
|
2022-01-29 18:41:57 -08:00
|
|
|
)
|
2017-08-28 12:23:14 -07:00
|
|
|
return tf.matmul(x, w) + b
|
|
|
|
|
|
|
|
|
2022-05-24 22:14:25 -07:00
|
|
|
@DeveloperAPI
|
2020-11-12 03:18:50 -08:00
|
|
|
def flatten(x: TensorType) -> TensorType:
|
2017-08-28 12:23:14 -07:00
|
|
|
return tf.reshape(x, [-1, np.prod(x.get_shape().as_list()[1:])])
|