mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
[RLlib] Issue 9071 A3C w/ RNN not working due to VF assuming no RNN. (#13238)
This commit is contained in:
parent
e74947cc94
commit
2e3655e8a9
24 changed files with 251 additions and 255 deletions
|
@ -1,17 +1,34 @@
|
|||
"""Note: Keep in sync with changes to VTraceTFPolicy."""
|
||||
|
||||
import ray
|
||||
from ray.rllib.agents.ppo.ppo_tf_policy import ValueNetworkMixin
|
||||
from ray.rllib.policy.sample_batch import SampleBatch
|
||||
from ray.rllib.evaluation.postprocessing import compute_advantages, \
|
||||
from ray.rllib.evaluation.postprocessing import compute_gae_for_sample_batch, \
|
||||
Postprocessing
|
||||
from ray.rllib.policy.tf_policy_template import build_tf_policy
|
||||
from ray.rllib.policy.tf_policy import LearningRateSchedule
|
||||
from ray.rllib.utils.deprecation import deprecation_warning
|
||||
from ray.rllib.utils.framework import try_import_tf
|
||||
from ray.rllib.utils.tf_ops import explained_variance, make_tf_callable
|
||||
from ray.rllib.utils.tf_ops import explained_variance
|
||||
|
||||
tf1, tf, tfv = try_import_tf()
|
||||
|
||||
|
||||
def postprocess_advantages(policy,
|
||||
sample_batch,
|
||||
other_agent_batches=None,
|
||||
episode=None):
|
||||
|
||||
# Stub serving backward compatibility.
|
||||
deprecation_warning(
|
||||
old="rllib.agents.a3c.a3c_tf_policy.postprocess_advantages",
|
||||
new="rllib.evaluation.postprocessing.compute_gae_for_sample_batch",
|
||||
error=False)
|
||||
|
||||
return compute_gae_for_sample_batch(policy, sample_batch,
|
||||
other_agent_batches, episode)
|
||||
|
||||
|
||||
class A3CLoss:
|
||||
def __init__(self,
|
||||
action_dist,
|
||||
|
@ -45,46 +62,10 @@ def actor_critic_loss(policy, model, dist_class, train_batch):
|
|||
return policy.loss.total_loss
|
||||
|
||||
|
||||
def postprocess_advantages(policy,
|
||||
sample_batch,
|
||||
other_agent_batches=None,
|
||||
episode=None):
|
||||
completed = sample_batch[SampleBatch.DONES][-1]
|
||||
if completed:
|
||||
last_r = 0.0
|
||||
else:
|
||||
next_state = []
|
||||
for i in range(policy.num_state_tensors()):
|
||||
next_state.append(sample_batch["state_out_{}".format(i)][-1])
|
||||
last_r = policy._value(sample_batch[SampleBatch.NEXT_OBS][-1],
|
||||
sample_batch[SampleBatch.ACTIONS][-1],
|
||||
sample_batch[SampleBatch.REWARDS][-1],
|
||||
*next_state)
|
||||
return compute_advantages(
|
||||
sample_batch, last_r, policy.config["gamma"], policy.config["lambda"],
|
||||
policy.config["use_gae"], policy.config["use_critic"])
|
||||
|
||||
|
||||
def add_value_function_fetch(policy):
|
||||
return {SampleBatch.VF_PREDS: policy.model.value_function()}
|
||||
|
||||
|
||||
class ValueNetworkMixin:
|
||||
def __init__(self):
|
||||
@make_tf_callable(self.get_session())
|
||||
def value(ob, prev_action, prev_reward, *state):
|
||||
model_out, _ = self.model({
|
||||
SampleBatch.CUR_OBS: tf.convert_to_tensor([ob]),
|
||||
SampleBatch.PREV_ACTIONS: tf.convert_to_tensor([prev_action]),
|
||||
SampleBatch.PREV_REWARDS: tf.convert_to_tensor([prev_reward]),
|
||||
"is_training": tf.convert_to_tensor(False),
|
||||
}, [tf.convert_to_tensor([s]) for s in state],
|
||||
tf.convert_to_tensor([1]))
|
||||
return self.model.value_function()[0]
|
||||
|
||||
self._value = value
|
||||
|
||||
|
||||
def stats(policy, train_batch):
|
||||
return {
|
||||
"cur_lr": tf.cast(policy.cur_lr, tf.float64),
|
||||
|
@ -115,7 +96,7 @@ def clip_gradients(policy, optimizer, loss):
|
|||
|
||||
|
||||
def setup_mixins(policy, obs_space, action_space, config):
|
||||
ValueNetworkMixin.__init__(policy)
|
||||
ValueNetworkMixin.__init__(policy, obs_space, action_space, config)
|
||||
LearningRateSchedule.__init__(policy, config["lr"], config["lr_schedule"])
|
||||
|
||||
|
||||
|
@ -126,7 +107,7 @@ A3CTFPolicy = build_tf_policy(
|
|||
stats_fn=stats,
|
||||
grad_stats_fn=grad_stats,
|
||||
gradients_fn=clip_gradients,
|
||||
postprocess_fn=postprocess_advantages,
|
||||
postprocess_fn=compute_gae_for_sample_batch,
|
||||
extra_action_fetches_fn=add_value_function_fetch,
|
||||
before_loss_init=setup_mixins,
|
||||
mixins=[ValueNetworkMixin, LearningRateSchedule])
|
||||
|
|
|
@ -1,13 +1,35 @@
|
|||
import gym
|
||||
|
||||
import ray
|
||||
from ray.rllib.evaluation.postprocessing import compute_advantages, \
|
||||
from ray.rllib.agents.ppo.ppo_torch_policy import ValueNetworkMixin
|
||||
from ray.rllib.evaluation.postprocessing import compute_gae_for_sample_batch, \
|
||||
Postprocessing
|
||||
from ray.rllib.policy.policy import Policy
|
||||
from ray.rllib.policy.policy_template import build_policy_class
|
||||
from ray.rllib.policy.sample_batch import SampleBatch
|
||||
from ray.rllib.utils.deprecation import deprecation_warning
|
||||
from ray.rllib.utils.framework import try_import_torch
|
||||
from ray.rllib.utils.torch_ops import apply_grad_clipping
|
||||
from ray.rllib.utils.typing import TrainerConfigDict
|
||||
|
||||
torch, nn = try_import_torch()
|
||||
|
||||
|
||||
def add_advantages(policy,
|
||||
sample_batch,
|
||||
other_agent_batches=None,
|
||||
episode=None):
|
||||
|
||||
# Stub serving backward compatibility.
|
||||
deprecation_warning(
|
||||
old="rllib.agents.a3c.a3c_torch_policy.add_advantages",
|
||||
new="rllib.evaluation.postprocessing.compute_gae_for_sample_batch",
|
||||
error=False)
|
||||
|
||||
return compute_gae_for_sample_batch(policy, sample_batch,
|
||||
other_agent_batches, episode)
|
||||
|
||||
|
||||
def actor_critic_loss(policy, model, dist_class, train_batch):
|
||||
logits, _ = model.from_batch(train_batch)
|
||||
values = model.value_function()
|
||||
|
@ -36,52 +58,27 @@ def loss_and_entropy_stats(policy, train_batch):
|
|||
}
|
||||
|
||||
|
||||
def add_advantages(policy,
|
||||
sample_batch,
|
||||
other_agent_batches=None,
|
||||
episode=None):
|
||||
|
||||
completed = sample_batch[SampleBatch.DONES][-1]
|
||||
if completed:
|
||||
last_r = 0.0
|
||||
else:
|
||||
last_r = policy._value(sample_batch[SampleBatch.NEXT_OBS][-1])
|
||||
|
||||
return compute_advantages(
|
||||
sample_batch, last_r, policy.config["gamma"], policy.config["lambda"],
|
||||
policy.config["use_gae"], policy.config["use_critic"])
|
||||
|
||||
|
||||
def model_value_predictions(policy, input_dict, state_batches, model,
|
||||
action_dist):
|
||||
return {SampleBatch.VF_PREDS: model.value_function()}
|
||||
|
||||
|
||||
def apply_grad_clipping(policy, optimizer, loss):
|
||||
info = {}
|
||||
if policy.config["grad_clip"]:
|
||||
for param_group in optimizer.param_groups:
|
||||
# Make sure we only pass params with grad != None into torch
|
||||
# clip_grad_norm_. Would fail otherwise.
|
||||
params = list(
|
||||
filter(lambda p: p.grad is not None, param_group["params"]))
|
||||
if params:
|
||||
grad_gnorm = nn.utils.clip_grad_norm_(
|
||||
params, policy.config["grad_clip"])
|
||||
if isinstance(grad_gnorm, torch.Tensor):
|
||||
grad_gnorm = grad_gnorm.cpu().numpy()
|
||||
info["grad_gnorm"] = grad_gnorm
|
||||
return info
|
||||
|
||||
|
||||
def torch_optimizer(policy, config):
|
||||
return torch.optim.Adam(policy.model.parameters(), lr=config["lr"])
|
||||
|
||||
|
||||
class ValueNetworkMixin:
|
||||
def _value(self, obs):
|
||||
_ = self.model({"obs": torch.Tensor([obs]).to(self.device)}, [], [1])
|
||||
return self.model.value_function()[0]
|
||||
def setup_mixins(policy: Policy, obs_space: gym.spaces.Space,
|
||||
action_space: gym.spaces.Space,
|
||||
config: TrainerConfigDict) -> None:
|
||||
"""Call all mixin classes' constructors before PPOPolicy initialization.
|
||||
|
||||
Args:
|
||||
policy (Policy): The Policy object.
|
||||
obs_space (gym.spaces.Space): The Policy's observation space.
|
||||
action_space (gym.spaces.Space): The Policy's action space.
|
||||
config (TrainerConfigDict): The Policy's config.
|
||||
"""
|
||||
ValueNetworkMixin.__init__(policy, obs_space, action_space, config)
|
||||
|
||||
|
||||
A3CTorchPolicy = build_policy_class(
|
||||
|
@ -90,9 +87,10 @@ A3CTorchPolicy = build_policy_class(
|
|||
get_default_config=lambda: ray.rllib.agents.a3c.a3c.DEFAULT_CONFIG,
|
||||
loss_fn=actor_critic_loss,
|
||||
stats_fn=loss_and_entropy_stats,
|
||||
postprocess_fn=add_advantages,
|
||||
postprocess_fn=compute_gae_for_sample_batch,
|
||||
extra_action_out_fn=model_value_predictions,
|
||||
extra_grad_process_fn=apply_grad_clipping,
|
||||
optimizer_fn=torch_optimizer,
|
||||
before_loss_init=setup_mixins,
|
||||
mixins=[ValueNetworkMixin],
|
||||
)
|
||||
|
|
|
@ -8,7 +8,6 @@ from typing import Dict, List, Tuple, Type, Union
|
|||
|
||||
import ray
|
||||
import ray.experimental.tf_utils
|
||||
from ray.rllib.agents.a3c.a3c_torch_policy import apply_grad_clipping
|
||||
from ray.rllib.agents.sac.sac_tf_policy import postprocess_trajectory, \
|
||||
validate_spaces
|
||||
from ray.rllib.agents.sac.sac_torch_policy import _get_dist_class, stats, \
|
||||
|
@ -22,7 +21,8 @@ from ray.rllib.policy.sample_batch import SampleBatch
|
|||
from ray.rllib.utils.framework import try_import_torch
|
||||
from ray.rllib.utils.typing import LocalOptimizer, TensorType, \
|
||||
TrainerConfigDict
|
||||
from ray.rllib.utils.torch_ops import convert_to_torch_tensor
|
||||
from ray.rllib.utils.torch_ops import apply_grad_clipping, \
|
||||
convert_to_torch_tensor
|
||||
|
||||
torch, nn = try_import_torch()
|
||||
F = nn.functional
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
import logging
|
||||
|
||||
import ray
|
||||
from ray.rllib.agents.a3c.a3c_torch_policy import apply_grad_clipping
|
||||
from ray.rllib.agents.ddpg.ddpg_tf_policy import build_ddpg_models, \
|
||||
get_distribution_inputs_and_class, validate_spaces
|
||||
from ray.rllib.agents.dqn.dqn_tf_policy import postprocess_nstep_and_prio, \
|
||||
|
@ -10,7 +9,7 @@ from ray.rllib.models.torch.torch_action_dist import TorchDeterministic
|
|||
from ray.rllib.policy.policy_template import build_policy_class
|
||||
from ray.rllib.policy.sample_batch import SampleBatch
|
||||
from ray.rllib.utils.framework import try_import_torch
|
||||
from ray.rllib.utils.torch_ops import huber_loss, l2_loss
|
||||
from ray.rllib.utils.torch_ops import apply_grad_clipping, huber_loss, l2_loss
|
||||
|
||||
torch, nn = try_import_torch()
|
||||
|
||||
|
|
|
@ -301,17 +301,11 @@ def adam_optimizer(policy: Policy, config: TrainerConfigDict
|
|||
|
||||
def clip_gradients(policy: Policy, optimizer: "tf.keras.optimizers.Optimizer",
|
||||
loss: TensorType) -> ModelGradients:
|
||||
if policy.config["grad_clip"] is not None:
|
||||
grads_and_vars = minimize_and_clip(
|
||||
optimizer,
|
||||
loss,
|
||||
var_list=policy.q_func_vars,
|
||||
clip_val=policy.config["grad_clip"])
|
||||
else:
|
||||
grads_and_vars = optimizer.compute_gradients(
|
||||
loss, var_list=policy.q_func_vars)
|
||||
grads_and_vars = [(g, v) for (g, v) in grads_and_vars if g is not None]
|
||||
return grads_and_vars
|
||||
return minimize_and_clip(
|
||||
optimizer,
|
||||
loss,
|
||||
var_list=policy.q_func_vars,
|
||||
clip_val=policy.config["grad_clip"])
|
||||
|
||||
|
||||
def build_q_stats(policy: Policy, batch) -> Dict[str, TensorType]:
|
||||
|
|
|
@ -4,7 +4,6 @@ from typing import Dict, List, Tuple
|
|||
|
||||
import gym
|
||||
import ray
|
||||
from ray.rllib.agents.a3c.a3c_torch_policy import apply_grad_clipping
|
||||
from ray.rllib.agents.dqn.dqn_tf_policy import (
|
||||
PRIO_WEIGHTS, Q_SCOPE, Q_TARGET_SCOPE, postprocess_nstep_and_prio)
|
||||
from ray.rllib.agents.dqn.dqn_torch_model import DQNTorchModel
|
||||
|
@ -20,9 +19,8 @@ from ray.rllib.policy.torch_policy import LearningRateSchedule
|
|||
from ray.rllib.utils.error import UnsupportedSpaceException
|
||||
from ray.rllib.utils.exploration.parameter_noise import ParameterNoise
|
||||
from ray.rllib.utils.framework import try_import_torch
|
||||
from ray.rllib.utils.torch_ops import (FLOAT_MIN, huber_loss,
|
||||
reduce_mean_ignore_inf,
|
||||
softmax_cross_entropy_with_logits)
|
||||
from ray.rllib.utils.torch_ops import apply_grad_clipping, FLOAT_MIN, \
|
||||
huber_loss, reduce_mean_ignore_inf, softmax_cross_entropy_with_logits
|
||||
from ray.rllib.utils.typing import TensorType, TrainerConfigDict
|
||||
|
||||
torch, nn = try_import_torch()
|
||||
|
|
|
@ -1,11 +1,11 @@
|
|||
import logging
|
||||
|
||||
import ray
|
||||
from ray.rllib.agents.a3c.a3c_torch_policy import apply_grad_clipping
|
||||
from ray.rllib.agents.dreamer.utils import FreezeParameters
|
||||
from ray.rllib.models.catalog import ModelCatalog
|
||||
from ray.rllib.policy.policy_template import build_policy_class
|
||||
from ray.rllib.utils.framework import try_import_torch
|
||||
from ray.rllib.utils.torch_ops import apply_grad_clipping
|
||||
|
||||
torch, nn = try_import_torch()
|
||||
if torch:
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
import logging
|
||||
|
||||
import ray
|
||||
from ray.rllib.agents.a3c.a3c_tf_policy import A3CTFPolicy
|
||||
from ray.rllib.agents.impala.vtrace_tf_policy import VTraceTFPolicy
|
||||
from ray.rllib.agents.trainer import Trainer, with_common_config
|
||||
from ray.rllib.agents.trainer_template import build_trainer
|
||||
|
@ -160,6 +159,7 @@ def get_policy_class(config):
|
|||
if config["vtrace"]:
|
||||
return VTraceTFPolicy
|
||||
else:
|
||||
from ray.rllib.agents.a3c.a3c_tf_policy import A3CTFPolicy
|
||||
return A3CTFPolicy
|
||||
|
||||
|
||||
|
|
|
@ -3,7 +3,6 @@ import logging
|
|||
import numpy as np
|
||||
|
||||
import ray
|
||||
from ray.rllib.agents.a3c.a3c_torch_policy import apply_grad_clipping
|
||||
import ray.rllib.agents.impala.vtrace_torch as vtrace
|
||||
from ray.rllib.models.torch.torch_action_dist import TorchCategorical
|
||||
from ray.rllib.policy.policy_template import build_policy_class
|
||||
|
@ -11,8 +10,8 @@ from ray.rllib.policy.sample_batch import SampleBatch
|
|||
from ray.rllib.policy.torch_policy import LearningRateSchedule, \
|
||||
EntropyCoeffSchedule
|
||||
from ray.rllib.utils.framework import try_import_torch
|
||||
from ray.rllib.utils.torch_ops import explained_variance, global_norm, \
|
||||
sequence_mask
|
||||
from ray.rllib.utils.torch_ops import apply_grad_clipping, \
|
||||
explained_variance, global_norm, sequence_mask
|
||||
|
||||
torch, nn = try_import_torch()
|
||||
|
||||
|
|
|
@ -1,10 +1,10 @@
|
|||
import logging
|
||||
|
||||
import ray
|
||||
from ray.rllib.agents.ppo.ppo_tf_policy import postprocess_ppo_gae, \
|
||||
vf_preds_fetches, compute_and_clip_gradients, setup_config, \
|
||||
ValueNetworkMixin
|
||||
from ray.rllib.evaluation.postprocessing import Postprocessing
|
||||
from ray.rllib.agents.ppo.ppo_tf_policy import vf_preds_fetches, \
|
||||
compute_and_clip_gradients, setup_config, ValueNetworkMixin
|
||||
from ray.rllib.evaluation.postprocessing import compute_gae_for_sample_batch, \
|
||||
Postprocessing
|
||||
from ray.rllib.models.utils import get_activation_fn
|
||||
from ray.rllib.policy.sample_batch import SampleBatch
|
||||
from ray.rllib.policy.tf_policy_template import build_tf_policy
|
||||
|
@ -422,7 +422,7 @@ MAMLTFPolicy = build_tf_policy(
|
|||
stats_fn=maml_stats,
|
||||
optimizer_fn=maml_optimizer_fn,
|
||||
extra_action_fetches_fn=vf_preds_fetches,
|
||||
postprocess_fn=postprocess_ppo_gae,
|
||||
postprocess_fn=compute_gae_for_sample_batch,
|
||||
gradients_fn=compute_and_clip_gradients,
|
||||
before_init=setup_config,
|
||||
before_loss_init=setup_mixins,
|
||||
|
|
|
@ -1,15 +1,15 @@
|
|||
import logging
|
||||
|
||||
import ray
|
||||
from ray.rllib.evaluation.postprocessing import Postprocessing
|
||||
from ray.rllib.evaluation.postprocessing import compute_gae_for_sample_batch, \
|
||||
Postprocessing
|
||||
from ray.rllib.policy.policy_template import build_policy_class
|
||||
from ray.rllib.policy.sample_batch import SampleBatch
|
||||
from ray.rllib.agents.ppo.ppo_tf_policy import postprocess_ppo_gae, \
|
||||
setup_config
|
||||
from ray.rllib.agents.ppo.ppo_tf_policy import setup_config
|
||||
from ray.rllib.agents.ppo.ppo_torch_policy import vf_preds_fetches, \
|
||||
ValueNetworkMixin
|
||||
from ray.rllib.agents.a3c.a3c_torch_policy import apply_grad_clipping
|
||||
from ray.rllib.utils.framework import try_import_torch
|
||||
from ray.rllib.utils.torch_ops import apply_grad_clipping
|
||||
|
||||
torch, nn = try_import_torch()
|
||||
|
||||
|
@ -355,7 +355,7 @@ MAMLTorchPolicy = build_policy_class(
|
|||
stats_fn=maml_stats,
|
||||
optimizer_fn=maml_optimizer_fn,
|
||||
extra_action_out_fn=vf_preds_fetches,
|
||||
postprocess_fn=postprocess_ppo_gae,
|
||||
postprocess_fn=compute_gae_for_sample_batch,
|
||||
extra_grad_process_fn=apply_grad_clipping,
|
||||
before_init=setup_config,
|
||||
after_init=setup_mixins,
|
||||
|
|
|
@ -3,18 +3,18 @@ import logging
|
|||
from typing import Tuple, Type
|
||||
|
||||
import ray
|
||||
from ray.rllib.agents.a3c.a3c_torch_policy import apply_grad_clipping
|
||||
from ray.rllib.agents.maml.maml_torch_policy import setup_mixins, \
|
||||
maml_loss, maml_stats, maml_optimizer_fn, KLCoeffMixin
|
||||
from ray.rllib.agents.ppo.ppo_tf_policy import postprocess_ppo_gae, \
|
||||
setup_config
|
||||
from ray.rllib.agents.ppo.ppo_tf_policy import setup_config
|
||||
from ray.rllib.agents.ppo.ppo_torch_policy import vf_preds_fetches
|
||||
from ray.rllib.evaluation.postprocessing import compute_gae_for_sample_batch
|
||||
from ray.rllib.models.catalog import ModelCatalog
|
||||
from ray.rllib.models.modelv2 import ModelV2
|
||||
from ray.rllib.models.torch.torch_action_dist import TorchDistributionWrapper
|
||||
from ray.rllib.policy.policy import Policy
|
||||
from ray.rllib.policy.policy_template import build_policy_class
|
||||
from ray.rllib.utils.framework import try_import_torch
|
||||
from ray.rllib.utils.torch_ops import apply_grad_clipping
|
||||
from ray.rllib.utils.typing import TrainerConfigDict
|
||||
|
||||
torch, nn = try_import_torch()
|
||||
|
@ -85,7 +85,7 @@ MBMPOTorchPolicy = build_policy_class(
|
|||
stats_fn=maml_stats,
|
||||
optimizer_fn=maml_optimizer_fn,
|
||||
extra_action_out_fn=vf_preds_fetches,
|
||||
postprocess_fn=postprocess_ppo_gae,
|
||||
postprocess_fn=compute_gae_for_sample_batch,
|
||||
extra_grad_process_fn=apply_grad_clipping,
|
||||
before_init=setup_config,
|
||||
after_init=setup_mixins,
|
||||
|
|
|
@ -13,9 +13,9 @@ from typing import Dict, List, Optional, Type, Union
|
|||
from ray.rllib.agents.impala import vtrace_tf as vtrace
|
||||
from ray.rllib.agents.impala.vtrace_tf_policy import _make_time_major, \
|
||||
clip_gradients, choose_optimizer
|
||||
from ray.rllib.agents.ppo.ppo_tf_policy import postprocess_ppo_gae
|
||||
from ray.rllib.evaluation.episode import MultiAgentEpisode
|
||||
from ray.rllib.evaluation.postprocessing import Postprocessing
|
||||
from ray.rllib.evaluation.postprocessing import compute_gae_for_sample_batch, \
|
||||
Postprocessing
|
||||
from ray.rllib.models.tf.tf_action_dist import Categorical
|
||||
from ray.rllib.policy.policy import Policy
|
||||
from ray.rllib.policy.sample_batch import SampleBatch
|
||||
|
@ -338,8 +338,8 @@ def postprocess_trajectory(
|
|||
SampleBatch: The postprocessed, modified SampleBatch (or a new one).
|
||||
"""
|
||||
if not policy.config["vtrace"]:
|
||||
sample_batch = postprocess_ppo_gae(policy, sample_batch,
|
||||
other_agent_batches, episode)
|
||||
sample_batch = compute_gae_for_sample_batch(
|
||||
policy, sample_batch, other_agent_batches, episode)
|
||||
|
||||
# TODO: (sven) remove this del once we have trajectory view API fully in
|
||||
# place.
|
||||
|
|
|
@ -10,7 +10,6 @@ import numpy as np
|
|||
import logging
|
||||
from typing import Type
|
||||
|
||||
from ray.rllib.agents.a3c.a3c_torch_policy import apply_grad_clipping
|
||||
import ray.rllib.agents.impala.vtrace_torch as vtrace
|
||||
from ray.rllib.agents.impala.vtrace_torch_policy import make_time_major, \
|
||||
choose_optimizer
|
||||
|
@ -27,8 +26,8 @@ from ray.rllib.policy.policy_template import build_policy_class
|
|||
from ray.rllib.policy.sample_batch import SampleBatch
|
||||
from ray.rllib.policy.torch_policy import LearningRateSchedule
|
||||
from ray.rllib.utils.framework import try_import_torch
|
||||
from ray.rllib.utils.torch_ops import explained_variance, global_norm, \
|
||||
sequence_mask
|
||||
from ray.rllib.utils.torch_ops import apply_grad_clipping, explained_variance,\
|
||||
global_norm, sequence_mask
|
||||
from ray.rllib.utils.typing import TensorType, TrainerConfigDict
|
||||
|
||||
torch, nn = try_import_torch()
|
||||
|
|
|
@ -8,7 +8,7 @@ from typing import Dict, List, Optional, Type, Union
|
|||
|
||||
import ray
|
||||
from ray.rllib.evaluation.episode import MultiAgentEpisode
|
||||
from ray.rllib.evaluation.postprocessing import compute_advantages, \
|
||||
from ray.rllib.evaluation.postprocessing import compute_gae_for_sample_batch, \
|
||||
Postprocessing
|
||||
from ray.rllib.models.modelv2 import ModelV2
|
||||
from ray.rllib.models.tf.tf_action_dist import TFActionDistribution
|
||||
|
@ -160,71 +160,6 @@ def vf_preds_fetches(policy: Policy) -> Dict[str, TensorType]:
|
|||
}
|
||||
|
||||
|
||||
def postprocess_ppo_gae(
|
||||
policy: Policy,
|
||||
sample_batch: SampleBatch,
|
||||
other_agent_batches: Optional[Dict[AgentID, SampleBatch]] = None,
|
||||
episode: Optional[MultiAgentEpisode] = None) -> SampleBatch:
|
||||
"""Postprocesses a trajectory and returns the processed trajectory.
|
||||
|
||||
The trajectory contains only data from one episode and from one agent.
|
||||
- If `config.batch_mode=truncate_episodes` (default), sample_batch may
|
||||
contain a truncated (at-the-end) episode, in case the
|
||||
`config.rollout_fragment_length` was reached by the sampler.
|
||||
- If `config.batch_mode=complete_episodes`, sample_batch will contain
|
||||
exactly one episode (no matter how long).
|
||||
New columns can be added to sample_batch and existing ones may be altered.
|
||||
|
||||
Args:
|
||||
policy (Policy): The Policy used to generate the trajectory
|
||||
(`sample_batch`)
|
||||
sample_batch (SampleBatch): The SampleBatch to postprocess.
|
||||
other_agent_batches (Optional[Dict[PolicyID, SampleBatch]]): Optional
|
||||
dict of AgentIDs mapping to other agents' trajectory data (from the
|
||||
same episode). NOTE: The other agents use the same policy.
|
||||
episode (Optional[MultiAgentEpisode]): Optional multi-agent episode
|
||||
object in which the agents operated.
|
||||
|
||||
Returns:
|
||||
SampleBatch: The postprocessed, modified SampleBatch (or a new one).
|
||||
"""
|
||||
|
||||
# Trajectory is actually complete -> last r=0.0.
|
||||
if sample_batch[SampleBatch.DONES][-1]:
|
||||
last_r = 0.0
|
||||
# Trajectory has been truncated -> last r=VF estimate of last obs.
|
||||
else:
|
||||
# Input dict is provided to us automatically via the Model's
|
||||
# requirements. It's a single-timestep (last one in trajectory)
|
||||
# input_dict.
|
||||
if policy.config["_use_trajectory_view_api"]:
|
||||
# Create an input dict according to the Model's requirements.
|
||||
input_dict = policy.model.get_input_dict(
|
||||
sample_batch, index="last")
|
||||
last_r = policy._value(**input_dict)
|
||||
# TODO: (sven) Remove once trajectory view API is all-algo default.
|
||||
else:
|
||||
next_state = []
|
||||
for i in range(policy.num_state_tensors()):
|
||||
next_state.append(sample_batch["state_out_{}".format(i)][-1])
|
||||
last_r = policy._value(sample_batch[SampleBatch.NEXT_OBS][-1],
|
||||
sample_batch[SampleBatch.ACTIONS][-1],
|
||||
sample_batch[SampleBatch.REWARDS][-1],
|
||||
*next_state)
|
||||
|
||||
# Adds the policy logits, VF preds, and advantages to the batch,
|
||||
# using GAE ("generalized advantage estimation") or not.
|
||||
batch = compute_advantages(
|
||||
sample_batch,
|
||||
last_r,
|
||||
policy.config["gamma"],
|
||||
policy.config["lambda"],
|
||||
use_gae=policy.config["use_gae"],
|
||||
use_critic=policy.config.get("use_critic", True))
|
||||
|
||||
return batch
|
||||
|
||||
|
||||
def compute_and_clip_gradients(policy: Policy, optimizer: LocalOptimizer,
|
||||
loss: TensorType) -> ModelGradients:
|
||||
"""Gradients computing function (from loss tensor, using local optimizer).
|
||||
|
@ -392,13 +327,29 @@ def setup_mixins(policy: Policy, obs_space: gym.spaces.Space,
|
|||
LearningRateSchedule.__init__(policy, config["lr"], config["lr_schedule"])
|
||||
|
||||
|
||||
def postprocess_ppo_gae(
|
||||
policy: Policy,
|
||||
sample_batch: SampleBatch,
|
||||
other_agent_batches: Optional[Dict[AgentID, SampleBatch]] = None,
|
||||
episode: Optional[MultiAgentEpisode] = None) -> SampleBatch:
|
||||
|
||||
# Stub serving backward compatibility.
|
||||
deprecation_warning(
|
||||
old="rllib.agents.ppo.ppo_tf_policy.postprocess_ppo_gae",
|
||||
new="rllib.evaluation.postprocessing.compute_gae_for_sample_batch",
|
||||
error=False)
|
||||
|
||||
return compute_gae_for_sample_batch(policy, sample_batch,
|
||||
other_agent_batches, episode)
|
||||
|
||||
|
||||
# Build a child class of `DynamicTFPolicy`, given the custom functions defined
|
||||
# above.
|
||||
PPOTFPolicy = build_tf_policy(
|
||||
name="PPOTFPolicy",
|
||||
loss_fn=ppo_surrogate_loss,
|
||||
get_default_config=lambda: ray.rllib.agents.ppo.ppo.DEFAULT_CONFIG,
|
||||
postprocess_fn=postprocess_ppo_gae,
|
||||
postprocess_fn=compute_gae_for_sample_batch,
|
||||
stats_fn=kl_and_loss_stats,
|
||||
gradients_fn=compute_and_clip_gradients,
|
||||
extra_action_fetches_fn=vf_preds_fetches,
|
||||
|
|
|
@ -7,10 +7,9 @@ import numpy as np
|
|||
from typing import Dict, List, Type, Union
|
||||
|
||||
import ray
|
||||
from ray.rllib.agents.a3c.a3c_torch_policy import apply_grad_clipping
|
||||
from ray.rllib.agents.ppo.ppo_tf_policy import postprocess_ppo_gae, \
|
||||
setup_config
|
||||
from ray.rllib.evaluation.postprocessing import Postprocessing
|
||||
from ray.rllib.agents.ppo.ppo_tf_policy import setup_config
|
||||
from ray.rllib.evaluation.postprocessing import compute_gae_for_sample_batch, \
|
||||
Postprocessing
|
||||
from ray.rllib.models.modelv2 import ModelV2
|
||||
from ray.rllib.models.torch.torch_action_dist import TorchDistributionWrapper
|
||||
from ray.rllib.policy.policy import Policy
|
||||
|
@ -19,8 +18,8 @@ from ray.rllib.policy.sample_batch import SampleBatch
|
|||
from ray.rllib.policy.torch_policy import EntropyCoeffSchedule, \
|
||||
LearningRateSchedule
|
||||
from ray.rllib.utils.framework import try_import_torch
|
||||
from ray.rllib.utils.torch_ops import convert_to_torch_tensor, \
|
||||
explained_variance, sequence_mask
|
||||
from ray.rllib.utils.torch_ops import apply_grad_clipping, \
|
||||
convert_to_torch_tensor, explained_variance, sequence_mask
|
||||
from ray.rllib.utils.typing import TensorType, TrainerConfigDict
|
||||
|
||||
torch, nn = try_import_torch()
|
||||
|
@ -279,7 +278,7 @@ PPOTorchPolicy = build_policy_class(
|
|||
loss_fn=ppo_surrogate_loss,
|
||||
stats_fn=kl_and_loss_stats,
|
||||
extra_action_out_fn=vf_preds_fetches,
|
||||
postprocess_fn=postprocess_ppo_gae,
|
||||
postprocess_fn=compute_gae_for_sample_batch,
|
||||
extra_grad_process_fn=apply_grad_clipping,
|
||||
before_init=setup_config,
|
||||
before_loss_init=setup_mixins,
|
||||
|
|
|
@ -5,11 +5,12 @@ import unittest
|
|||
import ray
|
||||
from ray.rllib.agents.callbacks import DefaultCallbacks
|
||||
import ray.rllib.agents.ppo as ppo
|
||||
from ray.rllib.agents.ppo.ppo_tf_policy import postprocess_ppo_gae as \
|
||||
postprocess_ppo_gae_tf, ppo_surrogate_loss as ppo_surrogate_loss_tf
|
||||
from ray.rllib.agents.ppo.ppo_torch_policy import postprocess_ppo_gae as \
|
||||
postprocess_ppo_gae_torch, ppo_surrogate_loss as ppo_surrogate_loss_torch
|
||||
from ray.rllib.evaluation.postprocessing import Postprocessing
|
||||
from ray.rllib.agents.ppo.ppo_tf_policy import ppo_surrogate_loss as \
|
||||
ppo_surrogate_loss_tf
|
||||
from ray.rllib.agents.ppo.ppo_torch_policy import ppo_surrogate_loss as \
|
||||
ppo_surrogate_loss_torch
|
||||
from ray.rllib.evaluation.postprocessing import compute_gae_for_sample_batch, \
|
||||
Postprocessing
|
||||
from ray.rllib.models.tf.tf_action_dist import Categorical
|
||||
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
|
||||
from ray.rllib.models.torch.torch_action_dist import TorchCategorical
|
||||
|
@ -212,11 +213,8 @@ class TestPPO(unittest.TestCase):
|
|||
# Check the variable is initially zero.
|
||||
init_std = get_value()
|
||||
assert init_std == 0.0, init_std
|
||||
|
||||
if fw in ["tf2", "tf", "tfe"]:
|
||||
batch = postprocess_ppo_gae_tf(policy, FAKE_BATCH.copy())
|
||||
else:
|
||||
batch = postprocess_ppo_gae_torch(policy, FAKE_BATCH.copy())
|
||||
batch = compute_gae_for_sample_batch(policy, FAKE_BATCH.copy())
|
||||
if fw == "torch":
|
||||
batch = policy._lazy_tensor_dict(batch)
|
||||
policy.learn_on_batch(batch)
|
||||
|
||||
|
@ -255,11 +253,9 @@ class TestPPO(unittest.TestCase):
|
|||
# to train_batch dict.
|
||||
# A = [0.99^2 * 0.5 + 0.99 * -1.0 + 1.0, 0.99 * 0.5 - 1.0, 0.5] =
|
||||
# [0.50005, -0.505, 0.5]
|
||||
if fw in ["tf2", "tf", "tfe"]:
|
||||
train_batch = postprocess_ppo_gae_tf(policy, FAKE_BATCH.copy())
|
||||
else:
|
||||
train_batch = postprocess_ppo_gae_torch(
|
||||
policy, FAKE_BATCH.copy())
|
||||
train_batch = compute_gae_for_sample_batch(policy,
|
||||
FAKE_BATCH.copy())
|
||||
if fw == "torch":
|
||||
train_batch = policy._lazy_tensor_dict(train_batch)
|
||||
|
||||
# Check Advantage values.
|
||||
|
|
|
@ -73,8 +73,7 @@ DEFAULT_CONFIG = with_common_config({
|
|||
"timesteps_per_iteration": 100,
|
||||
|
||||
# === Replay buffer ===
|
||||
# Size of the replay buffer. Note that if async_updates is set, then
|
||||
# each worker will have a replay buffer of this size.
|
||||
# Size of the replay buffer (in time steps).
|
||||
"buffer_size": int(1e6),
|
||||
# If True prioritized replay buffer will be used.
|
||||
"prioritized_replay": False,
|
||||
|
@ -104,9 +103,7 @@ DEFAULT_CONFIG = with_common_config({
|
|||
# Update the replay buffer with this many samples at once. Note that this
|
||||
# setting applies per-worker if num_workers > 1.
|
||||
"rollout_fragment_length": 1,
|
||||
# Size of a batched sampled from replay buffer for training. Note that
|
||||
# if async_updates is set, then each worker returns gradients for a
|
||||
# batch of this size.
|
||||
# Size of a batched sampled from replay buffer for training.
|
||||
"train_batch_size": 256,
|
||||
# Update the target network every `target_network_update_freq` steps.
|
||||
"target_network_update_freq": 0,
|
||||
|
|
|
@ -9,7 +9,6 @@ from typing import Dict, List, Optional, Tuple, Type, Union
|
|||
|
||||
import ray
|
||||
import ray.experimental.tf_utils
|
||||
from ray.rllib.agents.a3c.a3c_torch_policy import apply_grad_clipping
|
||||
from ray.rllib.agents.sac.sac_tf_policy import build_sac_model, \
|
||||
postprocess_trajectory, validate_spaces
|
||||
from ray.rllib.agents.dqn.dqn_tf_policy import PRIO_WEIGHTS
|
||||
|
@ -23,7 +22,7 @@ from ray.rllib.models.torch.torch_action_dist import (
|
|||
TorchCategorical, TorchSquashedGaussian, TorchDiagGaussian, TorchBeta)
|
||||
from ray.rllib.utils.framework import try_import_torch
|
||||
from ray.rllib.utils.spaces.simplex import Simplex
|
||||
from ray.rllib.utils.torch_ops import huber_loss
|
||||
from ray.rllib.utils.torch_ops import apply_grad_clipping, huber_loss
|
||||
from ray.rllib.utils.typing import LocalOptimizer, TensorType, \
|
||||
TrainerConfigDict
|
||||
|
||||
|
|
|
@ -265,17 +265,11 @@ class MADDPGTFPolicy(MADDPGPostprocessing, TFPolicy):
|
|||
|
||||
@override(TFPolicy)
|
||||
def gradients(self, optimizer, loss):
|
||||
if self.config["grad_norm_clipping"] is not None:
|
||||
self.gvs = {
|
||||
k: minimize_and_clip(optimizer, self.losses[k], self.vars[k],
|
||||
self.config["grad_norm_clipping"])
|
||||
for k, optimizer in self.optimizers.items()
|
||||
}
|
||||
else:
|
||||
self.gvs = {
|
||||
k: optimizer.compute_gradients(self.losses[k], self.vars[k])
|
||||
for k, optimizer in self.optimizers.items()
|
||||
}
|
||||
self.gvs = {
|
||||
k: minimize_and_clip(optimizer, self.losses[k], self.vars[k],
|
||||
self.config["grad_norm_clipping"])
|
||||
for k, optimizer in self.optimizers.items()
|
||||
}
|
||||
return self.gvs["critic"] + self.gvs["actor"]
|
||||
|
||||
@override(TFPolicy)
|
||||
|
|
|
@ -1,22 +1,12 @@
|
|||
import numpy as np
|
||||
import scipy.signal
|
||||
from typing import Dict, Optional
|
||||
|
||||
from ray.rllib.evaluation.episode import MultiAgentEpisode
|
||||
from ray.rllib.policy.policy import Policy
|
||||
from ray.rllib.policy.sample_batch import SampleBatch
|
||||
from ray.rllib.utils.annotations import DeveloperAPI
|
||||
|
||||
|
||||
def discount_cumsum(x: np.ndarray, gamma: float) -> float:
|
||||
"""Calculates the discounted cumulative sum over a reward sequence `x`.
|
||||
|
||||
y[t] - discount*y[t+1] = x[t]
|
||||
reversed(y)[t] - discount*reversed(y)[t-1] = reversed(x)[t]
|
||||
|
||||
Args:
|
||||
gamma (float): The discount factor gamma.
|
||||
|
||||
Returns:
|
||||
float: The discounted cumulative sum over the reward sequence `x`.
|
||||
"""
|
||||
return scipy.signal.lfilter([1], [1, float(-gamma)], x[::-1], axis=0)[::-1]
|
||||
from ray.rllib.utils.typing import AgentID
|
||||
|
||||
|
||||
class Postprocessing:
|
||||
|
@ -89,3 +79,83 @@ def compute_advantages(rollout: SampleBatch,
|
|||
Postprocessing.ADVANTAGES].astype(np.float32)
|
||||
|
||||
return rollout
|
||||
|
||||
|
||||
def compute_gae_for_sample_batch(
|
||||
policy: Policy,
|
||||
sample_batch: SampleBatch,
|
||||
other_agent_batches: Optional[Dict[AgentID, SampleBatch]] = None,
|
||||
episode: Optional[MultiAgentEpisode] = None) -> SampleBatch:
|
||||
"""Adds GAE (generalized advantage estimations) to a trajectory.
|
||||
|
||||
The trajectory contains only data from one episode and from one agent.
|
||||
- If `config.batch_mode=truncate_episodes` (default), sample_batch may
|
||||
contain a truncated (at-the-end) episode, in case the
|
||||
`config.rollout_fragment_length` was reached by the sampler.
|
||||
- If `config.batch_mode=complete_episodes`, sample_batch will contain
|
||||
exactly one episode (no matter how long).
|
||||
New columns can be added to sample_batch and existing ones may be altered.
|
||||
|
||||
Args:
|
||||
policy (Policy): The Policy used to generate the trajectory
|
||||
(`sample_batch`)
|
||||
sample_batch (SampleBatch): The SampleBatch to postprocess.
|
||||
other_agent_batches (Optional[Dict[PolicyID, SampleBatch]]): Optional
|
||||
dict of AgentIDs mapping to other agents' trajectory data (from the
|
||||
same episode). NOTE: The other agents use the same policy.
|
||||
episode (Optional[MultiAgentEpisode]): Optional multi-agent episode
|
||||
object in which the agents operated.
|
||||
|
||||
Returns:
|
||||
SampleBatch: The postprocessed, modified SampleBatch (or a new one).
|
||||
"""
|
||||
|
||||
# Trajectory is actually complete -> last r=0.0.
|
||||
if sample_batch[SampleBatch.DONES][-1]:
|
||||
last_r = 0.0
|
||||
# Trajectory has been truncated -> last r=VF estimate of last obs.
|
||||
else:
|
||||
# Input dict is provided to us automatically via the Model's
|
||||
# requirements. It's a single-timestep (last one in trajectory)
|
||||
# input_dict.
|
||||
if policy.config.get("_use_trajectory_view_api"):
|
||||
# Create an input dict according to the Model's requirements.
|
||||
input_dict = policy.model.get_input_dict(
|
||||
sample_batch, index="last")
|
||||
last_r = policy._value(**input_dict)
|
||||
# TODO: (sven) Remove once trajectory view API is all-algo default.
|
||||
else:
|
||||
next_state = []
|
||||
for i in range(policy.num_state_tensors()):
|
||||
next_state.append(sample_batch["state_out_{}".format(i)][-1])
|
||||
last_r = policy._value(sample_batch[SampleBatch.NEXT_OBS][-1],
|
||||
sample_batch[SampleBatch.ACTIONS][-1],
|
||||
sample_batch[SampleBatch.REWARDS][-1],
|
||||
*next_state)
|
||||
|
||||
# Adds the policy logits, VF preds, and advantages to the batch,
|
||||
# using GAE ("generalized advantage estimation") or not.
|
||||
batch = compute_advantages(
|
||||
sample_batch,
|
||||
last_r,
|
||||
policy.config["gamma"],
|
||||
policy.config["lambda"],
|
||||
use_gae=policy.config["use_gae"],
|
||||
use_critic=policy.config.get("use_critic", True))
|
||||
|
||||
return batch
|
||||
|
||||
|
||||
def discount_cumsum(x: np.ndarray, gamma: float) -> float:
|
||||
"""Calculates the discounted cumulative sum over a reward sequence `x`.
|
||||
|
||||
y[t] - discount*y[t+1] = x[t]
|
||||
reversed(y)[t] - discount*reversed(y)[t-1] = reversed(x)[t]
|
||||
|
||||
Args:
|
||||
gamma (float): The discount factor gamma.
|
||||
|
||||
Returns:
|
||||
float: The discounted cumulative sum over the reward sequence `x`.
|
||||
"""
|
||||
return scipy.signal.lfilter([1], [1, float(-gamma)], x[::-1], axis=0)[::-1]
|
||||
|
|
|
@ -14,11 +14,11 @@ mspacman-sac-tf:
|
|||
# state-preprocessor=Our default Atari Conv2D-net.
|
||||
use_state_preprocessor: true
|
||||
Q_model:
|
||||
hidden_activation: relu
|
||||
hidden_layer_sizes: [512]
|
||||
fcnet_hiddens: [512]
|
||||
fcnet_activation: relu
|
||||
policy_model:
|
||||
hidden_activation: relu
|
||||
hidden_layer_sizes: [512]
|
||||
fcnet_hiddens: [512]
|
||||
fcnet_activation: relu
|
||||
# Do hard syncs.
|
||||
# Soft-syncs seem to work less reliably for discrete action spaces.
|
||||
tau: 1.0
|
||||
|
|
|
@ -92,7 +92,7 @@ def minimize_and_clip(optimizer, objective, var_list, clip_val=10.0):
|
|||
variable is clipped to `clip_val`
|
||||
"""
|
||||
# Accidentally passing values < 0.0 will break all gradients.
|
||||
assert clip_val > 0.0, clip_val
|
||||
assert clip_val is None or clip_val > 0.0, clip_val
|
||||
|
||||
if tf.executing_eagerly():
|
||||
tape = optimizer.tape
|
||||
|
@ -102,10 +102,8 @@ def minimize_and_clip(optimizer, objective, var_list, clip_val=10.0):
|
|||
grads_and_vars = optimizer.compute_gradients(
|
||||
objective, var_list=var_list)
|
||||
|
||||
for i, (grad, var) in enumerate(grads_and_vars):
|
||||
if grad is not None:
|
||||
grads_and_vars[i] = (tf.clip_by_norm(grad, clip_val), var)
|
||||
return grads_and_vars
|
||||
return [(tf.clip_by_norm(g, clip_val) if clip_val is not None else g, v)
|
||||
for (g, v) in grads_and_vars if g is not None]
|
||||
|
||||
|
||||
def make_tf_callable(session_or_none, dynamic_shape=False):
|
||||
|
|
|
@ -14,6 +14,30 @@ FLOAT_MIN = -3.4e38
|
|||
FLOAT_MAX = 3.4e38
|
||||
|
||||
|
||||
def apply_grad_clipping(policy, optimizer, loss):
|
||||
"""Applies gradient clipping to already computed grads inside `optimizer`.
|
||||
|
||||
Args:
|
||||
policy (TorchPolicy): The TorchPolicy, which calculated `loss`.
|
||||
optimizer (torch.optim.Optimizer): A local torch optimizer object.
|
||||
loss (torch.Tensor): The torch loss tensor.
|
||||
"""
|
||||
info = {}
|
||||
if policy.config["grad_clip"]:
|
||||
for param_group in optimizer.param_groups:
|
||||
# Make sure we only pass params with grad != None into torch
|
||||
# clip_grad_norm_. Would fail otherwise.
|
||||
params = list(
|
||||
filter(lambda p: p.grad is not None, param_group["params"]))
|
||||
if params:
|
||||
grad_gnorm = nn.utils.clip_grad_norm_(
|
||||
params, policy.config["grad_clip"])
|
||||
if isinstance(grad_gnorm, torch.Tensor):
|
||||
grad_gnorm = grad_gnorm.cpu().numpy()
|
||||
info["grad_gnorm"] = grad_gnorm
|
||||
return info
|
||||
|
||||
|
||||
def atanh(x):
|
||||
return 0.5 * torch.log((1 + x) / (1 - x))
|
||||
|
||||
|
|
Loading…
Add table
Reference in a new issue