ray/rllib/utils/exploration/tests/test_random_encoder.py

Ignoring revisions in .git-blame-ignore-revs. Click here to bypass and see the normal blame view.

76 lines
2.1 KiB
Python
Raw Normal View History

import sys
import unittest
import pytest
import ray
from ray.rllib.utils.test_utils import framework_iterator
import ray.rllib.algorithms.ppo as ppo
import ray.rllib.algorithms.sac as sac
from ray.rllib.algorithms.callbacks import RE3UpdateCallbacks
class TestRE3(unittest.TestCase):
"""Tests for RE3 exploration algorithm."""
@classmethod
def setUpClass(cls):
ray.init()
@classmethod
def tearDownClass(cls):
ray.shutdown()
def run_re3(self, rl_algorithm):
"""Tests RE3 for PPO and SAC.
Both the on-policy and off-policy setups are validated.
"""
if rl_algorithm == "PPO":
config = ppo.PPOConfig().to_dict()
algo_cls = ppo.PPO
beta_schedule = "constant"
elif rl_algorithm == "SAC":
config = sac.SACConfig().to_dict()
algo_cls = sac.SAC
beta_schedule = "linear_decay"
class RE3Callbacks(RE3UpdateCallbacks, config["callbacks"]):
pass
config["env"] = "Pendulum-v1"
config["callbacks"] = RE3Callbacks
config["exploration_config"] = {
"type": "RE3",
"embeds_dim": 128,
"beta_schedule": beta_schedule,
"sub_exploration": {
"type": "StochasticSampling",
},
}
num_iterations = 60
for _ in framework_iterator(config, frameworks=("tf", "tf2"), session=True):
algo = algo_cls(config=config)
learnt = False
for i in range(num_iterations):
result = algo.train()
print(result)
if result["episode_reward_max"] > -900.0:
print("Reached goal after {} iters!".format(i))
learnt = True
break
algo.stop()
self.assertTrue(learnt)
def test_re3_ppo(self):
"""Tests RE3 with PPO."""
self.run_re3("PPO")
def test_re3_sac(self):
"""Tests RE3 with SAC."""
self.run_re3("SAC")
if __name__ == "__main__":
sys.exit(pytest.main(["-v", __file__]))