mirror of
https://github.com/vale981/ray
synced 2025-03-05 18:11:42 -05:00
[RLlib] Fix failing test cases: Soft-deprecate ModelV2.from_batch (in favor of ModelV2.__call__). (#19693)
This commit is contained in:
parent
6e455e59d8
commit
b213565783
25 changed files with 46 additions and 48 deletions
|
@ -73,7 +73,7 @@ class A3CLoss:
|
|||
def actor_critic_loss(policy: Policy, model: ModelV2,
|
||||
dist_class: ActionDistribution,
|
||||
train_batch: SampleBatch) -> TensorType:
|
||||
model_out, _ = model.from_batch(train_batch)
|
||||
model_out, _ = model(train_batch)
|
||||
action_dist = dist_class(model_out, model)
|
||||
if policy.is_recurrent():
|
||||
max_seq_len = tf.reduce_max(train_batch[SampleBatch.SEQ_LENS])
|
||||
|
|
|
@ -39,7 +39,7 @@ def add_advantages(
|
|||
def actor_critic_loss(policy: Policy, model: ModelV2,
|
||||
dist_class: ActionDistribution,
|
||||
train_batch: SampleBatch) -> TensorType:
|
||||
logits, _ = model.from_batch(train_batch)
|
||||
logits, _ = model(train_batch)
|
||||
values = model.value_function()
|
||||
|
||||
if policy.is_recurrent():
|
||||
|
|
|
@ -212,12 +212,12 @@ def build_q_model(policy: Policy, obs_space: gym.spaces.Space,
|
|||
|
||||
def get_distribution_inputs_and_class(policy: Policy,
|
||||
model: ModelV2,
|
||||
obs_batch: TensorType,
|
||||
input_dict: SampleBatch,
|
||||
*,
|
||||
explore=True,
|
||||
**kwargs):
|
||||
q_vals = compute_q_values(
|
||||
policy, model, {"obs": obs_batch}, state_batches=None, explore=explore)
|
||||
policy, model, input_dict, state_batches=None, explore=explore)
|
||||
q_vals = q_vals[0] if isinstance(q_vals, tuple) else q_vals
|
||||
|
||||
policy.q_values = q_vals
|
||||
|
@ -342,7 +342,6 @@ def compute_q_values(policy: Policy,
|
|||
|
||||
config = policy.config
|
||||
|
||||
input_dict.is_training = policy._get_is_training_placeholder()
|
||||
model_out, state = model(input_dict, state_batches or [], seq_lens)
|
||||
|
||||
if config["num_atoms"] > 1:
|
||||
|
|
|
@ -204,16 +204,13 @@ def build_q_model_and_distribution(
|
|||
def get_distribution_inputs_and_class(
|
||||
policy: Policy,
|
||||
model: ModelV2,
|
||||
obs_batch: TensorType,
|
||||
input_dict: SampleBatch,
|
||||
*,
|
||||
explore: bool = True,
|
||||
is_training: bool = False,
|
||||
**kwargs) -> Tuple[TensorType, type, List[TensorType]]:
|
||||
q_vals = compute_q_values(
|
||||
policy,
|
||||
model, {"obs": obs_batch},
|
||||
explore=explore,
|
||||
is_training=is_training)
|
||||
policy, model, input_dict, explore=explore, is_training=is_training)
|
||||
q_vals = q_vals[0] if isinstance(q_vals, tuple) else q_vals
|
||||
|
||||
model.tower_stats["q_values"] = q_vals
|
||||
|
@ -350,7 +347,6 @@ def compute_q_values(policy: Policy,
|
|||
is_training: bool = False):
|
||||
config = policy.config
|
||||
|
||||
input_dict.is_training = is_training
|
||||
model_out, state = model(input_dict, state_batches or [], seq_lens)
|
||||
|
||||
if config["num_atoms"] > 1:
|
||||
|
|
|
@ -161,7 +161,7 @@ def _make_time_major(policy, seq_lens, tensor, drop_last=False):
|
|||
|
||||
|
||||
def build_vtrace_loss(policy, model, dist_class, train_batch):
|
||||
model_out, _ = model.from_batch(train_batch)
|
||||
model_out, _ = model(train_batch)
|
||||
action_dist = dist_class(model_out, model)
|
||||
|
||||
if isinstance(policy.action_space, gym.spaces.Discrete):
|
||||
|
|
|
@ -113,7 +113,7 @@ class VTraceLoss:
|
|||
|
||||
|
||||
def build_vtrace_loss(policy, model, dist_class, train_batch):
|
||||
model_out, _ = model.from_batch(train_batch)
|
||||
model_out, _ = model(train_batch)
|
||||
action_dist = dist_class(model_out, model)
|
||||
|
||||
if isinstance(policy.action_space, gym.spaces.Discrete):
|
||||
|
|
|
@ -308,7 +308,7 @@ class MAMLLoss(object):
|
|||
|
||||
|
||||
def maml_loss(policy, model, dist_class, train_batch):
|
||||
logits, state = model.from_batch(train_batch)
|
||||
logits, state = model(train_batch)
|
||||
policy.cur_lr = policy.config["lr"]
|
||||
|
||||
if policy.config["worker_index"]:
|
||||
|
|
|
@ -246,7 +246,7 @@ class MAMLLoss(object):
|
|||
|
||||
|
||||
def maml_loss(policy, model, dist_class, train_batch):
|
||||
logits, state = model.from_batch(train_batch)
|
||||
logits, state = model(train_batch)
|
||||
policy.cur_lr = policy.config["lr"]
|
||||
|
||||
if policy.config["worker_index"]:
|
||||
|
|
|
@ -30,7 +30,7 @@ class ValueNetworkMixin:
|
|||
# input_dict.
|
||||
@make_tf_callable(self.get_session())
|
||||
def value(**input_dict):
|
||||
model_out, _ = self.model.from_batch(input_dict, is_training=False)
|
||||
model_out, _ = self.model(input_dict)
|
||||
# [0] = remove the batch dim.
|
||||
return self.model.value_function()[0]
|
||||
|
||||
|
@ -150,7 +150,7 @@ class MARWILLoss:
|
|||
|
||||
def marwil_loss(policy: Policy, model: ModelV2, dist_class: ActionDistribution,
|
||||
train_batch: SampleBatch) -> TensorType:
|
||||
model_out, _ = model.from_batch(train_batch)
|
||||
model_out, _ = model(train_batch)
|
||||
action_dist = dist_class(model_out, model)
|
||||
value_estimates = model.value_function()
|
||||
|
||||
|
|
|
@ -19,7 +19,7 @@ torch, _ = try_import_torch()
|
|||
|
||||
def marwil_loss(policy: Policy, model: ModelV2, dist_class: ActionDistribution,
|
||||
train_batch: SampleBatch) -> TensorType:
|
||||
model_out, _ = model.from_batch(train_batch)
|
||||
model_out, _ = model(train_batch)
|
||||
action_dist = dist_class(model_out, model)
|
||||
actions = train_batch[SampleBatch.ACTIONS]
|
||||
# log\pi_\theta(a|s)
|
||||
|
|
|
@ -149,7 +149,7 @@ class TestMARWIL(unittest.TestCase):
|
|||
cummulative_rewards = torch.tensor(cummulative_rewards)
|
||||
if fw != "tf":
|
||||
batch = policy._lazy_tensor_dict(batch)
|
||||
model_out, _ = model.from_batch(batch)
|
||||
model_out, _ = model(batch)
|
||||
vf_estimates = model.value_function()
|
||||
if fw == "tf":
|
||||
model_out, vf_estimates = \
|
||||
|
|
|
@ -34,7 +34,7 @@ def pg_tf_loss(
|
|||
of loss tensors.
|
||||
"""
|
||||
# Pass the training data through our model to get distribution parameters.
|
||||
dist_inputs, _ = model.from_batch(train_batch)
|
||||
dist_inputs, _ = model(train_batch)
|
||||
|
||||
# Create an action distribution object.
|
||||
action_dist = dist_class(dist_inputs, model)
|
||||
|
|
|
@ -35,7 +35,7 @@ def pg_torch_loss(
|
|||
of loss tensors.
|
||||
"""
|
||||
# Pass the training data through our model to get distribution parameters.
|
||||
dist_inputs, _ = model.from_batch(train_batch)
|
||||
dist_inputs, _ = model(train_batch)
|
||||
|
||||
# Create an action distribution object.
|
||||
action_dist = dist_class(dist_inputs, model)
|
||||
|
|
|
@ -99,7 +99,7 @@ def appo_surrogate_loss(
|
|||
Union[TensorType, List[TensorType]]: A single loss tensor or a list
|
||||
of loss tensors.
|
||||
"""
|
||||
model_out, _ = model.from_batch(train_batch)
|
||||
model_out, _ = model(train_batch)
|
||||
action_dist = dist_class(model_out, model)
|
||||
|
||||
if isinstance(policy.action_space, gym.spaces.Discrete):
|
||||
|
@ -123,7 +123,7 @@ def appo_surrogate_loss(
|
|||
rewards = train_batch[SampleBatch.REWARDS]
|
||||
behaviour_logits = train_batch[SampleBatch.ACTION_DIST_INPUTS]
|
||||
|
||||
target_model_out, _ = policy.target_model.from_batch(train_batch)
|
||||
target_model_out, _ = policy.target_model(train_batch)
|
||||
prev_action_dist = dist_class(behaviour_logits, policy.model)
|
||||
values = policy.model.value_function()
|
||||
values_time_major = make_time_major(values)
|
||||
|
|
|
@ -56,7 +56,7 @@ def appo_surrogate_loss(policy: Policy, model: ModelV2,
|
|||
"""
|
||||
target_model = policy.target_models[model]
|
||||
|
||||
model_out, _ = model.from_batch(train_batch)
|
||||
model_out, _ = model(train_batch)
|
||||
action_dist = dist_class(model_out, model)
|
||||
|
||||
if isinstance(policy.action_space, gym.spaces.Discrete):
|
||||
|
@ -79,7 +79,7 @@ def appo_surrogate_loss(policy: Policy, model: ModelV2,
|
|||
rewards = train_batch[SampleBatch.REWARDS]
|
||||
behaviour_logits = train_batch[SampleBatch.ACTION_DIST_INPUTS]
|
||||
|
||||
target_model_out, _ = target_model.from_batch(train_batch)
|
||||
target_model_out, _ = target_model(train_batch)
|
||||
|
||||
prev_action_dist = dist_class(behaviour_logits, model)
|
||||
values = model.value_function()
|
||||
|
|
|
@ -50,7 +50,7 @@ def ppo_surrogate_loss(
|
|||
logits, state, extra_outs = model(train_batch)
|
||||
value_fn_out = extra_outs[SampleBatch.VF_PREDS]
|
||||
else:
|
||||
logits, state = model.from_batch(train_batch)
|
||||
logits, state = model(train_batch)
|
||||
value_fn_out = model.value_function()
|
||||
|
||||
curr_action_dist = dist_class(logits, model)
|
||||
|
|
|
@ -16,7 +16,7 @@ parser.add_argument("--num-cpus", type=int, default=0)
|
|||
|
||||
|
||||
def policy_gradient_loss(policy, model, dist_class, train_batch):
|
||||
logits, _ = model.from_batch(train_batch)
|
||||
logits, _ = model(train_batch)
|
||||
action_dist = dist_class(logits, model)
|
||||
return -tf.reduce_mean(
|
||||
action_dist.logp(train_batch["actions"]) * train_batch["returns"])
|
||||
|
|
|
@ -80,7 +80,7 @@ def policy_gradient_loss(policy, model, dist_class, train_batch):
|
|||
print("The eagerly computed penalty is", penalty, actions, rewards)
|
||||
return penalty
|
||||
|
||||
logits, _ = model.from_batch(train_batch)
|
||||
logits, _ = model(train_batch)
|
||||
action_dist = dist_class(logits, model)
|
||||
|
||||
actions = train_batch[SampleBatch.ACTIONS]
|
||||
|
|
|
@ -33,7 +33,7 @@ trainer = ppo.PPOTrainer(config=config, env="CartPole-v0")
|
|||
|
||||
# Let's run inference on the tensorflow model
|
||||
policy = trainer.get_policy()
|
||||
result_tf, _ = policy.model.from_batch(test_data)
|
||||
result_tf, _ = policy.model(test_data)
|
||||
|
||||
# Evaluate tensor to fetch numpy array
|
||||
with policy._sess.as_default():
|
||||
|
|
|
@ -35,7 +35,7 @@ trainer = ppo.PPOTrainer(config=config, env="CartPole-v0")
|
|||
|
||||
# Let's run inference on the torch model
|
||||
policy = trainer.get_policy()
|
||||
result_pytorch, _ = policy.model.from_batch({
|
||||
result_pytorch, _ = policy.model({
|
||||
"obs": torch.tensor(test_data["obs"]),
|
||||
})
|
||||
|
||||
|
|
|
@ -146,7 +146,7 @@ def run_with_custom_entropy_loss(args, stop):
|
|||
This performs about the same as the default loss does."""
|
||||
|
||||
def entropy_policy_gradient_loss(policy, model, dist_class, train_batch):
|
||||
logits, _ = model.from_batch(train_batch)
|
||||
logits, _ = model(train_batch)
|
||||
action_dist = dist_class(logits, model)
|
||||
if args.framework == "torch":
|
||||
# Required by PGTorchPolicy's stats fn.
|
||||
|
|
|
@ -10,7 +10,7 @@ from ray.rllib.models.repeated_values import RepeatedValues
|
|||
from ray.rllib.policy.sample_batch import SampleBatch
|
||||
from ray.rllib.policy.view_requirement import ViewRequirement
|
||||
from ray.rllib.utils import NullContextManager
|
||||
from ray.rllib.utils.annotations import DeveloperAPI, PublicAPI
|
||||
from ray.rllib.utils.annotations import Deprecated, DeveloperAPI, PublicAPI
|
||||
from ray.rllib.utils.framework import try_import_tf, try_import_torch, \
|
||||
TensorType
|
||||
from ray.rllib.utils.spaces.repeated import Repeated
|
||||
|
@ -204,18 +204,19 @@ class ModelV2:
|
|||
# where tensors get automatically converted).
|
||||
if isinstance(input_dict, SampleBatch):
|
||||
restored = input_dict.copy(shallow=True)
|
||||
# Backward compatibility.
|
||||
if seq_lens is None:
|
||||
seq_lens = input_dict.get(SampleBatch.SEQ_LENS)
|
||||
if not state:
|
||||
state = []
|
||||
i = 0
|
||||
while "state_in_{}".format(i) in input_dict:
|
||||
state.append(input_dict["state_in_{}".format(i)])
|
||||
i += 1
|
||||
else:
|
||||
restored = input_dict.copy()
|
||||
|
||||
# Backward compatibility.
|
||||
if not state:
|
||||
state = []
|
||||
i = 0
|
||||
while "state_in_{}".format(i) in input_dict:
|
||||
state.append(input_dict["state_in_{}".format(i)])
|
||||
i += 1
|
||||
if seq_lens is None:
|
||||
seq_lens = input_dict.get(SampleBatch.SEQ_LENS)
|
||||
|
||||
# No Preprocessor used: `config._disable_preprocessor_api`=True.
|
||||
# TODO: This is unnecessary for when no preprocessor is used.
|
||||
# Obs are not flat then anymore. However, we'll keep this
|
||||
|
@ -255,9 +256,7 @@ class ModelV2:
|
|||
self._last_output = outputs
|
||||
return outputs, state_out if len(state_out) > 0 else (state or [])
|
||||
|
||||
# TODO: (sven) obsolete this method at some point (replace by
|
||||
# simply calling model directly with a sample_batch as only input).
|
||||
@PublicAPI
|
||||
@Deprecated(new="ModelV2.__call__()", error=False)
|
||||
def from_batch(self, train_batch: SampleBatch,
|
||||
is_training: bool = True) -> (TensorType, List[TensorType]):
|
||||
"""Convenience function that calls this model with a tensor batch.
|
||||
|
@ -267,7 +266,7 @@ class ModelV2:
|
|||
"""
|
||||
|
||||
input_dict = train_batch.copy()
|
||||
input_dict.is_training = is_training
|
||||
input_dict["is_training"] = is_training
|
||||
states = []
|
||||
i = 0
|
||||
while "state_in_{}".format(i) in input_dict:
|
||||
|
|
|
@ -621,7 +621,9 @@ class DynamicTFPolicy(TFPolicy):
|
|||
)
|
||||
|
||||
train_batch = SampleBatch(
|
||||
dict(self._input_dict, **self._loss_input_dict))
|
||||
dict(self._input_dict, **self._loss_input_dict),
|
||||
_is_training=True,
|
||||
)
|
||||
|
||||
if self._state_inputs:
|
||||
train_batch[SampleBatch.SEQ_LENS] = self._seq_lens
|
||||
|
|
|
@ -264,7 +264,8 @@ class TorchPolicy(Policy):
|
|||
with torch.no_grad():
|
||||
seq_lens = torch.ones(len(obs_batch), dtype=torch.int32)
|
||||
input_dict = self._lazy_tensor_dict({
|
||||
SampleBatch.CUR_OBS: obs_batch
|
||||
SampleBatch.CUR_OBS: obs_batch,
|
||||
"is_training": False,
|
||||
})
|
||||
if prev_action_batch is not None:
|
||||
input_dict[SampleBatch.PREV_ACTIONS] = \
|
||||
|
@ -291,6 +292,7 @@ class TorchPolicy(Policy):
|
|||
with torch.no_grad():
|
||||
# Pass lazy (torch) tensor dict to Model as `input_dict`.
|
||||
input_dict = self._lazy_tensor_dict(input_dict)
|
||||
input_dict.is_training = False
|
||||
# Pack internal state inputs into (separate) list.
|
||||
state_batches = [
|
||||
input_dict[k] for k in input_dict.keys() if "state_in" in k[:8]
|
||||
|
|
|
@ -107,7 +107,7 @@ ModelGradients = Union[List[Tuple[TensorType, TensorType]], List[TensorType]]
|
|||
# Type of dict returned by get_weights() representing model weights.
|
||||
ModelWeights = dict
|
||||
|
||||
# An input dict used for direct ModelV2 calls or `ModelV2.from_batch` calls.
|
||||
# An input dict used for direct ModelV2 calls.
|
||||
ModelInputDict = Dict[str, TensorType]
|
||||
|
||||
# Some kind of sample batch.
|
||||
|
|
Loading…
Add table
Reference in a new issue