ray/rllib/examples/custom_model_api.py

106 lines
3.7 KiB
Python

import argparse
from gym.spaces import Box, Discrete
import numpy as np
from ray.rllib.examples.models.custom_model_api import DuelingQModel, \
TorchDuelingQModel, ContActionQModel, TorchContActionQModel
from ray.rllib.models.catalog import ModelCatalog, MODEL_DEFAULTS
from ray.rllib.utils.framework import try_import_tf, try_import_torch
tf1, tf, tfv = try_import_tf()
torch, _ = try_import_torch()
parser = argparse.ArgumentParser()
parser.add_argument(
"--framework",
choices=["tf", "tf2", "tfe", "torch"],
default="tf",
help="The DL framework specifier.")
if __name__ == "__main__":
args = parser.parse_args()
# Test API wrapper for dueling Q-head.
obs_space = Box(-1.0, 1.0, (3, ))
action_space = Discrete(3)
# Run in eager mode for value checking and debugging.
tf1.enable_eager_execution()
# __sphinx_doc_model_construct_1_begin__
my_dueling_model = ModelCatalog.get_model_v2(
obs_space=obs_space,
action_space=action_space,
num_outputs=action_space.n,
model_config=MODEL_DEFAULTS,
framework=args.framework,
# Providing the `model_interface` arg will make the factory
# wrap the chosen default model with our new model API class
# (DuelingQModel). This way, both `forward` and `get_q_values`
# are available in the returned class.
model_interface=DuelingQModel
if args.framework != "torch" else TorchDuelingQModel,
name="dueling_q_model",
)
# __sphinx_doc_model_construct_1_end__
batch_size = 10
input_ = np.array([obs_space.sample() for _ in range(batch_size)])
# Note that for PyTorch, you will have to provide torch tensors here.
if args.framework == "torch":
input_ = torch.from_numpy(input_)
input_dict = {
"obs": input_,
"is_training": False,
}
out, state_outs = my_dueling_model(input_dict=input_dict)
assert out.shape == (10, 256)
# Pass `out` into `get_q_values`
q_values = my_dueling_model.get_q_values(out)
assert q_values.shape == (10, action_space.n)
# Test API wrapper for single value Q-head from obs/action input.
obs_space = Box(-1.0, 1.0, (3, ))
action_space = Box(-1.0, -1.0, (2, ))
# __sphinx_doc_model_construct_2_begin__
my_cont_action_q_model = ModelCatalog.get_model_v2(
obs_space=obs_space,
action_space=action_space,
num_outputs=2,
model_config=MODEL_DEFAULTS,
framework=args.framework,
# Providing the `model_interface` arg will make the factory
# wrap the chosen default model with our new model API class
# (DuelingQModel). This way, both `forward` and `get_q_values`
# are available in the returned class.
model_interface=ContActionQModel
if args.framework != "torch" else TorchContActionQModel,
name="cont_action_q_model",
)
# __sphinx_doc_model_construct_2_end__
batch_size = 10
input_ = np.array([obs_space.sample() for _ in range(batch_size)])
# Note that for PyTorch, you will have to provide torch tensors here.
if args.framework == "torch":
input_ = torch.from_numpy(input_)
input_dict = {
"obs": input_,
"is_training": False,
}
# Note that for PyTorch, you will have to provide torch tensors here.
out, state_outs = my_cont_action_q_model(input_dict=input_dict)
assert out.shape == (10, 256)
# Pass `out` and an action into `my_cont_action_q_model`
action = np.array([action_space.sample() for _ in range(batch_size)])
if args.framework == "torch":
action = torch.from_numpy(action)
q_value = my_cont_action_q_model.get_single_q_value(out, action)
assert q_value.shape == (10, 1)