mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
[RLlib] Issue 9071 A3C w/ RNN not working due to VF assuming no RNN. (#13238)
This commit is contained in:
parent
e74947cc94
commit
2e3655e8a9
24 changed files with 251 additions and 255 deletions
|
@ -1,17 +1,34 @@
|
||||||
"""Note: Keep in sync with changes to VTraceTFPolicy."""
|
"""Note: Keep in sync with changes to VTraceTFPolicy."""
|
||||||
|
|
||||||
import ray
|
import ray
|
||||||
|
from ray.rllib.agents.ppo.ppo_tf_policy import ValueNetworkMixin
|
||||||
from ray.rllib.policy.sample_batch import SampleBatch
|
from ray.rllib.policy.sample_batch import SampleBatch
|
||||||
from ray.rllib.evaluation.postprocessing import compute_advantages, \
|
from ray.rllib.evaluation.postprocessing import compute_gae_for_sample_batch, \
|
||||||
Postprocessing
|
Postprocessing
|
||||||
from ray.rllib.policy.tf_policy_template import build_tf_policy
|
from ray.rllib.policy.tf_policy_template import build_tf_policy
|
||||||
from ray.rllib.policy.tf_policy import LearningRateSchedule
|
from ray.rllib.policy.tf_policy import LearningRateSchedule
|
||||||
|
from ray.rllib.utils.deprecation import deprecation_warning
|
||||||
from ray.rllib.utils.framework import try_import_tf
|
from ray.rllib.utils.framework import try_import_tf
|
||||||
from ray.rllib.utils.tf_ops import explained_variance, make_tf_callable
|
from ray.rllib.utils.tf_ops import explained_variance
|
||||||
|
|
||||||
tf1, tf, tfv = try_import_tf()
|
tf1, tf, tfv = try_import_tf()
|
||||||
|
|
||||||
|
|
||||||
|
def postprocess_advantages(policy,
|
||||||
|
sample_batch,
|
||||||
|
other_agent_batches=None,
|
||||||
|
episode=None):
|
||||||
|
|
||||||
|
# Stub serving backward compatibility.
|
||||||
|
deprecation_warning(
|
||||||
|
old="rllib.agents.a3c.a3c_tf_policy.postprocess_advantages",
|
||||||
|
new="rllib.evaluation.postprocessing.compute_gae_for_sample_batch",
|
||||||
|
error=False)
|
||||||
|
|
||||||
|
return compute_gae_for_sample_batch(policy, sample_batch,
|
||||||
|
other_agent_batches, episode)
|
||||||
|
|
||||||
|
|
||||||
class A3CLoss:
|
class A3CLoss:
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
action_dist,
|
action_dist,
|
||||||
|
@ -45,46 +62,10 @@ def actor_critic_loss(policy, model, dist_class, train_batch):
|
||||||
return policy.loss.total_loss
|
return policy.loss.total_loss
|
||||||
|
|
||||||
|
|
||||||
def postprocess_advantages(policy,
|
|
||||||
sample_batch,
|
|
||||||
other_agent_batches=None,
|
|
||||||
episode=None):
|
|
||||||
completed = sample_batch[SampleBatch.DONES][-1]
|
|
||||||
if completed:
|
|
||||||
last_r = 0.0
|
|
||||||
else:
|
|
||||||
next_state = []
|
|
||||||
for i in range(policy.num_state_tensors()):
|
|
||||||
next_state.append(sample_batch["state_out_{}".format(i)][-1])
|
|
||||||
last_r = policy._value(sample_batch[SampleBatch.NEXT_OBS][-1],
|
|
||||||
sample_batch[SampleBatch.ACTIONS][-1],
|
|
||||||
sample_batch[SampleBatch.REWARDS][-1],
|
|
||||||
*next_state)
|
|
||||||
return compute_advantages(
|
|
||||||
sample_batch, last_r, policy.config["gamma"], policy.config["lambda"],
|
|
||||||
policy.config["use_gae"], policy.config["use_critic"])
|
|
||||||
|
|
||||||
|
|
||||||
def add_value_function_fetch(policy):
|
def add_value_function_fetch(policy):
|
||||||
return {SampleBatch.VF_PREDS: policy.model.value_function()}
|
return {SampleBatch.VF_PREDS: policy.model.value_function()}
|
||||||
|
|
||||||
|
|
||||||
class ValueNetworkMixin:
|
|
||||||
def __init__(self):
|
|
||||||
@make_tf_callable(self.get_session())
|
|
||||||
def value(ob, prev_action, prev_reward, *state):
|
|
||||||
model_out, _ = self.model({
|
|
||||||
SampleBatch.CUR_OBS: tf.convert_to_tensor([ob]),
|
|
||||||
SampleBatch.PREV_ACTIONS: tf.convert_to_tensor([prev_action]),
|
|
||||||
SampleBatch.PREV_REWARDS: tf.convert_to_tensor([prev_reward]),
|
|
||||||
"is_training": tf.convert_to_tensor(False),
|
|
||||||
}, [tf.convert_to_tensor([s]) for s in state],
|
|
||||||
tf.convert_to_tensor([1]))
|
|
||||||
return self.model.value_function()[0]
|
|
||||||
|
|
||||||
self._value = value
|
|
||||||
|
|
||||||
|
|
||||||
def stats(policy, train_batch):
|
def stats(policy, train_batch):
|
||||||
return {
|
return {
|
||||||
"cur_lr": tf.cast(policy.cur_lr, tf.float64),
|
"cur_lr": tf.cast(policy.cur_lr, tf.float64),
|
||||||
|
@ -115,7 +96,7 @@ def clip_gradients(policy, optimizer, loss):
|
||||||
|
|
||||||
|
|
||||||
def setup_mixins(policy, obs_space, action_space, config):
|
def setup_mixins(policy, obs_space, action_space, config):
|
||||||
ValueNetworkMixin.__init__(policy)
|
ValueNetworkMixin.__init__(policy, obs_space, action_space, config)
|
||||||
LearningRateSchedule.__init__(policy, config["lr"], config["lr_schedule"])
|
LearningRateSchedule.__init__(policy, config["lr"], config["lr_schedule"])
|
||||||
|
|
||||||
|
|
||||||
|
@ -126,7 +107,7 @@ A3CTFPolicy = build_tf_policy(
|
||||||
stats_fn=stats,
|
stats_fn=stats,
|
||||||
grad_stats_fn=grad_stats,
|
grad_stats_fn=grad_stats,
|
||||||
gradients_fn=clip_gradients,
|
gradients_fn=clip_gradients,
|
||||||
postprocess_fn=postprocess_advantages,
|
postprocess_fn=compute_gae_for_sample_batch,
|
||||||
extra_action_fetches_fn=add_value_function_fetch,
|
extra_action_fetches_fn=add_value_function_fetch,
|
||||||
before_loss_init=setup_mixins,
|
before_loss_init=setup_mixins,
|
||||||
mixins=[ValueNetworkMixin, LearningRateSchedule])
|
mixins=[ValueNetworkMixin, LearningRateSchedule])
|
||||||
|
|
|
@ -1,13 +1,35 @@
|
||||||
|
import gym
|
||||||
|
|
||||||
import ray
|
import ray
|
||||||
from ray.rllib.evaluation.postprocessing import compute_advantages, \
|
from ray.rllib.agents.ppo.ppo_torch_policy import ValueNetworkMixin
|
||||||
|
from ray.rllib.evaluation.postprocessing import compute_gae_for_sample_batch, \
|
||||||
Postprocessing
|
Postprocessing
|
||||||
|
from ray.rllib.policy.policy import Policy
|
||||||
from ray.rllib.policy.policy_template import build_policy_class
|
from ray.rllib.policy.policy_template import build_policy_class
|
||||||
from ray.rllib.policy.sample_batch import SampleBatch
|
from ray.rllib.policy.sample_batch import SampleBatch
|
||||||
|
from ray.rllib.utils.deprecation import deprecation_warning
|
||||||
from ray.rllib.utils.framework import try_import_torch
|
from ray.rllib.utils.framework import try_import_torch
|
||||||
|
from ray.rllib.utils.torch_ops import apply_grad_clipping
|
||||||
|
from ray.rllib.utils.typing import TrainerConfigDict
|
||||||
|
|
||||||
torch, nn = try_import_torch()
|
torch, nn = try_import_torch()
|
||||||
|
|
||||||
|
|
||||||
|
def add_advantages(policy,
|
||||||
|
sample_batch,
|
||||||
|
other_agent_batches=None,
|
||||||
|
episode=None):
|
||||||
|
|
||||||
|
# Stub serving backward compatibility.
|
||||||
|
deprecation_warning(
|
||||||
|
old="rllib.agents.a3c.a3c_torch_policy.add_advantages",
|
||||||
|
new="rllib.evaluation.postprocessing.compute_gae_for_sample_batch",
|
||||||
|
error=False)
|
||||||
|
|
||||||
|
return compute_gae_for_sample_batch(policy, sample_batch,
|
||||||
|
other_agent_batches, episode)
|
||||||
|
|
||||||
|
|
||||||
def actor_critic_loss(policy, model, dist_class, train_batch):
|
def actor_critic_loss(policy, model, dist_class, train_batch):
|
||||||
logits, _ = model.from_batch(train_batch)
|
logits, _ = model.from_batch(train_batch)
|
||||||
values = model.value_function()
|
values = model.value_function()
|
||||||
|
@ -36,52 +58,27 @@ def loss_and_entropy_stats(policy, train_batch):
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def add_advantages(policy,
|
|
||||||
sample_batch,
|
|
||||||
other_agent_batches=None,
|
|
||||||
episode=None):
|
|
||||||
|
|
||||||
completed = sample_batch[SampleBatch.DONES][-1]
|
|
||||||
if completed:
|
|
||||||
last_r = 0.0
|
|
||||||
else:
|
|
||||||
last_r = policy._value(sample_batch[SampleBatch.NEXT_OBS][-1])
|
|
||||||
|
|
||||||
return compute_advantages(
|
|
||||||
sample_batch, last_r, policy.config["gamma"], policy.config["lambda"],
|
|
||||||
policy.config["use_gae"], policy.config["use_critic"])
|
|
||||||
|
|
||||||
|
|
||||||
def model_value_predictions(policy, input_dict, state_batches, model,
|
def model_value_predictions(policy, input_dict, state_batches, model,
|
||||||
action_dist):
|
action_dist):
|
||||||
return {SampleBatch.VF_PREDS: model.value_function()}
|
return {SampleBatch.VF_PREDS: model.value_function()}
|
||||||
|
|
||||||
|
|
||||||
def apply_grad_clipping(policy, optimizer, loss):
|
|
||||||
info = {}
|
|
||||||
if policy.config["grad_clip"]:
|
|
||||||
for param_group in optimizer.param_groups:
|
|
||||||
# Make sure we only pass params with grad != None into torch
|
|
||||||
# clip_grad_norm_. Would fail otherwise.
|
|
||||||
params = list(
|
|
||||||
filter(lambda p: p.grad is not None, param_group["params"]))
|
|
||||||
if params:
|
|
||||||
grad_gnorm = nn.utils.clip_grad_norm_(
|
|
||||||
params, policy.config["grad_clip"])
|
|
||||||
if isinstance(grad_gnorm, torch.Tensor):
|
|
||||||
grad_gnorm = grad_gnorm.cpu().numpy()
|
|
||||||
info["grad_gnorm"] = grad_gnorm
|
|
||||||
return info
|
|
||||||
|
|
||||||
|
|
||||||
def torch_optimizer(policy, config):
|
def torch_optimizer(policy, config):
|
||||||
return torch.optim.Adam(policy.model.parameters(), lr=config["lr"])
|
return torch.optim.Adam(policy.model.parameters(), lr=config["lr"])
|
||||||
|
|
||||||
|
|
||||||
class ValueNetworkMixin:
|
def setup_mixins(policy: Policy, obs_space: gym.spaces.Space,
|
||||||
def _value(self, obs):
|
action_space: gym.spaces.Space,
|
||||||
_ = self.model({"obs": torch.Tensor([obs]).to(self.device)}, [], [1])
|
config: TrainerConfigDict) -> None:
|
||||||
return self.model.value_function()[0]
|
"""Call all mixin classes' constructors before PPOPolicy initialization.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
policy (Policy): The Policy object.
|
||||||
|
obs_space (gym.spaces.Space): The Policy's observation space.
|
||||||
|
action_space (gym.spaces.Space): The Policy's action space.
|
||||||
|
config (TrainerConfigDict): The Policy's config.
|
||||||
|
"""
|
||||||
|
ValueNetworkMixin.__init__(policy, obs_space, action_space, config)
|
||||||
|
|
||||||
|
|
||||||
A3CTorchPolicy = build_policy_class(
|
A3CTorchPolicy = build_policy_class(
|
||||||
|
@ -90,9 +87,10 @@ A3CTorchPolicy = build_policy_class(
|
||||||
get_default_config=lambda: ray.rllib.agents.a3c.a3c.DEFAULT_CONFIG,
|
get_default_config=lambda: ray.rllib.agents.a3c.a3c.DEFAULT_CONFIG,
|
||||||
loss_fn=actor_critic_loss,
|
loss_fn=actor_critic_loss,
|
||||||
stats_fn=loss_and_entropy_stats,
|
stats_fn=loss_and_entropy_stats,
|
||||||
postprocess_fn=add_advantages,
|
postprocess_fn=compute_gae_for_sample_batch,
|
||||||
extra_action_out_fn=model_value_predictions,
|
extra_action_out_fn=model_value_predictions,
|
||||||
extra_grad_process_fn=apply_grad_clipping,
|
extra_grad_process_fn=apply_grad_clipping,
|
||||||
optimizer_fn=torch_optimizer,
|
optimizer_fn=torch_optimizer,
|
||||||
|
before_loss_init=setup_mixins,
|
||||||
mixins=[ValueNetworkMixin],
|
mixins=[ValueNetworkMixin],
|
||||||
)
|
)
|
||||||
|
|
|
@ -8,7 +8,6 @@ from typing import Dict, List, Tuple, Type, Union
|
||||||
|
|
||||||
import ray
|
import ray
|
||||||
import ray.experimental.tf_utils
|
import ray.experimental.tf_utils
|
||||||
from ray.rllib.agents.a3c.a3c_torch_policy import apply_grad_clipping
|
|
||||||
from ray.rllib.agents.sac.sac_tf_policy import postprocess_trajectory, \
|
from ray.rllib.agents.sac.sac_tf_policy import postprocess_trajectory, \
|
||||||
validate_spaces
|
validate_spaces
|
||||||
from ray.rllib.agents.sac.sac_torch_policy import _get_dist_class, stats, \
|
from ray.rllib.agents.sac.sac_torch_policy import _get_dist_class, stats, \
|
||||||
|
@ -22,7 +21,8 @@ from ray.rllib.policy.sample_batch import SampleBatch
|
||||||
from ray.rllib.utils.framework import try_import_torch
|
from ray.rllib.utils.framework import try_import_torch
|
||||||
from ray.rllib.utils.typing import LocalOptimizer, TensorType, \
|
from ray.rllib.utils.typing import LocalOptimizer, TensorType, \
|
||||||
TrainerConfigDict
|
TrainerConfigDict
|
||||||
from ray.rllib.utils.torch_ops import convert_to_torch_tensor
|
from ray.rllib.utils.torch_ops import apply_grad_clipping, \
|
||||||
|
convert_to_torch_tensor
|
||||||
|
|
||||||
torch, nn = try_import_torch()
|
torch, nn = try_import_torch()
|
||||||
F = nn.functional
|
F = nn.functional
|
||||||
|
|
|
@ -1,7 +1,6 @@
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
import ray
|
import ray
|
||||||
from ray.rllib.agents.a3c.a3c_torch_policy import apply_grad_clipping
|
|
||||||
from ray.rllib.agents.ddpg.ddpg_tf_policy import build_ddpg_models, \
|
from ray.rllib.agents.ddpg.ddpg_tf_policy import build_ddpg_models, \
|
||||||
get_distribution_inputs_and_class, validate_spaces
|
get_distribution_inputs_and_class, validate_spaces
|
||||||
from ray.rllib.agents.dqn.dqn_tf_policy import postprocess_nstep_and_prio, \
|
from ray.rllib.agents.dqn.dqn_tf_policy import postprocess_nstep_and_prio, \
|
||||||
|
@ -10,7 +9,7 @@ from ray.rllib.models.torch.torch_action_dist import TorchDeterministic
|
||||||
from ray.rllib.policy.policy_template import build_policy_class
|
from ray.rllib.policy.policy_template import build_policy_class
|
||||||
from ray.rllib.policy.sample_batch import SampleBatch
|
from ray.rllib.policy.sample_batch import SampleBatch
|
||||||
from ray.rllib.utils.framework import try_import_torch
|
from ray.rllib.utils.framework import try_import_torch
|
||||||
from ray.rllib.utils.torch_ops import huber_loss, l2_loss
|
from ray.rllib.utils.torch_ops import apply_grad_clipping, huber_loss, l2_loss
|
||||||
|
|
||||||
torch, nn = try_import_torch()
|
torch, nn = try_import_torch()
|
||||||
|
|
||||||
|
|
|
@ -301,17 +301,11 @@ def adam_optimizer(policy: Policy, config: TrainerConfigDict
|
||||||
|
|
||||||
def clip_gradients(policy: Policy, optimizer: "tf.keras.optimizers.Optimizer",
|
def clip_gradients(policy: Policy, optimizer: "tf.keras.optimizers.Optimizer",
|
||||||
loss: TensorType) -> ModelGradients:
|
loss: TensorType) -> ModelGradients:
|
||||||
if policy.config["grad_clip"] is not None:
|
return minimize_and_clip(
|
||||||
grads_and_vars = minimize_and_clip(
|
optimizer,
|
||||||
optimizer,
|
loss,
|
||||||
loss,
|
var_list=policy.q_func_vars,
|
||||||
var_list=policy.q_func_vars,
|
clip_val=policy.config["grad_clip"])
|
||||||
clip_val=policy.config["grad_clip"])
|
|
||||||
else:
|
|
||||||
grads_and_vars = optimizer.compute_gradients(
|
|
||||||
loss, var_list=policy.q_func_vars)
|
|
||||||
grads_and_vars = [(g, v) for (g, v) in grads_and_vars if g is not None]
|
|
||||||
return grads_and_vars
|
|
||||||
|
|
||||||
|
|
||||||
def build_q_stats(policy: Policy, batch) -> Dict[str, TensorType]:
|
def build_q_stats(policy: Policy, batch) -> Dict[str, TensorType]:
|
||||||
|
|
|
@ -4,7 +4,6 @@ from typing import Dict, List, Tuple
|
||||||
|
|
||||||
import gym
|
import gym
|
||||||
import ray
|
import ray
|
||||||
from ray.rllib.agents.a3c.a3c_torch_policy import apply_grad_clipping
|
|
||||||
from ray.rllib.agents.dqn.dqn_tf_policy import (
|
from ray.rllib.agents.dqn.dqn_tf_policy import (
|
||||||
PRIO_WEIGHTS, Q_SCOPE, Q_TARGET_SCOPE, postprocess_nstep_and_prio)
|
PRIO_WEIGHTS, Q_SCOPE, Q_TARGET_SCOPE, postprocess_nstep_and_prio)
|
||||||
from ray.rllib.agents.dqn.dqn_torch_model import DQNTorchModel
|
from ray.rllib.agents.dqn.dqn_torch_model import DQNTorchModel
|
||||||
|
@ -20,9 +19,8 @@ from ray.rllib.policy.torch_policy import LearningRateSchedule
|
||||||
from ray.rllib.utils.error import UnsupportedSpaceException
|
from ray.rllib.utils.error import UnsupportedSpaceException
|
||||||
from ray.rllib.utils.exploration.parameter_noise import ParameterNoise
|
from ray.rllib.utils.exploration.parameter_noise import ParameterNoise
|
||||||
from ray.rllib.utils.framework import try_import_torch
|
from ray.rllib.utils.framework import try_import_torch
|
||||||
from ray.rllib.utils.torch_ops import (FLOAT_MIN, huber_loss,
|
from ray.rllib.utils.torch_ops import apply_grad_clipping, FLOAT_MIN, \
|
||||||
reduce_mean_ignore_inf,
|
huber_loss, reduce_mean_ignore_inf, softmax_cross_entropy_with_logits
|
||||||
softmax_cross_entropy_with_logits)
|
|
||||||
from ray.rllib.utils.typing import TensorType, TrainerConfigDict
|
from ray.rllib.utils.typing import TensorType, TrainerConfigDict
|
||||||
|
|
||||||
torch, nn = try_import_torch()
|
torch, nn = try_import_torch()
|
||||||
|
|
|
@ -1,11 +1,11 @@
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
import ray
|
import ray
|
||||||
from ray.rllib.agents.a3c.a3c_torch_policy import apply_grad_clipping
|
|
||||||
from ray.rllib.agents.dreamer.utils import FreezeParameters
|
from ray.rllib.agents.dreamer.utils import FreezeParameters
|
||||||
from ray.rllib.models.catalog import ModelCatalog
|
from ray.rllib.models.catalog import ModelCatalog
|
||||||
from ray.rllib.policy.policy_template import build_policy_class
|
from ray.rllib.policy.policy_template import build_policy_class
|
||||||
from ray.rllib.utils.framework import try_import_torch
|
from ray.rllib.utils.framework import try_import_torch
|
||||||
|
from ray.rllib.utils.torch_ops import apply_grad_clipping
|
||||||
|
|
||||||
torch, nn = try_import_torch()
|
torch, nn = try_import_torch()
|
||||||
if torch:
|
if torch:
|
||||||
|
|
|
@ -1,7 +1,6 @@
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
import ray
|
import ray
|
||||||
from ray.rllib.agents.a3c.a3c_tf_policy import A3CTFPolicy
|
|
||||||
from ray.rllib.agents.impala.vtrace_tf_policy import VTraceTFPolicy
|
from ray.rllib.agents.impala.vtrace_tf_policy import VTraceTFPolicy
|
||||||
from ray.rllib.agents.trainer import Trainer, with_common_config
|
from ray.rllib.agents.trainer import Trainer, with_common_config
|
||||||
from ray.rllib.agents.trainer_template import build_trainer
|
from ray.rllib.agents.trainer_template import build_trainer
|
||||||
|
@ -160,6 +159,7 @@ def get_policy_class(config):
|
||||||
if config["vtrace"]:
|
if config["vtrace"]:
|
||||||
return VTraceTFPolicy
|
return VTraceTFPolicy
|
||||||
else:
|
else:
|
||||||
|
from ray.rllib.agents.a3c.a3c_tf_policy import A3CTFPolicy
|
||||||
return A3CTFPolicy
|
return A3CTFPolicy
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -3,7 +3,6 @@ import logging
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
import ray
|
import ray
|
||||||
from ray.rllib.agents.a3c.a3c_torch_policy import apply_grad_clipping
|
|
||||||
import ray.rllib.agents.impala.vtrace_torch as vtrace
|
import ray.rllib.agents.impala.vtrace_torch as vtrace
|
||||||
from ray.rllib.models.torch.torch_action_dist import TorchCategorical
|
from ray.rllib.models.torch.torch_action_dist import TorchCategorical
|
||||||
from ray.rllib.policy.policy_template import build_policy_class
|
from ray.rllib.policy.policy_template import build_policy_class
|
||||||
|
@ -11,8 +10,8 @@ from ray.rllib.policy.sample_batch import SampleBatch
|
||||||
from ray.rllib.policy.torch_policy import LearningRateSchedule, \
|
from ray.rllib.policy.torch_policy import LearningRateSchedule, \
|
||||||
EntropyCoeffSchedule
|
EntropyCoeffSchedule
|
||||||
from ray.rllib.utils.framework import try_import_torch
|
from ray.rllib.utils.framework import try_import_torch
|
||||||
from ray.rllib.utils.torch_ops import explained_variance, global_norm, \
|
from ray.rllib.utils.torch_ops import apply_grad_clipping, \
|
||||||
sequence_mask
|
explained_variance, global_norm, sequence_mask
|
||||||
|
|
||||||
torch, nn = try_import_torch()
|
torch, nn = try_import_torch()
|
||||||
|
|
||||||
|
|
|
@ -1,10 +1,10 @@
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
import ray
|
import ray
|
||||||
from ray.rllib.agents.ppo.ppo_tf_policy import postprocess_ppo_gae, \
|
from ray.rllib.agents.ppo.ppo_tf_policy import vf_preds_fetches, \
|
||||||
vf_preds_fetches, compute_and_clip_gradients, setup_config, \
|
compute_and_clip_gradients, setup_config, ValueNetworkMixin
|
||||||
ValueNetworkMixin
|
from ray.rllib.evaluation.postprocessing import compute_gae_for_sample_batch, \
|
||||||
from ray.rllib.evaluation.postprocessing import Postprocessing
|
Postprocessing
|
||||||
from ray.rllib.models.utils import get_activation_fn
|
from ray.rllib.models.utils import get_activation_fn
|
||||||
from ray.rllib.policy.sample_batch import SampleBatch
|
from ray.rllib.policy.sample_batch import SampleBatch
|
||||||
from ray.rllib.policy.tf_policy_template import build_tf_policy
|
from ray.rllib.policy.tf_policy_template import build_tf_policy
|
||||||
|
@ -422,7 +422,7 @@ MAMLTFPolicy = build_tf_policy(
|
||||||
stats_fn=maml_stats,
|
stats_fn=maml_stats,
|
||||||
optimizer_fn=maml_optimizer_fn,
|
optimizer_fn=maml_optimizer_fn,
|
||||||
extra_action_fetches_fn=vf_preds_fetches,
|
extra_action_fetches_fn=vf_preds_fetches,
|
||||||
postprocess_fn=postprocess_ppo_gae,
|
postprocess_fn=compute_gae_for_sample_batch,
|
||||||
gradients_fn=compute_and_clip_gradients,
|
gradients_fn=compute_and_clip_gradients,
|
||||||
before_init=setup_config,
|
before_init=setup_config,
|
||||||
before_loss_init=setup_mixins,
|
before_loss_init=setup_mixins,
|
||||||
|
|
|
@ -1,15 +1,15 @@
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
import ray
|
import ray
|
||||||
from ray.rllib.evaluation.postprocessing import Postprocessing
|
from ray.rllib.evaluation.postprocessing import compute_gae_for_sample_batch, \
|
||||||
|
Postprocessing
|
||||||
from ray.rllib.policy.policy_template import build_policy_class
|
from ray.rllib.policy.policy_template import build_policy_class
|
||||||
from ray.rllib.policy.sample_batch import SampleBatch
|
from ray.rllib.policy.sample_batch import SampleBatch
|
||||||
from ray.rllib.agents.ppo.ppo_tf_policy import postprocess_ppo_gae, \
|
from ray.rllib.agents.ppo.ppo_tf_policy import setup_config
|
||||||
setup_config
|
|
||||||
from ray.rllib.agents.ppo.ppo_torch_policy import vf_preds_fetches, \
|
from ray.rllib.agents.ppo.ppo_torch_policy import vf_preds_fetches, \
|
||||||
ValueNetworkMixin
|
ValueNetworkMixin
|
||||||
from ray.rllib.agents.a3c.a3c_torch_policy import apply_grad_clipping
|
|
||||||
from ray.rllib.utils.framework import try_import_torch
|
from ray.rllib.utils.framework import try_import_torch
|
||||||
|
from ray.rllib.utils.torch_ops import apply_grad_clipping
|
||||||
|
|
||||||
torch, nn = try_import_torch()
|
torch, nn = try_import_torch()
|
||||||
|
|
||||||
|
@ -355,7 +355,7 @@ MAMLTorchPolicy = build_policy_class(
|
||||||
stats_fn=maml_stats,
|
stats_fn=maml_stats,
|
||||||
optimizer_fn=maml_optimizer_fn,
|
optimizer_fn=maml_optimizer_fn,
|
||||||
extra_action_out_fn=vf_preds_fetches,
|
extra_action_out_fn=vf_preds_fetches,
|
||||||
postprocess_fn=postprocess_ppo_gae,
|
postprocess_fn=compute_gae_for_sample_batch,
|
||||||
extra_grad_process_fn=apply_grad_clipping,
|
extra_grad_process_fn=apply_grad_clipping,
|
||||||
before_init=setup_config,
|
before_init=setup_config,
|
||||||
after_init=setup_mixins,
|
after_init=setup_mixins,
|
||||||
|
|
|
@ -3,18 +3,18 @@ import logging
|
||||||
from typing import Tuple, Type
|
from typing import Tuple, Type
|
||||||
|
|
||||||
import ray
|
import ray
|
||||||
from ray.rllib.agents.a3c.a3c_torch_policy import apply_grad_clipping
|
|
||||||
from ray.rllib.agents.maml.maml_torch_policy import setup_mixins, \
|
from ray.rllib.agents.maml.maml_torch_policy import setup_mixins, \
|
||||||
maml_loss, maml_stats, maml_optimizer_fn, KLCoeffMixin
|
maml_loss, maml_stats, maml_optimizer_fn, KLCoeffMixin
|
||||||
from ray.rllib.agents.ppo.ppo_tf_policy import postprocess_ppo_gae, \
|
from ray.rllib.agents.ppo.ppo_tf_policy import setup_config
|
||||||
setup_config
|
|
||||||
from ray.rllib.agents.ppo.ppo_torch_policy import vf_preds_fetches
|
from ray.rllib.agents.ppo.ppo_torch_policy import vf_preds_fetches
|
||||||
|
from ray.rllib.evaluation.postprocessing import compute_gae_for_sample_batch
|
||||||
from ray.rllib.models.catalog import ModelCatalog
|
from ray.rllib.models.catalog import ModelCatalog
|
||||||
from ray.rllib.models.modelv2 import ModelV2
|
from ray.rllib.models.modelv2 import ModelV2
|
||||||
from ray.rllib.models.torch.torch_action_dist import TorchDistributionWrapper
|
from ray.rllib.models.torch.torch_action_dist import TorchDistributionWrapper
|
||||||
from ray.rllib.policy.policy import Policy
|
from ray.rllib.policy.policy import Policy
|
||||||
from ray.rllib.policy.policy_template import build_policy_class
|
from ray.rllib.policy.policy_template import build_policy_class
|
||||||
from ray.rllib.utils.framework import try_import_torch
|
from ray.rllib.utils.framework import try_import_torch
|
||||||
|
from ray.rllib.utils.torch_ops import apply_grad_clipping
|
||||||
from ray.rllib.utils.typing import TrainerConfigDict
|
from ray.rllib.utils.typing import TrainerConfigDict
|
||||||
|
|
||||||
torch, nn = try_import_torch()
|
torch, nn = try_import_torch()
|
||||||
|
@ -85,7 +85,7 @@ MBMPOTorchPolicy = build_policy_class(
|
||||||
stats_fn=maml_stats,
|
stats_fn=maml_stats,
|
||||||
optimizer_fn=maml_optimizer_fn,
|
optimizer_fn=maml_optimizer_fn,
|
||||||
extra_action_out_fn=vf_preds_fetches,
|
extra_action_out_fn=vf_preds_fetches,
|
||||||
postprocess_fn=postprocess_ppo_gae,
|
postprocess_fn=compute_gae_for_sample_batch,
|
||||||
extra_grad_process_fn=apply_grad_clipping,
|
extra_grad_process_fn=apply_grad_clipping,
|
||||||
before_init=setup_config,
|
before_init=setup_config,
|
||||||
after_init=setup_mixins,
|
after_init=setup_mixins,
|
||||||
|
|
|
@ -13,9 +13,9 @@ from typing import Dict, List, Optional, Type, Union
|
||||||
from ray.rllib.agents.impala import vtrace_tf as vtrace
|
from ray.rllib.agents.impala import vtrace_tf as vtrace
|
||||||
from ray.rllib.agents.impala.vtrace_tf_policy import _make_time_major, \
|
from ray.rllib.agents.impala.vtrace_tf_policy import _make_time_major, \
|
||||||
clip_gradients, choose_optimizer
|
clip_gradients, choose_optimizer
|
||||||
from ray.rllib.agents.ppo.ppo_tf_policy import postprocess_ppo_gae
|
|
||||||
from ray.rllib.evaluation.episode import MultiAgentEpisode
|
from ray.rllib.evaluation.episode import MultiAgentEpisode
|
||||||
from ray.rllib.evaluation.postprocessing import Postprocessing
|
from ray.rllib.evaluation.postprocessing import compute_gae_for_sample_batch, \
|
||||||
|
Postprocessing
|
||||||
from ray.rllib.models.tf.tf_action_dist import Categorical
|
from ray.rllib.models.tf.tf_action_dist import Categorical
|
||||||
from ray.rllib.policy.policy import Policy
|
from ray.rllib.policy.policy import Policy
|
||||||
from ray.rllib.policy.sample_batch import SampleBatch
|
from ray.rllib.policy.sample_batch import SampleBatch
|
||||||
|
@ -338,8 +338,8 @@ def postprocess_trajectory(
|
||||||
SampleBatch: The postprocessed, modified SampleBatch (or a new one).
|
SampleBatch: The postprocessed, modified SampleBatch (or a new one).
|
||||||
"""
|
"""
|
||||||
if not policy.config["vtrace"]:
|
if not policy.config["vtrace"]:
|
||||||
sample_batch = postprocess_ppo_gae(policy, sample_batch,
|
sample_batch = compute_gae_for_sample_batch(
|
||||||
other_agent_batches, episode)
|
policy, sample_batch, other_agent_batches, episode)
|
||||||
|
|
||||||
# TODO: (sven) remove this del once we have trajectory view API fully in
|
# TODO: (sven) remove this del once we have trajectory view API fully in
|
||||||
# place.
|
# place.
|
||||||
|
|
|
@ -10,7 +10,6 @@ import numpy as np
|
||||||
import logging
|
import logging
|
||||||
from typing import Type
|
from typing import Type
|
||||||
|
|
||||||
from ray.rllib.agents.a3c.a3c_torch_policy import apply_grad_clipping
|
|
||||||
import ray.rllib.agents.impala.vtrace_torch as vtrace
|
import ray.rllib.agents.impala.vtrace_torch as vtrace
|
||||||
from ray.rllib.agents.impala.vtrace_torch_policy import make_time_major, \
|
from ray.rllib.agents.impala.vtrace_torch_policy import make_time_major, \
|
||||||
choose_optimizer
|
choose_optimizer
|
||||||
|
@ -27,8 +26,8 @@ from ray.rllib.policy.policy_template import build_policy_class
|
||||||
from ray.rllib.policy.sample_batch import SampleBatch
|
from ray.rllib.policy.sample_batch import SampleBatch
|
||||||
from ray.rllib.policy.torch_policy import LearningRateSchedule
|
from ray.rllib.policy.torch_policy import LearningRateSchedule
|
||||||
from ray.rllib.utils.framework import try_import_torch
|
from ray.rllib.utils.framework import try_import_torch
|
||||||
from ray.rllib.utils.torch_ops import explained_variance, global_norm, \
|
from ray.rllib.utils.torch_ops import apply_grad_clipping, explained_variance,\
|
||||||
sequence_mask
|
global_norm, sequence_mask
|
||||||
from ray.rllib.utils.typing import TensorType, TrainerConfigDict
|
from ray.rllib.utils.typing import TensorType, TrainerConfigDict
|
||||||
|
|
||||||
torch, nn = try_import_torch()
|
torch, nn = try_import_torch()
|
||||||
|
|
|
@ -8,7 +8,7 @@ from typing import Dict, List, Optional, Type, Union
|
||||||
|
|
||||||
import ray
|
import ray
|
||||||
from ray.rllib.evaluation.episode import MultiAgentEpisode
|
from ray.rllib.evaluation.episode import MultiAgentEpisode
|
||||||
from ray.rllib.evaluation.postprocessing import compute_advantages, \
|
from ray.rllib.evaluation.postprocessing import compute_gae_for_sample_batch, \
|
||||||
Postprocessing
|
Postprocessing
|
||||||
from ray.rllib.models.modelv2 import ModelV2
|
from ray.rllib.models.modelv2 import ModelV2
|
||||||
from ray.rllib.models.tf.tf_action_dist import TFActionDistribution
|
from ray.rllib.models.tf.tf_action_dist import TFActionDistribution
|
||||||
|
@ -160,71 +160,6 @@ def vf_preds_fetches(policy: Policy) -> Dict[str, TensorType]:
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def postprocess_ppo_gae(
|
|
||||||
policy: Policy,
|
|
||||||
sample_batch: SampleBatch,
|
|
||||||
other_agent_batches: Optional[Dict[AgentID, SampleBatch]] = None,
|
|
||||||
episode: Optional[MultiAgentEpisode] = None) -> SampleBatch:
|
|
||||||
"""Postprocesses a trajectory and returns the processed trajectory.
|
|
||||||
|
|
||||||
The trajectory contains only data from one episode and from one agent.
|
|
||||||
- If `config.batch_mode=truncate_episodes` (default), sample_batch may
|
|
||||||
contain a truncated (at-the-end) episode, in case the
|
|
||||||
`config.rollout_fragment_length` was reached by the sampler.
|
|
||||||
- If `config.batch_mode=complete_episodes`, sample_batch will contain
|
|
||||||
exactly one episode (no matter how long).
|
|
||||||
New columns can be added to sample_batch and existing ones may be altered.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
policy (Policy): The Policy used to generate the trajectory
|
|
||||||
(`sample_batch`)
|
|
||||||
sample_batch (SampleBatch): The SampleBatch to postprocess.
|
|
||||||
other_agent_batches (Optional[Dict[PolicyID, SampleBatch]]): Optional
|
|
||||||
dict of AgentIDs mapping to other agents' trajectory data (from the
|
|
||||||
same episode). NOTE: The other agents use the same policy.
|
|
||||||
episode (Optional[MultiAgentEpisode]): Optional multi-agent episode
|
|
||||||
object in which the agents operated.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
SampleBatch: The postprocessed, modified SampleBatch (or a new one).
|
|
||||||
"""
|
|
||||||
|
|
||||||
# Trajectory is actually complete -> last r=0.0.
|
|
||||||
if sample_batch[SampleBatch.DONES][-1]:
|
|
||||||
last_r = 0.0
|
|
||||||
# Trajectory has been truncated -> last r=VF estimate of last obs.
|
|
||||||
else:
|
|
||||||
# Input dict is provided to us automatically via the Model's
|
|
||||||
# requirements. It's a single-timestep (last one in trajectory)
|
|
||||||
# input_dict.
|
|
||||||
if policy.config["_use_trajectory_view_api"]:
|
|
||||||
# Create an input dict according to the Model's requirements.
|
|
||||||
input_dict = policy.model.get_input_dict(
|
|
||||||
sample_batch, index="last")
|
|
||||||
last_r = policy._value(**input_dict)
|
|
||||||
# TODO: (sven) Remove once trajectory view API is all-algo default.
|
|
||||||
else:
|
|
||||||
next_state = []
|
|
||||||
for i in range(policy.num_state_tensors()):
|
|
||||||
next_state.append(sample_batch["state_out_{}".format(i)][-1])
|
|
||||||
last_r = policy._value(sample_batch[SampleBatch.NEXT_OBS][-1],
|
|
||||||
sample_batch[SampleBatch.ACTIONS][-1],
|
|
||||||
sample_batch[SampleBatch.REWARDS][-1],
|
|
||||||
*next_state)
|
|
||||||
|
|
||||||
# Adds the policy logits, VF preds, and advantages to the batch,
|
|
||||||
# using GAE ("generalized advantage estimation") or not.
|
|
||||||
batch = compute_advantages(
|
|
||||||
sample_batch,
|
|
||||||
last_r,
|
|
||||||
policy.config["gamma"],
|
|
||||||
policy.config["lambda"],
|
|
||||||
use_gae=policy.config["use_gae"],
|
|
||||||
use_critic=policy.config.get("use_critic", True))
|
|
||||||
|
|
||||||
return batch
|
|
||||||
|
|
||||||
|
|
||||||
def compute_and_clip_gradients(policy: Policy, optimizer: LocalOptimizer,
|
def compute_and_clip_gradients(policy: Policy, optimizer: LocalOptimizer,
|
||||||
loss: TensorType) -> ModelGradients:
|
loss: TensorType) -> ModelGradients:
|
||||||
"""Gradients computing function (from loss tensor, using local optimizer).
|
"""Gradients computing function (from loss tensor, using local optimizer).
|
||||||
|
@ -392,13 +327,29 @@ def setup_mixins(policy: Policy, obs_space: gym.spaces.Space,
|
||||||
LearningRateSchedule.__init__(policy, config["lr"], config["lr_schedule"])
|
LearningRateSchedule.__init__(policy, config["lr"], config["lr_schedule"])
|
||||||
|
|
||||||
|
|
||||||
|
def postprocess_ppo_gae(
|
||||||
|
policy: Policy,
|
||||||
|
sample_batch: SampleBatch,
|
||||||
|
other_agent_batches: Optional[Dict[AgentID, SampleBatch]] = None,
|
||||||
|
episode: Optional[MultiAgentEpisode] = None) -> SampleBatch:
|
||||||
|
|
||||||
|
# Stub serving backward compatibility.
|
||||||
|
deprecation_warning(
|
||||||
|
old="rllib.agents.ppo.ppo_tf_policy.postprocess_ppo_gae",
|
||||||
|
new="rllib.evaluation.postprocessing.compute_gae_for_sample_batch",
|
||||||
|
error=False)
|
||||||
|
|
||||||
|
return compute_gae_for_sample_batch(policy, sample_batch,
|
||||||
|
other_agent_batches, episode)
|
||||||
|
|
||||||
|
|
||||||
# Build a child class of `DynamicTFPolicy`, given the custom functions defined
|
# Build a child class of `DynamicTFPolicy`, given the custom functions defined
|
||||||
# above.
|
# above.
|
||||||
PPOTFPolicy = build_tf_policy(
|
PPOTFPolicy = build_tf_policy(
|
||||||
name="PPOTFPolicy",
|
name="PPOTFPolicy",
|
||||||
loss_fn=ppo_surrogate_loss,
|
loss_fn=ppo_surrogate_loss,
|
||||||
get_default_config=lambda: ray.rllib.agents.ppo.ppo.DEFAULT_CONFIG,
|
get_default_config=lambda: ray.rllib.agents.ppo.ppo.DEFAULT_CONFIG,
|
||||||
postprocess_fn=postprocess_ppo_gae,
|
postprocess_fn=compute_gae_for_sample_batch,
|
||||||
stats_fn=kl_and_loss_stats,
|
stats_fn=kl_and_loss_stats,
|
||||||
gradients_fn=compute_and_clip_gradients,
|
gradients_fn=compute_and_clip_gradients,
|
||||||
extra_action_fetches_fn=vf_preds_fetches,
|
extra_action_fetches_fn=vf_preds_fetches,
|
||||||
|
|
|
@ -7,10 +7,9 @@ import numpy as np
|
||||||
from typing import Dict, List, Type, Union
|
from typing import Dict, List, Type, Union
|
||||||
|
|
||||||
import ray
|
import ray
|
||||||
from ray.rllib.agents.a3c.a3c_torch_policy import apply_grad_clipping
|
from ray.rllib.agents.ppo.ppo_tf_policy import setup_config
|
||||||
from ray.rllib.agents.ppo.ppo_tf_policy import postprocess_ppo_gae, \
|
from ray.rllib.evaluation.postprocessing import compute_gae_for_sample_batch, \
|
||||||
setup_config
|
Postprocessing
|
||||||
from ray.rllib.evaluation.postprocessing import Postprocessing
|
|
||||||
from ray.rllib.models.modelv2 import ModelV2
|
from ray.rllib.models.modelv2 import ModelV2
|
||||||
from ray.rllib.models.torch.torch_action_dist import TorchDistributionWrapper
|
from ray.rllib.models.torch.torch_action_dist import TorchDistributionWrapper
|
||||||
from ray.rllib.policy.policy import Policy
|
from ray.rllib.policy.policy import Policy
|
||||||
|
@ -19,8 +18,8 @@ from ray.rllib.policy.sample_batch import SampleBatch
|
||||||
from ray.rllib.policy.torch_policy import EntropyCoeffSchedule, \
|
from ray.rllib.policy.torch_policy import EntropyCoeffSchedule, \
|
||||||
LearningRateSchedule
|
LearningRateSchedule
|
||||||
from ray.rllib.utils.framework import try_import_torch
|
from ray.rllib.utils.framework import try_import_torch
|
||||||
from ray.rllib.utils.torch_ops import convert_to_torch_tensor, \
|
from ray.rllib.utils.torch_ops import apply_grad_clipping, \
|
||||||
explained_variance, sequence_mask
|
convert_to_torch_tensor, explained_variance, sequence_mask
|
||||||
from ray.rllib.utils.typing import TensorType, TrainerConfigDict
|
from ray.rllib.utils.typing import TensorType, TrainerConfigDict
|
||||||
|
|
||||||
torch, nn = try_import_torch()
|
torch, nn = try_import_torch()
|
||||||
|
@ -279,7 +278,7 @@ PPOTorchPolicy = build_policy_class(
|
||||||
loss_fn=ppo_surrogate_loss,
|
loss_fn=ppo_surrogate_loss,
|
||||||
stats_fn=kl_and_loss_stats,
|
stats_fn=kl_and_loss_stats,
|
||||||
extra_action_out_fn=vf_preds_fetches,
|
extra_action_out_fn=vf_preds_fetches,
|
||||||
postprocess_fn=postprocess_ppo_gae,
|
postprocess_fn=compute_gae_for_sample_batch,
|
||||||
extra_grad_process_fn=apply_grad_clipping,
|
extra_grad_process_fn=apply_grad_clipping,
|
||||||
before_init=setup_config,
|
before_init=setup_config,
|
||||||
before_loss_init=setup_mixins,
|
before_loss_init=setup_mixins,
|
||||||
|
|
|
@ -5,11 +5,12 @@ import unittest
|
||||||
import ray
|
import ray
|
||||||
from ray.rllib.agents.callbacks import DefaultCallbacks
|
from ray.rllib.agents.callbacks import DefaultCallbacks
|
||||||
import ray.rllib.agents.ppo as ppo
|
import ray.rllib.agents.ppo as ppo
|
||||||
from ray.rllib.agents.ppo.ppo_tf_policy import postprocess_ppo_gae as \
|
from ray.rllib.agents.ppo.ppo_tf_policy import ppo_surrogate_loss as \
|
||||||
postprocess_ppo_gae_tf, ppo_surrogate_loss as ppo_surrogate_loss_tf
|
ppo_surrogate_loss_tf
|
||||||
from ray.rllib.agents.ppo.ppo_torch_policy import postprocess_ppo_gae as \
|
from ray.rllib.agents.ppo.ppo_torch_policy import ppo_surrogate_loss as \
|
||||||
postprocess_ppo_gae_torch, ppo_surrogate_loss as ppo_surrogate_loss_torch
|
ppo_surrogate_loss_torch
|
||||||
from ray.rllib.evaluation.postprocessing import Postprocessing
|
from ray.rllib.evaluation.postprocessing import compute_gae_for_sample_batch, \
|
||||||
|
Postprocessing
|
||||||
from ray.rllib.models.tf.tf_action_dist import Categorical
|
from ray.rllib.models.tf.tf_action_dist import Categorical
|
||||||
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
|
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
|
||||||
from ray.rllib.models.torch.torch_action_dist import TorchCategorical
|
from ray.rllib.models.torch.torch_action_dist import TorchCategorical
|
||||||
|
@ -212,11 +213,8 @@ class TestPPO(unittest.TestCase):
|
||||||
# Check the variable is initially zero.
|
# Check the variable is initially zero.
|
||||||
init_std = get_value()
|
init_std = get_value()
|
||||||
assert init_std == 0.0, init_std
|
assert init_std == 0.0, init_std
|
||||||
|
batch = compute_gae_for_sample_batch(policy, FAKE_BATCH.copy())
|
||||||
if fw in ["tf2", "tf", "tfe"]:
|
if fw == "torch":
|
||||||
batch = postprocess_ppo_gae_tf(policy, FAKE_BATCH.copy())
|
|
||||||
else:
|
|
||||||
batch = postprocess_ppo_gae_torch(policy, FAKE_BATCH.copy())
|
|
||||||
batch = policy._lazy_tensor_dict(batch)
|
batch = policy._lazy_tensor_dict(batch)
|
||||||
policy.learn_on_batch(batch)
|
policy.learn_on_batch(batch)
|
||||||
|
|
||||||
|
@ -255,11 +253,9 @@ class TestPPO(unittest.TestCase):
|
||||||
# to train_batch dict.
|
# to train_batch dict.
|
||||||
# A = [0.99^2 * 0.5 + 0.99 * -1.0 + 1.0, 0.99 * 0.5 - 1.0, 0.5] =
|
# A = [0.99^2 * 0.5 + 0.99 * -1.0 + 1.0, 0.99 * 0.5 - 1.0, 0.5] =
|
||||||
# [0.50005, -0.505, 0.5]
|
# [0.50005, -0.505, 0.5]
|
||||||
if fw in ["tf2", "tf", "tfe"]:
|
train_batch = compute_gae_for_sample_batch(policy,
|
||||||
train_batch = postprocess_ppo_gae_tf(policy, FAKE_BATCH.copy())
|
FAKE_BATCH.copy())
|
||||||
else:
|
if fw == "torch":
|
||||||
train_batch = postprocess_ppo_gae_torch(
|
|
||||||
policy, FAKE_BATCH.copy())
|
|
||||||
train_batch = policy._lazy_tensor_dict(train_batch)
|
train_batch = policy._lazy_tensor_dict(train_batch)
|
||||||
|
|
||||||
# Check Advantage values.
|
# Check Advantage values.
|
||||||
|
|
|
@ -73,8 +73,7 @@ DEFAULT_CONFIG = with_common_config({
|
||||||
"timesteps_per_iteration": 100,
|
"timesteps_per_iteration": 100,
|
||||||
|
|
||||||
# === Replay buffer ===
|
# === Replay buffer ===
|
||||||
# Size of the replay buffer. Note that if async_updates is set, then
|
# Size of the replay buffer (in time steps).
|
||||||
# each worker will have a replay buffer of this size.
|
|
||||||
"buffer_size": int(1e6),
|
"buffer_size": int(1e6),
|
||||||
# If True prioritized replay buffer will be used.
|
# If True prioritized replay buffer will be used.
|
||||||
"prioritized_replay": False,
|
"prioritized_replay": False,
|
||||||
|
@ -104,9 +103,7 @@ DEFAULT_CONFIG = with_common_config({
|
||||||
# Update the replay buffer with this many samples at once. Note that this
|
# Update the replay buffer with this many samples at once. Note that this
|
||||||
# setting applies per-worker if num_workers > 1.
|
# setting applies per-worker if num_workers > 1.
|
||||||
"rollout_fragment_length": 1,
|
"rollout_fragment_length": 1,
|
||||||
# Size of a batched sampled from replay buffer for training. Note that
|
# Size of a batched sampled from replay buffer for training.
|
||||||
# if async_updates is set, then each worker returns gradients for a
|
|
||||||
# batch of this size.
|
|
||||||
"train_batch_size": 256,
|
"train_batch_size": 256,
|
||||||
# Update the target network every `target_network_update_freq` steps.
|
# Update the target network every `target_network_update_freq` steps.
|
||||||
"target_network_update_freq": 0,
|
"target_network_update_freq": 0,
|
||||||
|
|
|
@ -9,7 +9,6 @@ from typing import Dict, List, Optional, Tuple, Type, Union
|
||||||
|
|
||||||
import ray
|
import ray
|
||||||
import ray.experimental.tf_utils
|
import ray.experimental.tf_utils
|
||||||
from ray.rllib.agents.a3c.a3c_torch_policy import apply_grad_clipping
|
|
||||||
from ray.rllib.agents.sac.sac_tf_policy import build_sac_model, \
|
from ray.rllib.agents.sac.sac_tf_policy import build_sac_model, \
|
||||||
postprocess_trajectory, validate_spaces
|
postprocess_trajectory, validate_spaces
|
||||||
from ray.rllib.agents.dqn.dqn_tf_policy import PRIO_WEIGHTS
|
from ray.rllib.agents.dqn.dqn_tf_policy import PRIO_WEIGHTS
|
||||||
|
@ -23,7 +22,7 @@ from ray.rllib.models.torch.torch_action_dist import (
|
||||||
TorchCategorical, TorchSquashedGaussian, TorchDiagGaussian, TorchBeta)
|
TorchCategorical, TorchSquashedGaussian, TorchDiagGaussian, TorchBeta)
|
||||||
from ray.rllib.utils.framework import try_import_torch
|
from ray.rllib.utils.framework import try_import_torch
|
||||||
from ray.rllib.utils.spaces.simplex import Simplex
|
from ray.rllib.utils.spaces.simplex import Simplex
|
||||||
from ray.rllib.utils.torch_ops import huber_loss
|
from ray.rllib.utils.torch_ops import apply_grad_clipping, huber_loss
|
||||||
from ray.rllib.utils.typing import LocalOptimizer, TensorType, \
|
from ray.rllib.utils.typing import LocalOptimizer, TensorType, \
|
||||||
TrainerConfigDict
|
TrainerConfigDict
|
||||||
|
|
||||||
|
|
|
@ -265,17 +265,11 @@ class MADDPGTFPolicy(MADDPGPostprocessing, TFPolicy):
|
||||||
|
|
||||||
@override(TFPolicy)
|
@override(TFPolicy)
|
||||||
def gradients(self, optimizer, loss):
|
def gradients(self, optimizer, loss):
|
||||||
if self.config["grad_norm_clipping"] is not None:
|
self.gvs = {
|
||||||
self.gvs = {
|
k: minimize_and_clip(optimizer, self.losses[k], self.vars[k],
|
||||||
k: minimize_and_clip(optimizer, self.losses[k], self.vars[k],
|
self.config["grad_norm_clipping"])
|
||||||
self.config["grad_norm_clipping"])
|
for k, optimizer in self.optimizers.items()
|
||||||
for k, optimizer in self.optimizers.items()
|
}
|
||||||
}
|
|
||||||
else:
|
|
||||||
self.gvs = {
|
|
||||||
k: optimizer.compute_gradients(self.losses[k], self.vars[k])
|
|
||||||
for k, optimizer in self.optimizers.items()
|
|
||||||
}
|
|
||||||
return self.gvs["critic"] + self.gvs["actor"]
|
return self.gvs["critic"] + self.gvs["actor"]
|
||||||
|
|
||||||
@override(TFPolicy)
|
@override(TFPolicy)
|
||||||
|
|
|
@ -1,22 +1,12 @@
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import scipy.signal
|
import scipy.signal
|
||||||
|
from typing import Dict, Optional
|
||||||
|
|
||||||
|
from ray.rllib.evaluation.episode import MultiAgentEpisode
|
||||||
|
from ray.rllib.policy.policy import Policy
|
||||||
from ray.rllib.policy.sample_batch import SampleBatch
|
from ray.rllib.policy.sample_batch import SampleBatch
|
||||||
from ray.rllib.utils.annotations import DeveloperAPI
|
from ray.rllib.utils.annotations import DeveloperAPI
|
||||||
|
from ray.rllib.utils.typing import AgentID
|
||||||
|
|
||||||
def discount_cumsum(x: np.ndarray, gamma: float) -> float:
|
|
||||||
"""Calculates the discounted cumulative sum over a reward sequence `x`.
|
|
||||||
|
|
||||||
y[t] - discount*y[t+1] = x[t]
|
|
||||||
reversed(y)[t] - discount*reversed(y)[t-1] = reversed(x)[t]
|
|
||||||
|
|
||||||
Args:
|
|
||||||
gamma (float): The discount factor gamma.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
float: The discounted cumulative sum over the reward sequence `x`.
|
|
||||||
"""
|
|
||||||
return scipy.signal.lfilter([1], [1, float(-gamma)], x[::-1], axis=0)[::-1]
|
|
||||||
|
|
||||||
|
|
||||||
class Postprocessing:
|
class Postprocessing:
|
||||||
|
@ -89,3 +79,83 @@ def compute_advantages(rollout: SampleBatch,
|
||||||
Postprocessing.ADVANTAGES].astype(np.float32)
|
Postprocessing.ADVANTAGES].astype(np.float32)
|
||||||
|
|
||||||
return rollout
|
return rollout
|
||||||
|
|
||||||
|
|
||||||
|
def compute_gae_for_sample_batch(
|
||||||
|
policy: Policy,
|
||||||
|
sample_batch: SampleBatch,
|
||||||
|
other_agent_batches: Optional[Dict[AgentID, SampleBatch]] = None,
|
||||||
|
episode: Optional[MultiAgentEpisode] = None) -> SampleBatch:
|
||||||
|
"""Adds GAE (generalized advantage estimations) to a trajectory.
|
||||||
|
|
||||||
|
The trajectory contains only data from one episode and from one agent.
|
||||||
|
- If `config.batch_mode=truncate_episodes` (default), sample_batch may
|
||||||
|
contain a truncated (at-the-end) episode, in case the
|
||||||
|
`config.rollout_fragment_length` was reached by the sampler.
|
||||||
|
- If `config.batch_mode=complete_episodes`, sample_batch will contain
|
||||||
|
exactly one episode (no matter how long).
|
||||||
|
New columns can be added to sample_batch and existing ones may be altered.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
policy (Policy): The Policy used to generate the trajectory
|
||||||
|
(`sample_batch`)
|
||||||
|
sample_batch (SampleBatch): The SampleBatch to postprocess.
|
||||||
|
other_agent_batches (Optional[Dict[PolicyID, SampleBatch]]): Optional
|
||||||
|
dict of AgentIDs mapping to other agents' trajectory data (from the
|
||||||
|
same episode). NOTE: The other agents use the same policy.
|
||||||
|
episode (Optional[MultiAgentEpisode]): Optional multi-agent episode
|
||||||
|
object in which the agents operated.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
SampleBatch: The postprocessed, modified SampleBatch (or a new one).
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Trajectory is actually complete -> last r=0.0.
|
||||||
|
if sample_batch[SampleBatch.DONES][-1]:
|
||||||
|
last_r = 0.0
|
||||||
|
# Trajectory has been truncated -> last r=VF estimate of last obs.
|
||||||
|
else:
|
||||||
|
# Input dict is provided to us automatically via the Model's
|
||||||
|
# requirements. It's a single-timestep (last one in trajectory)
|
||||||
|
# input_dict.
|
||||||
|
if policy.config.get("_use_trajectory_view_api"):
|
||||||
|
# Create an input dict according to the Model's requirements.
|
||||||
|
input_dict = policy.model.get_input_dict(
|
||||||
|
sample_batch, index="last")
|
||||||
|
last_r = policy._value(**input_dict)
|
||||||
|
# TODO: (sven) Remove once trajectory view API is all-algo default.
|
||||||
|
else:
|
||||||
|
next_state = []
|
||||||
|
for i in range(policy.num_state_tensors()):
|
||||||
|
next_state.append(sample_batch["state_out_{}".format(i)][-1])
|
||||||
|
last_r = policy._value(sample_batch[SampleBatch.NEXT_OBS][-1],
|
||||||
|
sample_batch[SampleBatch.ACTIONS][-1],
|
||||||
|
sample_batch[SampleBatch.REWARDS][-1],
|
||||||
|
*next_state)
|
||||||
|
|
||||||
|
# Adds the policy logits, VF preds, and advantages to the batch,
|
||||||
|
# using GAE ("generalized advantage estimation") or not.
|
||||||
|
batch = compute_advantages(
|
||||||
|
sample_batch,
|
||||||
|
last_r,
|
||||||
|
policy.config["gamma"],
|
||||||
|
policy.config["lambda"],
|
||||||
|
use_gae=policy.config["use_gae"],
|
||||||
|
use_critic=policy.config.get("use_critic", True))
|
||||||
|
|
||||||
|
return batch
|
||||||
|
|
||||||
|
|
||||||
|
def discount_cumsum(x: np.ndarray, gamma: float) -> float:
|
||||||
|
"""Calculates the discounted cumulative sum over a reward sequence `x`.
|
||||||
|
|
||||||
|
y[t] - discount*y[t+1] = x[t]
|
||||||
|
reversed(y)[t] - discount*reversed(y)[t-1] = reversed(x)[t]
|
||||||
|
|
||||||
|
Args:
|
||||||
|
gamma (float): The discount factor gamma.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
float: The discounted cumulative sum over the reward sequence `x`.
|
||||||
|
"""
|
||||||
|
return scipy.signal.lfilter([1], [1, float(-gamma)], x[::-1], axis=0)[::-1]
|
||||||
|
|
|
@ -14,11 +14,11 @@ mspacman-sac-tf:
|
||||||
# state-preprocessor=Our default Atari Conv2D-net.
|
# state-preprocessor=Our default Atari Conv2D-net.
|
||||||
use_state_preprocessor: true
|
use_state_preprocessor: true
|
||||||
Q_model:
|
Q_model:
|
||||||
hidden_activation: relu
|
fcnet_hiddens: [512]
|
||||||
hidden_layer_sizes: [512]
|
fcnet_activation: relu
|
||||||
policy_model:
|
policy_model:
|
||||||
hidden_activation: relu
|
fcnet_hiddens: [512]
|
||||||
hidden_layer_sizes: [512]
|
fcnet_activation: relu
|
||||||
# Do hard syncs.
|
# Do hard syncs.
|
||||||
# Soft-syncs seem to work less reliably for discrete action spaces.
|
# Soft-syncs seem to work less reliably for discrete action spaces.
|
||||||
tau: 1.0
|
tau: 1.0
|
||||||
|
|
|
@ -92,7 +92,7 @@ def minimize_and_clip(optimizer, objective, var_list, clip_val=10.0):
|
||||||
variable is clipped to `clip_val`
|
variable is clipped to `clip_val`
|
||||||
"""
|
"""
|
||||||
# Accidentally passing values < 0.0 will break all gradients.
|
# Accidentally passing values < 0.0 will break all gradients.
|
||||||
assert clip_val > 0.0, clip_val
|
assert clip_val is None or clip_val > 0.0, clip_val
|
||||||
|
|
||||||
if tf.executing_eagerly():
|
if tf.executing_eagerly():
|
||||||
tape = optimizer.tape
|
tape = optimizer.tape
|
||||||
|
@ -102,10 +102,8 @@ def minimize_and_clip(optimizer, objective, var_list, clip_val=10.0):
|
||||||
grads_and_vars = optimizer.compute_gradients(
|
grads_and_vars = optimizer.compute_gradients(
|
||||||
objective, var_list=var_list)
|
objective, var_list=var_list)
|
||||||
|
|
||||||
for i, (grad, var) in enumerate(grads_and_vars):
|
return [(tf.clip_by_norm(g, clip_val) if clip_val is not None else g, v)
|
||||||
if grad is not None:
|
for (g, v) in grads_and_vars if g is not None]
|
||||||
grads_and_vars[i] = (tf.clip_by_norm(grad, clip_val), var)
|
|
||||||
return grads_and_vars
|
|
||||||
|
|
||||||
|
|
||||||
def make_tf_callable(session_or_none, dynamic_shape=False):
|
def make_tf_callable(session_or_none, dynamic_shape=False):
|
||||||
|
|
|
@ -14,6 +14,30 @@ FLOAT_MIN = -3.4e38
|
||||||
FLOAT_MAX = 3.4e38
|
FLOAT_MAX = 3.4e38
|
||||||
|
|
||||||
|
|
||||||
|
def apply_grad_clipping(policy, optimizer, loss):
|
||||||
|
"""Applies gradient clipping to already computed grads inside `optimizer`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
policy (TorchPolicy): The TorchPolicy, which calculated `loss`.
|
||||||
|
optimizer (torch.optim.Optimizer): A local torch optimizer object.
|
||||||
|
loss (torch.Tensor): The torch loss tensor.
|
||||||
|
"""
|
||||||
|
info = {}
|
||||||
|
if policy.config["grad_clip"]:
|
||||||
|
for param_group in optimizer.param_groups:
|
||||||
|
# Make sure we only pass params with grad != None into torch
|
||||||
|
# clip_grad_norm_. Would fail otherwise.
|
||||||
|
params = list(
|
||||||
|
filter(lambda p: p.grad is not None, param_group["params"]))
|
||||||
|
if params:
|
||||||
|
grad_gnorm = nn.utils.clip_grad_norm_(
|
||||||
|
params, policy.config["grad_clip"])
|
||||||
|
if isinstance(grad_gnorm, torch.Tensor):
|
||||||
|
grad_gnorm = grad_gnorm.cpu().numpy()
|
||||||
|
info["grad_gnorm"] = grad_gnorm
|
||||||
|
return info
|
||||||
|
|
||||||
|
|
||||||
def atanh(x):
|
def atanh(x):
|
||||||
return 0.5 * torch.log((1 + x) / (1 - x))
|
return 0.5 * torch.log((1 + x) / (1 - x))
|
||||||
|
|
||||||
|
|
Loading…
Add table
Reference in a new issue