[RLlib] Add simple action-masking example script/env/model (tf and torch). (#18494)

This commit is contained in:
Sven Mika 2021-09-11 23:08:09 +02:00 committed by GitHub
parent 370473fc5f
commit ea4a22249c
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 240 additions and 8 deletions

View file

@ -1866,6 +1866,23 @@ py_test(
# for `examples/all_stuff.py`.
# --------------------------------------------------------------------
py_test(
name = "examples/action_masking_tf",
main = "examples/action_masking.py",
tags = ["team:ml", "examples", "examples_A"],
size = "small",
srcs = ["examples/action_masking.py"],
args = ["--stop-iter=2"]
)
py_test(
name = "examples/action_masking_torch",
main = "examples/action_masking.py",
tags = ["team:ml", "examples", "examples_A"],
size = "small",
srcs = ["examples/action_masking.py"],
args = ["--stop-iter=2", "--framework=torch"]
)
py_test(
name = "examples/attention_net_tf",

View file

@ -215,11 +215,6 @@ def validate_config(config):
if config["entropy_coeff"] < 0.0:
raise ValueError("`entropy_coeff` must be >= 0.0!")
if config["vtrace"] and not config["in_evaluation"]:
if config["batch_mode"] != "truncate_episodes":
raise ValueError(
"Must use `batch_mode`=truncate_episodes if `vtrace` is True.")
# Check whether worker to aggregation-worker ratio makes sense.
if config["num_aggregation_workers"] > config["num_workers"]:
raise ValueError(

View file

@ -138,7 +138,9 @@ def _make_time_major(policy, seq_lens, tensor, drop_last=False):
T = tf.shape(tensor)[0] // B
else:
# Important: chop the tensor into batches at known episode cut
# boundaries. TODO(ekl) this is kind of a hack
# boundaries.
# TODO: (sven) this is kind of a hack and won't work for
# batch_mode=complete_episodes.
T = policy.config["rollout_fragment_length"]
B = tf.shape(tensor)[0] // T
rs = tf.reshape(tensor, tf.concat([[B, T], tf.shape(tensor)[1:]], axis=0))

View file

@ -222,6 +222,8 @@ def make_time_major(policy, seq_lens, tensor, drop_last=False):
else:
# Important: chop the tensor into batches at known episode cut
# boundaries.
# TODO: (sven) this is kind of a hack and won't work for
# batch_mode=complete_episodes.
T = policy.config["rollout_fragment_length"]
B = tensor.shape[0] // T
rs = torch.reshape(tensor, [B, T] + list(tensor.shape[1:]))

View file

@ -0,0 +1,76 @@
import argparse
from gym.spaces import Box, Discrete
import os
from ray.rllib.examples.env.action_mask_env import ActionMaskEnv
from ray.rllib.examples.models.action_mask_model import \
ActionMaskModel, TorchActionMaskModel
parser = argparse.ArgumentParser()
parser.add_argument(
"--run",
type=str,
default="APPO",
help="The RLlib-registered algorithm to use.")
parser.add_argument("--num-cpus", type=int, default=0)
parser.add_argument(
"--framework",
choices=["tf", "tf2", "tfe", "torch"],
default="tf",
help="The DL framework specifier.")
parser.add_argument("--eager-tracing", action="store_true")
parser.add_argument(
"--stop-iters",
type=int,
default=200,
help="Number of iterations to train.")
parser.add_argument(
"--stop-timesteps",
type=int,
default=100000,
help="Number of timesteps to train.")
parser.add_argument(
"--stop-reward",
type=float,
default=80.0,
help="Reward at which we stop training.")
parser.add_argument(
"--local-mode",
action="store_true",
help="Init Ray in local mode for easier debugging.")
if __name__ == "__main__":
import ray
from ray import tune
args = parser.parse_args()
ray.init(num_cpus=args.num_cpus or None, local_mode=args.local_mode)
config = {
"env": ActionMaskEnv,
"env_config": {
"action_space": Discrete(100),
"observation_space": Box(-1.0, 1.0, (5, )),
},
"model": {
"custom_model": ActionMaskModel
if args.framework != "torch" else TorchActionMaskModel,
},
# Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0.
"num_gpus": int(os.environ.get("RLLIB_NUM_GPUS", "0")),
"framework": args.framework,
# Run with tracing enabled for tfe/tf2?
"eager_tracing": args.eager_tracing,
}
stop = {
"training_iteration": args.stop_iters,
"timesteps_total": args.stop_timesteps,
"episode_reward_mean": args.stop_reward,
}
results = tune.run(args.run, config=config, stop=stop, verbose=2)
ray.shutdown()

40
rllib/examples/env/action_mask_env.py vendored Normal file
View file

@ -0,0 +1,40 @@
from gym.spaces import Box, Dict, Discrete
import numpy as np
from ray.rllib.examples.env.random_env import RandomEnv
class ActionMaskEnv(RandomEnv):
"""A randomly acting environment that publishes an action-mask each step.
"""
def __init__(self, config):
super().__init__(config)
# Masking only works for Discrete actions.
assert isinstance(self.action_space, Discrete)
# Add action_mask to observations.
self.observation_space = Dict({
"action_mask": Box(0.0, 1.0, shape=(self.action_space.n, )),
"observations": self.observation_space,
})
self.valid_actions = None
def reset(self):
obs = super().reset()
self._fix_action_mask(obs)
return obs
def step(self, action):
# Check whether action is valid.
if not self.valid_actions[action]:
raise ValueError(f"Invalid action sent to env! "
f"valid_actions={self.valid_actions}")
obs, rew, done, info = super().step(action)
self._fix_action_mask(obs)
return obs, rew, done, info
def _fix_action_mask(self, obs):
# Fix action-mask: Everything larger 0.5 is 1.0, everything else 0.0.
self.valid_actions = np.round(obs["action_mask"])
obs["action_mask"] = self.valid_actions

View file

@ -2,7 +2,7 @@ import gym
from gym.spaces import Discrete, Tuple
import numpy as np
from ray.rllib.examples.env.multi_agent import make_multiagent
from ray.rllib.examples.env.multi_agent import make_multi_agent
class RandomEnv(gym.Env):
@ -62,4 +62,4 @@ class RandomEnv(gym.Env):
# Multi-agent version of the RandomEnv.
RandomMultiAgentEnv = make_multiagent(lambda c: RandomEnv(c))
RandomMultiAgentEnv = make_multi_agent(lambda c: RandomEnv(c))

View file

@ -0,0 +1,100 @@
from gym.spaces import Dict
from ray.rllib.models.tf.fcnet import FullyConnectedNetwork
from ray.rllib.models.tf.tf_modelv2 import TFModelV2
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
from ray.rllib.models.torch.fcnet import FullyConnectedNetwork as TorchFC
from ray.rllib.utils.framework import try_import_tf, try_import_torch
from ray.rllib.utils.torch_ops import FLOAT_MIN
tf1, tf, tfv = try_import_tf()
torch, nn = try_import_torch()
class ActionMaskModel(TFModelV2):
"""Model that handles simple discrete action masking.
This assumes the outputs are logits for a single Categorical action dist.
Getting this to work with a more complex output (e.g., if the action space
is a tuple of several distributions) is also possible but left as an
exercise to the reader.
"""
def __init__(self, obs_space, action_space, num_outputs, model_config,
name, **kwargs):
orig_space = getattr(obs_space, "original_space", obs_space)
assert isinstance(orig_space, Dict) and \
"action_mask" in orig_space.spaces and \
"observations" in orig_space.spaces
super().__init__(obs_space, action_space, num_outputs, model_config,
name)
self.internal_model = FullyConnectedNetwork(
orig_space["observations"], action_space, num_outputs,
model_config, name + "_internal")
def forward(self, input_dict, state, seq_lens):
# Extract the available actions tensor from the observation.
action_mask = input_dict["obs"]["action_mask"]
# Compute the unmasked logits.
logits, _ = self.internal_model({
"obs": input_dict["obs"]["observations"]
})
# Convert action_mask into a [0.0 || -inf]-type mask.
inf_mask = tf.maximum(tf.math.log(action_mask), tf.float32.min)
masked_logits = logits + inf_mask
# Return masked logits.
return masked_logits, state
def value_function(self):
return self.internal_model.value_function()
class TorchActionMaskModel(TorchModelV2, nn.Module):
"""PyTorch version of above ActionMaskingModel."""
def __init__(
self,
obs_space,
action_space,
num_outputs,
model_config,
name,
**kwargs,
):
orig_space = getattr(obs_space, "original_space", obs_space)
assert isinstance(orig_space, Dict) and \
"action_mask" in orig_space.spaces and \
"observations" in orig_space.spaces
TorchModelV2.__init__(self, obs_space, action_space, num_outputs,
model_config, name, **kwargs)
nn.Module.__init__(self)
self.internal_model = TorchFC(orig_space["observations"], action_space,
num_outputs, model_config,
name + "_internal")
def forward(self, input_dict, state, seq_lens):
# Extract the available actions tensor from the observation.
action_mask = input_dict["obs"]["action_mask"]
# Compute the unmasked logits.
logits, _ = self.internal_model({
"obs": input_dict["obs"]["observations"]
})
# Convert action_mask into a [0.0 || -inf]-type mask.
inf_mask = torch.clamp(torch.log(action_mask), min=FLOAT_MIN)
masked_logits = logits + inf_mask
# Return masked logits.
return masked_logits, state
def value_function(self):
return self.internal_model.value_function()