mirror of
https://github.com/vale981/ray
synced 2025-03-08 19:41:38 -05:00
53 lines
1.5 KiB
Python
53 lines
1.5 KiB
Python
![]() |
import unittest
|
||
|
|
||
|
from ray.rllib.agents.pg import PGTrainer
|
||
|
|
||
|
|
||
|
class TestCheckMultiAgent(unittest.TestCase):
|
||
|
def test_multi_agent_dict_invalid_subkeys(self):
|
||
|
config = {
|
||
|
"multiagent": {
|
||
|
"wrong_key": 1,
|
||
|
"policies": {"p0"},
|
||
|
"policies_to_train": ["p0"]
|
||
|
}
|
||
|
}
|
||
|
self.assertRaisesRegex(
|
||
|
KeyError,
|
||
|
"You have invalid keys in your",
|
||
|
lambda: PGTrainer(config, env="CartPole-v0"),
|
||
|
)
|
||
|
|
||
|
def test_multi_agent_dict_bad_policy_ids(self):
|
||
|
config = {
|
||
|
"multiagent": {
|
||
|
"policies": {1, "good_id"},
|
||
|
"policy_mapping_fn": lambda aid, **kw: "good_id"
|
||
|
}
|
||
|
}
|
||
|
self.assertRaisesRegex(
|
||
|
KeyError,
|
||
|
"Policy IDs must always be of type",
|
||
|
lambda: PGTrainer(config, env="CartPole-v0"),
|
||
|
)
|
||
|
|
||
|
def test_multi_agent_dict_invalid_sub_values(self):
|
||
|
config = {"multiagent": {"count_steps_by": "invalid_value"}}
|
||
|
self.assertRaisesRegex(
|
||
|
ValueError,
|
||
|
"config.multiagent.count_steps_by must be",
|
||
|
lambda: PGTrainer(config, env="CartPole-v0"),
|
||
|
)
|
||
|
|
||
|
config = {"multiagent": {"replay_mode": "invalid_value"}}
|
||
|
self.assertRaisesRegex(
|
||
|
ValueError,
|
||
|
"config.multiagent.replay_mode must be",
|
||
|
lambda: PGTrainer(config, env="CartPole-v0"),
|
||
|
)
|
||
|
|
||
|
|
||
|
if __name__ == "__main__":
|
||
|
import pytest
|
||
|
pytest.main()
|