mirror of
https://github.com/vale981/ray
synced 2025-03-10 13:26:39 -04:00
72 lines
2.7 KiB
Python
72 lines
2.7 KiB
Python
import copy
|
|
import gym
|
|
import numpy as np
|
|
import unittest
|
|
|
|
import ray
|
|
import ray.rllib.agents.dyna as dyna
|
|
from ray.rllib.utils.framework import try_import_torch
|
|
from ray.rllib.utils.test_utils import check_compute_single_action, \
|
|
framework_iterator
|
|
|
|
torch, _ = try_import_torch()
|
|
|
|
|
|
class TestDYNA(unittest.TestCase):
|
|
@classmethod
|
|
def setUpClass(cls):
|
|
ray.init(local_mode=True)
|
|
|
|
@classmethod
|
|
def tearDownClass(cls):
|
|
ray.shutdown()
|
|
|
|
def test_dyna_compilation(self):
|
|
"""Test whether a DYNATrainer can be built with both frameworks."""
|
|
config = copy.deepcopy(dyna.DEFAULT_CONFIG)
|
|
config["num_workers"] = 1
|
|
config["train_batch_size"] = 1000
|
|
num_iterations = 30
|
|
env = "CartPole-v0"
|
|
test_env = gym.make(env)
|
|
|
|
for _ in framework_iterator(config, frameworks="torch"):
|
|
trainer = dyna.DYNATrainer(config=config, env=env)
|
|
policy = trainer.get_policy()
|
|
# Do n supervised epochs, each over `train_batch_size`.
|
|
# Ignore validation loss here as a stopping criteria.
|
|
for i in range(num_iterations):
|
|
info = trainer.train()["info"]["learner"]["default_policy"]
|
|
print("SL iteration: {}".format(i))
|
|
print("train loss {}".format(info["dynamics_train_loss"]))
|
|
print("validation loss {}".format(
|
|
info["dynamics_validation_loss"]))
|
|
# Check, whether normal action stepping works with DYNA's policy.
|
|
# Note that DYNA does not train its Policy. It must be pushed
|
|
# down from the main model-based algo from time to time.
|
|
check_compute_single_action(trainer)
|
|
|
|
# Check, whether env dynamics were actually learnt - more or less.
|
|
obs = test_env.reset()
|
|
for _ in range(10):
|
|
action = trainer.compute_action(obs)
|
|
obs = torch.from_numpy(np.array([obs])).float()
|
|
# Make the prediction over the next state (deterministic delta
|
|
# like in MBMPO).
|
|
predicted_next_obs_delta = \
|
|
policy.dynamics_model.get_next_observation(
|
|
obs,
|
|
torch.from_numpy(np.array([action])))
|
|
predicted_next_obs = obs + predicted_next_obs_delta
|
|
obs, _, done, _ = test_env.step(action)
|
|
self.assertLess(
|
|
np.sum(obs - predicted_next_obs.detach().numpy()), 0.05)
|
|
# Reset if done.
|
|
if done:
|
|
obs = test_env.reset()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
import pytest
|
|
import sys
|
|
sys.exit(pytest.main(["-v", __file__]))
|