mirror of
https://github.com/vale981/ray
synced 2025-03-05 18:11:42 -05:00
[RLlib] Deprecate all Model(v1) usage. (#8146)
Deprecate all Model(v1) usage.
This commit is contained in:
parent
eb91619175
commit
bf25aee392
28 changed files with 551 additions and 676 deletions
|
@ -1406,7 +1406,7 @@ py_test(
|
|||
name = "examples/nested_action_spaces_ppo",
|
||||
main = "examples/nested_action_spaces.py",
|
||||
tags = ["examples", "examples_N"],
|
||||
size = "small",
|
||||
size = "medium",
|
||||
srcs = ["examples/nested_action_spaces.py"],
|
||||
args = ["--stop=-500", "--run=PPO"]
|
||||
)
|
||||
|
|
|
@ -36,15 +36,18 @@ class ARSTFPolicy:
|
|||
dist_class, dist_dim = ModelCatalog.get_action_dist(
|
||||
self.action_space, config["model"], dist_type="deterministic")
|
||||
|
||||
model = ModelCatalog.get_model({
|
||||
SampleBatch.CUR_OBS: self.inputs
|
||||
}, self.observation_space, self.action_space, dist_dim,
|
||||
config["model"])
|
||||
dist = dist_class(model.outputs, model)
|
||||
self.model = ModelCatalog.get_model_v2(
|
||||
obs_space=self.preprocessor.observation_space,
|
||||
action_space=self.action_space,
|
||||
num_outputs=dist_dim,
|
||||
model_config=config["model"])
|
||||
dist_inputs, _ = self.model({SampleBatch.CUR_OBS: self.inputs})
|
||||
dist = dist_class(dist_inputs, self.model)
|
||||
|
||||
self.sampler = dist.sample()
|
||||
|
||||
self.variables = ray.experimental.tf_utils.TensorFlowVariables(
|
||||
model.outputs, self.sess)
|
||||
dist_inputs, self.sess)
|
||||
|
||||
self.num_params = sum(
|
||||
np.prod(variable.shape.as_list())
|
||||
|
|
|
@ -1,403 +0,0 @@
|
|||
from gym.spaces import Box
|
||||
import logging
|
||||
import numpy as np
|
||||
|
||||
import ray
|
||||
import ray.experimental.tf_utils
|
||||
from ray.rllib.agents.ddpg.ddpg_model import DDPGModel
|
||||
from ray.rllib.agents.ddpg.noop_model import NoopModel
|
||||
from ray.rllib.agents.dqn.dqn_tf_policy import postprocess_nstep_and_prio, \
|
||||
PRIO_WEIGHTS
|
||||
from ray.rllib.policy.sample_batch import SampleBatch
|
||||
from ray.rllib.models import ModelCatalog
|
||||
from ray.rllib.models.tf.tf_action_dist import Deterministic
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils.error import UnsupportedSpaceException
|
||||
from ray.rllib.policy.tf_policy import TFPolicy
|
||||
from ray.rllib.policy.tf_policy_template import build_tf_policy
|
||||
from ray.rllib.utils import try_import_tf
|
||||
from ray.rllib.utils.tf_ops import huber_loss, minimize_and_clip, \
|
||||
make_tf_callable
|
||||
|
||||
tf = try_import_tf()
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
ACTION_SCOPE = "action"
|
||||
POLICY_SCOPE = "policy"
|
||||
POLICY_TARGET_SCOPE = "target_policy"
|
||||
Q_SCOPE = "critic"
|
||||
Q_TARGET_SCOPE = "target_critic"
|
||||
TWIN_Q_SCOPE = "twin_critic"
|
||||
TWIN_Q_TARGET_SCOPE = "twin_target_critic"
|
||||
|
||||
|
||||
def build_ddpg_models(policy, observation_space, action_space, config):
|
||||
if config["model"]["custom_model"]:
|
||||
logger.warning(
|
||||
"Setting use_state_preprocessor=True since a custom model "
|
||||
"was specified.")
|
||||
config["use_state_preprocessor"] = True
|
||||
|
||||
if not isinstance(action_space, Box):
|
||||
raise UnsupportedSpaceException(
|
||||
"Action space {} is not supported for DDPG.".format(action_space))
|
||||
elif len(action_space.shape) > 1:
|
||||
raise UnsupportedSpaceException(
|
||||
"Action space has multiple dimensions "
|
||||
"{}. ".format(action_space.shape) +
|
||||
"Consider reshaping this into a single dimension, "
|
||||
"using a Tuple action space, or the multi-agent API.")
|
||||
|
||||
if policy.config["use_state_preprocessor"]:
|
||||
default_model = None # catalog decides
|
||||
num_outputs = 256 # arbitrary
|
||||
config["model"]["no_final_linear"] = True
|
||||
else:
|
||||
default_model = NoopModel
|
||||
num_outputs = int(np.product(observation_space.shape))
|
||||
|
||||
policy.model = ModelCatalog.get_model_v2(
|
||||
obs_space=observation_space,
|
||||
action_space=action_space,
|
||||
num_outputs=num_outputs,
|
||||
model_config=config["model"],
|
||||
framework="tf",
|
||||
model_interface=DDPGModel,
|
||||
default_model=default_model,
|
||||
name="ddpg_model",
|
||||
actor_hidden_activation=config["actor_hidden_activation"],
|
||||
actor_hiddens=config["actor_hiddens"],
|
||||
critic_hidden_activation=config["critic_hidden_activation"],
|
||||
critic_hiddens=config["critic_hiddens"],
|
||||
twin_q=config["twin_q"],
|
||||
add_layer_norm=(policy.config["exploration_config"].get("type") ==
|
||||
"ParameterNoise"),
|
||||
)
|
||||
|
||||
policy.target_model = ModelCatalog.get_model_v2(
|
||||
obs_space=observation_space,
|
||||
action_space=action_space,
|
||||
num_outputs=num_outputs,
|
||||
model_config=config["model"],
|
||||
framework="tf",
|
||||
model_interface=DDPGModel,
|
||||
default_model=default_model,
|
||||
name="target_ddpg_model",
|
||||
actor_hidden_activation=config["actor_hidden_activation"],
|
||||
actor_hiddens=config["actor_hiddens"],
|
||||
critic_hidden_activation=config["critic_hidden_activation"],
|
||||
critic_hiddens=config["critic_hiddens"],
|
||||
twin_q=config["twin_q"],
|
||||
add_layer_norm=(policy.config["exploration_config"].get("type") ==
|
||||
"ParameterNoise"),
|
||||
)
|
||||
|
||||
return policy.model
|
||||
|
||||
|
||||
def get_distribution_inputs_and_class(policy,
|
||||
model,
|
||||
obs_batch,
|
||||
*,
|
||||
explore=True,
|
||||
**kwargs):
|
||||
model_out, _ = model({
|
||||
"obs": obs_batch,
|
||||
"is_training": policy._get_is_training_placeholder()
|
||||
}, [], None)
|
||||
dist_inputs = model.get_policy_output(model_out)
|
||||
|
||||
return dist_inputs, Deterministic, [] # []=state out
|
||||
|
||||
|
||||
def ddpg_actor_critic_loss(policy, model, _, train_batch):
|
||||
twin_q = policy.config["twin_q"]
|
||||
gamma = policy.config["gamma"]
|
||||
n_step = policy.config["n_step"]
|
||||
use_huber = policy.config["use_huber"]
|
||||
huber_threshold = policy.config["huber_threshold"]
|
||||
l2_reg = policy.config["l2_reg"]
|
||||
|
||||
input_dict = {
|
||||
"obs": train_batch[SampleBatch.CUR_OBS],
|
||||
"is_training": policy._get_is_training_placeholder(),
|
||||
}
|
||||
input_dict_next = {
|
||||
"obs": train_batch[SampleBatch.NEXT_OBS],
|
||||
"is_training": policy._get_is_training_placeholder(),
|
||||
}
|
||||
|
||||
model_out_t, _ = model(input_dict, [], None)
|
||||
model_out_tp1, _ = model(input_dict_next, [], None)
|
||||
target_model_out_tp1, _ = policy.target_model(input_dict_next, [], None)
|
||||
|
||||
# Policy network evaluation.
|
||||
with tf.variable_scope(POLICY_SCOPE, reuse=True):
|
||||
# prev_update_ops = set(tf.get_collection(tf.GraphKeys.UPDATE_OPS))
|
||||
policy_t = model.get_policy_output(model_out_t)
|
||||
# policy_batchnorm_update_ops = list(
|
||||
# set(tf.get_collection(tf.GraphKeys.UPDATE_OPS)) - prev_update_ops)
|
||||
|
||||
with tf.variable_scope(POLICY_TARGET_SCOPE):
|
||||
policy_tp1 = \
|
||||
policy.target_model.get_policy_output(target_model_out_tp1)
|
||||
|
||||
# Action outputs.
|
||||
with tf.variable_scope(ACTION_SCOPE, reuse=True):
|
||||
if policy.config["smooth_target_policy"]:
|
||||
target_noise_clip = policy.config["target_noise_clip"]
|
||||
clipped_normal_sample = tf.clip_by_value(
|
||||
tf.random_normal(
|
||||
tf.shape(policy_tp1),
|
||||
stddev=policy.config["target_noise"]), -target_noise_clip,
|
||||
target_noise_clip)
|
||||
policy_tp1_smoothed = tf.clip_by_value(
|
||||
policy_tp1 + clipped_normal_sample,
|
||||
policy.action_space.low * tf.ones_like(policy_tp1),
|
||||
policy.action_space.high * tf.ones_like(policy_tp1))
|
||||
else:
|
||||
# No smoothing, just use deterministic actions.
|
||||
policy_tp1_smoothed = policy_tp1
|
||||
|
||||
# Q-net(s) evaluation.
|
||||
# prev_update_ops = set(tf.get_collection(tf.GraphKeys.UPDATE_OPS))
|
||||
with tf.variable_scope(Q_SCOPE):
|
||||
# Q-values for given actions & observations in given current
|
||||
q_t = model.get_q_values(model_out_t, train_batch[SampleBatch.ACTIONS])
|
||||
|
||||
with tf.variable_scope(Q_SCOPE, reuse=True):
|
||||
# Q-values for current policy (no noise) in given current state
|
||||
q_t_det_policy = model.get_q_values(model_out_t, policy_t)
|
||||
|
||||
if twin_q:
|
||||
with tf.variable_scope(TWIN_Q_SCOPE):
|
||||
twin_q_t = model.get_twin_q_values(
|
||||
model_out_t, train_batch[SampleBatch.ACTIONS])
|
||||
# q_batchnorm_update_ops = list(
|
||||
# set(tf.get_collection(tf.GraphKeys.UPDATE_OPS)) - prev_update_ops)
|
||||
|
||||
# Target q-net(s) evaluation.
|
||||
with tf.variable_scope(Q_TARGET_SCOPE):
|
||||
q_tp1 = policy.target_model.get_q_values(target_model_out_tp1,
|
||||
policy_tp1_smoothed)
|
||||
|
||||
if twin_q:
|
||||
with tf.variable_scope(TWIN_Q_TARGET_SCOPE):
|
||||
twin_q_tp1 = policy.target_model.get_twin_q_values(
|
||||
target_model_out_tp1, policy_tp1_smoothed)
|
||||
|
||||
q_t_selected = tf.squeeze(q_t, axis=len(q_t.shape) - 1)
|
||||
if twin_q:
|
||||
twin_q_t_selected = tf.squeeze(twin_q_t, axis=len(q_t.shape) - 1)
|
||||
q_tp1 = tf.minimum(q_tp1, twin_q_tp1)
|
||||
|
||||
q_tp1_best = tf.squeeze(input=q_tp1, axis=len(q_tp1.shape) - 1)
|
||||
q_tp1_best_masked = \
|
||||
(1.0 - tf.cast(train_batch[SampleBatch.DONES], tf.float32)) * \
|
||||
q_tp1_best
|
||||
|
||||
# Compute RHS of bellman equation.
|
||||
q_t_selected_target = tf.stop_gradient(train_batch[SampleBatch.REWARDS] +
|
||||
gamma**n_step * q_tp1_best_masked)
|
||||
|
||||
# Compute the error (potentially clipped).
|
||||
if twin_q:
|
||||
td_error = q_t_selected - q_t_selected_target
|
||||
twin_td_error = twin_q_t_selected - q_t_selected_target
|
||||
td_error = td_error + twin_td_error
|
||||
if use_huber:
|
||||
errors = huber_loss(td_error, huber_threshold) \
|
||||
+ huber_loss(twin_td_error, huber_threshold)
|
||||
else:
|
||||
errors = 0.5 * tf.square(td_error) + 0.5 * tf.square(twin_td_error)
|
||||
else:
|
||||
td_error = q_t_selected - q_t_selected_target
|
||||
if use_huber:
|
||||
errors = huber_loss(td_error, huber_threshold)
|
||||
else:
|
||||
errors = 0.5 * tf.square(td_error)
|
||||
|
||||
critic_loss = tf.reduce_mean(train_batch[PRIO_WEIGHTS] * errors)
|
||||
actor_loss = -tf.reduce_mean(q_t_det_policy)
|
||||
|
||||
# Add l2-regularization if required.
|
||||
if l2_reg is not None:
|
||||
for var in policy.model.policy_variables():
|
||||
if "bias" not in var.name:
|
||||
actor_loss += (l2_reg * tf.nn.l2_loss(var))
|
||||
for var in policy.model.q_variables():
|
||||
if "bias" not in var.name:
|
||||
critic_loss += (l2_reg * tf.nn.l2_loss(var))
|
||||
|
||||
# Model self-supervised losses.
|
||||
if policy.config["use_state_preprocessor"]:
|
||||
# Expand input_dict in case custom_loss' need them.
|
||||
input_dict[SampleBatch.ACTIONS] = train_batch[SampleBatch.ACTIONS]
|
||||
input_dict[SampleBatch.REWARDS] = train_batch[SampleBatch.REWARDS]
|
||||
input_dict[SampleBatch.DONES] = train_batch[SampleBatch.DONES]
|
||||
input_dict[SampleBatch.NEXT_OBS] = train_batch[SampleBatch.NEXT_OBS]
|
||||
actor_loss, critic_loss = model.custom_loss([actor_loss, critic_loss],
|
||||
input_dict)
|
||||
|
||||
# Store values for stats function.
|
||||
policy.actor_loss = actor_loss
|
||||
policy.critic_loss = critic_loss
|
||||
policy.td_error = td_error
|
||||
policy.q_t = q_t
|
||||
|
||||
# Return one loss value (even though we treat them separately in our
|
||||
# 2 optimizers: actor and critic).
|
||||
return policy.critic_loss + policy.actor_loss
|
||||
|
||||
|
||||
def make_ddpg_optimizers(policy, config):
|
||||
# Create separate optimizers for actor & critic losses.
|
||||
policy._actor_optimizer = tf.train.AdamOptimizer(
|
||||
learning_rate=config["actor_lr"])
|
||||
policy._critic_optimizer = tf.train.AdamOptimizer(
|
||||
learning_rate=config["critic_lr"])
|
||||
return None
|
||||
|
||||
|
||||
def build_apply_op(policy, optimizer, grads_and_vars):
|
||||
# For policy gradient, update policy net one time v.s.
|
||||
# update critic net `policy_delay` time(s).
|
||||
should_apply_actor_opt = tf.equal(
|
||||
tf.mod(policy.global_step, policy.config["policy_delay"]), 0)
|
||||
|
||||
def make_apply_op():
|
||||
return policy._actor_optimizer.apply_gradients(
|
||||
policy._actor_grads_and_vars)
|
||||
|
||||
actor_op = tf.cond(
|
||||
should_apply_actor_opt,
|
||||
true_fn=make_apply_op,
|
||||
false_fn=lambda: tf.no_op())
|
||||
critic_op = policy._critic_optimizer.apply_gradients(
|
||||
policy._critic_grads_and_vars)
|
||||
# Increment global step & apply ops.
|
||||
with tf.control_dependencies([tf.assign_add(policy.global_step, 1)]):
|
||||
return tf.group(actor_op, critic_op)
|
||||
|
||||
|
||||
def gradients_fn(policy, optimizer, loss):
|
||||
if policy.config["grad_norm_clipping"] is not None:
|
||||
actor_grads_and_vars = minimize_and_clip(
|
||||
policy._actor_optimizer,
|
||||
policy.actor_loss,
|
||||
var_list=policy.model.policy_variables(),
|
||||
clip_val=policy.config["grad_norm_clipping"])
|
||||
critic_grads_and_vars = minimize_and_clip(
|
||||
policy._critic_optimizer,
|
||||
policy.critic_loss,
|
||||
var_list=policy.model.q_variables(),
|
||||
clip_val=policy.config["grad_norm_clipping"])
|
||||
else:
|
||||
actor_grads_and_vars = policy._actor_optimizer.compute_gradients(
|
||||
policy.actor_loss, var_list=policy.model.policy_variables())
|
||||
critic_grads_and_vars = policy._critic_optimizer.compute_gradients(
|
||||
policy.critic_loss, var_list=policy.model.q_variables())
|
||||
# Save these for later use in build_apply_op.
|
||||
policy._actor_grads_and_vars = [(g, v) for (g, v) in actor_grads_and_vars
|
||||
if g is not None]
|
||||
policy._critic_grads_and_vars = [(g, v) for (g, v) in critic_grads_and_vars
|
||||
if g is not None]
|
||||
grads_and_vars = policy._actor_grads_and_vars + \
|
||||
policy._critic_grads_and_vars
|
||||
return grads_and_vars
|
||||
|
||||
|
||||
def build_ddpg_stats(policy, batch):
|
||||
stats = {
|
||||
"mean_q": tf.reduce_mean(policy.q_t),
|
||||
"max_q": tf.reduce_max(policy.q_t),
|
||||
"min_q": tf.reduce_min(policy.q_t),
|
||||
}
|
||||
return stats
|
||||
|
||||
|
||||
def before_init_fn(policy, obs_space, action_space, config):
|
||||
# Create global step for counting the number of update operations.
|
||||
policy.global_step = tf.train.get_or_create_global_step()
|
||||
|
||||
|
||||
class ComputeTDErrorMixin:
|
||||
def __init__(self, loss_fn):
|
||||
@make_tf_callable(self.get_session(), dynamic_shape=True)
|
||||
def compute_td_error(obs_t, act_t, rew_t, obs_tp1, done_mask,
|
||||
importance_weights):
|
||||
# Do forward pass on loss to update td errors attribute
|
||||
# (one TD-error value per item in batch to update PR weights).
|
||||
loss_fn(
|
||||
self, self.model, None, {
|
||||
SampleBatch.CUR_OBS: tf.convert_to_tensor(obs_t),
|
||||
SampleBatch.ACTIONS: tf.convert_to_tensor(act_t),
|
||||
SampleBatch.REWARDS: tf.convert_to_tensor(rew_t),
|
||||
SampleBatch.NEXT_OBS: tf.convert_to_tensor(obs_tp1),
|
||||
SampleBatch.DONES: tf.convert_to_tensor(done_mask),
|
||||
PRIO_WEIGHTS: tf.convert_to_tensor(importance_weights),
|
||||
})
|
||||
# `self.td_error` is set in loss_fn.
|
||||
return self.td_error
|
||||
|
||||
self.compute_td_error = compute_td_error
|
||||
|
||||
|
||||
def setup_mid_mixins(policy, obs_space, action_space, config):
|
||||
ComputeTDErrorMixin.__init__(policy, ddpg_actor_critic_loss)
|
||||
|
||||
|
||||
class TargetNetworkMixin:
|
||||
def __init__(self, config):
|
||||
@make_tf_callable(self.get_session())
|
||||
def update_target_fn(tau):
|
||||
tau = tf.convert_to_tensor(tau, dtype=tf.float32)
|
||||
update_target_expr = []
|
||||
model_vars = self.model.trainable_variables()
|
||||
target_model_vars = self.target_model.trainable_variables()
|
||||
assert len(model_vars) == len(target_model_vars), \
|
||||
(model_vars, target_model_vars)
|
||||
for var, var_target in zip(model_vars, target_model_vars):
|
||||
update_target_expr.append(
|
||||
var_target.assign(tau * var + (1.0 - tau) * var_target))
|
||||
logger.debug("Update target op {}".format(var_target))
|
||||
return tf.group(*update_target_expr)
|
||||
|
||||
# Hard initial update.
|
||||
self._do_update = update_target_fn
|
||||
self.update_target(tau=1.0)
|
||||
|
||||
# Support both hard and soft sync.
|
||||
def update_target(self, tau=None):
|
||||
self._do_update(np.float32(tau or self.config.get("tau")))
|
||||
|
||||
@override(TFPolicy)
|
||||
def variables(self):
|
||||
return self.model.variables() + self.target_model.variables()
|
||||
|
||||
|
||||
def setup_late_mixins(policy, obs_space, action_space, config):
|
||||
TargetNetworkMixin.__init__(policy, config)
|
||||
|
||||
|
||||
DDPGTFPolicy = build_tf_policy(
|
||||
name="DQNTFPolicy",
|
||||
get_default_config=lambda: ray.rllib.agents.ddpg.ddpg.DEFAULT_CONFIG,
|
||||
make_model=build_ddpg_models,
|
||||
action_distribution_fn=get_distribution_inputs_and_class,
|
||||
loss_fn=ddpg_actor_critic_loss,
|
||||
stats_fn=build_ddpg_stats,
|
||||
postprocess_fn=postprocess_nstep_and_prio,
|
||||
optimizer_fn=make_ddpg_optimizers,
|
||||
gradients_fn=gradients_fn,
|
||||
apply_gradients_fn=build_apply_op,
|
||||
extra_learn_fetches_fn=lambda policy: {"td_error": policy.td_error},
|
||||
before_init=before_init_fn,
|
||||
before_loss_init=setup_mid_mixins,
|
||||
after_init=setup_late_mixins,
|
||||
obs_include_prev_action_reward=False,
|
||||
mixins=[
|
||||
TargetNetworkMixin,
|
||||
ComputeTDErrorMixin,
|
||||
])
|
|
@ -1,3 +1,4 @@
|
|||
from ray.rllib.models.modelv2 import ModelV2
|
||||
from ray.rllib.models.tf.tf_modelv2 import TFModelV2
|
||||
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
|
||||
from ray.rllib.utils.annotations import override
|
||||
|
@ -11,7 +12,7 @@ class NoopModel(TFModelV2):
|
|||
|
||||
This is the model used if use_state_preprocessor=False."""
|
||||
|
||||
@override(TFModelV2)
|
||||
@override(ModelV2)
|
||||
def forward(self, input_dict, state, seq_lens):
|
||||
return tf.cast(input_dict["obs_flat"], tf.float32), state
|
||||
|
||||
|
@ -21,6 +22,6 @@ class TorchNoopModel(TorchModelV2):
|
|||
|
||||
This is the model used if use_state_preprocessor=False."""
|
||||
|
||||
@override(TorchModelV2)
|
||||
@override(ModelV2)
|
||||
def forward(self, input_dict, state, seq_lens):
|
||||
return input_dict["obs_flat"].float(), state
|
||||
|
|
|
@ -82,14 +82,17 @@ class ESTFPolicy:
|
|||
# Policy network.
|
||||
dist_class, dist_dim = ModelCatalog.get_action_dist(
|
||||
self.action_space, config["model"], dist_type="deterministic")
|
||||
model = ModelCatalog.get_model({
|
||||
SampleBatch.CUR_OBS: self.inputs
|
||||
}, obs_space, action_space, dist_dim, config["model"])
|
||||
dist = dist_class(model.outputs, model)
|
||||
self.model = ModelCatalog.get_model_v2(
|
||||
obs_space=self.preprocessor.observation_space,
|
||||
action_space=action_space,
|
||||
num_outputs=dist_dim,
|
||||
model_config=config["model"])
|
||||
dist_inputs, _ = self.model({SampleBatch.CUR_OBS: self.inputs})
|
||||
dist = dist_class(dist_inputs, self.model)
|
||||
self.sampler = dist.sample()
|
||||
|
||||
self.variables = ray.experimental.tf_utils.TensorFlowVariables(
|
||||
model.outputs, self.sess)
|
||||
dist_inputs, self.sess)
|
||||
|
||||
self.num_params = sum(
|
||||
np.prod(variable.shape.as_list())
|
||||
|
|
|
@ -235,7 +235,7 @@ class ValueNetworkMixin:
|
|||
[prev_action]),
|
||||
SampleBatch.PREV_REWARDS: tf.convert_to_tensor(
|
||||
[prev_reward]),
|
||||
"is_training": tf.convert_to_tensor(False),
|
||||
"is_training": tf.convert_to_tensor([False]),
|
||||
}, [tf.convert_to_tensor([s]) for s in state],
|
||||
tf.convert_to_tensor([1]))
|
||||
return self.model.value_function()[0]
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
from ray.rllib.models.modelv2 import ModelV2
|
||||
from ray.rllib.models.preprocessors import get_preprocessor
|
||||
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
|
||||
from ray.rllib.utils.annotations import override
|
||||
|
@ -20,12 +21,12 @@ class RNNModel(TorchModelV2, nn.Module):
|
|||
self.rnn = nn.GRUCell(self.rnn_hidden_dim, self.rnn_hidden_dim)
|
||||
self.fc2 = nn.Linear(self.rnn_hidden_dim, num_outputs)
|
||||
|
||||
@override(TorchModelV2)
|
||||
@override(ModelV2)
|
||||
def get_initial_state(self):
|
||||
# make hidden states on same device as model
|
||||
return [self.fc1.weight.new(1, self.rnn_hidden_dim).zero_().squeeze(0)]
|
||||
|
||||
@override(TorchModelV2)
|
||||
@override(ModelV2)
|
||||
def forward(self, input_dict, hidden_state, seq_lens):
|
||||
x = nn.functional.relu(self.fc1(input_dict["obs_flat"].float()))
|
||||
h_in = hidden_state[0].reshape(-1, self.rnn_hidden_dim)
|
||||
|
|
|
@ -10,7 +10,7 @@ from ray.rllib.policy.policy import Policy
|
|||
from ray.rllib.policy.rnn_sequencing import chop_into_sequences
|
||||
from ray.rllib.policy.sample_batch import SampleBatch
|
||||
from ray.rllib.models.catalog import ModelCatalog
|
||||
from ray.rllib.models.model import _unpack_obs
|
||||
from ray.rllib.models.modelv2 import _unpack_obs
|
||||
from ray.rllib.env.constants import GROUP_REWARDS
|
||||
from ray.rllib.utils.framework import try_import_torch
|
||||
from ray.rllib.utils.annotations import override
|
||||
|
|
|
@ -1,4 +1,6 @@
|
|||
import gym
|
||||
|
||||
from ray.rllib.models.modelv2 import ModelV2
|
||||
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
|
||||
from ray.rllib.utils import try_import_torch
|
||||
from ray.rllib.utils.annotations import override
|
||||
|
@ -111,7 +113,7 @@ class DiscreteLinearModel(TorchModelV2, nn.Module):
|
|||
self._cur_value = None
|
||||
self._cur_ctx = None
|
||||
|
||||
@override(TorchModelV2)
|
||||
@override(ModelV2)
|
||||
def forward(self, input_dict, state, seq_lens):
|
||||
x = input_dict["obs"]
|
||||
scores = self.predict(x)
|
||||
|
@ -137,7 +139,7 @@ class DiscreteLinearModel(TorchModelV2, nn.Module):
|
|||
f"It should be 0 <= arm < {len(self.arms)}"
|
||||
self.arms[arm].partial_fit(x, y)
|
||||
|
||||
@override(TorchModelV2)
|
||||
@override(ModelV2)
|
||||
def value_function(self):
|
||||
assert self._cur_value is not None, "must call forward() first"
|
||||
return self._cur_value
|
||||
|
@ -190,7 +192,7 @@ class ParametricLinearModel(TorchModelV2, nn.Module):
|
|||
assert x.size()[
|
||||
0] == 1, "Only batch size of 1 is supported for now."
|
||||
|
||||
@override(TorchModelV2)
|
||||
@override(ModelV2)
|
||||
def forward(self, input_dict, state, seq_lens):
|
||||
x = input_dict["obs"]["item"]
|
||||
self._check_inputs(x)
|
||||
|
@ -214,7 +216,7 @@ class ParametricLinearModel(TorchModelV2, nn.Module):
|
|||
action_id = arm.item()
|
||||
self.arm.partial_fit(x[:, action_id], y)
|
||||
|
||||
@override(TorchModelV2)
|
||||
@override(ModelV2)
|
||||
def value_function(self):
|
||||
assert self._cur_value is not None, "must call forward() first"
|
||||
return self._cur_value
|
||||
|
|
|
@ -4,9 +4,12 @@ import argparse
|
|||
|
||||
import ray
|
||||
from ray import tune
|
||||
from ray.rllib.models import Model, ModelCatalog
|
||||
from ray.rllib.models import ModelCatalog
|
||||
from ray.rllib.models.modelv2 import ModelV2
|
||||
from ray.rllib.models.tf.misc import normc_initializer
|
||||
from ray.rllib.models.tf.tf_modelv2 import TFModelV2
|
||||
from ray.rllib.utils import try_import_tf
|
||||
from ray.rllib.utils.annotations import override
|
||||
|
||||
tf = try_import_tf()
|
||||
|
||||
|
@ -15,28 +18,127 @@ parser.add_argument("--num-iters", type=int, default=200)
|
|||
parser.add_argument("--run", type=str, default="PPO")
|
||||
|
||||
|
||||
class BatchNormModel(Model):
|
||||
def _build_layers_v2(self, input_dict, num_outputs, options):
|
||||
class BatchNormModel(TFModelV2):
|
||||
"""Example of a TFModelV2 that is built w/o using tf.keras.
|
||||
|
||||
NOTE: This example does not work when using a keras-based TFModelV2 due
|
||||
to a bug in keras related to missing values for input placeholders, even
|
||||
though these input values have been provided in a forward pass through the
|
||||
actual keras Model.
|
||||
|
||||
All Model logic (layers) is defined in the `forward` method (incl.
|
||||
the batch_normalization layers). Also, all variables are registered
|
||||
(only once) at the end of `forward`, so an optimizer knows which tensors
|
||||
to train on. A standard `value_function` override is used.
|
||||
"""
|
||||
capture_index = 0
|
||||
|
||||
def __init__(self, obs_space, action_space, num_outputs, model_config,
|
||||
name):
|
||||
super().__init__(obs_space, action_space, num_outputs, model_config,
|
||||
name)
|
||||
# Have we registered our vars yet (see `forward`)?
|
||||
self._registered = False
|
||||
|
||||
@override(ModelV2)
|
||||
def forward(self, input_dict, state, seq_lens):
|
||||
last_layer = input_dict["obs"]
|
||||
hiddens = [256, 256]
|
||||
with tf.variable_scope("model", reuse=tf.AUTO_REUSE):
|
||||
for i, size in enumerate(hiddens):
|
||||
last_layer = tf.layers.dense(
|
||||
last_layer,
|
||||
size,
|
||||
kernel_initializer=normc_initializer(1.0),
|
||||
activation=tf.nn.tanh,
|
||||
name="fc{}".format(i))
|
||||
# Add a batch norm layer
|
||||
last_layer = tf.layers.batch_normalization(
|
||||
last_layer,
|
||||
training=input_dict["is_training"],
|
||||
name="bn_{}".format(i))
|
||||
|
||||
output = tf.layers.dense(
|
||||
last_layer,
|
||||
self.num_outputs,
|
||||
kernel_initializer=normc_initializer(0.01),
|
||||
activation=None,
|
||||
name="out")
|
||||
self._value_out = tf.layers.dense(
|
||||
last_layer,
|
||||
1,
|
||||
kernel_initializer=normc_initializer(1.0),
|
||||
activation=None,
|
||||
name="vf")
|
||||
if not self._registered:
|
||||
self.register_variables(
|
||||
tf.get_collection(
|
||||
tf.GraphKeys.TRAINABLE_VARIABLES, scope=".+/model/.+"))
|
||||
self._registered = True
|
||||
|
||||
return output, []
|
||||
|
||||
@override(ModelV2)
|
||||
def value_function(self):
|
||||
return tf.reshape(self._value_out, [-1])
|
||||
|
||||
|
||||
class KerasBatchNormModel(TFModelV2):
|
||||
"""Keras version of above BatchNormModel with exactly the same structure.
|
||||
|
||||
IMORTANT NOTE: This model will not work with PPO due to a bug in keras
|
||||
that surfaces when having more than one input placeholder (here: `inputs`
|
||||
and `is_training`) AND using the `make_tf_callable` helper (e.g. used by
|
||||
PPO), in which auto-placeholders are generated, then passed through the
|
||||
tf.keras. models.Model. In this last step, the connection between 1) the
|
||||
provided value in the auto-placeholder and 2) the keras `is_training`
|
||||
Input is broken and keras complains.
|
||||
Use the above `BatchNormModel` (a non-keras based TFModelV2), instead.
|
||||
"""
|
||||
|
||||
def __init__(self, obs_space, action_space, num_outputs, model_config,
|
||||
name):
|
||||
super().__init__(obs_space, action_space, num_outputs, model_config,
|
||||
name)
|
||||
inputs = tf.keras.layers.Input(shape=obs_space.shape, name="inputs")
|
||||
is_training = tf.keras.layers.Input(
|
||||
shape=(), dtype=tf.bool, batch_size=1, name="is_training")
|
||||
last_layer = inputs
|
||||
hiddens = [256, 256]
|
||||
for i, size in enumerate(hiddens):
|
||||
label = "fc{}".format(i)
|
||||
last_layer = tf.layers.dense(
|
||||
last_layer,
|
||||
size,
|
||||
last_layer = tf.keras.layers.Dense(
|
||||
units=size,
|
||||
kernel_initializer=normc_initializer(1.0),
|
||||
activation=tf.nn.tanh,
|
||||
name=label)
|
||||
name=label)(last_layer)
|
||||
# Add a batch norm layer
|
||||
last_layer = tf.layers.batch_normalization(
|
||||
last_layer, training=input_dict["is_training"])
|
||||
output = tf.layers.dense(
|
||||
last_layer,
|
||||
num_outputs,
|
||||
last_layer = tf.keras.layers.BatchNormalization()(
|
||||
last_layer, training=is_training[0])
|
||||
output = tf.keras.layers.Dense(
|
||||
units=self.num_outputs,
|
||||
kernel_initializer=normc_initializer(0.01),
|
||||
activation=None,
|
||||
name="fc_out")
|
||||
return output, last_layer
|
||||
name="fc_out")(last_layer)
|
||||
value_out = tf.keras.layers.Dense(
|
||||
units=1,
|
||||
kernel_initializer=normc_initializer(0.01),
|
||||
activation=None,
|
||||
name="value_out")(last_layer)
|
||||
|
||||
self.base_model = tf.keras.models.Model(
|
||||
inputs=[inputs, is_training], outputs=[output, value_out])
|
||||
self.register_variables(self.base_model.variables)
|
||||
|
||||
@override(ModelV2)
|
||||
def forward(self, input_dict, state, seq_lens):
|
||||
out, self._value_out = self.base_model(
|
||||
[input_dict["obs"], input_dict["is_training"]])
|
||||
return out, []
|
||||
|
||||
@override(ModelV2)
|
||||
def value_function(self):
|
||||
return tf.reshape(self._value_out, [-1])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -44,14 +146,21 @@ if __name__ == "__main__":
|
|||
ray.init()
|
||||
|
||||
ModelCatalog.register_custom_model("bn_model", BatchNormModel)
|
||||
|
||||
config = {
|
||||
"env": "Pendulum-v0" if args.run == "DDPG" else "CartPole-v0",
|
||||
"model": {
|
||||
"custom_model": "bn_model",
|
||||
},
|
||||
"num_workers": 0,
|
||||
}
|
||||
|
||||
from ray.rllib.agents.ppo import PPOTrainer
|
||||
trainer = PPOTrainer(config=config)
|
||||
trainer.train()
|
||||
|
||||
tune.run(
|
||||
args.run,
|
||||
stop={"training_iteration": args.num_iters},
|
||||
config={
|
||||
"env": "Pendulum-v0" if args.run == "DDPG" else "CartPole-v0",
|
||||
"model": {
|
||||
"custom_model": "bn_model",
|
||||
},
|
||||
"num_workers": 0,
|
||||
},
|
||||
config=config,
|
||||
)
|
||||
|
|
|
@ -169,7 +169,6 @@ if __name__ == "__main__":
|
|||
"max_seq_len": 20,
|
||||
},
|
||||
}
|
||||
|
||||
tune.run(
|
||||
args.run,
|
||||
config=config,
|
||||
|
|
|
@ -4,11 +4,14 @@ import random
|
|||
import ray
|
||||
from ray import tune
|
||||
from ray.rllib.agents.trainer_template import build_trainer
|
||||
from ray.rllib.models import Model, ModelCatalog
|
||||
from ray.rllib.models.tf.fcnet_v1 import FullyConnectedNetwork
|
||||
from ray.rllib.models import ModelCatalog
|
||||
from ray.rllib.models.modelv2 import ModelV2
|
||||
from ray.rllib.models.tf.fcnet_v2 import FullyConnectedNetwork
|
||||
from ray.rllib.models.tf.tf_modelv2 import TFModelV2
|
||||
from ray.rllib.policy.sample_batch import SampleBatch
|
||||
from ray.rllib.policy.tf_policy_template import build_tf_policy
|
||||
from ray.rllib.utils import try_import_tf
|
||||
from ray.rllib.utils.annotations import override
|
||||
|
||||
tf = try_import_tf()
|
||||
|
||||
|
@ -16,7 +19,7 @@ parser = argparse.ArgumentParser()
|
|||
parser.add_argument("--iters", type=int, default=200)
|
||||
|
||||
|
||||
class EagerModel(Model):
|
||||
class EagerModel(TFModelV2):
|
||||
"""Example of using embedded eager execution in a custom model.
|
||||
|
||||
This shows how to use tf.py_function() to execute a snippet of TF code
|
||||
|
@ -25,15 +28,39 @@ class EagerModel(Model):
|
|||
perform any TF eager operation in tf.py_function().
|
||||
"""
|
||||
|
||||
def _build_layers_v2(self, input_dict, num_outputs, options):
|
||||
self.fcnet = FullyConnectedNetwork(input_dict, self.obs_space,
|
||||
self.action_space, num_outputs,
|
||||
options)
|
||||
feature_out = tf.py_function(self.forward_eager,
|
||||
[self.fcnet.last_layer], tf.float32)
|
||||
def __init__(self, observation_space, action_space, num_outputs,
|
||||
model_config, name):
|
||||
super().__init__(observation_space, action_space, num_outputs,
|
||||
model_config, name)
|
||||
|
||||
with tf.control_dependencies([feature_out]):
|
||||
return tf.identity(self.fcnet.outputs), feature_out
|
||||
inputs = tf.keras.layers.Input(shape=observation_space.shape)
|
||||
self.fcnet = FullyConnectedNetwork(
|
||||
obs_space=self.obs_space,
|
||||
action_space=self.action_space,
|
||||
num_outputs=self.num_outputs,
|
||||
model_config=self.model_config,
|
||||
name="fc1")
|
||||
out, value_out = self.fcnet.base_model(inputs)
|
||||
|
||||
def lambda_(x):
|
||||
eager_out = tf.py_function(self.forward_eager, [x], tf.float32)
|
||||
with tf.control_dependencies([eager_out]):
|
||||
eager_out.set_shape(x.shape)
|
||||
return eager_out
|
||||
|
||||
out = tf.keras.layers.Lambda(lambda_)(out)
|
||||
self.base_model = tf.keras.models.Model(inputs, [out, value_out])
|
||||
self.register_variables(self.base_model.variables)
|
||||
|
||||
@override(ModelV2)
|
||||
def forward(self, input_dict, state, seq_lens):
|
||||
out, self._value_out = self.base_model(input_dict["obs"], state,
|
||||
seq_lens)
|
||||
return out, []
|
||||
|
||||
@override(ModelV2)
|
||||
def value_function(self):
|
||||
return tf.reshape(self._value_out, [-1])
|
||||
|
||||
def forward_eager(self, feature_layer):
|
||||
assert tf.executing_eagerly()
|
||||
|
@ -84,13 +111,13 @@ if __name__ == "__main__":
|
|||
ray.init()
|
||||
args = parser.parse_args()
|
||||
ModelCatalog.register_custom_model("eager_model", EagerModel)
|
||||
tune.run(
|
||||
MyTrainer,
|
||||
stop={"training_iteration": args.iters},
|
||||
config={
|
||||
"env": "CartPole-v0",
|
||||
"num_workers": 0,
|
||||
"model": {
|
||||
"custom_model": "eager_model"
|
||||
},
|
||||
})
|
||||
|
||||
config = {
|
||||
"env": "CartPole-v0",
|
||||
"num_workers": 0,
|
||||
"model": {
|
||||
"custom_model": "eager_model"
|
||||
},
|
||||
}
|
||||
|
||||
tune.run(MyTrainer, stop={"training_iteration": args.iters}, config=config)
|
||||
|
|
|
@ -15,10 +15,13 @@ import random
|
|||
|
||||
import ray
|
||||
from ray import tune
|
||||
from ray.rllib.models import Model, ModelCatalog
|
||||
from ray.rllib.models import ModelCatalog
|
||||
from ray.rllib.models.modelv2 import ModelV2
|
||||
from ray.rllib.models.tf.tf_modelv2 import TFModelV2
|
||||
from ray.rllib.tests.test_multi_agent_env import MultiCartpole
|
||||
from ray.tune.registry import register_env
|
||||
from ray.rllib.utils import try_import_tf
|
||||
from ray.rllib.utils.annotations import override
|
||||
|
||||
tf = try_import_tf()
|
||||
|
||||
|
@ -31,8 +34,13 @@ parser.add_argument("--simple", action="store_true")
|
|||
parser.add_argument("--num-cpus", type=int, default=0)
|
||||
|
||||
|
||||
class CustomModel1(Model):
|
||||
def _build_layers_v2(self, input_dict, num_outputs, options):
|
||||
class CustomModel1(TFModelV2):
|
||||
def __init__(self, observation_space, action_space, num_outputs,
|
||||
model_config, name):
|
||||
super().__init__(observation_space, action_space, num_outputs,
|
||||
model_config, name)
|
||||
|
||||
inputs = tf.keras.layers.Input(observation_space.shape)
|
||||
# Example of (optional) weight sharing between two different policies.
|
||||
# Here, we share the variables defined in the 'shared' variable scope
|
||||
# by entering it explicitly with tf.AUTO_REUSE. This creates the
|
||||
|
@ -42,29 +50,55 @@ class CustomModel1(Model):
|
|||
tf.VariableScope(tf.AUTO_REUSE, "shared"),
|
||||
reuse=tf.AUTO_REUSE,
|
||||
auxiliary_name_scope=False):
|
||||
last_layer = tf.layers.dense(
|
||||
input_dict["obs"], 64, activation=tf.nn.relu, name="fc1")
|
||||
last_layer = tf.layers.dense(
|
||||
last_layer, 64, activation=tf.nn.relu, name="fc2")
|
||||
output = tf.layers.dense(
|
||||
last_layer, num_outputs, activation=None, name="fc_out")
|
||||
return output, last_layer
|
||||
last_layer = tf.keras.layers.Dense(
|
||||
units=64, activation=tf.nn.relu, name="fc1")(inputs)
|
||||
output = tf.keras.layers.Dense(
|
||||
units=num_outputs, activation=None, name="fc_out")(last_layer)
|
||||
vf = tf.keras.layers.Dense(
|
||||
units=1, activation=None, name="value_out")(last_layer)
|
||||
self.base_model = tf.keras.models.Model(inputs, [output, vf])
|
||||
self.register_variables(self.base_model.variables)
|
||||
|
||||
@override(ModelV2)
|
||||
def forward(self, input_dict, state, seq_lens):
|
||||
out, self._value_out = self.base_model(input_dict["obs"])
|
||||
return out, []
|
||||
|
||||
@override(ModelV2)
|
||||
def value_function(self):
|
||||
return tf.reshape(self._value_out, [-1])
|
||||
|
||||
|
||||
class CustomModel2(Model):
|
||||
def _build_layers_v2(self, input_dict, num_outputs, options):
|
||||
# Weights shared with CustomModel1
|
||||
class CustomModel2(TFModelV2):
|
||||
def __init__(self, observation_space, action_space, num_outputs,
|
||||
model_config, name):
|
||||
super().__init__(observation_space, action_space, num_outputs,
|
||||
model_config, name)
|
||||
|
||||
inputs = tf.keras.layers.Input(observation_space.shape)
|
||||
|
||||
# Weights shared with CustomModel1.
|
||||
with tf.variable_scope(
|
||||
tf.VariableScope(tf.AUTO_REUSE, "shared"),
|
||||
reuse=tf.AUTO_REUSE,
|
||||
auxiliary_name_scope=False):
|
||||
last_layer = tf.layers.dense(
|
||||
input_dict["obs"], 64, activation=tf.nn.relu, name="fc1")
|
||||
last_layer = tf.layers.dense(
|
||||
last_layer, 64, activation=tf.nn.relu, name="fc2")
|
||||
output = tf.layers.dense(
|
||||
last_layer, num_outputs, activation=None, name="fc_out")
|
||||
return output, last_layer
|
||||
last_layer = tf.keras.layers.Dense(
|
||||
units=64, activation=tf.nn.relu, name="fc1")(inputs)
|
||||
output = tf.keras.layers.Dense(
|
||||
units=num_outputs, activation=None, name="fc_out")(last_layer)
|
||||
vf = tf.keras.layers.Dense(
|
||||
units=1, activation=None, name="value_out")(last_layer)
|
||||
self.base_model = tf.keras.models.Model(inputs, [output, vf])
|
||||
self.register_variables(self.base_model.variables)
|
||||
|
||||
@override(ModelV2)
|
||||
def forward(self, input_dict, state, seq_lens):
|
||||
out, self._value_out = self.base_model(input_dict["obs"])
|
||||
return out, []
|
||||
|
||||
@override(ModelV2)
|
||||
def value_function(self):
|
||||
return tf.reshape(self._value_out, [-1])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
@ -23,6 +23,7 @@ from ray.rllib.models.torch.torch_action_dist import TorchCategorical, \
|
|||
TorchMultiActionDistribution, TorchMultiCategorical
|
||||
from ray.rllib.utils import try_import_tf, try_import_tree
|
||||
from ray.rllib.utils.annotations import DeveloperAPI, PublicAPI
|
||||
from ray.rllib.utils.deprecation import deprecation_warning
|
||||
from ray.rllib.utils.error import UnsupportedSpaceException
|
||||
from ray.rllib.utils.space_utils import flatten_space
|
||||
|
||||
|
@ -471,6 +472,7 @@ class ModelCatalog:
|
|||
seq_lens=None):
|
||||
"""Deprecated: use get_model_v2() instead."""
|
||||
|
||||
deprecation_warning("get_model", "get_model_v2", error=False)
|
||||
assert isinstance(input_dict, dict)
|
||||
options = options or MODEL_DEFAULTS
|
||||
model = ModelCatalog._get_model(input_dict, obs_space, action_space,
|
||||
|
@ -496,6 +498,7 @@ class ModelCatalog:
|
|||
@staticmethod
|
||||
def _get_model(input_dict, obs_space, action_space, num_outputs, options,
|
||||
state_in, seq_lens):
|
||||
deprecation_warning("_get_model", "get_model_v2", error=False)
|
||||
if options.get("custom_model"):
|
||||
model = options["custom_model"]
|
||||
logger.debug("Using custom model {}".format(model))
|
||||
|
|
|
@ -5,7 +5,8 @@ import gym
|
|||
from ray.rllib.models.tf.misc import linear, normc_initializer
|
||||
from ray.rllib.models.preprocessors import get_preprocessor
|
||||
from ray.rllib.utils.annotations import PublicAPI, DeveloperAPI
|
||||
from ray.rllib.utils import try_import_tf, try_import_torch
|
||||
from ray.rllib.utils.deprecation import deprecation_warning
|
||||
from ray.rllib.utils.framework import try_import_tf, try_import_torch
|
||||
|
||||
tf = try_import_tf()
|
||||
torch, _ = try_import_torch()
|
||||
|
@ -14,7 +15,7 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
|
||||
class Model:
|
||||
"""This class is deprecated, please use TFModelV2 instead."""
|
||||
"""This class is deprecated! Use ModelV2 instead."""
|
||||
|
||||
def __init__(self,
|
||||
input_dict,
|
||||
|
@ -24,6 +25,9 @@ class Model:
|
|||
options,
|
||||
state_in=None,
|
||||
seq_lens=None):
|
||||
# Soft-deprecate this class. All Models should use the ModelV2
|
||||
# API from here on.
|
||||
deprecation_warning("Model", "ModelV2", error=False)
|
||||
assert isinstance(input_dict, dict), input_dict
|
||||
|
||||
# Default attribute values for the non-RNN case
|
||||
|
|
|
@ -1,6 +1,13 @@
|
|||
from collections import OrderedDict
|
||||
import gym
|
||||
|
||||
from ray.rllib.models.preprocessors import get_preprocessor
|
||||
from ray.rllib.policy.sample_batch import SampleBatch
|
||||
from ray.rllib.models.model import restore_original_dimensions, flatten
|
||||
from ray.rllib.utils.annotations import PublicAPI
|
||||
from ray.rllib.utils.annotations import DeveloperAPI, PublicAPI
|
||||
from ray.rllib.utils.framework import try_import_tf, try_import_torch
|
||||
|
||||
tf = try_import_tf()
|
||||
torch, _ = try_import_torch()
|
||||
|
||||
|
||||
@PublicAPI
|
||||
|
@ -70,7 +77,7 @@ class ModelV2:
|
|||
|
||||
Custom models should override this instead of __call__.
|
||||
|
||||
Arguments:
|
||||
Args:
|
||||
input_dict (dict): dictionary of input tensors, including "obs",
|
||||
"obs_flat", "prev_action", "prev_reward", "is_training"
|
||||
state (list): list of state tensors with sizes matching those
|
||||
|
@ -80,6 +87,12 @@ class ModelV2:
|
|||
Returns:
|
||||
(outputs, state): The model output tensor of size
|
||||
[BATCH, num_outputs]
|
||||
|
||||
Examples:
|
||||
>>> def forward(self, input_dict, state, seq_lens):
|
||||
>>> model_out, self._value_out = self.base_model(
|
||||
... input_dict["obs"])
|
||||
>>> return model_out, state
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
@ -267,3 +280,98 @@ class NullContextManager:
|
|||
|
||||
def __exit__(self, *args):
|
||||
pass
|
||||
|
||||
|
||||
@DeveloperAPI
|
||||
def flatten(obs, framework):
|
||||
"""Flatten the given tensor."""
|
||||
if framework == "tf":
|
||||
return tf.layers.flatten(obs)
|
||||
elif framework == "torch":
|
||||
assert torch is not None
|
||||
return torch.flatten(obs, start_dim=1)
|
||||
else:
|
||||
raise NotImplementedError("flatten", framework)
|
||||
|
||||
|
||||
@DeveloperAPI
|
||||
def restore_original_dimensions(obs, obs_space, tensorlib=tf):
|
||||
"""Unpacks Dict and Tuple space observations into their original form.
|
||||
|
||||
This is needed since we flatten Dict and Tuple observations in transit.
|
||||
Before sending them to the model though, we should unflatten them into
|
||||
Dicts or Tuples of tensors.
|
||||
|
||||
Arguments:
|
||||
obs: The flattened observation tensor.
|
||||
obs_space: The flattened obs space. If this has the `original_space`
|
||||
attribute, we will unflatten the tensor to that shape.
|
||||
tensorlib: The library used to unflatten (reshape) the array/tensor.
|
||||
|
||||
Returns:
|
||||
single tensor or dict / tuple of tensors matching the original
|
||||
observation space.
|
||||
"""
|
||||
|
||||
if hasattr(obs_space, "original_space"):
|
||||
if tensorlib == "tf":
|
||||
tensorlib = tf
|
||||
elif tensorlib == "torch":
|
||||
assert torch is not None
|
||||
tensorlib = torch
|
||||
return _unpack_obs(obs, obs_space.original_space, tensorlib=tensorlib)
|
||||
else:
|
||||
return obs
|
||||
|
||||
|
||||
# Cache of preprocessors, for if the user is calling unpack obs often.
|
||||
_cache = {}
|
||||
|
||||
|
||||
def _unpack_obs(obs, space, tensorlib=tf):
|
||||
"""Unpack a flattened Dict or Tuple observation array/tensor.
|
||||
|
||||
Arguments:
|
||||
obs: The flattened observation tensor
|
||||
space: The original space prior to flattening
|
||||
tensorlib: The library used to unflatten (reshape) the array/tensor
|
||||
"""
|
||||
|
||||
if (isinstance(space, gym.spaces.Dict)
|
||||
or isinstance(space, gym.spaces.Tuple)):
|
||||
if id(space) in _cache:
|
||||
prep = _cache[id(space)]
|
||||
else:
|
||||
prep = get_preprocessor(space)(space)
|
||||
# Make an attempt to cache the result, if enough space left.
|
||||
if len(_cache) < 999:
|
||||
_cache[id(space)] = prep
|
||||
if len(obs.shape) != 2 or obs.shape[1] != prep.shape[0]:
|
||||
raise ValueError(
|
||||
"Expected flattened obs shape of [None, {}], got {}".format(
|
||||
prep.shape[0], obs.shape))
|
||||
assert len(prep.preprocessors) == len(space.spaces), \
|
||||
(len(prep.preprocessors) == len(space.spaces))
|
||||
offset = 0
|
||||
if isinstance(space, gym.spaces.Tuple):
|
||||
u = []
|
||||
for p, v in zip(prep.preprocessors, space.spaces):
|
||||
obs_slice = obs[:, offset:offset + p.size]
|
||||
offset += p.size
|
||||
u.append(
|
||||
_unpack_obs(
|
||||
tensorlib.reshape(obs_slice, [-1] + list(p.shape)),
|
||||
v,
|
||||
tensorlib=tensorlib))
|
||||
else:
|
||||
u = OrderedDict()
|
||||
for p, (k, v) in zip(prep.preprocessors, space.spaces.items()):
|
||||
obs_slice = obs[:, offset:offset + p.size]
|
||||
offset += p.size
|
||||
u[k] = _unpack_obs(
|
||||
tensorlib.reshape(obs_slice, [-1] + list(p.shape)),
|
||||
v,
|
||||
tensorlib=tensorlib)
|
||||
return u
|
||||
else:
|
||||
return obs
|
||||
|
|
|
@ -1,14 +1,12 @@
|
|||
# TODO(sven): Add once ModelV1 is deprecated and we no longer cause circular
|
||||
# dependencies b/c of that.
|
||||
# from ray.rllib.models.tf.tf_modelv2 import TFModelV2
|
||||
# from ray.rllib.models.tf.fcnet_v2 import FullyConnectedNetwork
|
||||
# from ray.rllib.models.tf.recurrent_tf_modelv2 import \
|
||||
# RecurrentTFModelV2
|
||||
# from ray.rllib.models.tf.visionnet_v2 import VisionNetwork
|
||||
from ray.rllib.models.tf.tf_modelv2 import TFModelV2
|
||||
from ray.rllib.models.tf.fcnet_v2 import FullyConnectedNetwork
|
||||
from ray.rllib.models.tf.recurrent_tf_modelv2 import \
|
||||
RecurrentTFModelV2
|
||||
from ray.rllib.models.tf.visionnet_v2 import VisionNetwork
|
||||
|
||||
# __all__ = [
|
||||
# "FullyConnectedNetwork",
|
||||
# "RecurrentTFModelV2",
|
||||
# "TFModelV2",
|
||||
# "VisionNetwork",
|
||||
# ]
|
||||
__all__ = [
|
||||
"FullyConnectedNetwork",
|
||||
"RecurrentTFModelV2",
|
||||
"TFModelV2",
|
||||
"VisionNetwork",
|
||||
]
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
from ray.rllib.models.model import Model
|
||||
from ray.rllib.models.tf.misc import normc_initializer
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils.deprecation import deprecation_warning
|
||||
from ray.rllib.utils.framework import get_activation_fn, try_import_tf
|
||||
|
||||
tf = try_import_tf()
|
||||
|
@ -17,6 +18,12 @@ class FullyConnectedNetwork(Model):
|
|||
Note that dict inputs will be flattened into a vector. To define a
|
||||
model that processes the components separately, use _build_layers_v2().
|
||||
"""
|
||||
# Soft deprecate this class. All Models should use the ModelV2
|
||||
# API from here on.
|
||||
deprecation_warning(
|
||||
"Model->FullyConnectedNetwork",
|
||||
"ModelV2->FullyConnectedNetwork",
|
||||
error=False)
|
||||
|
||||
hiddens = options.get("fcnet_hiddens")
|
||||
activation = get_activation_fn(options.get("fcnet_activation"))
|
||||
|
|
|
@ -16,7 +16,7 @@ class FullyConnectedNetwork(TFModelV2):
|
|||
obs_space, action_space, num_outputs, model_config, name)
|
||||
|
||||
activation = get_activation_fn(model_config.get("fcnet_activation"))
|
||||
hiddens = model_config.get("fcnet_hiddens")
|
||||
hiddens = model_config.get("fcnet_hiddens", [])
|
||||
no_final_linear = model_config.get("no_final_linear")
|
||||
vf_share_layers = model_config.get("vf_share_layers")
|
||||
|
||||
|
|
|
@ -4,7 +4,8 @@ from ray.rllib.models.model import Model
|
|||
from ray.rllib.models.tf.misc import linear, normc_initializer
|
||||
from ray.rllib.policy.rnn_sequencing import add_time_dimension
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils import try_import_tf
|
||||
from ray.rllib.utils.deprecation import deprecation_warning
|
||||
from ray.rllib.utils.framework import try_import_tf
|
||||
|
||||
tf = try_import_tf()
|
||||
|
||||
|
@ -21,6 +22,10 @@ class LSTM(Model):
|
|||
|
||||
@override(Model)
|
||||
def _build_layers_v2(self, input_dict, num_outputs, options):
|
||||
# Hard deprecate this class. All Models should use the ModelV2
|
||||
# API from here on.
|
||||
deprecation_warning("Model->LSTM", "RecurrentTFModelV2", error=False)
|
||||
|
||||
cell_size = options.get("lstm_cell_size")
|
||||
if options.get("lstm_use_prev_action_reward"):
|
||||
action_dim = int(
|
||||
|
|
|
@ -52,39 +52,6 @@ class TFModelV2(ModelV2):
|
|||
else:
|
||||
return ModelV2.context(self)
|
||||
|
||||
def forward(self, input_dict, state, seq_lens):
|
||||
"""Call the model with the given input tensors and state.
|
||||
|
||||
Any complex observations (dicts, tuples, etc.) will be unpacked by
|
||||
__call__ before being passed to forward(). To access the flattened
|
||||
observation tensor, refer to input_dict["obs_flat"].
|
||||
|
||||
This method can be called any number of times. In eager execution,
|
||||
each call to forward() will eagerly evaluate the model. In symbolic
|
||||
execution, each call to forward creates a computation graph that
|
||||
operates over the variables of this model (i.e., shares weights).
|
||||
|
||||
Custom models should override this instead of __call__.
|
||||
|
||||
Args:
|
||||
input_dict (dict): dictionary of input tensors, including "obs",
|
||||
"obs_flat", "prev_action", "prev_reward", "is_training"
|
||||
state (list): list of state tensors with sizes matching those
|
||||
returned by get_initial_state + the batch dimension
|
||||
seq_lens (Tensor): 1d tensor holding input sequence lengths
|
||||
|
||||
Returns:
|
||||
(outputs, state): The model output tensor of size
|
||||
[BATCH, num_outputs]
|
||||
|
||||
Examples:
|
||||
>>> def forward(self, input_dict, state, seq_lens):
|
||||
>>> model_out, self._value_out = self.base_model(
|
||||
... input_dict["obs"])
|
||||
>>> return model_out, state
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def update_ops(self):
|
||||
"""Return the list of update ops for this model.
|
||||
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
from ray.rllib.models.model import Model
|
||||
from ray.rllib.models.tf.misc import flatten
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils.deprecation import deprecation_warning
|
||||
from ray.rllib.utils.framework import get_activation_fn, try_import_tf
|
||||
|
||||
tf = try_import_tf()
|
||||
|
@ -12,6 +13,10 @@ class VisionNetwork(Model):
|
|||
|
||||
@override(Model)
|
||||
def _build_layers_v2(self, input_dict, num_outputs, options):
|
||||
# Hard deprecate this class. All Models should use the ModelV2
|
||||
# API from here on.
|
||||
deprecation_warning(
|
||||
"Model->VisionNetwork", "ModelV2->VisionNetwork", error=False)
|
||||
inputs = input_dict["obs"]
|
||||
filters = options.get("conv_filters")
|
||||
if not filters:
|
||||
|
|
|
@ -11,7 +11,7 @@ torch, nn = try_import_torch()
|
|||
|
||||
@DeveloperAPI
|
||||
class RecurrentTorchModel(TorchModelV2, nn.Module):
|
||||
"""Helper class to simplify implementing RNN models with TFModelV2.
|
||||
"""Helper class to simplify implementing RNN models with TorchModelV2.
|
||||
|
||||
Instead of implementing forward(), you can implement forward_rnn() which
|
||||
takes batches with the time dimension added already.
|
||||
|
|
|
@ -30,7 +30,7 @@ class TorchModelV2(ModelV2):
|
|||
if not isinstance(self, nn.Module):
|
||||
raise ValueError(
|
||||
"Subclasses of TorchModelV2 must also inherit from "
|
||||
"nn.Module, e.g., MyModel(TorchModelV2, nn.Module)")
|
||||
"nn.Module, e.g., MyModel(TorchModel, nn.Module)")
|
||||
|
||||
ModelV2.__init__(
|
||||
self,
|
||||
|
@ -41,39 +41,6 @@ class TorchModelV2(ModelV2):
|
|||
name,
|
||||
framework="torch")
|
||||
|
||||
def forward(self, input_dict, state, seq_lens):
|
||||
"""Call the model with the given input tensors and state.
|
||||
|
||||
Any complex observations (dicts, tuples, etc.) will be unpacked by
|
||||
__call__ before being passed to forward(). To access the flattened
|
||||
observation tensor, refer to input_dict["obs_flat"].
|
||||
|
||||
This method can be called any number of times. In eager execution,
|
||||
each call to forward() will eagerly evaluate the model. In symbolic
|
||||
execution, each call to forward creates a computation graph that
|
||||
operates over the variables of this model (i.e., shares weights).
|
||||
|
||||
Custom models should override this instead of __call__.
|
||||
|
||||
Args:
|
||||
input_dict (dict): dictionary of input tensors, including "obs",
|
||||
"obs_flat", "prev_action", "prev_reward", "is_training"
|
||||
state (list): list of state tensors with sizes matching those
|
||||
returned by get_initial_state + the batch dimension
|
||||
seq_lens (Tensor): 1d tensor holding input sequence lengths
|
||||
|
||||
Returns:
|
||||
(outputs, state): The model output tensor of size
|
||||
[BATCH, num_outputs]
|
||||
|
||||
Examples:
|
||||
>>> def forward(self, input_dict, state, seq_lens):
|
||||
>>> features = self._hidden_layers(input_dict["obs"])
|
||||
>>> self._value_out = self._value_branch(features)
|
||||
>>> return self._logits(features), state
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@override(ModelV2)
|
||||
def variables(self, as_dict=False):
|
||||
if as_dict:
|
||||
|
|
|
@ -5,12 +5,12 @@ import unittest
|
|||
|
||||
import ray
|
||||
from ray.rllib.models import ModelCatalog, MODEL_DEFAULTS, ActionDistribution
|
||||
from ray.rllib.models.model import Model
|
||||
from ray.rllib.models.tf.tf_modelv2 import TFModelV2
|
||||
from ray.rllib.models.tf.tf_action_dist import TFActionDistribution
|
||||
from ray.rllib.models.preprocessors import (NoPreprocessor, OneHotPreprocessor,
|
||||
Preprocessor)
|
||||
from ray.rllib.models.tf.fcnet_v1 import FullyConnectedNetwork
|
||||
from ray.rllib.models.tf.visionnet_v1 import VisionNetwork
|
||||
from ray.rllib.models.tf.fcnet_v2 import FullyConnectedNetwork
|
||||
from ray.rllib.models.tf.visionnet_v2 import VisionNetwork
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils.framework import try_import_tf
|
||||
|
||||
|
@ -27,7 +27,7 @@ class CustomPreprocessor2(Preprocessor):
|
|||
return [1]
|
||||
|
||||
|
||||
class CustomModel(Model):
|
||||
class CustomModel(TFModelV2):
|
||||
def _build_layers(self, *args):
|
||||
return tf.constant([[0] * 5]), None
|
||||
|
||||
|
@ -101,25 +101,29 @@ class ModelCatalogTest(unittest.TestCase):
|
|||
ray.init(object_store_memory=1000 * 1024 * 1024)
|
||||
|
||||
with tf.variable_scope("test1"):
|
||||
p1 = ModelCatalog.get_model({
|
||||
"obs": tf.zeros((10, 3), dtype=tf.float32)
|
||||
}, Box(0, 1, shape=(3, ), dtype=np.float32), Discrete(5), 5, {})
|
||||
p1 = ModelCatalog.get_model_v2(
|
||||
obs_space=Box(0, 1, shape=(3, ), dtype=np.float32),
|
||||
action_space=Discrete(5),
|
||||
num_outputs=5,
|
||||
model_config={})
|
||||
self.assertEqual(type(p1), FullyConnectedNetwork)
|
||||
|
||||
with tf.variable_scope("test2"):
|
||||
p2 = ModelCatalog.get_model({
|
||||
"obs": tf.zeros((10, 84, 84, 3), dtype=tf.float32)
|
||||
}, Box(0, 1, shape=(84, 84, 3), dtype=np.float32), Discrete(5), 5,
|
||||
{})
|
||||
p2 = ModelCatalog.get_model_v2(
|
||||
obs_space=Box(0, 1, shape=(84, 84, 3), dtype=np.float32),
|
||||
action_space=Discrete(5),
|
||||
num_outputs=5,
|
||||
model_config={})
|
||||
self.assertEqual(type(p2), VisionNetwork)
|
||||
|
||||
def test_custom_model(self):
|
||||
ray.init(object_store_memory=1000 * 1024 * 1024)
|
||||
ModelCatalog.register_custom_model("foo", CustomModel)
|
||||
p1 = ModelCatalog.get_model({
|
||||
"obs": tf.constant([1, 2, 3])
|
||||
}, Box(0, 1, shape=(3, ), dtype=np.float32), Discrete(5), 5,
|
||||
{"custom_model": "foo"})
|
||||
p1 = ModelCatalog.get_model_v2(
|
||||
obs_space=Box(0, 1, shape=(3, ), dtype=np.float32),
|
||||
action_space=Discrete(5),
|
||||
num_outputs=5,
|
||||
model_config={"custom_model": "foo"})
|
||||
self.assertEqual(str(type(p1)), str(CustomModel))
|
||||
|
||||
def test_custom_action_distribution(self):
|
||||
|
|
|
@ -5,13 +5,14 @@ import unittest
|
|||
|
||||
import ray
|
||||
from ray.rllib.agents.ppo import PPOTrainer
|
||||
from ray.rllib.policy.rnn_sequencing import chop_into_sequences, \
|
||||
add_time_dimension
|
||||
from ray.rllib.policy.rnn_sequencing import chop_into_sequences
|
||||
from ray.rllib.models import ModelCatalog
|
||||
from ray.rllib.models.tf.misc import linear, normc_initializer
|
||||
from ray.rllib.models.model import Model
|
||||
from ray.rllib.models.modelv2 import ModelV2
|
||||
from ray.rllib.models.tf.misc import normc_initializer
|
||||
from ray.rllib.models.tf.recurrent_tf_modelv2 import RecurrentTFModelV2
|
||||
from ray.tune.registry import register_env
|
||||
from ray.rllib.utils import try_import_tf
|
||||
from ray.rllib.utils.annotations import override
|
||||
|
||||
tf = try_import_tf()
|
||||
|
||||
|
@ -91,83 +92,109 @@ class TestLSTMUtils(unittest.TestCase):
|
|||
self.assertEqual(seq_lens.tolist(), [1, 2])
|
||||
|
||||
|
||||
class RNNSpyModel(Model):
|
||||
class RNNSpyModel(RecurrentTFModelV2):
|
||||
capture_index = 0
|
||||
cell_size = 3
|
||||
|
||||
def _build_layers_v2(self, input_dict, num_outputs, options):
|
||||
# Previously, a new class object was created during
|
||||
# deserialization and this `capture_index`
|
||||
# variable would be refreshed between class instantiations.
|
||||
# This behavior is no longer the case, so we manually refresh
|
||||
# the variable.
|
||||
RNNSpyModel.capture_index = 0
|
||||
def __init__(self, obs_space, action_space, num_outputs, model_config,
|
||||
name):
|
||||
super().__init__(obs_space, action_space, num_outputs, model_config,
|
||||
name)
|
||||
self.cell_size = RNNSpyModel.cell_size
|
||||
|
||||
def spy(sequences, state_in, state_out, seq_lens):
|
||||
if len(sequences) == 1:
|
||||
def spy(inputs, seq_lens, h_in, c_in, h_out, c_out):
|
||||
if len(inputs) == 1:
|
||||
return 0 # don't capture inference inputs
|
||||
# TF runs this function in an isolated context, so we have to use
|
||||
# redis to communicate back to our suite
|
||||
ray.experimental.internal_kv._internal_kv_put(
|
||||
"rnn_spy_in_{}".format(RNNSpyModel.capture_index),
|
||||
pickle.dumps({
|
||||
"sequences": sequences,
|
||||
"state_in": state_in,
|
||||
"state_out": state_out,
|
||||
"seq_lens": seq_lens
|
||||
"sequences": inputs,
|
||||
"seq_lens": seq_lens,
|
||||
"state_in": [h_in, c_in],
|
||||
"state_out": [h_out, c_out]
|
||||
}),
|
||||
overwrite=True)
|
||||
RNNSpyModel.capture_index += 1
|
||||
return 0
|
||||
|
||||
features = input_dict["obs"]
|
||||
cell_size = 3
|
||||
last_layer = add_time_dimension(features, self.seq_lens)
|
||||
# Create a keras LSTM model.
|
||||
inputs = tf.keras.layers.Input(
|
||||
shape=(None, ) + obs_space.shape, name="input")
|
||||
state_in_h = tf.keras.layers.Input(shape=(self.cell_size, ), name="h")
|
||||
state_in_c = tf.keras.layers.Input(shape=(self.cell_size, ), name="c")
|
||||
seq_lens = tf.keras.layers.Input(
|
||||
shape=(), name="seq_lens", dtype=tf.int32)
|
||||
|
||||
# Setup the LSTM cell
|
||||
lstm = tf.nn.rnn_cell.BasicLSTMCell(cell_size, state_is_tuple=True)
|
||||
self.state_init = [
|
||||
np.zeros(lstm.state_size.c, np.float32),
|
||||
np.zeros(lstm.state_size.h, np.float32)
|
||||
lstm_out, state_out_h, state_out_c = tf.keras.layers.LSTM(
|
||||
self.cell_size,
|
||||
return_sequences=True,
|
||||
return_state=True,
|
||||
name="lstm")(
|
||||
inputs=inputs,
|
||||
mask=tf.sequence_mask(seq_lens),
|
||||
initial_state=[state_in_h, state_in_c])
|
||||
self.dense = tf.keras.layers.Dense(
|
||||
units=self.num_outputs, kernel_initializer=normc_initializer(0.01))
|
||||
|
||||
def lambda_(inputs):
|
||||
spy_fn = tf.py_func(
|
||||
spy,
|
||||
[
|
||||
inputs[0], # observations
|
||||
inputs[2], # seq_lens
|
||||
inputs[3], # h_in
|
||||
inputs[4], # c_in
|
||||
inputs[5], # h_out
|
||||
inputs[6], # c_out
|
||||
],
|
||||
tf.int64,
|
||||
stateful=True)
|
||||
|
||||
# Compute outputs
|
||||
with tf.control_dependencies([spy_fn]):
|
||||
return self.dense(inputs[1]) # lstm_out
|
||||
|
||||
logits = tf.keras.layers.Lambda(lambda_)([
|
||||
inputs, lstm_out, seq_lens, state_in_h, state_in_c, state_out_h,
|
||||
state_out_c
|
||||
])
|
||||
|
||||
# Value branch.
|
||||
value_out = tf.keras.layers.Dense(
|
||||
units=1, kernel_initializer=normc_initializer(1.0))(lstm_out)
|
||||
|
||||
self.base_model = tf.keras.Model(
|
||||
[inputs, seq_lens, state_in_h, state_in_c],
|
||||
[logits, value_out, state_out_h, state_out_c])
|
||||
self.base_model.summary()
|
||||
self.register_variables(self.base_model.variables)
|
||||
|
||||
@override(RecurrentTFModelV2)
|
||||
def forward_rnn(self, inputs, state, seq_lens):
|
||||
# Previously, a new class object was created during
|
||||
# deserialization and this `capture_index`
|
||||
# variable would be refreshed between class instantiations.
|
||||
# This behavior is no longer the case, so we manually refresh
|
||||
# the variable.
|
||||
RNNSpyModel.capture_index = 0
|
||||
model_out, value_out, h, c = self.base_model(
|
||||
[inputs, seq_lens, state[0], state[1]])
|
||||
self._value_out = value_out
|
||||
return model_out, [h, c]
|
||||
|
||||
@override(ModelV2)
|
||||
def value_function(self):
|
||||
return tf.reshape(self._value_out, [-1])
|
||||
|
||||
@override(ModelV2)
|
||||
def get_initial_state(self):
|
||||
return [
|
||||
np.zeros(self.cell_size, np.float32),
|
||||
np.zeros(self.cell_size, np.float32)
|
||||
]
|
||||
|
||||
# Setup LSTM inputs
|
||||
if self.state_in:
|
||||
c_in, h_in = self.state_in
|
||||
else:
|
||||
c_in = tf.placeholder(
|
||||
tf.float32, [None, lstm.state_size.c], name="c")
|
||||
h_in = tf.placeholder(
|
||||
tf.float32, [None, lstm.state_size.h], name="h")
|
||||
self.state_in = [c_in, h_in]
|
||||
|
||||
# Setup LSTM outputs
|
||||
state_in = tf.nn.rnn_cell.LSTMStateTuple(c_in, h_in)
|
||||
lstm_out, lstm_state = tf.nn.dynamic_rnn(
|
||||
lstm,
|
||||
last_layer,
|
||||
initial_state=state_in,
|
||||
sequence_length=self.seq_lens,
|
||||
time_major=False,
|
||||
dtype=tf.float32)
|
||||
|
||||
self.state_out = list(lstm_state)
|
||||
spy_fn = tf.py_func(
|
||||
spy, [
|
||||
last_layer,
|
||||
self.state_in,
|
||||
self.state_out,
|
||||
self.seq_lens,
|
||||
],
|
||||
tf.int64,
|
||||
stateful=True)
|
||||
|
||||
# Compute outputs
|
||||
with tf.control_dependencies([spy_fn]):
|
||||
last_layer = tf.reshape(lstm_out, [-1, cell_size])
|
||||
logits = linear(last_layer, num_outputs, "action",
|
||||
normc_initializer(0.01))
|
||||
return logits, last_layer
|
||||
|
||||
|
||||
class DebugCounterEnv(gym.Env):
|
||||
def __init__(self):
|
||||
|
|
|
@ -12,7 +12,7 @@ from ray.rllib.env import MultiAgentEnv
|
|||
from ray.rllib.env.base_env import BaseEnv
|
||||
from ray.rllib.env.vector_env import VectorEnv
|
||||
from ray.rllib.models import ModelCatalog
|
||||
from ray.rllib.models.model import Model
|
||||
from ray.rllib.models.tf.tf_modelv2 import TFModelV2
|
||||
from ray.rllib.models.torch.fcnet import FullyConnectedNetwork
|
||||
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
|
||||
from ray.rllib.rollout import rollout
|
||||
|
@ -119,13 +119,13 @@ class NestedMultiAgentEnv(MultiAgentEnv):
|
|||
return obs, rew, dones, infos
|
||||
|
||||
|
||||
class InvalidModel(Model):
|
||||
def _build_layers_v2(self, input_dict, num_outputs, options):
|
||||
class InvalidModel(TorchModelV2):
|
||||
def forward(self, input_dict, state, seq_lens):
|
||||
return "not", "valid"
|
||||
|
||||
|
||||
class InvalidModel2(Model):
|
||||
def _build_layers_v2(self, input_dict, num_outputs, options):
|
||||
class InvalidModel2(TFModelV2):
|
||||
def forward(self, input_dict, state, seq_lens):
|
||||
return tf.constant(0), tf.constant(0)
|
||||
|
||||
|
||||
|
@ -158,10 +158,10 @@ class TorchSpyModel(TorchModelV2, nn.Module):
|
|||
return self.fc.value_function()
|
||||
|
||||
|
||||
class DictSpyModel(Model):
|
||||
class DictSpyModel(TFModelV2):
|
||||
capture_index = 0
|
||||
|
||||
def _build_layers_v2(self, input_dict, num_outputs, options):
|
||||
def forward(self, input_dict, state, seq_lens):
|
||||
def spy(pos, front_cam, task):
|
||||
# TF runs this function in an isolated context, so we have to use
|
||||
# redis to communicate back to our suite
|
||||
|
@ -183,14 +183,14 @@ class DictSpyModel(Model):
|
|||
|
||||
with tf.control_dependencies([spy_fn]):
|
||||
output = tf.layers.dense(input_dict["obs"]["sensors"]["position"],
|
||||
num_outputs)
|
||||
return output, output
|
||||
self.num_outputs)
|
||||
return output, []
|
||||
|
||||
|
||||
class TupleSpyModel(Model):
|
||||
class TupleSpyModel(TFModelV2):
|
||||
capture_index = 0
|
||||
|
||||
def _build_layers_v2(self, input_dict, num_outputs, options):
|
||||
def forward(self, input_dict, state, seq_lens):
|
||||
def spy(pos, cam, task):
|
||||
# TF runs this function in an isolated context, so we have to use
|
||||
# redis to communicate back to our suite
|
||||
|
@ -211,8 +211,8 @@ class TupleSpyModel(Model):
|
|||
stateful=True)
|
||||
|
||||
with tf.control_dependencies([spy_fn]):
|
||||
output = tf.layers.dense(input_dict["obs"][0], num_outputs)
|
||||
return output, output
|
||||
output = tf.layers.dense(input_dict["obs"][0], self.num_outputs)
|
||||
return output, []
|
||||
|
||||
|
||||
class NestedSpacesTest(unittest.TestCase):
|
||||
|
@ -226,17 +226,20 @@ class NestedSpacesTest(unittest.TestCase):
|
|||
|
||||
def test_invalid_model(self):
|
||||
ModelCatalog.register_custom_model("invalid", InvalidModel)
|
||||
self.assertRaises(ValueError, lambda: PGTrainer(
|
||||
env="CartPole-v0", config={
|
||||
"model": {
|
||||
"custom_model": "invalid",
|
||||
},
|
||||
}))
|
||||
self.assertRaisesRegexp(
|
||||
ValueError,
|
||||
"Subclasses of TorchModelV2 must also inherit from",
|
||||
lambda: PGTrainer(
|
||||
env="CartPole-v0", config={
|
||||
"model": {
|
||||
"custom_model": "invalid",
|
||||
},
|
||||
}))
|
||||
|
||||
def test_invalid_model2(self):
|
||||
ModelCatalog.register_custom_model("invalid2", InvalidModel2)
|
||||
self.assertRaisesRegexp(
|
||||
ValueError, "Expected output.*",
|
||||
ValueError, "Expected output shape of",
|
||||
lambda: PGTrainer(
|
||||
env="CartPole-v0", config={
|
||||
"model": {
|
||||
|
|
|
@ -86,7 +86,8 @@ def make_tf_callable(session_or_none, dynamic_shape=False):
|
|||
name="arg_{}".format(i)))
|
||||
symbolic_out[0] = fn(*placeholders)
|
||||
feed_dict = dict(zip(placeholders, args))
|
||||
return session_or_none.run(symbolic_out[0], feed_dict)
|
||||
ret = session_or_none.run(symbolic_out[0], feed_dict)
|
||||
return ret
|
||||
|
||||
return call
|
||||
else:
|
||||
|
|
Loading…
Add table
Reference in a new issue