ray/rllib/agents/ars/ars_torch_policy.py
Sven Mika d15609ba2a
[RLlib] PyTorch version of ARS (Augmented Random Search). (#8106)
This PR implements a PyTorch version of RLlib's ARS algorithm using RLlib's functional algo builder API. It also adds a regression test for ARS (torch) on CartPole.
2020-04-21 09:47:52 +02:00

15 lines
553 B
Python

# Code in this file is adapted from:
# https://github.com/openai/evolution-strategies-starter.
import ray
from ray.rllib.agents.es.es_torch_policy import after_init, before_init, \
make_model_and_action_dist
from ray.rllib.policy.torch_policy_template import build_torch_policy
ARSTorchPolicy = build_torch_policy(
name="ARSTorchPolicy",
loss_fn=None,
get_default_config=lambda: ray.rllib.agents.ars.ars.DEFAULT_CONFIG,
before_init=before_init,
after_init=after_init,
make_model_and_action_dist=make_model_and_action_dist)