[RLlib] DDPG/TD3 + A3C/A2C + MARWIL/BC Annotation/Comments/Code Cleanup (#14707)

This commit is contained in:
Michael Luo 2021-05-19 07:32:29 -07:00 committed by GitHub
parent 8790bb465b
commit 474f04e322
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
15 changed files with 316 additions and 115 deletions

View file

@ -1,4 +1,5 @@
import math
from typing import Optional, Type
from ray.rllib.agents.a3c.a3c import DEFAULT_CONFIG as A3C_CONFIG, \
validate_config, get_policy_class
@ -9,6 +10,9 @@ from ray.rllib.execution.rollout_ops import ParallelRollouts, ConcatBatches
from ray.rllib.execution.train_ops import ComputeGradients, AverageGradients, \
ApplyGradients, TrainTFMultiGPU, TrainOneStep
from ray.rllib.utils import merge_dicts
from ray.rllib.utils.typing import TrainerConfigDict
from ray.rllib.evaluation.worker_set import WorkerSet
from ray.rllib.policy.policy import Policy
A2C_DEFAULT_CONFIG = merge_dicts(
A3C_CONFIG,
@ -26,7 +30,19 @@ A2C_DEFAULT_CONFIG = merge_dicts(
)
def execution_plan(workers, config):
def execution_plan(workers: WorkerSet,
config: TrainerConfigDict) -> Optional[Type[Policy]]:
"""Execution plan of the MARWIL/BC algorithm. Defines the distributed
dataflow.
Args:
workers (WorkerSet): The WorkerSet for training the Polic(y/ies)
of the Trainer.
config (TrainerConfigDict): The trainer's configuration dict.
Returns:
LocalIterator[dict]: A local iterator over training metrics.
"""
rollouts = ParallelRollouts(workers, mode="bulk_sync")
if config["microbatch_size"]:

View file

@ -1,4 +1,5 @@
import logging
from typing import Optional, Type
from ray.rllib.agents.a3c.a3c_tf_policy import A3CTFPolicy
from ray.rllib.agents.trainer import with_common_config
@ -6,6 +7,10 @@ from ray.rllib.agents.trainer_template import build_trainer
from ray.rllib.execution.rollout_ops import AsyncGradients
from ray.rllib.execution.train_ops import ApplyGradients
from ray.rllib.execution.metric_ops import StandardMetricsReporting
from ray.rllib.utils.typing import TrainerConfigDict
from ray.rllib.evaluation.worker_set import WorkerSet
from ray.util.iter import LocalIterator
from ray.rllib.policy.policy import Policy
logger = logging.getLogger(__name__)
@ -42,7 +47,16 @@ DEFAULT_CONFIG = with_common_config({
# yapf: enable
def get_policy_class(config):
def get_policy_class(config: TrainerConfigDict) -> Optional[Type[Policy]]:
"""Policy class picker function. Class is chosen based on DL-framework.
Args:
config (TrainerConfigDict): The trainer's configuration dict.
Returns:
Optional[Type[Policy]]: The Policy class to use with DQNTrainer.
If None, use `default_policy` provided in build_trainer().
"""
if config["framework"] == "torch":
from ray.rllib.agents.a3c.a3c_torch_policy import \
A3CTorchPolicy
@ -51,14 +65,30 @@ def get_policy_class(config):
return A3CTFPolicy
def validate_config(config):
def validate_config(config: TrainerConfigDict) -> None:
"""Checks and updates the config based on settings.
Rewrites rollout_fragment_length to take into account n_step truncation.
"""
if config["entropy_coeff"] < 0:
raise ValueError("`entropy_coeff` must be >= 0.0!")
if config["num_workers"] <= 0 and config["sample_async"]:
raise ValueError("`num_workers` for A3C must be >= 1!")
def execution_plan(workers, config):
def execution_plan(workers: WorkerSet,
config: TrainerConfigDict) -> LocalIterator[dict]:
"""Execution plan of the MARWIL/BC algorithm. Defines the distributed
dataflow.
Args:
workers (WorkerSet): The WorkerSet for training the Polic(y/ies)
of the Trainer.
config (TrainerConfigDict): The trainer's configuration dict.
Returns:
LocalIterator[dict]: A local iterator over training metrics.
"""
# For A3C, compute policy gradients remotely on the rollout workers.
grads = AsyncGradients(workers)

View file

@ -1,4 +1,6 @@
"""Note: Keep in sync with changes to VTraceTFPolicy."""
from typing import Optional, Dict
import gym
import ray
from ray.rllib.agents.ppo.ppo_tf_policy import ValueNetworkMixin
@ -10,14 +12,21 @@ from ray.rllib.policy.tf_policy import LearningRateSchedule
from ray.rllib.utils.deprecation import deprecation_warning
from ray.rllib.utils.framework import try_import_tf
from ray.rllib.utils.tf_ops import explained_variance
from ray.rllib.policy.policy import Policy
from ray.rllib.utils.typing import TrainerConfigDict, TensorType, \
PolicyID, LocalOptimizer, ModelGradients
from ray.rllib.models.action_dist import ActionDistribution
from ray.rllib.models.modelv2 import ModelV2
from ray.rllib.evaluation import MultiAgentEpisode
tf1, tf, tfv = try_import_tf()
def postprocess_advantages(policy,
sample_batch,
other_agent_batches=None,
episode=None):
def postprocess_advantages(
policy: Policy,
sample_batch: SampleBatch,
other_agent_batches: Optional[Dict[PolicyID, SampleBatch]] = None,
episode: Optional[MultiAgentEpisode] = None) -> SampleBatch:
# Stub serving backward compatibility.
deprecation_warning(
@ -31,15 +40,15 @@ def postprocess_advantages(policy,
class A3CLoss:
def __init__(self,
action_dist,
actions,
advantages,
v_target,
vf,
valid_mask,
vf_loss_coeff=0.5,
entropy_coeff=0.01,
use_critic=True):
action_dist: ActionDistribution,
actions: TensorType,
advantages: TensorType,
v_target: TensorType,
vf: TensorType,
valid_mask: TensorType,
vf_loss_coeff: float = 0.5,
entropy_coeff: float = 0.01,
use_critic: bool = True):
log_prob = action_dist.logp(actions)
# The "policy gradients" loss
@ -62,7 +71,9 @@ class A3CLoss:
self.entropy * entropy_coeff)
def actor_critic_loss(policy, model, dist_class, train_batch):
def actor_critic_loss(policy: Policy, model: ModelV2,
dist_class: ActionDistribution,
train_batch: SampleBatch) -> TensorType:
model_out, _ = model.from_batch(train_batch)
action_dist = dist_class(model_out, model)
if policy.is_recurrent():
@ -81,11 +92,11 @@ def actor_critic_loss(policy, model, dist_class, train_batch):
return policy.loss.total_loss
def add_value_function_fetch(policy):
def add_value_function_fetch(policy: Policy) -> Dict[str, TensorType]:
return {SampleBatch.VF_PREDS: policy.model.value_function()}
def stats(policy, train_batch):
def stats(policy: Policy, train_batch: SampleBatch) -> Dict[str, TensorType]:
return {
"cur_lr": tf.cast(policy.cur_lr, tf.float64),
"policy_loss": policy.loss.pi_loss,
@ -96,7 +107,8 @@ def stats(policy, train_batch):
}
def grad_stats(policy, train_batch, grads):
def grad_stats(policy: Policy, train_batch: SampleBatch,
grads: ModelGradients) -> Dict[str, TensorType]:
return {
"grad_gnorm": tf.linalg.global_norm(grads),
"vf_explained_var": explained_variance(
@ -105,7 +117,8 @@ def grad_stats(policy, train_batch, grads):
}
def clip_gradients(policy, optimizer, loss):
def clip_gradients(policy: Policy, optimizer: LocalOptimizer,
loss: TensorType) -> ModelGradients:
grads_and_vars = optimizer.compute_gradients(
loss, policy.model.trainable_variables())
grads = [g for (g, v) in grads_and_vars]
@ -114,7 +127,9 @@ def clip_gradients(policy, optimizer, loss):
return clipped_grads
def setup_mixins(policy, obs_space, action_space, config):
def setup_mixins(policy: Policy, obs_space: gym.spaces.Space,
action_space: gym.spaces.Space,
config: TrainerConfigDict) -> None:
ValueNetworkMixin.__init__(policy, obs_space, action_space, config)
LearningRateSchedule.__init__(policy, config["lr"], config["lr_schedule"])

View file

@ -1,24 +1,30 @@
import gym
from typing import Optional, Dict
import ray
from ray.rllib.agents.ppo.ppo_torch_policy import ValueNetworkMixin
from ray.rllib.evaluation.episode import MultiAgentEpisode
from ray.rllib.evaluation.postprocessing import compute_gae_for_sample_batch, \
Postprocessing
from ray.rllib.models.action_dist import ActionDistribution
from ray.rllib.models.modelv2 import ModelV2
from ray.rllib.policy.policy import Policy
from ray.rllib.policy.policy_template import build_policy_class
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.utils.deprecation import deprecation_warning
from ray.rllib.utils.framework import try_import_torch
from ray.rllib.utils.torch_ops import apply_grad_clipping, sequence_mask
from ray.rllib.utils.typing import TrainerConfigDict
from ray.rllib.utils.typing import TrainerConfigDict, TensorType, \
PolicyID, LocalOptimizer
torch, nn = try_import_torch()
def add_advantages(policy,
sample_batch,
other_agent_batches=None,
episode=None):
def add_advantages(
policy: Policy,
sample_batch: SampleBatch,
other_agent_batches: Optional[Dict[PolicyID, SampleBatch]] = None,
episode: Optional[MultiAgentEpisode] = None) -> SampleBatch:
# Stub serving backward compatibility.
deprecation_warning(
@ -30,7 +36,9 @@ def add_advantages(policy,
other_agent_batches, episode)
def actor_critic_loss(policy, model, dist_class, train_batch):
def actor_critic_loss(policy: Policy, model: ModelV2,
dist_class: ActionDistribution,
train_batch: SampleBatch) -> TensorType:
logits, _ = model.from_batch(train_batch)
values = model.value_function()
@ -68,7 +76,8 @@ def actor_critic_loss(policy, model, dist_class, train_batch):
return total_loss
def loss_and_entropy_stats(policy, train_batch):
def loss_and_entropy_stats(policy: Policy,
train_batch: SampleBatch) -> Dict[str, TensorType]:
return {
"policy_entropy": policy.entropy,
"policy_loss": policy.pi_err,
@ -76,12 +85,15 @@ def loss_and_entropy_stats(policy, train_batch):
}
def model_value_predictions(policy, input_dict, state_batches, model,
action_dist):
def model_value_predictions(
policy: Policy, input_dict: Dict[str, TensorType], state_batches,
model: ModelV2,
action_dist: ActionDistribution) -> Dict[str, TensorType]:
return {SampleBatch.VF_PREDS: model.value_function()}
def torch_optimizer(policy, config):
def torch_optimizer(policy: Policy,
config: TrainerConfigDict) -> LocalOptimizer:
return torch.optim.Adam(policy.model.parameters(), lr=config["lr"])

View file

@ -1,9 +1,12 @@
import logging
from typing import Optional, Type
from ray.rllib.agents.trainer import with_common_config
from ray.rllib.agents.dqn.dqn import GenericOffPolicyTrainer
from ray.rllib.agents.ddpg.ddpg_tf_policy import DDPGTFPolicy
from ray.rllib.policy.policy import Policy
from ray.rllib.utils.deprecation import DEPRECATED_VALUE
from ray.rllib.utils.typing import TrainerConfigDict
logger = logging.getLogger(__name__)
@ -151,7 +154,11 @@ DEFAULT_CONFIG = with_common_config({
# yapf: enable
def validate_config(config):
def validate_config(config: TrainerConfigDict) -> None:
"""Checks and updates the config based on settings.
Rewrites rollout_fragment_length to take into account n_step truncation.
"""
if config["num_gpus"] > 1:
raise ValueError("`num_gpus` > 1 not yet supported for DDPG!")
if config["model"]["custom_model"]:
@ -176,7 +183,16 @@ def validate_config(config):
config["simple_optimizer"] = True
def get_policy_class(config):
def get_policy_class(config: TrainerConfigDict) -> Optional[Type[Policy]]:
"""Policy class picker function. Class is chosen based on DL-framework.
Args:
config (TrainerConfigDict): The trainer's configuration dict.
Returns:
Optional[Type[Policy]]: The Policy class to use with DQNTrainer.
If None, use `default_policy` provided in build_trainer().
"""
if config["framework"] == "torch":
from ray.rllib.agents.ddpg.ddpg_torch_policy import DDPGTorchPolicy
return DDPGTorchPolicy

View file

@ -1,7 +1,10 @@
import numpy as np
import gym
from typing import List
from ray.rllib.models.tf.tf_modelv2 import TFModelV2
from ray.rllib.utils.framework import try_import_tf
from ray.rllib.utils.typing import ModelConfigDict, TensorType
tf1, tf, tfv = try_import_tf()
@ -20,18 +23,18 @@ class DDPGTFModel(TFModelV2):
def __init__(
self,
obs_space,
action_space,
num_outputs,
model_config,
name,
obs_space: gym.spaces.Space,
action_space: gym.spaces.Space,
num_outputs: int,
model_config: ModelConfigDict,
name: str,
# Extra DDPGActionModel args:
actor_hiddens=(256, 256),
actor_hidden_activation="relu",
critic_hiddens=(256, 256),
critic_hidden_activation="relu",
twin_q=False,
add_layer_norm=False):
actor_hiddens: List[int] = [256, 256],
actor_hidden_activation: str = "relu",
critic_hiddens: List[int] = [256, 256],
critic_hidden_activation: str = "relu",
twin_q: bool = False,
add_layer_norm: bool = False):
"""Initialize variables of this model.
Extra model kwargs:
@ -122,7 +125,8 @@ class DDPGTFModel(TFModelV2):
else:
self.twin_q_model = None
def get_q_values(self, model_out, actions):
def get_q_values(self, model_out: TensorType,
actions: TensorType) -> TensorType:
"""Return the Q estimates for the most recent forward pass.
This implements Q(s, a).
@ -141,7 +145,8 @@ class DDPGTFModel(TFModelV2):
else:
return self.q_model(model_out)
def get_twin_q_values(self, model_out, actions):
def get_twin_q_values(self, model_out: TensorType,
actions: TensorType) -> TensorType:
"""Same as get_q_values but using the twin Q net.
This implements the twin Q(s, a).
@ -160,7 +165,7 @@ class DDPGTFModel(TFModelV2):
else:
return self.twin_q_model(model_out)
def get_policy_output(self, model_out):
def get_policy_output(self, model_out: TensorType) -> TensorType:
"""Return the action output for the most recent forward pass.
This outputs the support for pi(s). For continuous action spaces, this
@ -175,11 +180,11 @@ class DDPGTFModel(TFModelV2):
"""
return self.policy_model(model_out)
def policy_variables(self):
def policy_variables(self) -> List[TensorType]:
"""Return the list of variables for the policy net."""
return list(self.policy_model.variables)
def q_variables(self):
def q_variables(self) -> List[TensorType]:
"""Return the list of variables for Q / twin Q nets."""
return self.q_model.variables + (self.twin_q_model.variables

View file

@ -2,6 +2,8 @@ from gym.spaces import Box
from functools import partial
import logging
import numpy as np
import gym
from typing import Dict, Tuple, List
import ray
import ray.experimental.tf_utils
@ -23,13 +25,20 @@ from ray.rllib.utils.error import UnsupportedSpaceException
from ray.rllib.utils.framework import get_variable, try_import_tf
from ray.rllib.utils.spaces.simplex import Simplex
from ray.rllib.utils.tf_ops import huber_loss, make_tf_callable
from ray.rllib.utils.typing import TrainerConfigDict, TensorType, \
LocalOptimizer, ModelGradients
from ray.rllib.models.action_dist import ActionDistribution
from ray.rllib.models.modelv2 import ModelV2
from ray.rllib.policy.policy import Policy
tf1, tf, tfv = try_import_tf()
logger = logging.getLogger(__name__)
def build_ddpg_models(policy, observation_space, action_space, config):
def build_ddpg_models(policy: Policy, observation_space: gym.spaces.Space,
action_space: gym.spaces.Space,
config: TrainerConfigDict) -> ModelV2:
if policy.config["use_state_preprocessor"]:
default_model = None # catalog decides
num_outputs = 256 # arbitrary
@ -80,13 +89,14 @@ def build_ddpg_models(policy, observation_space, action_space, config):
return policy.model
def get_distribution_inputs_and_class(policy,
model,
obs_batch,
*,
explore=True,
is_training=False,
**kwargs):
def get_distribution_inputs_and_class(
policy: Policy,
model: ModelV2,
obs_batch: SampleBatch,
*,
explore: bool = True,
is_training: bool = False,
**kwargs) -> Tuple[TensorType, ActionDistribution, List[TensorType]]:
model_out, _ = model({
"obs": obs_batch,
"is_training": is_training,
@ -102,7 +112,8 @@ def get_distribution_inputs_and_class(policy,
return dist_inputs, distr_class, [] # []=state out
def ddpg_actor_critic_loss(policy, model, _, train_batch):
def ddpg_actor_critic_loss(policy: Policy, model: ModelV2, _,
train_batch: SampleBatch) -> TensorType:
twin_q = policy.config["twin_q"]
gamma = policy.config["gamma"]
n_step = policy.config["n_step"]
@ -242,7 +253,7 @@ def ddpg_actor_critic_loss(policy, model, _, train_batch):
return policy.critic_loss + policy.actor_loss
def make_ddpg_optimizers(policy, config):
def make_ddpg_optimizers(policy: Policy, config: TrainerConfigDict) -> None:
# Create separate optimizers for actor & critic losses.
if policy.config["framework"] in ["tf2", "tfe"]:
policy._actor_optimizer = tf.keras.optimizers.Adam(
@ -259,7 +270,8 @@ def make_ddpg_optimizers(policy, config):
return None
def build_apply_op(policy, optimizer, grads_and_vars):
def build_apply_op(policy: Policy, optimizer: LocalOptimizer,
grads_and_vars: ModelGradients) -> TensorType:
# For policy gradient, update policy net one time v.s.
# update critic net `policy_delay` time(s).
should_apply_actor_opt = tf.equal(
@ -284,7 +296,8 @@ def build_apply_op(policy, optimizer, grads_and_vars):
return tf.group(actor_op, critic_op)
def gradients_fn(policy, optimizer, loss):
def gradients_fn(policy: Policy, optimizer: LocalOptimizer,
loss: TensorType) -> ModelGradients:
if policy.config["framework"] in ["tf2", "tfe"]:
tape = optimizer.tape
pol_weights = policy.model.policy_variables()
@ -320,7 +333,8 @@ def gradients_fn(policy, optimizer, loss):
return grads_and_vars
def build_ddpg_stats(policy, batch):
def build_ddpg_stats(policy: Policy,
batch: SampleBatch) -> Dict[str, TensorType]:
stats = {
"mean_q": tf.reduce_mean(policy.q_t),
"max_q": tf.reduce_max(policy.q_t),
@ -329,7 +343,9 @@ def build_ddpg_stats(policy, batch):
return stats
def before_init_fn(policy, obs_space, action_space, config):
def before_init_fn(policy: Policy, obs_space: gym.spaces.Space,
action_space: gym.spaces.Space,
config: TrainerConfigDict) -> None:
# Create global step for counting the number of update operations.
if config["framework"] in ["tf2", "tfe"]:
policy.global_step = get_variable(0, tf_name="global_step")
@ -359,12 +375,14 @@ class ComputeTDErrorMixin:
self.compute_td_error = compute_td_error
def setup_mid_mixins(policy, obs_space, action_space, config):
def setup_mid_mixins(policy: Policy, obs_space: gym.spaces.Space,
action_space: gym.spaces.Space,
config: TrainerConfigDict) -> None:
ComputeTDErrorMixin.__init__(policy, ddpg_actor_critic_loss)
class TargetNetworkMixin:
def __init__(self, config):
def __init__(self, config: TrainerConfigDict):
@make_tf_callable(self.get_session())
def update_target_fn(tau):
tau = tf.convert_to_tensor(tau, dtype=tf.float32)
@ -384,19 +402,23 @@ class TargetNetworkMixin:
self.update_target(tau=1.0)
# Support both hard and soft sync.
def update_target(self, tau=None):
def update_target(self, tau: int = None) -> None:
self._do_update(np.float32(tau or self.config.get("tau")))
@override(TFPolicy)
def variables(self):
def variables(self) -> List[TensorType]:
return self.model.variables() + self.target_model.variables()
def setup_late_mixins(policy, obs_space, action_space, config):
def setup_late_mixins(policy: Policy, obs_space: gym.spaces.Space,
action_space: gym.spaces.Space,
config: TrainerConfigDict) -> None:
TargetNetworkMixin.__init__(policy, config)
def validate_spaces(pid, observation_space, action_space, config):
def validate_spaces(pid: int, observation_space: gym.spaces.Space,
action_space: gym.spaces.Space,
config: TrainerConfigDict) -> None:
if not isinstance(action_space, Box):
raise UnsupportedSpaceException(
"Action space ({}) of {} is not supported for "

View file

@ -1,9 +1,12 @@
import numpy as np
import gym
from typing import List, Dict, Union
from ray.rllib.models.torch.misc import SlimFC
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
from ray.rllib.models.utils import get_activation_fn
from ray.rllib.utils.framework import try_import_torch
from ray.rllib.utils.typing import ModelConfigDict, TensorType
torch, nn = try_import_torch()
@ -20,18 +23,20 @@ class DDPGTorchModel(TorchModelV2, nn.Module):
Note that this class by itself is not a valid model unless you
implement forward() in a subclass."""
def __init__(self,
obs_space,
action_space,
num_outputs,
model_config,
name,
actor_hidden_activation="relu",
actor_hiddens=(256, 256),
critic_hidden_activation="relu",
critic_hiddens=(256, 256),
twin_q=False,
add_layer_norm=False):
def __init__(
self,
obs_space: gym.spaces.Space,
action_space: gym.spaces.Space,
num_outputs: int,
model_config: ModelConfigDict,
name: str,
# Extra DDPGActionModel args:
actor_hiddens: List[int] = [256, 256],
actor_hidden_activation: str = "relu",
critic_hiddens: List[int] = [256, 256],
critic_hidden_activation: str = "relu",
twin_q: bool = False,
add_layer_norm: bool = False):
"""Initialize variables of this model.
Extra model kwargs:
@ -137,7 +142,8 @@ class DDPGTorchModel(TorchModelV2, nn.Module):
else:
self.twin_q_model = None
def get_q_values(self, model_out, actions):
def get_q_values(self, model_out: TensorType,
actions: TensorType) -> TensorType:
"""Return the Q estimates for the most recent forward pass.
This implements Q(s, a).
@ -153,7 +159,8 @@ class DDPGTorchModel(TorchModelV2, nn.Module):
"""
return self.q_model(torch.cat([model_out, actions], -1))
def get_twin_q_values(self, model_out, actions):
def get_twin_q_values(self, model_out: TensorType,
actions: TensorType) -> TensorType:
"""Same as get_q_values but using the twin Q net.
This implements the twin Q(s, a).
@ -169,7 +176,7 @@ class DDPGTorchModel(TorchModelV2, nn.Module):
"""
return self.twin_q_model(torch.cat([model_out, actions], -1))
def get_policy_output(self, model_out):
def get_policy_output(self, model_out: TensorType) -> TensorType:
"""Return the action output for the most recent forward pass.
This outputs the support for pi(s). For continuous action spaces, this
@ -184,13 +191,15 @@ class DDPGTorchModel(TorchModelV2, nn.Module):
"""
return self.policy_model(model_out)
def policy_variables(self, as_dict=False):
def policy_variables(self, as_dict: bool = False
) -> Union[List[TensorType], Dict[str, TensorType]]:
"""Return the list of variables for the policy net."""
if as_dict:
return self.policy_model.state_dict()
return list(self.policy_model.parameters())
def q_variables(self, as_dict=False):
def q_variables(self, as_dict=False
) -> Union[List[TensorType], Dict[str, TensorType]]:
"""Return the list of variables for Q / twin Q nets."""
if as_dict:
return {

View file

@ -1,24 +1,34 @@
import logging
import gym
from typing import Dict, Tuple
import ray
from ray.rllib.agents.ddpg.ddpg_tf_policy import build_ddpg_models, \
get_distribution_inputs_and_class, validate_spaces
from ray.rllib.agents.dqn.dqn_tf_policy import postprocess_nstep_and_prio, \
PRIO_WEIGHTS
from ray.rllib.models.action_dist import ActionDistribution
from ray.rllib.models.modelv2 import ModelV2
from ray.rllib.models.torch.torch_action_dist import TorchDeterministic, \
TorchDirichlet
from ray.rllib.policy.policy import Policy
from ray.rllib.policy.policy_template import build_policy_class
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.utils.framework import try_import_torch
from ray.rllib.utils.spaces.simplex import Simplex
from ray.rllib.utils.torch_ops import apply_grad_clipping, huber_loss, l2_loss
from ray.rllib.utils.typing import TrainerConfigDict, TensorType, \
LocalOptimizer, GradInfoDict
torch, nn = try_import_torch()
logger = logging.getLogger(__name__)
def build_ddpg_models_and_action_dist(policy, obs_space, action_space, config):
def build_ddpg_models_and_action_dist(
policy: Policy, obs_space: gym.spaces.Space,
action_space: gym.spaces.Space,
config: TrainerConfigDict) -> Tuple[ModelV2, ActionDistribution]:
model = build_ddpg_models(policy, obs_space, action_space, config)
# TODO(sven): Unify this once we generically support creating more than
# one Model per policy. Note: Device placement is done automatically
@ -33,7 +43,8 @@ def build_ddpg_models_and_action_dist(policy, obs_space, action_space, config):
return model, TorchDeterministic
def ddpg_actor_critic_loss(policy, model, _, train_batch):
def ddpg_actor_critic_loss(policy: Policy, model: ModelV2, _,
train_batch: SampleBatch) -> TensorType:
twin_q = policy.config["twin_q"]
gamma = policy.config["gamma"]
n_step = policy.config["n_step"]
@ -173,7 +184,8 @@ def ddpg_actor_critic_loss(policy, model, _, train_batch):
return policy.actor_loss, policy.critic_loss
def make_ddpg_optimizers(policy, config):
def make_ddpg_optimizers(policy: Policy,
config: TrainerConfigDict) -> Tuple[LocalOptimizer]:
"""Create separate optimizers for actor & critic losses."""
# Set epsilons to match tf.keras.optimizers.Adam's epsilon default.
@ -189,7 +201,7 @@ def make_ddpg_optimizers(policy, config):
return policy._actor_optimizer, policy._critic_optimizer
def apply_gradients_fn(policy, gradients):
def apply_gradients_fn(policy: Policy, gradients: GradInfoDict) -> None:
# For policy gradient, update policy net one time v.s.
# update critic net `policy_delay` time(s).
if policy.global_step % policy.config["policy_delay"] == 0:
@ -201,7 +213,8 @@ def apply_gradients_fn(policy, gradients):
policy.global_step += 1
def build_ddpg_stats(policy, batch):
def build_ddpg_stats(policy: Policy,
batch: SampleBatch) -> Dict[str, TensorType]:
stats = {
"actor_loss": policy.actor_loss,
"critic_loss": policy.critic_loss,
@ -214,7 +227,9 @@ def build_ddpg_stats(policy, batch):
return stats
def before_init_fn(policy, obs_space, action_space, config):
def before_init_fn(policy: Policy, obs_space: gym.spaces.Space,
action_space: gym.spaces.Space,
config: TrainerConfigDict) -> None:
# Create global step for counting the number of update operations.
policy.global_step = 0
@ -247,7 +262,7 @@ class TargetNetworkMixin:
# Hard initial update from Q-net(s) to target Q-net(s).
self.update_target(tau=1.0)
def update_target(self, tau=None):
def update_target(self, tau: int = None):
tau = tau or self.config.get("tau")
# Update_target_fn will be called periodically to copy Q network to
# target Q network, using (soft) tau-synching.
@ -265,7 +280,9 @@ class TargetNetworkMixin:
(1.0 - tau) * var_target.data
def setup_late_mixins(policy, obs_space, action_space, config):
def setup_late_mixins(policy: Policy, obs_space: gym.spaces.Space,
action_space: gym.spaces.Space,
config: TrainerConfigDict) -> None:
ComputeTDErrorMixin.__init__(policy, ddpg_actor_critic_loss)
TargetNetworkMixin.__init__(policy)

View file

@ -375,7 +375,9 @@ def compute_q_values(policy: Policy,
return value, logits, dist, state
def _adjust_nstep(n_step, gamma, obs, actions, rewards, new_obs, dones):
def _adjust_nstep(n_step: int, gamma: int, obs: TensorType,
actions: TensorType, rewards: TensorType,
new_obs: TensorType, dones: TensorType):
"""Rewrites the given trajectory fragments to encode n-step rewards.
reward[i] = (

View file

@ -23,7 +23,7 @@ BC_DEFAULT_CONFIG = MARWILTrainer.merge_trainer_configs(
# yapf: enable
def validate_config(config: TrainerConfigDict):
def validate_config(config: TrainerConfigDict) -> None:
if config["beta"] != 0.0:
raise ValueError(
"For behavioral cloning, `beta` parameter must be 0.0!")

View file

@ -1,3 +1,5 @@
from typing import Optional, Type
from ray.rllib.agents.trainer import with_common_config
from ray.rllib.agents.trainer_template import build_trainer
from ray.rllib.agents.marwil.marwil_tf_policy import MARWILTFPolicy
@ -7,6 +9,10 @@ from ray.rllib.execution.rollout_ops import ParallelRollouts, ConcatBatches
from ray.rllib.execution.concurrency_ops import Concurrently
from ray.rllib.execution.train_ops import TrainOneStep
from ray.rllib.execution.metric_ops import StandardMetricsReporting
from ray.rllib.utils.typing import TrainerConfigDict
from ray.rllib.evaluation.worker_set import WorkerSet
from ray.util.iter import LocalIterator
from ray.rllib.policy.policy import Policy
# yapf: disable
# __sphinx_doc_begin__
@ -46,14 +52,36 @@ DEFAULT_CONFIG = with_common_config({
# yapf: enable
def get_policy_class(config):
def get_policy_class(config: TrainerConfigDict) -> Optional[Type[Policy]]:
"""Policy class picker function. Class is chosen based on DL-framework.
MARWIL/BC have both TF and Torch policy support.
Args:
config (TrainerConfigDict): The trainer's configuration dict.
Returns:
Optional[Type[Policy]]: The Policy class to use with DQNTrainer.
If None, use `default_policy` provided in build_trainer().
"""
if config["framework"] == "torch":
from ray.rllib.agents.marwil.marwil_torch_policy import \
MARWILTorchPolicy
return MARWILTorchPolicy
def execution_plan(workers, config):
def execution_plan(workers: WorkerSet,
config: TrainerConfigDict) -> LocalIterator[dict]:
"""Execution plan of the MARWIL/BC algorithm. Defines the distributed
dataflow.
Args:
workers (WorkerSet): The WorkerSet for training the Polic(y/ies)
of the Trainer.
config (TrainerConfigDict): The trainer's configuration dict.
Returns:
LocalIterator[dict]: A local iterator over training metrics.
"""
rollouts = ParallelRollouts(workers, mode="bulk_sync")
replay_buffer = SimpleReplayBuffer(config["replay_buffer_size"])
@ -74,7 +102,11 @@ def execution_plan(workers, config):
return StandardMetricsReporting(train_op, workers, config)
def validate_config(config):
def validate_config(config: TrainerConfigDict) -> None:
"""Checks and updates the config based on settings.
Rewrites rollout_fragment_length to take into account n_step truncation.
"""
if config["num_gpus"] > 1:
raise ValueError("`num_gpus` > 1 not yet supported for MARWIL!")

View file

@ -1,4 +1,6 @@
import logging
import gym
from typing import Optional, Dict
import ray
from ray.rllib.agents.ppo.ppo_tf_policy import compute_and_clip_gradients
@ -8,6 +10,11 @@ from ray.rllib.evaluation.postprocessing import compute_advantages, \
from ray.rllib.policy.tf_policy_template import build_tf_policy
from ray.rllib.utils.framework import try_import_tf, get_variable
from ray.rllib.utils.tf_ops import explained_variance, make_tf_callable
from ray.rllib.policy.policy import Policy
from ray.rllib.utils.typing import TrainerConfigDict, TensorType, \
PolicyID
from ray.rllib.models.action_dist import ActionDistribution
from ray.rllib.models.modelv2 import ModelV2
tf1, tf, tfv = try_import_tf()
@ -15,7 +22,8 @@ logger = logging.getLogger(__name__)
class ValueNetworkMixin:
def __init__(self, obs_space, action_space, config):
def __init__(self, obs_space: gym.spaces.Space,
action_space: gym.spaces.Space, config: TrainerConfigDict):
# Input dict is provided to us automatically via the Model's
# requirements. It's a single-timestep (last one in trajectory)
@ -29,10 +37,11 @@ class ValueNetworkMixin:
self._value = value
def postprocess_advantages(policy,
sample_batch,
other_agent_batches=None,
episode=None):
def postprocess_advantages(
policy: Policy,
sample_batch: SampleBatch,
other_agent_batches: Optional[Dict[PolicyID, SampleBatch]] = None,
episode=None) -> SampleBatch:
"""Postprocesses a trajectory and returns the processed trajectory.
The trajectory contains only data from one episode and from one agent.
@ -84,8 +93,10 @@ def postprocess_advantages(policy,
class MARWILLoss:
def __init__(self, policy, value_estimates, action_dist, actions,
cumulative_rewards, vf_loss_coeff, beta):
def __init__(self, policy: Policy, value_estimates: TensorType,
action_dist: ActionDistribution, actions: TensorType,
cumulative_rewards: TensorType, vf_loss_coeff: float,
beta: float):
# Advantage Estimation.
adv = cumulative_rewards - value_estimates
@ -133,7 +144,8 @@ class MARWILLoss:
explained_variance(cumulative_rewards, value_estimates))
def marwil_loss(policy, model, dist_class, train_batch):
def marwil_loss(policy: Policy, model: ModelV2, dist_class: ActionDistribution,
train_batch: SampleBatch) -> TensorType:
model_out, _ = model.from_batch(train_batch)
action_dist = dist_class(model_out, model)
value_estimates = model.value_function()
@ -146,7 +158,7 @@ def marwil_loss(policy, model, dist_class, train_batch):
return policy.loss.total_loss
def stats(policy, train_batch):
def stats(policy: Policy, train_batch: SampleBatch) -> Dict[str, TensorType]:
return {
"policy_loss": policy.loss.p_loss,
"vf_loss": policy.loss.v_loss,
@ -155,7 +167,9 @@ def stats(policy, train_batch):
}
def setup_mixins(policy, obs_space, action_space, config):
def setup_mixins(policy: Policy, obs_space: gym.spaces.Space,
action_space: gym.spaces.Space,
config: TrainerConfigDict) -> None:
ValueNetworkMixin.__init__(policy, obs_space, action_space, config)
# Set up a tf-var for the moving avg (do this here to make it work with
# eager mode); "c^2" in the paper.

View file

@ -1,3 +1,6 @@
import gym
from typing import Dict
import ray
from ray.rllib.agents.ppo.ppo_torch_policy import ValueNetworkMixin
from ray.rllib.agents.marwil.marwil_tf_policy import postprocess_advantages
@ -6,11 +9,16 @@ from ray.rllib.policy.policy_template import build_policy_class
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.utils.framework import try_import_torch
from ray.rllib.utils.torch_ops import apply_grad_clipping, explained_variance
from ray.rllib.utils.typing import TrainerConfigDict, TensorType
from ray.rllib.policy.policy import Policy
from ray.rllib.models.action_dist import ActionDistribution
from ray.rllib.models.modelv2 import ModelV2
torch, _ = try_import_torch()
def marwil_loss(policy, model, dist_class, train_batch):
def marwil_loss(policy: Policy, model: ModelV2, dist_class: ActionDistribution,
train_batch: SampleBatch) -> TensorType:
model_out, _ = model.from_batch(train_batch)
action_dist = dist_class(model_out, model)
state_values = model.value_function()
@ -43,7 +51,7 @@ def marwil_loss(policy, model, dist_class, train_batch):
return policy.total_loss
def stats(policy, train_batch):
def stats(policy: Policy, train_batch: SampleBatch) -> Dict[str, TensorType]:
return {
"policy_loss": policy.p_loss,
"vf_loss": policy.v_loss,
@ -52,7 +60,9 @@ def stats(policy, train_batch):
}
def setup_mixins(policy, obs_space, action_space, config):
def setup_mixins(policy: Policy, obs_space: gym.spaces.Space,
action_space: gym.spaces.Space,
config: TrainerConfigDict) -> None:
# Create a var.
policy.ma_adv_norm = torch.tensor(
[100.0], dtype=torch.float32, requires_grad=False).to(policy.device)

View file

@ -123,7 +123,8 @@ class SampleBatch(dict):
"""Concatenates n data dicts or MultiAgentBatches.
Args:
samples (List[Dict[TensorType]]]): List of dicts of data (numpy).
samples (List[Dict[str, TensorType]]]): List of dicts of data
(numpy).
Returns:
Union[SampleBatch, MultiAgentBatch]: A new (compressed)