mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00

The DDPG/TD3 algorithms currently do not have a PyTorch implementation. This PR adds PyTorch support for DDPG/TD3 to RLlib. This PR: - Depends on the re-factor PR for DDPG (Functional Algorithm API). - Adds learning regression tests for the PyTorch version of DDPG and a DDPG (torch) - Updates the documentation to reflect that DDPG and TD3 now support PyTorch. * Learning Pendulum-v0 on torch version (same config as tf). Wall time a little slower (~20% than tf). * Fix GPU target model problem.
10 lines
268 B
Python
10 lines
268 B
Python
from ray.rllib.agents.ddpg.apex import ApexDDPGTrainer
|
|
from ray.rllib.agents.ddpg.ddpg import DDPGTrainer, DEFAULT_CONFIG
|
|
from ray.rllib.agents.ddpg.td3 import TD3Trainer
|
|
|
|
__all__ = [
|
|
"ApexDDPGTrainer",
|
|
"DDPGTrainer",
|
|
"DEFAULT_CONFIG",
|
|
"TD3Trainer",
|
|
]
|