Add multi agent support in rollout.py (#4114)

This commit is contained in:
Antoine Galataud 2019-03-02 04:45:39 +01:00 committed by Eric Liang
parent 48f6cd3e5d
commit 8288deb92d

View file

@ -73,15 +73,15 @@ def run(args, parser):
if not config:
# Load configuration from file
config_dir = os.path.dirname(args.checkpoint)
config_path = os.path.join(config_dir, "params.json")
config_path = os.path.join(config_dir, "params.pkl")
if not os.path.exists(config_path):
config_path = os.path.join(config_dir, "../params.json")
config_path = os.path.join(config_dir, "../params.pkl")
if not os.path.exists(config_path):
raise ValueError(
"Could not find params.json in either the checkpoint dir or "
"Could not find params.pkl in either the checkpoint dir or "
"its parent directory.")
with open(config_path) as f:
config = json.load(f)
with open(config_path, 'rb') as f:
config = pickle.load(f)
if "num_workers" in config:
config["num_workers"] = min(2, config["num_workers"])
@ -102,18 +102,18 @@ def run(args, parser):
def rollout(agent, env_name, num_steps, out=None, no_render=True):
if hasattr(agent, "local_evaluator"):
env = agent.local_evaluator.env
multiagent = agent.local_evaluator.multiagent
if multiagent:
policy_agent_mapping = agent.config["multiagent"][
"policy_mapping_fn"]
mapping_cache = {}
policy_map = agent.local_evaluator.policy_map
state_init = {p: m.get_initial_state() for p, m in policy_map.items()}
use_lstm = {p: len(s) > 0 for p, s in state_init.items()}
else:
env = gym.make(env_name)
if hasattr(agent, "local_evaluator"):
state_init = agent.local_evaluator.policy_map[
"default"].get_initial_state()
else:
state_init = []
if state_init:
use_lstm = True
else:
use_lstm = False
multiagent = False
use_lstm = {'default': False}
if out is not None:
rollouts = []
@ -125,12 +125,38 @@ def rollout(agent, env_name, num_steps, out=None, no_render=True):
done = False
reward_total = 0.0
while not done and steps < (num_steps or steps + 1):
if use_lstm:
action, state_init, logits = agent.compute_action(
if multiagent:
action_dict = {}
for agent_id in state.keys():
a_state = state[agent_id]
if a_state is not None:
policy_id = mapping_cache.setdefault(
agent_id, policy_agent_mapping(agent_id))
p_use_lstm = use_lstm[policy_id]
if p_use_lstm:
a_action, p_state_init, _ = agent.compute_action(
a_state,
state=state_init[policy_id],
policy_id=policy_id)
state_init[policy_id] = p_state_init
else:
a_action = agent.compute_action(
a_state, policy_id=policy_id)
action_dict[agent_id] = a_action
action = action_dict
else:
if use_lstm["default"]:
action, state_init, _ = agent.compute_action(
state, state=state_init)
else:
action = agent.compute_action(state)
next_state, reward, done, _ = env.step(action)
if multiagent:
done = done["__all__"]
reward_total += sum(reward.values())
else:
reward_total += reward
if not no_render:
env.render()
@ -141,6 +167,7 @@ def rollout(agent, env_name, num_steps, out=None, no_render=True):
if out is not None:
rollouts.append(rollout)
print("Episode reward", reward_total)
if out is not None:
pickle.dump(rollouts, open(out, "wb"))