2020-03-04 13:00:37 -08:00
|
|
|
from gym.spaces import Space
|
2021-06-15 13:08:43 +02:00
|
|
|
from typing import Dict, List, Optional, Union, TYPE_CHECKING
|
2020-03-29 00:16:30 +01:00
|
|
|
|
2020-10-07 21:59:14 +02:00
|
|
|
from ray.rllib.env.base_env import BaseEnv
|
2020-04-01 09:43:21 +02:00
|
|
|
from ray.rllib.models.action_dist import ActionDistribution
|
2020-03-04 13:00:37 -08:00
|
|
|
from ray.rllib.models.modelv2 import ModelV2
|
2020-08-19 17:49:50 +02:00
|
|
|
from ray.rllib.policy.sample_batch import SampleBatch
|
2021-08-03 18:30:02 -04:00
|
|
|
from ray.rllib.utils.annotations import Deprecated, DeveloperAPI
|
2020-08-13 14:14:16 -04:00
|
|
|
from ray.rllib.utils.framework import try_import_torch, TensorType
|
2020-08-19 17:49:50 +02:00
|
|
|
from ray.rllib.utils.typing import LocalOptimizer, TrainerConfigDict
|
2020-02-19 21:18:45 +01:00
|
|
|
|
2020-08-19 17:49:50 +02:00
|
|
|
if TYPE_CHECKING:
|
|
|
|
from ray.rllib.policy.policy import Policy
|
2021-05-03 14:23:28 -07:00
|
|
|
from ray.rllib.utils import try_import_tf
|
|
|
|
_, tf, _ = try_import_tf()
|
2020-08-19 17:49:50 +02:00
|
|
|
|
|
|
|
_, nn = try_import_torch()
|
2020-02-11 00:22:07 +01:00
|
|
|
|
|
|
|
|
2020-03-29 00:16:30 +01:00
|
|
|
@DeveloperAPI
|
2020-02-11 00:22:07 +01:00
|
|
|
class Exploration:
|
2020-02-19 21:18:45 +01:00
|
|
|
"""Implements an exploration strategy for Policies.
|
2020-02-11 00:22:07 +01:00
|
|
|
|
2020-02-19 21:18:45 +01:00
|
|
|
An Exploration takes model outputs, a distribution, and a timestep from
|
|
|
|
the agent and computes an action to apply to the environment using an
|
|
|
|
implemented exploration schema.
|
2020-02-11 00:22:07 +01:00
|
|
|
"""
|
|
|
|
|
2020-04-01 09:43:21 +02:00
|
|
|
def __init__(self, action_space: Space, *, framework: str,
|
2020-08-19 17:49:50 +02:00
|
|
|
policy_config: TrainerConfigDict, model: ModelV2,
|
|
|
|
num_workers: int, worker_index: int):
|
2020-02-11 00:22:07 +01:00
|
|
|
"""
|
|
|
|
Args:
|
2020-03-04 13:00:37 -08:00
|
|
|
action_space (Space): The action space in which to explore.
|
2020-04-01 09:43:21 +02:00
|
|
|
framework (str): One of "tf" or "torch".
|
2020-08-19 17:49:50 +02:00
|
|
|
policy_config (TrainerConfigDict): The Policy's config dict.
|
2020-04-01 09:43:21 +02:00
|
|
|
model (ModelV2): The Policy's model.
|
2020-04-03 19:44:25 +02:00
|
|
|
num_workers (int): The overall number of workers used.
|
|
|
|
worker_index (int): The index of the worker using this class.
|
2020-02-11 00:22:07 +01:00
|
|
|
"""
|
|
|
|
self.action_space = action_space
|
2020-04-01 09:43:21 +02:00
|
|
|
self.policy_config = policy_config
|
|
|
|
self.model = model
|
2020-02-11 00:22:07 +01:00
|
|
|
self.num_workers = num_workers
|
|
|
|
self.worker_index = worker_index
|
2020-06-08 23:04:50 -07:00
|
|
|
self.framework = framework
|
2020-04-15 13:25:16 +02:00
|
|
|
# The device on which the Model has been placed.
|
|
|
|
# This Exploration will be on the same device.
|
2020-05-08 08:20:18 +02:00
|
|
|
self.device = None
|
|
|
|
if isinstance(self.model, nn.Module):
|
|
|
|
params = list(self.model.parameters())
|
|
|
|
if params:
|
|
|
|
self.device = params[0].device
|
2020-02-11 00:22:07 +01:00
|
|
|
|
2020-03-29 00:16:30 +01:00
|
|
|
@DeveloperAPI
|
2020-08-19 17:49:50 +02:00
|
|
|
def before_compute_actions(
|
|
|
|
self,
|
|
|
|
*,
|
|
|
|
timestep: Optional[Union[TensorType, int]] = None,
|
|
|
|
explore: Optional[Union[TensorType, bool]] = None,
|
|
|
|
tf_sess: Optional["tf.Session"] = None,
|
|
|
|
**kwargs):
|
2020-03-29 00:16:30 +01:00
|
|
|
"""Hook for preparations before policy.compute_actions() is called.
|
|
|
|
|
|
|
|
Args:
|
2020-08-19 17:49:50 +02:00
|
|
|
timestep (Optional[Union[TensorType, int]]): An optional timestep
|
|
|
|
tensor.
|
|
|
|
explore (Optional[Union[TensorType, bool]]): An optional explore
|
|
|
|
boolean flag.
|
2020-03-29 00:16:30 +01:00
|
|
|
tf_sess (Optional[tf.Session]): The tf-session object to use.
|
|
|
|
**kwargs: Forward compatibility kwargs.
|
|
|
|
"""
|
|
|
|
pass
|
|
|
|
|
2020-09-29 09:39:22 +02:00
|
|
|
# yapf: disable
|
|
|
|
# __sphinx_doc_begin_get_exploration_action__
|
|
|
|
|
2020-03-29 00:16:30 +01:00
|
|
|
@DeveloperAPI
|
2020-02-11 00:22:07 +01:00
|
|
|
def get_exploration_action(self,
|
2020-04-01 09:43:21 +02:00
|
|
|
*,
|
|
|
|
action_distribution: ActionDistribution,
|
2020-08-19 17:49:50 +02:00
|
|
|
timestep: Union[TensorType, int],
|
2020-03-04 13:00:37 -08:00
|
|
|
explore: bool = True):
|
2020-03-02 01:55:41 +01:00
|
|
|
"""Returns a (possibly) exploratory action and its log-likelihood.
|
2020-02-11 00:22:07 +01:00
|
|
|
|
2020-02-19 21:18:45 +01:00
|
|
|
Given the Model's logits outputs and action distribution, returns an
|
|
|
|
exploratory action.
|
2020-02-11 00:22:07 +01:00
|
|
|
|
|
|
|
Args:
|
2020-04-01 09:43:21 +02:00
|
|
|
action_distribution (ActionDistribution): The instantiated
|
|
|
|
ActionDistribution object to work with when creating
|
|
|
|
exploration actions.
|
2020-08-19 17:49:50 +02:00
|
|
|
timestep (Union[TensorType, int]): The current sampling time step.
|
|
|
|
It can be a tensor for TF graph mode, otherwise an integer.
|
|
|
|
explore (Union[TensorType, bool]): True: "Normal" exploration
|
|
|
|
behavior. False: Suppress all exploratory behavior and return
|
|
|
|
a deterministic action.
|
2020-02-11 00:22:07 +01:00
|
|
|
|
|
|
|
Returns:
|
2020-03-02 01:55:41 +01:00
|
|
|
Tuple:
|
|
|
|
- The chosen exploration action or a tf-op to fetch the exploration
|
|
|
|
action from the graph.
|
|
|
|
- The log-likelihood of the exploration action.
|
2020-02-11 00:22:07 +01:00
|
|
|
"""
|
|
|
|
pass
|
|
|
|
|
2020-09-29 09:39:22 +02:00
|
|
|
# __sphinx_doc_end_get_exploration_action__
|
|
|
|
# yapf: enable
|
|
|
|
|
2020-03-29 00:16:30 +01:00
|
|
|
@DeveloperAPI
|
|
|
|
def on_episode_start(self,
|
2020-10-07 21:59:14 +02:00
|
|
|
policy: "Policy",
|
2020-03-29 00:16:30 +01:00
|
|
|
*,
|
2020-10-07 21:59:14 +02:00
|
|
|
environment: BaseEnv = None,
|
|
|
|
episode: int = None,
|
|
|
|
tf_sess: Optional["tf.Session"] = None):
|
2020-03-29 00:16:30 +01:00
|
|
|
"""Handles necessary exploration logic at the beginning of an episode.
|
2020-02-11 00:22:07 +01:00
|
|
|
|
|
|
|
Args:
|
2020-03-29 00:16:30 +01:00
|
|
|
policy (Policy): The Policy object that holds this Exploration.
|
|
|
|
environment (BaseEnv): The environment object we are acting in.
|
|
|
|
episode (int): The number of the episode that is starting.
|
|
|
|
tf_sess (Optional[tf.Session]): In case of tf, the session object.
|
|
|
|
"""
|
|
|
|
pass
|
2020-02-11 00:22:07 +01:00
|
|
|
|
2020-03-29 00:16:30 +01:00
|
|
|
@DeveloperAPI
|
|
|
|
def on_episode_end(self,
|
2020-10-07 21:59:14 +02:00
|
|
|
policy: "Policy",
|
2020-03-29 00:16:30 +01:00
|
|
|
*,
|
2020-10-07 21:59:14 +02:00
|
|
|
environment: BaseEnv = None,
|
|
|
|
episode: int = None,
|
|
|
|
tf_sess: Optional["tf.Session"] = None):
|
2020-03-29 00:16:30 +01:00
|
|
|
"""Handles necessary exploration logic at the end of an episode.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
policy (Policy): The Policy object that holds this Exploration.
|
|
|
|
environment (BaseEnv): The environment object we are acting in.
|
|
|
|
episode (int): The number of the episode that is starting.
|
|
|
|
tf_sess (Optional[tf.Session]): In case of tf, the session object.
|
|
|
|
"""
|
|
|
|
pass
|
|
|
|
|
|
|
|
@DeveloperAPI
|
2020-08-19 17:49:50 +02:00
|
|
|
def postprocess_trajectory(self,
|
|
|
|
policy: "Policy",
|
2020-10-07 21:59:14 +02:00
|
|
|
sample_batch: SampleBatch,
|
|
|
|
tf_sess: Optional["tf.Session"] = None):
|
2020-03-29 00:16:30 +01:00
|
|
|
"""Handles post-processing of done episode trajectories.
|
|
|
|
|
|
|
|
Changes the given batch in place. This callback is invoked by the
|
|
|
|
sampler after policy.postprocess_trajectory() is called.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
policy (Policy): The owning policy object.
|
|
|
|
sample_batch (SampleBatch): The SampleBatch object to post-process.
|
|
|
|
tf_sess (Optional[tf.Session]): An optional tf.Session object.
|
2020-02-11 00:22:07 +01:00
|
|
|
"""
|
2020-03-29 00:16:30 +01:00
|
|
|
return sample_batch
|
2020-02-11 00:22:07 +01:00
|
|
|
|
2020-08-19 17:49:50 +02:00
|
|
|
@DeveloperAPI
|
2020-11-29 12:31:24 +01:00
|
|
|
def get_exploration_optimizer(self, optimizers: List[LocalOptimizer]) -> \
|
|
|
|
List[LocalOptimizer]:
|
2020-08-19 17:49:50 +02:00
|
|
|
"""May add optimizer(s) to the Policy's own `optimizers`.
|
|
|
|
|
|
|
|
The number of optimizers (Policy's plus Exploration's optimizers) must
|
|
|
|
match the number of loss terms produced by the Policy's loss function
|
|
|
|
and the Exploration component's loss terms.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
optimizers (List[LocalOptimizer]): The list of the Policy's
|
|
|
|
local optimizers.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
List[LocalOptimizer]: The updated list of local optimizers to use
|
|
|
|
on the different loss terms.
|
|
|
|
"""
|
|
|
|
return optimizers
|
|
|
|
|
2020-03-29 00:16:30 +01:00
|
|
|
@DeveloperAPI
|
2021-06-15 13:08:43 +02:00
|
|
|
def get_state(self, sess: Optional["tf.Session"] = None) -> \
|
|
|
|
Dict[str, TensorType]:
|
|
|
|
"""Returns the current exploration state.
|
2020-02-11 00:22:07 +01:00
|
|
|
|
2020-05-04 23:53:38 +02:00
|
|
|
Args:
|
|
|
|
sess (Optional[tf.Session]): An optional tf Session object to use.
|
|
|
|
|
2020-02-11 00:22:07 +01:00
|
|
|
Returns:
|
2021-06-15 13:08:43 +02:00
|
|
|
Dict[str, TensorType]: The Exploration object's current state.
|
2020-02-11 00:22:07 +01:00
|
|
|
"""
|
2020-03-04 13:00:37 -08:00
|
|
|
return {}
|
2021-06-15 13:08:43 +02:00
|
|
|
|
|
|
|
@DeveloperAPI
|
|
|
|
def set_state(self, state: object,
|
|
|
|
sess: Optional["tf.Session"] = None) -> None:
|
|
|
|
"""Sets the Exploration object's state to the given values.
|
|
|
|
|
|
|
|
Note that some exploration components are stateless, even though they
|
|
|
|
decay some values over time (e.g. EpsilonGreedy). However the decay is
|
|
|
|
only dependent on the current global timestep of the policy and we
|
|
|
|
therefore don't need to keep track of it.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
state (object): The state to set this Exploration to.
|
|
|
|
sess (Optional[tf.Session]): An optional tf Session object to use.
|
|
|
|
"""
|
|
|
|
pass
|
|
|
|
|
2021-08-03 18:30:02 -04:00
|
|
|
@Deprecated(new="get_state", error=False)
|
2021-06-15 13:08:43 +02:00
|
|
|
def get_info(self, sess: Optional["tf.Session"] = None):
|
|
|
|
return self.get_state(sess)
|