ray/rllib/utils/exploration/exploration.py

211 lines
8.1 KiB
Python
Raw Normal View History

from gym.spaces import Space
from typing import List, Optional, Union, TYPE_CHECKING
2020-03-29 00:16:30 +01:00
from ray.rllib.env.base_env import BaseEnv
from ray.rllib.models.action_dist import ActionDistribution
from ray.rllib.models.modelv2 import ModelV2
from ray.rllib.policy.sample_batch import SampleBatch
2020-03-29 00:16:30 +01:00
from ray.rllib.utils.annotations import DeveloperAPI
from ray.rllib.utils.framework import try_import_torch, TensorType
from ray.rllib.utils.typing import LocalOptimizer, TrainerConfigDict
if TYPE_CHECKING:
from ray.rllib.policy.policy import Policy
_, nn = try_import_torch()
2020-03-29 00:16:30 +01:00
@DeveloperAPI
class Exploration:
"""Implements an exploration strategy for Policies.
An Exploration takes model outputs, a distribution, and a timestep from
the agent and computes an action to apply to the environment using an
implemented exploration schema.
"""
def __init__(self, action_space: Space, *, framework: str,
policy_config: TrainerConfigDict, model: ModelV2,
num_workers: int, worker_index: int):
"""
Args:
action_space (Space): The action space in which to explore.
framework (str): One of "tf" or "torch".
policy_config (TrainerConfigDict): The Policy's config dict.
model (ModelV2): The Policy's model.
num_workers (int): The overall number of workers used.
worker_index (int): The index of the worker using this class.
"""
self.action_space = action_space
self.policy_config = policy_config
self.model = model
self.num_workers = num_workers
self.worker_index = worker_index
self.framework = framework
[RLlib] SAC Torch (incl. Atari learning) (#7984) * Policy-classes cleanup and torch/tf unification. - Make Policy abstract. - Add `action_dist` to call to `extra_action_out_fn` (necessary for PPO torch). - Move some methods and vars to base Policy (from TFPolicy): num_state_tensors, ACTION_PROB, ACTION_LOGP and some more. * Fix `clip_action` import from Policy (should probably be moved into utils altogether). * - Move `is_recurrent()` and `num_state_tensors()` into TFPolicy (from DynamicTFPolicy). - Add config to all Policy c'tor calls (as 3rd arg after obs and action spaces). * Add `config` to c'tor call to TFPolicy. * Add missing `config` to c'tor call to TFPolicy in marvil_policy.py. * Fix test_rollout_worker.py::MockPolicy and BadPolicy classes (Policy base class is now abstract). * Fix LINT errors in Policy classes. * Implement StatefulPolicy abstract methods in test cases: test_multi_agent_env.py. * policy.py LINT errors. * Create a simple TestPolicy to sub-class from when testing Policies (reduces code in some test cases). * policy.py - Remove abstractmethod from `apply_gradients` and `compute_gradients` (these are not required iff `learn_on_batch` implemented). - Fix docstring of `num_state_tensors`. * Make QMIX torch Policy a child of TorchPolicy (instead of Policy). * QMixPolicy add empty implementations of abstract Policy methods. * Store Policy's config in self.config in base Policy c'tor. * - Make only compute_actions in base Policy's an abstractmethod and provide pass implementation to all other methods if not defined. - Fix state_batches=None (most Policies don't have internal states). * Cartpole tf learning. * Cartpole tf AND torch learning (in ~ same ts). * Cartpole tf AND torch learning (in ~ same ts). 2 * Cartpole tf (torch syntax-broken) learning (in ~ same ts). 3 * Cartpole tf AND torch learning (in ~ same ts). 4 * Cartpole tf AND torch learning (in ~ same ts). 5 * Cartpole tf AND torch learning (in ~ same ts). 6 * Cartpole tf AND torch learning (in ~ same ts). Pendulum tf learning. * WIP. * WIP. * SAC torch learning Pendulum. * WIP. * SAC torch and tf learning Pendulum and Cartpole after cleanup. * WIP. * LINT. * LINT. * SAC: Move policy.target_model to policy.device as well. * Fixes and cleanup. * Fix data-format of tf keras Conv2d layers (broken for some tf-versions which have data_format="channels_first" as default). * Fixes and LINT. * Fixes and LINT. * Fix and LINT. * WIP. * Test fixes and LINT. * Fixes and LINT. Co-authored-by: Sven Mika <sven@Svens-MacBook-Pro.local>
2020-04-15 13:25:16 +02:00
# The device on which the Model has been placed.
# This Exploration will be on the same device.
self.device = None
if isinstance(self.model, nn.Module):
params = list(self.model.parameters())
if params:
self.device = params[0].device
2020-03-29 00:16:30 +01:00
@DeveloperAPI
def before_compute_actions(
self,
*,
timestep: Optional[Union[TensorType, int]] = None,
explore: Optional[Union[TensorType, bool]] = None,
tf_sess: Optional["tf.Session"] = None,
**kwargs):
2020-03-29 00:16:30 +01:00
"""Hook for preparations before policy.compute_actions() is called.
Args:
timestep (Optional[Union[TensorType, int]]): An optional timestep
tensor.
explore (Optional[Union[TensorType, bool]]): An optional explore
boolean flag.
2020-03-29 00:16:30 +01:00
tf_sess (Optional[tf.Session]): The tf-session object to use.
**kwargs: Forward compatibility kwargs.
"""
pass
# yapf: disable
# __sphinx_doc_begin_get_exploration_action__
2020-03-29 00:16:30 +01:00
@DeveloperAPI
def get_exploration_action(self,
*,
action_distribution: ActionDistribution,
timestep: Union[TensorType, int],
explore: bool = True):
"""Returns a (possibly) exploratory action and its log-likelihood.
Given the Model's logits outputs and action distribution, returns an
exploratory action.
Args:
action_distribution (ActionDistribution): The instantiated
ActionDistribution object to work with when creating
exploration actions.
timestep (Union[TensorType, int]): The current sampling time step.
It can be a tensor for TF graph mode, otherwise an integer.
explore (Union[TensorType, bool]): True: "Normal" exploration
behavior. False: Suppress all exploratory behavior and return
a deterministic action.
Returns:
Tuple:
- The chosen exploration action or a tf-op to fetch the exploration
action from the graph.
- The log-likelihood of the exploration action.
"""
pass
# __sphinx_doc_end_get_exploration_action__
# yapf: enable
2020-03-29 00:16:30 +01:00
@DeveloperAPI
def on_episode_start(self,
policy: "Policy",
2020-03-29 00:16:30 +01:00
*,
environment: BaseEnv = None,
episode: int = None,
tf_sess: Optional["tf.Session"] = None):
2020-03-29 00:16:30 +01:00
"""Handles necessary exploration logic at the beginning of an episode.
Args:
2020-03-29 00:16:30 +01:00
policy (Policy): The Policy object that holds this Exploration.
environment (BaseEnv): The environment object we are acting in.
episode (int): The number of the episode that is starting.
tf_sess (Optional[tf.Session]): In case of tf, the session object.
"""
pass
2020-03-29 00:16:30 +01:00
@DeveloperAPI
def on_episode_end(self,
policy: "Policy",
2020-03-29 00:16:30 +01:00
*,
environment: BaseEnv = None,
episode: int = None,
tf_sess: Optional["tf.Session"] = None):
2020-03-29 00:16:30 +01:00
"""Handles necessary exploration logic at the end of an episode.
Args:
policy (Policy): The Policy object that holds this Exploration.
environment (BaseEnv): The environment object we are acting in.
episode (int): The number of the episode that is starting.
tf_sess (Optional[tf.Session]): In case of tf, the session object.
"""
pass
@DeveloperAPI
def postprocess_trajectory(self,
policy: "Policy",
sample_batch: SampleBatch,
tf_sess: Optional["tf.Session"] = None):
2020-03-29 00:16:30 +01:00
"""Handles post-processing of done episode trajectories.
Changes the given batch in place. This callback is invoked by the
sampler after policy.postprocess_trajectory() is called.
Args:
policy (Policy): The owning policy object.
sample_batch (SampleBatch): The SampleBatch object to post-process.
tf_sess (Optional[tf.Session]): An optional tf.Session object.
"""
2020-03-29 00:16:30 +01:00
return sample_batch
@DeveloperAPI
def get_exploration_optimizer(self, optimizers: List[LocalOptimizer]):
"""May add optimizer(s) to the Policy's own `optimizers`.
The number of optimizers (Policy's plus Exploration's optimizers) must
match the number of loss terms produced by the Policy's loss function
and the Exploration component's loss terms.
Args:
optimizers (List[LocalOptimizer]): The list of the Policy's
local optimizers.
Returns:
List[LocalOptimizer]: The updated list of local optimizers to use
on the different loss terms.
"""
return optimizers
@DeveloperAPI
def get_exploration_loss(self, policy_loss: List[TensorType],
train_batch: SampleBatch):
"""May add loss term(s) to the Policy's own loss(es).
Args:
policy_loss (List[TensorType]): Loss(es) already calculated by the
Policy's own loss function and maybe the Model's custom loss.
train_batch (SampleBatch): The training data to calculate the
loss(es) for. This train data has already gone through
this Exploration's `preprocess_train_batch()` method.
Returns:
List[TensorType]: The updated list of loss terms.
This may be the original Policy loss(es), altered, and/or new
loss terms added to it.
"""
return policy_loss
2020-03-29 00:16:30 +01:00
@DeveloperAPI
def get_info(self, sess: Optional["tf.Session"] = None):
"""Returns a description of the current exploration state.
This is not necessarily the state itself (and cannot be used in
set_state!), but rather useful (e.g. debugging) information.
Args:
sess (Optional[tf.Session]): An optional tf Session object to use.
Returns:
dict: A description of the Exploration (not necessarily its state).
This may include tf.ops as values in graph mode.
"""
return {}