mirror of
https://github.com/vale981/ray
synced 2025-03-05 18:11:42 -05:00
[RLlib] Add simple action-masking example script/env/model (tf and torch). (#18494)
This commit is contained in:
parent
370473fc5f
commit
ea4a22249c
8 changed files with 240 additions and 8 deletions
17
rllib/BUILD
17
rllib/BUILD
|
@ -1866,6 +1866,23 @@ py_test(
|
|||
# for `examples/all_stuff.py`.
|
||||
# --------------------------------------------------------------------
|
||||
|
||||
py_test(
|
||||
name = "examples/action_masking_tf",
|
||||
main = "examples/action_masking.py",
|
||||
tags = ["team:ml", "examples", "examples_A"],
|
||||
size = "small",
|
||||
srcs = ["examples/action_masking.py"],
|
||||
args = ["--stop-iter=2"]
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "examples/action_masking_torch",
|
||||
main = "examples/action_masking.py",
|
||||
tags = ["team:ml", "examples", "examples_A"],
|
||||
size = "small",
|
||||
srcs = ["examples/action_masking.py"],
|
||||
args = ["--stop-iter=2", "--framework=torch"]
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "examples/attention_net_tf",
|
||||
|
|
|
@ -215,11 +215,6 @@ def validate_config(config):
|
|||
if config["entropy_coeff"] < 0.0:
|
||||
raise ValueError("`entropy_coeff` must be >= 0.0!")
|
||||
|
||||
if config["vtrace"] and not config["in_evaluation"]:
|
||||
if config["batch_mode"] != "truncate_episodes":
|
||||
raise ValueError(
|
||||
"Must use `batch_mode`=truncate_episodes if `vtrace` is True.")
|
||||
|
||||
# Check whether worker to aggregation-worker ratio makes sense.
|
||||
if config["num_aggregation_workers"] > config["num_workers"]:
|
||||
raise ValueError(
|
||||
|
|
|
@ -138,7 +138,9 @@ def _make_time_major(policy, seq_lens, tensor, drop_last=False):
|
|||
T = tf.shape(tensor)[0] // B
|
||||
else:
|
||||
# Important: chop the tensor into batches at known episode cut
|
||||
# boundaries. TODO(ekl) this is kind of a hack
|
||||
# boundaries.
|
||||
# TODO: (sven) this is kind of a hack and won't work for
|
||||
# batch_mode=complete_episodes.
|
||||
T = policy.config["rollout_fragment_length"]
|
||||
B = tf.shape(tensor)[0] // T
|
||||
rs = tf.reshape(tensor, tf.concat([[B, T], tf.shape(tensor)[1:]], axis=0))
|
||||
|
|
|
@ -222,6 +222,8 @@ def make_time_major(policy, seq_lens, tensor, drop_last=False):
|
|||
else:
|
||||
# Important: chop the tensor into batches at known episode cut
|
||||
# boundaries.
|
||||
# TODO: (sven) this is kind of a hack and won't work for
|
||||
# batch_mode=complete_episodes.
|
||||
T = policy.config["rollout_fragment_length"]
|
||||
B = tensor.shape[0] // T
|
||||
rs = torch.reshape(tensor, [B, T] + list(tensor.shape[1:]))
|
||||
|
|
76
rllib/examples/action_masking.py
Normal file
76
rllib/examples/action_masking.py
Normal file
|
@ -0,0 +1,76 @@
|
|||
import argparse
|
||||
from gym.spaces import Box, Discrete
|
||||
import os
|
||||
|
||||
from ray.rllib.examples.env.action_mask_env import ActionMaskEnv
|
||||
from ray.rllib.examples.models.action_mask_model import \
|
||||
ActionMaskModel, TorchActionMaskModel
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--run",
|
||||
type=str,
|
||||
default="APPO",
|
||||
help="The RLlib-registered algorithm to use.")
|
||||
parser.add_argument("--num-cpus", type=int, default=0)
|
||||
parser.add_argument(
|
||||
"--framework",
|
||||
choices=["tf", "tf2", "tfe", "torch"],
|
||||
default="tf",
|
||||
help="The DL framework specifier.")
|
||||
parser.add_argument("--eager-tracing", action="store_true")
|
||||
parser.add_argument(
|
||||
"--stop-iters",
|
||||
type=int,
|
||||
default=200,
|
||||
help="Number of iterations to train.")
|
||||
parser.add_argument(
|
||||
"--stop-timesteps",
|
||||
type=int,
|
||||
default=100000,
|
||||
help="Number of timesteps to train.")
|
||||
parser.add_argument(
|
||||
"--stop-reward",
|
||||
type=float,
|
||||
default=80.0,
|
||||
help="Reward at which we stop training.")
|
||||
parser.add_argument(
|
||||
"--local-mode",
|
||||
action="store_true",
|
||||
help="Init Ray in local mode for easier debugging.")
|
||||
|
||||
if __name__ == "__main__":
|
||||
import ray
|
||||
from ray import tune
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
ray.init(num_cpus=args.num_cpus or None, local_mode=args.local_mode)
|
||||
|
||||
config = {
|
||||
"env": ActionMaskEnv,
|
||||
"env_config": {
|
||||
"action_space": Discrete(100),
|
||||
"observation_space": Box(-1.0, 1.0, (5, )),
|
||||
},
|
||||
"model": {
|
||||
"custom_model": ActionMaskModel
|
||||
if args.framework != "torch" else TorchActionMaskModel,
|
||||
},
|
||||
|
||||
# Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0.
|
||||
"num_gpus": int(os.environ.get("RLLIB_NUM_GPUS", "0")),
|
||||
"framework": args.framework,
|
||||
# Run with tracing enabled for tfe/tf2?
|
||||
"eager_tracing": args.eager_tracing,
|
||||
}
|
||||
|
||||
stop = {
|
||||
"training_iteration": args.stop_iters,
|
||||
"timesteps_total": args.stop_timesteps,
|
||||
"episode_reward_mean": args.stop_reward,
|
||||
}
|
||||
|
||||
results = tune.run(args.run, config=config, stop=stop, verbose=2)
|
||||
|
||||
ray.shutdown()
|
40
rllib/examples/env/action_mask_env.py
vendored
Normal file
40
rllib/examples/env/action_mask_env.py
vendored
Normal file
|
@ -0,0 +1,40 @@
|
|||
from gym.spaces import Box, Dict, Discrete
|
||||
import numpy as np
|
||||
|
||||
from ray.rllib.examples.env.random_env import RandomEnv
|
||||
|
||||
|
||||
class ActionMaskEnv(RandomEnv):
|
||||
"""A randomly acting environment that publishes an action-mask each step.
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
|
||||
# Masking only works for Discrete actions.
|
||||
assert isinstance(self.action_space, Discrete)
|
||||
# Add action_mask to observations.
|
||||
self.observation_space = Dict({
|
||||
"action_mask": Box(0.0, 1.0, shape=(self.action_space.n, )),
|
||||
"observations": self.observation_space,
|
||||
})
|
||||
self.valid_actions = None
|
||||
|
||||
def reset(self):
|
||||
obs = super().reset()
|
||||
self._fix_action_mask(obs)
|
||||
return obs
|
||||
|
||||
def step(self, action):
|
||||
# Check whether action is valid.
|
||||
if not self.valid_actions[action]:
|
||||
raise ValueError(f"Invalid action sent to env! "
|
||||
f"valid_actions={self.valid_actions}")
|
||||
obs, rew, done, info = super().step(action)
|
||||
self._fix_action_mask(obs)
|
||||
return obs, rew, done, info
|
||||
|
||||
def _fix_action_mask(self, obs):
|
||||
# Fix action-mask: Everything larger 0.5 is 1.0, everything else 0.0.
|
||||
self.valid_actions = np.round(obs["action_mask"])
|
||||
obs["action_mask"] = self.valid_actions
|
4
rllib/examples/env/random_env.py
vendored
4
rllib/examples/env/random_env.py
vendored
|
@ -2,7 +2,7 @@ import gym
|
|||
from gym.spaces import Discrete, Tuple
|
||||
import numpy as np
|
||||
|
||||
from ray.rllib.examples.env.multi_agent import make_multiagent
|
||||
from ray.rllib.examples.env.multi_agent import make_multi_agent
|
||||
|
||||
|
||||
class RandomEnv(gym.Env):
|
||||
|
@ -62,4 +62,4 @@ class RandomEnv(gym.Env):
|
|||
|
||||
|
||||
# Multi-agent version of the RandomEnv.
|
||||
RandomMultiAgentEnv = make_multiagent(lambda c: RandomEnv(c))
|
||||
RandomMultiAgentEnv = make_multi_agent(lambda c: RandomEnv(c))
|
||||
|
|
100
rllib/examples/models/action_mask_model.py
Normal file
100
rllib/examples/models/action_mask_model.py
Normal file
|
@ -0,0 +1,100 @@
|
|||
from gym.spaces import Dict
|
||||
|
||||
from ray.rllib.models.tf.fcnet import FullyConnectedNetwork
|
||||
from ray.rllib.models.tf.tf_modelv2 import TFModelV2
|
||||
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
|
||||
from ray.rllib.models.torch.fcnet import FullyConnectedNetwork as TorchFC
|
||||
from ray.rllib.utils.framework import try_import_tf, try_import_torch
|
||||
from ray.rllib.utils.torch_ops import FLOAT_MIN
|
||||
|
||||
tf1, tf, tfv = try_import_tf()
|
||||
torch, nn = try_import_torch()
|
||||
|
||||
|
||||
class ActionMaskModel(TFModelV2):
|
||||
"""Model that handles simple discrete action masking.
|
||||
|
||||
This assumes the outputs are logits for a single Categorical action dist.
|
||||
Getting this to work with a more complex output (e.g., if the action space
|
||||
is a tuple of several distributions) is also possible but left as an
|
||||
exercise to the reader.
|
||||
"""
|
||||
|
||||
def __init__(self, obs_space, action_space, num_outputs, model_config,
|
||||
name, **kwargs):
|
||||
|
||||
orig_space = getattr(obs_space, "original_space", obs_space)
|
||||
assert isinstance(orig_space, Dict) and \
|
||||
"action_mask" in orig_space.spaces and \
|
||||
"observations" in orig_space.spaces
|
||||
|
||||
super().__init__(obs_space, action_space, num_outputs, model_config,
|
||||
name)
|
||||
|
||||
self.internal_model = FullyConnectedNetwork(
|
||||
orig_space["observations"], action_space, num_outputs,
|
||||
model_config, name + "_internal")
|
||||
|
||||
def forward(self, input_dict, state, seq_lens):
|
||||
# Extract the available actions tensor from the observation.
|
||||
action_mask = input_dict["obs"]["action_mask"]
|
||||
|
||||
# Compute the unmasked logits.
|
||||
logits, _ = self.internal_model({
|
||||
"obs": input_dict["obs"]["observations"]
|
||||
})
|
||||
|
||||
# Convert action_mask into a [0.0 || -inf]-type mask.
|
||||
inf_mask = tf.maximum(tf.math.log(action_mask), tf.float32.min)
|
||||
masked_logits = logits + inf_mask
|
||||
|
||||
# Return masked logits.
|
||||
return masked_logits, state
|
||||
|
||||
def value_function(self):
|
||||
return self.internal_model.value_function()
|
||||
|
||||
|
||||
class TorchActionMaskModel(TorchModelV2, nn.Module):
|
||||
"""PyTorch version of above ActionMaskingModel."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
obs_space,
|
||||
action_space,
|
||||
num_outputs,
|
||||
model_config,
|
||||
name,
|
||||
**kwargs,
|
||||
):
|
||||
orig_space = getattr(obs_space, "original_space", obs_space)
|
||||
assert isinstance(orig_space, Dict) and \
|
||||
"action_mask" in orig_space.spaces and \
|
||||
"observations" in orig_space.spaces
|
||||
|
||||
TorchModelV2.__init__(self, obs_space, action_space, num_outputs,
|
||||
model_config, name, **kwargs)
|
||||
nn.Module.__init__(self)
|
||||
|
||||
self.internal_model = TorchFC(orig_space["observations"], action_space,
|
||||
num_outputs, model_config,
|
||||
name + "_internal")
|
||||
|
||||
def forward(self, input_dict, state, seq_lens):
|
||||
# Extract the available actions tensor from the observation.
|
||||
action_mask = input_dict["obs"]["action_mask"]
|
||||
|
||||
# Compute the unmasked logits.
|
||||
logits, _ = self.internal_model({
|
||||
"obs": input_dict["obs"]["observations"]
|
||||
})
|
||||
|
||||
# Convert action_mask into a [0.0 || -inf]-type mask.
|
||||
inf_mask = torch.clamp(torch.log(action_mask), min=FLOAT_MIN)
|
||||
masked_logits = logits + inf_mask
|
||||
|
||||
# Return masked logits.
|
||||
return masked_logits, state
|
||||
|
||||
def value_function(self):
|
||||
return self.internal_model.value_function()
|
Loading…
Add table
Reference in a new issue