[RLlib] rllib train crashes when using torch PPO/PG/A2C. (#7508)

* Fix.

* Rollback.

* TEST.

* TEST.

* TEST.

* TEST.

* TEST.

* TEST.

* TEST.

* TEST.

* TEST.

* TEST.

* TEST.

* TEST.

* TEST.

* TEST.

* TEST.

* TEST.

* TEST.

* TEST.

* TEST.

* TEST.

* TEST.

* TEST.

* TEST.
This commit is contained in:
Sven Mika 2020-03-08 21:03:18 +01:00 committed by GitHub
parent bc637a2546
commit f08687f550
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 9 additions and 5 deletions

View file

@ -12,6 +12,7 @@ class TorchDistributionWrapper(ActionDistribution):
@override(ActionDistribution)
def __init__(self, inputs, model):
if not isinstance(inputs, torch.Tensor):
inputs = torch.Tensor(inputs)
super().__init__(inputs, model)
# Store the last sample here.
@ -46,7 +47,8 @@ class TorchCategorical(TorchDistributionWrapper):
@override(ActionDistribution)
def __init__(self, inputs, model=None, temperature=1.0):
assert temperature > 0.0, "Categorical `temperature` must be > 0.0!"
super().__init__(inputs / temperature, model)
inputs /= temperature
super().__init__(inputs, model)
self.dist = torch.distributions.categorical.Categorical(
logits=self.inputs)

View file

@ -7,6 +7,7 @@ from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.utils.annotations import override, DeveloperAPI
from ray.rllib.utils.framework import try_import_torch
from ray.rllib.utils.schedules import ConstantSchedule, PiecewiseSchedule
from ray.rllib.utils.torch_ops import convert_to_non_torch_type
from ray.rllib.utils.tracking_dict import UsageTrackingDict
torch, _ = try_import_torch()
@ -94,12 +95,13 @@ class TorchPolicy(Policy):
extra_action_out = self.extra_action_out(input_dict, state_batches,
self.model, action_dist)
if logp is not None:
logp = convert_to_non_torch_type(logp)
extra_action_out.update({
ACTION_PROB: torch.exp(logp),
ACTION_PROB: np.exp(logp),
ACTION_LOGP: logp
})
return (actions.cpu().numpy(), [h.cpu().numpy() for h in state],
extra_action_out)
return convert_to_non_torch_type(
(actions, state, extra_action_out))
@override(Policy)
def compute_log_likelihoods(self,