[RLlib] Unify the way we create local replay buffer for all agents (#19627)

* [RLlib] Unify the way we create and use LocalReplayBuffer for all the agents.

This change
1. Get rid of the try...except clause when we call execution_plan(),
   and get rid of the Deprecation warning as a result.
2. Fix the execution_plan() call in Trainer._try_recover() too.
3. Most importantly, makes it much easier to create and use different types
   of local replay buffers for all our agents.
   E.g., allow us to easily create a reservoir sampling replay buffer for
   APPO agent for Riot in the near future.
* Introduce explicit configuration for replay buffer types.
* Fix is_training key error.
* actually deprecate buffer_size field.
This commit is contained in:
gjoliver 2021-10-26 11:56:02 -07:00 committed by GitHub
parent ab15dfd478
commit 99a0088233
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
22 changed files with 180 additions and 117 deletions

View file

@ -29,8 +29,8 @@ A2C_DEFAULT_CONFIG = merge_dicts(
)
def execution_plan(workers: WorkerSet,
config: TrainerConfigDict) -> LocalIterator[dict]:
def execution_plan(workers: WorkerSet, config: TrainerConfigDict,
**kwargs) -> LocalIterator[dict]:
"""Execution plan of the A2C algorithm. Defines the distributed
dataflow.
@ -42,6 +42,9 @@ def execution_plan(workers: WorkerSet,
Returns:
LocalIterator[dict]: A local iterator over training metrics.
"""
assert len(kwargs) == 0, (
"A2C execution_plan does NOT take any additional parameters")
rollouts = ParallelRollouts(workers, mode="bulk_sync")
if config["microbatch_size"]:

View file

@ -78,8 +78,8 @@ def validate_config(config: TrainerConfigDict) -> None:
raise ValueError("`num_workers` for A3C must be >= 1!")
def execution_plan(workers: WorkerSet,
config: TrainerConfigDict) -> LocalIterator[dict]:
def execution_plan(workers: WorkerSet, config: TrainerConfigDict,
**kwargs) -> LocalIterator[dict]:
"""Execution plan of the MARWIL/BC algorithm. Defines the distributed
dataflow.
@ -91,6 +91,9 @@ def execution_plan(workers: WorkerSet,
Returns:
LocalIterator[dict]: A local iterator over training metrics.
"""
assert len(kwargs) == 0, (
"A3C execution_plan does NOT take any additional parameters")
# For A3C, compute policy gradients remotely on the rollout workers.
grads = AsyncGradients(workers)

View file

@ -9,7 +9,6 @@ from ray.rllib.agents.cql.cql_torch_policy import CQLTorchPolicy
from ray.rllib.agents.sac.sac import SACTrainer, \
DEFAULT_CONFIG as SAC_CONFIG
from ray.rllib.execution.metric_ops import StandardMetricsReporting
from ray.rllib.execution.replay_buffer import LocalReplayBuffer
from ray.rllib.execution.replay_ops import Replay
from ray.rllib.execution.train_ops import MultiGPUTrainOneStep, TrainOneStep, \
UpdateTargetNetwork
@ -17,6 +16,7 @@ from ray.rllib.offline.shuffled_input import ShuffledInput
from ray.rllib.policy.policy import Policy
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.utils import merge_dicts
from ray.rllib.utils.deprecation import DEPRECATED_VALUE
from ray.rllib.utils.framework import try_import_tf, try_import_tfp
from ray.rllib.utils.metrics.learner_info import LEARNER_STATS_KEY
from ray.rllib.utils.typing import TrainerConfigDict
@ -24,7 +24,6 @@ from ray.rllib.utils.typing import TrainerConfigDict
tf1, tf, tfv = try_import_tf()
tfp = try_import_tfp()
logger = logging.getLogger(__name__)
replay_buffer = None
# yapf: disable
# __sphinx_doc_begin__
@ -48,7 +47,11 @@ CQL_DEFAULT_CONFIG = merge_dicts(
"min_q_weight": 5.0,
# Replay buffer should be larger or equal the size of the offline
# dataset.
"buffer_size": int(1e6),
"buffer_size": DEPRECATED_VALUE,
"replay_buffer_config": {
"type": "LocalReplayBuffer",
"capacity": int(1e6),
},
})
# __sphinx_doc_end__
# yapf: enable
@ -74,29 +77,11 @@ def validate_config(config: TrainerConfigDict):
try_import_tfp(error=True)
def execution_plan(workers, config):
if config.get("prioritized_replay"):
prio_args = {
"prioritized_replay_alpha": config["prioritized_replay_alpha"],
"prioritized_replay_beta": config["prioritized_replay_beta"],
"prioritized_replay_eps": config["prioritized_replay_eps"],
}
else:
prio_args = {}
def execution_plan(workers, config, **kwargs):
assert "local_replay_buffer" in kwargs, (
"CQL execution plan requires a local replay buffer.")
local_replay_buffer = LocalReplayBuffer(
num_shards=1,
learning_starts=config["learning_starts"],
capacity=config["buffer_size"],
replay_batch_size=config["train_batch_size"],
replay_mode=config["multiagent"]["replay_mode"],
replay_sequence_length=config.get("replay_sequence_length", 1),
replay_burn_in=config.get("burn_in", 0),
replay_zero_init_states=config.get("zero_init_states", True),
**prio_args)
global replay_buffer
replay_buffer = local_replay_buffer
local_replay_buffer = kwargs["local_replay_buffer"]
def update_prio(item):
samples, info_dict = item
@ -150,8 +135,8 @@ def get_policy_class(config: TrainerConfigDict) -> Optional[Type[Policy]]:
def after_init(trainer):
# Add the entire dataset to Replay Buffer (global variable)
global replay_buffer
reader = trainer.workers.local_worker().input_reader
replay_buffer = trainer.local_replay_buffer
# For d4rl, add the D4RLReaders' dataset to the buffer.
if isinstance(trainer.config["input"], str) and \

