mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
Add multi agent support in rollout.py (#4114)
This commit is contained in:
parent
48f6cd3e5d
commit
8288deb92d
1 changed files with 47 additions and 20 deletions
|
@ -73,15 +73,15 @@ def run(args, parser):
|
||||||
if not config:
|
if not config:
|
||||||
# Load configuration from file
|
# Load configuration from file
|
||||||
config_dir = os.path.dirname(args.checkpoint)
|
config_dir = os.path.dirname(args.checkpoint)
|
||||||
config_path = os.path.join(config_dir, "params.json")
|
config_path = os.path.join(config_dir, "params.pkl")
|
||||||
if not os.path.exists(config_path):
|
if not os.path.exists(config_path):
|
||||||
config_path = os.path.join(config_dir, "../params.json")
|
config_path = os.path.join(config_dir, "../params.pkl")
|
||||||
if not os.path.exists(config_path):
|
if not os.path.exists(config_path):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Could not find params.json in either the checkpoint dir or "
|
"Could not find params.pkl in either the checkpoint dir or "
|
||||||
"its parent directory.")
|
"its parent directory.")
|
||||||
with open(config_path) as f:
|
with open(config_path, 'rb') as f:
|
||||||
config = json.load(f)
|
config = pickle.load(f)
|
||||||
if "num_workers" in config:
|
if "num_workers" in config:
|
||||||
config["num_workers"] = min(2, config["num_workers"])
|
config["num_workers"] = min(2, config["num_workers"])
|
||||||
|
|
||||||
|
@ -102,18 +102,18 @@ def run(args, parser):
|
||||||
def rollout(agent, env_name, num_steps, out=None, no_render=True):
|
def rollout(agent, env_name, num_steps, out=None, no_render=True):
|
||||||
if hasattr(agent, "local_evaluator"):
|
if hasattr(agent, "local_evaluator"):
|
||||||
env = agent.local_evaluator.env
|
env = agent.local_evaluator.env
|
||||||
|
multiagent = agent.local_evaluator.multiagent
|
||||||
|
if multiagent:
|
||||||
|
policy_agent_mapping = agent.config["multiagent"][
|
||||||
|
"policy_mapping_fn"]
|
||||||
|
mapping_cache = {}
|
||||||
|
policy_map = agent.local_evaluator.policy_map
|
||||||
|
state_init = {p: m.get_initial_state() for p, m in policy_map.items()}
|
||||||
|
use_lstm = {p: len(s) > 0 for p, s in state_init.items()}
|
||||||
else:
|
else:
|
||||||
env = gym.make(env_name)
|
env = gym.make(env_name)
|
||||||
|
multiagent = False
|
||||||
if hasattr(agent, "local_evaluator"):
|
use_lstm = {'default': False}
|
||||||
state_init = agent.local_evaluator.policy_map[
|
|
||||||
"default"].get_initial_state()
|
|
||||||
else:
|
|
||||||
state_init = []
|
|
||||||
if state_init:
|
|
||||||
use_lstm = True
|
|
||||||
else:
|
|
||||||
use_lstm = False
|
|
||||||
|
|
||||||
if out is not None:
|
if out is not None:
|
||||||
rollouts = []
|
rollouts = []
|
||||||
|
@ -125,13 +125,39 @@ def rollout(agent, env_name, num_steps, out=None, no_render=True):
|
||||||
done = False
|
done = False
|
||||||
reward_total = 0.0
|
reward_total = 0.0
|
||||||
while not done and steps < (num_steps or steps + 1):
|
while not done and steps < (num_steps or steps + 1):
|
||||||
if use_lstm:
|
if multiagent:
|
||||||
action, state_init, logits = agent.compute_action(
|
action_dict = {}
|
||||||
state, state=state_init)
|
for agent_id in state.keys():
|
||||||
|
a_state = state[agent_id]
|
||||||
|
if a_state is not None:
|
||||||
|
policy_id = mapping_cache.setdefault(
|
||||||
|
agent_id, policy_agent_mapping(agent_id))
|
||||||
|
p_use_lstm = use_lstm[policy_id]
|
||||||
|
if p_use_lstm:
|
||||||
|
a_action, p_state_init, _ = agent.compute_action(
|
||||||
|
a_state,
|
||||||
|
state=state_init[policy_id],
|
||||||
|
policy_id=policy_id)
|
||||||
|
state_init[policy_id] = p_state_init
|
||||||
|
else:
|
||||||
|
a_action = agent.compute_action(
|
||||||
|
a_state, policy_id=policy_id)
|
||||||
|
action_dict[agent_id] = a_action
|
||||||
|
action = action_dict
|
||||||
else:
|
else:
|
||||||
action = agent.compute_action(state)
|
if use_lstm["default"]:
|
||||||
|
action, state_init, _ = agent.compute_action(
|
||||||
|
state, state=state_init)
|
||||||
|
else:
|
||||||
|
action = agent.compute_action(state)
|
||||||
|
|
||||||
next_state, reward, done, _ = env.step(action)
|
next_state, reward, done, _ = env.step(action)
|
||||||
reward_total += reward
|
|
||||||
|
if multiagent:
|
||||||
|
done = done["__all__"]
|
||||||
|
reward_total += sum(reward.values())
|
||||||
|
else:
|
||||||
|
reward_total += reward
|
||||||
if not no_render:
|
if not no_render:
|
||||||
env.render()
|
env.render()
|
||||||
if out is not None:
|
if out is not None:
|
||||||
|
@ -141,6 +167,7 @@ def rollout(agent, env_name, num_steps, out=None, no_render=True):
|
||||||
if out is not None:
|
if out is not None:
|
||||||
rollouts.append(rollout)
|
rollouts.append(rollout)
|
||||||
print("Episode reward", reward_total)
|
print("Episode reward", reward_total)
|
||||||
|
|
||||||
if out is not None:
|
if out is not None:
|
||||||
pickle.dump(rollouts, open(out, "wb"))
|
pickle.dump(rollouts, open(out, "wb"))
|
||||||
|
|
||||||
|
|
Loading…
Add table
Reference in a new issue