ray/rllib/utils/exploration/exploration.py

212 lines
8 KiB
Python
Raw Normal View History

from gym.spaces import Space
from typing import Dict, List, Optional, Union, TYPE_CHECKING
2020-03-29 00:16:30 +01:00
from ray.rllib.env.base_env import BaseEnv
from ray.rllib.models.action_dist import ActionDistribution
from ray.rllib.models.modelv2 import ModelV2
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.utils.annotations import Deprecated, DeveloperAPI
from ray.rllib.utils.framework import try_import_torch, TensorType
from ray.rllib.utils.typing import LocalOptimizer, TrainerConfigDict
if TYPE_CHECKING:
from ray.rllib.policy.policy import Policy
from ray.rllib.utils import try_import_tf
_, tf, _ = try_import_tf()
_, nn = try_import_torch()
2020-03-29 00:16:30 +01:00
@DeveloperAPI
class Exploration:
"""Implements an exploration strategy for Policies.
An Exploration takes model outputs, a distribution, and a timestep from
the agent and computes an action to apply to the environment using an
implemented exploration schema.
"""
def __init__(self, action_space: Space, *, framework: str,
policy_config: TrainerConfigDict, model: ModelV2,
num_workers: int, worker_index: int):
"""
Args:
action_space (Space): The action space in which to explore.
framework (str): One of "tf" or "torch".
policy_config (TrainerConfigDict): The Policy's config dict.
model (ModelV2): The Policy's model.
num_workers (int): The overall number of workers used.
worker_index (int): The index of the worker using this class.
"""
self.action_space = action_space
self.policy_config = policy_config
self.model = model
self.num_workers = num_workers
self.worker_index = worker_index
self.framework = framework
[RLlib] SAC Torch (incl. Atari learning) (#7984) * Policy-classes cleanup and torch/tf unification. - Make Policy abstract. - Add `action_dist` to call to `extra_action_out_fn` (necessary for PPO torch). - Move some methods and vars to base Policy (from TFPolicy): num_state_tensors, ACTION_PROB, ACTION_LOGP and some more. * Fix `clip_action` import from Policy (should probably be moved into utils altogether). * - Move `is_recurrent()` and `num_state_tensors()` into TFPolicy (from DynamicTFPolicy). - Add config to all Policy c'tor calls (as 3rd arg after obs and action spaces). * Add `config` to c'tor call to TFPolicy. * Add missing `config` to c'tor call to TFPolicy in marvil_policy.py. * Fix test_rollout_worker.py::MockPolicy and BadPolicy classes (Policy base class is now abstract). * Fix LINT errors in Policy classes. * Implement StatefulPolicy abstract methods in test cases: test_multi_agent_env.py. * policy.py LINT errors. * Create a simple TestPolicy to sub-class from when testing Policies (reduces code in some test cases). * policy.py - Remove abstractmethod from `apply_gradients` and `compute_gradients` (these are not required iff `learn_on_batch` implemented). - Fix docstring of `num_state_tensors`. * Make QMIX torch Policy a child of TorchPolicy (instead of Policy). * QMixPolicy add empty implementations of abstract Policy methods. * Store Policy's config in self.config in base Policy c'tor. * - Make only compute_actions in base Policy's an abstractmethod and provide pass implementation to all other methods if not defined. - Fix state_batches=None (most Policies don't have internal states). * Cartpole tf learning. * Cartpole tf AND torch learning (in ~ same ts). * Cartpole tf AND torch learning (in ~ same ts). 2 * Cartpole tf (torch syntax-broken) learning (in ~ same ts). 3 * Cartpole tf AND torch learning (in ~ same ts). 4 * Cartpole tf AND torch learning (in ~ same ts). 5 * Cartpole tf AND torch learning (in ~ same ts). 6 * Cartpole tf AND torch learning (in ~ same ts). Pendulum tf learning. * WIP. * WIP. * SAC torch learning Pendulum. * WIP. * SAC torch and tf learning Pendulum and Cartpole after cleanup. * WIP. * LINT. * LINT. * SAC: Move policy.target_model to policy.device as well. * Fixes and cleanup. * Fix data-format of tf keras Conv2d layers (broken for some tf-versions which have data_format="channels_first" as default). * Fixes and LINT. * Fixes and LINT. * Fix and LINT. * WIP. * Test fixes and LINT. * Fixes and LINT. Co-authored-by: Sven Mika <sven@Svens-MacBook-Pro.local>
2020-04-15 13:25:16 +02:00
# The device on which the Model has been placed.
# This Exploration will be on the same device.
self.device = None
if isinstance(self.model, nn.Module):
params = list(self.model.parameters())
if params:
self.device = params[0].device
2020-03-29 00:16:30 +01:00
@DeveloperAPI
def before_compute_actions(
self,
*,
timestep: Optional[Union[TensorType, int]] = None,
explore: Optional[Union[TensorType, bool]] = None,
tf_sess: Optional["tf.Session"] = None,
**kwargs):
2020-03-29 00:16:30 +01:00
"""Hook for preparations before policy.compute_actions() is called.
Args:
timestep (Optional[Union[TensorType, int]]): An optional timestep
tensor.
explore (Optional[Union[TensorType, bool]]): An optional explore
boolean flag.
2020-03-29 00:16:30 +01:00
tf_sess (Optional[tf.Session]): The tf-session object to use.
**kwargs: Forward compatibility kwargs.
"""
pass
# yapf: disable
# __sphinx_doc_begin_get_exploration_action__
2020-03-29 00:16:30 +01:00
@DeveloperAPI
def get_exploration_action(self,
*,
action_distribution: ActionDistribution,
timestep: Union[TensorType, int],
explore: bool = True):
"""Returns a (possibly) exploratory action and its log-likelihood.
Given the Model's logits outputs and action distribution, returns an
exploratory action.
Args:
action_distribution (ActionDistribution): The instantiated
ActionDistribution object to work with when creating
exploration actions.
timestep (Union[TensorType, int]): The current sampling time step.
It can be a tensor for TF graph mode, otherwise an integer.
explore (Union[TensorType, bool]): True: "Normal" exploration
behavior. False: Suppress all exploratory behavior and return
a deterministic action.
Returns:
Tuple:
- The chosen exploration action or a tf-op to fetch the exploration
action from the graph.
- The log-likelihood of the exploration action.
"""
pass
# __sphinx_doc_end_get_exploration_action__
# yapf: enable
2020-03-29 00:16:30 +01:00
@DeveloperAPI
def on_episode_start(self,
policy: "Policy",
2020-03-29 00:16:30 +01:00
*,
environment: BaseEnv = None,
episode: int = None,
tf_sess: Optional["tf.Session"] = None):
2020-03-29 00:16:30 +01:00
"""Handles necessary exploration logic at the beginning of an episode.
Args:
2020-03-29 00:16:30 +01:00
policy (Policy): The Policy object that holds this Exploration.
environment (BaseEnv): The environment object we are acting in.
episode (int): The number of the episode that is starting.
tf_sess (Optional[tf.Session]): In case of tf, the session object.
"""
pass
2020-03-29 00:16:30 +01:00
@DeveloperAPI
def on_episode_end(self,
policy: "Policy",
2020-03-29 00:16:30 +01:00
*,
environment: BaseEnv = None,
episode: int = None,
tf_sess: Optional["tf.Session"] = None):
2020-03-29 00:16:30 +01:00
"""Handles necessary exploration logic at the end of an episode.
Args:
policy (Policy): The Policy object that holds this Exploration.
environment (BaseEnv): The environment object we are acting in.
episode (int): The number of the episode that is starting.
tf_sess (Optional[tf.Session]): In case of tf, the session object.
"""
pass
@DeveloperAPI
def postprocess_trajectory(self,
policy: "Policy",
sample_batch: SampleBatch,
tf_sess: Optional["tf.Session"] = None):
2020-03-29 00:16:30 +01:00
"""Handles post-processing of done episode trajectories.
Changes the given batch in place. This callback is invoked by the
sampler after policy.postprocess_trajectory() is called.
Args:
policy (Policy): The owning policy object.
sample_batch (SampleBatch): The SampleBatch object to post-process.
tf_sess (Optional[tf.Session]): An optional tf.Session object.
"""
2020-03-29 00:16:30 +01:00
return sample_batch
@DeveloperAPI
def get_exploration_optimizer(self, optimizers: List[LocalOptimizer]) -> \
List[LocalOptimizer]:
"""May add optimizer(s) to the Policy's own `optimizers`.
The number of optimizers (Policy's plus Exploration's optimizers) must
match the number of loss terms produced by the Policy's loss function
and the Exploration component's loss terms.
Args:
optimizers (List[LocalOptimizer]): The list of the Policy's
local optimizers.
Returns:
List[LocalOptimizer]: The updated list of local optimizers to use
on the different loss terms.
"""
return optimizers
2020-03-29 00:16:30 +01:00
@DeveloperAPI
def get_state(self, sess: Optional["tf.Session"] = None) -> \
Dict[str, TensorType]:
"""Returns the current exploration state.
Args:
sess (Optional[tf.Session]): An optional tf Session object to use.
Returns:
Dict[str, TensorType]: The Exploration object's current state.
"""
return {}
@DeveloperAPI
def set_state(self, state: object,
sess: Optional["tf.Session"] = None) -> None:
"""Sets the Exploration object's state to the given values.
Note that some exploration components are stateless, even though they
decay some values over time (e.g. EpsilonGreedy). However the decay is
only dependent on the current global timestep of the policy and we
therefore don't need to keep track of it.
Args:
state (object): The state to set this Exploration to.
sess (Optional[tf.Session]): An optional tf Session object to use.
"""
pass
@Deprecated(new="get_state", error=False)
def get_info(self, sess: Optional["tf.Session"] = None):
return self.get_state(sess)