[RLlib] Fix failing test cases: Soft-deprecate ModelV2.from_batch (in favor of ModelV2.__call__). (#19693)

This commit is contained in:
Sven Mika 2021-10-25 15:00:00 +02:00 committed by GitHub
parent 6e455e59d8
commit b213565783
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
25 changed files with 46 additions and 48 deletions

View file

@ -73,7 +73,7 @@ class A3CLoss:
def actor_critic_loss(policy: Policy, model: ModelV2,
dist_class: ActionDistribution,
train_batch: SampleBatch) -> TensorType:
model_out, _ = model.from_batch(train_batch)
model_out, _ = model(train_batch)
action_dist = dist_class(model_out, model)
if policy.is_recurrent():
max_seq_len = tf.reduce_max(train_batch[SampleBatch.SEQ_LENS])

View file

@ -39,7 +39,7 @@ def add_advantages(
def actor_critic_loss(policy: Policy, model: ModelV2,
dist_class: ActionDistribution,
train_batch: SampleBatch) -> TensorType:
logits, _ = model.from_batch(train_batch)
logits, _ = model(train_batch)
values = model.value_function()
if policy.is_recurrent():

View file

@ -212,12 +212,12 @@ def build_q_model(policy: Policy, obs_space: gym.spaces.Space,
def get_distribution_inputs_and_class(policy: Policy,
model: ModelV2,
obs_batch: TensorType,
input_dict: SampleBatch,
*,
explore=True,
**kwargs):
q_vals = compute_q_values(
policy, model, {"obs": obs_batch}, state_batches=None, explore=explore)
policy, model, input_dict, state_batches=None, explore=explore)
q_vals = q_vals[0] if isinstance(q_vals, tuple) else q_vals
policy.q_values = q_vals
@ -342,7 +342,6 @@ def compute_q_values(policy: Policy,
config = policy.config
input_dict.is_training = policy._get_is_training_placeholder()
model_out, state = model(input_dict, state_batches or [], seq_lens)
if config["num_atoms"] > 1:

View file

@ -204,16 +204,13 @@ def build_q_model_and_distribution(
def get_distribution_inputs_and_class(
policy: Policy,
model: ModelV2,
obs_batch: TensorType,
input_dict: SampleBatch,
*,
explore: bool = True,
is_training: bool = False,
**kwargs) -> Tuple[TensorType, type, List[TensorType]]:
q_vals = compute_q_values(
policy,
model, {"obs": obs_batch},
explore=explore,
is_training=is_training)
policy, model, input_dict, explore=explore, is_training=is_training)
q_vals = q_vals[0] if isinstance(q_vals, tuple) else q_vals
model.tower_stats["q_values"] = q_vals
@ -350,7 +347,6 @@ def compute_q_values(policy: Policy,
is_training: bool = False):
config = policy.config
input_dict.is_training = is_training
model_out, state = model(input_dict, state_batches or [], seq_lens)
if config["num_atoms"] > 1:

View file

@ -161,7 +161,7 @@ def _make_time_major(policy, seq_lens, tensor, drop_last=False):
def build_vtrace_loss(policy, model, dist_class, train_batch):
model_out, _ = model.from_batch(train_batch)
model_out, _ = model(train_batch)
action_dist = dist_class(model_out, model)
if isinstance(policy.action_space, gym.spaces.Discrete):

View file

@ -113,7 +113,7 @@ class VTraceLoss:
def build_vtrace_loss(policy, model, dist_class, train_batch):
model_out, _ = model.from_batch(train_batch)
model_out, _ = model(train_batch)
action_dist = dist_class(model_out, model)
if isinstance(policy.action_space, gym.spaces.Discrete):

View file

@ -308,7 +308,7 @@ class MAMLLoss(object):
def maml_loss(policy, model, dist_class, train_batch):
logits, state = model.from_batch(train_batch)
logits, state = model(train_batch)
policy.cur_lr = policy.config["lr"]
if policy.config["worker_index"]:

View file

@ -246,7 +246,7 @@ class MAMLLoss(object):
def maml_loss(policy, model, dist_class, train_batch):
logits, state = model.from_batch(train_batch)
logits, state = model(train_batch)
policy.cur_lr = policy.config["lr"]
if policy.config["worker_index"]:

View file

@ -30,7 +30,7 @@ class ValueNetworkMixin:
# input_dict.
@make_tf_callable(self.get_session())
def value(**input_dict):
model_out, _ = self.model.from_batch(input_dict, is_training=False)
model_out, _ = self.model(input_dict)
# [0] = remove the batch dim.
return self.model.value_function()[0]
@ -150,7 +150,7 @@ class MARWILLoss:
def marwil_loss(policy: Policy, model: ModelV2, dist_class: ActionDistribution,
train_batch: SampleBatch) -> TensorType:
model_out, _ = model.from_batch(train_batch)
model_out, _ = model(train_batch)
action_dist = dist_class(model_out, model)
value_estimates = model.value_function()

View file

@ -19,7 +19,7 @@ torch, _ = try_import_torch()
def marwil_loss(policy: Policy, model: ModelV2, dist_class: ActionDistribution,
train_batch: SampleBatch) -> TensorType:
model_out, _ = model.from_batch(train_batch)
model_out, _ = model(train_batch)
action_dist = dist_class(model_out, model)
actions = train_batch[SampleBatch.ACTIONS]
# log\pi_\theta(a|s)

View file

@ -149,7 +149,7 @@ class TestMARWIL(unittest.TestCase):
cummulative_rewards = torch.tensor(cummulative_rewards)
if fw != "tf":
batch = policy._lazy_tensor_dict(batch)
model_out, _ = model.from_batch(batch)
model_out, _ = model(batch)
vf_estimates = model.value_function()
if fw == "tf":
model_out, vf_estimates = \

View file

@ -34,7 +34,7 @@ def pg_tf_loss(
of loss tensors.
"""
# Pass the training data through our model to get distribution parameters.
dist_inputs, _ = model.from_batch(train_batch)
dist_inputs, _ = model(train_batch)
# Create an action distribution object.
action_dist = dist_class(dist_inputs, model)

View file

@ -35,7 +35,7 @@ def pg_torch_loss(
of loss tensors.
"""
# Pass the training data through our model to get distribution parameters.
dist_inputs, _ = model.from_batch(train_batch)
dist_inputs, _ = model(train_batch)
# Create an action distribution object.
action_dist = dist_class(dist_inputs, model)

View file

@ -99,7 +99,7 @@ def appo_surrogate_loss(
Union[TensorType, List[TensorType]]: A single loss tensor or a list
of loss tensors.
"""
model_out, _ = model.from_batch(train_batch)
model_out, _ = model(train_batch)
action_dist = dist_class(model_out, model)
if isinstance(policy.action_space, gym.spaces.Discrete):
@ -123,7 +123,7 @@ def appo_surrogate_loss(
rewards = train_batch[SampleBatch.REWARDS]
behaviour_logits = train_batch[SampleBatch.ACTION_DIST_INPUTS]
target_model_out, _ = policy.target_model.from_batch(train_batch)
target_model_out, _ = policy.target_model(train_batch)
prev_action_dist = dist_class(behaviour_logits, policy.model)
values = policy.model.value_function()
values_time_major = make_time_major(values)

View file

@ -56,7 +56,7 @@ def appo_surrogate_loss(policy: Policy, model: ModelV2,
"""
target_model = policy.target_models[model]
model_out, _ = model.from_batch(train_batch)
model_out, _ = model(train_batch)
action_dist = dist_class(model_out, model)
if isinstance(policy.action_space, gym.spaces.Discrete):
@ -79,7 +79,7 @@ def appo_surrogate_loss(policy: Policy, model: ModelV2,
rewards = train_batch[SampleBatch.REWARDS]
behaviour_logits = train_batch[SampleBatch.ACTION_DIST_INPUTS]
target_model_out, _ = target_model.from_batch(train_batch)
target_model_out, _ = target_model(train_batch)
prev_action_dist = dist_class(behaviour_logits, model)
values = model.value_function()

View file

@ -50,7 +50,7 @@ def ppo_surrogate_loss(
logits, state, extra_outs = model(train_batch)
value_fn_out = extra_outs[SampleBatch.VF_PREDS]
else:
logits, state = model.from_batch(train_batch)
logits, state = model(train_batch)
value_fn_out = model.value_function()
curr_action_dist = dist_class(logits, model)

View file

@ -16,7 +16,7 @@ parser.add_argument("--num-cpus", type=int, default=0)
def policy_gradient_loss(policy, model, dist_class, train_batch):
logits, _ = model.from_batch(train_batch)
logits, _ = model(train_batch)
action_dist = dist_class(logits, model)
return -tf.reduce_mean(
action_dist.logp(train_batch["actions"]) * train_batch["returns"])

View file

@ -80,7 +80,7 @@ def policy_gradient_loss(policy, model, dist_class, train_batch):
print("The eagerly computed penalty is", penalty, actions, rewards)
return penalty
logits, _ = model.from_batch(train_batch)
logits, _ = model(train_batch)
action_dist = dist_class(logits, model)
actions = train_batch[SampleBatch.ACTIONS]

View file

@ -33,7 +33,7 @@ trainer = ppo.PPOTrainer(config=config, env="CartPole-v0")
# Let's run inference on the tensorflow model
policy = trainer.get_policy()
result_tf, _ = policy.model.from_batch(test_data)
result_tf, _ = policy.model(test_data)
# Evaluate tensor to fetch numpy array
with policy._sess.as_default():

View file

@ -35,7 +35,7 @@ trainer = ppo.PPOTrainer(config=config, env="CartPole-v0")
# Let's run inference on the torch model
policy = trainer.get_policy()
result_pytorch, _ = policy.model.from_batch({
result_pytorch, _ = policy.model({
"obs": torch.tensor(test_data["obs"]),
})

View file

@ -146,7 +146,7 @@ def run_with_custom_entropy_loss(args, stop):
This performs about the same as the default loss does."""
def entropy_policy_gradient_loss(policy, model, dist_class, train_batch):
logits, _ = model.from_batch(train_batch)
logits, _ = model(train_batch)
action_dist = dist_class(logits, model)
if args.framework == "torch":
# Required by PGTorchPolicy's stats fn.

View file

@ -10,7 +10,7 @@ from ray.rllib.models.repeated_values import RepeatedValues
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.policy.view_requirement import ViewRequirement
from ray.rllib.utils import NullContextManager
from ray.rllib.utils.annotations import DeveloperAPI, PublicAPI
from ray.rllib.utils.annotations import Deprecated, DeveloperAPI, PublicAPI
from ray.rllib.utils.framework import try_import_tf, try_import_torch, \
TensorType
from ray.rllib.utils.spaces.repeated import Repeated
@ -204,18 +204,19 @@ class ModelV2:
# where tensors get automatically converted).
if isinstance(input_dict, SampleBatch):
restored = input_dict.copy(shallow=True)
# Backward compatibility.
if seq_lens is None:
seq_lens = input_dict.get(SampleBatch.SEQ_LENS)
if not state:
state = []
i = 0
while "state_in_{}".format(i) in input_dict:
state.append(input_dict["state_in_{}".format(i)])
i += 1
else:
restored = input_dict.copy()
# Backward compatibility.
if not state:
state = []
i = 0
while "state_in_{}".format(i) in input_dict:
state.append(input_dict["state_in_{}".format(i)])
i += 1
if seq_lens is None:
seq_lens = input_dict.get(SampleBatch.SEQ_LENS)
# No Preprocessor used: `config._disable_preprocessor_api`=True.
# TODO: This is unnecessary for when no preprocessor is used.
# Obs are not flat then anymore. However, we'll keep this
@ -255,9 +256,7 @@ class ModelV2:
self._last_output = outputs
return outputs, state_out if len(state_out) > 0 else (state or [])
# TODO: (sven) obsolete this method at some point (replace by
# simply calling model directly with a sample_batch as only input).
@PublicAPI
@Deprecated(new="ModelV2.__call__()", error=False)
def from_batch(self, train_batch: SampleBatch,
is_training: bool = True) -> (TensorType, List[TensorType]):
"""Convenience function that calls this model with a tensor batch.
@ -267,7 +266,7 @@ class ModelV2:
"""
input_dict = train_batch.copy()
input_dict.is_training = is_training
input_dict["is_training"] = is_training
states = []
i = 0
while "state_in_{}".format(i) in input_dict:

View file

@ -621,7 +621,9 @@ class DynamicTFPolicy(TFPolicy):
)
train_batch = SampleBatch(
dict(self._input_dict, **self._loss_input_dict))
dict(self._input_dict, **self._loss_input_dict),
_is_training=True,
)
if self._state_inputs:
train_batch[SampleBatch.SEQ_LENS] = self._seq_lens

View file

@ -264,7 +264,8 @@ class TorchPolicy(Policy):
with torch.no_grad():
seq_lens = torch.ones(len(obs_batch), dtype=torch.int32)
input_dict = self._lazy_tensor_dict({
SampleBatch.CUR_OBS: obs_batch
SampleBatch.CUR_OBS: obs_batch,
"is_training": False,
})
if prev_action_batch is not None:
input_dict[SampleBatch.PREV_ACTIONS] = \
@ -291,6 +292,7 @@ class TorchPolicy(Policy):
with torch.no_grad():
# Pass lazy (torch) tensor dict to Model as `input_dict`.
input_dict = self._lazy_tensor_dict(input_dict)
input_dict.is_training = False
# Pack internal state inputs into (separate) list.
state_batches = [
input_dict[k] for k in input_dict.keys() if "state_in" in k[:8]

View file

@ -107,7 +107,7 @@ ModelGradients = Union[List[Tuple[TensorType, TensorType]], List[TensorType]]
# Type of dict returned by get_weights() representing model weights.
ModelWeights = dict
# An input dict used for direct ModelV2 calls or `ModelV2.from_batch` calls.
# An input dict used for direct ModelV2 calls.
ModelInputDict = Dict[str, TensorType]
# Some kind of sample batch.