2019-09-17 04:44:20 -04:00
|
|
|
"""Eager mode TF policy built using build_tf_policy().
|
|
|
|
|
|
|
|
It supports both traced and non-traced eager execution modes."""
|
2019-08-23 02:21:11 -04:00
|
|
|
|
2019-09-17 04:44:20 -04:00
|
|
|
import functools
|
2020-03-23 20:19:30 +01:00
|
|
|
import logging
|
2021-01-18 19:29:03 +01:00
|
|
|
import threading
|
2021-02-02 18:42:18 +01:00
|
|
|
from typing import Dict, List, Optional, Tuple
|
2019-08-23 02:21:11 -04:00
|
|
|
|
2020-02-27 19:40:44 +01:00
|
|
|
from ray.util.debug import log_once
|
2019-08-23 02:21:11 -04:00
|
|
|
from ray.rllib.models.catalog import ModelCatalog
|
2020-11-12 16:27:34 +01:00
|
|
|
from ray.rllib.models.repeated_values import RepeatedValues
|
2020-04-01 09:43:21 +02:00
|
|
|
from ray.rllib.policy.policy import Policy, LEARNER_STATS_KEY
|
2020-11-02 11:18:41 +01:00
|
|
|
from ray.rllib.policy.rnn_sequencing import pad_batch_to_sequences_of_same_size
|
2019-08-23 02:21:11 -04:00
|
|
|
from ray.rllib.policy.sample_batch import SampleBatch
|
2020-11-29 12:31:24 +01:00
|
|
|
from ray.rllib.utils import add_mixins, force_list
|
2019-08-23 02:21:11 -04:00
|
|
|
from ray.rllib.utils.annotations import override
|
2021-03-23 17:50:18 +01:00
|
|
|
from ray.rllib.utils.deprecation import deprecation_warning, DEPRECATED_VALUE
|
2020-02-22 23:19:49 +01:00
|
|
|
from ray.rllib.utils.framework import try_import_tf
|
2021-01-18 19:29:03 +01:00
|
|
|
from ray.rllib.utils.threading import with_lock
|
2021-02-02 18:42:18 +01:00
|
|
|
from ray.rllib.utils.typing import TensorType
|
2019-08-23 02:21:11 -04:00
|
|
|
|
2020-06-30 10:13:20 +02:00
|
|
|
tf1, tf, tfv = try_import_tf()
|
2019-08-23 02:21:11 -04:00
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
2020-10-12 22:49:11 +02:00
|
|
|
def _convert_to_tf(x, dtype=None):
|
2019-09-17 04:44:20 -04:00
|
|
|
if isinstance(x, SampleBatch):
|
2021-03-29 20:07:44 +02:00
|
|
|
dict_ = {k: v for k, v in x.items() if k != SampleBatch.INFOS}
|
|
|
|
return tf.nest.map_structure(_convert_to_tf, dict_)
|
2020-11-12 16:27:34 +01:00
|
|
|
elif isinstance(x, Policy):
|
2019-09-17 04:44:20 -04:00
|
|
|
return x
|
2020-11-12 16:27:34 +01:00
|
|
|
# Special handling of "Repeated" values.
|
|
|
|
elif isinstance(x, RepeatedValues):
|
|
|
|
return RepeatedValues(
|
|
|
|
tf.nest.map_structure(_convert_to_tf, x.values), x.lengths,
|
|
|
|
x.max_len)
|
2019-09-17 04:44:20 -04:00
|
|
|
|
|
|
|
if x is not None:
|
2020-10-12 22:49:11 +02:00
|
|
|
d = dtype
|
2020-03-22 21:51:24 +01:00
|
|
|
x = tf.nest.map_structure(
|
2021-03-30 19:28:45 +02:00
|
|
|
lambda f: _convert_to_tf(f, d) if isinstance(f, RepeatedValues)
|
|
|
|
else tf.convert_to_tensor(f, d) if f is not None else None, x)
|
2019-09-17 04:44:20 -04:00
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
|
|
def _convert_to_numpy(x):
|
2020-04-01 09:43:21 +02:00
|
|
|
def _map(x):
|
|
|
|
if isinstance(x, tf.Tensor):
|
|
|
|
return x.numpy()
|
|
|
|
return x
|
|
|
|
|
2019-09-17 04:44:20 -04:00
|
|
|
try:
|
2020-04-01 09:43:21 +02:00
|
|
|
return tf.nest.map_structure(_map, x)
|
2019-09-17 04:44:20 -04:00
|
|
|
except AttributeError:
|
|
|
|
raise TypeError(
|
|
|
|
("Object of type {} has no method to convert to numpy.").format(
|
|
|
|
type(x)))
|
|
|
|
|
|
|
|
|
|
|
|
def convert_eager_inputs(func):
|
|
|
|
@functools.wraps(func)
|
|
|
|
def _func(*args, **kwargs):
|
|
|
|
if tf.executing_eagerly():
|
|
|
|
args = [_convert_to_tf(x) for x in args]
|
2020-10-12 22:49:11 +02:00
|
|
|
# TODO: (sven) find a way to remove key-specific hacks.
|
2019-09-17 04:44:20 -04:00
|
|
|
kwargs = {
|
2020-10-12 22:49:11 +02:00
|
|
|
k: _convert_to_tf(
|
|
|
|
v, dtype=tf.int64 if k == "timestep" else None)
|
2019-09-17 04:44:20 -04:00
|
|
|
for k, v in kwargs.items()
|
|
|
|
if k not in {"info_batch", "episodes"}
|
|
|
|
}
|
|
|
|
return func(*args, **kwargs)
|
|
|
|
|
|
|
|
return _func
|
|
|
|
|
|
|
|
|
|
|
|
def convert_eager_outputs(func):
|
|
|
|
@functools.wraps(func)
|
|
|
|
def _func(*args, **kwargs):
|
|
|
|
out = func(*args, **kwargs)
|
|
|
|
if tf.executing_eagerly():
|
2020-03-22 21:51:24 +01:00
|
|
|
out = tf.nest.map_structure(_convert_to_numpy, out)
|
2019-09-17 04:44:20 -04:00
|
|
|
return out
|
|
|
|
|
|
|
|
return _func
|
|
|
|
|
|
|
|
|
2019-08-23 02:21:11 -04:00
|
|
|
def _disallow_var_creation(next_creator, **kw):
|
|
|
|
v = next_creator(**kw)
|
|
|
|
raise ValueError("Detected a variable being created during an eager "
|
|
|
|
"forward pass. Variables should only be created during "
|
|
|
|
"model initialization: {}".format(v.name))
|
|
|
|
|
|
|
|
|
2019-09-17 04:44:20 -04:00
|
|
|
def traced_eager_policy(eager_policy_cls):
|
|
|
|
"""Wrapper that enables tracing for all eager policy methods.
|
|
|
|
|
|
|
|
This is enabled by the --trace / "eager_tracing" config."""
|
|
|
|
|
|
|
|
class TracedEagerPolicy(eager_policy_cls):
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
|
|
self._traced_learn_on_batch = None
|
|
|
|
self._traced_compute_actions = None
|
|
|
|
self._traced_compute_gradients = None
|
|
|
|
self._traced_apply_gradients = None
|
|
|
|
super(TracedEagerPolicy, self).__init__(*args, **kwargs)
|
|
|
|
|
2020-11-02 11:18:41 +01:00
|
|
|
@override(eager_policy_cls)
|
2019-09-17 04:44:20 -04:00
|
|
|
@convert_eager_inputs
|
|
|
|
@convert_eager_outputs
|
2020-11-02 11:18:41 +01:00
|
|
|
def _learn_on_batch_eager(self, samples):
|
2019-09-17 04:44:20 -04:00
|
|
|
|
|
|
|
if self._traced_learn_on_batch is None:
|
|
|
|
self._traced_learn_on_batch = tf.function(
|
2020-11-02 11:18:41 +01:00
|
|
|
super(TracedEagerPolicy, self)._learn_on_batch_eager,
|
|
|
|
autograph=False,
|
|
|
|
experimental_relax_shapes=True)
|
2019-09-17 04:44:20 -04:00
|
|
|
|
|
|
|
return self._traced_learn_on_batch(samples)
|
|
|
|
|
|
|
|
@override(Policy)
|
|
|
|
@convert_eager_inputs
|
|
|
|
@convert_eager_outputs
|
|
|
|
def compute_actions(self,
|
|
|
|
obs_batch,
|
2020-04-03 19:44:25 +02:00
|
|
|
state_batches=None,
|
2019-09-17 04:44:20 -04:00
|
|
|
prev_action_batch=None,
|
|
|
|
prev_reward_batch=None,
|
|
|
|
info_batch=None,
|
|
|
|
episodes=None,
|
2020-02-19 21:18:45 +01:00
|
|
|
explore=None,
|
2020-02-11 00:22:07 +01:00
|
|
|
timestep=None,
|
2019-09-17 04:44:20 -04:00
|
|
|
**kwargs):
|
|
|
|
|
|
|
|
obs_batch = tf.convert_to_tensor(obs_batch)
|
|
|
|
state_batches = _convert_to_tf(state_batches)
|
|
|
|
prev_action_batch = _convert_to_tf(prev_action_batch)
|
|
|
|
prev_reward_batch = _convert_to_tf(prev_reward_batch)
|
|
|
|
|
|
|
|
if self._traced_compute_actions is None:
|
|
|
|
self._traced_compute_actions = tf.function(
|
|
|
|
super(TracedEagerPolicy, self).compute_actions,
|
2020-11-02 11:18:41 +01:00
|
|
|
autograph=False,
|
|
|
|
experimental_relax_shapes=True)
|
2019-09-17 04:44:20 -04:00
|
|
|
|
|
|
|
return self._traced_compute_actions(
|
|
|
|
obs_batch, state_batches, prev_action_batch, prev_reward_batch,
|
2020-02-11 00:22:07 +01:00
|
|
|
info_batch, episodes, explore, timestep, **kwargs)
|
2019-09-17 04:44:20 -04:00
|
|
|
|
2020-11-02 11:18:41 +01:00
|
|
|
@override(eager_policy_cls)
|
2019-09-17 04:44:20 -04:00
|
|
|
@convert_eager_inputs
|
|
|
|
@convert_eager_outputs
|
2020-11-02 11:18:41 +01:00
|
|
|
def _compute_gradients_eager(self, samples):
|
2019-09-17 04:44:20 -04:00
|
|
|
|
|
|
|
if self._traced_compute_gradients is None:
|
|
|
|
self._traced_compute_gradients = tf.function(
|
|
|
|
super(TracedEagerPolicy, self).compute_gradients,
|
2020-11-02 11:18:41 +01:00
|
|
|
autograph=False,
|
|
|
|
experimental_relax_shapes=True)
|
2019-09-17 04:44:20 -04:00
|
|
|
|
|
|
|
return self._traced_compute_gradients(samples)
|
|
|
|
|
|
|
|
@override(Policy)
|
|
|
|
@convert_eager_inputs
|
|
|
|
@convert_eager_outputs
|
|
|
|
def apply_gradients(self, grads):
|
|
|
|
|
|
|
|
if self._traced_apply_gradients is None:
|
|
|
|
self._traced_apply_gradients = tf.function(
|
|
|
|
super(TracedEagerPolicy, self).apply_gradients,
|
2020-11-02 11:18:41 +01:00
|
|
|
autograph=False,
|
|
|
|
experimental_relax_shapes=True)
|
2019-09-17 04:44:20 -04:00
|
|
|
|
|
|
|
return self._traced_apply_gradients(grads)
|
|
|
|
|
|
|
|
TracedEagerPolicy.__name__ = eager_policy_cls.__name__
|
|
|
|
TracedEagerPolicy.__qualname__ = eager_policy_cls.__qualname__
|
|
|
|
return TracedEagerPolicy
|
|
|
|
|
|
|
|
|
2021-02-25 12:18:11 +01:00
|
|
|
def build_eager_tf_policy(
|
|
|
|
name,
|
|
|
|
loss_fn,
|
|
|
|
get_default_config=None,
|
|
|
|
postprocess_fn=None,
|
|
|
|
stats_fn=None,
|
|
|
|
optimizer_fn=None,
|
|
|
|
gradients_fn=None,
|
|
|
|
apply_gradients_fn=None,
|
|
|
|
grad_stats_fn=None,
|
|
|
|
extra_learn_fetches_fn=None,
|
|
|
|
extra_action_out_fn=None,
|
|
|
|
validate_spaces=None,
|
|
|
|
before_init=None,
|
|
|
|
before_loss_init=None,
|
|
|
|
after_init=None,
|
|
|
|
make_model=None,
|
|
|
|
action_sampler_fn=None,
|
|
|
|
action_distribution_fn=None,
|
|
|
|
mixins=None,
|
2021-03-23 17:50:18 +01:00
|
|
|
obs_include_prev_action_reward=DEPRECATED_VALUE,
|
2021-02-25 12:18:11 +01:00
|
|
|
get_batch_divisibility_req=None,
|
|
|
|
# Deprecated args.
|
|
|
|
extra_action_fetches_fn=None):
|
2019-08-23 02:21:11 -04:00
|
|
|
"""Build an eager TF policy.
|
|
|
|
|
|
|
|
An eager policy runs all operations in eager mode, which makes debugging
|
2020-02-22 23:19:49 +01:00
|
|
|
much simpler, but has lower performance.
|
2019-08-23 02:21:11 -04:00
|
|
|
|
|
|
|
You shouldn't need to call this directly. Rather, prefer to build a TF
|
2020-05-27 16:19:13 +02:00
|
|
|
graph policy and use set {"framework": "tfe"} in the trainer config to have
|
2019-08-23 02:21:11 -04:00
|
|
|
it automatically be converted to an eager policy.
|
|
|
|
|
|
|
|
This has the same signature as build_tf_policy()."""
|
|
|
|
|
|
|
|
base = add_mixins(Policy, mixins)
|
|
|
|
|
2021-02-25 12:18:11 +01:00
|
|
|
if extra_action_fetches_fn is not None:
|
|
|
|
deprecation_warning(
|
|
|
|
old="extra_action_fetches_fn",
|
|
|
|
new="extra_action_out_fn",
|
|
|
|
error=False)
|
|
|
|
extra_action_out_fn = extra_action_fetches_fn
|
|
|
|
|
2021-03-23 17:50:18 +01:00
|
|
|
if obs_include_prev_action_reward != DEPRECATED_VALUE:
|
|
|
|
deprecation_warning(old="obs_include_prev_action_reward", error=False)
|
|
|
|
|
2019-08-23 02:21:11 -04:00
|
|
|
class eager_policy_cls(base):
|
|
|
|
def __init__(self, observation_space, action_space, config):
|
|
|
|
assert tf.executing_eagerly()
|
2020-10-02 23:07:44 +02:00
|
|
|
self.framework = config.get("framework", "tfe")
|
2019-08-23 02:21:11 -04:00
|
|
|
Policy.__init__(self, observation_space, action_space, config)
|
2021-04-20 08:46:05 +02:00
|
|
|
|
|
|
|
# Log device and worker index.
|
|
|
|
from ray.rllib.evaluation.rollout_worker import get_global_worker
|
|
|
|
worker = get_global_worker()
|
|
|
|
worker_idx = worker.worker_index if worker else 0
|
|
|
|
if tf.config.list_physical_devices("GPU"):
|
|
|
|
logger.info(
|
|
|
|
"TF-eager Policy (worker={}) running on GPU.".format(
|
|
|
|
worker_idx if worker_idx > 0 else "local"))
|
|
|
|
else:
|
|
|
|
logger.info(
|
|
|
|
"TF-eager Policy (worker={}) running on CPU.".format(
|
|
|
|
worker_idx if worker_idx > 0 else "local"))
|
|
|
|
|
2019-08-23 02:21:11 -04:00
|
|
|
self._is_training = False
|
|
|
|
self._loss_initialized = False
|
|
|
|
self._sess = None
|
|
|
|
|
2020-11-02 11:18:41 +01:00
|
|
|
self._loss = loss_fn
|
|
|
|
self.batch_divisibility_req = get_batch_divisibility_req(self) if \
|
|
|
|
callable(get_batch_divisibility_req) else \
|
|
|
|
(get_batch_divisibility_req or 1)
|
|
|
|
self._max_seq_len = config["model"]["max_seq_len"]
|
|
|
|
|
2019-08-23 02:21:11 -04:00
|
|
|
if get_default_config:
|
|
|
|
config = dict(get_default_config(), **config)
|
|
|
|
|
2020-06-25 19:01:32 +02:00
|
|
|
if validate_spaces:
|
|
|
|
validate_spaces(self, observation_space, action_space, config)
|
|
|
|
|
2019-08-23 02:21:11 -04:00
|
|
|
if before_init:
|
|
|
|
before_init(self, observation_space, action_space, config)
|
|
|
|
|
|
|
|
self.config = config
|
2020-02-22 23:19:49 +01:00
|
|
|
self.dist_class = None
|
2020-04-01 09:43:21 +02:00
|
|
|
if action_sampler_fn or action_distribution_fn:
|
2019-08-23 02:21:11 -04:00
|
|
|
if not make_model:
|
2020-04-01 09:43:21 +02:00
|
|
|
raise ValueError(
|
|
|
|
"`make_model` is required if `action_sampler_fn` OR "
|
|
|
|
"`action_distribution_fn` is given")
|
2019-08-23 02:21:11 -04:00
|
|
|
else:
|
2019-09-08 23:01:26 -07:00
|
|
|
self.dist_class, logit_dim = ModelCatalog.get_action_dist(
|
2019-08-23 02:21:11 -04:00
|
|
|
action_space, self.config["model"])
|
|
|
|
|
|
|
|
if make_model:
|
|
|
|
self.model = make_model(self, observation_space, action_space,
|
|
|
|
config)
|
|
|
|
else:
|
|
|
|
self.model = ModelCatalog.get_model_v2(
|
|
|
|
observation_space,
|
|
|
|
action_space,
|
|
|
|
logit_dim,
|
|
|
|
config["model"],
|
2020-07-09 10:44:10 +02:00
|
|
|
framework=self.framework,
|
2019-08-23 02:21:11 -04:00
|
|
|
)
|
2021-01-18 19:29:03 +01:00
|
|
|
# Lock used for locking some methods on the object-level.
|
|
|
|
# This prevents possible race conditions when calling the model
|
|
|
|
# first, then its value function (e.g. in a loss function), in
|
|
|
|
# between of which another model call is made (e.g. to compute an
|
|
|
|
# action).
|
|
|
|
self._lock = threading.RLock()
|
|
|
|
|
2020-11-12 16:27:34 +01:00
|
|
|
# Auto-update model's inference view requirements, if recurrent.
|
2020-12-30 20:32:21 -05:00
|
|
|
self._update_model_view_requirements_from_init_state()
|
2020-11-12 16:27:34 +01:00
|
|
|
|
2020-04-01 09:43:21 +02:00
|
|
|
self.exploration = self._create_exploration()
|
2020-12-21 21:38:34 -05:00
|
|
|
self._state_inputs = self.model.get_initial_state()
|
|
|
|
self._is_recurrent = len(self._state_inputs) > 0
|
2020-04-01 09:43:21 +02:00
|
|
|
|
2020-11-12 16:27:34 +01:00
|
|
|
# Combine view_requirements for Model and Policy.
|
2020-12-30 20:32:21 -05:00
|
|
|
self.view_requirements.update(self.model.view_requirements)
|
2019-08-23 02:21:11 -04:00
|
|
|
|
|
|
|
if before_loss_init:
|
|
|
|
before_loss_init(self, observation_space, action_space, config)
|
|
|
|
|
|
|
|
if optimizer_fn:
|
2020-11-29 12:31:24 +01:00
|
|
|
optimizers = optimizer_fn(self, config)
|
2019-08-23 02:21:11 -04:00
|
|
|
else:
|
2020-11-29 12:31:24 +01:00
|
|
|
optimizers = tf.keras.optimizers.Adam(config["lr"])
|
|
|
|
optimizers = force_list(optimizers)
|
|
|
|
if getattr(self, "exploration", None):
|
|
|
|
optimizers = self.exploration.get_exploration_optimizer(
|
|
|
|
optimizers)
|
|
|
|
# TODO: (sven) Allow tf policy to have more than 1 optimizer.
|
|
|
|
# Just like torch Policy does.
|
|
|
|
self._optimizer = optimizers[0] if optimizers else None
|
2019-08-23 02:21:11 -04:00
|
|
|
|
2020-12-07 13:08:17 +01:00
|
|
|
self._initialize_loss_from_dummy_batch(
|
|
|
|
auto_remove_unneeded_view_reqs=True,
|
|
|
|
stats_fn=stats_fn,
|
|
|
|
)
|
|
|
|
self._loss_initialized = True
|
|
|
|
|
2019-08-23 02:21:11 -04:00
|
|
|
if after_init:
|
|
|
|
after_init(self, observation_space, action_space, config)
|
|
|
|
|
2020-12-07 13:08:17 +01:00
|
|
|
# Got to reset global_timestep again after fake run-throughs.
|
2020-11-12 16:27:34 +01:00
|
|
|
self.global_timestep = 0
|
|
|
|
|
2019-08-23 02:21:11 -04:00
|
|
|
@override(Policy)
|
|
|
|
def postprocess_trajectory(self,
|
2020-03-29 00:16:30 +01:00
|
|
|
sample_batch,
|
2019-08-23 02:21:11 -04:00
|
|
|
other_agent_batches=None,
|
|
|
|
episode=None):
|
|
|
|
assert tf.executing_eagerly()
|
2020-03-29 00:16:30 +01:00
|
|
|
# Call super's postprocess_trajectory first.
|
|
|
|
sample_batch = Policy.postprocess_trajectory(self, sample_batch)
|
2019-08-23 02:21:11 -04:00
|
|
|
if postprocess_fn:
|
2020-03-29 00:16:30 +01:00
|
|
|
return postprocess_fn(self, sample_batch, other_agent_batches,
|
2019-09-11 12:15:34 -07:00
|
|
|
episode)
|
2020-03-29 00:16:30 +01:00
|
|
|
return sample_batch
|
2019-08-23 02:21:11 -04:00
|
|
|
|
2021-01-18 19:29:03 +01:00
|
|
|
@with_lock
|
2019-08-23 02:21:11 -04:00
|
|
|
@override(Policy)
|
2020-11-18 15:39:23 +01:00
|
|
|
def learn_on_batch(self, postprocessed_batch):
|
|
|
|
# Callback handling.
|
2021-02-08 15:02:19 +01:00
|
|
|
learn_stats = {}
|
2020-11-18 15:39:23 +01:00
|
|
|
self.callbacks.on_learn_on_batch(
|
2021-02-08 15:02:19 +01:00
|
|
|
policy=self,
|
|
|
|
train_batch=postprocessed_batch,
|
|
|
|
result=learn_stats)
|
2020-11-18 15:39:23 +01:00
|
|
|
|
2021-02-25 12:18:11 +01:00
|
|
|
if not isinstance(postprocessed_batch, SampleBatch) or \
|
|
|
|
not postprocessed_batch.zero_padded:
|
|
|
|
pad_batch_to_sequences_of_same_size(
|
|
|
|
postprocessed_batch,
|
|
|
|
max_seq_len=self._max_seq_len,
|
|
|
|
shuffle=False,
|
|
|
|
batch_divisibility_req=self.batch_divisibility_req,
|
|
|
|
view_requirements=self.view_requirements,
|
|
|
|
)
|
2020-12-21 02:22:32 +01:00
|
|
|
|
|
|
|
self._is_training = True
|
|
|
|
postprocessed_batch["is_training"] = True
|
2021-02-08 15:02:19 +01:00
|
|
|
stats = self._learn_on_batch_eager(postprocessed_batch)
|
|
|
|
stats.update({"custom_metrics": learn_stats})
|
|
|
|
return stats
|
2020-11-02 11:18:41 +01:00
|
|
|
|
2019-09-17 04:44:20 -04:00
|
|
|
@convert_eager_inputs
|
|
|
|
@convert_eager_outputs
|
2020-11-02 11:18:41 +01:00
|
|
|
def _learn_on_batch_eager(self, samples):
|
2019-08-23 02:21:11 -04:00
|
|
|
with tf.variable_creator_scope(_disallow_var_creation):
|
|
|
|
grads_and_vars, stats = self._compute_gradients(samples)
|
|
|
|
self._apply_gradients(grads_and_vars)
|
|
|
|
return stats
|
|
|
|
|
|
|
|
@override(Policy)
|
2020-11-02 11:18:41 +01:00
|
|
|
def compute_gradients(self, samples):
|
|
|
|
pad_batch_to_sequences_of_same_size(
|
|
|
|
samples,
|
|
|
|
shuffle=False,
|
|
|
|
max_seq_len=self._max_seq_len,
|
|
|
|
batch_divisibility_req=self.batch_divisibility_req)
|
2020-12-21 02:22:32 +01:00
|
|
|
|
|
|
|
self._is_training = True
|
|
|
|
samples["is_training"] = True
|
2020-11-02 11:18:41 +01:00
|
|
|
return self._compute_gradients_eager(samples)
|
|
|
|
|
2019-09-17 04:44:20 -04:00
|
|
|
@convert_eager_inputs
|
|
|
|
@convert_eager_outputs
|
2020-11-02 11:18:41 +01:00
|
|
|
def _compute_gradients_eager(self, samples):
|
2019-08-23 02:21:11 -04:00
|
|
|
with tf.variable_creator_scope(_disallow_var_creation):
|
|
|
|
grads_and_vars, stats = self._compute_gradients(samples)
|
|
|
|
grads = [g for g, v in grads_and_vars]
|
|
|
|
return grads, stats
|
|
|
|
|
|
|
|
@override(Policy)
|
|
|
|
def compute_actions(self,
|
|
|
|
obs_batch,
|
2020-04-03 19:44:25 +02:00
|
|
|
state_batches=None,
|
2019-08-23 02:21:11 -04:00
|
|
|
prev_action_batch=None,
|
|
|
|
prev_reward_batch=None,
|
|
|
|
info_batch=None,
|
|
|
|
episodes=None,
|
2020-02-19 21:18:45 +01:00
|
|
|
explore=None,
|
2020-02-11 00:22:07 +01:00
|
|
|
timestep=None,
|
2019-08-23 02:21:11 -04:00
|
|
|
**kwargs):
|
2020-02-19 21:18:45 +01:00
|
|
|
|
2019-08-23 02:21:11 -04:00
|
|
|
self._is_training = False
|
2020-12-21 21:38:34 -05:00
|
|
|
self._is_recurrent = \
|
|
|
|
state_batches is not None and state_batches != []
|
2019-09-17 04:44:20 -04:00
|
|
|
|
2020-07-08 16:12:20 +02:00
|
|
|
if not tf1.executing_eagerly():
|
|
|
|
tf1.enable_eager_execution()
|
2020-04-03 19:44:25 +02:00
|
|
|
|
2019-09-17 04:44:20 -04:00
|
|
|
input_dict = {
|
2019-08-23 02:21:11 -04:00
|
|
|
SampleBatch.CUR_OBS: tf.convert_to_tensor(obs_batch),
|
2019-09-17 04:44:20 -04:00
|
|
|
"is_training": tf.constant(False),
|
2019-08-23 02:21:11 -04:00
|
|
|
}
|
2021-03-23 17:50:18 +01:00
|
|
|
if prev_action_batch is not None:
|
|
|
|
input_dict[SampleBatch.PREV_ACTIONS] = \
|
|
|
|
tf.convert_to_tensor(prev_action_batch)
|
|
|
|
if prev_reward_batch is not None:
|
|
|
|
input_dict[SampleBatch.PREV_REWARDS] = \
|
|
|
|
tf.convert_to_tensor(prev_reward_batch)
|
2019-09-17 04:44:20 -04:00
|
|
|
|
2021-02-02 18:42:18 +01:00
|
|
|
return self._compute_action_helper(input_dict, state_batches,
|
|
|
|
episodes, explore, timestep)
|
|
|
|
|
|
|
|
@override(Policy)
|
|
|
|
def compute_actions_from_input_dict(
|
|
|
|
self,
|
|
|
|
input_dict: Dict[str, TensorType],
|
|
|
|
explore: bool = None,
|
|
|
|
timestep: Optional[int] = None,
|
|
|
|
**kwargs
|
|
|
|
) -> Tuple[TensorType, List[TensorType], Dict[str, TensorType]]:
|
|
|
|
|
|
|
|
if not tf1.executing_eagerly():
|
|
|
|
tf1.enable_eager_execution()
|
|
|
|
|
2021-04-11 18:20:04 +02:00
|
|
|
# Pass lazy (eager) tensor dict to Model as `input_dict`.
|
2021-02-02 18:42:18 +01:00
|
|
|
input_dict = self._lazy_tensor_dict(input_dict)
|
|
|
|
# Pack internal state inputs into (separate) list.
|
|
|
|
state_batches = [
|
|
|
|
input_dict[k] for k in input_dict.keys() if "state_in" in k[:8]
|
|
|
|
]
|
|
|
|
|
|
|
|
return self._compute_action_helper(input_dict, state_batches, None,
|
|
|
|
explore, timestep)
|
|
|
|
|
|
|
|
@with_lock
|
|
|
|
@convert_eager_inputs
|
|
|
|
@convert_eager_outputs
|
|
|
|
def _compute_action_helper(self, input_dict, state_batches, episodes,
|
|
|
|
explore, timestep):
|
|
|
|
|
|
|
|
explore = explore if explore is not None else \
|
|
|
|
self.config["explore"]
|
|
|
|
timestep = timestep if timestep is not None else \
|
|
|
|
self.global_timestep
|
|
|
|
if isinstance(timestep, tf.Tensor):
|
|
|
|
timestep = int(timestep.numpy())
|
|
|
|
self._is_training = False
|
|
|
|
self._state_in = state_batches or []
|
|
|
|
# Calculate RNN sequence lengths.
|
|
|
|
batch_size = input_dict[SampleBatch.CUR_OBS].shape[0]
|
|
|
|
seq_lens = tf.ones(batch_size, dtype=tf.int32) if state_batches \
|
|
|
|
else None
|
|
|
|
|
2021-04-27 10:44:54 +02:00
|
|
|
# Add default and custom fetches.
|
|
|
|
extra_fetches = {}
|
|
|
|
|
2020-02-19 21:18:45 +01:00
|
|
|
# Use Exploration object.
|
2020-04-01 09:43:21 +02:00
|
|
|
with tf.variable_creator_scope(_disallow_var_creation):
|
|
|
|
if action_sampler_fn:
|
2020-04-03 19:44:25 +02:00
|
|
|
dist_inputs = None
|
2020-04-01 09:43:21 +02:00
|
|
|
state_out = []
|
2020-12-11 12:57:33 +01:00
|
|
|
actions, logp = action_sampler_fn(
|
2020-04-01 09:43:21 +02:00
|
|
|
self,
|
2020-02-22 23:19:49 +01:00
|
|
|
self.model,
|
2020-04-01 09:43:21 +02:00
|
|
|
input_dict[SampleBatch.CUR_OBS],
|
|
|
|
explore=explore,
|
2020-07-09 10:44:10 +02:00
|
|
|
timestep=timestep,
|
|
|
|
episodes=episodes)
|
2020-04-01 09:43:21 +02:00
|
|
|
else:
|
|
|
|
# Exploration hook before each forward pass.
|
|
|
|
self.exploration.before_compute_actions(
|
|
|
|
timestep=timestep, explore=explore)
|
|
|
|
|
|
|
|
if action_distribution_fn:
|
2021-02-25 12:18:11 +01:00
|
|
|
|
|
|
|
# Try new action_distribution_fn signature, supporting
|
|
|
|
# state_batches and seq_lens.
|
|
|
|
try:
|
|
|
|
dist_inputs, self.dist_class, state_out = \
|
|
|
|
action_distribution_fn(
|
|
|
|
self,
|
|
|
|
self.model,
|
|
|
|
input_dict=input_dict,
|
|
|
|
state_batches=state_batches,
|
|
|
|
seq_lens=seq_lens,
|
|
|
|
explore=explore,
|
|
|
|
timestep=timestep,
|
|
|
|
is_training=False)
|
|
|
|
# Trying the old way (to stay backward compatible).
|
|
|
|
# TODO: Remove in future.
|
|
|
|
except TypeError as e:
|
|
|
|
if "positional argument" in e.args[0] or \
|
|
|
|
"unexpected keyword argument" in e.args[0]:
|
|
|
|
dist_inputs, self.dist_class, state_out = \
|
|
|
|
action_distribution_fn(
|
|
|
|
self, self.model,
|
|
|
|
input_dict[SampleBatch.CUR_OBS],
|
|
|
|
explore=explore,
|
|
|
|
timestep=timestep,
|
|
|
|
is_training=False)
|
|
|
|
else:
|
|
|
|
raise e
|
2021-04-27 10:44:54 +02:00
|
|
|
elif isinstance(self.model, tf.keras.Model):
|
|
|
|
input_dict = SampleBatch(input_dict, seq_lens=seq_lens)
|
2021-04-30 19:26:30 +02:00
|
|
|
if state_batches and "state_in_0" not in input_dict:
|
|
|
|
for i, s in enumerate(state_batches):
|
|
|
|
input_dict[f"state_in_{i}"] = s
|
2021-04-27 10:44:54 +02:00
|
|
|
self._lazy_tensor_dict(input_dict)
|
|
|
|
dist_inputs, state_out, extra_fetches = \
|
|
|
|
self.model(input_dict)
|
2020-04-01 09:43:21 +02:00
|
|
|
else:
|
|
|
|
dist_inputs, state_out = self.model(
|
2020-04-03 19:44:25 +02:00
|
|
|
input_dict, state_batches, seq_lens)
|
2020-04-01 09:43:21 +02:00
|
|
|
|
2020-12-07 13:08:17 +01:00
|
|
|
action_dist = self.dist_class(dist_inputs, self.model)
|
2020-04-01 09:43:21 +02:00
|
|
|
|
|
|
|
# Get the exploration action from the forward results.
|
|
|
|
actions, logp = self.exploration.get_exploration_action(
|
|
|
|
action_distribution=action_dist,
|
2020-03-29 00:16:30 +01:00
|
|
|
timestep=timestep,
|
2020-03-04 13:00:37 -08:00
|
|
|
explore=explore)
|
2020-02-11 00:22:07 +01:00
|
|
|
|
2020-04-01 09:43:21 +02:00
|
|
|
# Action-logp and action-prob.
|
2019-08-23 02:21:11 -04:00
|
|
|
if logp is not None:
|
2020-04-01 09:43:21 +02:00
|
|
|
extra_fetches[SampleBatch.ACTION_PROB] = tf.exp(logp)
|
|
|
|
extra_fetches[SampleBatch.ACTION_LOGP] = logp
|
|
|
|
# Action-dist inputs.
|
|
|
|
if dist_inputs is not None:
|
|
|
|
extra_fetches[SampleBatch.ACTION_DIST_INPUTS] = dist_inputs
|
|
|
|
# Custom extra fetches.
|
2021-02-25 12:18:11 +01:00
|
|
|
if extra_action_out_fn:
|
|
|
|
extra_fetches.update(extra_action_out_fn(self))
|
2020-02-11 00:22:07 +01:00
|
|
|
|
2020-10-06 20:28:16 +02:00
|
|
|
# Update our global timestep by the batch size.
|
2020-11-03 21:53:34 +01:00
|
|
|
self.global_timestep += int(batch_size)
|
2020-02-11 00:22:07 +01:00
|
|
|
|
2020-04-01 09:43:21 +02:00
|
|
|
return actions, state_out, extra_fetches
|
2019-08-23 02:21:11 -04:00
|
|
|
|
2021-01-18 19:29:03 +01:00
|
|
|
@with_lock
|
2020-02-22 23:19:49 +01:00
|
|
|
@override(Policy)
|
|
|
|
def compute_log_likelihoods(self,
|
|
|
|
actions,
|
|
|
|
obs_batch,
|
|
|
|
state_batches=None,
|
|
|
|
prev_action_batch=None,
|
|
|
|
prev_reward_batch=None):
|
2020-04-01 09:43:21 +02:00
|
|
|
if action_sampler_fn and action_distribution_fn is None:
|
|
|
|
raise ValueError("Cannot compute log-prob/likelihood w/o an "
|
|
|
|
"`action_distribution_fn` and a provided "
|
|
|
|
"`action_sampler_fn`!")
|
2020-02-22 23:19:49 +01:00
|
|
|
|
|
|
|
seq_lens = tf.ones(len(obs_batch), dtype=tf.int32)
|
|
|
|
input_dict = {
|
|
|
|
SampleBatch.CUR_OBS: tf.convert_to_tensor(obs_batch),
|
|
|
|
"is_training": tf.constant(False),
|
|
|
|
}
|
2021-03-23 17:50:18 +01:00
|
|
|
if prev_action_batch is not None:
|
|
|
|
input_dict[SampleBatch.PREV_ACTIONS] = \
|
|
|
|
tf.convert_to_tensor(prev_action_batch)
|
|
|
|
if prev_reward_batch is not None:
|
|
|
|
input_dict[SampleBatch.PREV_REWARDS] = \
|
|
|
|
tf.convert_to_tensor(prev_reward_batch)
|
2020-02-22 23:19:49 +01:00
|
|
|
|
2020-04-01 09:43:21 +02:00
|
|
|
# Exploration hook before each forward pass.
|
|
|
|
self.exploration.before_compute_actions(explore=False)
|
|
|
|
|
|
|
|
# Action dist class and inputs are generated via custom function.
|
|
|
|
if action_distribution_fn:
|
|
|
|
dist_inputs, dist_class, _ = action_distribution_fn(
|
2020-04-06 20:56:16 +02:00
|
|
|
self,
|
|
|
|
self.model,
|
|
|
|
input_dict[SampleBatch.CUR_OBS],
|
|
|
|
explore=False,
|
|
|
|
is_training=False)
|
2020-02-22 23:19:49 +01:00
|
|
|
# Default log-likelihood calculation.
|
|
|
|
else:
|
|
|
|
dist_inputs, _ = self.model(input_dict, state_batches,
|
|
|
|
seq_lens)
|
2020-04-03 19:44:25 +02:00
|
|
|
dist_class = self.dist_class
|
|
|
|
|
|
|
|
action_dist = dist_class(dist_inputs, self.model)
|
|
|
|
log_likelihoods = action_dist.logp(actions)
|
2020-02-22 23:19:49 +01:00
|
|
|
|
|
|
|
return log_likelihoods
|
|
|
|
|
2019-08-23 02:21:11 -04:00
|
|
|
@override(Policy)
|
|
|
|
def apply_gradients(self, gradients):
|
|
|
|
self._apply_gradients(
|
|
|
|
zip([(tf.convert_to_tensor(g) if g is not None else None)
|
|
|
|
for g in gradients], self.model.trainable_variables()))
|
|
|
|
|
2020-03-01 20:53:35 +01:00
|
|
|
@override(Policy)
|
|
|
|
def get_exploration_info(self):
|
|
|
|
return _convert_to_numpy(self.exploration.get_info())
|
|
|
|
|
2019-08-23 02:21:11 -04:00
|
|
|
@override(Policy)
|
2020-07-09 10:44:10 +02:00
|
|
|
def get_weights(self, as_dict=False):
|
2019-10-31 15:16:02 -07:00
|
|
|
variables = self.variables()
|
2020-07-09 10:44:10 +02:00
|
|
|
if as_dict:
|
|
|
|
return {v.name: v.numpy() for v in variables}
|
2019-08-26 23:23:35 -07:00
|
|
|
return [v.numpy() for v in variables]
|
2019-08-23 02:21:11 -04:00
|
|
|
|
|
|
|
@override(Policy)
|
|
|
|
def set_weights(self, weights):
|
2019-10-31 15:16:02 -07:00
|
|
|
variables = self.variables()
|
2019-08-26 23:23:35 -07:00
|
|
|
assert len(weights) == len(variables), (len(weights),
|
|
|
|
len(variables))
|
|
|
|
for v, w in zip(variables, weights):
|
|
|
|
v.assign(w)
|
2019-08-23 02:21:11 -04:00
|
|
|
|
2020-06-05 21:07:02 +02:00
|
|
|
@override(Policy)
|
|
|
|
def get_state(self):
|
|
|
|
state = {"_state": super().get_state()}
|
2021-03-24 11:26:22 -04:00
|
|
|
if self._optimizer and \
|
|
|
|
len(self._optimizer.variables()) > 0:
|
|
|
|
state["_optimizer_variables"] = \
|
|
|
|
self._optimizer.variables()
|
2020-06-05 21:07:02 +02:00
|
|
|
return state
|
|
|
|
|
|
|
|
@override(Policy)
|
|
|
|
def set_state(self, state):
|
|
|
|
state = state.copy() # shallow copy
|
|
|
|
# Set optimizer vars first.
|
|
|
|
optimizer_vars = state.pop("_optimizer_variables", None)
|
|
|
|
if optimizer_vars and self._optimizer.variables():
|
|
|
|
logger.warning(
|
|
|
|
"Cannot restore an optimizer's state for tf eager! Keras "
|
|
|
|
"is not able to save the v1.x optimizers (from "
|
|
|
|
"tf.compat.v1.train) since they aren't compatible with "
|
|
|
|
"checkpoints.")
|
|
|
|
for opt_var, value in zip(self._optimizer.variables(),
|
|
|
|
optimizer_vars):
|
|
|
|
opt_var.assign(value)
|
|
|
|
# Then the Policy's (NN) weights.
|
|
|
|
super().set_state(state["_state"])
|
|
|
|
|
2019-10-31 15:16:02 -07:00
|
|
|
def variables(self):
|
|
|
|
"""Return the list of all savable variables for this policy."""
|
2021-04-27 10:44:54 +02:00
|
|
|
if isinstance(self.model, tf.keras.Model):
|
|
|
|
return self.model.variables
|
|
|
|
else:
|
|
|
|
return self.model.variables()
|
2019-10-31 15:16:02 -07:00
|
|
|
|
2020-02-11 00:22:07 +01:00
|
|
|
@override(Policy)
|
2019-08-23 02:21:11 -04:00
|
|
|
def is_recurrent(self):
|
2020-12-21 21:38:34 -05:00
|
|
|
return self._is_recurrent
|
2019-08-23 02:21:11 -04:00
|
|
|
|
2020-01-18 07:26:28 +01:00
|
|
|
@override(Policy)
|
2019-08-23 02:21:11 -04:00
|
|
|
def num_state_tensors(self):
|
2020-12-21 21:38:34 -05:00
|
|
|
return len(self._state_inputs)
|
2019-08-23 02:21:11 -04:00
|
|
|
|
2020-02-06 18:44:08 +01:00
|
|
|
@override(Policy)
|
|
|
|
def get_initial_state(self):
|
2020-12-07 13:08:17 +01:00
|
|
|
if hasattr(self, "model"):
|
|
|
|
return self.model.get_initial_state()
|
|
|
|
return []
|
2020-02-06 18:44:08 +01:00
|
|
|
|
2019-08-23 02:21:11 -04:00
|
|
|
def get_session(self):
|
|
|
|
return None # None implies eager
|
|
|
|
|
2019-09-11 12:15:34 -07:00
|
|
|
def get_placeholder(self, ph):
|
|
|
|
raise ValueError(
|
|
|
|
"get_placeholder() is not allowed in eager mode. Try using "
|
|
|
|
"rllib.utils.tf_ops.make_tf_callable() to write "
|
|
|
|
"functions that work in both graph and eager mode.")
|
|
|
|
|
2019-08-23 02:21:11 -04:00
|
|
|
def loss_initialized(self):
|
|
|
|
return self._loss_initialized
|
|
|
|
|
2020-01-18 07:26:28 +01:00
|
|
|
@override(Policy)
|
|
|
|
def export_model(self, export_dir):
|
|
|
|
pass
|
|
|
|
|
|
|
|
@override(Policy)
|
|
|
|
def export_checkpoint(self, export_dir):
|
|
|
|
pass
|
|
|
|
|
2019-08-23 02:21:11 -04:00
|
|
|
def _get_is_training_placeholder(self):
|
|
|
|
return tf.convert_to_tensor(self._is_training)
|
|
|
|
|
|
|
|
def _apply_gradients(self, grads_and_vars):
|
|
|
|
if apply_gradients_fn:
|
|
|
|
apply_gradients_fn(self, self._optimizer, grads_and_vars)
|
|
|
|
else:
|
2020-11-29 12:31:24 +01:00
|
|
|
self._optimizer.apply_gradients(
|
|
|
|
[(g, v) for g, v in grads_and_vars if g is not None])
|
2019-08-23 02:21:11 -04:00
|
|
|
|
2021-01-18 19:29:03 +01:00
|
|
|
@with_lock
|
2019-08-23 02:21:11 -04:00
|
|
|
def _compute_gradients(self, samples):
|
|
|
|
"""Computes and returns grads as eager tensors."""
|
|
|
|
|
|
|
|
with tf.GradientTape(persistent=gradients_fn is not None) as tape:
|
2019-09-08 23:01:26 -07:00
|
|
|
loss = loss_fn(self, self.model, self.dist_class, samples)
|
2019-08-23 02:21:11 -04:00
|
|
|
|
2021-04-27 10:44:54 +02:00
|
|
|
if isinstance(self.model, tf.keras.Model):
|
|
|
|
variables = self.model.trainable_variables
|
|
|
|
else:
|
|
|
|
variables = self.model.trainable_variables()
|
2019-08-23 02:21:11 -04:00
|
|
|
|
|
|
|
if gradients_fn:
|
|
|
|
|
2020-01-02 17:42:13 -08:00
|
|
|
class OptimizerWrapper:
|
2019-08-23 02:21:11 -04:00
|
|
|
def __init__(self, tape):
|
|
|
|
self.tape = tape
|
|
|
|
|
|
|
|
def compute_gradients(self, loss, var_list):
|
|
|
|
return list(
|
|
|
|
zip(self.tape.gradient(loss, var_list), var_list))
|
|
|
|
|
|
|
|
grads_and_vars = gradients_fn(self, OptimizerWrapper(tape),
|
|
|
|
loss)
|
|
|
|
else:
|
|
|
|
grads_and_vars = list(
|
|
|
|
zip(tape.gradient(loss, variables), variables))
|
|
|
|
|
|
|
|
if log_once("grad_vars"):
|
|
|
|
for _, v in grads_and_vars:
|
|
|
|
logger.info("Optimizing variable {}".format(v.name))
|
|
|
|
|
|
|
|
grads = [g for g, v in grads_and_vars]
|
|
|
|
stats = self._stats(self, samples, grads)
|
|
|
|
return grads_and_vars, stats
|
|
|
|
|
|
|
|
def _stats(self, outputs, samples, grads):
|
2019-09-17 04:44:20 -04:00
|
|
|
|
2019-08-23 02:21:11 -04:00
|
|
|
fetches = {}
|
|
|
|
if stats_fn:
|
|
|
|
fetches[LEARNER_STATS_KEY] = {
|
2019-09-17 04:44:20 -04:00
|
|
|
k: v
|
2019-08-23 02:21:11 -04:00
|
|
|
for k, v in stats_fn(outputs, samples).items()
|
|
|
|
}
|
|
|
|
else:
|
|
|
|
fetches[LEARNER_STATS_KEY] = {}
|
2020-04-06 20:56:16 +02:00
|
|
|
|
2019-08-23 02:21:11 -04:00
|
|
|
if extra_learn_fetches_fn:
|
2019-09-17 04:44:20 -04:00
|
|
|
fetches.update(
|
|
|
|
{k: v
|
|
|
|
for k, v in extra_learn_fetches_fn(self).items()})
|
2019-08-23 02:21:11 -04:00
|
|
|
if grad_stats_fn:
|
|
|
|
fetches.update({
|
2019-09-17 04:44:20 -04:00
|
|
|
k: v
|
2019-08-23 02:21:11 -04:00
|
|
|
for k, v in grad_stats_fn(self, samples, grads).items()
|
|
|
|
})
|
|
|
|
return fetches
|
|
|
|
|
2021-03-17 08:18:15 +01:00
|
|
|
def _lazy_tensor_dict(self, postprocessed_batch: SampleBatch):
|
|
|
|
# TODO: (sven): Keep for a while to ensure backward compatibility.
|
|
|
|
if not isinstance(postprocessed_batch, SampleBatch):
|
|
|
|
postprocessed_batch = SampleBatch(postprocessed_batch)
|
|
|
|
postprocessed_batch.set_get_interceptor(_convert_to_tf)
|
|
|
|
return postprocessed_batch
|
2020-11-03 21:53:34 +01:00
|
|
|
|
2019-09-17 04:44:20 -04:00
|
|
|
@classmethod
|
|
|
|
def with_tracing(cls):
|
|
|
|
return traced_eager_policy(cls)
|
|
|
|
|
2019-08-23 02:21:11 -04:00
|
|
|
eager_policy_cls.__name__ = name + "_eager"
|
|
|
|
eager_policy_cls.__qualname__ = name + "_eager"
|
|
|
|
return eager_policy_cls
|