mirror of
https://github.com/vale981/ray
synced 2025-03-05 10:01:43 -05:00
[RLlib] QMIX better defaults + added to CI learning tests (#21332)
This commit is contained in:
parent
8cc268096c
commit
abd3bef63b
10 changed files with 194 additions and 63 deletions
47
rllib/BUILD
47
rllib/BUILD
|
@ -11,10 +11,9 @@
|
|||
# Currently we have the following categories:
|
||||
|
||||
# - Learning tests/regression, tagged:
|
||||
# -- "learning_tests_[tf|tf2|torch]": Distinguish tf/tf2 vs torch.
|
||||
# -- "learning_tests_[discrete|continuous]_[tf|tf2|torch]": distinguish discrete
|
||||
# actions vs continuous actions AND tf vs torch.
|
||||
# -- "fake_gpus_[tf|torch]": Tests that run using 2 fake GPUs.
|
||||
# -- "learning_tests_[discrete|continuous]": distinguish discrete
|
||||
# actions vs continuous actions.
|
||||
# -- "fake_gpus": Tests that run using 2 fake GPUs.
|
||||
|
||||
# - Quick agent compilation/tune-train tests, tagged "quick_train".
|
||||
# NOTE: These should be obsoleted in favor of "trainers_dir" tests as
|
||||
|
@ -413,6 +412,37 @@ py_test(
|
|||
args = ["--yaml-dir=tuned_examples/ppo"]
|
||||
)
|
||||
|
||||
# QMIX
|
||||
py_test(
|
||||
name = "learning_tests_two_step_game_qmix",
|
||||
main = "tests/run_regression_tests.py",
|
||||
tags = ["team:ml", "learning_tests", "learning_tests_discrete"],
|
||||
size = "large",
|
||||
srcs = ["tests/run_regression_tests.py"],
|
||||
data = ["tuned_examples/qmix/two-step-game-qmix.yaml"],
|
||||
args = ["--yaml-dir=tuned_examples/qmix", "--framework=torch"]
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "learning_tests_two_step_game_qmix_vdn_mixer",
|
||||
main = "tests/run_regression_tests.py",
|
||||
tags = ["team:ml", "learning_tests", "learning_tests_discrete"],
|
||||
size = "large",
|
||||
srcs = ["tests/run_regression_tests.py"],
|
||||
data = ["tuned_examples/qmix/two-step-game-qmix-vdn-mixer.yaml"],
|
||||
args = ["--yaml-dir=tuned_examples/qmix", "--framework=torch"]
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "learning_tests_two_step_game_qmix_no_mixer",
|
||||
main = "tests/run_regression_tests.py",
|
||||
tags = ["team:ml", "learning_tests", "learning_tests_discrete"],
|
||||
size = "large",
|
||||
srcs = ["tests/run_regression_tests.py"],
|
||||
data = ["tuned_examples/qmix/two-step-game-qmix-no-mixer.yaml"],
|
||||
args = ["--yaml-dir=tuned_examples/qmix", "--framework=torch"]
|
||||
)
|
||||
|
||||
# R2D2
|
||||
py_test(
|
||||
name = "learning_tests_stateless_cartpole_r2d2",
|
||||
|
@ -2683,15 +2713,6 @@ py_test(
|
|||
args = ["--as-test", "--framework=torch", "--stop-reward=7", "--run=PG"]
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "examples/two_step_game_qmix",
|
||||
main = "examples/two_step_game.py",
|
||||
tags = ["team:ml", "examples", "examples_T"],
|
||||
size = "large",
|
||||
srcs = ["examples/two_step_game.py"],
|
||||
args = ["--as-test", "--framework=torch", "--stop-reward=7", "--run=QMIX"]
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "contrib/bandits/examples/lin_ts",
|
||||
main = "contrib/bandits/examples/simple_context_bandit.py",
|
||||
|
|
|
@ -34,8 +34,9 @@ DEFAULT_CONFIG = with_common_config({
|
|||
"type": "EpsilonGreedy",
|
||||
# Config for the Exploration class' constructor:
|
||||
"initial_epsilon": 1.0,
|
||||
"final_epsilon": 0.02,
|
||||
"epsilon_timesteps": 10000, # Timesteps over which to anneal epsilon.
|
||||
"final_epsilon": 0.01,
|
||||
# Timesteps over which to anneal epsilon.
|
||||
"epsilon_timesteps": 40000,
|
||||
|
||||
# For soft_q, use:
|
||||
# "exploration_config" = {
|
||||
|
|
|
@ -24,9 +24,10 @@ class AvailActionsTestEnv(MultiAgentEnv):
|
|||
|
||||
def __init__(self, env_config):
|
||||
self.state = None
|
||||
self.avail = env_config["avail_action"]
|
||||
self.avail = env_config.get("avail_actions", [3])
|
||||
self.action_mask = np.array([0] * 10)
|
||||
self.action_mask[env_config["avail_action"]] = 1
|
||||
for a in self.avail:
|
||||
self.action_mask[a] = 1
|
||||
|
||||
def reset(self):
|
||||
self.state = 0
|
||||
|
@ -34,22 +35,31 @@ class AvailActionsTestEnv(MultiAgentEnv):
|
|||
"agent_1": {
|
||||
"obs": self.observation_space["obs"].sample(),
|
||||
"action_mask": self.action_mask
|
||||
}
|
||||
},
|
||||
"agent_2": {
|
||||
"obs": self.observation_space["obs"].sample(),
|
||||
"action_mask": self.action_mask
|
||||
},
|
||||
}
|
||||
|
||||
def step(self, action_dict):
|
||||
if self.state > 0:
|
||||
assert action_dict["agent_1"] == self.avail, \
|
||||
assert (action_dict["agent_1"] in self.avail and
|
||||
action_dict["agent_2"] in self.avail), \
|
||||
"Failed to obey available actions mask!"
|
||||
self.state += 1
|
||||
rewards = {"agent_1": 1}
|
||||
rewards = {"agent_1": 1, "agent_2": 0.5}
|
||||
obs = {
|
||||
"agent_1": {
|
||||
"obs": self.observation_space["obs"].sample(),
|
||||
"action_mask": self.action_mask
|
||||
},
|
||||
"agent_2": {
|
||||
"obs": self.observation_space["obs"].sample(),
|
||||
"action_mask": self.action_mask
|
||||
}
|
||||
}
|
||||
dones = {"__all__": self.state > 20}
|
||||
dones = {"__all__": self.state >= 20}
|
||||
return obs, rewards, dones, {}
|
||||
|
||||
|
||||
|
@ -64,28 +74,33 @@ class TestQMix(unittest.TestCase):
|
|||
|
||||
def test_avail_actions_qmix(self):
|
||||
grouping = {
|
||||
"group_1": ["agent_1"], # trivial grouping for testing
|
||||
"group_1": ["agent_1", "agent_2"],
|
||||
}
|
||||
obs_space = Tuple([AvailActionsTestEnv.observation_space])
|
||||
act_space = Tuple([AvailActionsTestEnv.action_space])
|
||||
obs_space = Tuple([
|
||||
AvailActionsTestEnv.observation_space,
|
||||
AvailActionsTestEnv.observation_space
|
||||
])
|
||||
act_space = Tuple([
|
||||
AvailActionsTestEnv.action_space, AvailActionsTestEnv.action_space
|
||||
])
|
||||
register_env(
|
||||
"action_mask_test",
|
||||
lambda config: AvailActionsTestEnv(config).with_agent_groups(
|
||||
grouping, obs_space=obs_space, act_space=act_space))
|
||||
|
||||
agent = QMixTrainer(
|
||||
trainer = QMixTrainer(
|
||||
env="action_mask_test",
|
||||
config={
|
||||
"num_envs_per_worker": 5, # test with vectorization on
|
||||
"env_config": {
|
||||
"avail_action": 3,
|
||||
"avail_actions": [3, 4, 8],
|
||||
},
|
||||
"framework": "torch",
|
||||
})
|
||||
for _ in range(4):
|
||||
agent.train() # OK if it doesn't trip the action assertion error
|
||||
assert agent.train()["episode_reward_mean"] == 21.0
|
||||
agent.stop()
|
||||
trainer.train() # OK if it doesn't trip the action assertion error
|
||||
assert trainer.train()["episode_reward_mean"] == 30.0
|
||||
trainer.stop()
|
||||
ray.shutdown()
|
||||
|
||||
|
||||
|
|
22
rllib/examples/env/two_step_game.py
vendored
22
rllib/examples/env/two_step_game.py
vendored
|
@ -1,4 +1,4 @@
|
|||
from gym.spaces import MultiDiscrete, Dict, Discrete
|
||||
from gym.spaces import Dict, Discrete, MultiDiscrete, Tuple
|
||||
import numpy as np
|
||||
|
||||
from ray.rllib.env.multi_agent_env import MultiAgentEnv, ENV_STATE
|
||||
|
@ -109,3 +109,23 @@ class TwoStepGame(MultiAgentEnv):
|
|||
return np.concatenate([self.state, [2]])
|
||||
else:
|
||||
return np.flatnonzero(self.state)[0] + 3
|
||||
|
||||
|
||||
class TwoStepGameWithGroupedAgents(MultiAgentEnv):
|
||||
def __init__(self, env_config):
|
||||
env = TwoStepGame(env_config)
|
||||
tuple_obs_space = Tuple([env.observation_space, env.observation_space])
|
||||
tuple_act_space = Tuple([env.action_space, env.action_space])
|
||||
|
||||
self.env = env.with_agent_groups(
|
||||
groups={"agents": [0, 1]},
|
||||
obs_space=tuple_obs_space,
|
||||
act_space=tuple_act_space)
|
||||
self.observation_space = self.env.observation_space
|
||||
self.action_space = self.env.action_space
|
||||
|
||||
def reset(self):
|
||||
return self.env.reset()
|
||||
|
||||
def step(self, actions):
|
||||
return self.env.step(actions)
|
||||
|
|
|
@ -14,7 +14,7 @@ import os
|
|||
|
||||
import ray
|
||||
from ray import tune
|
||||
from ray.tune import register_env, grid_search
|
||||
from ray.tune import register_env
|
||||
from ray.rllib.env.multi_agent_env import ENV_STATE
|
||||
from ray.rllib.examples.env.two_step_game import TwoStepGame
|
||||
from ray.rllib.policy.policy import PolicySpec
|
||||
|
@ -32,6 +32,12 @@ parser.add_argument(
|
|||
default="tf",
|
||||
help="The DL framework specifier.")
|
||||
parser.add_argument("--num-cpus", type=int, default=0)
|
||||
parser.add_argument(
|
||||
"--mixer",
|
||||
type=str,
|
||||
default="qmix",
|
||||
choices=["qmix", "vdn", "none"],
|
||||
help="The mixer model to use.")
|
||||
parser.add_argument(
|
||||
"--as-test",
|
||||
action="store_true",
|
||||
|
@ -45,12 +51,12 @@ parser.add_argument(
|
|||
parser.add_argument(
|
||||
"--stop-timesteps",
|
||||
type=int,
|
||||
default=50000,
|
||||
default=70000,
|
||||
help="Number of timesteps to train.")
|
||||
parser.add_argument(
|
||||
"--stop-reward",
|
||||
type=float,
|
||||
default=7.0,
|
||||
default=8.0,
|
||||
help="Reward at which we stop training.")
|
||||
parser.add_argument(
|
||||
"--local-mode",
|
||||
|
@ -116,11 +122,10 @@ if __name__ == "__main__":
|
|||
"rollout_fragment_length": 4,
|
||||
"train_batch_size": 32,
|
||||
"exploration_config": {
|
||||
"epsilon_timesteps": 5000,
|
||||
"final_epsilon": 0.05,
|
||||
"final_epsilon": 0.0,
|
||||
},
|
||||
"num_workers": 0,
|
||||
"mixer": grid_search([None, "qmix"]),
|
||||
"mixer": args.mixer,
|
||||
"env_config": {
|
||||
"separate_state_space": True,
|
||||
"one_hot_state_encoding": True
|
||||
|
@ -147,9 +152,6 @@ if __name__ == "__main__":
|
|||
"env": "grouped_twostep" if group else TwoStepGame,
|
||||
})
|
||||
|
||||
if args.as_test:
|
||||
config["seed"] = 1234
|
||||
|
||||
results = tune.run(args.run, stop=stop, config=config, verbose=2)
|
||||
|
||||
if args.as_test:
|
||||
|
|
|
@ -53,6 +53,10 @@ parser.add_argument(
|
|||
if __name__ == "__main__":
|
||||
args = parser.parse_args()
|
||||
|
||||
# Error if deprecated --torch option used.
|
||||
if args.torch:
|
||||
deprecation_warning(old="--torch", new="--framework=torch", error=True)
|
||||
|
||||
# Bazel regression test mode: Get path to look for yaml files.
|
||||
# Get the path or single file to use.
|
||||
rllib_dir = Path(__file__).parent.parent
|
||||
|
@ -81,13 +85,14 @@ if __name__ == "__main__":
|
|||
assert len(experiments) == 1,\
|
||||
"Error, can only run a single experiment per yaml file!"
|
||||
|
||||
# Add torch option to exp config.
|
||||
exp = list(experiments.values())[0]
|
||||
exp["config"]["framework"] = args.framework
|
||||
if args.torch:
|
||||
deprecation_warning(old="--torch", new="--framework=torch")
|
||||
exp["config"]["framework"] = "torch"
|
||||
args.framework = "torch"
|
||||
|
||||
# QMIX does not support tf yet -> skip.
|
||||
if exp["run"] == "QMIX" and args.framework != "torch":
|
||||
print(f"Skipping framework='{args.framework}' for QMIX.")
|
||||
continue
|
||||
|
||||
# Always run with eager-tracing when framework=tf2.
|
||||
if args.framework in ["tf2", "tfe"]:
|
||||
exp["config"]["eager_tracing"] = True
|
||||
|
|
|
@ -42,7 +42,7 @@ class AgentIOTest(unittest.TestCase):
|
|||
shutil.rmtree(self.test_dir)
|
||||
ray.shutdown()
|
||||
|
||||
def writeOutputs(self, output, fw):
|
||||
def write_outputs(self, output, fw):
|
||||
agent = PGTrainer(
|
||||
env="CartPole-v0",
|
||||
config={
|
||||
|
@ -53,23 +53,23 @@ class AgentIOTest(unittest.TestCase):
|
|||
agent.train()
|
||||
return agent
|
||||
|
||||
def testAgentOutputOk(self):
|
||||
def test_agent_output_ok(self):
|
||||
for fw in framework_iterator(frameworks=("torch", "tf")):
|
||||
self.writeOutputs(self.test_dir, fw)
|
||||
self.write_outputs(self.test_dir, fw)
|
||||
self.assertEqual(len(os.listdir(self.test_dir + fw)), 1)
|
||||
reader = JsonReader(self.test_dir + fw + "/*.json")
|
||||
reader.next()
|
||||
|
||||
def testAgentOutputLogdir(self):
|
||||
def test_agent_output_logdir(self):
|
||||
"""Test special value 'logdir' as Agent's output."""
|
||||
for fw in framework_iterator():
|
||||
agent = self.writeOutputs("logdir", fw)
|
||||
agent = self.write_outputs("logdir", fw)
|
||||
self.assertEqual(
|
||||
len(glob.glob(agent.logdir + "/output-*.json")), 1)
|
||||
|
||||
def testAgentInputDir(self):
|
||||
def test_agent_input_dir(self):
|
||||
for fw in framework_iterator(frameworks=("torch", "tf")):
|
||||
self.writeOutputs(self.test_dir, fw)
|
||||
self.write_outputs(self.test_dir, fw)
|
||||
agent = PGTrainer(
|
||||
env="CartPole-v0",
|
||||
config={
|
||||
|
@ -81,16 +81,16 @@ class AgentIOTest(unittest.TestCase):
|
|||
self.assertEqual(result["timesteps_total"], 250) # read from input
|
||||
self.assertTrue(np.isnan(result["episode_reward_mean"]))
|
||||
|
||||
def testSplitByEpisode(self):
|
||||
def test_split_by_episode(self):
|
||||
splits = SAMPLES.split_by_episode()
|
||||
self.assertEqual(len(splits), 3)
|
||||
self.assertEqual(splits[0].count, 2)
|
||||
self.assertEqual(splits[1].count, 1)
|
||||
self.assertEqual(splits[2].count, 1)
|
||||
|
||||
def testAgentInputPostprocessingEnabled(self):
|
||||
def test_agent_input_postprocessing_enabled(self):
|
||||
for fw in framework_iterator(frameworks=("tf", "torch")):
|
||||
self.writeOutputs(self.test_dir, fw)
|
||||
self.write_outputs(self.test_dir, fw)
|
||||
|
||||
# Rewrite the files to drop advantages and value_targets for
|
||||
# testing
|
||||
|
@ -100,7 +100,7 @@ class AgentIOTest(unittest.TestCase):
|
|||
for line in f.readlines():
|
||||
data = json.loads(line)
|
||||
# Data won't contain rewards as these are not included
|
||||
# in the writeOutputs run (not needed in the
|
||||
# in the write_outputs run (not needed in the
|
||||
# SampleBatch). Flip out "rewards" for "advantages"
|
||||
# just for testing.
|
||||
data["rewards"] = data["advantages"]
|
||||
|
@ -125,9 +125,9 @@ class AgentIOTest(unittest.TestCase):
|
|||
self.assertEqual(result["timesteps_total"], 250) # read from input
|
||||
self.assertTrue(np.isnan(result["episode_reward_mean"]))
|
||||
|
||||
def testAgentInputEvalSim(self):
|
||||
def test_agent_input_eval_sim(self):
|
||||
for fw in framework_iterator():
|
||||
self.writeOutputs(self.test_dir, fw)
|
||||
self.write_outputs(self.test_dir, fw)
|
||||
agent = PGTrainer(
|
||||
env="CartPole-v0",
|
||||
config={
|
||||
|
@ -142,9 +142,9 @@ class AgentIOTest(unittest.TestCase):
|
|||
time.sleep(0.1)
|
||||
assert False, "did not see any simulation results"
|
||||
|
||||
def testAgentInputList(self):
|
||||
def test_agent_input_list(self):
|
||||
for fw in framework_iterator(frameworks=("torch", "tf")):
|
||||
self.writeOutputs(self.test_dir, fw)
|
||||
self.write_outputs(self.test_dir, fw)
|
||||
agent = PGTrainer(
|
||||
env="CartPole-v0",
|
||||
config={
|
||||
|
@ -157,9 +157,9 @@ class AgentIOTest(unittest.TestCase):
|
|||
self.assertEqual(result["timesteps_total"], 250) # read from input
|
||||
self.assertTrue(np.isnan(result["episode_reward_mean"]))
|
||||
|
||||
def testAgentInputDict(self):
|
||||
def test_agent_input_dict(self):
|
||||
for fw in framework_iterator():
|
||||
self.writeOutputs(self.test_dir, fw)
|
||||
self.write_outputs(self.test_dir, fw)
|
||||
agent = PGTrainer(
|
||||
env="CartPole-v0",
|
||||
config={
|
||||
|
@ -174,7 +174,7 @@ class AgentIOTest(unittest.TestCase):
|
|||
result = agent.train()
|
||||
self.assertTrue(not np.isnan(result["episode_reward_mean"]))
|
||||
|
||||
def testMultiAgent(self):
|
||||
def test_multi_agent(self):
|
||||
register_env("multi_agent_cartpole",
|
||||
lambda _: MultiAgentCartPole({"num_agents": 10}))
|
||||
|
||||
|
@ -234,7 +234,7 @@ class AgentIOTest(unittest.TestCase):
|
|||
]
|
||||
for input_procedure in test_input_procedure:
|
||||
for fw in framework_iterator(frameworks=("torch", "tf")):
|
||||
self.writeOutputs(self.test_dir, fw)
|
||||
self.write_outputs(self.test_dir, fw)
|
||||
agent = PGTrainer(
|
||||
env="CartPole-v0",
|
||||
config={
|
||||
|
|
22
rllib/tuned_examples/qmix/two-step-game-qmix-no-mixer.yaml
Normal file
22
rllib/tuned_examples/qmix/two-step-game-qmix-no-mixer.yaml
Normal file
|
@ -0,0 +1,22 @@
|
|||
two-step-game-qmix-without-mixer:
|
||||
env: ray.rllib.examples.env.two_step_game.TwoStepGameWithGroupedAgents
|
||||
run: QMIX
|
||||
stop:
|
||||
episode_reward_mean: 7.0
|
||||
timesteps_total: 70000
|
||||
config:
|
||||
# QMIX only supports torch for now.
|
||||
framework: torch
|
||||
|
||||
env_config:
|
||||
env_config:
|
||||
separate_state_space: true
|
||||
one_hot_state_encoding: true
|
||||
|
||||
exploration_config:
|
||||
final_epsilon: 0.0
|
||||
|
||||
rollout_fragment_length: 4
|
||||
train_batch_size: 32
|
||||
num_workers: 0
|
||||
mixer: null
|
22
rllib/tuned_examples/qmix/two-step-game-qmix-vdn-mixer.yaml
Normal file
22
rllib/tuned_examples/qmix/two-step-game-qmix-vdn-mixer.yaml
Normal file
|
@ -0,0 +1,22 @@
|
|||
two-step-game-qmix-with-vdn-mixer:
|
||||
env: ray.rllib.examples.env.two_step_game.TwoStepGameWithGroupedAgents
|
||||
run: QMIX
|
||||
stop:
|
||||
episode_reward_mean: 7.0
|
||||
timesteps_total: 70000
|
||||
config:
|
||||
# QMIX only supports torch for now.
|
||||
framework: torch
|
||||
|
||||
env_config:
|
||||
env_config:
|
||||
separate_state_space: true
|
||||
one_hot_state_encoding: true
|
||||
|
||||
exploration_config:
|
||||
final_epsilon: 0.0
|
||||
|
||||
rollout_fragment_length: 4
|
||||
train_batch_size: 32
|
||||
num_workers: 0
|
||||
mixer: vdn
|
23
rllib/tuned_examples/qmix/two-step-game-qmix.yaml
Normal file
23
rllib/tuned_examples/qmix/two-step-game-qmix.yaml
Normal file
|
@ -0,0 +1,23 @@
|
|||
two-step-game-qmix-with-qmix-mixer:
|
||||
env: ray.rllib.examples.env.two_step_game.TwoStepGameWithGroupedAgents
|
||||
run: QMIX
|
||||
stop:
|
||||
episode_reward_mean: 8.0
|
||||
timesteps_total: 70000
|
||||
config:
|
||||
# QMIX only supports torch for now.
|
||||
framework: torch
|
||||
|
||||
env_config:
|
||||
env_config:
|
||||
separate_state_space: true
|
||||
one_hot_state_encoding: true
|
||||
|
||||
# W/o this setting, won't get to 8.0 reward.
|
||||
exploration_config:
|
||||
final_epsilon: 0.0
|
||||
|
||||
rollout_fragment_length: 4
|
||||
train_batch_size: 32
|
||||
num_workers: 0
|
||||
mixer: qmix
|
Loading…
Add table
Reference in a new issue