mirror of
https://github.com/vale981/ray
synced 2025-03-12 22:26:39 -04:00
40 lines
1.1 KiB
Python
40 lines
1.1 KiB
Python
![]() |
import unittest
|
||
|
|
||
|
import ray
|
||
|
import ray.rllib.agents.sac as sac
|
||
|
from ray.rllib.utils.framework import try_import_tf
|
||
|
|
||
|
tf = try_import_tf()
|
||
|
|
||
|
|
||
|
class TestSAC(unittest.TestCase):
|
||
|
def test_sac_compilation(self):
|
||
|
"""Test whether an SACTrainer can be built with all frameworks."""
|
||
|
ray.init()
|
||
|
config = sac.DEFAULT_CONFIG.copy()
|
||
|
config["num_workers"] = 0 # Run locally.
|
||
|
num_iterations = 1
|
||
|
|
||
|
# eager (discrete and cont. actions).
|
||
|
for fw in ["eager", "tf", "torch"]:
|
||
|
print("framework={}".format(fw))
|
||
|
if fw == "torch":
|
||
|
continue
|
||
|
config["eager"] = fw == "eager"
|
||
|
config["use_pytorch"] = fw == "torch"
|
||
|
for env in [
|
||
|
"CartPole-v0",
|
||
|
"Pendulum-v0",
|
||
|
]:
|
||
|
print("Env={}".format(env))
|
||
|
trainer = sac.SACTrainer(config=config, env=env)
|
||
|
for i in range(num_iterations):
|
||
|
results = trainer.train()
|
||
|
print(results)
|
||
|
|
||
|
|
||
|
if __name__ == "__main__":
|
||
|
import pytest
|
||
|
import sys
|
||
|
sys.exit(pytest.main(["-v", __file__]))
|