mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
106 lines
3.7 KiB
Python
106 lines
3.7 KiB
Python
import argparse
|
|
from gym.spaces import Box, Discrete
|
|
import numpy as np
|
|
|
|
from ray.rllib.examples.models.custom_model_api import DuelingQModel, \
|
|
TorchDuelingQModel, ContActionQModel, TorchContActionQModel
|
|
from ray.rllib.models.catalog import ModelCatalog, MODEL_DEFAULTS
|
|
from ray.rllib.utils.framework import try_import_tf, try_import_torch
|
|
|
|
tf1, tf, tfv = try_import_tf()
|
|
torch, _ = try_import_torch()
|
|
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument(
|
|
"--framework",
|
|
choices=["tf", "tf2", "tfe", "torch"],
|
|
default="tf",
|
|
help="The DL framework specifier.")
|
|
|
|
if __name__ == "__main__":
|
|
args = parser.parse_args()
|
|
|
|
# Test API wrapper for dueling Q-head.
|
|
|
|
obs_space = Box(-1.0, 1.0, (3, ))
|
|
action_space = Discrete(3)
|
|
|
|
# Run in eager mode for value checking and debugging.
|
|
tf1.enable_eager_execution()
|
|
|
|
# __sphinx_doc_model_construct_1_begin__
|
|
my_dueling_model = ModelCatalog.get_model_v2(
|
|
obs_space=obs_space,
|
|
action_space=action_space,
|
|
num_outputs=action_space.n,
|
|
model_config=MODEL_DEFAULTS,
|
|
framework=args.framework,
|
|
# Providing the `model_interface` arg will make the factory
|
|
# wrap the chosen default model with our new model API class
|
|
# (DuelingQModel). This way, both `forward` and `get_q_values`
|
|
# are available in the returned class.
|
|
model_interface=DuelingQModel
|
|
if args.framework != "torch" else TorchDuelingQModel,
|
|
name="dueling_q_model",
|
|
)
|
|
# __sphinx_doc_model_construct_1_end__
|
|
|
|
batch_size = 10
|
|
input_ = np.array([obs_space.sample() for _ in range(batch_size)])
|
|
# Note that for PyTorch, you will have to provide torch tensors here.
|
|
if args.framework == "torch":
|
|
input_ = torch.from_numpy(input_)
|
|
|
|
input_dict = {
|
|
"obs": input_,
|
|
"is_training": False,
|
|
}
|
|
out, state_outs = my_dueling_model(input_dict=input_dict)
|
|
assert out.shape == (10, 256)
|
|
# Pass `out` into `get_q_values`
|
|
q_values = my_dueling_model.get_q_values(out)
|
|
assert q_values.shape == (10, action_space.n)
|
|
|
|
# Test API wrapper for single value Q-head from obs/action input.
|
|
|
|
obs_space = Box(-1.0, 1.0, (3, ))
|
|
action_space = Box(-1.0, -1.0, (2, ))
|
|
|
|
# __sphinx_doc_model_construct_2_begin__
|
|
my_cont_action_q_model = ModelCatalog.get_model_v2(
|
|
obs_space=obs_space,
|
|
action_space=action_space,
|
|
num_outputs=2,
|
|
model_config=MODEL_DEFAULTS,
|
|
framework=args.framework,
|
|
# Providing the `model_interface` arg will make the factory
|
|
# wrap the chosen default model with our new model API class
|
|
# (DuelingQModel). This way, both `forward` and `get_q_values`
|
|
# are available in the returned class.
|
|
model_interface=ContActionQModel
|
|
if args.framework != "torch" else TorchContActionQModel,
|
|
name="cont_action_q_model",
|
|
)
|
|
# __sphinx_doc_model_construct_2_end__
|
|
|
|
batch_size = 10
|
|
input_ = np.array([obs_space.sample() for _ in range(batch_size)])
|
|
|
|
# Note that for PyTorch, you will have to provide torch tensors here.
|
|
if args.framework == "torch":
|
|
input_ = torch.from_numpy(input_)
|
|
|
|
input_dict = {
|
|
"obs": input_,
|
|
"is_training": False,
|
|
}
|
|
# Note that for PyTorch, you will have to provide torch tensors here.
|
|
out, state_outs = my_cont_action_q_model(input_dict=input_dict)
|
|
assert out.shape == (10, 256)
|
|
# Pass `out` and an action into `my_cont_action_q_model`
|
|
action = np.array([action_space.sample() for _ in range(batch_size)])
|
|
if args.framework == "torch":
|
|
action = torch.from_numpy(action)
|
|
|
|
q_value = my_cont_action_q_model.get_single_q_value(out, action)
|
|
assert q_value.shape == (10, 1)
|