mirror of
https://github.com/vale981/ray
synced 2025-03-05 10:01:43 -05:00
[RLlib] DDPG/TD3 + A3C/A2C + MARWIL/BC Annotation/Comments/Code Cleanup (#14707)
This commit is contained in:
parent
8790bb465b
commit
474f04e322
15 changed files with 316 additions and 115 deletions
|
@ -1,4 +1,5 @@
|
|||
import math
|
||||
from typing import Optional, Type
|
||||
|
||||
from ray.rllib.agents.a3c.a3c import DEFAULT_CONFIG as A3C_CONFIG, \
|
||||
validate_config, get_policy_class
|
||||
|
@ -9,6 +10,9 @@ from ray.rllib.execution.rollout_ops import ParallelRollouts, ConcatBatches
|
|||
from ray.rllib.execution.train_ops import ComputeGradients, AverageGradients, \
|
||||
ApplyGradients, TrainTFMultiGPU, TrainOneStep
|
||||
from ray.rllib.utils import merge_dicts
|
||||
from ray.rllib.utils.typing import TrainerConfigDict
|
||||
from ray.rllib.evaluation.worker_set import WorkerSet
|
||||
from ray.rllib.policy.policy import Policy
|
||||
|
||||
A2C_DEFAULT_CONFIG = merge_dicts(
|
||||
A3C_CONFIG,
|
||||
|
@ -26,7 +30,19 @@ A2C_DEFAULT_CONFIG = merge_dicts(
|
|||
)
|
||||
|
||||
|
||||
def execution_plan(workers, config):
|
||||
def execution_plan(workers: WorkerSet,
|
||||
config: TrainerConfigDict) -> Optional[Type[Policy]]:
|
||||
"""Execution plan of the MARWIL/BC algorithm. Defines the distributed
|
||||
dataflow.
|
||||
|
||||
Args:
|
||||
workers (WorkerSet): The WorkerSet for training the Polic(y/ies)
|
||||
of the Trainer.
|
||||
config (TrainerConfigDict): The trainer's configuration dict.
|
||||
|
||||
Returns:
|
||||
LocalIterator[dict]: A local iterator over training metrics.
|
||||
"""
|
||||
rollouts = ParallelRollouts(workers, mode="bulk_sync")
|
||||
|
||||
if config["microbatch_size"]:
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
import logging
|
||||
from typing import Optional, Type
|
||||
|
||||
from ray.rllib.agents.a3c.a3c_tf_policy import A3CTFPolicy
|
||||
from ray.rllib.agents.trainer import with_common_config
|
||||
|
@ -6,6 +7,10 @@ from ray.rllib.agents.trainer_template import build_trainer
|
|||
from ray.rllib.execution.rollout_ops import AsyncGradients
|
||||
from ray.rllib.execution.train_ops import ApplyGradients
|
||||
from ray.rllib.execution.metric_ops import StandardMetricsReporting
|
||||
from ray.rllib.utils.typing import TrainerConfigDict
|
||||
from ray.rllib.evaluation.worker_set import WorkerSet
|
||||
from ray.util.iter import LocalIterator
|
||||
from ray.rllib.policy.policy import Policy
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -42,7 +47,16 @@ DEFAULT_CONFIG = with_common_config({
|
|||
# yapf: enable
|
||||
|
||||
|
||||
def get_policy_class(config):
|
||||
def get_policy_class(config: TrainerConfigDict) -> Optional[Type[Policy]]:
|
||||
"""Policy class picker function. Class is chosen based on DL-framework.
|
||||
|
||||
Args:
|
||||
config (TrainerConfigDict): The trainer's configuration dict.
|
||||
|
||||
Returns:
|
||||
Optional[Type[Policy]]: The Policy class to use with DQNTrainer.
|
||||
If None, use `default_policy` provided in build_trainer().
|
||||
"""
|
||||
if config["framework"] == "torch":
|
||||
from ray.rllib.agents.a3c.a3c_torch_policy import \
|
||||
A3CTorchPolicy
|
||||
|
@ -51,14 +65,30 @@ def get_policy_class(config):
|
|||
return A3CTFPolicy
|
||||
|
||||
|
||||
def validate_config(config):
|
||||
def validate_config(config: TrainerConfigDict) -> None:
|
||||
"""Checks and updates the config based on settings.
|
||||
|
||||
Rewrites rollout_fragment_length to take into account n_step truncation.
|
||||
"""
|
||||
if config["entropy_coeff"] < 0:
|
||||
raise ValueError("`entropy_coeff` must be >= 0.0!")
|
||||
if config["num_workers"] <= 0 and config["sample_async"]:
|
||||
raise ValueError("`num_workers` for A3C must be >= 1!")
|
||||
|
||||
|
||||
def execution_plan(workers, config):
|
||||
def execution_plan(workers: WorkerSet,
|
||||
config: TrainerConfigDict) -> LocalIterator[dict]:
|
||||
"""Execution plan of the MARWIL/BC algorithm. Defines the distributed
|
||||
dataflow.
|
||||
|
||||
Args:
|
||||
workers (WorkerSet): The WorkerSet for training the Polic(y/ies)
|
||||
of the Trainer.
|
||||
config (TrainerConfigDict): The trainer's configuration dict.
|
||||
|
||||
Returns:
|
||||
LocalIterator[dict]: A local iterator over training metrics.
|
||||
"""
|
||||
# For A3C, compute policy gradients remotely on the rollout workers.
|
||||
grads = AsyncGradients(workers)
|
||||
|
||||
|
|
|
@ -1,4 +1,6 @@
|
|||
"""Note: Keep in sync with changes to VTraceTFPolicy."""
|
||||
from typing import Optional, Dict
|
||||
import gym
|
||||
|
||||
import ray
|
||||
from ray.rllib.agents.ppo.ppo_tf_policy import ValueNetworkMixin
|
||||
|
@ -10,14 +12,21 @@ from ray.rllib.policy.tf_policy import LearningRateSchedule
|
|||
from ray.rllib.utils.deprecation import deprecation_warning
|
||||
from ray.rllib.utils.framework import try_import_tf
|
||||
from ray.rllib.utils.tf_ops import explained_variance
|
||||
from ray.rllib.policy.policy import Policy
|
||||
from ray.rllib.utils.typing import TrainerConfigDict, TensorType, \
|
||||
PolicyID, LocalOptimizer, ModelGradients
|
||||
from ray.rllib.models.action_dist import ActionDistribution
|
||||
from ray.rllib.models.modelv2 import ModelV2
|
||||
from ray.rllib.evaluation import MultiAgentEpisode
|
||||
|
||||
tf1, tf, tfv = try_import_tf()
|
||||
|
||||
|
||||
def postprocess_advantages(policy,
|
||||
sample_batch,
|
||||
other_agent_batches=None,
|
||||
episode=None):
|
||||
def postprocess_advantages(
|
||||
policy: Policy,
|
||||
sample_batch: SampleBatch,
|
||||
other_agent_batches: Optional[Dict[PolicyID, SampleBatch]] = None,
|
||||
episode: Optional[MultiAgentEpisode] = None) -> SampleBatch:
|
||||
|
||||
# Stub serving backward compatibility.
|
||||
deprecation_warning(
|
||||
|
@ -31,15 +40,15 @@ def postprocess_advantages(policy,
|
|||
|
||||
class A3CLoss:
|
||||
def __init__(self,
|
||||
action_dist,
|
||||
actions,
|
||||
advantages,
|
||||
v_target,
|
||||
vf,
|
||||
valid_mask,
|
||||
vf_loss_coeff=0.5,
|
||||
entropy_coeff=0.01,
|
||||
use_critic=True):
|
||||
action_dist: ActionDistribution,
|
||||
actions: TensorType,
|
||||
advantages: TensorType,
|
||||
v_target: TensorType,
|
||||
vf: TensorType,
|
||||
valid_mask: TensorType,
|
||||
vf_loss_coeff: float = 0.5,
|
||||
entropy_coeff: float = 0.01,
|
||||
use_critic: bool = True):
|
||||
log_prob = action_dist.logp(actions)
|
||||
|
||||
# The "policy gradients" loss
|
||||
|
@ -62,7 +71,9 @@ class A3CLoss:
|
|||
self.entropy * entropy_coeff)
|
||||
|
||||
|
||||
def actor_critic_loss(policy, model, dist_class, train_batch):
|
||||
def actor_critic_loss(policy: Policy, model: ModelV2,
|
||||
dist_class: ActionDistribution,
|
||||
train_batch: SampleBatch) -> TensorType:
|
||||
model_out, _ = model.from_batch(train_batch)
|
||||
action_dist = dist_class(model_out, model)
|
||||
if policy.is_recurrent():
|
||||
|
@ -81,11 +92,11 @@ def actor_critic_loss(policy, model, dist_class, train_batch):
|
|||
return policy.loss.total_loss
|
||||
|
||||
|
||||
def add_value_function_fetch(policy):
|
||||
def add_value_function_fetch(policy: Policy) -> Dict[str, TensorType]:
|
||||
return {SampleBatch.VF_PREDS: policy.model.value_function()}
|
||||
|
||||
|
||||
def stats(policy, train_batch):
|
||||
def stats(policy: Policy, train_batch: SampleBatch) -> Dict[str, TensorType]:
|
||||
return {
|
||||
"cur_lr": tf.cast(policy.cur_lr, tf.float64),
|
||||
"policy_loss": policy.loss.pi_loss,
|
||||
|
@ -96,7 +107,8 @@ def stats(policy, train_batch):
|
|||
}
|
||||
|
||||
|
||||
def grad_stats(policy, train_batch, grads):
|
||||
def grad_stats(policy: Policy, train_batch: SampleBatch,
|
||||
grads: ModelGradients) -> Dict[str, TensorType]:
|
||||
return {
|
||||
"grad_gnorm": tf.linalg.global_norm(grads),
|
||||
"vf_explained_var": explained_variance(
|
||||
|
@ -105,7 +117,8 @@ def grad_stats(policy, train_batch, grads):
|
|||
}
|
||||
|
||||
|
||||
def clip_gradients(policy, optimizer, loss):
|
||||
def clip_gradients(policy: Policy, optimizer: LocalOptimizer,
|
||||
loss: TensorType) -> ModelGradients:
|
||||
grads_and_vars = optimizer.compute_gradients(
|
||||
loss, policy.model.trainable_variables())
|
||||
grads = [g for (g, v) in grads_and_vars]
|
||||
|
@ -114,7 +127,9 @@ def clip_gradients(policy, optimizer, loss):
|
|||
return clipped_grads
|
||||
|
||||
|
||||
def setup_mixins(policy, obs_space, action_space, config):
|
||||
def setup_mixins(policy: Policy, obs_space: gym.spaces.Space,
|
||||
action_space: gym.spaces.Space,
|
||||
config: TrainerConfigDict) -> None:
|
||||
ValueNetworkMixin.__init__(policy, obs_space, action_space, config)
|
||||
LearningRateSchedule.__init__(policy, config["lr"], config["lr_schedule"])
|
||||
|
||||
|
|
|
@ -1,24 +1,30 @@
|
|||
import gym
|
||||
from typing import Optional, Dict
|
||||
|
||||
import ray
|
||||
from ray.rllib.agents.ppo.ppo_torch_policy import ValueNetworkMixin
|
||||
from ray.rllib.evaluation.episode import MultiAgentEpisode
|
||||
from ray.rllib.evaluation.postprocessing import compute_gae_for_sample_batch, \
|
||||
Postprocessing
|
||||
from ray.rllib.models.action_dist import ActionDistribution
|
||||
from ray.rllib.models.modelv2 import ModelV2
|
||||
from ray.rllib.policy.policy import Policy
|
||||
from ray.rllib.policy.policy_template import build_policy_class
|
||||
from ray.rllib.policy.sample_batch import SampleBatch
|
||||
from ray.rllib.utils.deprecation import deprecation_warning
|
||||
from ray.rllib.utils.framework import try_import_torch
|
||||
from ray.rllib.utils.torch_ops import apply_grad_clipping, sequence_mask
|
||||
from ray.rllib.utils.typing import TrainerConfigDict
|
||||
from ray.rllib.utils.typing import TrainerConfigDict, TensorType, \
|
||||
PolicyID, LocalOptimizer
|
||||
|
||||
torch, nn = try_import_torch()
|
||||
|
||||
|
||||
def add_advantages(policy,
|
||||
sample_batch,
|
||||
other_agent_batches=None,
|
||||
episode=None):
|
||||
def add_advantages(
|
||||
policy: Policy,
|
||||
sample_batch: SampleBatch,
|
||||
other_agent_batches: Optional[Dict[PolicyID, SampleBatch]] = None,
|
||||
episode: Optional[MultiAgentEpisode] = None) -> SampleBatch:
|
||||
|
||||
# Stub serving backward compatibility.
|
||||
deprecation_warning(
|
||||
|
@ -30,7 +36,9 @@ def add_advantages(policy,
|
|||
other_agent_batches, episode)
|
||||
|
||||
|
||||
def actor_critic_loss(policy, model, dist_class, train_batch):
|
||||
def actor_critic_loss(policy: Policy, model: ModelV2,
|
||||
dist_class: ActionDistribution,
|
||||
train_batch: SampleBatch) -> TensorType:
|
||||
logits, _ = model.from_batch(train_batch)
|
||||
values = model.value_function()
|
||||
|
||||
|
@ -68,7 +76,8 @@ def actor_critic_loss(policy, model, dist_class, train_batch):
|
|||
return total_loss
|
||||
|
||||
|
||||
def loss_and_entropy_stats(policy, train_batch):
|
||||
def loss_and_entropy_stats(policy: Policy,
|
||||
train_batch: SampleBatch) -> Dict[str, TensorType]:
|
||||
return {
|
||||
"policy_entropy": policy.entropy,
|
||||
"policy_loss": policy.pi_err,
|
||||
|
@ -76,12 +85,15 @@ def loss_and_entropy_stats(policy, train_batch):
|
|||
}
|
||||
|
||||
|
||||
def model_value_predictions(policy, input_dict, state_batches, model,
|
||||
action_dist):
|
||||
def model_value_predictions(
|
||||
policy: Policy, input_dict: Dict[str, TensorType], state_batches,
|
||||
model: ModelV2,
|
||||
action_dist: ActionDistribution) -> Dict[str, TensorType]:
|
||||
return {SampleBatch.VF_PREDS: model.value_function()}
|
||||
|
||||
|
||||
def torch_optimizer(policy, config):
|
||||
def torch_optimizer(policy: Policy,
|
||||
config: TrainerConfigDict) -> LocalOptimizer:
|
||||
return torch.optim.Adam(policy.model.parameters(), lr=config["lr"])
|
||||
|
||||
|
||||
|
|
|
@ -1,9 +1,12 @@
|
|||
import logging
|
||||
from typing import Optional, Type
|
||||
|
||||
from ray.rllib.agents.trainer import with_common_config
|
||||
from ray.rllib.agents.dqn.dqn import GenericOffPolicyTrainer
|
||||
from ray.rllib.agents.ddpg.ddpg_tf_policy import DDPGTFPolicy
|
||||
from ray.rllib.policy.policy import Policy
|
||||
from ray.rllib.utils.deprecation import DEPRECATED_VALUE
|
||||
from ray.rllib.utils.typing import TrainerConfigDict
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -151,7 +154,11 @@ DEFAULT_CONFIG = with_common_config({
|
|||
# yapf: enable
|
||||
|
||||
|
||||
def validate_config(config):
|
||||
def validate_config(config: TrainerConfigDict) -> None:
|
||||
"""Checks and updates the config based on settings.
|
||||
|
||||
Rewrites rollout_fragment_length to take into account n_step truncation.
|
||||
"""
|
||||
if config["num_gpus"] > 1:
|
||||
raise ValueError("`num_gpus` > 1 not yet supported for DDPG!")
|
||||
if config["model"]["custom_model"]:
|
||||
|
@ -176,7 +183,16 @@ def validate_config(config):
|
|||
config["simple_optimizer"] = True
|
||||
|
||||
|
||||
def get_policy_class(config):
|
||||
def get_policy_class(config: TrainerConfigDict) -> Optional[Type[Policy]]:
|
||||
"""Policy class picker function. Class is chosen based on DL-framework.
|
||||
|
||||
Args:
|
||||
config (TrainerConfigDict): The trainer's configuration dict.
|
||||
|
||||
Returns:
|
||||
Optional[Type[Policy]]: The Policy class to use with DQNTrainer.
|
||||
If None, use `default_policy` provided in build_trainer().
|
||||
"""
|
||||
if config["framework"] == "torch":
|
||||
from ray.rllib.agents.ddpg.ddpg_torch_policy import DDPGTorchPolicy
|
||||
return DDPGTorchPolicy
|
||||
|
|
|
@ -1,7 +1,10 @@
|
|||
import numpy as np
|
||||
import gym
|
||||
from typing import List
|
||||
|
||||
from ray.rllib.models.tf.tf_modelv2 import TFModelV2
|
||||
from ray.rllib.utils.framework import try_import_tf
|
||||
from ray.rllib.utils.typing import ModelConfigDict, TensorType
|
||||
|
||||
tf1, tf, tfv = try_import_tf()
|
||||
|
||||
|
@ -20,18 +23,18 @@ class DDPGTFModel(TFModelV2):
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
obs_space,
|
||||
action_space,
|
||||
num_outputs,
|
||||
model_config,
|
||||
name,
|
||||
obs_space: gym.spaces.Space,
|
||||
action_space: gym.spaces.Space,
|
||||
num_outputs: int,
|
||||
model_config: ModelConfigDict,
|
||||
name: str,
|
||||
# Extra DDPGActionModel args:
|
||||
actor_hiddens=(256, 256),
|
||||
actor_hidden_activation="relu",
|
||||
critic_hiddens=(256, 256),
|
||||
critic_hidden_activation="relu",
|
||||
twin_q=False,
|
||||
add_layer_norm=False):
|
||||
actor_hiddens: List[int] = [256, 256],
|
||||
actor_hidden_activation: str = "relu",
|
||||
critic_hiddens: List[int] = [256, 256],
|
||||
critic_hidden_activation: str = "relu",
|
||||
twin_q: bool = False,
|
||||
add_layer_norm: bool = False):
|
||||
"""Initialize variables of this model.
|
||||
|
||||
Extra model kwargs:
|
||||
|
@ -122,7 +125,8 @@ class DDPGTFModel(TFModelV2):
|
|||
else:
|
||||
self.twin_q_model = None
|
||||
|
||||
def get_q_values(self, model_out, actions):
|
||||
def get_q_values(self, model_out: TensorType,
|
||||
actions: TensorType) -> TensorType:
|
||||
"""Return the Q estimates for the most recent forward pass.
|
||||
|
||||
This implements Q(s, a).
|
||||
|
@ -141,7 +145,8 @@ class DDPGTFModel(TFModelV2):
|
|||
else:
|
||||
return self.q_model(model_out)
|
||||
|
||||
def get_twin_q_values(self, model_out, actions):
|
||||
def get_twin_q_values(self, model_out: TensorType,
|
||||
actions: TensorType) -> TensorType:
|
||||
"""Same as get_q_values but using the twin Q net.
|
||||
|
||||
This implements the twin Q(s, a).
|
||||
|
@ -160,7 +165,7 @@ class DDPGTFModel(TFModelV2):
|
|||
else:
|
||||
return self.twin_q_model(model_out)
|
||||
|
||||
def get_policy_output(self, model_out):
|
||||
def get_policy_output(self, model_out: TensorType) -> TensorType:
|
||||
"""Return the action output for the most recent forward pass.
|
||||
|
||||
This outputs the support for pi(s). For continuous action spaces, this
|
||||
|
@ -175,11 +180,11 @@ class DDPGTFModel(TFModelV2):
|
|||
"""
|
||||
return self.policy_model(model_out)
|
||||
|
||||
def policy_variables(self):
|
||||
def policy_variables(self) -> List[TensorType]:
|
||||
"""Return the list of variables for the policy net."""
|
||||
return list(self.policy_model.variables)
|
||||
|
||||
def q_variables(self):
|
||||
def q_variables(self) -> List[TensorType]:
|
||||
"""Return the list of variables for Q / twin Q nets."""
|
||||
|
||||
return self.q_model.variables + (self.twin_q_model.variables
|
||||
|
|
|
@ -2,6 +2,8 @@ from gym.spaces import Box
|
|||
from functools import partial
|
||||
import logging
|
||||
import numpy as np
|
||||
import gym
|
||||
from typing import Dict, Tuple, List
|
||||
|
||||
import ray
|
||||
import ray.experimental.tf_utils
|
||||
|
@ -23,13 +25,20 @@ from ray.rllib.utils.error import UnsupportedSpaceException
|
|||
from ray.rllib.utils.framework import get_variable, try_import_tf
|
||||
from ray.rllib.utils.spaces.simplex import Simplex
|
||||
from ray.rllib.utils.tf_ops import huber_loss, make_tf_callable
|
||||
from ray.rllib.utils.typing import TrainerConfigDict, TensorType, \
|
||||
LocalOptimizer, ModelGradients
|
||||
from ray.rllib.models.action_dist import ActionDistribution
|
||||
from ray.rllib.models.modelv2 import ModelV2
|
||||
from ray.rllib.policy.policy import Policy
|
||||
|
||||
tf1, tf, tfv = try_import_tf()
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def build_ddpg_models(policy, observation_space, action_space, config):
|
||||
def build_ddpg_models(policy: Policy, observation_space: gym.spaces.Space,
|
||||
action_space: gym.spaces.Space,
|
||||
config: TrainerConfigDict) -> ModelV2:
|
||||
if policy.config["use_state_preprocessor"]:
|
||||
default_model = None # catalog decides
|
||||
num_outputs = 256 # arbitrary
|
||||
|
@ -80,13 +89,14 @@ def build_ddpg_models(policy, observation_space, action_space, config):
|
|||
return policy.model
|
||||
|
||||
|
||||
def get_distribution_inputs_and_class(policy,
|
||||
model,
|
||||
obs_batch,
|
||||
*,
|
||||
explore=True,
|
||||
is_training=False,
|
||||
**kwargs):
|
||||
def get_distribution_inputs_and_class(
|
||||
policy: Policy,
|
||||
model: ModelV2,
|
||||
obs_batch: SampleBatch,
|
||||
*,
|
||||
explore: bool = True,
|
||||
is_training: bool = False,
|
||||
**kwargs) -> Tuple[TensorType, ActionDistribution, List[TensorType]]:
|
||||
model_out, _ = model({
|
||||
"obs": obs_batch,
|
||||
"is_training": is_training,
|
||||
|
@ -102,7 +112,8 @@ def get_distribution_inputs_and_class(policy,
|
|||
return dist_inputs, distr_class, [] # []=state out
|
||||
|
||||
|
||||
def ddpg_actor_critic_loss(policy, model, _, train_batch):
|
||||
def ddpg_actor_critic_loss(policy: Policy, model: ModelV2, _,
|
||||
train_batch: SampleBatch) -> TensorType:
|
||||
twin_q = policy.config["twin_q"]
|
||||
gamma = policy.config["gamma"]
|
||||
n_step = policy.config["n_step"]
|
||||
|
@ -242,7 +253,7 @@ def ddpg_actor_critic_loss(policy, model, _, train_batch):
|
|||
return policy.critic_loss + policy.actor_loss
|
||||
|
||||
|
||||
def make_ddpg_optimizers(policy, config):
|
||||
def make_ddpg_optimizers(policy: Policy, config: TrainerConfigDict) -> None:
|
||||
# Create separate optimizers for actor & critic losses.
|
||||
if policy.config["framework"] in ["tf2", "tfe"]:
|
||||
policy._actor_optimizer = tf.keras.optimizers.Adam(
|
||||
|
@ -259,7 +270,8 @@ def make_ddpg_optimizers(policy, config):
|
|||
return None
|
||||
|
||||
|
||||
def build_apply_op(policy, optimizer, grads_and_vars):
|
||||
def build_apply_op(policy: Policy, optimizer: LocalOptimizer,
|
||||
grads_and_vars: ModelGradients) -> TensorType:
|
||||
# For policy gradient, update policy net one time v.s.
|
||||
# update critic net `policy_delay` time(s).
|
||||
should_apply_actor_opt = tf.equal(
|
||||
|
@ -284,7 +296,8 @@ def build_apply_op(policy, optimizer, grads_and_vars):
|
|||
return tf.group(actor_op, critic_op)
|
||||
|
||||
|
||||
def gradients_fn(policy, optimizer, loss):
|
||||
def gradients_fn(policy: Policy, optimizer: LocalOptimizer,
|
||||
loss: TensorType) -> ModelGradients:
|
||||
if policy.config["framework"] in ["tf2", "tfe"]:
|
||||
tape = optimizer.tape
|
||||
pol_weights = policy.model.policy_variables()
|
||||
|
@ -320,7 +333,8 @@ def gradients_fn(policy, optimizer, loss):
|
|||
return grads_and_vars
|
||||
|
||||
|
||||
def build_ddpg_stats(policy, batch):
|
||||
def build_ddpg_stats(policy: Policy,
|
||||
batch: SampleBatch) -> Dict[str, TensorType]:
|
||||
stats = {
|
||||
"mean_q": tf.reduce_mean(policy.q_t),
|
||||
"max_q": tf.reduce_max(policy.q_t),
|
||||
|
@ -329,7 +343,9 @@ def build_ddpg_stats(policy, batch):
|
|||
return stats
|
||||
|
||||
|
||||
def before_init_fn(policy, obs_space, action_space, config):
|
||||
def before_init_fn(policy: Policy, obs_space: gym.spaces.Space,
|
||||
action_space: gym.spaces.Space,
|
||||
config: TrainerConfigDict) -> None:
|
||||
# Create global step for counting the number of update operations.
|
||||
if config["framework"] in ["tf2", "tfe"]:
|
||||
policy.global_step = get_variable(0, tf_name="global_step")
|
||||
|
@ -359,12 +375,14 @@ class ComputeTDErrorMixin:
|
|||
self.compute_td_error = compute_td_error
|
||||
|
||||
|
||||
def setup_mid_mixins(policy, obs_space, action_space, config):
|
||||
def setup_mid_mixins(policy: Policy, obs_space: gym.spaces.Space,
|
||||
action_space: gym.spaces.Space,
|
||||
config: TrainerConfigDict) -> None:
|
||||
ComputeTDErrorMixin.__init__(policy, ddpg_actor_critic_loss)
|
||||
|
||||
|
||||
class TargetNetworkMixin:
|
||||
def __init__(self, config):
|
||||
def __init__(self, config: TrainerConfigDict):
|
||||
@make_tf_callable(self.get_session())
|
||||
def update_target_fn(tau):
|
||||
tau = tf.convert_to_tensor(tau, dtype=tf.float32)
|
||||
|
@ -384,19 +402,23 @@ class TargetNetworkMixin:
|
|||
self.update_target(tau=1.0)
|
||||
|
||||
# Support both hard and soft sync.
|
||||
def update_target(self, tau=None):
|
||||
def update_target(self, tau: int = None) -> None:
|
||||
self._do_update(np.float32(tau or self.config.get("tau")))
|
||||
|
||||
@override(TFPolicy)
|
||||
def variables(self):
|
||||
def variables(self) -> List[TensorType]:
|
||||
return self.model.variables() + self.target_model.variables()
|
||||
|
||||
|
||||
def setup_late_mixins(policy, obs_space, action_space, config):
|
||||
def setup_late_mixins(policy: Policy, obs_space: gym.spaces.Space,
|
||||
action_space: gym.spaces.Space,
|
||||
config: TrainerConfigDict) -> None:
|
||||
TargetNetworkMixin.__init__(policy, config)
|
||||
|
||||
|
||||
def validate_spaces(pid, observation_space, action_space, config):
|
||||
def validate_spaces(pid: int, observation_space: gym.spaces.Space,
|
||||
action_space: gym.spaces.Space,
|
||||
config: TrainerConfigDict) -> None:
|
||||
if not isinstance(action_space, Box):
|
||||
raise UnsupportedSpaceException(
|
||||
"Action space ({}) of {} is not supported for "
|
||||
|
|
|
@ -1,9 +1,12 @@
|
|||
import numpy as np
|
||||
import gym
|
||||
from typing import List, Dict, Union
|
||||
|
||||
from ray.rllib.models.torch.misc import SlimFC
|
||||
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
|
||||
from ray.rllib.models.utils import get_activation_fn
|
||||
from ray.rllib.utils.framework import try_import_torch
|
||||
from ray.rllib.utils.typing import ModelConfigDict, TensorType
|
||||
|
||||
torch, nn = try_import_torch()
|
||||
|
||||
|
@ -20,18 +23,20 @@ class DDPGTorchModel(TorchModelV2, nn.Module):
|
|||
Note that this class by itself is not a valid model unless you
|
||||
implement forward() in a subclass."""
|
||||
|
||||
def __init__(self,
|
||||
obs_space,
|
||||
action_space,
|
||||
num_outputs,
|
||||
model_config,
|
||||
name,
|
||||
actor_hidden_activation="relu",
|
||||
actor_hiddens=(256, 256),
|
||||
critic_hidden_activation="relu",
|
||||
critic_hiddens=(256, 256),
|
||||
twin_q=False,
|
||||
add_layer_norm=False):
|
||||
def __init__(
|
||||
self,
|
||||
obs_space: gym.spaces.Space,
|
||||
action_space: gym.spaces.Space,
|
||||
num_outputs: int,
|
||||
model_config: ModelConfigDict,
|
||||
name: str,
|
||||
# Extra DDPGActionModel args:
|
||||
actor_hiddens: List[int] = [256, 256],
|
||||
actor_hidden_activation: str = "relu",
|
||||
critic_hiddens: List[int] = [256, 256],
|
||||
critic_hidden_activation: str = "relu",
|
||||
twin_q: bool = False,
|
||||
add_layer_norm: bool = False):
|
||||
"""Initialize variables of this model.
|
||||
|
||||
Extra model kwargs:
|
||||
|
@ -137,7 +142,8 @@ class DDPGTorchModel(TorchModelV2, nn.Module):
|
|||
else:
|
||||
self.twin_q_model = None
|
||||
|
||||
def get_q_values(self, model_out, actions):
|
||||
def get_q_values(self, model_out: TensorType,
|
||||
actions: TensorType) -> TensorType:
|
||||
"""Return the Q estimates for the most recent forward pass.
|
||||
|
||||
This implements Q(s, a).
|
||||
|
@ -153,7 +159,8 @@ class DDPGTorchModel(TorchModelV2, nn.Module):
|
|||
"""
|
||||
return self.q_model(torch.cat([model_out, actions], -1))
|
||||
|
||||
def get_twin_q_values(self, model_out, actions):
|
||||
def get_twin_q_values(self, model_out: TensorType,
|
||||
actions: TensorType) -> TensorType:
|
||||
"""Same as get_q_values but using the twin Q net.
|
||||
|
||||
This implements the twin Q(s, a).
|
||||
|
@ -169,7 +176,7 @@ class DDPGTorchModel(TorchModelV2, nn.Module):
|
|||
"""
|
||||
return self.twin_q_model(torch.cat([model_out, actions], -1))
|
||||
|
||||
def get_policy_output(self, model_out):
|
||||
def get_policy_output(self, model_out: TensorType) -> TensorType:
|
||||
"""Return the action output for the most recent forward pass.
|
||||
|
||||
This outputs the support for pi(s). For continuous action spaces, this
|
||||
|
@ -184,13 +191,15 @@ class DDPGTorchModel(TorchModelV2, nn.Module):
|
|||
"""
|
||||
return self.policy_model(model_out)
|
||||
|
||||
def policy_variables(self, as_dict=False):
|
||||
def policy_variables(self, as_dict: bool = False
|
||||
) -> Union[List[TensorType], Dict[str, TensorType]]:
|
||||
"""Return the list of variables for the policy net."""
|
||||
if as_dict:
|
||||
return self.policy_model.state_dict()
|
||||
return list(self.policy_model.parameters())
|
||||
|
||||
def q_variables(self, as_dict=False):
|
||||
def q_variables(self, as_dict=False
|
||||
) -> Union[List[TensorType], Dict[str, TensorType]]:
|
||||
"""Return the list of variables for Q / twin Q nets."""
|
||||
if as_dict:
|
||||
return {
|
||||
|
|
|
@ -1,24 +1,34 @@
|
|||
import logging
|
||||
import gym
|
||||
from typing import Dict, Tuple
|
||||
|
||||
import ray
|
||||
from ray.rllib.agents.ddpg.ddpg_tf_policy import build_ddpg_models, \
|
||||
get_distribution_inputs_and_class, validate_spaces
|
||||
from ray.rllib.agents.dqn.dqn_tf_policy import postprocess_nstep_and_prio, \
|
||||
PRIO_WEIGHTS
|
||||
from ray.rllib.models.action_dist import ActionDistribution
|
||||
from ray.rllib.models.modelv2 import ModelV2
|
||||
from ray.rllib.models.torch.torch_action_dist import TorchDeterministic, \
|
||||
TorchDirichlet
|
||||
from ray.rllib.policy.policy import Policy
|
||||
from ray.rllib.policy.policy_template import build_policy_class
|
||||
from ray.rllib.policy.sample_batch import SampleBatch
|
||||
from ray.rllib.utils.framework import try_import_torch
|
||||
from ray.rllib.utils.spaces.simplex import Simplex
|
||||
from ray.rllib.utils.torch_ops import apply_grad_clipping, huber_loss, l2_loss
|
||||
from ray.rllib.utils.typing import TrainerConfigDict, TensorType, \
|
||||
LocalOptimizer, GradInfoDict
|
||||
|
||||
torch, nn = try_import_torch()
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def build_ddpg_models_and_action_dist(policy, obs_space, action_space, config):
|
||||
def build_ddpg_models_and_action_dist(
|
||||
policy: Policy, obs_space: gym.spaces.Space,
|
||||
action_space: gym.spaces.Space,
|
||||
config: TrainerConfigDict) -> Tuple[ModelV2, ActionDistribution]:
|
||||
model = build_ddpg_models(policy, obs_space, action_space, config)
|
||||
# TODO(sven): Unify this once we generically support creating more than
|
||||
# one Model per policy. Note: Device placement is done automatically
|
||||
|
@ -33,7 +43,8 @@ def build_ddpg_models_and_action_dist(policy, obs_space, action_space, config):
|
|||
return model, TorchDeterministic
|
||||
|
||||
|
||||
def ddpg_actor_critic_loss(policy, model, _, train_batch):
|
||||
def ddpg_actor_critic_loss(policy: Policy, model: ModelV2, _,
|
||||
train_batch: SampleBatch) -> TensorType:
|
||||
twin_q = policy.config["twin_q"]
|
||||
gamma = policy.config["gamma"]
|
||||
n_step = policy.config["n_step"]
|
||||
|
@ -173,7 +184,8 @@ def ddpg_actor_critic_loss(policy, model, _, train_batch):
|
|||
return policy.actor_loss, policy.critic_loss
|
||||
|
||||
|
||||
def make_ddpg_optimizers(policy, config):
|
||||
def make_ddpg_optimizers(policy: Policy,
|
||||
config: TrainerConfigDict) -> Tuple[LocalOptimizer]:
|
||||
"""Create separate optimizers for actor & critic losses."""
|
||||
|
||||
# Set epsilons to match tf.keras.optimizers.Adam's epsilon default.
|
||||
|
@ -189,7 +201,7 @@ def make_ddpg_optimizers(policy, config):
|
|||
return policy._actor_optimizer, policy._critic_optimizer
|
||||
|
||||
|
||||
def apply_gradients_fn(policy, gradients):
|
||||
def apply_gradients_fn(policy: Policy, gradients: GradInfoDict) -> None:
|
||||
# For policy gradient, update policy net one time v.s.
|
||||
# update critic net `policy_delay` time(s).
|
||||
if policy.global_step % policy.config["policy_delay"] == 0:
|
||||
|
@ -201,7 +213,8 @@ def apply_gradients_fn(policy, gradients):
|
|||
policy.global_step += 1
|
||||
|
||||
|
||||
def build_ddpg_stats(policy, batch):
|
||||
def build_ddpg_stats(policy: Policy,
|
||||
batch: SampleBatch) -> Dict[str, TensorType]:
|
||||
stats = {
|
||||
"actor_loss": policy.actor_loss,
|
||||
"critic_loss": policy.critic_loss,
|
||||
|
@ -214,7 +227,9 @@ def build_ddpg_stats(policy, batch):
|
|||
return stats
|
||||
|
||||
|
||||
def before_init_fn(policy, obs_space, action_space, config):
|
||||
def before_init_fn(policy: Policy, obs_space: gym.spaces.Space,
|
||||
action_space: gym.spaces.Space,
|
||||
config: TrainerConfigDict) -> None:
|
||||
# Create global step for counting the number of update operations.
|
||||
policy.global_step = 0
|
||||
|
||||
|
@ -247,7 +262,7 @@ class TargetNetworkMixin:
|
|||
# Hard initial update from Q-net(s) to target Q-net(s).
|
||||
self.update_target(tau=1.0)
|
||||
|
||||
def update_target(self, tau=None):
|
||||
def update_target(self, tau: int = None):
|
||||
tau = tau or self.config.get("tau")
|
||||
# Update_target_fn will be called periodically to copy Q network to
|
||||
# target Q network, using (soft) tau-synching.
|
||||
|
@ -265,7 +280,9 @@ class TargetNetworkMixin:
|
|||
(1.0 - tau) * var_target.data
|
||||
|
||||
|
||||
def setup_late_mixins(policy, obs_space, action_space, config):
|
||||
def setup_late_mixins(policy: Policy, obs_space: gym.spaces.Space,
|
||||
action_space: gym.spaces.Space,
|
||||
config: TrainerConfigDict) -> None:
|
||||
ComputeTDErrorMixin.__init__(policy, ddpg_actor_critic_loss)
|
||||
TargetNetworkMixin.__init__(policy)
|
||||
|
||||
|
|
|
@ -375,7 +375,9 @@ def compute_q_values(policy: Policy,
|
|||
return value, logits, dist, state
|
||||
|
||||
|
||||
def _adjust_nstep(n_step, gamma, obs, actions, rewards, new_obs, dones):
|
||||
def _adjust_nstep(n_step: int, gamma: int, obs: TensorType,
|
||||
actions: TensorType, rewards: TensorType,
|
||||
new_obs: TensorType, dones: TensorType):
|
||||
"""Rewrites the given trajectory fragments to encode n-step rewards.
|
||||
|
||||
reward[i] = (
|
||||
|
|
|
@ -23,7 +23,7 @@ BC_DEFAULT_CONFIG = MARWILTrainer.merge_trainer_configs(
|
|||
# yapf: enable
|
||||
|
||||
|
||||
def validate_config(config: TrainerConfigDict):
|
||||
def validate_config(config: TrainerConfigDict) -> None:
|
||||
if config["beta"] != 0.0:
|
||||
raise ValueError(
|
||||
"For behavioral cloning, `beta` parameter must be 0.0!")
|
||||
|
|
|
@ -1,3 +1,5 @@
|
|||
from typing import Optional, Type
|
||||
|
||||
from ray.rllib.agents.trainer import with_common_config
|
||||
from ray.rllib.agents.trainer_template import build_trainer
|
||||
from ray.rllib.agents.marwil.marwil_tf_policy import MARWILTFPolicy
|
||||
|
@ -7,6 +9,10 @@ from ray.rllib.execution.rollout_ops import ParallelRollouts, ConcatBatches
|
|||
from ray.rllib.execution.concurrency_ops import Concurrently
|
||||
from ray.rllib.execution.train_ops import TrainOneStep
|
||||
from ray.rllib.execution.metric_ops import StandardMetricsReporting
|
||||
from ray.rllib.utils.typing import TrainerConfigDict
|
||||
from ray.rllib.evaluation.worker_set import WorkerSet
|
||||
from ray.util.iter import LocalIterator
|
||||
from ray.rllib.policy.policy import Policy
|
||||
|
||||
# yapf: disable
|
||||
# __sphinx_doc_begin__
|
||||
|
@ -46,14 +52,36 @@ DEFAULT_CONFIG = with_common_config({
|
|||
# yapf: enable
|
||||
|
||||
|
||||
def get_policy_class(config):
|
||||
def get_policy_class(config: TrainerConfigDict) -> Optional[Type[Policy]]:
|
||||
"""Policy class picker function. Class is chosen based on DL-framework.
|
||||
MARWIL/BC have both TF and Torch policy support.
|
||||
|
||||
Args:
|
||||
config (TrainerConfigDict): The trainer's configuration dict.
|
||||
|
||||
Returns:
|
||||
Optional[Type[Policy]]: The Policy class to use with DQNTrainer.
|
||||
If None, use `default_policy` provided in build_trainer().
|
||||
"""
|
||||
if config["framework"] == "torch":
|
||||
from ray.rllib.agents.marwil.marwil_torch_policy import \
|
||||
MARWILTorchPolicy
|
||||
return MARWILTorchPolicy
|
||||
|
||||
|
||||
def execution_plan(workers, config):
|
||||
def execution_plan(workers: WorkerSet,
|
||||
config: TrainerConfigDict) -> LocalIterator[dict]:
|
||||
"""Execution plan of the MARWIL/BC algorithm. Defines the distributed
|
||||
dataflow.
|
||||
|
||||
Args:
|
||||
workers (WorkerSet): The WorkerSet for training the Polic(y/ies)
|
||||
of the Trainer.
|
||||
config (TrainerConfigDict): The trainer's configuration dict.
|
||||
|
||||
Returns:
|
||||
LocalIterator[dict]: A local iterator over training metrics.
|
||||
"""
|
||||
rollouts = ParallelRollouts(workers, mode="bulk_sync")
|
||||
replay_buffer = SimpleReplayBuffer(config["replay_buffer_size"])
|
||||
|
||||
|
@ -74,7 +102,11 @@ def execution_plan(workers, config):
|
|||
return StandardMetricsReporting(train_op, workers, config)
|
||||
|
||||
|
||||
def validate_config(config):
|
||||
def validate_config(config: TrainerConfigDict) -> None:
|
||||
"""Checks and updates the config based on settings.
|
||||
|
||||
Rewrites rollout_fragment_length to take into account n_step truncation.
|
||||
"""
|
||||
if config["num_gpus"] > 1:
|
||||
raise ValueError("`num_gpus` > 1 not yet supported for MARWIL!")
|
||||
|
||||
|
|
|
@ -1,4 +1,6 @@
|
|||
import logging
|
||||
import gym
|
||||
from typing import Optional, Dict
|
||||
|
||||
import ray
|
||||
from ray.rllib.agents.ppo.ppo_tf_policy import compute_and_clip_gradients
|
||||
|
@ -8,6 +10,11 @@ from ray.rllib.evaluation.postprocessing import compute_advantages, \
|
|||
from ray.rllib.policy.tf_policy_template import build_tf_policy
|
||||
from ray.rllib.utils.framework import try_import_tf, get_variable
|
||||
from ray.rllib.utils.tf_ops import explained_variance, make_tf_callable
|
||||
from ray.rllib.policy.policy import Policy
|
||||
from ray.rllib.utils.typing import TrainerConfigDict, TensorType, \
|
||||
PolicyID
|
||||
from ray.rllib.models.action_dist import ActionDistribution
|
||||
from ray.rllib.models.modelv2 import ModelV2
|
||||
|
||||
tf1, tf, tfv = try_import_tf()
|
||||
|
||||
|
@ -15,7 +22,8 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
|
||||
class ValueNetworkMixin:
|
||||
def __init__(self, obs_space, action_space, config):
|
||||
def __init__(self, obs_space: gym.spaces.Space,
|
||||
action_space: gym.spaces.Space, config: TrainerConfigDict):
|
||||
|
||||
# Input dict is provided to us automatically via the Model's
|
||||
# requirements. It's a single-timestep (last one in trajectory)
|
||||
|
@ -29,10 +37,11 @@ class ValueNetworkMixin:
|
|||
self._value = value
|
||||
|
||||
|
||||
def postprocess_advantages(policy,
|
||||
sample_batch,
|
||||
other_agent_batches=None,
|
||||
episode=None):
|
||||
def postprocess_advantages(
|
||||
policy: Policy,
|
||||
sample_batch: SampleBatch,
|
||||
other_agent_batches: Optional[Dict[PolicyID, SampleBatch]] = None,
|
||||
episode=None) -> SampleBatch:
|
||||
"""Postprocesses a trajectory and returns the processed trajectory.
|
||||
|
||||
The trajectory contains only data from one episode and from one agent.
|
||||
|
@ -84,8 +93,10 @@ def postprocess_advantages(policy,
|
|||
|
||||
|
||||
class MARWILLoss:
|
||||
def __init__(self, policy, value_estimates, action_dist, actions,
|
||||
cumulative_rewards, vf_loss_coeff, beta):
|
||||
def __init__(self, policy: Policy, value_estimates: TensorType,
|
||||
action_dist: ActionDistribution, actions: TensorType,
|
||||
cumulative_rewards: TensorType, vf_loss_coeff: float,
|
||||
beta: float):
|
||||
|
||||
# Advantage Estimation.
|
||||
adv = cumulative_rewards - value_estimates
|
||||
|
@ -133,7 +144,8 @@ class MARWILLoss:
|
|||
explained_variance(cumulative_rewards, value_estimates))
|
||||
|
||||
|
||||
def marwil_loss(policy, model, dist_class, train_batch):
|
||||
def marwil_loss(policy: Policy, model: ModelV2, dist_class: ActionDistribution,
|
||||
train_batch: SampleBatch) -> TensorType:
|
||||
model_out, _ = model.from_batch(train_batch)
|
||||
action_dist = dist_class(model_out, model)
|
||||
value_estimates = model.value_function()
|
||||
|
@ -146,7 +158,7 @@ def marwil_loss(policy, model, dist_class, train_batch):
|
|||
return policy.loss.total_loss
|
||||
|
||||
|
||||
def stats(policy, train_batch):
|
||||
def stats(policy: Policy, train_batch: SampleBatch) -> Dict[str, TensorType]:
|
||||
return {
|
||||
"policy_loss": policy.loss.p_loss,
|
||||
"vf_loss": policy.loss.v_loss,
|
||||
|
@ -155,7 +167,9 @@ def stats(policy, train_batch):
|
|||
}
|
||||
|
||||
|
||||
def setup_mixins(policy, obs_space, action_space, config):
|
||||
def setup_mixins(policy: Policy, obs_space: gym.spaces.Space,
|
||||
action_space: gym.spaces.Space,
|
||||
config: TrainerConfigDict) -> None:
|
||||
ValueNetworkMixin.__init__(policy, obs_space, action_space, config)
|
||||
# Set up a tf-var for the moving avg (do this here to make it work with
|
||||
# eager mode); "c^2" in the paper.
|
||||
|
|
|
@ -1,3 +1,6 @@
|
|||
import gym
|
||||
from typing import Dict
|
||||
|
||||
import ray
|
||||
from ray.rllib.agents.ppo.ppo_torch_policy import ValueNetworkMixin
|
||||
from ray.rllib.agents.marwil.marwil_tf_policy import postprocess_advantages
|
||||
|
@ -6,11 +9,16 @@ from ray.rllib.policy.policy_template import build_policy_class
|
|||
from ray.rllib.policy.sample_batch import SampleBatch
|
||||
from ray.rllib.utils.framework import try_import_torch
|
||||
from ray.rllib.utils.torch_ops import apply_grad_clipping, explained_variance
|
||||
from ray.rllib.utils.typing import TrainerConfigDict, TensorType
|
||||
from ray.rllib.policy.policy import Policy
|
||||
from ray.rllib.models.action_dist import ActionDistribution
|
||||
from ray.rllib.models.modelv2 import ModelV2
|
||||
|
||||
torch, _ = try_import_torch()
|
||||
|
||||
|
||||
def marwil_loss(policy, model, dist_class, train_batch):
|
||||
def marwil_loss(policy: Policy, model: ModelV2, dist_class: ActionDistribution,
|
||||
train_batch: SampleBatch) -> TensorType:
|
||||
model_out, _ = model.from_batch(train_batch)
|
||||
action_dist = dist_class(model_out, model)
|
||||
state_values = model.value_function()
|
||||
|
@ -43,7 +51,7 @@ def marwil_loss(policy, model, dist_class, train_batch):
|
|||
return policy.total_loss
|
||||
|
||||
|
||||
def stats(policy, train_batch):
|
||||
def stats(policy: Policy, train_batch: SampleBatch) -> Dict[str, TensorType]:
|
||||
return {
|
||||
"policy_loss": policy.p_loss,
|
||||
"vf_loss": policy.v_loss,
|
||||
|
@ -52,7 +60,9 @@ def stats(policy, train_batch):
|
|||
}
|
||||
|
||||
|
||||
def setup_mixins(policy, obs_space, action_space, config):
|
||||
def setup_mixins(policy: Policy, obs_space: gym.spaces.Space,
|
||||
action_space: gym.spaces.Space,
|
||||
config: TrainerConfigDict) -> None:
|
||||
# Create a var.
|
||||
policy.ma_adv_norm = torch.tensor(
|
||||
[100.0], dtype=torch.float32, requires_grad=False).to(policy.device)
|
||||
|
|
|
@ -123,7 +123,8 @@ class SampleBatch(dict):
|
|||
"""Concatenates n data dicts or MultiAgentBatches.
|
||||
|
||||
Args:
|
||||
samples (List[Dict[TensorType]]]): List of dicts of data (numpy).
|
||||
samples (List[Dict[str, TensorType]]]): List of dicts of data
|
||||
(numpy).
|
||||
|
||||
Returns:
|
||||
Union[SampleBatch, MultiAgentBatch]: A new (compressed)
|
||||
|
|
Loading…
Add table
Reference in a new issue