[RLlib] QMIX better defaults + added to CI learning tests (#21332)

This commit is contained in:
Sven Mika 2022-01-04 08:54:41 +01:00 committed by GitHub
parent 8cc268096c
commit abd3bef63b
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
10 changed files with 194 additions and 63 deletions

View file

@ -11,10 +11,9 @@
# Currently we have the following categories:
# - Learning tests/regression, tagged:
# -- "learning_tests_[tf|tf2|torch]": Distinguish tf/tf2 vs torch.
# -- "learning_tests_[discrete|continuous]_[tf|tf2|torch]": distinguish discrete
# actions vs continuous actions AND tf vs torch.
# -- "fake_gpus_[tf|torch]": Tests that run using 2 fake GPUs.
# -- "learning_tests_[discrete|continuous]": distinguish discrete
# actions vs continuous actions.
# -- "fake_gpus": Tests that run using 2 fake GPUs.
# - Quick agent compilation/tune-train tests, tagged "quick_train".
# NOTE: These should be obsoleted in favor of "trainers_dir" tests as
@ -413,6 +412,37 @@ py_test(
args = ["--yaml-dir=tuned_examples/ppo"]
)
# QMIX
py_test(
name = "learning_tests_two_step_game_qmix",
main = "tests/run_regression_tests.py",
tags = ["team:ml", "learning_tests", "learning_tests_discrete"],
size = "large",
srcs = ["tests/run_regression_tests.py"],
data = ["tuned_examples/qmix/two-step-game-qmix.yaml"],
args = ["--yaml-dir=tuned_examples/qmix", "--framework=torch"]
)
py_test(
name = "learning_tests_two_step_game_qmix_vdn_mixer",
main = "tests/run_regression_tests.py",
tags = ["team:ml", "learning_tests", "learning_tests_discrete"],
size = "large",
srcs = ["tests/run_regression_tests.py"],
data = ["tuned_examples/qmix/two-step-game-qmix-vdn-mixer.yaml"],
args = ["--yaml-dir=tuned_examples/qmix", "--framework=torch"]
)
py_test(
name = "learning_tests_two_step_game_qmix_no_mixer",
main = "tests/run_regression_tests.py",
tags = ["team:ml", "learning_tests", "learning_tests_discrete"],
size = "large",
srcs = ["tests/run_regression_tests.py"],
data = ["tuned_examples/qmix/two-step-game-qmix-no-mixer.yaml"],
args = ["--yaml-dir=tuned_examples/qmix", "--framework=torch"]
)
# R2D2
py_test(
name = "learning_tests_stateless_cartpole_r2d2",
@ -2683,15 +2713,6 @@ py_test(
args = ["--as-test", "--framework=torch", "--stop-reward=7", "--run=PG"]
)
py_test(
name = "examples/two_step_game_qmix",
main = "examples/two_step_game.py",
tags = ["team:ml", "examples", "examples_T"],
size = "large",
srcs = ["examples/two_step_game.py"],
args = ["--as-test", "--framework=torch", "--stop-reward=7", "--run=QMIX"]
)
py_test(
name = "contrib/bandits/examples/lin_ts",
main = "contrib/bandits/examples/simple_context_bandit.py",

View file

@ -34,8 +34,9 @@ DEFAULT_CONFIG = with_common_config({
"type": "EpsilonGreedy",
# Config for the Exploration class' constructor:
"initial_epsilon": 1.0,
"final_epsilon": 0.02,
"epsilon_timesteps": 10000, # Timesteps over which to anneal epsilon.
"final_epsilon": 0.01,
# Timesteps over which to anneal epsilon.
"epsilon_timesteps": 40000,
# For soft_q, use:
# "exploration_config" = {

View file

@ -24,9 +24,10 @@ class AvailActionsTestEnv(MultiAgentEnv):
def __init__(self, env_config):
self.state = None
self.avail = env_config["avail_action"]
self.avail = env_config.get("avail_actions", [3])
self.action_mask = np.array([0] * 10)
self.action_mask[env_config["avail_action"]] = 1
for a in self.avail:
self.action_mask[a] = 1
def reset(self):
self.state = 0
@ -34,22 +35,31 @@ class AvailActionsTestEnv(MultiAgentEnv):
"agent_1": {
"obs": self.observation_space["obs"].sample(),
"action_mask": self.action_mask
}
},
"agent_2": {
"obs": self.observation_space["obs"].sample(),
"action_mask": self.action_mask
},
}
def step(self, action_dict):
if self.state > 0:
assert action_dict["agent_1"] == self.avail, \
assert (action_dict["agent_1"] in self.avail and
action_dict["agent_2"] in self.avail), \
"Failed to obey available actions mask!"
self.state += 1
rewards = {"agent_1": 1}
rewards = {"agent_1": 1, "agent_2": 0.5}
obs = {
"agent_1": {
"obs": self.observation_space["obs"].sample(),
"action_mask": self.action_mask
},
"agent_2": {
"obs": self.observation_space["obs"].sample(),
"action_mask": self.action_mask
}
}
dones = {"__all__": self.state > 20}
dones = {"__all__": self.state >= 20}
return obs, rewards, dones, {}
@ -64,28 +74,33 @@ class TestQMix(unittest.TestCase):
def test_avail_actions_qmix(self):
grouping = {
"group_1": ["agent_1"], # trivial grouping for testing
"group_1": ["agent_1", "agent_2"],
}
obs_space = Tuple([AvailActionsTestEnv.observation_space])
act_space = Tuple([AvailActionsTestEnv.action_space])
obs_space = Tuple([
AvailActionsTestEnv.observation_space,
AvailActionsTestEnv.observation_space
])
act_space = Tuple([
AvailActionsTestEnv.action_space, AvailActionsTestEnv.action_space
])
register_env(
"action_mask_test",
lambda config: AvailActionsTestEnv(config).with_agent_groups(
grouping, obs_space=obs_space, act_space=act_space))
agent = QMixTrainer(
trainer = QMixTrainer(
env="action_mask_test",
config={
"num_envs_per_worker": 5, # test with vectorization on
"env_config": {
"avail_action": 3,
"avail_actions": [3, 4, 8],
},
"framework": "torch",
})
for _ in range(4):
agent.train() # OK if it doesn't trip the action assertion error
assert agent.train()["episode_reward_mean"] == 21.0
agent.stop()
trainer.train() # OK if it doesn't trip the action assertion error
assert trainer.train()["episode_reward_mean"] == 30.0
trainer.stop()
ray.shutdown()

View file

@ -1,4 +1,4 @@
from gym.spaces import MultiDiscrete, Dict, Discrete
from gym.spaces import Dict, Discrete, MultiDiscrete, Tuple
import numpy as np
from ray.rllib.env.multi_agent_env import MultiAgentEnv, ENV_STATE
@ -109,3 +109,23 @@ class TwoStepGame(MultiAgentEnv):
return np.concatenate([self.state, [2]])
else:
return np.flatnonzero(self.state)[0] + 3
class TwoStepGameWithGroupedAgents(MultiAgentEnv):
def __init__(self, env_config):
env = TwoStepGame(env_config)
tuple_obs_space = Tuple([env.observation_space, env.observation_space])
tuple_act_space = Tuple([env.action_space, env.action_space])
self.env = env.with_agent_groups(
groups={"agents": [0, 1]},
obs_space=tuple_obs_space,
act_space=tuple_act_space)
self.observation_space = self.env.observation_space
self.action_space = self.env.action_space
def reset(self):
return self.env.reset()
def step(self, actions):
return self.env.step(actions)

View file

@ -14,7 +14,7 @@ import os
import ray
from ray import tune
from ray.tune import register_env, grid_search
from ray.tune import register_env
from ray.rllib.env.multi_agent_env import ENV_STATE
from ray.rllib.examples.env.two_step_game import TwoStepGame
from ray.rllib.policy.policy import PolicySpec
@ -32,6 +32,12 @@ parser.add_argument(
default="tf",
help="The DL framework specifier.")
parser.add_argument("--num-cpus", type=int, default=0)
parser.add_argument(
"--mixer",
type=str,
default="qmix",
choices=["qmix", "vdn", "none"],
help="The mixer model to use.")
parser.add_argument(
"--as-test",
action="store_true",
@ -45,12 +51,12 @@ parser.add_argument(
parser.add_argument(
"--stop-timesteps",
type=int,
default=50000,
default=70000,
help="Number of timesteps to train.")
parser.add_argument(
"--stop-reward",
type=float,
default=7.0,
default=8.0,
help="Reward at which we stop training.")
parser.add_argument(
"--local-mode",
@ -116,11 +122,10 @@ if __name__ == "__main__":
"rollout_fragment_length": 4,
"train_batch_size": 32,
"exploration_config": {
"epsilon_timesteps": 5000,
"final_epsilon": 0.05,
"final_epsilon": 0.0,
},
"num_workers": 0,
"mixer": grid_search([None, "qmix"]),
"mixer": args.mixer,
"env_config": {
"separate_state_space": True,
"one_hot_state_encoding": True
@ -147,9 +152,6 @@ if __name__ == "__main__":
"env": "grouped_twostep" if group else TwoStepGame,
})
if args.as_test:
config["seed"] = 1234
results = tune.run(args.run, stop=stop, config=config, verbose=2)
if args.as_test:

View file

@ -53,6 +53,10 @@ parser.add_argument(
if __name__ == "__main__":
args = parser.parse_args()
# Error if deprecated --torch option used.
if args.torch:
deprecation_warning(old="--torch", new="--framework=torch", error=True)
# Bazel regression test mode: Get path to look for yaml files.
# Get the path or single file to use.
rllib_dir = Path(__file__).parent.parent
@ -81,13 +85,14 @@ if __name__ == "__main__":
assert len(experiments) == 1,\
"Error, can only run a single experiment per yaml file!"
# Add torch option to exp config.
exp = list(experiments.values())[0]
exp["config"]["framework"] = args.framework
if args.torch:
deprecation_warning(old="--torch", new="--framework=torch")
exp["config"]["framework"] = "torch"
args.framework = "torch"
# QMIX does not support tf yet -> skip.
if exp["run"] == "QMIX" and args.framework != "torch":
print(f"Skipping framework='{args.framework}' for QMIX.")
continue
# Always run with eager-tracing when framework=tf2.
if args.framework in ["tf2", "tfe"]:
exp["config"]["eager_tracing"] = True

View file

@ -42,7 +42,7 @@ class AgentIOTest(unittest.TestCase):
shutil.rmtree(self.test_dir)
ray.shutdown()
def writeOutputs(self, output, fw):
def write_outputs(self, output, fw):
agent = PGTrainer(
env="CartPole-v0",
config={
@ -53,23 +53,23 @@ class AgentIOTest(unittest.TestCase):
agent.train()
return agent
def testAgentOutputOk(self):
def test_agent_output_ok(self):
for fw in framework_iterator(frameworks=("torch", "tf")):
self.writeOutputs(self.test_dir, fw)
self.write_outputs(self.test_dir, fw)
self.assertEqual(len(os.listdir(self.test_dir + fw)), 1)
reader = JsonReader(self.test_dir + fw + "/*.json")
reader.next()
def testAgentOutputLogdir(self):
def test_agent_output_logdir(self):
"""Test special value 'logdir' as Agent's output."""
for fw in framework_iterator():
agent = self.writeOutputs("logdir", fw)
agent = self.write_outputs("logdir", fw)
self.assertEqual(
len(glob.glob(agent.logdir + "/output-*.json")), 1)
def testAgentInputDir(self):
def test_agent_input_dir(self):
for fw in framework_iterator(frameworks=("torch", "tf")):
self.writeOutputs(self.test_dir, fw)
self.write_outputs(self.test_dir, fw)
agent = PGTrainer(
env="CartPole-v0",
config={
@ -81,16 +81,16 @@ class AgentIOTest(unittest.TestCase):
self.assertEqual(result["timesteps_total"], 250) # read from input
self.assertTrue(np.isnan(result["episode_reward_mean"]))
def testSplitByEpisode(self):
def test_split_by_episode(self):
splits = SAMPLES.split_by_episode()
self.assertEqual(len(splits), 3)
self.assertEqual(splits[0].count, 2)
self.assertEqual(splits[1].count, 1)
self.assertEqual(splits[2].count, 1)
def testAgentInputPostprocessingEnabled(self):
def test_agent_input_postprocessing_enabled(self):
for fw in framework_iterator(frameworks=("tf", "torch")):
self.writeOutputs(self.test_dir, fw)
self.write_outputs(self.test_dir, fw)
# Rewrite the files to drop advantages and value_targets for
# testing
@ -100,7 +100,7 @@ class AgentIOTest(unittest.TestCase):
for line in f.readlines():
data = json.loads(line)
# Data won't contain rewards as these are not included
# in the writeOutputs run (not needed in the
# in the write_outputs run (not needed in the
# SampleBatch). Flip out "rewards" for "advantages"
# just for testing.
data["rewards"] = data["advantages"]
@ -125,9 +125,9 @@ class AgentIOTest(unittest.TestCase):
self.assertEqual(result["timesteps_total"], 250) # read from input
self.assertTrue(np.isnan(result["episode_reward_mean"]))
def testAgentInputEvalSim(self):
def test_agent_input_eval_sim(self):
for fw in framework_iterator():
self.writeOutputs(self.test_dir, fw)
self.write_outputs(self.test_dir, fw)
agent = PGTrainer(
env="CartPole-v0",
config={
@ -142,9 +142,9 @@ class AgentIOTest(unittest.TestCase):
time.sleep(0.1)
assert False, "did not see any simulation results"
def testAgentInputList(self):
def test_agent_input_list(self):
for fw in framework_iterator(frameworks=("torch", "tf")):
self.writeOutputs(self.test_dir, fw)
self.write_outputs(self.test_dir, fw)
agent = PGTrainer(
env="CartPole-v0",
config={
@ -157,9 +157,9 @@ class AgentIOTest(unittest.TestCase):
self.assertEqual(result["timesteps_total"], 250) # read from input
self.assertTrue(np.isnan(result["episode_reward_mean"]))
def testAgentInputDict(self):
def test_agent_input_dict(self):
for fw in framework_iterator():
self.writeOutputs(self.test_dir, fw)
self.write_outputs(self.test_dir, fw)
agent = PGTrainer(
env="CartPole-v0",
config={
@ -174,7 +174,7 @@ class AgentIOTest(unittest.TestCase):
result = agent.train()
self.assertTrue(not np.isnan(result["episode_reward_mean"]))
def testMultiAgent(self):
def test_multi_agent(self):
register_env("multi_agent_cartpole",
lambda _: MultiAgentCartPole({"num_agents": 10}))
@ -234,7 +234,7 @@ class AgentIOTest(unittest.TestCase):
]
for input_procedure in test_input_procedure:
for fw in framework_iterator(frameworks=("torch", "tf")):
self.writeOutputs(self.test_dir, fw)
self.write_outputs(self.test_dir, fw)
agent = PGTrainer(
env="CartPole-v0",
config={

View file

@ -0,0 +1,22 @@
two-step-game-qmix-without-mixer:
env: ray.rllib.examples.env.two_step_game.TwoStepGameWithGroupedAgents
run: QMIX
stop:
episode_reward_mean: 7.0
timesteps_total: 70000
config:
# QMIX only supports torch for now.
framework: torch
env_config:
env_config:
separate_state_space: true
one_hot_state_encoding: true
exploration_config:
final_epsilon: 0.0
rollout_fragment_length: 4
train_batch_size: 32
num_workers: 0
mixer: null

View file

@ -0,0 +1,22 @@
two-step-game-qmix-with-vdn-mixer:
env: ray.rllib.examples.env.two_step_game.TwoStepGameWithGroupedAgents
run: QMIX
stop:
episode_reward_mean: 7.0
timesteps_total: 70000
config:
# QMIX only supports torch for now.
framework: torch
env_config:
env_config:
separate_state_space: true
one_hot_state_encoding: true
exploration_config:
final_epsilon: 0.0
rollout_fragment_length: 4
train_batch_size: 32
num_workers: 0
mixer: vdn

View file

@ -0,0 +1,23 @@
two-step-game-qmix-with-qmix-mixer:
env: ray.rllib.examples.env.two_step_game.TwoStepGameWithGroupedAgents
run: QMIX
stop:
episode_reward_mean: 8.0
timesteps_total: 70000
config:
# QMIX only supports torch for now.
framework: torch
env_config:
env_config:
separate_state_space: true
one_hot_state_encoding: true
# W/o this setting, won't get to 8.0 reward.
exploration_config:
final_epsilon: 0.0
rollout_fragment_length: 4
train_batch_size: 32
num_workers: 0
mixer: qmix