mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
[RLlib] rollout.py - Add multi-agent test case. (#9981)
This commit is contained in:
parent
10baecb8c2
commit
4b10bdf8fc
2 changed files with 102 additions and 3 deletions
|
@ -1320,7 +1320,7 @@ py_test(
|
|||
name = "test_rollout_w_learning",
|
||||
main = "tests/test_rollout.py",
|
||||
tags = ["tests_dir", "tests_dir_R"],
|
||||
size = "medium",
|
||||
size = "large",
|
||||
data = ["train.py", "rollout.py"],
|
||||
srcs = ["tests/test_rollout.py"],
|
||||
args = ["TestRolloutLearntPolicy"]
|
||||
|
|
|
@ -1,9 +1,12 @@
|
|||
from pathlib import Path
|
||||
from gym.spaces import Box, Discrete
|
||||
import os
|
||||
from pathlib import Path
|
||||
import re
|
||||
import unittest
|
||||
|
||||
import ray
|
||||
from ray import tune
|
||||
from ray.rllib.examples.env.multi_agent import MultiAgentCartPole
|
||||
from ray.rllib.utils.test_utils import framework_iterator
|
||||
|
||||
|
||||
|
@ -64,7 +67,7 @@ def rollout_test(algo, env="CartPole-v0", test_episode_rollout=False):
|
|||
|
||||
|
||||
def learn_test_plus_rollout(algo, env="CartPole-v0"):
|
||||
for fw in framework_iterator(frameworks="tf"):
|
||||
for fw in framework_iterator(frameworks=("tf", "torch")):
|
||||
fw_ = ", \\\"framework\\\": \\\"{}\\\"".format(fw)
|
||||
|
||||
tmp_dir = os.popen("mktemp -d").read()[:-1]
|
||||
|
@ -129,6 +132,99 @@ def learn_test_plus_rollout(algo, env="CartPole-v0"):
|
|||
os.popen("rm -rf \"{}\"".format(tmp_dir)).read()
|
||||
|
||||
|
||||
def learn_test_multi_agent_plus_rollout(algo):
|
||||
for fw in framework_iterator(frameworks=("tf", "torch")):
|
||||
tmp_dir = os.popen("mktemp -d").read()[:-1]
|
||||
if not os.path.exists(tmp_dir):
|
||||
# Last resort: Resolve via underlying tempdir (and cut tmp_.
|
||||
tmp_dir = ray.utils.tempfile.gettempdir() + tmp_dir[4:]
|
||||
if not os.path.exists(tmp_dir):
|
||||
sys.exit(1)
|
||||
|
||||
print("Saving results to {}".format(tmp_dir))
|
||||
|
||||
rllib_dir = str(Path(__file__).parent.parent.absolute())
|
||||
print("RLlib dir = {}\nexists={}".format(rllib_dir,
|
||||
os.path.exists(rllib_dir)))
|
||||
|
||||
def policy_fn(agent):
|
||||
return "pol{}".format(agent)
|
||||
|
||||
observation_space = Box(float("-inf"), float("inf"), (4, ))
|
||||
action_space = Discrete(2)
|
||||
|
||||
config = {
|
||||
"num_gpus": 0,
|
||||
"num_workers": 1,
|
||||
"evaluation_config": {
|
||||
"explore": False
|
||||
},
|
||||
"framework": fw,
|
||||
"env": MultiAgentCartPole,
|
||||
"multiagent": {
|
||||
"policies": {
|
||||
"pol0": (None, observation_space, action_space, {}),
|
||||
"pol1": (None, observation_space, action_space, {}),
|
||||
},
|
||||
"policy_mapping_fn": policy_fn,
|
||||
},
|
||||
}
|
||||
stop = {"episode_reward_mean": 190.0}
|
||||
tune.run(
|
||||
algo,
|
||||
config=config,
|
||||
stop=stop,
|
||||
checkpoint_freq=1,
|
||||
checkpoint_at_end=True,
|
||||
local_dir=tmp_dir,
|
||||
verbose=1)
|
||||
|
||||
# Find last checkpoint and use that for the rollout.
|
||||
checkpoint_path = os.popen("ls {}/PPO/*/checkpoint_*/"
|
||||
"checkpoint-*".format(tmp_dir)).read()[:-1]
|
||||
checkpoint_paths = checkpoint_path.split("\n")
|
||||
assert len(checkpoint_paths) > 0
|
||||
checkpoints = [
|
||||
cp for cp in checkpoint_paths
|
||||
if re.match(r"^.+checkpoint-\d+$", cp)
|
||||
]
|
||||
# Sort by number and pick last (which should be the best checkpoint).
|
||||
last_checkpoint = sorted(
|
||||
checkpoints,
|
||||
key=lambda x: int(re.match(r".+checkpoint-(\d+)", x).group(1)))[-1]
|
||||
assert re.match(r"^.+checkpoint_\d+/checkpoint-\d+$", last_checkpoint)
|
||||
if not os.path.exists(last_checkpoint):
|
||||
sys.exit(1)
|
||||
print("Best checkpoint={} (exists)".format(last_checkpoint))
|
||||
|
||||
ray.shutdown()
|
||||
|
||||
# Test rolling out n steps.
|
||||
result = os.popen(
|
||||
"python {}/rollout.py --run={} "
|
||||
"--steps=400 "
|
||||
"--out=\"{}/rollouts_n_steps.pkl\" --no-render \"{}\"".format(
|
||||
rllib_dir, algo, tmp_dir, last_checkpoint)).read()[:-1]
|
||||
if not os.path.exists(tmp_dir + "/rollouts_n_steps.pkl"):
|
||||
sys.exit(1)
|
||||
print("Rollout output exists -> Checking reward ...".format(
|
||||
checkpoint_path))
|
||||
episodes = result.split("\n")
|
||||
mean_reward = 0.0
|
||||
num_episodes = 0
|
||||
for ep in episodes:
|
||||
mo = re.match(r"Episode .+reward: ([\d\.\-]+)", ep)
|
||||
if mo:
|
||||
mean_reward += float(mo.group(1))
|
||||
num_episodes += 1
|
||||
mean_reward /= num_episodes
|
||||
print("Rollout's mean episode reward={}".format(mean_reward))
|
||||
assert mean_reward >= 190.0
|
||||
|
||||
# Cleanup.
|
||||
os.popen("rm -rf \"{}\"".format(tmp_dir)).read()
|
||||
|
||||
|
||||
class TestRolloutSimple(unittest.TestCase):
|
||||
def test_a3c(self):
|
||||
rollout_test("A3C")
|
||||
|
@ -156,6 +252,9 @@ class TestRolloutLearntPolicy(unittest.TestCase):
|
|||
def test_ppo_train_then_rollout(self):
|
||||
learn_test_plus_rollout("PPO")
|
||||
|
||||
def test_ppo_multi_agent_train_then_rollout(self):
|
||||
learn_test_multi_agent_plus_rollout("PPO")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import sys
|
||||
|
|
Loading…
Add table
Reference in a new issue