mirror of
https://github.com/vale981/ray
synced 2025-03-05 18:11:42 -05:00
[RLlib] Issue 21489: Unity3D env lacks group rewards (#24016).
This commit is contained in:
parent
732175e245
commit
dfb9689701
2 changed files with 31 additions and 2 deletions
32
rllib/env/wrappers/unity3d_env.py
vendored
32
rllib/env/wrappers/unity3d_env.py
vendored
|
@ -219,7 +219,9 @@ class Unity3DEnv(MultiAgentEnv):
|
|||
os = tuple(o[idx] for o in decision_steps.obs)
|
||||
os = os[0] if len(os) == 1 else os
|
||||
obs[key] = os
|
||||
rewards[key] = decision_steps.reward[idx] # rewards vector
|
||||
rewards[key] = (
|
||||
decision_steps.reward[idx] + decision_steps.group_reward[idx]
|
||||
)
|
||||
for agent_id, idx in terminal_steps.agent_id_to_index.items():
|
||||
key = behavior_name + "_{}".format(agent_id)
|
||||
# Only overwrite rewards (last reward in episode), b/c obs
|
||||
|
@ -228,7 +230,9 @@ class Unity3DEnv(MultiAgentEnv):
|
|||
if key not in obs:
|
||||
os = tuple(o[idx] for o in terminal_steps.obs)
|
||||
obs[key] = os = os[0] if len(os) == 1 else os
|
||||
rewards[key] = terminal_steps.reward[idx] # rewards vector
|
||||
rewards[key] = (
|
||||
terminal_steps.reward[idx] + terminal_steps.group_reward[idx]
|
||||
)
|
||||
|
||||
# Only use dones if all agents are done, then we should do a reset.
|
||||
return obs, rewards, {"__all__": False}, infos
|
||||
|
@ -256,6 +260,13 @@ class Unity3DEnv(MultiAgentEnv):
|
|||
Box(float("-inf"), float("inf"), (4,)),
|
||||
]
|
||||
),
|
||||
# SoccerTwos.
|
||||
"SoccerPlayer": TupleSpace(
|
||||
[
|
||||
Box(-1.0, 1.0, (264,)),
|
||||
Box(-1.0, 1.0, (72,)),
|
||||
]
|
||||
),
|
||||
# SoccerStrikersVsGoalie.
|
||||
"Goalie": Box(float("-inf"), float("inf"), (738,)),
|
||||
"Striker": TupleSpace(
|
||||
|
@ -305,6 +316,8 @@ class Unity3DEnv(MultiAgentEnv):
|
|||
# SoccerStrikersVsGoalie.
|
||||
"Goalie": MultiDiscrete([3, 3, 3]),
|
||||
"Striker": MultiDiscrete([3, 3, 3]),
|
||||
# SoccerTwos.
|
||||
"SoccerPlayer": MultiDiscrete([3, 3, 3]),
|
||||
# Sorter.
|
||||
"Sorter": MultiDiscrete([3, 3, 3]),
|
||||
# Tennis.
|
||||
|
@ -333,6 +346,21 @@ class Unity3DEnv(MultiAgentEnv):
|
|||
def policy_mapping_fn(agent_id, episode, worker, **kwargs):
|
||||
return "Striker" if "Striker" in agent_id else "Goalie"
|
||||
|
||||
elif game_name == "SoccerTwos":
|
||||
policies = {
|
||||
"PurplePlayer": PolicySpec(
|
||||
observation_space=obs_spaces["SoccerPlayer"],
|
||||
action_space=action_spaces["SoccerPlayer"],
|
||||
),
|
||||
"BluePlayer": PolicySpec(
|
||||
observation_space=obs_spaces["SoccerPlayer"],
|
||||
action_space=action_spaces["SoccerPlayer"],
|
||||
),
|
||||
}
|
||||
|
||||
def policy_mapping_fn(agent_id, episode, worker, **kwargs):
|
||||
return "BluePlayer" if "1_" in agent_id else "PurplePlayer"
|
||||
|
||||
else:
|
||||
policies = {
|
||||
game_name: PolicySpec(
|
||||
|
|
|
@ -40,6 +40,7 @@ parser.add_argument(
|
|||
"GridFoodCollector",
|
||||
"Pyramids",
|
||||
"SoccerStrikersVsGoalie",
|
||||
"SoccerTwos",
|
||||
"Sorter",
|
||||
"Tennis",
|
||||
"VisualHallway",
|
||||
|
|
Loading…
Add table
Reference in a new issue