[RLlib] Issue #9437 (PyTorch converts to CPU tensor, even if on GPU). (#9497)

This commit is contained in:
Sven Mika 2020-07-16 14:55:50 +02:00 committed by GitHub
parent 2f674728a6
commit 935d8308fb
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
10 changed files with 36 additions and 21 deletions

View file

@ -12,6 +12,7 @@ from ray.rllib.policy.tf_policy import LearningRateSchedule
from ray.rllib.policy.tf_policy_template import build_tf_policy
from ray.rllib.utils.error import UnsupportedSpaceException
from ray.rllib.utils.exploration import ParameterNoise
from ray.rllib.utils.numpy import convert_to_numpy
from ray.rllib.utils.framework import try_import_tf
from ray.rllib.utils.tf_ops import huber_loss, reduce_mean_ignore_inf, \
minimize_and_clip
@ -378,7 +379,8 @@ def postprocess_nstep_and_prio(policy, batch, other_agent=None, episode=None):
batch[SampleBatch.REWARDS], batch[SampleBatch.NEXT_OBS],
batch[SampleBatch.DONES], batch[PRIO_WEIGHTS])
new_priorities = (
np.abs(td_errors) + policy.config["prioritized_replay_eps"])
np.abs(convert_to_numpy(td_errors)) +
policy.config["prioritized_replay_eps"])
batch.data[PRIO_WEIGHTS] = new_priorities
return batch

View file

@ -65,7 +65,7 @@ def before_init(policy, observation_space, action_space, config):
observation = policy.observation_filter(
observation[None], update=update)
observation = convert_to_torch_tensor(observation)
observation = convert_to_torch_tensor(observation, policy.device)
dist_inputs, _ = policy.model({
SampleBatch.CUR_OBS: observation
}, [], None)

View file

@ -188,15 +188,16 @@ class ValueNetworkMixin:
def value(ob, prev_action, prev_reward, *state):
model_out, _ = self.model({
SampleBatch.CUR_OBS: convert_to_torch_tensor(
np.asarray([ob])),
np.asarray([ob]), self.device),
SampleBatch.PREV_ACTIONS: convert_to_torch_tensor(
np.asarray([prev_action])),
np.asarray([prev_action]), self.device),
SampleBatch.PREV_REWARDS: convert_to_torch_tensor(
np.asarray([prev_reward])),
np.asarray([prev_reward]), self.device),
"is_training": False,
}, [convert_to_torch_tensor(np.asarray([s])) for s in state],
convert_to_torch_tensor(
np.asarray([1])))
}, [convert_to_torch_tensor(np.asarray([s]), self.device) for
s in state],
convert_to_torch_tensor(
np.asarray([1]), self.device))
return self.model.value_function()[0]
else:

View file

@ -83,10 +83,13 @@ def centralized_critic_postprocessing(policy,
# overwrite default VF prediction with the central VF
if args.torch:
sample_batch[SampleBatch.VF_PREDS] = policy.compute_central_vf(
convert_to_torch_tensor(sample_batch[SampleBatch.CUR_OBS]),
convert_to_torch_tensor(sample_batch[OPPONENT_OBS]),
convert_to_torch_tensor(sample_batch[OPPONENT_ACTION])). \
detach().numpy()
convert_to_torch_tensor(
sample_batch[SampleBatch.CUR_OBS], policy.device),
convert_to_torch_tensor(
sample_batch[OPPONENT_OBS], policy.device),
convert_to_torch_tensor(
sample_batch[OPPONENT_ACTION], policy.device)) \
.detach().numpy()
else:
sample_batch[SampleBatch.VF_PREDS] = policy.compute_central_vf(
sample_batch[SampleBatch.CUR_OBS], sample_batch[OPPONENT_OBS],

View file

@ -178,12 +178,13 @@ class GumbelSoftmax(TFActionDistribution):
assert temperature >= 0.0
self.dist = tfp.distributions.RelaxedOneHotCategorical(
temperature=temperature, logits=inputs)
self.probs = tf.nn.softmax(self.dist._distribution.logits)
super().__init__(inputs, model)
@override(ActionDistribution)
def deterministic_sample(self):
# Return the dist object's prob values.
return self.dist._distribution.probs
return self.probs
@override(ActionDistribution)
def logp(self, x):

View file

@ -149,7 +149,8 @@ class TorchPolicy(Policy):
input_dict[SampleBatch.PREV_REWARDS] = \
np.asarray(prev_reward_batch)
state_batches = [
convert_to_torch_tensor(s) for s in (state_batches or [])
convert_to_torch_tensor(s, self.device)
for s in (state_batches or [])
]
actions, state_out, extra_fetches, logp = \
self._compute_action_helper(
@ -556,7 +557,8 @@ class TorchPolicy(Policy):
def _lazy_tensor_dict(self, postprocessed_batch):
train_batch = UsageTrackingDict(postprocessed_batch)
train_batch.set_get_interceptor(convert_to_torch_tensor)
train_batch.set_get_interceptor(functools.partial(
convert_to_torch_tensor, device=self.device))
return train_batch

View file

@ -28,8 +28,10 @@ class PerWorkerEpsilonGreedy(EpsilonGreedy):
if worker_index > 0:
# From page 5 of https://arxiv.org/pdf/1803.00933.pdf
alpha, eps, i = 7, 0.4, worker_index - 1
num_workers_minus_1 = float(num_workers - 1) \
if num_workers > 1 else 1.0
epsilon_schedule = ConstantSchedule(
eps**(1 + i / float(num_workers - 1) * alpha),
eps**(1 + (i / num_workers_minus_1) * alpha),
framework=framework)
# Local worker should have zero exploration so that eval
# rollouts run properly.

View file

@ -24,7 +24,9 @@ class PerWorkerGaussianNoise(GaussianNoise):
# Use a fixed, different epsilon per worker. See: Ape-X paper.
if num_workers > 0:
if worker_index > 0:
exponent = (1 + worker_index / float(num_workers - 1) * 7)
num_workers_minus_1 = float(num_workers - 1) \
if num_workers > 1 else 1.0
exponent = (1 + (worker_index / num_workers_minus_1) * 7)
scale_schedule = ConstantSchedule(
0.4**exponent, framework=framework)
# Local worker should have zero exploration so that eval

View file

@ -25,7 +25,9 @@ class PerWorkerOrnsteinUhlenbeckNoise(OrnsteinUhlenbeckNoise):
# Use a fixed, different epsilon per worker. See: Ape-X paper.
if num_workers > 0:
if worker_index > 0:
exponent = (1 + worker_index / float(num_workers - 1) * 7)
num_workers_minus_1 = float(num_workers - 1) \
if num_workers > 1 else 1.0
exponent = (1 + (worker_index / num_workers_minus_1) * 7)
scale_schedule = ConstantSchedule(
0.4**exponent, framework=framework)
# Local worker should have zero exploration so that eval

View file

@ -110,10 +110,10 @@ def convert_to_non_torch_type(stats):
return tree.map_structure(mapping, stats)
def convert_to_torch_tensor(stats, device=None):
def convert_to_torch_tensor(x, device=None):
"""Converts any struct to torch.Tensors.
stats (any): Any (possibly nested) struct, the values in which will be
x (any): Any (possibly nested) struct, the values in which will be
converted and returned as a new struct with all leaves converted
to torch tensors.
@ -137,7 +137,7 @@ def convert_to_torch_tensor(stats, device=None):
tensor = tensor.float()
return tensor if device is None else tensor.to(device)
return tree.map_structure(mapping, stats)
return tree.map_structure(mapping, x)
def atanh(x):