View file

@ -88,8 +88,8 @@ class TestCQL(unittest.TestCase):
# Example on how to do evaluation on the trained Trainer
# using the data from CQL's global replay buffer.
# Get a sample (MultiAgentBatch -> SampleBatch).
from ray.rllib.agents.cql.cql import replay_buffer
batch = replay_buffer.replay().policy_batches["default_policy"]
batch = trainer.local_replay_buffer.replay().policy_batches[
"default_policy"]
if fw == "torch":
obs = torch.from_numpy(batch["obs"])

View file

@ -55,6 +55,9 @@ APEX_DEFAULT_CONFIG = merge_dicts(
"num_gpus": 1,
"num_workers": 32,
"buffer_size": 2000000,
# TODO(jungong) : add proper replay_buffer_config after
# DistributedReplayBuffer type is supported.
"replay_buffer_config": None,
"learning_starts": 50000,
"train_batch_size": 512,
"rollout_fragment_length": 50,
@ -141,8 +144,11 @@ class UpdateWorkerWeights:
metrics.counters["num_weight_syncs"] += 1
def apex_execution_plan(workers: WorkerSet,
config: dict) -> LocalIterator[dict]:
def apex_execution_plan(workers: WorkerSet, config: dict,
**kwargs) -> LocalIterator[dict]:
assert len(kwargs) == 0, (
"Apex execution_plan does NOT take any additional parameters")
# Create a number of replay buffer actors.
num_replay_buffer_shards = config["optimizer"]["num_replay_buffer_shards"]
replay_actors = create_colocated(ReplayActor, [

View file

@ -20,7 +20,6 @@ from ray.rllib.agents.trainer import Trainer
from ray.rllib.evaluation.worker_set import WorkerSet
from ray.rllib.execution.concurrency_ops import Concurrently
from ray.rllib.execution.metric_ops import StandardMetricsReporting
from ray.rllib.execution.replay_buffer import LocalReplayBuffer
from ray.rllib.execution.replay_ops import Replay, StoreToReplayBuffer
from ray.rllib.execution.rollout_ops import ParallelRollouts
from ray.rllib.execution.train_ops import TrainOneStep, UpdateTargetNetwork, \
@ -145,8 +144,8 @@ def validate_config(config: TrainerConfigDict) -> None:
"simple_optimizer=True if this doesn't work for you.")
def execution_plan(trainer: Trainer, workers: WorkerSet,
config: TrainerConfigDict, **kwargs) -> LocalIterator[dict]:
def execution_plan(workers: WorkerSet, config: TrainerConfigDict,
**kwargs) -> LocalIterator[dict]:
"""Execution plan of the DQN algorithm. Defines the distributed dataflow.
Args:
@ -158,28 +157,12 @@ def execution_plan(trainer: Trainer, workers: WorkerSet,
Returns:
LocalIterator[dict]: A local iterator over training metrics.
"""
if config.get("prioritized_replay"):
prio_args = {
"prioritized_replay_alpha": config["prioritized_replay_alpha"],
"prioritized_replay_beta": config["prioritized_replay_beta"],
"prioritized_replay_eps": config["prioritized_replay_eps"],
}
else:
prio_args = {}
assert "local_replay_buffer" in kwargs, (
"DQN execution plan requires a local replay buffer.")
local_replay_buffer = LocalReplayBuffer(
num_shards=1,
learning_starts=config["learning_starts"],
capacity=config["buffer_size"],
replay_batch_size=config["train_batch_size"],
replay_mode=config["multiagent"]["replay_mode"],
replay_sequence_length=config.get("replay_sequence_length", 1),
replay_burn_in=config.get("burn_in", 0),
replay_zero_init_states=config.get("zero_init_states", True),
**prio_args)
# Assign to Trainer, so we can store the LocalReplayBuffer's
# data when we save checkpoints.
trainer.local_replay_buffer = local_replay_buffer
local_replay_buffer = kwargs["local_replay_buffer"]
rollouts = ParallelRollouts(workers, mode="bulk_sync")

View file

@ -14,17 +14,17 @@ from typing import Optional, Type
from ray.rllib.agents.dqn.simple_q_tf_policy import SimpleQTFPolicy
from ray.rllib.agents.dqn.simple_q_torch_policy import SimpleQTorchPolicy
from ray.rllib.agents.trainer import Trainer, with_common_config
from ray.rllib.agents.trainer import with_common_config
from ray.rllib.agents.trainer_template import build_trainer
from ray.rllib.evaluation.worker_set import WorkerSet
from ray.rllib.execution.concurrency_ops import Concurrently
from ray.rllib.execution.metric_ops import StandardMetricsReporting
from ray.rllib.execution.replay_buffer import LocalReplayBuffer
from ray.rllib.execution.replay_ops import Replay, StoreToReplayBuffer
from ray.rllib.execution.rollout_ops import ParallelRollouts
from ray.rllib.execution.train_ops import MultiGPUTrainOneStep, TrainOneStep, \
UpdateTargetNetwork
from ray.rllib.policy.policy import Policy
from ray.rllib.utils.deprecation import DEPRECATED_VALUE
from ray.rllib.utils.typing import TrainerConfigDict
from ray.util.iter import LocalIterator
@ -62,7 +62,11 @@ DEFAULT_CONFIG = with_common_config({
# === Replay buffer ===
# Size of the replay buffer. Note that if async_updates is set, then
# each worker will have a replay buffer of this size.
"buffer_size": 50000,
"buffer_size": DEPRECATED_VALUE,
"replay_buffer_config": {
"type": "LocalReplayBuffer",
"capacity": 50000,
},
# Set this to True, if you want the contents of your buffer(s) to be
# stored in any saved checkpoints as well.
# Warnings will be created if:
@ -122,8 +126,8 @@ def get_policy_class(config: TrainerConfigDict) -> Optional[Type[Policy]]:
return SimpleQTorchPolicy
def execution_plan(trainer: Trainer, workers: WorkerSet,
config: TrainerConfigDict, **kwargs) -> LocalIterator[dict]:
def execution_plan(workers: WorkerSet, config: TrainerConfigDict,
**kwargs) -> LocalIterator[dict]:
"""Execution plan of the Simple Q algorithm. Defines the distributed dataflow.
Args:
@ -135,16 +139,10 @@ def execution_plan(trainer: Trainer, workers: WorkerSet,
Returns:
LocalIterator[dict]: A local iterator over training metrics.
"""
local_replay_buffer = LocalReplayBuffer(
num_shards=1,
learning_starts=config["learning_starts"],
capacity=config["buffer_size"],
replay_batch_size=config["train_batch_size"],
replay_mode=config["multiagent"]["replay_mode"],
replay_sequence_length=config["replay_sequence_length"])
# Assign to Trainer, so we can store the LocalReplayBuffer's
# data when we save checkpoints.
trainer.local_replay_buffer = local_replay_buffer
assert "local_replay_buffer" in kwargs, (
"SimpleQ execution plan requires a local replay buffer.")
local_replay_buffer = kwargs["local_replay_buffer"]
rollouts = ParallelRollouts(workers, mode="bulk_sync")

View file

@ -185,7 +185,10 @@ class DreamerIteration:
return fetches[DEFAULT_POLICY_ID]["learner_stats"]
def execution_plan(workers, config):
def execution_plan(workers, config, **kwargs):
assert len(kwargs) == 0, (
"Dreamer execution_plan does NOT take any additional parameters")
# Special replay buffer for Dreamer agent.
episode_buffer = EpisodicBuffer(length=config["batch_length"])

View file

@ -312,7 +312,10 @@ def gather_experiences_directly(workers, config):
return train_batches
def execution_plan(workers, config):
def execution_plan(workers, config, **kwargs):
assert len(kwargs) == 0, (
"IMPALA execution_plan does NOT take any additional parameters")
if config["num_aggregation_workers"] > 0:
train_batches = gather_experiences_tree_aggregation(workers, config)
else:

View file

@ -156,7 +156,10 @@ def inner_adaptation(workers, samples):
e.learn_on_batch.remote(samples[i])
def execution_plan(workers, config):
def execution_plan(workers, config, **kwargs):
assert len(kwargs) == 0, (
"MAML execution_plan does NOT take any additional parameters")
# Sync workers with meta policy
workers.sync_weights()

View file

@ -90,8 +90,8 @@ def get_policy_class(config: TrainerConfigDict) -> Optional[Type[Policy]]:
return MARWILTorchPolicy
def execution_plan(workers: WorkerSet,
config: TrainerConfigDict) -> LocalIterator[dict]:
def execution_plan(workers: WorkerSet, config: TrainerConfigDict,
**kwargs) -> LocalIterator[dict]:
"""Execution plan of the MARWIL/BC algorithm. Defines the distributed
dataflow.
@ -103,6 +103,9 @@ def execution_plan(workers: WorkerSet,
Returns:
LocalIterator[dict]: A local iterator over training metrics.
"""
assert len(kwargs) == 0, (
"Marwill execution_plan does NOT take any additional parameters")
rollouts = ParallelRollouts(workers, mode="bulk_sync")
replay_buffer = LocalReplayBuffer(
learning_starts=config["learning_starts"],

View file

@ -336,8 +336,8 @@ def post_process_samples(samples, config: TrainerConfigDict):
return samples, split_lst
def execution_plan(workers: WorkerSet,
config: TrainerConfigDict) -> LocalIterator[dict]:
def execution_plan(workers: WorkerSet, config: TrainerConfigDict,
**kwargs) -> LocalIterator[dict]:
"""Execution plan of the PPO algorithm. Defines the distributed dataflow.
Args:
@ -349,6 +349,9 @@ def execution_plan(workers: WorkerSet,
LocalIterator[dict]: The Policy class to use with PPOTrainer.
If None, use `default_policy` provided in build_trainer().
"""
assert len(kwargs) == 0, (
"MBMPO execution_plan does NOT take any additional parameters")
# Train TD Models on the driver.
workers.local_worker().foreach_policy(fit_dynamics)

View file

@ -146,8 +146,8 @@ def validate_config(config):
raise ValueError("DDPPO doesn't support KL penalties like PPO-1")
def execution_plan(workers: WorkerSet,
config: TrainerConfigDict) -> LocalIterator[dict]:
def execution_plan(workers: WorkerSet, config: TrainerConfigDict,
**kwargs) -> LocalIterator[dict]:
"""Execution plan of the DD-PPO algorithm. Defines the distributed dataflow.
Args:
@ -159,6 +159,9 @@ def execution_plan(workers: WorkerSet,
LocalIterator[dict]: The Policy class to use with PGTrainer.
If None, use `default_policy` provided in build_trainer().
"""
assert len(kwargs) == 0, (
"DDPPO execution_plan does NOT take any additional parameters")
rollouts = ParallelRollouts(workers, mode="raw")
# Setup the distributed processes.

View file

@ -253,8 +253,8 @@ def warn_about_bad_reward_scales(config, result):
return result
def execution_plan(workers: WorkerSet,
config: TrainerConfigDict) -> LocalIterator[dict]:
def execution_plan(workers: WorkerSet, config: TrainerConfigDict,
**kwargs) -> LocalIterator[dict]:
"""Execution plan of the PPO algorithm. Defines the distributed dataflow.
Args:
@ -266,6 +266,9 @@ def execution_plan(workers: WorkerSet,
LocalIterator[dict]: The Policy class to use with PPOTrainer.
If None, use `default_policy` provided in build_trainer().
"""
assert len(kwargs) == 0, (
"PPO execution_plan does NOT take any additional parameters")
rollouts = ParallelRollouts(workers, mode="bulk_sync")
# Collect batches for the trainable policies.

View file

@ -100,7 +100,10 @@ DEFAULT_CONFIG = with_common_config({
# yapf: enable
def execution_plan(workers, config):
def execution_plan(workers, config, **kwargs):
assert len(kwargs) == 0, (
"QMIX execution_plan does NOT take any additional parameters")
rollouts = ParallelRollouts(workers, mode="bulk_sync")
replay_buffer = SimpleReplayBuffer(config["buffer_size"])

View file

@ -248,7 +248,7 @@ class SACTFModel(TFModelV2):
input_dict = {"obs": model_out}
# Switch on training mode (when getting Q-values, we are usually in
# training).
input_dict.is_training = True
input_dict["is_training"] = True
out, _ = net(input_dict, [], None)
return out

View file

@ -256,7 +256,7 @@ class SACTorchModel(TorchModelV2, nn.Module):
input_dict = {"obs": model_out}
# Switch on training mode (when getting Q-values, we are usually in
# training).
input_dict.is_training = True
input_dict["is_training"] = True
out, _ = net(input_dict, [], None)
return out

View file

@ -22,11 +22,11 @@ from ray.rllib.evaluation.worker_set import WorkerSet
from ray.rllib.examples.policy.random_policy import RandomPolicy
from ray.rllib.execution.concurrency_ops import Concurrently
from ray.rllib.execution.metric_ops import StandardMetricsReporting
from ray.rllib.execution.replay_buffer import LocalReplayBuffer
from ray.rllib.execution.replay_ops import Replay, StoreToReplayBuffer
from ray.rllib.execution.rollout_ops import ParallelRollouts
from ray.rllib.execution.train_ops import TrainOneStep
from ray.rllib.policy.policy import Policy
from ray.rllib.utils.deprecation import DEPRECATED_VALUE
from ray.rllib.utils.typing import TrainerConfigDict
from ray.util.iter import LocalIterator
@ -82,7 +82,11 @@ DEFAULT_CONFIG = with_common_config({
# === Replay buffer ===
# Size of the replay buffer. Note that if async_updates is set, then
# each worker will have a replay buffer of this size.
"buffer_size": 50000,
"buffer_size": DEPRECATED_VALUE,
"replay_buffer_config": {
"type": "LocalReplayBuffer",
"capacity": 50000,
},
# The number of contiguous environment steps to replay at once. This may
# be set to greater than 1 to support recurrent models.
"replay_sequence_length": 1,
@ -152,8 +156,8 @@ def validate_config(config: TrainerConfigDict) -> None:
"For SARSA strategy, batch_mode must be 'complete_episodes'")
def execution_plan(workers: WorkerSet,
config: TrainerConfigDict) -> LocalIterator[dict]:
def execution_plan(workers: WorkerSet, config: TrainerConfigDict,
**kwargs) -> LocalIterator[dict]:
"""Execution plan of the SlateQ algorithm. Defines the distributed dataflow.
Args:
@ -164,14 +168,8 @@ def execution_plan(workers: WorkerSet,
Returns:
LocalIterator[dict]: A local iterator over training metrics.
"""
local_replay_buffer = LocalReplayBuffer(
num_shards=1,
learning_starts=config["learning_starts"],
capacity=config["buffer_size"],
replay_batch_size=config["train_batch_size"],
replay_mode=config["multiagent"]["replay_mode"],
replay_sequence_length=config["replay_sequence_length"],
)
assert "local_replay_buffer" in kwargs, (
"SlateQ execution plan requires a local replay buffer.")
rollouts = ParallelRollouts(workers, mode="bulk_sync")
@ -179,12 +177,12 @@ def execution_plan(workers: WorkerSet,
# (1) Generate rollouts and store them in our local replay buffer. Calling
# next() on store_op drives this.
store_op = rollouts.for_each(
StoreToReplayBuffer(local_buffer=local_replay_buffer))
StoreToReplayBuffer(local_buffer=kwargs["local_replay_buffer"]))
# (2) Read and train on experiences from the replay buffer. Every batch
# returned from the LocalReplay() iterator is passed to TrainOneStep to
# take a SGD step.
replay_op = Replay(local_buffer=local_replay_buffer) \
replay_op = Replay(local_buffer=kwargs["local_replay_buffer"]) \
.for_each(TrainOneStep(workers))
if config["slateq_strategy"] != "RANDOM":

View file

@ -23,6 +23,7 @@ from ray.rllib.evaluation.episode import MultiAgentEpisode
from ray.rllib.evaluation.metrics import collect_metrics
from ray.rllib.evaluation.rollout_worker import RolloutWorker
from ray.rllib.evaluation.worker_set import WorkerSet
from ray.rllib.execution.replay_buffer import LocalReplayBuffer
from ray.rllib.models import MODEL_DEFAULTS
from ray.rllib.policy.policy import Policy, PolicySpec
from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID, SampleBatch
@ -706,6 +707,67 @@ class Trainer(Trainable):
# to mutate the result
Trainable.log_result(self, result)
@DeveloperAPI
def _create_local_replay_buffer_if_necessary(self, config):
"""Create a LocalReplayBuffer instance if necessary.
Args:
config (dict): Algorithm-specific configuration data.
Returns:
LocalReplayBuffer instance based on trainer config.
None, if local replay buffer is not needed.
"""
# These are the agents that utilizes a local replay buffer.
if ("replay_buffer_config" not in config
or not config["replay_buffer_config"]):
# Does not need a replay buffer.
return None
replay_buffer_config = config["replay_buffer_config"]
if ("type" not in replay_buffer_config
or replay_buffer_config["type"] != "LocalReplayBuffer"):
# DistributedReplayBuffer coming soon.
return None
capacity = config.get("buffer_size", DEPRECATED_VALUE)
if capacity != DEPRECATED_VALUE:
# Print a deprecation warning.
deprecation_warning(
old="config['buffer_size']",
new="config['replay_buffer_config']['capacity']",
error=False)
else:
# Get capacity out of replay_buffer_config.
capacity = replay_buffer_config["capacity"]
if config.get("prioritized_replay"):
prio_args = {
"prioritized_replay_alpha": config["prioritized_replay_alpha"],
"prioritized_replay_beta": config["prioritized_replay_beta"],
"prioritized_replay_eps": config["prioritized_replay_eps"],
}
else:
prio_args = {}
return LocalReplayBuffer(
num_shards=1,
learning_starts=config["learning_starts"],
capacity=capacity,
replay_batch_size=config["train_batch_size"],
replay_mode=config["multiagent"]["replay_mode"],
replay_sequence_length=config.get("replay_sequence_length", 1),
replay_burn_in=config.get("burn_in", 0),
replay_zero_init_states=config.get("zero_init_states", True),
**prio_args)
@DeveloperAPI
def _kwargs_for_execution_plan(self):
kwargs = {}
if self.local_replay_buffer:
kwargs["local_replay_buffer"] = self.local_replay_buffer
return kwargs
@override(Trainable)
def setup(self, config: PartialTrainerConfigDict):
env = self._env_id
@ -773,6 +835,10 @@ class Trainer(Trainable):
if self.config.get("log_level"):
logging.getLogger("ray.rllib").setLevel(self.config["log_level"])
# Create local replay buffer if necessary.
self.local_replay_buffer = (
self._create_local_replay_buffer_if_necessary(self.config))
self._init(self.config, self.env_creator)
# Evaluation setup.
@ -1747,7 +1813,8 @@ class Trainer(Trainable):
logger.warning("Recreating execution plan after failure")
workers.reset(healthy_workers)
self.train_exec_impl = self.execution_plan(workers, self.config)
self.train_exec_impl = self.execution_plan(
workers, self.config, **self._kwargs_for_execution_plan())
@override(Trainable)
def _export_model(self, export_formats: List[str],

View file

@ -18,7 +18,11 @@ from ray.rllib.utils.typing import EnvConfigDict, EnvType, \
logger = logging.getLogger(__name__)
def default_execution_plan(workers: WorkerSet, config: TrainerConfigDict):
def default_execution_plan(workers: WorkerSet, config: TrainerConfigDict,
**kwargs):
assert len(kwargs) == 0, (
"Default execution_plan does NOT take any additional parameters")
# Collects experiences in parallel from multiple RolloutWorker actors.
rollouts = ParallelRollouts(workers, mode="bulk_sync")
@ -175,19 +179,8 @@ def build_trainer(
config=config,
num_workers=self.config["num_workers"])
self.execution_plan = execution_plan
try:
self.train_exec_impl = execution_plan(self, self.workers,
config)
except TypeError as e:
# Keyword error: Try old way w/o kwargs.
if "() takes 2 positional arguments but 3" in e.args[0]:
self.train_exec_impl = execution_plan(self.workers, config)
logger.warning(
"`execution_plan` functions should accept "
"`trainer`, `workers`, and `config` as args!")
# Other error -> re-raise.
else:
raise e
self.train_exec_impl = execution_plan(
self.workers, config, **self._kwargs_for_execution_plan())
if after_init:
after_init(self)

View file

@ -160,7 +160,10 @@ class AlphaZeroPolicyWrapperClass(AlphaZeroPolicy):
_env_creator)
def execution_plan(workers, config):
def execution_plan(workers, config, **kwargs):
assert len(kwargs) == 0, (
"Alpha zero execution_plan does NOT take any additional parameters")
rollouts = ParallelRollouts(workers, mode="bulk_sync")
if config["simple_optimizer"]:

View file

@ -63,8 +63,8 @@ class RandomParametriclPolicy(Policy, ABC):
pass
def execution_plan(workers: WorkerSet,
config: TrainerConfigDict) -> LocalIterator[dict]:
def execution_plan(workers: WorkerSet, config: TrainerConfigDict,
**kwargs) -> LocalIterator[dict]:
rollouts = ParallelRollouts(workers, mode="async")
# Collect batches for the trainable policies.