mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
209 lines
7.9 KiB
Python
209 lines
7.9 KiB
Python
from gym.spaces import Space
|
|
from typing import List, Optional, Union, TYPE_CHECKING
|
|
|
|
from ray.rllib.models.action_dist import ActionDistribution
|
|
from ray.rllib.models.modelv2 import ModelV2
|
|
from ray.rllib.policy.sample_batch import SampleBatch
|
|
from ray.rllib.utils.annotations import DeveloperAPI
|
|
from ray.rllib.utils.framework import try_import_torch, TensorType
|
|
from ray.rllib.utils.typing import LocalOptimizer, TrainerConfigDict
|
|
|
|
if TYPE_CHECKING:
|
|
from ray.rllib.policy.policy import Policy
|
|
|
|
_, nn = try_import_torch()
|
|
|
|
|
|
@DeveloperAPI
|
|
class Exploration:
|
|
"""Implements an exploration strategy for Policies.
|
|
|
|
An Exploration takes model outputs, a distribution, and a timestep from
|
|
the agent and computes an action to apply to the environment using an
|
|
implemented exploration schema.
|
|
"""
|
|
|
|
def __init__(self, action_space: Space, *, framework: str,
|
|
policy_config: TrainerConfigDict, model: ModelV2,
|
|
num_workers: int, worker_index: int):
|
|
"""
|
|
Args:
|
|
action_space (Space): The action space in which to explore.
|
|
framework (str): One of "tf" or "torch".
|
|
policy_config (TrainerConfigDict): The Policy's config dict.
|
|
model (ModelV2): The Policy's model.
|
|
num_workers (int): The overall number of workers used.
|
|
worker_index (int): The index of the worker using this class.
|
|
"""
|
|
self.action_space = action_space
|
|
self.policy_config = policy_config
|
|
self.model = model
|
|
self.num_workers = num_workers
|
|
self.worker_index = worker_index
|
|
self.framework = framework
|
|
# The device on which the Model has been placed.
|
|
# This Exploration will be on the same device.
|
|
self.device = None
|
|
if isinstance(self.model, nn.Module):
|
|
params = list(self.model.parameters())
|
|
if params:
|
|
self.device = params[0].device
|
|
|
|
@DeveloperAPI
|
|
def before_compute_actions(
|
|
self,
|
|
*,
|
|
timestep: Optional[Union[TensorType, int]] = None,
|
|
explore: Optional[Union[TensorType, bool]] = None,
|
|
tf_sess: Optional["tf.Session"] = None,
|
|
**kwargs):
|
|
"""Hook for preparations before policy.compute_actions() is called.
|
|
|
|
Args:
|
|
timestep (Optional[Union[TensorType, int]]): An optional timestep
|
|
tensor.
|
|
explore (Optional[Union[TensorType, bool]]): An optional explore
|
|
boolean flag.
|
|
tf_sess (Optional[tf.Session]): The tf-session object to use.
|
|
**kwargs: Forward compatibility kwargs.
|
|
"""
|
|
pass
|
|
|
|
# yapf: disable
|
|
# __sphinx_doc_begin_get_exploration_action__
|
|
|
|
@DeveloperAPI
|
|
def get_exploration_action(self,
|
|
*,
|
|
action_distribution: ActionDistribution,
|
|
timestep: Union[TensorType, int],
|
|
explore: bool = True):
|
|
"""Returns a (possibly) exploratory action and its log-likelihood.
|
|
|
|
Given the Model's logits outputs and action distribution, returns an
|
|
exploratory action.
|
|
|
|
Args:
|
|
action_distribution (ActionDistribution): The instantiated
|
|
ActionDistribution object to work with when creating
|
|
exploration actions.
|
|
timestep (Union[TensorType, int]): The current sampling time step.
|
|
It can be a tensor for TF graph mode, otherwise an integer.
|
|
explore (Union[TensorType, bool]): True: "Normal" exploration
|
|
behavior. False: Suppress all exploratory behavior and return
|
|
a deterministic action.
|
|
|
|
Returns:
|
|
Tuple:
|
|
- The chosen exploration action or a tf-op to fetch the exploration
|
|
action from the graph.
|
|
- The log-likelihood of the exploration action.
|
|
"""
|
|
pass
|
|
|
|
# __sphinx_doc_end_get_exploration_action__
|
|
# yapf: enable
|
|
|
|
@DeveloperAPI
|
|
def on_episode_start(self,
|
|
policy,
|
|
*,
|
|
environment=None,
|
|
episode=None,
|
|
tf_sess=None):
|
|
"""Handles necessary exploration logic at the beginning of an episode.
|
|
|
|
Args:
|
|
policy (Policy): The Policy object that holds this Exploration.
|
|
environment (BaseEnv): The environment object we are acting in.
|
|
episode (int): The number of the episode that is starting.
|
|
tf_sess (Optional[tf.Session]): In case of tf, the session object.
|
|
"""
|
|
pass
|
|
|
|
@DeveloperAPI
|
|
def on_episode_end(self,
|
|
policy,
|
|
*,
|
|
environment=None,
|
|
episode=None,
|
|
tf_sess=None):
|
|
"""Handles necessary exploration logic at the end of an episode.
|
|
|
|
Args:
|
|
policy (Policy): The Policy object that holds this Exploration.
|
|
environment (BaseEnv): The environment object we are acting in.
|
|
episode (int): The number of the episode that is starting.
|
|
tf_sess (Optional[tf.Session]): In case of tf, the session object.
|
|
"""
|
|
pass
|
|
|
|
@DeveloperAPI
|
|
def postprocess_trajectory(self,
|
|
policy: "Policy",
|
|
sample_batch,
|
|
tf_sess=None):
|
|
"""Handles post-processing of done episode trajectories.
|
|
|
|
Changes the given batch in place. This callback is invoked by the
|
|
sampler after policy.postprocess_trajectory() is called.
|
|
|
|
Args:
|
|
policy (Policy): The owning policy object.
|
|
sample_batch (SampleBatch): The SampleBatch object to post-process.
|
|
tf_sess (Optional[tf.Session]): An optional tf.Session object.
|
|
"""
|
|
return sample_batch
|
|
|
|
@DeveloperAPI
|
|
def get_exploration_optimizer(self, optimizers: List[LocalOptimizer]):
|
|
"""May add optimizer(s) to the Policy's own `optimizers`.
|
|
|
|
The number of optimizers (Policy's plus Exploration's optimizers) must
|
|
match the number of loss terms produced by the Policy's loss function
|
|
and the Exploration component's loss terms.
|
|
|
|
Args:
|
|
optimizers (List[LocalOptimizer]): The list of the Policy's
|
|
local optimizers.
|
|
|
|
Returns:
|
|
List[LocalOptimizer]: The updated list of local optimizers to use
|
|
on the different loss terms.
|
|
"""
|
|
return optimizers
|
|
|
|
@DeveloperAPI
|
|
def get_exploration_loss(self, policy_loss: List[TensorType],
|
|
train_batch: SampleBatch):
|
|
"""May add loss term(s) to the Policy's own loss(es).
|
|
|
|
Args:
|
|
policy_loss (List[TensorType]): Loss(es) already calculated by the
|
|
Policy's own loss function and maybe the Model's custom loss.
|
|
train_batch (SampleBatch): The training data to calculate the
|
|
loss(es) for. This train data has already gone through
|
|
this Exploration's `preprocess_train_batch()` method.
|
|
|
|
Returns:
|
|
List[TensorType]: The updated list of loss terms.
|
|
This may be the original Policy loss(es), altered, and/or new
|
|
loss terms added to it.
|
|
"""
|
|
return policy_loss
|
|
|
|
@DeveloperAPI
|
|
def get_info(self, sess=None):
|
|
"""Returns a description of the current exploration state.
|
|
|
|
This is not necessarily the state itself (and cannot be used in
|
|
set_state!), but rather useful (e.g. debugging) information.
|
|
|
|
Args:
|
|
sess (Optional[tf.Session]): An optional tf Session object to use.
|
|
|
|
Returns:
|
|
dict: A description of the Exploration (not necessarily its state).
|
|
This may include tf.ops as values in graph mode.
|
|
"""
|
|
return {}
|