mirror of
https://github.com/vale981/ray
synced 2025-03-04 17:41:43 -05:00
This commit is contained in:
parent
2f674728a6
commit
935d8308fb
10 changed files with 36 additions and 21 deletions
|
@ -12,6 +12,7 @@ from ray.rllib.policy.tf_policy import LearningRateSchedule
|
|||
from ray.rllib.policy.tf_policy_template import build_tf_policy
|
||||
from ray.rllib.utils.error import UnsupportedSpaceException
|
||||
from ray.rllib.utils.exploration import ParameterNoise
|
||||
from ray.rllib.utils.numpy import convert_to_numpy
|
||||
from ray.rllib.utils.framework import try_import_tf
|
||||
from ray.rllib.utils.tf_ops import huber_loss, reduce_mean_ignore_inf, \
|
||||
minimize_and_clip
|
||||
|
@ -378,7 +379,8 @@ def postprocess_nstep_and_prio(policy, batch, other_agent=None, episode=None):
|
|||
batch[SampleBatch.REWARDS], batch[SampleBatch.NEXT_OBS],
|
||||
batch[SampleBatch.DONES], batch[PRIO_WEIGHTS])
|
||||
new_priorities = (
|
||||
np.abs(td_errors) + policy.config["prioritized_replay_eps"])
|
||||
np.abs(convert_to_numpy(td_errors)) +
|
||||
policy.config["prioritized_replay_eps"])
|
||||
batch.data[PRIO_WEIGHTS] = new_priorities
|
||||
|
||||
return batch
|
||||
|
|
|
@ -65,7 +65,7 @@ def before_init(policy, observation_space, action_space, config):
|
|||
observation = policy.observation_filter(
|
||||
observation[None], update=update)
|
||||
|
||||
observation = convert_to_torch_tensor(observation)
|
||||
observation = convert_to_torch_tensor(observation, policy.device)
|
||||
dist_inputs, _ = policy.model({
|
||||
SampleBatch.CUR_OBS: observation
|
||||
}, [], None)
|
||||
|
|
|
@ -188,15 +188,16 @@ class ValueNetworkMixin:
|
|||
def value(ob, prev_action, prev_reward, *state):
|
||||
model_out, _ = self.model({
|
||||
SampleBatch.CUR_OBS: convert_to_torch_tensor(
|
||||
np.asarray([ob])),
|
||||
np.asarray([ob]), self.device),
|
||||
SampleBatch.PREV_ACTIONS: convert_to_torch_tensor(
|
||||
np.asarray([prev_action])),
|
||||
np.asarray([prev_action]), self.device),
|
||||
SampleBatch.PREV_REWARDS: convert_to_torch_tensor(
|
||||
np.asarray([prev_reward])),
|
||||
np.asarray([prev_reward]), self.device),
|
||||
"is_training": False,
|
||||
}, [convert_to_torch_tensor(np.asarray([s])) for s in state],
|
||||
convert_to_torch_tensor(
|
||||
np.asarray([1])))
|
||||
}, [convert_to_torch_tensor(np.asarray([s]), self.device) for
|
||||
s in state],
|
||||
convert_to_torch_tensor(
|
||||
np.asarray([1]), self.device))
|
||||
return self.model.value_function()[0]
|
||||
|
||||
else:
|
||||
|
|
|
@ -83,10 +83,13 @@ def centralized_critic_postprocessing(policy,
|
|||
# overwrite default VF prediction with the central VF
|
||||
if args.torch:
|
||||
sample_batch[SampleBatch.VF_PREDS] = policy.compute_central_vf(
|
||||
convert_to_torch_tensor(sample_batch[SampleBatch.CUR_OBS]),
|
||||
convert_to_torch_tensor(sample_batch[OPPONENT_OBS]),
|
||||
convert_to_torch_tensor(sample_batch[OPPONENT_ACTION])). \
|
||||
detach().numpy()
|
||||
convert_to_torch_tensor(
|
||||
sample_batch[SampleBatch.CUR_OBS], policy.device),
|
||||
convert_to_torch_tensor(
|
||||
sample_batch[OPPONENT_OBS], policy.device),
|
||||
convert_to_torch_tensor(
|
||||
sample_batch[OPPONENT_ACTION], policy.device)) \
|
||||
.detach().numpy()
|
||||
else:
|
||||
sample_batch[SampleBatch.VF_PREDS] = policy.compute_central_vf(
|
||||
sample_batch[SampleBatch.CUR_OBS], sample_batch[OPPONENT_OBS],
|
||||
|
|
|
@ -178,12 +178,13 @@ class GumbelSoftmax(TFActionDistribution):
|
|||
assert temperature >= 0.0
|
||||
self.dist = tfp.distributions.RelaxedOneHotCategorical(
|
||||
temperature=temperature, logits=inputs)
|
||||
self.probs = tf.nn.softmax(self.dist._distribution.logits)
|
||||
super().__init__(inputs, model)
|
||||
|
||||
@override(ActionDistribution)
|
||||
def deterministic_sample(self):
|
||||
# Return the dist object's prob values.
|
||||
return self.dist._distribution.probs
|
||||
return self.probs
|
||||
|
||||
@override(ActionDistribution)
|
||||
def logp(self, x):
|
||||
|
|
|
@ -149,7 +149,8 @@ class TorchPolicy(Policy):
|
|||
input_dict[SampleBatch.PREV_REWARDS] = \
|
||||
np.asarray(prev_reward_batch)
|
||||
state_batches = [
|
||||
convert_to_torch_tensor(s) for s in (state_batches or [])
|
||||
convert_to_torch_tensor(s, self.device)
|
||||
for s in (state_batches or [])
|
||||
]
|
||||
actions, state_out, extra_fetches, logp = \
|
||||
self._compute_action_helper(
|
||||
|
@ -556,7 +557,8 @@ class TorchPolicy(Policy):
|
|||
|
||||
def _lazy_tensor_dict(self, postprocessed_batch):
|
||||
train_batch = UsageTrackingDict(postprocessed_batch)
|
||||
train_batch.set_get_interceptor(convert_to_torch_tensor)
|
||||
train_batch.set_get_interceptor(functools.partial(
|
||||
convert_to_torch_tensor, device=self.device))
|
||||
return train_batch
|
||||
|
||||
|
||||
|
|
|
@ -28,8 +28,10 @@ class PerWorkerEpsilonGreedy(EpsilonGreedy):
|
|||
if worker_index > 0:
|
||||
# From page 5 of https://arxiv.org/pdf/1803.00933.pdf
|
||||
alpha, eps, i = 7, 0.4, worker_index - 1
|
||||
num_workers_minus_1 = float(num_workers - 1) \
|
||||
if num_workers > 1 else 1.0
|
||||
epsilon_schedule = ConstantSchedule(
|
||||
eps**(1 + i / float(num_workers - 1) * alpha),
|
||||
eps**(1 + (i / num_workers_minus_1) * alpha),
|
||||
framework=framework)
|
||||
# Local worker should have zero exploration so that eval
|
||||
# rollouts run properly.
|
||||
|
|
|
@ -24,7 +24,9 @@ class PerWorkerGaussianNoise(GaussianNoise):
|
|||
# Use a fixed, different epsilon per worker. See: Ape-X paper.
|
||||
if num_workers > 0:
|
||||
if worker_index > 0:
|
||||
exponent = (1 + worker_index / float(num_workers - 1) * 7)
|
||||
num_workers_minus_1 = float(num_workers - 1) \
|
||||
if num_workers > 1 else 1.0
|
||||
exponent = (1 + (worker_index / num_workers_minus_1) * 7)
|
||||
scale_schedule = ConstantSchedule(
|
||||
0.4**exponent, framework=framework)
|
||||
# Local worker should have zero exploration so that eval
|
||||
|
|
|
@ -25,7 +25,9 @@ class PerWorkerOrnsteinUhlenbeckNoise(OrnsteinUhlenbeckNoise):
|
|||
# Use a fixed, different epsilon per worker. See: Ape-X paper.
|
||||
if num_workers > 0:
|
||||
if worker_index > 0:
|
||||
exponent = (1 + worker_index / float(num_workers - 1) * 7)
|
||||
num_workers_minus_1 = float(num_workers - 1) \
|
||||
if num_workers > 1 else 1.0
|
||||
exponent = (1 + (worker_index / num_workers_minus_1) * 7)
|
||||
scale_schedule = ConstantSchedule(
|
||||
0.4**exponent, framework=framework)
|
||||
# Local worker should have zero exploration so that eval
|
||||
|
|
|
@ -110,10 +110,10 @@ def convert_to_non_torch_type(stats):
|
|||
return tree.map_structure(mapping, stats)
|
||||
|
||||
|
||||
def convert_to_torch_tensor(stats, device=None):
|
||||
def convert_to_torch_tensor(x, device=None):
|
||||
"""Converts any struct to torch.Tensors.
|
||||
|
||||
stats (any): Any (possibly nested) struct, the values in which will be
|
||||
x (any): Any (possibly nested) struct, the values in which will be
|
||||
converted and returned as a new struct with all leaves converted
|
||||
to torch tensors.
|
||||
|
||||
|
@ -137,7 +137,7 @@ def convert_to_torch_tensor(stats, device=None):
|
|||
tensor = tensor.float()
|
||||
return tensor if device is None else tensor.to(device)
|
||||
|
||||
return tree.map_structure(mapping, stats)
|
||||
return tree.map_structure(mapping, x)
|
||||
|
||||
|
||||
def atanh(x):
|
||||
|
|
Loading…
Add table
Reference in a new issue