mirror of
https://github.com/vale981/ray
synced 2025-03-06 18:41:40 -05:00
52 lines
1.5 KiB
Python
52 lines
1.5 KiB
Python
import unittest
|
|
|
|
from ray.rllib.agents.pg import PGTrainer
|
|
|
|
|
|
class TestCheckMultiAgent(unittest.TestCase):
|
|
def test_multi_agent_dict_invalid_subkeys(self):
|
|
config = {
|
|
"multiagent": {
|
|
"wrong_key": 1,
|
|
"policies": {"p0"},
|
|
"policies_to_train": ["p0"]
|
|
}
|
|
}
|
|
self.assertRaisesRegex(
|
|
KeyError,
|
|
"You have invalid keys in your",
|
|
lambda: PGTrainer(config, env="CartPole-v0"),
|
|
)
|
|
|
|
def test_multi_agent_dict_bad_policy_ids(self):
|
|
config = {
|
|
"multiagent": {
|
|
"policies": {1, "good_id"},
|
|
"policy_mapping_fn": lambda aid, **kw: "good_id"
|
|
}
|
|
}
|
|
self.assertRaisesRegex(
|
|
KeyError,
|
|
"Policy IDs must always be of type",
|
|
lambda: PGTrainer(config, env="CartPole-v0"),
|
|
)
|
|
|
|
def test_multi_agent_dict_invalid_sub_values(self):
|
|
config = {"multiagent": {"count_steps_by": "invalid_value"}}
|
|
self.assertRaisesRegex(
|
|
ValueError,
|
|
"config.multiagent.count_steps_by must be",
|
|
lambda: PGTrainer(config, env="CartPole-v0"),
|
|
)
|
|
|
|
config = {"multiagent": {"replay_mode": "invalid_value"}}
|
|
self.assertRaisesRegex(
|
|
ValueError,
|
|
"config.multiagent.replay_mode must be",
|
|
lambda: PGTrainer(config, env="CartPole-v0"),
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
import pytest
|
|
pytest.main()
|