mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
[RLlib] rllib train
crashes when using torch PPO/PG/A2C. (#7508)
* Fix. * Rollback. * TEST. * TEST. * TEST. * TEST. * TEST. * TEST. * TEST. * TEST. * TEST. * TEST. * TEST. * TEST. * TEST. * TEST. * TEST. * TEST. * TEST. * TEST. * TEST. * TEST. * TEST. * TEST. * TEST.
This commit is contained in:
parent
bc637a2546
commit
f08687f550
2 changed files with 9 additions and 5 deletions
|
@ -12,6 +12,7 @@ class TorchDistributionWrapper(ActionDistribution):
|
|||
|
||||
@override(ActionDistribution)
|
||||
def __init__(self, inputs, model):
|
||||
if not isinstance(inputs, torch.Tensor):
|
||||
inputs = torch.Tensor(inputs)
|
||||
super().__init__(inputs, model)
|
||||
# Store the last sample here.
|
||||
|
@ -46,7 +47,8 @@ class TorchCategorical(TorchDistributionWrapper):
|
|||
@override(ActionDistribution)
|
||||
def __init__(self, inputs, model=None, temperature=1.0):
|
||||
assert temperature > 0.0, "Categorical `temperature` must be > 0.0!"
|
||||
super().__init__(inputs / temperature, model)
|
||||
inputs /= temperature
|
||||
super().__init__(inputs, model)
|
||||
self.dist = torch.distributions.categorical.Categorical(
|
||||
logits=self.inputs)
|
||||
|
||||
|
|
|
@ -7,6 +7,7 @@ from ray.rllib.policy.sample_batch import SampleBatch
|
|||
from ray.rllib.utils.annotations import override, DeveloperAPI
|
||||
from ray.rllib.utils.framework import try_import_torch
|
||||
from ray.rllib.utils.schedules import ConstantSchedule, PiecewiseSchedule
|
||||
from ray.rllib.utils.torch_ops import convert_to_non_torch_type
|
||||
from ray.rllib.utils.tracking_dict import UsageTrackingDict
|
||||
|
||||
torch, _ = try_import_torch()
|
||||
|
@ -94,12 +95,13 @@ class TorchPolicy(Policy):
|
|||
extra_action_out = self.extra_action_out(input_dict, state_batches,
|
||||
self.model, action_dist)
|
||||
if logp is not None:
|
||||
logp = convert_to_non_torch_type(logp)
|
||||
extra_action_out.update({
|
||||
ACTION_PROB: torch.exp(logp),
|
||||
ACTION_PROB: np.exp(logp),
|
||||
ACTION_LOGP: logp
|
||||
})
|
||||
return (actions.cpu().numpy(), [h.cpu().numpy() for h in state],
|
||||
extra_action_out)
|
||||
return convert_to_non_torch_type(
|
||||
(actions, state, extra_action_out))
|
||||
|
||||
@override(Policy)
|
||||
def compute_log_likelihoods(self,
|
||||
|
|
Loading…
Add table
Reference in a new issue