[RLlib] Issue 21489: Unity3D env lacks group rewards (#24016).

This commit is contained in:
Grzegorz Rypeść 2022-04-21 18:49:52 +02:00 committed by GitHub
parent 732175e245
commit dfb9689701
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 31 additions and 2 deletions

View file

@ -219,7 +219,9 @@ class Unity3DEnv(MultiAgentEnv):
os = tuple(o[idx] for o in decision_steps.obs)
os = os[0] if len(os) == 1 else os
obs[key] = os
rewards[key] = decision_steps.reward[idx] # rewards vector
rewards[key] = (
decision_steps.reward[idx] + decision_steps.group_reward[idx]
)
for agent_id, idx in terminal_steps.agent_id_to_index.items():
key = behavior_name + "_{}".format(agent_id)
# Only overwrite rewards (last reward in episode), b/c obs
@ -228,7 +230,9 @@ class Unity3DEnv(MultiAgentEnv):
if key not in obs:
os = tuple(o[idx] for o in terminal_steps.obs)
obs[key] = os = os[0] if len(os) == 1 else os
rewards[key] = terminal_steps.reward[idx] # rewards vector
rewards[key] = (
terminal_steps.reward[idx] + terminal_steps.group_reward[idx]
)
# Only use dones if all agents are done, then we should do a reset.
return obs, rewards, {"__all__": False}, infos
@ -256,6 +260,13 @@ class Unity3DEnv(MultiAgentEnv):
Box(float("-inf"), float("inf"), (4,)),
]
),
# SoccerTwos.
"SoccerPlayer": TupleSpace(
[
Box(-1.0, 1.0, (264,)),
Box(-1.0, 1.0, (72,)),
]
),
# SoccerStrikersVsGoalie.
"Goalie": Box(float("-inf"), float("inf"), (738,)),
"Striker": TupleSpace(
@ -305,6 +316,8 @@ class Unity3DEnv(MultiAgentEnv):
# SoccerStrikersVsGoalie.
"Goalie": MultiDiscrete([3, 3, 3]),
"Striker": MultiDiscrete([3, 3, 3]),
# SoccerTwos.
"SoccerPlayer": MultiDiscrete([3, 3, 3]),
# Sorter.
"Sorter": MultiDiscrete([3, 3, 3]),
# Tennis.
@ -333,6 +346,21 @@ class Unity3DEnv(MultiAgentEnv):
def policy_mapping_fn(agent_id, episode, worker, **kwargs):
return "Striker" if "Striker" in agent_id else "Goalie"
elif game_name == "SoccerTwos":
policies = {
"PurplePlayer": PolicySpec(
observation_space=obs_spaces["SoccerPlayer"],
action_space=action_spaces["SoccerPlayer"],
),
"BluePlayer": PolicySpec(
observation_space=obs_spaces["SoccerPlayer"],
action_space=action_spaces["SoccerPlayer"],
),
}
def policy_mapping_fn(agent_id, episode, worker, **kwargs):
return "BluePlayer" if "1_" in agent_id else "PurplePlayer"
else:
policies = {
game_name: PolicySpec(

View file

@ -40,6 +40,7 @@ parser.add_argument(
"GridFoodCollector",
"Pyramids",
"SoccerStrikersVsGoalie",
"SoccerTwos",
"Sorter",
"Tennis",
"VisualHallway",