2020-08-06 05:33:24 +02:00
|
|
|
import gym
|
2021-05-04 10:06:19 -07:00
|
|
|
from typing import Callable, Dict, List, Optional, Tuple, Type, Union
|
2020-08-06 05:33:24 +02:00
|
|
|
|
|
|
|
from ray.rllib.models.modelv2 import ModelV2
|
|
|
|
from ray.rllib.models.torch.torch_action_dist import TorchDistributionWrapper
|
2020-12-26 20:14:18 -05:00
|
|
|
from ray.rllib.policy.policy import Policy
|
|
|
|
from ray.rllib.policy.policy_template import build_policy_class
|
2020-08-06 05:33:24 +02:00
|
|
|
from ray.rllib.policy.sample_batch import SampleBatch
|
|
|
|
from ray.rllib.policy.torch_policy import TorchPolicy
|
2021-11-01 21:46:02 +01:00
|
|
|
from ray.rllib.utils.deprecation import Deprecated
|
2020-02-22 20:02:31 +01:00
|
|
|
from ray.rllib.utils.framework import try_import_torch
|
2022-06-11 15:10:39 +02:00
|
|
|
from ray.rllib.utils.typing import ModelGradients, TensorType, AlgorithmConfigDict
|
2021-05-03 14:23:28 -07:00
|
|
|
|
2020-02-22 20:02:31 +01:00
|
|
|
torch, _ = try_import_torch()
|
2019-05-18 00:23:11 -07:00
|
|
|
|
|
|
|
|
2021-08-03 18:30:02 -04:00
|
|
|
@Deprecated(new="build_policy_class(framework='torch')", error=False)
|
2020-08-07 16:49:49 -07:00
|
|
|
def build_torch_policy(
|
|
|
|
name: str,
|
|
|
|
*,
|
2020-10-12 22:49:48 +02:00
|
|
|
loss_fn: Optional[
|
|
|
|
Callable[
|
2020-08-20 17:05:57 +02:00
|
|
|
[Policy, ModelV2, Type[TorchDistributionWrapper], SampleBatch],
|
2020-10-12 22:49:48 +02:00
|
|
|
Union[TensorType, List[TensorType]],
|
2022-01-29 18:41:57 -08:00
|
|
|
]
|
2020-10-12 22:49:48 +02:00
|
|
|
],
|
2022-06-11 15:10:39 +02:00
|
|
|
get_default_config: Optional[Callable[[], AlgorithmConfigDict]] = None,
|
2020-08-07 16:49:49 -07:00
|
|
|
stats_fn: Optional[Callable[[Policy, SampleBatch], Dict[str, TensorType]]] = None,
|
2021-05-04 10:06:19 -07:00
|
|
|
postprocess_fn=None,
|
2020-08-07 16:49:49 -07:00
|
|
|
extra_action_out_fn: Optional[
|
|
|
|
Callable[
|
2022-01-29 18:41:57 -08:00
|
|
|
[
|
2020-08-07 16:49:49 -07:00
|
|
|
Policy,
|
|
|
|
Dict[str, TensorType],
|
|
|
|
List[TensorType],
|
|
|
|
ModelV2,
|
|
|
|
TorchDistributionWrapper,
|
|
|
|
],
|
|
|
|
Dict[str, TensorType],
|
2022-01-29 18:41:57 -08:00
|
|
|
]
|
2020-08-07 16:49:49 -07:00
|
|
|
] = None,
|
|
|
|
extra_grad_process_fn: Optional[
|
|
|
|
Callable[[Policy, "torch.optim.Optimizer", TensorType], Dict[str, TensorType]]
|
|
|
|
] = None,
|
|
|
|
extra_learn_fetches_fn: Optional[Callable[[Policy], Dict[str, TensorType]]] = None,
|
|
|
|
optimizer_fn: Optional[
|
2022-06-11 15:10:39 +02:00
|
|
|
Callable[[Policy, AlgorithmConfigDict], "torch.optim.Optimizer"]
|
2020-08-07 16:49:49 -07:00
|
|
|
] = None,
|
|
|
|
validate_spaces: Optional[
|
2022-06-11 15:10:39 +02:00
|
|
|
Callable[[Policy, gym.Space, gym.Space, AlgorithmConfigDict], None]
|
2020-08-07 16:49:49 -07:00
|
|
|
] = None,
|
|
|
|
before_init: Optional[
|
2022-06-11 15:10:39 +02:00
|
|
|
Callable[[Policy, gym.Space, gym.Space, AlgorithmConfigDict], None]
|
2020-08-07 16:49:49 -07:00
|
|
|
] = None,
|
2020-11-03 21:53:34 +01:00
|
|
|
before_loss_init: Optional[
|
2022-06-11 15:10:39 +02:00
|
|
|
Callable[
|
|
|
|
[Policy, gym.spaces.Space, gym.spaces.Space, AlgorithmConfigDict], None
|
|
|
|
]
|
2020-11-03 21:53:34 +01:00
|
|
|
] = None,
|
2020-08-07 16:49:49 -07:00
|
|
|
after_init: Optional[
|
2022-06-11 15:10:39 +02:00
|
|
|
Callable[[Policy, gym.Space, gym.Space, AlgorithmConfigDict], None]
|
2020-08-07 16:49:49 -07:00
|
|
|
] = None,
|
2020-11-03 21:53:34 +01:00
|
|
|
_after_loss_init: Optional[
|
2022-06-11 15:10:39 +02:00
|
|
|
Callable[
|
|
|
|
[Policy, gym.spaces.Space, gym.spaces.Space, AlgorithmConfigDict], None
|
|
|
|
]
|
2020-11-03 21:53:34 +01:00
|
|
|
] = None,
|
2020-08-07 16:49:49 -07:00
|
|
|
action_sampler_fn: Optional[
|
|
|
|
Callable[[TensorType, List[TensorType]], Tuple[TensorType, TensorType]]
|
|
|
|
] = None,
|
|
|
|
action_distribution_fn: Optional[
|
|
|
|
Callable[
|
|
|
|
[Policy, ModelV2, TensorType, TensorType, TensorType],
|
|
|
|
Tuple[TensorType, type, List[TensorType]],
|
2022-01-29 18:41:57 -08:00
|
|
|
]
|
2020-08-07 16:49:49 -07:00
|
|
|
] = None,
|
|
|
|
make_model: Optional[
|
|
|
|
Callable[
|
2022-06-11 15:10:39 +02:00
|
|
|
[Policy, gym.spaces.Space, gym.spaces.Space, AlgorithmConfigDict], ModelV2
|
2022-01-29 18:41:57 -08:00
|
|
|
]
|
2020-08-07 16:49:49 -07:00
|
|
|
] = None,
|
|
|
|
make_model_and_action_dist: Optional[
|
|
|
|
Callable[
|
2022-06-11 15:10:39 +02:00
|
|
|
[Policy, gym.spaces.Space, gym.spaces.Space, AlgorithmConfigDict],
|
2020-09-20 11:27:02 +02:00
|
|
|
Tuple[ModelV2, Type[TorchDistributionWrapper]],
|
2022-01-29 18:41:57 -08:00
|
|
|
]
|
2020-09-20 11:27:02 +02:00
|
|
|
] = None,
|
2021-05-04 10:06:19 -07:00
|
|
|
compute_gradients_fn: Optional[
|
|
|
|
Callable[[Policy, SampleBatch], Tuple[ModelGradients, dict]]
|
|
|
|
] = None,
|
2020-08-07 16:49:49 -07:00
|
|
|
apply_gradients_fn: Optional[
|
|
|
|
Callable[[Policy, "torch.optim.Optimizer"], None]
|
|
|
|
] = None,
|
|
|
|
mixins: Optional[List[type]] = None,
|
2020-10-02 23:07:44 +02:00
|
|
|
get_batch_divisibility_req: Optional[Callable[[Policy], int]] = None
|
|
|
|
) -> Type[TorchPolicy]:
|
2019-05-18 00:23:11 -07:00
|
|
|
|
2020-12-26 20:14:18 -05:00
|
|
|
kwargs = locals().copy()
|
|
|
|
# Set to torch and call new function.
|
|
|
|
kwargs["framework"] = "torch"
|
|
|
|
return build_policy_class(**kwargs)
|