mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
[AIR] Add RLTrainer interface, implementation, and examples (#23465)
This PR adds a RLTrainer to Ray AIR. It works for both offline and online use cases. In offline training, it will leverage the datasets key of the Trainer API to specify a dataset reader input, used e.g. in Behavioral Cloning (BC). In online training, it is a wrapper around the rllib trainables making use of the parameter layering enabled by the Trainer API.
This commit is contained in:
parent
5a41fb18bd
commit
8c2e471265
13 changed files with 390 additions and 17 deletions
|
@ -19,11 +19,30 @@ py_test (
|
|||
deps = [":ml_lib"]
|
||||
)
|
||||
|
||||
py_test (
|
||||
name = "rl_example_offline",
|
||||
main = "examples/rl_example.py",
|
||||
size = "medium",
|
||||
srcs = ["examples/rl_example.py"],
|
||||
args = ["--offline"],
|
||||
tags = ["team:ml", "exclusive"],
|
||||
deps = [":ml_lib"]
|
||||
)
|
||||
|
||||
py_test (
|
||||
name = "rl_example_online",
|
||||
main = "examples/rl_example.py",
|
||||
size = "medium",
|
||||
srcs = ["examples/rl_example.py"],
|
||||
tags = ["team:ml", "exclusive"],
|
||||
deps = [":ml_lib"]
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "tensorflow_linear_dataset_example",
|
||||
size = "medium",
|
||||
main = "examples/tensorflow/tensorflow_linear_dataset_example.py",
|
||||
srcs = ["examples/tensorflow/tensorflow_linear_dataset_example.py"],
|
||||
main = "examples/tf/tensorflow_linear_dataset_example.py",
|
||||
srcs = ["examples/tf/tensorflow_linear_dataset_example.py"],
|
||||
tags = ["team:ml", "exclusive"],
|
||||
deps = [":ml_lib"],
|
||||
args = ["--smoke-test"]
|
||||
|
@ -32,8 +51,8 @@ py_test(
|
|||
py_test(
|
||||
name = "tensorflow_mnist_example",
|
||||
size = "medium",
|
||||
main = "examples/tensorflow/tensorflow_mnist_example.py",
|
||||
srcs = ["examples/tensorflow/tensorflow_mnist_example.py"],
|
||||
main = "examples/tf/tensorflow_mnist_example.py",
|
||||
srcs = ["examples/tf/tensorflow_mnist_example.py"],
|
||||
tags = ["team:ml", "exclusive"],
|
||||
deps = [":ml_lib"],
|
||||
args = ["--smoke-test"]
|
||||
|
@ -82,8 +101,8 @@ py_test(
|
|||
py_test(
|
||||
name = "tune_tensorflow_mnist_example",
|
||||
size = "medium",
|
||||
main = "examples/tensorflow/tune_tensorflow_mnist_example.py",
|
||||
srcs = ["examples/tensorflow/tune_tensorflow_mnist_example.py"],
|
||||
main = "examples/tf/tune_tensorflow_mnist_example.py",
|
||||
srcs = ["examples/tf/tune_tensorflow_mnist_example.py"],
|
||||
tags = ["team:ml", "exclusive"],
|
||||
deps = [":ml_lib"],
|
||||
args = ["--smoke-test"]
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
from dataclasses import dataclass
|
||||
from typing import Dict, Any, Optional, List, TYPE_CHECKING
|
||||
from typing import Dict, Any, Optional, List, Mapping, Callable, Union, TYPE_CHECKING
|
||||
|
||||
from ray.tune.syncer import SyncConfig
|
||||
from ray.util import PublicAPI
|
||||
|
@ -7,7 +7,7 @@ from ray.util import PublicAPI
|
|||
if TYPE_CHECKING:
|
||||
from ray.tune.trainable import PlacementGroupFactory
|
||||
from ray.tune.callback import Callback
|
||||
|
||||
from ray.tune.stopper import Stopper
|
||||
|
||||
ScalingConfig = Dict[str, Any]
|
||||
|
||||
|
@ -137,6 +137,8 @@ class RunConfig:
|
|||
from the Trainable.
|
||||
local_dir: Local dir to save training results to.
|
||||
Defaults to ``~/ray_results``.
|
||||
stop: Stop conditions to consider. Refer to ray.tune.stopper.Stopper
|
||||
for more info. Stoppers should be serializable.
|
||||
callbacks: Callbacks to invoke.
|
||||
Refer to ray.tune.callback.Callback for more info.
|
||||
Callbacks should be serializable.
|
||||
|
@ -151,5 +153,6 @@ class RunConfig:
|
|||
name: Optional[str] = None
|
||||
local_dir: Optional[str] = None
|
||||
callbacks: Optional[List["Callback"]] = None
|
||||
stop: Optional[Union[Mapping, "Stopper", Callable[[str, Mapping], bool]]] = None
|
||||
failure: Optional[FailureConfig] = None
|
||||
sync_config: Optional[SyncConfig] = None
|
||||
|
|
110
python/ray/ml/examples/rl_example.py
Normal file
110
python/ray/ml/examples/rl_example.py
Normal file
|
@ -0,0 +1,110 @@
|
|||
import argparse
|
||||
import os
|
||||
|
||||
import ray
|
||||
from ray.ml.config import RunConfig
|
||||
from ray.ml.train.integrations.rl.rl_trainer import RLTrainer
|
||||
from ray.ml.result import Result
|
||||
from ray.rllib.agents.marwil import BCTrainer
|
||||
|
||||
|
||||
def generate_offline_data(path: str):
|
||||
print(f"Generating offline data for training at {path}")
|
||||
trainer = RLTrainer(
|
||||
algorithm="PPO",
|
||||
run_config=RunConfig(stop={"timesteps_total": 5000}),
|
||||
config={
|
||||
"env": "CartPole-v0",
|
||||
"output": "dataset",
|
||||
"output_config": {
|
||||
"format": "json",
|
||||
"path": path,
|
||||
"max_num_samples_per_file": 1,
|
||||
},
|
||||
"batch_mode": "complete_episodes",
|
||||
},
|
||||
)
|
||||
trainer.fit()
|
||||
|
||||
|
||||
def train_rl_bc_offline(path: str, num_workers: int, use_gpu: bool = False) -> Result:
|
||||
print("Starting offline training")
|
||||
dataset = ray.data.read_json(
|
||||
path, parallelism=num_workers, ray_remote_args={"num_cpus": 1}
|
||||
)
|
||||
|
||||
trainer = RLTrainer(
|
||||
run_config=RunConfig(stop={"training_iteration": 5}),
|
||||
scaling_config={
|
||||
"num_workers": num_workers,
|
||||
"use_gpu": use_gpu,
|
||||
},
|
||||
datasets={"train": dataset},
|
||||
algorithm=BCTrainer,
|
||||
config={
|
||||
"env": "CartPole-v0",
|
||||
"framework": "tf",
|
||||
"evaluation_num_workers": 1,
|
||||
"evaluation_interval": 1,
|
||||
"evaluation_config": {"input": "sampler"},
|
||||
},
|
||||
)
|
||||
result = trainer.fit()
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def train_rl_ppo_online(num_workers: int, use_gpu: bool = False) -> Result:
|
||||
print("Starting online training")
|
||||
trainer = RLTrainer(
|
||||
run_config=RunConfig(stop={"training_iteration": 5}),
|
||||
scaling_config={
|
||||
"num_workers": num_workers,
|
||||
"use_gpu": use_gpu,
|
||||
},
|
||||
algorithm="PPO",
|
||||
config={
|
||||
"env": "CartPole-v0",
|
||||
"framework": "tf",
|
||||
"evaluation_num_workers": 1,
|
||||
"evaluation_interval": 1,
|
||||
"evaluation_config": {"input": "sampler"},
|
||||
},
|
||||
)
|
||||
result = trainer.fit()
|
||||
|
||||
return result
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--offline", default=False, action="store_true")
|
||||
parser.add_argument(
|
||||
"--path", required=False, default="/tmp/out", help="Path to (offline) data"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--address", required=False, type=str, help="the address to use for Ray"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-workers",
|
||||
"-n",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Sets number of workers for training.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use-gpu", action="store_true", default=False, help="Enables GPU training"
|
||||
)
|
||||
args, _ = parser.parse_known_args()
|
||||
|
||||
ray.init(address=args.address)
|
||||
if args.offline:
|
||||
if not os.path.exists(args.path) or not os.listdir(args.path):
|
||||
generate_offline_data(args.path)
|
||||
result = train_rl_bc_offline(
|
||||
path=args.path, num_workers=args.num_workers, use_gpu=args.use_gpu
|
||||
)
|
||||
else:
|
||||
result = train_rl_ppo_online(num_workers=args.num_workers, use_gpu=args.use_gpu)
|
||||
|
||||
print(result.metrics)
|
|
@ -4,7 +4,7 @@ import ray
|
|||
from ray import tune
|
||||
from ray.ml.train.integrations.tensorflow import TensorflowTrainer
|
||||
|
||||
from ray.ml.examples.tensorflow.tensorflow_mnist_example import train_func
|
||||
from ray.ml.examples.tf.tensorflow_mnist_example import train_func
|
||||
from ray.tune.tune_config import TuneConfig
|
||||
from ray.tune.tuner import Tuner
|
||||
|
|
@ -4,7 +4,7 @@ import numpy as np
|
|||
import ray
|
||||
from ray import train
|
||||
from ray.ml.train.integrations.tensorflow import TensorflowTrainer
|
||||
from ray.ml.examples.tensorflow.tensorflow_linear_dataset_example import (
|
||||
from ray.ml.examples.tf.tensorflow_linear_dataset_example import (
|
||||
train_func as tensorflow_linear_train_func,
|
||||
get_dataset,
|
||||
)
|
||||
|
|
3
python/ray/ml/train/integrations/rl/__init__.py
Normal file
3
python/ray/ml/train/integrations/rl/__init__.py
Normal file
|
@ -0,0 +1,3 @@
|
|||
from ray.ml.train.integrations.rl.rl_trainer import RLTrainer
|
||||
|
||||
__all__ = ["RLTrainer"]
|
224
python/ray/ml/train/integrations/rl/rl_trainer.py
Normal file
224
python/ray/ml/train/integrations/rl/rl_trainer.py
Normal file
|
@ -0,0 +1,224 @@
|
|||
import inspect
|
||||
from typing import Optional, Dict, Type, Union, Callable, Any
|
||||
|
||||
from ray.ml.checkpoint import Checkpoint
|
||||
from ray.ml.config import ScalingConfig, RunConfig
|
||||
from ray.ml.preprocessor import Preprocessor
|
||||
from ray.ml.trainer import Trainer, GenDataset
|
||||
from ray.rllib.agents.trainer import Trainer as RLLibTrainer
|
||||
from ray.rllib.utils.typing import PartialTrainerConfigDict, EnvType
|
||||
from ray.tune import Trainable, PlacementGroupFactory
|
||||
from ray.tune.logger import Logger
|
||||
from ray.tune.registry import get_trainable_cls
|
||||
from ray.tune.resources import Resources
|
||||
from ray.util.annotations import PublicAPI
|
||||
from ray.util.ml_utils.dict import merge_dicts
|
||||
|
||||
|
||||
@PublicAPI(stability="alpha")
|
||||
class RLTrainer(Trainer):
|
||||
"""Reinforcement learning trainer.
|
||||
|
||||
This trainer provides an interface to RLLib trainables.
|
||||
|
||||
If datasets and preprocessors are used, they can be utilized for
|
||||
offline training, e.g. using behavior cloning. Otherwise, this trainer
|
||||
will use online training.
|
||||
|
||||
Args:
|
||||
algorithm: Algorithm to train on. Can be a string reference,
|
||||
(e.g. ``"PPO"``) or a RLLib trainer class.
|
||||
scaling_config: Configuration for how to scale training.
|
||||
run_config: Configuration for the execution of the training run.
|
||||
datasets: Any Ray Datasets to use for training. Use the key "train"
|
||||
to denote which dataset is the training
|
||||
dataset. If a ``preprocessor`` is provided and has not already been fit,
|
||||
it will be fit on the training dataset. All datasets will be transformed
|
||||
by the ``preprocessor`` if one is provided.
|
||||
If specified, datasets will be used for offline training. Will be
|
||||
configured as an RLLib ``input`` config item.
|
||||
preprocessor: A preprocessor to preprocess the provided datasets.
|
||||
resume_from_checkpoint: A checkpoint to resume training from.
|
||||
|
||||
Example:
|
||||
Online training:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from ray.ml.config import RunConfig
|
||||
from ray.ml.train.integrations.rl import RLTrainer
|
||||
|
||||
trainer = RLTrainer(
|
||||
run_config=RunConfig(stop={"training_iteration": 5}),
|
||||
scaling_config={
|
||||
"num_workers": 2,
|
||||
"use_gpu": False,
|
||||
},
|
||||
algorithm="PPO",
|
||||
config={
|
||||
"env": "CartPole-v0",
|
||||
"framework": "tf",
|
||||
"evaluation_num_workers": 1,
|
||||
"evaluation_interval": 1,
|
||||
"evaluation_config": {"input": "sampler"},
|
||||
},
|
||||
)
|
||||
result = trainer.fit()
|
||||
|
||||
|
||||
Example:
|
||||
Offline training (assumes data is stored in ``/tmp/data-dir``):
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import ray
|
||||
from ray.ml.config import RunConfig
|
||||
from ray.ml.train.integrations.rl import RLTrainer
|
||||
from ray.rllib.agents.marwil.bc import BCTrainer
|
||||
|
||||
dataset = ray.data.read_json(
|
||||
"/tmp/data-dir", parallelism=2, ray_remote_args={"num_cpus": 1}
|
||||
)
|
||||
|
||||
trainer = RLTrainer(
|
||||
run_config=RunConfig(stop={"training_iteration": 5}),
|
||||
scaling_config={
|
||||
"num_workers": 2,
|
||||
"use_gpu": False,
|
||||
},
|
||||
datasets={"train": dataset},
|
||||
algorithm=BCTrainer,
|
||||
config={
|
||||
"env": "CartPole-v0",
|
||||
"framework": "tf",
|
||||
"evaluation_num_workers": 1,
|
||||
"evaluation_interval": 1,
|
||||
"evaluation_config": {"input": "sampler"},
|
||||
},
|
||||
)
|
||||
result = trainer.fit()
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
algorithm: Union[str, Type[RLLibTrainer]],
|
||||
config: Optional[Dict[str, Any]] = None,
|
||||
scaling_config: Optional[ScalingConfig] = None,
|
||||
run_config: Optional[RunConfig] = None,
|
||||
datasets: Optional[Dict[str, GenDataset]] = None,
|
||||
preprocessor: Optional[Preprocessor] = None,
|
||||
resume_from_checkpoint: Optional[Checkpoint] = None,
|
||||
):
|
||||
self._algorithm = algorithm
|
||||
self._config = config if config is not None else {}
|
||||
|
||||
super(RLTrainer, self).__init__(
|
||||
scaling_config=scaling_config,
|
||||
run_config=run_config,
|
||||
datasets=datasets,
|
||||
preprocessor=preprocessor,
|
||||
resume_from_checkpoint=resume_from_checkpoint,
|
||||
)
|
||||
|
||||
def _validate_attributes(self):
|
||||
super(RLTrainer, self)._validate_attributes()
|
||||
|
||||
if not isinstance(self._algorithm, str) and not (
|
||||
inspect.isclass(self._algorithm)
|
||||
and issubclass(self._algorithm, RLLibTrainer)
|
||||
):
|
||||
raise ValueError(
|
||||
f"`algorithm` should be either a string or a RLLib trainer class, "
|
||||
f"found {type(self._algorithm)} with value `{self._algorithm}`."
|
||||
)
|
||||
|
||||
if not isinstance(self._config, dict):
|
||||
raise ValueError(
|
||||
f"`config` should be either a dict, "
|
||||
f"found {type(self._config)} with value `{self._config}`."
|
||||
)
|
||||
|
||||
def _get_rllib_config(self, process_datasets: bool = False) -> Dict:
|
||||
config = self._config.copy()
|
||||
num_workers = self.scaling_config.get("num_workers")
|
||||
if num_workers is not None:
|
||||
config["num_workers"] = num_workers
|
||||
|
||||
worker_resources = self.scaling_config.get("resources_per_worker")
|
||||
if worker_resources:
|
||||
res = worker_resources.copy()
|
||||
config["num_cpus_per_worker"] = res.pop("CPU", 1)
|
||||
config["num_gpus_per_worker"] = res.pop("GPU", 0)
|
||||
config["custom_resources_per_worker"] = res
|
||||
|
||||
use_gpu = self.scaling_config.get("use_gpu")
|
||||
if use_gpu:
|
||||
config["num_gpus"] = 1
|
||||
|
||||
trainer_resources = self.scaling_config.get("trainer_resources")
|
||||
if trainer_resources:
|
||||
config["num_cpus_for_driver"] = trainer_resources.get("CPU", 1)
|
||||
|
||||
if process_datasets:
|
||||
self.preprocess_datasets()
|
||||
# Up for discussion: If datasets is passed, should we always
|
||||
# set the input config? Is the sampler config required here, too?
|
||||
if self.datasets:
|
||||
config["input"] = "dataset"
|
||||
config["input_config"] = {
|
||||
"loader_fn": lambda: self.datasets["train"],
|
||||
}
|
||||
|
||||
return config
|
||||
|
||||
def training_loop(self) -> None:
|
||||
pass
|
||||
|
||||
def as_trainable(self) -> Type[Trainable]:
|
||||
param_dict = self._param_dict
|
||||
base_config = self._config
|
||||
trainer_cls = self.__class__
|
||||
|
||||
if isinstance(self._algorithm, str):
|
||||
rllib_trainer = get_trainable_cls(self._algorithm)
|
||||
else:
|
||||
rllib_trainer = self._algorithm
|
||||
|
||||
class AIRRLTrainer(rllib_trainer):
|
||||
def __init__(
|
||||
self,
|
||||
config: Optional[PartialTrainerConfigDict] = None,
|
||||
env: Optional[Union[str, EnvType]] = None,
|
||||
logger_creator: Optional[Callable[[], Logger]] = None,
|
||||
remote_checkpoint_dir: Optional[str] = None,
|
||||
sync_function_tpl: Optional[str] = None,
|
||||
):
|
||||
resolved_config = merge_dicts(base_config, config)
|
||||
param_dict["config"] = resolved_config
|
||||
|
||||
trainer = trainer_cls(**param_dict)
|
||||
rllib_config = trainer._get_rllib_config(process_datasets=True)
|
||||
|
||||
super(AIRRLTrainer, self).__init__(
|
||||
rllib_config,
|
||||
env,
|
||||
logger_creator,
|
||||
remote_checkpoint_dir,
|
||||
sync_function_tpl,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def default_resource_request(
|
||||
cls, config: PartialTrainerConfigDict
|
||||
) -> Union[Resources, PlacementGroupFactory]:
|
||||
resolved_config = merge_dicts(base_config, config)
|
||||
param_dict["config"] = resolved_config
|
||||
|
||||
trainer = trainer_cls(**param_dict)
|
||||
rllib_config = trainer._get_rllib_config(process_datasets=False)
|
||||
|
||||
return rllib_trainer.default_resource_request(rllib_config)
|
||||
|
||||
AIRRLTrainer.__name__ = f"AIR{rllib_trainer.__name__}"
|
||||
return AIRRLTrainer
|
|
@ -306,8 +306,8 @@ class Trainer(abc.ABC):
|
|||
result = result_grid[0]
|
||||
if result.error:
|
||||
raise result.error
|
||||
except TuneError:
|
||||
raise TrainingFailedError
|
||||
except TuneError as e:
|
||||
raise TrainingFailedError from e
|
||||
return result
|
||||
|
||||
def as_trainable(self) -> Type[Trainable]:
|
||||
|
|
|
@ -168,6 +168,7 @@ class TunerInternal:
|
|||
name=self._run_config.name,
|
||||
callbacks=self._run_config.callbacks,
|
||||
sync_config=self._run_config.sync_config,
|
||||
stop=self._run_config.stop,
|
||||
max_failures=(
|
||||
self._run_config.failure.max_failures if self._run_config.failure else 0
|
||||
),
|
||||
|
@ -186,6 +187,7 @@ class TunerInternal:
|
|||
metric=self._tune_config.metric,
|
||||
callbacks=self._run_config.callbacks,
|
||||
sync_config=self._run_config.sync_config,
|
||||
stop=self._run_config.stop,
|
||||
max_failures=(
|
||||
self._run_config.failure.max_failures if self._run_config.failure else 0
|
||||
),
|
||||
|
|
|
@ -32,9 +32,19 @@ def get_dataset_and_shards(
|
|||
), "Must specify input_config dict if using Dataset input."
|
||||
|
||||
input_config = config["input_config"]
|
||||
if not input_config.get("format", None) or not input_config.get("path", None):
|
||||
|
||||
format = input_config.get("format")
|
||||
path = input_config.get("path")
|
||||
loader_fn = input_config.get("loader_fn")
|
||||
|
||||
if loader_fn and (format or path):
|
||||
raise ValueError(
|
||||
"Must specify format and path via input_config key"
|
||||
"When using a `loader_fn`, you cannot specify a `format` or `path`."
|
||||
)
|
||||
|
||||
if not (format and path) and not loader_fn:
|
||||
raise ValueError(
|
||||
"Must specify format and path, or a loader_fn via input_config key"
|
||||
" when using Ray dataset input."
|
||||
)
|
||||
|
||||
|
@ -43,9 +53,11 @@ def get_dataset_and_shards(
|
|||
"num_cpus_per_read_task", DEFAULT_NUM_CPUS_PER_TASK
|
||||
)
|
||||
|
||||
format = input_config["format"]
|
||||
path = input_config["path"]
|
||||
if format == "json":
|
||||
assert loader_fn or (format and path)
|
||||
|
||||
if loader_fn:
|
||||
dataset = loader_fn()
|
||||
elif format == "json":
|
||||
dataset = ray.data.read_json(
|
||||
path, parallelism=parallelism, ray_remote_args={"num_cpus": cpus_per_task}
|
||||
)
|
||||
|
|
Loading…
Add table
Reference in a new issue