mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
142 lines
4.8 KiB
Python
142 lines
4.8 KiB
Python
import ray
|
|
import unittest
|
|
import numpy as np
|
|
from ray.rllib.agents.callbacks import DefaultCallbacks
|
|
from ray.rllib.env.multi_agent_env import MultiAgentEnv
|
|
from ray.rllib.evaluation.rollout_worker import RolloutWorker
|
|
from ray.rllib.examples.env.mock_env import MockEnv3
|
|
from ray.rllib.policy import Policy
|
|
from ray.rllib.utils import override
|
|
|
|
NUM_STEPS = 25
|
|
NUM_AGENTS = 4
|
|
|
|
|
|
class LastInfoCallback(DefaultCallbacks):
|
|
def __init__(self):
|
|
super(LastInfoCallback, self).__init__()
|
|
self.tc = unittest.TestCase()
|
|
self.step = 0
|
|
|
|
def on_episode_start(self, worker, base_env, policies, episode, env_index,
|
|
**kwargs):
|
|
self.step = 0
|
|
self._check_last_values(episode)
|
|
|
|
def on_episode_step(self,
|
|
worker,
|
|
base_env,
|
|
episode,
|
|
env_index=None,
|
|
**kwargs):
|
|
self.step += 1
|
|
self._check_last_values(episode)
|
|
|
|
def on_episode_end(self, worker, base_env, policies, episode, **kwargs):
|
|
self._check_last_values(episode)
|
|
|
|
def _check_last_values(self, episode):
|
|
last_obs = {
|
|
k: np.where(v)[0].item()
|
|
for k, v in episode._agent_to_last_obs.items()
|
|
}
|
|
last_info = episode._agent_to_last_info
|
|
last_done = episode._agent_to_last_done
|
|
last_action = episode._agent_to_last_action
|
|
last_reward = {
|
|
k: v[-1]
|
|
for k, v in episode._agent_reward_history.items()
|
|
}
|
|
if self.step == 0:
|
|
for last in [
|
|
last_obs, last_info, last_done, last_action, last_reward
|
|
]:
|
|
self.tc.assertEqual(last, {})
|
|
else:
|
|
for agent in last_obs.keys():
|
|
index = int(str(agent).replace("agent", ""))
|
|
self.tc.assertEqual(last_obs[agent], self.step + index)
|
|
self.tc.assertEqual(last_reward[agent], self.step + index)
|
|
self.tc.assertEqual(last_done[agent], self.step == NUM_STEPS)
|
|
if self.step == 1:
|
|
self.tc.assertEqual(last_action[agent], 0)
|
|
else:
|
|
self.tc.assertEqual(last_action[agent],
|
|
self.step + index - 1)
|
|
self.tc.assertEqual(last_info[agent]["timestep"],
|
|
self.step + index)
|
|
|
|
|
|
class EchoPolicy(Policy):
|
|
@override(Policy)
|
|
def compute_actions(self,
|
|
obs_batch,
|
|
state_batches=None,
|
|
prev_action_batch=None,
|
|
prev_reward_batch=None,
|
|
episodes=None,
|
|
explore=None,
|
|
timestep=None,
|
|
**kwargs):
|
|
return obs_batch.argmax(axis=1), [], {}
|
|
|
|
|
|
class EpisodeEnv(MultiAgentEnv):
|
|
def __init__(self, episode_length, num):
|
|
self.agents = [MockEnv3(episode_length) for _ in range(num)]
|
|
self.dones = set()
|
|
self.observation_space = self.agents[0].observation_space
|
|
self.action_space = self.agents[0].action_space
|
|
|
|
def reset(self):
|
|
self.dones = set()
|
|
return {i: a.reset() for i, a in enumerate(self.agents)}
|
|
|
|
def step(self, action_dict):
|
|
obs, rew, done, info = {}, {}, {}, {}
|
|
print("ACTIONDICT IN ENV\n", action_dict)
|
|
for i, action in action_dict.items():
|
|
obs[i], rew[i], done[i], info[i] = self.agents[i].step(action)
|
|
obs[i] = obs[i] + i
|
|
rew[i] = rew[i] + i
|
|
info[i]["timestep"] = info[i]["timestep"] + i
|
|
if done[i]:
|
|
self.dones.add(i)
|
|
done["__all__"] = len(self.dones) == len(self.agents)
|
|
return obs, rew, done, info
|
|
|
|
|
|
class TestEpisodeLastValues(unittest.TestCase):
|
|
@classmethod
|
|
def setUpClass(cls):
|
|
ray.init(num_cpus=1)
|
|
|
|
@classmethod
|
|
def tearDownClass(cls):
|
|
ray.shutdown()
|
|
|
|
def test_singleagent_env(self):
|
|
ev = RolloutWorker(
|
|
env_creator=lambda _: MockEnv3(NUM_STEPS),
|
|
policy_spec=EchoPolicy,
|
|
callbacks=LastInfoCallback)
|
|
ev.sample()
|
|
|
|
def test_multiagent_env(self):
|
|
temp_env = EpisodeEnv(NUM_STEPS, NUM_AGENTS)
|
|
ev = RolloutWorker(
|
|
env_creator=lambda _: EpisodeEnv(NUM_STEPS, NUM_AGENTS),
|
|
policy_spec={
|
|
str(agent_id): (EchoPolicy, temp_env.observation_space,
|
|
temp_env.action_space, {})
|
|
for agent_id in range(NUM_AGENTS)
|
|
},
|
|
policy_mapping_fn=lambda aid, eps, **kwargs: str(aid),
|
|
callbacks=LastInfoCallback)
|
|
ev.sample()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
import pytest
|
|
import sys
|
|
sys.exit(pytest.main(["-v", __file__]))
|