[AIR] Add RLTrainer interface, implementation, and examples (#23465)

This PR adds a RLTrainer to Ray AIR. It works for both offline and online use cases. In offline training, it will leverage the datasets key of the Trainer API to specify a dataset reader input, used e.g. in Behavioral Cloning (BC). In online training, it is a wrapper around the rllib trainables making use of the parameter layering enabled by the Trainer API.
This commit is contained in:
Kai Fricke 2022-04-08 17:16:42 -07:00 committed by GitHub
parent 5a41fb18bd
commit 8c2e471265
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
13 changed files with 390 additions and 17 deletions

View file

@ -19,11 +19,30 @@ py_test (
deps = [":ml_lib"]
)
py_test (
name = "rl_example_offline",
main = "examples/rl_example.py",
size = "medium",
srcs = ["examples/rl_example.py"],
args = ["--offline"],
tags = ["team:ml", "exclusive"],
deps = [":ml_lib"]
)
py_test (
name = "rl_example_online",
main = "examples/rl_example.py",
size = "medium",
srcs = ["examples/rl_example.py"],
tags = ["team:ml", "exclusive"],
deps = [":ml_lib"]
)
py_test(
name = "tensorflow_linear_dataset_example",
size = "medium",
main = "examples/tensorflow/tensorflow_linear_dataset_example.py",
srcs = ["examples/tensorflow/tensorflow_linear_dataset_example.py"],
main = "examples/tf/tensorflow_linear_dataset_example.py",
srcs = ["examples/tf/tensorflow_linear_dataset_example.py"],
tags = ["team:ml", "exclusive"],
deps = [":ml_lib"],
args = ["--smoke-test"]
@ -32,8 +51,8 @@ py_test(
py_test(
name = "tensorflow_mnist_example",
size = "medium",
main = "examples/tensorflow/tensorflow_mnist_example.py",
srcs = ["examples/tensorflow/tensorflow_mnist_example.py"],
main = "examples/tf/tensorflow_mnist_example.py",
srcs = ["examples/tf/tensorflow_mnist_example.py"],
tags = ["team:ml", "exclusive"],
deps = [":ml_lib"],
args = ["--smoke-test"]
@ -82,8 +101,8 @@ py_test(
py_test(
name = "tune_tensorflow_mnist_example",
size = "medium",
main = "examples/tensorflow/tune_tensorflow_mnist_example.py",
srcs = ["examples/tensorflow/tune_tensorflow_mnist_example.py"],
main = "examples/tf/tune_tensorflow_mnist_example.py",
srcs = ["examples/tf/tune_tensorflow_mnist_example.py"],
tags = ["team:ml", "exclusive"],
deps = [":ml_lib"],
args = ["--smoke-test"]

View file

@ -1,5 +1,5 @@
from dataclasses import dataclass
from typing import Dict, Any, Optional, List, TYPE_CHECKING
from typing import Dict, Any, Optional, List, Mapping, Callable, Union, TYPE_CHECKING
from ray.tune.syncer import SyncConfig
from ray.util import PublicAPI
@ -7,7 +7,7 @@ from ray.util import PublicAPI
if TYPE_CHECKING:
from ray.tune.trainable import PlacementGroupFactory
from ray.tune.callback import Callback
from ray.tune.stopper import Stopper
ScalingConfig = Dict[str, Any]
@ -137,6 +137,8 @@ class RunConfig:
from the Trainable.
local_dir: Local dir to save training results to.
Defaults to ``~/ray_results``.
stop: Stop conditions to consider. Refer to ray.tune.stopper.Stopper
for more info. Stoppers should be serializable.
callbacks: Callbacks to invoke.
Refer to ray.tune.callback.Callback for more info.
Callbacks should be serializable.
@ -151,5 +153,6 @@ class RunConfig:
name: Optional[str] = None
local_dir: Optional[str] = None
callbacks: Optional[List["Callback"]] = None
stop: Optional[Union[Mapping, "Stopper", Callable[[str, Mapping], bool]]] = None
failure: Optional[FailureConfig] = None
sync_config: Optional[SyncConfig] = None

View file

@ -0,0 +1,110 @@
import argparse
import os
import ray
from ray.ml.config import RunConfig
from ray.ml.train.integrations.rl.rl_trainer import RLTrainer
from ray.ml.result import Result
from ray.rllib.agents.marwil import BCTrainer
def generate_offline_data(path: str):
print(f"Generating offline data for training at {path}")
trainer = RLTrainer(
algorithm="PPO",
run_config=RunConfig(stop={"timesteps_total": 5000}),
config={
"env": "CartPole-v0",
"output": "dataset",
"output_config": {
"format": "json",
"path": path,
"max_num_samples_per_file": 1,
},
"batch_mode": "complete_episodes",
},
)
trainer.fit()
def train_rl_bc_offline(path: str, num_workers: int, use_gpu: bool = False) -> Result:
print("Starting offline training")
dataset = ray.data.read_json(
path, parallelism=num_workers, ray_remote_args={"num_cpus": 1}
)
trainer = RLTrainer(
run_config=RunConfig(stop={"training_iteration": 5}),
scaling_config={
"num_workers": num_workers,
"use_gpu": use_gpu,
},
datasets={"train": dataset},
algorithm=BCTrainer,
config={
"env": "CartPole-v0",
"framework": "tf",
"evaluation_num_workers": 1,
"evaluation_interval": 1,
"evaluation_config": {"input": "sampler"},
},
)
result = trainer.fit()
return result
def train_rl_ppo_online(num_workers: int, use_gpu: bool = False) -> Result:
print("Starting online training")
trainer = RLTrainer(
run_config=RunConfig(stop={"training_iteration": 5}),
scaling_config={
"num_workers": num_workers,
"use_gpu": use_gpu,
},
algorithm="PPO",
config={
"env": "CartPole-v0",
"framework": "tf",
"evaluation_num_workers": 1,
"evaluation_interval": 1,
"evaluation_config": {"input": "sampler"},
},
)
result = trainer.fit()
return result
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--offline", default=False, action="store_true")
parser.add_argument(
"--path", required=False, default="/tmp/out", help="Path to (offline) data"
)
parser.add_argument(
"--address", required=False, type=str, help="the address to use for Ray"
)
parser.add_argument(
"--num-workers",
"-n",
type=int,
default=1,
help="Sets number of workers for training.",
)
parser.add_argument(
"--use-gpu", action="store_true", default=False, help="Enables GPU training"
)
args, _ = parser.parse_known_args()
ray.init(address=args.address)
if args.offline:
if not os.path.exists(args.path) or not os.listdir(args.path):
generate_offline_data(args.path)
result = train_rl_bc_offline(
path=args.path, num_workers=args.num_workers, use_gpu=args.use_gpu
)
else:
result = train_rl_ppo_online(num_workers=args.num_workers, use_gpu=args.use_gpu)
print(result.metrics)

View file

@ -4,7 +4,7 @@ import ray
from ray import tune
from ray.ml.train.integrations.tensorflow import TensorflowTrainer
from ray.ml.examples.tensorflow.tensorflow_mnist_example import train_func
from ray.ml.examples.tf.tensorflow_mnist_example import train_func
from ray.tune.tune_config import TuneConfig
from ray.tune.tuner import Tuner

View file

@ -4,7 +4,7 @@ import numpy as np
import ray
from ray import train
from ray.ml.train.integrations.tensorflow import TensorflowTrainer
from ray.ml.examples.tensorflow.tensorflow_linear_dataset_example import (
from ray.ml.examples.tf.tensorflow_linear_dataset_example import (
train_func as tensorflow_linear_train_func,
get_dataset,
)

View file

@ -0,0 +1,3 @@
from ray.ml.train.integrations.rl.rl_trainer import RLTrainer
__all__ = ["RLTrainer"]

View file

@ -0,0 +1,224 @@
import inspect
from typing import Optional, Dict, Type, Union, Callable, Any
from ray.ml.checkpoint import Checkpoint
from ray.ml.config import ScalingConfig, RunConfig
from ray.ml.preprocessor import Preprocessor
from ray.ml.trainer import Trainer, GenDataset
from ray.rllib.agents.trainer import Trainer as RLLibTrainer
from ray.rllib.utils.typing import PartialTrainerConfigDict, EnvType
from ray.tune import Trainable, PlacementGroupFactory
from ray.tune.logger import Logger
from ray.tune.registry import get_trainable_cls
from ray.tune.resources import Resources
from ray.util.annotations import PublicAPI
from ray.util.ml_utils.dict import merge_dicts
@PublicAPI(stability="alpha")
class RLTrainer(Trainer):
"""Reinforcement learning trainer.
This trainer provides an interface to RLLib trainables.
If datasets and preprocessors are used, they can be utilized for
offline training, e.g. using behavior cloning. Otherwise, this trainer
will use online training.
Args:
algorithm: Algorithm to train on. Can be a string reference,
(e.g. ``"PPO"``) or a RLLib trainer class.
scaling_config: Configuration for how to scale training.
run_config: Configuration for the execution of the training run.
datasets: Any Ray Datasets to use for training. Use the key "train"
to denote which dataset is the training
dataset. If a ``preprocessor`` is provided and has not already been fit,
it will be fit on the training dataset. All datasets will be transformed
by the ``preprocessor`` if one is provided.
If specified, datasets will be used for offline training. Will be
configured as an RLLib ``input`` config item.
preprocessor: A preprocessor to preprocess the provided datasets.
resume_from_checkpoint: A checkpoint to resume training from.
Example:
Online training:
.. code-block:: python
from ray.ml.config import RunConfig
from ray.ml.train.integrations.rl import RLTrainer
trainer = RLTrainer(
run_config=RunConfig(stop={"training_iteration": 5}),
scaling_config={
"num_workers": 2,
"use_gpu": False,
},
algorithm="PPO",
config={
"env": "CartPole-v0",
"framework": "tf",
"evaluation_num_workers": 1,
"evaluation_interval": 1,
"evaluation_config": {"input": "sampler"},
},
)
result = trainer.fit()
Example:
Offline training (assumes data is stored in ``/tmp/data-dir``):
.. code-block:: python
import ray
from ray.ml.config import RunConfig
from ray.ml.train.integrations.rl import RLTrainer
from ray.rllib.agents.marwil.bc import BCTrainer
dataset = ray.data.read_json(
"/tmp/data-dir", parallelism=2, ray_remote_args={"num_cpus": 1}
)
trainer = RLTrainer(
run_config=RunConfig(stop={"training_iteration": 5}),
scaling_config={
"num_workers": 2,
"use_gpu": False,
},
datasets={"train": dataset},
algorithm=BCTrainer,
config={
"env": "CartPole-v0",
"framework": "tf",
"evaluation_num_workers": 1,
"evaluation_interval": 1,
"evaluation_config": {"input": "sampler"},
},
)
result = trainer.fit()
"""
def __init__(
self,
algorithm: Union[str, Type[RLLibTrainer]],
config: Optional[Dict[str, Any]] = None,
scaling_config: Optional[ScalingConfig] = None,
run_config: Optional[RunConfig] = None,
datasets: Optional[Dict[str, GenDataset]] = None,
preprocessor: Optional[Preprocessor] = None,
resume_from_checkpoint: Optional[Checkpoint] = None,
):
self._algorithm = algorithm
self._config = config if config is not None else {}
super(RLTrainer, self).__init__(
scaling_config=scaling_config,
run_config=run_config,
datasets=datasets,
preprocessor=preprocessor,
resume_from_checkpoint=resume_from_checkpoint,
)
def _validate_attributes(self):
super(RLTrainer, self)._validate_attributes()
if not isinstance(self._algorithm, str) and not (
inspect.isclass(self._algorithm)
and issubclass(self._algorithm, RLLibTrainer)
):
raise ValueError(
f"`algorithm` should be either a string or a RLLib trainer class, "
f"found {type(self._algorithm)} with value `{self._algorithm}`."
)
if not isinstance(self._config, dict):
raise ValueError(
f"`config` should be either a dict, "
f"found {type(self._config)} with value `{self._config}`."
)
def _get_rllib_config(self, process_datasets: bool = False) -> Dict:
config = self._config.copy()
num_workers = self.scaling_config.get("num_workers")
if num_workers is not None:
config["num_workers"] = num_workers
worker_resources = self.scaling_config.get("resources_per_worker")
if worker_resources:
res = worker_resources.copy()
config["num_cpus_per_worker"] = res.pop("CPU", 1)
config["num_gpus_per_worker"] = res.pop("GPU", 0)
config["custom_resources_per_worker"] = res
use_gpu = self.scaling_config.get("use_gpu")
if use_gpu:
config["num_gpus"] = 1
trainer_resources = self.scaling_config.get("trainer_resources")
if trainer_resources:
config["num_cpus_for_driver"] = trainer_resources.get("CPU", 1)
if process_datasets:
self.preprocess_datasets()
# Up for discussion: If datasets is passed, should we always
# set the input config? Is the sampler config required here, too?
if self.datasets:
config["input"] = "dataset"
config["input_config"] = {
"loader_fn": lambda: self.datasets["train"],
}
return config
def training_loop(self) -> None:
pass
def as_trainable(self) -> Type[Trainable]:
param_dict = self._param_dict
base_config = self._config
trainer_cls = self.__class__
if isinstance(self._algorithm, str):
rllib_trainer = get_trainable_cls(self._algorithm)
else:
rllib_trainer = self._algorithm
class AIRRLTrainer(rllib_trainer):
def __init__(
self,
config: Optional[PartialTrainerConfigDict] = None,
env: Optional[Union[str, EnvType]] = None,
logger_creator: Optional[Callable[[], Logger]] = None,
remote_checkpoint_dir: Optional[str] = None,
sync_function_tpl: Optional[str] = None,
):
resolved_config = merge_dicts(base_config, config)
param_dict["config"] = resolved_config
trainer = trainer_cls(**param_dict)
rllib_config = trainer._get_rllib_config(process_datasets=True)
super(AIRRLTrainer, self).__init__(
rllib_config,
env,
logger_creator,
remote_checkpoint_dir,
sync_function_tpl,
)
@classmethod
def default_resource_request(
cls, config: PartialTrainerConfigDict
) -> Union[Resources, PlacementGroupFactory]:
resolved_config = merge_dicts(base_config, config)
param_dict["config"] = resolved_config
trainer = trainer_cls(**param_dict)
rllib_config = trainer._get_rllib_config(process_datasets=False)
return rllib_trainer.default_resource_request(rllib_config)
AIRRLTrainer.__name__ = f"AIR{rllib_trainer.__name__}"
return AIRRLTrainer

View file

@ -306,8 +306,8 @@ class Trainer(abc.ABC):
result = result_grid[0]
if result.error:
raise result.error
except TuneError:
raise TrainingFailedError
except TuneError as e:
raise TrainingFailedError from e
return result
def as_trainable(self) -> Type[Trainable]:

View file

@ -168,6 +168,7 @@ class TunerInternal:
name=self._run_config.name,
callbacks=self._run_config.callbacks,
sync_config=self._run_config.sync_config,
stop=self._run_config.stop,
max_failures=(
self._run_config.failure.max_failures if self._run_config.failure else 0
),
@ -186,6 +187,7 @@ class TunerInternal:
metric=self._tune_config.metric,
callbacks=self._run_config.callbacks,
sync_config=self._run_config.sync_config,
stop=self._run_config.stop,
max_failures=(
self._run_config.failure.max_failures if self._run_config.failure else 0
),

View file

@ -32,9 +32,19 @@ def get_dataset_and_shards(
), "Must specify input_config dict if using Dataset input."
input_config = config["input_config"]
if not input_config.get("format", None) or not input_config.get("path", None):
format = input_config.get("format")
path = input_config.get("path")
loader_fn = input_config.get("loader_fn")
if loader_fn and (format or path):
raise ValueError(
"Must specify format and path via input_config key"
"When using a `loader_fn`, you cannot specify a `format` or `path`."
)
if not (format and path) and not loader_fn:
raise ValueError(
"Must specify format and path, or a loader_fn via input_config key"
" when using Ray dataset input."
)
@ -43,9 +53,11 @@ def get_dataset_and_shards(
"num_cpus_per_read_task", DEFAULT_NUM_CPUS_PER_TASK
)
format = input_config["format"]
path = input_config["path"]
if format == "json":
assert loader_fn or (format and path)
if loader_fn:
dataset = loader_fn()
elif format == "json":
dataset = ray.data.read_json(
path, parallelism=parallelism, ray_remote_args={"num_cpus": cpus_per_task}
)