mirror of
https://github.com/vale981/ray
synced 2025-03-09 12:56:46 -04:00

* Rollback. * WIP. * WIP. * LINT. * WIP. * Fix. * Fix. * Fix. * LINT. * Fix (SAC does currently not support eager). * Fix. * WIP. * LINT. * Update rllib/evaluation/sampler.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Update rllib/evaluation/sampler.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Update rllib/utils/exploration/exploration.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Update rllib/utils/exploration/exploration.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * WIP. * WIP. * Fix. * LINT. * LINT. * Fix and LINT. * WIP. * WIP. * WIP. * WIP. * Fix. * LINT. * Fix. * Fix and LINT. * Update rllib/utils/exploration/exploration.py * Update rllib/policy/dynamic_tf_policy.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Update rllib/policy/dynamic_tf_policy.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Update rllib/policy/dynamic_tf_policy.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Fixes. * LINT. * WIP. Co-authored-by: Eric Liang <ekhliang@gmail.com>
94 lines
3.7 KiB
Python
94 lines
3.7 KiB
Python
from gym.spaces import Discrete, MultiDiscrete, Tuple
|
|
from typing import Union
|
|
|
|
from ray.rllib.models.action_dist import ActionDistribution
|
|
from ray.rllib.utils.annotations import override
|
|
from ray.rllib.utils.exploration.exploration import Exploration
|
|
from ray.rllib.utils.framework import try_import_tf, try_import_torch, \
|
|
TensorType
|
|
from ray.rllib.utils.tuple_actions import TupleActions
|
|
|
|
tf = try_import_tf()
|
|
torch, _ = try_import_torch()
|
|
|
|
|
|
class Random(Exploration):
|
|
"""A random action selector (deterministic/greedy for explore=False).
|
|
|
|
If explore=True, returns actions randomly from `self.action_space` (via
|
|
Space.sample()).
|
|
If explore=False, returns the greedy/max-likelihood action.
|
|
"""
|
|
|
|
def __init__(self, action_space, *, model, framework, **kwargs):
|
|
"""Initialize a Random Exploration object.
|
|
|
|
Args:
|
|
action_space (Space): The gym action space used by the environment.
|
|
framework (Optional[str]): One of None, "tf", "torch".
|
|
"""
|
|
super().__init__(
|
|
action_space=action_space,
|
|
framework=framework,
|
|
model=model,
|
|
**kwargs)
|
|
|
|
# Determine py_func types, depending on our action-space.
|
|
if isinstance(self.action_space, (Discrete, MultiDiscrete)) or \
|
|
(isinstance(self.action_space, Tuple) and
|
|
isinstance(self.action_space[0], (Discrete, MultiDiscrete))):
|
|
self.dtype_sample, self.dtype = (tf.int64, tf.int32)
|
|
else:
|
|
self.dtype_sample, self.dtype = (tf.float64, tf.float32)
|
|
|
|
@override(Exploration)
|
|
def get_exploration_action(self,
|
|
*,
|
|
action_distribution: ActionDistribution,
|
|
timestep: Union[int, TensorType],
|
|
explore: bool = True):
|
|
# Instantiate the distribution object.
|
|
if self.framework == "tf":
|
|
return self.get_tf_exploration_action_op(action_distribution,
|
|
explore)
|
|
else:
|
|
return self.get_torch_exploration_action(action_distribution,
|
|
explore)
|
|
|
|
def get_tf_exploration_action_op(self, action_dist, explore):
|
|
def true_fn():
|
|
action = tf.py_function(self.action_space.sample, [],
|
|
self.dtype_sample)
|
|
# Will be unnecessary, once we support batch/time-aware Spaces.
|
|
return tf.expand_dims(tf.cast(action, dtype=self.dtype), 0)
|
|
|
|
def false_fn():
|
|
return tf.cast(
|
|
action_dist.deterministic_sample(), dtype=self.dtype)
|
|
|
|
action = tf.cond(
|
|
pred=tf.constant(explore, dtype=tf.bool)
|
|
if isinstance(explore, bool) else explore,
|
|
true_fn=true_fn,
|
|
false_fn=false_fn)
|
|
|
|
# TODO(sven): Move into (deterministic_)sample(logp=True|False)
|
|
if isinstance(action, TupleActions):
|
|
batch_size = tf.shape(action[0][0])[0]
|
|
else:
|
|
batch_size = tf.shape(action)[0]
|
|
logp = tf.zeros(shape=(batch_size, ), dtype=tf.float32)
|
|
return action, logp
|
|
|
|
def get_torch_exploration_action(self, action_dist, explore):
|
|
tensor_fn = torch.LongTensor if \
|
|
type(self.action_space) in [Discrete, MultiDiscrete] else \
|
|
torch.FloatTensor
|
|
if explore:
|
|
# Unsqueeze will be unnecessary, once we support batch/time-aware
|
|
# Spaces.
|
|
action = tensor_fn(self.action_space.sample()).unsqueeze(0)
|
|
else:
|
|
action = tensor_fn(action_dist.deterministic_sample())
|
|
logp = torch.zeros((action.size()[0], ), dtype=torch.float32)
|
|
return action, logp
|