mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
Add multi agent support in rollout.py (#4114)
This commit is contained in:
parent
48f6cd3e5d
commit
8288deb92d
1 changed files with 47 additions and 20 deletions
|
@ -73,15 +73,15 @@ def run(args, parser):
|
|||
if not config:
|
||||
# Load configuration from file
|
||||
config_dir = os.path.dirname(args.checkpoint)
|
||||
config_path = os.path.join(config_dir, "params.json")
|
||||
config_path = os.path.join(config_dir, "params.pkl")
|
||||
if not os.path.exists(config_path):
|
||||
config_path = os.path.join(config_dir, "../params.json")
|
||||
config_path = os.path.join(config_dir, "../params.pkl")
|
||||
if not os.path.exists(config_path):
|
||||
raise ValueError(
|
||||
"Could not find params.json in either the checkpoint dir or "
|
||||
"Could not find params.pkl in either the checkpoint dir or "
|
||||
"its parent directory.")
|
||||
with open(config_path) as f:
|
||||
config = json.load(f)
|
||||
with open(config_path, 'rb') as f:
|
||||
config = pickle.load(f)
|
||||
if "num_workers" in config:
|
||||
config["num_workers"] = min(2, config["num_workers"])
|
||||
|
||||
|
@ -102,18 +102,18 @@ def run(args, parser):
|
|||
def rollout(agent, env_name, num_steps, out=None, no_render=True):
|
||||
if hasattr(agent, "local_evaluator"):
|
||||
env = agent.local_evaluator.env
|
||||
multiagent = agent.local_evaluator.multiagent
|
||||
if multiagent:
|
||||
policy_agent_mapping = agent.config["multiagent"][
|
||||
"policy_mapping_fn"]
|
||||
mapping_cache = {}
|
||||
policy_map = agent.local_evaluator.policy_map
|
||||
state_init = {p: m.get_initial_state() for p, m in policy_map.items()}
|
||||
use_lstm = {p: len(s) > 0 for p, s in state_init.items()}
|
||||
else:
|
||||
env = gym.make(env_name)
|
||||
|
||||
if hasattr(agent, "local_evaluator"):
|
||||
state_init = agent.local_evaluator.policy_map[
|
||||
"default"].get_initial_state()
|
||||
else:
|
||||
state_init = []
|
||||
if state_init:
|
||||
use_lstm = True
|
||||
else:
|
||||
use_lstm = False
|
||||
multiagent = False
|
||||
use_lstm = {'default': False}
|
||||
|
||||
if out is not None:
|
||||
rollouts = []
|
||||
|
@ -125,12 +125,38 @@ def rollout(agent, env_name, num_steps, out=None, no_render=True):
|
|||
done = False
|
||||
reward_total = 0.0
|
||||
while not done and steps < (num_steps or steps + 1):
|
||||
if use_lstm:
|
||||
action, state_init, logits = agent.compute_action(
|
||||
if multiagent:
|
||||
action_dict = {}
|
||||
for agent_id in state.keys():
|
||||
a_state = state[agent_id]
|
||||
if a_state is not None:
|
||||
policy_id = mapping_cache.setdefault(
|
||||
agent_id, policy_agent_mapping(agent_id))
|
||||
p_use_lstm = use_lstm[policy_id]
|
||||
if p_use_lstm:
|
||||
a_action, p_state_init, _ = agent.compute_action(
|
||||
a_state,
|
||||
state=state_init[policy_id],
|
||||
policy_id=policy_id)
|
||||
state_init[policy_id] = p_state_init
|
||||
else:
|
||||
a_action = agent.compute_action(
|
||||
a_state, policy_id=policy_id)
|
||||
action_dict[agent_id] = a_action
|
||||
action = action_dict
|
||||
else:
|
||||
if use_lstm["default"]:
|
||||
action, state_init, _ = agent.compute_action(
|
||||
state, state=state_init)
|
||||
else:
|
||||
action = agent.compute_action(state)
|
||||
|
||||
next_state, reward, done, _ = env.step(action)
|
||||
|
||||
if multiagent:
|
||||
done = done["__all__"]
|
||||
reward_total += sum(reward.values())
|
||||
else:
|
||||
reward_total += reward
|
||||
if not no_render:
|
||||
env.render()
|
||||
|
@ -141,6 +167,7 @@ def rollout(agent, env_name, num_steps, out=None, no_render=True):
|
|||
if out is not None:
|
||||
rollouts.append(rollout)
|
||||
print("Episode reward", reward_total)
|
||||
|
||||
if out is not None:
|
||||
pickle.dump(rollouts, open(out, "wb"))
|
||||
|
||||
|
|
Loading…
Add table
Reference in a new issue