2020-02-15 23:50:44 +01:00
|
|
|
import os
|
2020-08-10 19:44:23 +02:00
|
|
|
from pathlib import Path
|
2020-07-30 16:17:03 +02:00
|
|
|
import re
|
2020-09-02 14:03:01 +02:00
|
|
|
import sys
|
2020-03-12 04:39:47 +01:00
|
|
|
import unittest
|
|
|
|
|
2020-07-30 16:17:03 +02:00
|
|
|
import ray
|
2020-08-10 19:44:23 +02:00
|
|
|
from ray import tune
|
|
|
|
from ray.rllib.examples.env.multi_agent import MultiAgentCartPole
|
2020-05-27 16:19:13 +02:00
|
|
|
from ray.rllib.utils.test_utils import framework_iterator
|
|
|
|
|
2020-03-12 04:39:47 +01:00
|
|
|
|
2021-09-15 08:45:17 +02:00
|
|
|
def evaluate_test(algo, env="CartPole-v0", test_episode_rollout=False):
|
2020-06-05 08:34:21 +02:00
|
|
|
extra_config = ""
|
2020-06-05 15:40:30 +02:00
|
|
|
if algo == "ARS":
|
2022-01-29 18:41:57 -08:00
|
|
|
extra_config = ',"train_batch_size": 10, "noise_size": 250000'
|
2020-06-05 15:40:30 +02:00
|
|
|
elif algo == "ES":
|
2022-01-29 18:41:57 -08:00
|
|
|
extra_config = (
|
|
|
|
',"episodes_per_batch": 1,"train_batch_size": 10, ' '"noise_size": 250000'
|
|
|
|
)
|
2020-05-27 16:19:13 +02:00
|
|
|
|
2020-06-05 08:34:21 +02:00
|
|
|
for fw in framework_iterator(frameworks=("tf", "torch")):
|
2022-01-29 18:41:57 -08:00
|
|
|
fw_ = ', "framework": "{}"'.format(fw)
|
2020-05-27 16:19:13 +02:00
|
|
|
|
|
|
|
tmp_dir = os.popen("mktemp -d").read()[:-1]
|
|
|
|
if not os.path.exists(tmp_dir):
|
|
|
|
sys.exit(1)
|
|
|
|
|
|
|
|
print("Saving results to {}".format(tmp_dir))
|
|
|
|
|
|
|
|
rllib_dir = str(Path(__file__).parent.parent.absolute())
|
2022-01-29 18:41:57 -08:00
|
|
|
print("RLlib dir = {}\nexists={}".format(rllib_dir, os.path.exists(rllib_dir)))
|
2022-01-25 14:16:58 +01:00
|
|
|
os.system(
|
|
|
|
"python {}/train.py --local-dir={} --run={} "
|
2022-01-29 18:41:57 -08:00
|
|
|
"--checkpoint-freq=1 ".format(rllib_dir, tmp_dir, algo)
|
|
|
|
+ "--config='{"
|
|
|
|
+ '"num_workers": 1, "num_gpus": 0{}{}'.format(fw_, extra_config)
|
|
|
|
+ ', "timesteps_per_iteration": 5,"min_time_s_per_reporting": 0.1, '
|
|
|
|
'"model": {"fcnet_hiddens": [10]}'
|
|
|
|
"}' --stop='{\"training_iteration\": 1}'"
|
|
|
|
+ " --env={} --no-ray-ui".format(env)
|
|
|
|
)
|
|
|
|
|
|
|
|
checkpoint_path = os.popen(
|
|
|
|
"ls {}/default/*/checkpoint_000001/" "checkpoint-1".format(tmp_dir)
|
|
|
|
).read()[:-1]
|
2020-05-27 16:19:13 +02:00
|
|
|
if not os.path.exists(checkpoint_path):
|
|
|
|
sys.exit(1)
|
|
|
|
print("Checkpoint path {} (exists)".format(checkpoint_path))
|
|
|
|
|
|
|
|
# Test rolling out n steps.
|
2022-01-29 18:41:57 -08:00
|
|
|
os.popen(
|
|
|
|
'python {}/evaluate.py --run={} "{}" --steps=10 '
|
|
|
|
'--out="{}/rollouts_10steps.pkl" --no-render'.format(
|
|
|
|
rllib_dir, algo, checkpoint_path, tmp_dir
|
|
|
|
)
|
|
|
|
).read()
|
2020-05-27 16:19:13 +02:00
|
|
|
if not os.path.exists(tmp_dir + "/rollouts_10steps.pkl"):
|
|
|
|
sys.exit(1)
|
2021-09-15 08:45:17 +02:00
|
|
|
print("evaluate output (10 steps) exists!")
|
2020-05-27 16:19:13 +02:00
|
|
|
|
|
|
|
# Test rolling out 1 episode.
|
2020-06-05 08:34:21 +02:00
|
|
|
if test_episode_rollout:
|
2022-01-29 18:41:57 -08:00
|
|
|
os.popen(
|
|
|
|
'python {}/evaluate.py --run={} "{}" --episodes=1 '
|
|
|
|
'--out="{}/rollouts_1episode.pkl" --no-render'.format(
|
|
|
|
rllib_dir, algo, checkpoint_path, tmp_dir
|
|
|
|
)
|
|
|
|
).read()
|
2020-06-05 08:34:21 +02:00
|
|
|
if not os.path.exists(tmp_dir + "/rollouts_1episode.pkl"):
|
|
|
|
sys.exit(1)
|
2021-09-15 08:45:17 +02:00
|
|
|
print("evaluate output (1 ep) exists!")
|
2020-05-27 16:19:13 +02:00
|
|
|
|
|
|
|
# Cleanup.
|
2022-01-29 18:41:57 -08:00
|
|
|
os.popen('rm -rf "{}"'.format(tmp_dir)).read()
|
2020-03-12 19:02:51 +01:00
|
|
|
|
|
|
|
|
2021-09-15 08:45:17 +02:00
|
|
|
def learn_test_plus_evaluate(algo, env="CartPole-v0"):
|
2020-08-10 19:44:23 +02:00
|
|
|
for fw in framework_iterator(frameworks=("tf", "torch")):
|
2022-01-29 18:41:57 -08:00
|
|
|
fw_ = ', \\"framework\\": \\"{}\\"'.format(fw)
|
2020-07-30 16:17:03 +02:00
|
|
|
|
|
|
|
tmp_dir = os.popen("mktemp -d").read()[:-1]
|
|
|
|
if not os.path.exists(tmp_dir):
|
|
|
|
# Last resort: Resolve via underlying tempdir (and cut tmp_.
|
2021-03-10 23:47:28 -07:00
|
|
|
tmp_dir = ray._private.utils.tempfile.gettempdir() + tmp_dir[4:]
|
2020-07-30 16:17:03 +02:00
|
|
|
if not os.path.exists(tmp_dir):
|
|
|
|
sys.exit(1)
|
|
|
|
|
|
|
|
print("Saving results to {}".format(tmp_dir))
|
|
|
|
|
|
|
|
rllib_dir = str(Path(__file__).parent.parent.absolute())
|
2022-01-29 18:41:57 -08:00
|
|
|
print("RLlib dir = {}\nexists={}".format(rllib_dir, os.path.exists(rllib_dir)))
|
|
|
|
os.system(
|
|
|
|
"python {}/train.py --local-dir={} --run={} "
|
|
|
|
"--checkpoint-freq=1 --checkpoint-at-end ".format(rllib_dir, tmp_dir, algo)
|
|
|
|
+ '--config="{\\"num_gpus\\": 0, \\"num_workers\\": 1, '
|
|
|
|
'\\"evaluation_config\\": {\\"explore\\": false}'
|
|
|
|
+ fw_
|
|
|
|
+ '}" '
|
|
|
|
+ '--stop="{\\"episode_reward_mean\\": 100.0}"'
|
|
|
|
+ " --env={}".format(env)
|
|
|
|
)
|
2020-07-30 16:17:03 +02:00
|
|
|
|
|
|
|
# Find last checkpoint and use that for the rollout.
|
2022-01-29 18:41:57 -08:00
|
|
|
checkpoint_path = os.popen(
|
|
|
|
"ls {}/default/*/checkpoint_*/" "checkpoint-*".format(tmp_dir)
|
|
|
|
).read()[:-1]
|
2020-07-30 16:17:03 +02:00
|
|
|
checkpoints = [
|
2022-01-29 18:41:57 -08:00
|
|
|
cp
|
|
|
|
for cp in checkpoint_path.split("\n")
|
2020-07-30 16:17:03 +02:00
|
|
|
if re.match(r"^.+checkpoint-\d+$", cp)
|
|
|
|
]
|
|
|
|
# Sort by number and pick last (which should be the best checkpoint).
|
|
|
|
last_checkpoint = sorted(
|
2022-01-29 18:41:57 -08:00
|
|
|
checkpoints, key=lambda x: int(re.match(r".+checkpoint-(\d+)", x).group(1))
|
|
|
|
)[-1]
|
2020-07-30 16:17:03 +02:00
|
|
|
assert re.match(r"^.+checkpoint_\d+/checkpoint-\d+$", last_checkpoint)
|
|
|
|
if not os.path.exists(last_checkpoint):
|
|
|
|
sys.exit(1)
|
|
|
|
print("Best checkpoint={} (exists)".format(last_checkpoint))
|
|
|
|
|
|
|
|
# Test rolling out n steps.
|
|
|
|
result = os.popen(
|
2021-09-15 08:45:17 +02:00
|
|
|
"python {}/evaluate.py --run={} "
|
2020-07-30 16:17:03 +02:00
|
|
|
"--steps=400 "
|
2022-01-29 18:41:57 -08:00
|
|
|
'--out="{}/rollouts_n_steps.pkl" --no-render "{}"'.format(
|
|
|
|
rllib_dir, algo, tmp_dir, last_checkpoint
|
|
|
|
)
|
|
|
|
).read()[:-1]
|
2020-07-30 16:17:03 +02:00
|
|
|
if not os.path.exists(tmp_dir + "/rollouts_n_steps.pkl"):
|
|
|
|
sys.exit(1)
|
2021-05-03 14:23:28 -07:00
|
|
|
print("Rollout output exists -> Checking reward ...")
|
2020-07-30 16:17:03 +02:00
|
|
|
episodes = result.split("\n")
|
|
|
|
mean_reward = 0.0
|
|
|
|
num_episodes = 0
|
|
|
|
for ep in episodes:
|
|
|
|
mo = re.match(r"Episode .+reward: ([\d\.\-]+)", ep)
|
|
|
|
if mo:
|
|
|
|
mean_reward += float(mo.group(1))
|
|
|
|
num_episodes += 1
|
|
|
|
mean_reward /= num_episodes
|
|
|
|
print("Rollout's mean episode reward={}".format(mean_reward))
|
2021-08-31 14:56:53 +02:00
|
|
|
assert mean_reward >= 100.0
|
2020-07-30 16:17:03 +02:00
|
|
|
|
|
|
|
# Cleanup.
|
2022-01-29 18:41:57 -08:00
|
|
|
os.popen('rm -rf "{}"'.format(tmp_dir)).read()
|
2020-07-30 16:17:03 +02:00
|
|
|
|
|
|
|
|
2021-09-15 08:45:17 +02:00
|
|
|
def learn_test_multi_agent_plus_evaluate(algo):
|
2020-08-10 19:44:23 +02:00
|
|
|
for fw in framework_iterator(frameworks=("tf", "torch")):
|
|
|
|
tmp_dir = os.popen("mktemp -d").read()[:-1]
|
|
|
|
if not os.path.exists(tmp_dir):
|
|
|
|
# Last resort: Resolve via underlying tempdir (and cut tmp_.
|
2021-03-10 23:47:28 -07:00
|
|
|
tmp_dir = ray._private.utils.tempfile.gettempdir() + tmp_dir[4:]
|
2020-08-10 19:44:23 +02:00
|
|
|
if not os.path.exists(tmp_dir):
|
|
|
|
sys.exit(1)
|
|
|
|
|
|
|
|
print("Saving results to {}".format(tmp_dir))
|
|
|
|
|
|
|
|
rllib_dir = str(Path(__file__).parent.parent.absolute())
|
2022-01-29 18:41:57 -08:00
|
|
|
print("RLlib dir = {}\nexists={}".format(rllib_dir, os.path.exists(rllib_dir)))
|
2020-08-10 19:44:23 +02:00
|
|
|
|
2021-06-21 13:46:01 +02:00
|
|
|
def policy_fn(agent_id, episode, **kwargs):
|
|
|
|
return "pol{}".format(agent_id)
|
2020-08-10 19:44:23 +02:00
|
|
|
|
|
|
|
config = {
|
|
|
|
"num_gpus": 0,
|
|
|
|
"num_workers": 1,
|
2022-01-29 18:41:57 -08:00
|
|
|
"evaluation_config": {"explore": False},
|
2020-08-10 19:44:23 +02:00
|
|
|
"framework": fw,
|
|
|
|
"env": MultiAgentCartPole,
|
|
|
|
"multiagent": {
|
2021-07-15 05:51:24 -04:00
|
|
|
"policies": {"pol0", "pol1"},
|
2020-08-10 19:44:23 +02:00
|
|
|
"policy_mapping_fn": policy_fn,
|
|
|
|
},
|
|
|
|
}
|
2021-08-31 14:56:53 +02:00
|
|
|
stop = {"episode_reward_mean": 100.0}
|
2020-08-10 19:44:23 +02:00
|
|
|
tune.run(
|
|
|
|
algo,
|
|
|
|
config=config,
|
|
|
|
stop=stop,
|
|
|
|
checkpoint_freq=1,
|
|
|
|
checkpoint_at_end=True,
|
|
|
|
local_dir=tmp_dir,
|
2022-01-29 18:41:57 -08:00
|
|
|
verbose=1,
|
|
|
|
)
|
2020-08-10 19:44:23 +02:00
|
|
|
|
|
|
|
# Find last checkpoint and use that for the rollout.
|
2022-01-29 18:41:57 -08:00
|
|
|
checkpoint_path = os.popen(
|
|
|
|
"ls {}/PPO/*/checkpoint_*/" "checkpoint-*".format(tmp_dir)
|
|
|
|
).read()[:-1]
|
2020-08-10 19:44:23 +02:00
|
|
|
checkpoint_paths = checkpoint_path.split("\n")
|
|
|
|
assert len(checkpoint_paths) > 0
|
|
|
|
checkpoints = [
|
2022-01-29 18:41:57 -08:00
|
|
|
cp for cp in checkpoint_paths if re.match(r"^.+checkpoint-\d+$", cp)
|
2020-08-10 19:44:23 +02:00
|
|
|
]
|
|
|
|
# Sort by number and pick last (which should be the best checkpoint).
|
|
|
|
last_checkpoint = sorted(
|
2022-01-29 18:41:57 -08:00
|
|
|
checkpoints, key=lambda x: int(re.match(r".+checkpoint-(\d+)", x).group(1))
|
|
|
|
)[-1]
|
2020-08-10 19:44:23 +02:00
|
|
|
assert re.match(r"^.+checkpoint_\d+/checkpoint-\d+$", last_checkpoint)
|
|
|
|
if not os.path.exists(last_checkpoint):
|
|
|
|
sys.exit(1)
|
|
|
|
print("Best checkpoint={} (exists)".format(last_checkpoint))
|
|
|
|
|
|
|
|
ray.shutdown()
|
|
|
|
|
|
|
|
# Test rolling out n steps.
|
|
|
|
result = os.popen(
|
2021-12-04 13:26:33 +01:00
|
|
|
"python {}/evaluate.py --run={} "
|
2020-08-10 19:44:23 +02:00
|
|
|
"--steps=400 "
|
2022-01-29 18:41:57 -08:00
|
|
|
'--out="{}/rollouts_n_steps.pkl" --no-render "{}"'.format(
|
|
|
|
rllib_dir, algo, tmp_dir, last_checkpoint
|
|
|
|
)
|
|
|
|
).read()[:-1]
|
2020-08-10 19:44:23 +02:00
|
|
|
if not os.path.exists(tmp_dir + "/rollouts_n_steps.pkl"):
|
|
|
|
sys.exit(1)
|
2021-05-03 14:23:28 -07:00
|
|
|
print("Rollout output exists -> Checking reward ...")
|
2020-08-10 19:44:23 +02:00
|
|
|
episodes = result.split("\n")
|
|
|
|
mean_reward = 0.0
|
|
|
|
num_episodes = 0
|
|
|
|
for ep in episodes:
|
|
|
|
mo = re.match(r"Episode .+reward: ([\d\.\-]+)", ep)
|
|
|
|
if mo:
|
|
|
|
mean_reward += float(mo.group(1))
|
|
|
|
num_episodes += 1
|
|
|
|
mean_reward /= num_episodes
|
|
|
|
print("Rollout's mean episode reward={}".format(mean_reward))
|
2021-08-31 14:56:53 +02:00
|
|
|
assert mean_reward >= 100.0
|
2020-08-10 19:44:23 +02:00
|
|
|
|
|
|
|
# Cleanup.
|
2022-01-29 18:41:57 -08:00
|
|
|
os.popen('rm -rf "{}"'.format(tmp_dir)).read()
|
2020-08-10 19:44:23 +02:00
|
|
|
|
|
|
|
|
2021-09-15 08:45:17 +02:00
|
|
|
class TestEvaluate1(unittest.TestCase):
|
2020-03-12 19:02:51 +01:00
|
|
|
def test_a3c(self):
|
2021-09-15 08:45:17 +02:00
|
|
|
evaluate_test("A3C")
|
2020-03-12 19:02:51 +01:00
|
|
|
|
|
|
|
def test_ddpg(self):
|
[RLlib] Upgrade gym version to 0.21 and deprecate pendulum-v0. (#19535)
* Fix QMix, SAC, and MADDPA too.
* Unpin gym and deprecate pendulum v0
Many tests in rllib depended on pendulum v0,
however in gym 0.21, pendulum v0 was deprecated
in favor of pendulum v1. This may change reward
thresholds, so will have to potentially rerun
all of the pendulum v1 benchmarks, or use another
environment in favor. The same applies to frozen
lake v0 and frozen lake v1
Lastly, all of the RLlib tests and have
been moved to python 3.7
* Add gym installation based on python version.
Pin python<= 3.6 to gym 0.19 due to install
issues with atari roms in gym 0.20
* Reformatting
* Fixing tests
* Move atari-py install conditional to req.txt
* migrate to new ale install method
* Fix QMix, SAC, and MADDPA too.
* Unpin gym and deprecate pendulum v0
Many tests in rllib depended on pendulum v0,
however in gym 0.21, pendulum v0 was deprecated
in favor of pendulum v1. This may change reward
thresholds, so will have to potentially rerun
all of the pendulum v1 benchmarks, or use another
environment in favor. The same applies to frozen
lake v0 and frozen lake v1
Lastly, all of the RLlib tests and have
been moved to python 3.7
* Add gym installation based on python version.
Pin python<= 3.6 to gym 0.19 due to install
issues with atari roms in gym 0.20
Move atari-py install conditional to req.txt
migrate to new ale install method
Make parametric_actions_cartpole return float32 actions/obs
Adding type conversions if obs/actions don't match space
Add utils to make elements match gym space dtypes
Co-authored-by: Jun Gong <jungong@anyscale.com>
Co-authored-by: sven1977 <svenmika1977@gmail.com>
2021-11-03 08:24:00 -07:00
|
|
|
evaluate_test("DDPG", env="Pendulum-v1")
|
2020-03-12 19:02:51 +01:00
|
|
|
|
2021-03-30 10:20:02 -07:00
|
|
|
|
2021-09-15 08:45:17 +02:00
|
|
|
class TestEvaluate2(unittest.TestCase):
|
2020-03-12 19:02:51 +01:00
|
|
|
def test_dqn(self):
|
2021-09-15 08:45:17 +02:00
|
|
|
evaluate_test("DQN")
|
2020-03-12 19:02:51 +01:00
|
|
|
|
|
|
|
def test_es(self):
|
2021-09-15 08:45:17 +02:00
|
|
|
evaluate_test("ES")
|
2020-03-12 19:02:51 +01:00
|
|
|
|
2021-03-30 10:20:02 -07:00
|
|
|
|
2021-09-15 08:45:17 +02:00
|
|
|
class TestEvaluate3(unittest.TestCase):
|
2020-03-12 19:02:51 +01:00
|
|
|
def test_impala(self):
|
2021-09-15 08:45:17 +02:00
|
|
|
evaluate_test("IMPALA", env="CartPole-v0")
|
2020-03-12 19:02:51 +01:00
|
|
|
|
|
|
|
def test_ppo(self):
|
2021-09-15 08:45:17 +02:00
|
|
|
evaluate_test("PPO", env="CartPole-v0", test_episode_rollout=True)
|
2020-03-12 19:02:51 +01:00
|
|
|
|
2021-03-30 10:20:02 -07:00
|
|
|
|
2021-09-15 08:45:17 +02:00
|
|
|
class TestEvaluate4(unittest.TestCase):
|
2020-03-12 19:02:51 +01:00
|
|
|
def test_sac(self):
|
[RLlib] Upgrade gym version to 0.21 and deprecate pendulum-v0. (#19535)
* Fix QMix, SAC, and MADDPA too.
* Unpin gym and deprecate pendulum v0
Many tests in rllib depended on pendulum v0,
however in gym 0.21, pendulum v0 was deprecated
in favor of pendulum v1. This may change reward
thresholds, so will have to potentially rerun
all of the pendulum v1 benchmarks, or use another
environment in favor. The same applies to frozen
lake v0 and frozen lake v1
Lastly, all of the RLlib tests and have
been moved to python 3.7
* Add gym installation based on python version.
Pin python<= 3.6 to gym 0.19 due to install
issues with atari roms in gym 0.20
* Reformatting
* Fixing tests
* Move atari-py install conditional to req.txt
* migrate to new ale install method
* Fix QMix, SAC, and MADDPA too.
* Unpin gym and deprecate pendulum v0
Many tests in rllib depended on pendulum v0,
however in gym 0.21, pendulum v0 was deprecated
in favor of pendulum v1. This may change reward
thresholds, so will have to potentially rerun
all of the pendulum v1 benchmarks, or use another
environment in favor. The same applies to frozen
lake v0 and frozen lake v1
Lastly, all of the RLlib tests and have
been moved to python 3.7
* Add gym installation based on python version.
Pin python<= 3.6 to gym 0.19 due to install
issues with atari roms in gym 0.20
Move atari-py install conditional to req.txt
migrate to new ale install method
Make parametric_actions_cartpole return float32 actions/obs
Adding type conversions if obs/actions don't match space
Add utils to make elements match gym space dtypes
Co-authored-by: Jun Gong <jungong@anyscale.com>
Co-authored-by: sven1977 <svenmika1977@gmail.com>
2021-11-03 08:24:00 -07:00
|
|
|
evaluate_test("SAC", env="Pendulum-v1")
|
2020-03-12 19:02:51 +01:00
|
|
|
|
2020-02-15 23:50:44 +01:00
|
|
|
|
2021-09-15 08:45:17 +02:00
|
|
|
class TestTrainAndEvaluate(unittest.TestCase):
|
2020-07-30 16:17:03 +02:00
|
|
|
def test_ppo_train_then_rollout(self):
|
2021-09-15 08:45:17 +02:00
|
|
|
learn_test_plus_evaluate("PPO")
|
2020-07-30 16:17:03 +02:00
|
|
|
|
2020-08-10 19:44:23 +02:00
|
|
|
def test_ppo_multi_agent_train_then_rollout(self):
|
2021-09-15 08:45:17 +02:00
|
|
|
learn_test_multi_agent_plus_evaluate("PPO")
|
2020-08-10 19:44:23 +02:00
|
|
|
|
2020-07-30 16:17:03 +02:00
|
|
|
|
2020-02-15 23:50:44 +01:00
|
|
|
if __name__ == "__main__":
|
2020-03-12 04:39:47 +01:00
|
|
|
import pytest
|
2020-07-30 16:17:03 +02:00
|
|
|
|
|
|
|
# One can specify the specific TestCase class to run.
|
|
|
|
# None for all unittest.TestCase classes in this file.
|
|
|
|
class_ = sys.argv[1] if len(sys.argv) > 1 else None
|
2022-01-29 18:41:57 -08:00
|
|
|
sys.exit(pytest.main(["-v", __file__ + ("" if class_ is None else "::" + class_)]))
|