ray/rllib/examples/custom_model_api.py

Ignoring revisions in .git-blame-ignore-revs. Click here to bypass and see the normal blame view.

109 lines
3.7 KiB
Python
Raw Normal View History

import argparse
from gym.spaces import Box, Discrete
import numpy as np
from ray.rllib.examples.models.custom_model_api import (
DuelingQModel,
TorchDuelingQModel,
ContActionQModel,
TorchContActionQModel,
)
from ray.rllib.models.catalog import ModelCatalog, MODEL_DEFAULTS
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.utils.framework import try_import_tf, try_import_torch
tf1, tf, tfv = try_import_tf()
torch, _ = try_import_torch()
parser = argparse.ArgumentParser()
parser.add_argument(
"--framework",
choices=["tf", "tf2", "tfe", "torch"],
default="tf",
help="The DL framework specifier.",
)
if __name__ == "__main__":
args = parser.parse_args()
# Test API wrapper for dueling Q-head.
obs_space = Box(-1.0, 1.0, (3,))
action_space = Discrete(3)
# Run in eager mode for value checking and debugging.
tf1.enable_eager_execution()
# __sphinx_doc_model_construct_1_begin__
my_dueling_model = ModelCatalog.get_model_v2(
obs_space=obs_space,
action_space=action_space,
num_outputs=action_space.n,
model_config=MODEL_DEFAULTS,
framework=args.framework,
# Providing the `model_interface` arg will make the factory
# wrap the chosen default model with our new model API class
# (DuelingQModel). This way, both `forward` and `get_q_values`
# are available in the returned class.
model_interface=DuelingQModel
if args.framework != "torch"
else TorchDuelingQModel,
name="dueling_q_model",
)
# __sphinx_doc_model_construct_1_end__
batch_size = 10
input_ = np.array([obs_space.sample() for _ in range(batch_size)])
# Note that for PyTorch, you will have to provide torch tensors here.
if args.framework == "torch":
input_ = torch.from_numpy(input_)
input_dict = SampleBatch(obs=input_, _is_training=False)
out, state_outs = my_dueling_model(input_dict=input_dict)
assert out.shape == (10, 256)
# Pass `out` into `get_q_values`
q_values = my_dueling_model.get_q_values(out)
assert q_values.shape == (10, action_space.n)
# Test API wrapper for single value Q-head from obs/action input.
obs_space = Box(-1.0, 1.0, (3,))
action_space = Box(-1.0, -1.0, (2,))
# __sphinx_doc_model_construct_2_begin__
my_cont_action_q_model = ModelCatalog.get_model_v2(
obs_space=obs_space,
action_space=action_space,
num_outputs=2,
model_config=MODEL_DEFAULTS,
framework=args.framework,
# Providing the `model_interface` arg will make the factory
# wrap the chosen default model with our new model API class
# (DuelingQModel). This way, both `forward` and `get_q_values`
# are available in the returned class.
model_interface=ContActionQModel
if args.framework != "torch"
else TorchContActionQModel,
name="cont_action_q_model",
)
# __sphinx_doc_model_construct_2_end__
batch_size = 10
input_ = np.array([obs_space.sample() for _ in range(batch_size)])
# Note that for PyTorch, you will have to provide torch tensors here.
if args.framework == "torch":
input_ = torch.from_numpy(input_)
input_dict = SampleBatch(obs=input_, _is_training=False)
# Note that for PyTorch, you will have to provide torch tensors here.
out, state_outs = my_cont_action_q_model(input_dict=input_dict)
assert out.shape == (10, 256)
# Pass `out` and an action into `my_cont_action_q_model`
action = np.array([action_space.sample() for _ in range(batch_size)])
if args.framework == "torch":
action = torch.from_numpy(action)
q_value = my_cont_action_q_model.get_single_q_value(out, action)
assert q_value.shape == (10, 1)