ray/rllib/policy/torch_policy_template.py

Ignoring revisions in .git-blame-ignore-revs. Click here to bypass and see the normal blame view.

102 lines
3.6 KiB
Python
Raw Normal View History

import gym
from typing import Callable, Dict, List, Optional, Tuple, Type, Union
from ray.rllib.models.modelv2 import ModelV2
from ray.rllib.models.torch.torch_action_dist import TorchDistributionWrapper
from ray.rllib.policy.policy import Policy
from ray.rllib.policy.policy_template import build_policy_class
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.policy.torch_policy import TorchPolicy
from ray.rllib.utils.deprecation import Deprecated
from ray.rllib.utils.framework import try_import_torch
from ray.rllib.utils.typing import ModelGradients, TensorType, AlgorithmConfigDict
torch, _ = try_import_torch()
@Deprecated(new="build_policy_class(framework='torch')", error=False)
def build_torch_policy(
name: str,
*,
loss_fn: Optional[
Callable[
[Policy, ModelV2, Type[TorchDistributionWrapper], SampleBatch],
Union[TensorType, List[TensorType]],
]
],
get_default_config: Optional[Callable[[], AlgorithmConfigDict]] = None,
stats_fn: Optional[Callable[[Policy, SampleBatch], Dict[str, TensorType]]] = None,
postprocess_fn=None,
extra_action_out_fn: Optional[
Callable[
[
Policy,
Dict[str, TensorType],
List[TensorType],
ModelV2,
TorchDistributionWrapper,
],
Dict[str, TensorType],
]
] = None,
extra_grad_process_fn: Optional[
Callable[[Policy, "torch.optim.Optimizer", TensorType], Dict[str, TensorType]]
] = None,
extra_learn_fetches_fn: Optional[Callable[[Policy], Dict[str, TensorType]]] = None,
optimizer_fn: Optional[
Callable[[Policy, AlgorithmConfigDict], "torch.optim.Optimizer"]
] = None,
validate_spaces: Optional[
Callable[[Policy, gym.Space, gym.Space, AlgorithmConfigDict], None]
] = None,
before_init: Optional[
Callable[[Policy, gym.Space, gym.Space, AlgorithmConfigDict], None]
] = None,
before_loss_init: Optional[
Callable[
[Policy, gym.spaces.Space, gym.spaces.Space, AlgorithmConfigDict], None
]
] = None,
after_init: Optional[
Callable[[Policy, gym.Space, gym.Space, AlgorithmConfigDict], None]
] = None,
_after_loss_init: Optional[
Callable[
[Policy, gym.spaces.Space, gym.spaces.Space, AlgorithmConfigDict], None
]
] = None,
action_sampler_fn: Optional[
Callable[[TensorType, List[TensorType]], Tuple[TensorType, TensorType]]
] = None,
action_distribution_fn: Optional[
Callable[
[Policy, ModelV2, TensorType, TensorType, TensorType],
Tuple[TensorType, type, List[TensorType]],
]
] = None,
make_model: Optional[
Callable[
[Policy, gym.spaces.Space, gym.spaces.Space, AlgorithmConfigDict], ModelV2
]
] = None,
make_model_and_action_dist: Optional[
Callable[
[Policy, gym.spaces.Space, gym.spaces.Space, AlgorithmConfigDict],
2020-09-20 11:27:02 +02:00
Tuple[ModelV2, Type[TorchDistributionWrapper]],
]
2020-09-20 11:27:02 +02:00
] = None,
compute_gradients_fn: Optional[
Callable[[Policy, SampleBatch], Tuple[ModelGradients, dict]]
] = None,
apply_gradients_fn: Optional[
Callable[[Policy, "torch.optim.Optimizer"], None]
] = None,
mixins: Optional[List[type]] = None,
get_batch_divisibility_req: Optional[Callable[[Policy], int]] = None
) -> Type[TorchPolicy]:
kwargs = locals().copy()
# Set to torch and call new function.
kwargs["framework"] = "torch"
return build_policy_class(**kwargs)