From 8288deb92df50678790024dc1fba1057f7e2a874 Mon Sep 17 00:00:00 2001 From: Antoine Galataud Date: Sat, 2 Mar 2019 04:45:39 +0100 Subject: [PATCH] Add multi agent support in rollout.py (#4114) --- python/ray/rllib/rollout.py | 67 ++++++++++++++++++++++++++----------- 1 file changed, 47 insertions(+), 20 deletions(-) diff --git a/python/ray/rllib/rollout.py b/python/ray/rllib/rollout.py index 05f19e1e8..68f2d6456 100755 --- a/python/ray/rllib/rollout.py +++ b/python/ray/rllib/rollout.py @@ -73,15 +73,15 @@ def run(args, parser): if not config: # Load configuration from file config_dir = os.path.dirname(args.checkpoint) - config_path = os.path.join(config_dir, "params.json") + config_path = os.path.join(config_dir, "params.pkl") if not os.path.exists(config_path): - config_path = os.path.join(config_dir, "../params.json") + config_path = os.path.join(config_dir, "../params.pkl") if not os.path.exists(config_path): raise ValueError( - "Could not find params.json in either the checkpoint dir or " + "Could not find params.pkl in either the checkpoint dir or " "its parent directory.") - with open(config_path) as f: - config = json.load(f) + with open(config_path, 'rb') as f: + config = pickle.load(f) if "num_workers" in config: config["num_workers"] = min(2, config["num_workers"]) @@ -102,18 +102,18 @@ def run(args, parser): def rollout(agent, env_name, num_steps, out=None, no_render=True): if hasattr(agent, "local_evaluator"): env = agent.local_evaluator.env + multiagent = agent.local_evaluator.multiagent + if multiagent: + policy_agent_mapping = agent.config["multiagent"][ + "policy_mapping_fn"] + mapping_cache = {} + policy_map = agent.local_evaluator.policy_map + state_init = {p: m.get_initial_state() for p, m in policy_map.items()} + use_lstm = {p: len(s) > 0 for p, s in state_init.items()} else: env = gym.make(env_name) - - if hasattr(agent, "local_evaluator"): - state_init = agent.local_evaluator.policy_map[ - "default"].get_initial_state() - else: - state_init = [] - if state_init: - use_lstm = True - else: - use_lstm = False + multiagent = False + use_lstm = {'default': False} if out is not None: rollouts = [] @@ -125,13 +125,39 @@ def rollout(agent, env_name, num_steps, out=None, no_render=True): done = False reward_total = 0.0 while not done and steps < (num_steps or steps + 1): - if use_lstm: - action, state_init, logits = agent.compute_action( - state, state=state_init) + if multiagent: + action_dict = {} + for agent_id in state.keys(): + a_state = state[agent_id] + if a_state is not None: + policy_id = mapping_cache.setdefault( + agent_id, policy_agent_mapping(agent_id)) + p_use_lstm = use_lstm[policy_id] + if p_use_lstm: + a_action, p_state_init, _ = agent.compute_action( + a_state, + state=state_init[policy_id], + policy_id=policy_id) + state_init[policy_id] = p_state_init + else: + a_action = agent.compute_action( + a_state, policy_id=policy_id) + action_dict[agent_id] = a_action + action = action_dict else: - action = agent.compute_action(state) + if use_lstm["default"]: + action, state_init, _ = agent.compute_action( + state, state=state_init) + else: + action = agent.compute_action(state) + next_state, reward, done, _ = env.step(action) - reward_total += reward + + if multiagent: + done = done["__all__"] + reward_total += sum(reward.values()) + else: + reward_total += reward if not no_render: env.render() if out is not None: @@ -141,6 +167,7 @@ def rollout(agent, env_name, num_steps, out=None, no_render=True): if out is not None: rollouts.append(rollout) print("Episode reward", reward_total) + if out is not None: pickle.dump(rollouts, open(out, "wb"))