mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
172 lines
7.3 KiB
Python
172 lines
7.3 KiB
Python
"""Wrap Kaggle's environment
|
|
|
|
Source: https://github.com/Kaggle/kaggle-environments
|
|
"""
|
|
|
|
from copy import deepcopy
|
|
from typing import Any, Dict, Optional, Tuple
|
|
|
|
try:
|
|
import kaggle_environments
|
|
except (ImportError, ModuleNotFoundError):
|
|
pass
|
|
import numpy as np
|
|
from gym.spaces import Box
|
|
from gym.spaces import Dict as DictSpace
|
|
from gym.spaces import Discrete, MultiBinary, MultiDiscrete, Space
|
|
from gym.spaces import Tuple as TupleSpace
|
|
|
|
from ray.rllib.env import MultiAgentEnv
|
|
from ray.rllib.utils.typing import MultiAgentDict, AgentID
|
|
|
|
|
|
class KaggleFootballMultiAgentEnv(MultiAgentEnv):
|
|
"""An interface to the kaggle's football environment.
|
|
|
|
See: https://github.com/Kaggle/kaggle-environments
|
|
"""
|
|
|
|
def __init__(self, configuration: Optional[Dict[str, Any]] = None) -> None:
|
|
"""Initializes a Kaggle football environment.
|
|
|
|
Args:
|
|
configuration (Optional[Dict[str, Any]]): configuration of the
|
|
football environment. For detailed information, see:
|
|
https://github.com/Kaggle/kaggle-environments/blob/master/kaggle_environments/envs/football/football.json
|
|
"""
|
|
super().__init__()
|
|
self.kaggle_env = kaggle_environments.make(
|
|
"football", configuration=configuration or {}
|
|
)
|
|
self.last_cumulative_reward = None
|
|
|
|
def reset(self) -> MultiAgentDict:
|
|
kaggle_state = self.kaggle_env.reset()
|
|
self.last_cumulative_reward = None
|
|
return {
|
|
f"agent{idx}": self._convert_obs(agent_state["observation"])
|
|
for idx, agent_state in enumerate(kaggle_state)
|
|
if agent_state["status"] == "ACTIVE"
|
|
}
|
|
|
|
def step(
|
|
self, action_dict: Dict[AgentID, int]
|
|
) -> Tuple[MultiAgentDict, MultiAgentDict, MultiAgentDict, MultiAgentDict]:
|
|
# Convert action_dict (used by RLlib) to a list of actions (used by
|
|
# kaggle_environments)
|
|
action_list = [None] * len(self.kaggle_env.state)
|
|
for idx, agent_state in enumerate(self.kaggle_env.state):
|
|
if agent_state["status"] == "ACTIVE":
|
|
action = action_dict[f"agent{idx}"]
|
|
action_list[idx] = [action]
|
|
self.kaggle_env.step(action_list)
|
|
|
|
# Parse (obs, reward, done, info) from kaggle's "state" representation
|
|
obs = {}
|
|
cumulative_reward = {}
|
|
done = {"__all__": self.kaggle_env.done}
|
|
info = {}
|
|
for idx in range(len(self.kaggle_env.state)):
|
|
agent_state = self.kaggle_env.state[idx]
|
|
agent_name = f"agent{idx}"
|
|
if agent_state["status"] == "ACTIVE":
|
|
obs[agent_name] = self._convert_obs(agent_state["observation"])
|
|
cumulative_reward[agent_name] = agent_state["reward"]
|
|
done[agent_name] = agent_state["status"] != "ACTIVE"
|
|
info[agent_name] = agent_state["info"]
|
|
# Compute the step rewards from the cumulative rewards
|
|
if self.last_cumulative_reward is not None:
|
|
reward = {
|
|
agent_id: agent_reward - self.last_cumulative_reward[agent_id]
|
|
for agent_id, agent_reward in cumulative_reward.items()
|
|
}
|
|
else:
|
|
reward = cumulative_reward
|
|
self.last_cumulative_reward = cumulative_reward
|
|
return obs, reward, done, info
|
|
|
|
def _convert_obs(self, obs: Dict[str, Any]) -> Dict[str, Any]:
|
|
"""Convert raw observations
|
|
|
|
These conversions are necessary to make the observations fall into the
|
|
observation space defined below.
|
|
"""
|
|
new_obs = deepcopy(obs)
|
|
if new_obs["players_raw"][0]["ball_owned_team"] == -1:
|
|
new_obs["players_raw"][0]["ball_owned_team"] = 2
|
|
if new_obs["players_raw"][0]["ball_owned_player"] == -1:
|
|
new_obs["players_raw"][0]["ball_owned_player"] = 11
|
|
new_obs["players_raw"][0]["steps_left"] = [
|
|
new_obs["players_raw"][0]["steps_left"]
|
|
]
|
|
return new_obs
|
|
|
|
def build_agent_spaces(self) -> Tuple[Space, Space]:
|
|
"""Construct the action and observation spaces
|
|
|
|
Description of actions and observations:
|
|
https://github.com/google-research/football/blob/master/gfootball/doc/observation.md
|
|
""" # noqa: E501
|
|
action_space = Discrete(19)
|
|
# The football field's corners are [+-1., +-0.42]. However, the players
|
|
# and balls may get out of the field. Thus we multiply those limits by
|
|
# a factor of 2.
|
|
xlim = 1.0 * 2
|
|
ylim = 0.42 * 2
|
|
num_players: int = 11
|
|
xy_space = Box(
|
|
np.array([-xlim, -ylim], dtype=np.float32),
|
|
np.array([xlim, ylim], dtype=np.float32),
|
|
)
|
|
xyz_space = Box(
|
|
np.array([-xlim, -ylim, 0], dtype=np.float32),
|
|
np.array([xlim, ylim, np.inf], dtype=np.float32),
|
|
)
|
|
observation_space = DictSpace(
|
|
{
|
|
"controlled_players": Discrete(2),
|
|
"players_raw": TupleSpace(
|
|
[
|
|
DictSpace(
|
|
{
|
|
# ball information
|
|
"ball": xyz_space,
|
|
"ball_direction": Box(-np.inf, np.inf, (3,)),
|
|
"ball_rotation": Box(-np.inf, np.inf, (3,)),
|
|
"ball_owned_team": Discrete(3),
|
|
"ball_owned_player": Discrete(num_players + 1),
|
|
# left team
|
|
"left_team": TupleSpace([xy_space] * num_players),
|
|
"left_team_direction": TupleSpace(
|
|
[xy_space] * num_players
|
|
),
|
|
"left_team_tired_factor": Box(0.0, 1.0, (num_players,)),
|
|
"left_team_yellow_card": MultiBinary(num_players),
|
|
"left_team_active": MultiBinary(num_players),
|
|
"left_team_roles": MultiDiscrete([10] * num_players),
|
|
# right team
|
|
"right_team": TupleSpace([xy_space] * num_players),
|
|
"right_team_direction": TupleSpace(
|
|
[xy_space] * num_players
|
|
),
|
|
"right_team_tired_factor": Box(
|
|
0.0, 1.0, (num_players,)
|
|
),
|
|
"right_team_yellow_card": MultiBinary(num_players),
|
|
"right_team_active": MultiBinary(num_players),
|
|
"right_team_roles": MultiDiscrete([10] * num_players),
|
|
# controlled player information
|
|
"active": Discrete(num_players),
|
|
"designated": Discrete(num_players),
|
|
"sticky_actions": MultiBinary(10),
|
|
# match state
|
|
"score": Box(-np.inf, np.inf, (2,)),
|
|
"steps_left": Box(0, np.inf, (1,)),
|
|
"game_mode": Discrete(7),
|
|
}
|
|
)
|
|
]
|
|
),
|
|
}
|
|
)
|
|
return action_space, observation_space
|