mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
[RLlib] MARWIL: Move to training_iteration API. (#23798)
This commit is contained in:
parent
5dc958037e
commit
a3d4fc74a6
3 changed files with 74 additions and 12 deletions
|
@ -7,11 +7,30 @@ from ray.rllib.execution.buffers.multi_agent_replay_buffer import MultiAgentRepl
|
|||
from ray.rllib.execution.concurrency_ops import Concurrently
|
||||
from ray.rllib.execution.metric_ops import StandardMetricsReporting
|
||||
from ray.rllib.execution.replay_ops import Replay, StoreToReplayBuffer
|
||||
from ray.rllib.execution.rollout_ops import ParallelRollouts, ConcatBatches
|
||||
from ray.rllib.execution.train_ops import TrainOneStep
|
||||
from ray.rllib.execution.rollout_ops import (
|
||||
ConcatBatches,
|
||||
ParallelRollouts,
|
||||
synchronous_parallel_sample,
|
||||
)
|
||||
from ray.rllib.execution.train_ops import (
|
||||
multi_gpu_train_one_step,
|
||||
TrainOneStep,
|
||||
train_one_step,
|
||||
)
|
||||
from ray.rllib.policy.policy import Policy
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils.typing import TrainerConfigDict
|
||||
from ray.rllib.utils.metrics import (
|
||||
NUM_AGENT_STEPS_SAMPLED,
|
||||
NUM_AGENT_STEPS_TRAINED,
|
||||
NUM_ENV_STEPS_SAMPLED,
|
||||
NUM_ENV_STEPS_TRAINED,
|
||||
WORKER_UPDATE_TIMER,
|
||||
)
|
||||
from ray.rllib.utils.typing import (
|
||||
PartialTrainerConfigDict,
|
||||
ResultDict,
|
||||
TrainerConfigDict,
|
||||
)
|
||||
from ray.util.iter import LocalIterator
|
||||
|
||||
# fmt: off
|
||||
|
@ -71,6 +90,9 @@ DEFAULT_CONFIG = with_common_config({
|
|||
|
||||
# === Parallelism ===
|
||||
"num_workers": 0,
|
||||
|
||||
# Use new `training_iteration` API (instead of `execution_plan` method).
|
||||
"_disable_execution_plan_api": True,
|
||||
})
|
||||
# __sphinx_doc_end__
|
||||
# fmt: on
|
||||
|
@ -105,6 +127,54 @@ class MARWILTrainer(Trainer):
|
|||
else:
|
||||
return MARWILTFPolicy
|
||||
|
||||
@override(Trainer)
|
||||
def setup(self, config: PartialTrainerConfigDict):
|
||||
super().setup(config)
|
||||
# `training_iteration` implementation: Setup buffer in `setup`, not
|
||||
# in `execution_plan` (deprecated).
|
||||
if self.config["_disable_execution_plan_api"] is True:
|
||||
self.local_replay_buffer = MultiAgentReplayBuffer(
|
||||
learning_starts=self.config["learning_starts"],
|
||||
capacity=self.config["replay_buffer_size"],
|
||||
replay_batch_size=self.config["train_batch_size"],
|
||||
replay_sequence_length=1,
|
||||
)
|
||||
|
||||
@override(Trainer)
|
||||
def training_iteration(self) -> ResultDict:
|
||||
# Collect SampleBatches from sample workers.
|
||||
batch = synchronous_parallel_sample(worker_set=self.workers)
|
||||
batch = batch.as_multi_agent()
|
||||
self._counters[NUM_AGENT_STEPS_SAMPLED] += batch.agent_steps()
|
||||
self._counters[NUM_ENV_STEPS_SAMPLED] += batch.env_steps()
|
||||
# Add batch to replay buffer.
|
||||
self.local_replay_buffer.add_batch(batch)
|
||||
|
||||
# Pull batch from replay buffer and train on it.
|
||||
train_batch = self.local_replay_buffer.replay()
|
||||
# Train.
|
||||
if self.config["simple_optimizer"]:
|
||||
train_results = train_one_step(self, train_batch)
|
||||
else:
|
||||
train_results = multi_gpu_train_one_step(self, train_batch)
|
||||
self._counters[NUM_AGENT_STEPS_TRAINED] += batch.agent_steps()
|
||||
self._counters[NUM_ENV_STEPS_TRAINED] += batch.env_steps()
|
||||
|
||||
global_vars = {
|
||||
"timestep": self._counters[NUM_AGENT_STEPS_SAMPLED],
|
||||
}
|
||||
|
||||
# Update weights - after learning on the local worker - on all remote
|
||||
# workers.
|
||||
if self.workers.remote_workers():
|
||||
with self._timers[WORKER_UPDATE_TIMER]:
|
||||
self.workers.sync_weights(global_vars=global_vars)
|
||||
|
||||
# Update global vars on local worker as well.
|
||||
self.workers.local_worker().set_global_vars(global_vars)
|
||||
|
||||
return train_results
|
||||
|
||||
@staticmethod
|
||||
@override(Trainer)
|
||||
def execution_plan(
|
||||
|
|
|
@ -14,8 +14,7 @@ DEFAULT_CONFIG = with_common_config({
|
|||
# Experimental: By default, switch off preprocessors for PG.
|
||||
"_disable_preprocessor_api": True,
|
||||
|
||||
# PG is the first algo (experimental) to not use the distr. exec API
|
||||
# anymore.
|
||||
# Use new `training_iteration` API (instead of `execution_plan` method).
|
||||
"_disable_execution_plan_api": True,
|
||||
})
|
||||
|
||||
|
|
|
@ -17,7 +17,6 @@ Note that unlike the paper, we currently do not implement straggler mitigation.
|
|||
"""
|
||||
|
||||
import logging
|
||||
import sys
|
||||
import time
|
||||
from typing import Callable, Optional, Union
|
||||
|
||||
|
@ -158,12 +157,6 @@ class DDPPOTrainer(PPOTrainer):
|
|||
# setting.
|
||||
super().validate_config(config)
|
||||
|
||||
# Error if run on Win.
|
||||
if sys.platform in ["win32", "cygwin"]:
|
||||
raise ValueError(
|
||||
"DD-PPO not supported on Win yet! Due to usage of torch.distributed."
|
||||
)
|
||||
|
||||
# Only supported for PyTorch so far.
|
||||
if config["framework"] != "torch":
|
||||
raise ValueError("Distributed data parallel is only supported for PyTorch")
|
||||
|
|
Loading…
Add table
Reference in a new issue