mirror of
https://github.com/vale981/ray
synced 2025-03-04 17:41:43 -05:00
[RLlib] Add Decision Transformer (DT) (#27890)
This commit is contained in:
parent
6be4bf8be3
commit
edde905741
12 changed files with 1050 additions and 3 deletions
47
rllib/BUILD
47
rllib/BUILD
|
@ -381,6 +381,35 @@ py_test(
|
|||
args = ["--yaml-dir=tuned_examples/dqn"]
|
||||
)
|
||||
|
||||
# DT
|
||||
py_test(
|
||||
name = "learning_tests_pendulum_dt",
|
||||
main = "tests/run_regression_tests.py",
|
||||
tags = ["team:rllib", "torch_only", "learning_tests", "learning_tests_pendulum", "learning_tests_continuous"],
|
||||
size = "large",
|
||||
srcs = ["tests/run_regression_tests.py"],
|
||||
# Include an offline json data file as well.
|
||||
data = [
|
||||
"tuned_examples/dt/pendulum-v1-dt.yaml",
|
||||
"tests/data/pendulum/pendulum_expert_sac_50eps.zip",
|
||||
],
|
||||
args = ["--yaml-dir=tuned_examples/dt"]
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "learning_tests_cartpole_dt",
|
||||
main = "tests/run_regression_tests.py",
|
||||
tags = ["team:rllib", "torch_only", "learning_tests", "learning_tests_cartpole", "learning_tests_discrete"],
|
||||
size = "large",
|
||||
srcs = ["tests/run_regression_tests.py"],
|
||||
# Include an offline json data file as well.
|
||||
data = [
|
||||
"tuned_examples/dt/cartpole-v0-dt.yaml",
|
||||
"tests/data/cartpole/large.json",
|
||||
],
|
||||
args = ["--yaml-dir=tuned_examples/dt"]
|
||||
)
|
||||
|
||||
# Simple-Q
|
||||
py_test(
|
||||
name = "learning_tests_cartpole_simpleq",
|
||||
|
@ -928,6 +957,14 @@ py_test(
|
|||
srcs = ["algorithms/dt/tests/test_dt_policy.py"]
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "test_dt",
|
||||
tags = ["team:rllib", "algorithms_dir"],
|
||||
size = "medium",
|
||||
srcs = ["algorithms/dt/tests/test_dt.py"],
|
||||
data = ["tests/data/pendulum/large.json"],
|
||||
)
|
||||
|
||||
# ES
|
||||
py_test(
|
||||
name = "test_es",
|
||||
|
@ -3148,6 +3185,16 @@ py_test(
|
|||
args = ["--stop-iters=2", "--framework=torch"]
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "examples/inference_and_serving/policy_inference_after_training_with_dt_torch",
|
||||
main = "examples/inference_and_serving/policy_inference_after_training_with_dt.py",
|
||||
tags = ["team:rllib", "exclusive", "examples", "examples_P"],
|
||||
size = "medium",
|
||||
srcs = ["examples/inference_and_serving/policy_inference_after_training_with_dt.py"],
|
||||
data = ["tests/data/cartpole/large.json"],
|
||||
args = ["--input-files=tests/data/cartpole/large.json"]
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "examples/inference_and_serving/policy_inference_after_training_with_lstm_tf",
|
||||
main = "examples/inference_and_serving/policy_inference_after_training_with_lstm.py",
|
||||
|
|
|
@ -0,0 +1,6 @@
|
|||
from ray.rllib.algorithms.dt.dt import DT, DTConfig
|
||||
|
||||
__all__ = [
|
||||
"DT",
|
||||
"DTConfig",
|
||||
]
|
401
rllib/algorithms/dt/dt.py
Normal file
401
rllib/algorithms/dt/dt.py
Normal file
|
@ -0,0 +1,401 @@
|
|||
import logging
|
||||
import math
|
||||
from typing import List, Optional, Type, Tuple, Dict, Any, Union
|
||||
|
||||
from ray.rllib import SampleBatch
|
||||
from ray.rllib.algorithms.algorithm import Algorithm, AlgorithmConfig
|
||||
from ray.rllib.algorithms.dt.segmentation_buffer import MultiAgentSegmentationBuffer
|
||||
from ray.rllib.execution import synchronous_parallel_sample
|
||||
from ray.rllib.execution.train_ops import multi_gpu_train_one_step, train_one_step
|
||||
from ray.rllib.policy import Policy
|
||||
from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID
|
||||
from ray.rllib.utils.annotations import override, PublicAPI
|
||||
from ray.rllib.utils.metrics import (
|
||||
NUM_AGENT_STEPS_SAMPLED,
|
||||
NUM_ENV_STEPS_SAMPLED,
|
||||
SAMPLE_TIMER,
|
||||
NUM_AGENT_STEPS_TRAINED,
|
||||
)
|
||||
from ray.rllib.utils.typing import (
|
||||
AlgorithmConfigDict,
|
||||
ResultDict,
|
||||
TensorStructType,
|
||||
PolicyID,
|
||||
TensorType,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DTConfig(AlgorithmConfig):
|
||||
def __init__(self, algo_class=None):
|
||||
super().__init__(algo_class=algo_class or DT)
|
||||
|
||||
# fmt: off
|
||||
# __sphinx_doc_begin__
|
||||
# DT-specific settings.
|
||||
# Required settings during training and evaluation:
|
||||
# Initial return to go used as target during rollout.
|
||||
self.target_return = None
|
||||
# Rollout horizon/maximum episode length.
|
||||
self.horizon = None
|
||||
|
||||
# Model settings:
|
||||
self.model = {
|
||||
# Transformer (GPT) context length.
|
||||
"max_seq_len": 5,
|
||||
}
|
||||
|
||||
# Transformer (GPT) settings:
|
||||
self.embed_dim = 128
|
||||
self.num_layers = 2
|
||||
self.num_heads = 1
|
||||
self.embed_pdrop = 0.1
|
||||
self.resid_pdrop = 0.1
|
||||
self.attn_pdrop = 0.1
|
||||
|
||||
# Optimization settings:
|
||||
self.lr = 1e-4
|
||||
self.lr_schedule = None
|
||||
self.optimizer = {
|
||||
# Weight decay for Adam optimizer.
|
||||
"weight_decay": 1e-4,
|
||||
# Betas for Adam optimizer.
|
||||
"betas": (0.9, 0.95),
|
||||
}
|
||||
self.grad_clip = None
|
||||
# Coefficients on the loss for each of the heads.
|
||||
# By default, only use the actions outputs for training.
|
||||
self.loss_coef_actions = 1
|
||||
self.loss_coef_obs = 0
|
||||
self.loss_coef_returns_to_go = 0
|
||||
|
||||
self.replay_buffer_config = {
|
||||
# How many trajectories/episodes does the segmentation buffer hold.
|
||||
# Increase for more data shuffling but increased memory usage.
|
||||
"capacity": 20,
|
||||
# Do not change the type of replay buffer.
|
||||
"type": MultiAgentSegmentationBuffer,
|
||||
}
|
||||
# __sphinx_doc_end__
|
||||
# fmt: on
|
||||
|
||||
# Overwriting the trainer config default
|
||||
# If data ingestion/sample_time is slow, increase this.
|
||||
self.num_workers = 0
|
||||
# Number of training_step calls between evaluation rollouts.
|
||||
self.min_train_timesteps_per_iteration = 5000
|
||||
|
||||
# Don't change
|
||||
self.offline_sampling = True
|
||||
self.postprocess_inputs = True
|
||||
self.discount = None
|
||||
|
||||
def training(
|
||||
self,
|
||||
*,
|
||||
replay_buffer_config: Optional[Dict[str, Any]],
|
||||
embed_dim: Optional[int] = None,
|
||||
num_layers: Optional[int] = None,
|
||||
num_heads: Optional[int] = None,
|
||||
embed_pdrop: Optional[float] = None,
|
||||
resid_pdrop: Optional[float] = None,
|
||||
attn_pdrop: Optional[float] = None,
|
||||
grad_clip: Optional[float] = None,
|
||||
loss_coef_actions: Optional[float] = None,
|
||||
loss_coef_obs: Optional[float] = None,
|
||||
loss_coef_returns_to_go: Optional[float] = None,
|
||||
lr_schedule: Optional[List[List[Union[int, float]]]] = None,
|
||||
**kwargs,
|
||||
) -> "DTConfig":
|
||||
"""
|
||||
=== DT configs
|
||||
|
||||
Args:
|
||||
replay_buffer_config: Replay buffer config.
|
||||
{
|
||||
"capacity": How many trajectories/episodes does the buffer hold.
|
||||
}
|
||||
embed_dim: Dimension of the embeddings in the GPT model.
|
||||
num_layers: Number of attention layers in the GPT model.
|
||||
num_heads: Number of attention heads in the GPT model. Must divide
|
||||
embed_dim evenly.
|
||||
embed_pdrop: Dropout probability of the embedding layer of the GPT model.
|
||||
resid_pdrop: Dropout probability of the residual layer of the GPT model.
|
||||
attn_pdrop: Dropout probability of the attention layer of the GPT model.
|
||||
grad_clip: If specified, clip the global norm of gradients by this amount.
|
||||
lr_schedule: Learning rate schedule. In the format of
|
||||
[[timestep, lr-value], [timestep, lr-value], ...]
|
||||
Intermediary timesteps will be assigned to interpolated learning rate
|
||||
values. A schedule should normally start from timestep 0.
|
||||
loss_coef_actions: Coefficients on the loss for the actions output.
|
||||
Default to 1.
|
||||
loss_coef_obs: Coefficients on the loss for the obs output. Default to 0.
|
||||
Set to a value greater than 0 to regress on the obs output.
|
||||
loss_coef_returns_to_go: Coefficients on the loss for the returns_to_go
|
||||
output. Default to 0. Set to a value greater than 0 to regress on the
|
||||
returns_to_go output.
|
||||
**kwargs: Forward compatibility kwargs
|
||||
|
||||
Returns:
|
||||
This updated DTConfig object.
|
||||
"""
|
||||
super().training(**kwargs)
|
||||
if replay_buffer_config is not None:
|
||||
self.replay_buffer_config = replay_buffer_config
|
||||
if embed_dim is not None:
|
||||
self.embed_dim = embed_dim
|
||||
if num_layers is not None:
|
||||
self.num_layers = num_layers
|
||||
if num_heads is not None:
|
||||
self.num_heads = num_heads
|
||||
if embed_pdrop is not None:
|
||||
self.embed_pdrop = embed_pdrop
|
||||
if resid_pdrop is not None:
|
||||
self.resid_pdrop = resid_pdrop
|
||||
if attn_pdrop is not None:
|
||||
self.attn_pdrop = attn_pdrop
|
||||
if grad_clip is not None:
|
||||
self.grad_clip = grad_clip
|
||||
if lr_schedule is not None:
|
||||
self.lr_schedule = lr_schedule
|
||||
if loss_coef_actions is not None:
|
||||
self.loss_coef_actions = loss_coef_actions
|
||||
if loss_coef_obs is not None:
|
||||
self.loss_coef_obs = loss_coef_obs
|
||||
if loss_coef_returns_to_go is not None:
|
||||
self.loss_coef_returns_to_go = loss_coef_returns_to_go
|
||||
|
||||
return self
|
||||
|
||||
def evaluation(
|
||||
self,
|
||||
*,
|
||||
target_return: Optional[float] = None,
|
||||
**kwargs,
|
||||
) -> "DTConfig":
|
||||
"""
|
||||
=== DT configs
|
||||
|
||||
Args:
|
||||
target_return: The target return-to-go for inference/evaluation.
|
||||
**kwargs: Forward compatibility kwargs
|
||||
|
||||
Returns:
|
||||
This updated DTConfig object.
|
||||
"""
|
||||
super().evaluation(**kwargs)
|
||||
if target_return is not None:
|
||||
self.target_return = target_return
|
||||
|
||||
return self
|
||||
|
||||
|
||||
class DT(Algorithm):
|
||||
"""Implements Decision Transformer: https://arxiv.org/abs/2106.01345"""
|
||||
|
||||
# TODO: we have a circular dependency for get
|
||||
# default config. config -> Trainer -> config
|
||||
# defining Config class in the same file for now as a workaround.
|
||||
|
||||
@override(Algorithm)
|
||||
def validate_config(self, config: AlgorithmConfigDict) -> None:
|
||||
"""Validates the Trainer's config dict.
|
||||
|
||||
Args:
|
||||
config: The Trainer's config to check.
|
||||
|
||||
Raises:
|
||||
ValueError: In case something is wrong with the config.
|
||||
"""
|
||||
# Call super's validation method.
|
||||
super().validate_config(config)
|
||||
|
||||
# target_return must be specified
|
||||
assert (
|
||||
self.config.get("target_return") is not None
|
||||
), "Must specify a target return (total sum of rewards)."
|
||||
|
||||
# horizon must be specified and >= 2
|
||||
assert self.config.get("horizon") is not None, "Must specify rollout horizon."
|
||||
assert self.config["horizon"] >= 2, "rollout horizon must be at least 2."
|
||||
|
||||
# replay_buffer's type must be MultiAgentSegmentationBuffer
|
||||
assert (
|
||||
self.config.get("replay_buffer_config") is not None
|
||||
), "Must specify replay_buffer_config."
|
||||
replay_buffer_type = self.config["replay_buffer_config"].get("type")
|
||||
assert (
|
||||
replay_buffer_type == MultiAgentSegmentationBuffer
|
||||
), "replay_buffer's type must be MultiAgentSegmentationBuffer."
|
||||
|
||||
# max_seq_len must be specified in model
|
||||
model_max_seq_len = self.config["model"].get("max_seq_len")
|
||||
assert model_max_seq_len is not None, "Must specify model's max_seq_len."
|
||||
|
||||
# User shouldn't need to specify replay_buffer's max_seq_len.
|
||||
# Autofill for replay buffer API. If they did specify, make sure it
|
||||
# matches with model's max_seq_len
|
||||
buffer_max_seq_len = self.config["replay_buffer_config"].get("max_seq_len")
|
||||
if buffer_max_seq_len is None:
|
||||
self.config["replay_buffer_config"]["max_seq_len"] = model_max_seq_len
|
||||
else:
|
||||
assert (
|
||||
buffer_max_seq_len == model_max_seq_len
|
||||
), "replay_buffer's max_seq_len must equal model's max_seq_len."
|
||||
|
||||
# Same thing for buffer's max_ep_len, which should be autofilled from
|
||||
# rollout's horizon, or check that it matches if user specified.
|
||||
buffer_max_ep_len = self.config["replay_buffer_config"].get("max_ep_len")
|
||||
if buffer_max_ep_len is None:
|
||||
self.config["replay_buffer_config"]["max_ep_len"] = self.config["horizon"]
|
||||
else:
|
||||
assert (
|
||||
buffer_max_ep_len == self.config["horizon"]
|
||||
), "replay_buffer's max_ep_len must equal rollout horizon."
|
||||
|
||||
@classmethod
|
||||
@override(Algorithm)
|
||||
def get_default_config(cls) -> AlgorithmConfigDict:
|
||||
return DTConfig().to_dict()
|
||||
|
||||
@override(Algorithm)
|
||||
def get_default_policy_class(self, config: AlgorithmConfigDict) -> Type[Policy]:
|
||||
if config["framework"] == "torch":
|
||||
from ray.rllib.algorithms.dt.dt_torch_policy import DTTorchPolicy
|
||||
|
||||
return DTTorchPolicy
|
||||
else:
|
||||
raise ValueError("Non-torch frameworks are not supported yet!")
|
||||
|
||||
@override(Algorithm)
|
||||
def training_step(self) -> ResultDict:
|
||||
with self._timers[SAMPLE_TIMER]:
|
||||
# TODO: Add ability to do obs_filter for offline sampling.
|
||||
train_batch = synchronous_parallel_sample(worker_set=self.workers)
|
||||
|
||||
train_batch = train_batch.as_multi_agent()
|
||||
self._counters[NUM_AGENT_STEPS_SAMPLED] += train_batch.agent_steps()
|
||||
self._counters[NUM_ENV_STEPS_SAMPLED] += train_batch.env_steps()
|
||||
|
||||
# Because each sample is a segment of max_seq_len transitions, doing
|
||||
# the division makes it so the total number of transitions per train
|
||||
# step is consistent.
|
||||
num_steps = train_batch.env_steps()
|
||||
batch_size = int(math.ceil(num_steps / self.config["model"]["max_seq_len"]))
|
||||
|
||||
# Add the batch of episodes to the segmentation buffer.
|
||||
self.local_replay_buffer.add(train_batch)
|
||||
# Sample a batch of segments.
|
||||
train_batch = self.local_replay_buffer.sample(batch_size)
|
||||
|
||||
# Postprocess batch before we learn on it.
|
||||
post_fn = self.config.get("before_learn_on_batch") or (lambda b, *a: b)
|
||||
train_batch = post_fn(train_batch, self.workers, self.config)
|
||||
|
||||
# Learn on training batch.
|
||||
# Use simple optimizer (only for multi-agent or tf-eager; all other
|
||||
# cases should use the multi-GPU optimizer, even if only using 1 GPU)
|
||||
if self.config.get("simple_optimizer", False):
|
||||
train_results = train_one_step(self, train_batch)
|
||||
else:
|
||||
train_results = multi_gpu_train_one_step(self, train_batch)
|
||||
|
||||
# Update learning rate scheduler.
|
||||
global_vars = {
|
||||
# Note: this counts the number of segments trained, not timesteps.
|
||||
# i.e. NUM_AGENT_STEPS_TRAINED: B, NUM_AGENT_STEPS_SAMPLED: B*T
|
||||
"timestep": self._counters[NUM_AGENT_STEPS_TRAINED],
|
||||
}
|
||||
self.workers.local_worker().set_global_vars(global_vars)
|
||||
|
||||
return train_results
|
||||
|
||||
@PublicAPI
|
||||
@override(Algorithm)
|
||||
def compute_single_action(
|
||||
self,
|
||||
*args,
|
||||
input_dict: Optional[SampleBatch] = None,
|
||||
full_fetch: bool = True,
|
||||
**kwargs,
|
||||
) -> Tuple[TensorStructType, List[TensorType], Dict[str, TensorType]]:
|
||||
"""Computes an action for the specified policy on the local worker.
|
||||
|
||||
Note that you can also access the policy object through
|
||||
self.get_policy(policy_id) and call compute_single_action() on it
|
||||
directly.
|
||||
|
||||
Args:
|
||||
input_dict: A SampleBatch taken from get_initial_input_dict or
|
||||
get_next_input_dict.
|
||||
full_fetch: Whether to return extra action fetch results.
|
||||
This is always True for DT.
|
||||
kwargs: forward compatibility args.
|
||||
|
||||
Returns:
|
||||
A tuple containing: (
|
||||
the computed action,
|
||||
list of RNN states (empty for DT),
|
||||
extra action output (pass to get_next_input_dict),
|
||||
)
|
||||
"""
|
||||
assert input_dict is not None, (
|
||||
"DT must take in input_dict for inference. "
|
||||
"See get_initial_input_dict() and get_next_input_dict()."
|
||||
)
|
||||
assert (
|
||||
full_fetch
|
||||
), "DT needs full_fetch=True. Pass extra into get_next_input_dict()."
|
||||
|
||||
return super().compute_single_action(
|
||||
*args, input_dict=input_dict.copy(), full_fetch=full_fetch, **kwargs
|
||||
)
|
||||
|
||||
@PublicAPI
|
||||
def get_initial_input_dict(
|
||||
self,
|
||||
observation: TensorStructType,
|
||||
policy_id: PolicyID = DEFAULT_POLICY_ID,
|
||||
) -> SampleBatch:
|
||||
"""Get the initial input_dict to be passed into compute_single_action.
|
||||
|
||||
Args:
|
||||
observation: first (unbatched) observation from env.reset()
|
||||
policy_id: Policy to query (only applies to multi-agent).
|
||||
Default: "default_policy".
|
||||
|
||||
Returns:
|
||||
The input_dict for inference.
|
||||
"""
|
||||
policy = self.get_policy(policy_id)
|
||||
return policy.get_initial_input_dict(observation)
|
||||
|
||||
@PublicAPI
|
||||
def get_next_input_dict(
|
||||
self,
|
||||
input_dict: SampleBatch,
|
||||
action: TensorStructType,
|
||||
reward: TensorStructType,
|
||||
next_obs: TensorStructType,
|
||||
extra: Dict[str, TensorType],
|
||||
policy_id: PolicyID = DEFAULT_POLICY_ID,
|
||||
) -> SampleBatch:
|
||||
"""Returns a new input_dict after stepping through the environment once.
|
||||
|
||||
Args:
|
||||
input_dict: the input dict passed into compute_single_action.
|
||||
action: the (unbatched) action taken this step.
|
||||
reward: the (unbatched) reward from env.step
|
||||
next_obs: the (unbatached) next observation from env.step
|
||||
extra: the extra action out from compute_single_action.
|
||||
For DT this case contains current returns to go *before* the current
|
||||
reward is subtracted from target_return.
|
||||
policy_id: Policy to query (only applies to multi-agent).
|
||||
Default: "default_policy".
|
||||
|
||||
Returns:
|
||||
A new input_dict to be passed into compute_single_action.
|
||||
"""
|
||||
policy = self.get_policy(policy_id)
|
||||
return policy.get_next_input_dict(input_dict, action, reward, next_obs, extra)
|
|
@ -58,9 +58,6 @@ class SegmentationBuffer:
|
|||
self._add_single_episode(episode)
|
||||
|
||||
def _add_single_episode(self, episode: SampleBatch):
|
||||
# Truncate if episode too long.
|
||||
# Note: sometimes this happens if the dataset shuffles such that the
|
||||
# same episode is concatenated together twice (which is okay).
|
||||
ep_len = episode.env_steps()
|
||||
|
||||
if ep_len > self.max_ep_len:
|
||||
|
|
269
rllib/algorithms/dt/tests/test_dt.py
Normal file
269
rllib/algorithms/dt/tests/test_dt.py
Normal file
|
@ -0,0 +1,269 @@
|
|||
from pathlib import Path
|
||||
import os
|
||||
import unittest
|
||||
from typing import Dict
|
||||
|
||||
import gym
|
||||
import numpy as np
|
||||
|
||||
import ray
|
||||
from ray.rllib import SampleBatch
|
||||
from ray.rllib.algorithms.dt import DTConfig
|
||||
from ray.rllib.utils.framework import try_import_tf, try_import_torch
|
||||
from ray.rllib.utils.test_utils import check_train_results
|
||||
|
||||
tf1, tf, tfv = try_import_tf()
|
||||
torch, _ = try_import_torch()
|
||||
|
||||
|
||||
def _assert_input_dict_equals(d1: Dict[str, np.ndarray], d2: Dict[str, np.ndarray]):
|
||||
for key in d1.keys():
|
||||
assert key in d2.keys()
|
||||
|
||||
for key in d2.keys():
|
||||
assert key in d1.keys()
|
||||
|
||||
for key in d1.keys():
|
||||
assert isinstance(d1[key], np.ndarray)
|
||||
assert isinstance(d2[key], np.ndarray)
|
||||
assert d1[key].shape == d2[key].shape
|
||||
assert np.allclose(d1[key], d2[key])
|
||||
|
||||
|
||||
class TestDT(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
ray.init()
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
ray.shutdown()
|
||||
|
||||
def test_dt_compilation(self):
|
||||
"""Test whether a DT algorithm can be built with all supported frameworks."""
|
||||
|
||||
rllib_dir = Path(__file__).parent.parent.parent.parent
|
||||
data_file = os.path.join(rllib_dir, "tests/data/pendulum/large.json")
|
||||
|
||||
input_config = {
|
||||
"paths": data_file,
|
||||
"format": "json",
|
||||
}
|
||||
|
||||
config = (
|
||||
DTConfig()
|
||||
.environment(
|
||||
env="Pendulum-v1",
|
||||
clip_actions=True,
|
||||
normalize_actions=True,
|
||||
)
|
||||
.framework("torch")
|
||||
.offline_data(
|
||||
input_="dataset",
|
||||
input_config=input_config,
|
||||
actions_in_input_normalized=True,
|
||||
)
|
||||
.training(
|
||||
train_batch_size=200,
|
||||
replay_buffer_config={
|
||||
"capacity": 8,
|
||||
},
|
||||
model={
|
||||
"max_seq_len": 4,
|
||||
},
|
||||
num_layers=1,
|
||||
num_heads=1,
|
||||
embed_dim=64,
|
||||
)
|
||||
.evaluation(
|
||||
target_return=-120,
|
||||
evaluation_interval=2,
|
||||
evaluation_num_workers=0,
|
||||
evaluation_duration=10,
|
||||
evaluation_duration_unit="episodes",
|
||||
evaluation_parallel_to_training=False,
|
||||
evaluation_config={"input": "sampler", "explore": False},
|
||||
)
|
||||
.rollouts(
|
||||
num_rollout_workers=0,
|
||||
horizon=200,
|
||||
)
|
||||
.reporting(
|
||||
min_train_timesteps_per_iteration=10,
|
||||
)
|
||||
)
|
||||
|
||||
num_iterations = 4
|
||||
|
||||
for _ in ["torch"]:
|
||||
algo = config.build()
|
||||
# check if 4 iterations raises any errors
|
||||
for i in range(num_iterations):
|
||||
results = algo.train()
|
||||
check_train_results(results)
|
||||
print(results)
|
||||
if (i + 1) % 2 == 0:
|
||||
# evaluation happens every 2 iterations
|
||||
eval_results = results["evaluation"]
|
||||
print(
|
||||
f"iter={algo.iteration} "
|
||||
f"R={eval_results['episode_reward_mean']}"
|
||||
)
|
||||
|
||||
# do example inference rollout
|
||||
env = gym.make("Pendulum-v1")
|
||||
|
||||
obs = env.reset()
|
||||
input_dict = algo.get_initial_input_dict(obs)
|
||||
|
||||
for _ in range(200):
|
||||
action, _, extra = algo.compute_single_action(input_dict=input_dict)
|
||||
obs, reward, done, _ = env.step(action)
|
||||
if done:
|
||||
break
|
||||
else:
|
||||
input_dict = algo.get_next_input_dict(
|
||||
input_dict,
|
||||
action,
|
||||
reward,
|
||||
obs,
|
||||
extra,
|
||||
)
|
||||
|
||||
env.close()
|
||||
algo.stop()
|
||||
|
||||
def test_inference_methods(self):
|
||||
"""Test inference methods."""
|
||||
|
||||
config = (
|
||||
DTConfig()
|
||||
.environment(
|
||||
env="Pendulum-v1",
|
||||
clip_actions=True,
|
||||
normalize_actions=True,
|
||||
)
|
||||
.framework("torch")
|
||||
.training(
|
||||
train_batch_size=200,
|
||||
replay_buffer_config={
|
||||
"capacity": 8,
|
||||
},
|
||||
model={
|
||||
"max_seq_len": 3,
|
||||
},
|
||||
num_layers=1,
|
||||
num_heads=1,
|
||||
embed_dim=64,
|
||||
)
|
||||
.evaluation(
|
||||
target_return=-120,
|
||||
)
|
||||
.rollouts(
|
||||
num_rollout_workers=0,
|
||||
horizon=200,
|
||||
)
|
||||
)
|
||||
algo = config.build()
|
||||
|
||||
# Do a controlled fake rollout for 2 steps and check input_dict
|
||||
# first input_dict
|
||||
obs = np.array([0.0, 1.0, 2.0])
|
||||
|
||||
input_dict = algo.get_initial_input_dict(obs)
|
||||
target = SampleBatch(
|
||||
{
|
||||
SampleBatch.OBS: np.array(
|
||||
[
|
||||
[0.0, 0.0, 0.0],
|
||||
[0.0, 0.0, 0.0],
|
||||
[0.0, 1.0, 2.0],
|
||||
],
|
||||
dtype=np.float32,
|
||||
),
|
||||
SampleBatch.ACTIONS: np.array([[0.0], [0.0]], dtype=np.float32),
|
||||
SampleBatch.RETURNS_TO_GO: np.array([0.0, 0.0], dtype=np.float32),
|
||||
SampleBatch.REWARDS: np.zeros((), dtype=np.float32),
|
||||
SampleBatch.T: np.array([-1, -1], dtype=np.int32),
|
||||
}
|
||||
)
|
||||
_assert_input_dict_equals(input_dict, target)
|
||||
|
||||
# forward pass with first input_dict
|
||||
action, _, extra = algo.compute_single_action(input_dict=input_dict)
|
||||
assert action.shape == (1,)
|
||||
assert SampleBatch.RETURNS_TO_GO in extra
|
||||
assert np.isclose(extra[SampleBatch.RETURNS_TO_GO], -120.0)
|
||||
|
||||
# second input_dict
|
||||
action = np.array([0.5])
|
||||
obs = np.array([3.0, 4.0, 5.0])
|
||||
reward = -10.0
|
||||
|
||||
input_dict = algo.get_next_input_dict(
|
||||
input_dict,
|
||||
action,
|
||||
reward,
|
||||
obs,
|
||||
extra,
|
||||
)
|
||||
target = SampleBatch(
|
||||
{
|
||||
SampleBatch.OBS: np.array(
|
||||
[
|
||||
[0.0, 0.0, 0.0],
|
||||
[0.0, 1.0, 2.0],
|
||||
[3.0, 4.0, 5.0],
|
||||
],
|
||||
dtype=np.float32,
|
||||
),
|
||||
SampleBatch.ACTIONS: np.array([[0.0], [0.5]], dtype=np.float32),
|
||||
SampleBatch.RETURNS_TO_GO: np.array([0.0, -120.0], dtype=np.float32),
|
||||
SampleBatch.REWARDS: np.asarray(-10.0),
|
||||
SampleBatch.T: np.array([-1, 0], dtype=np.int32),
|
||||
}
|
||||
)
|
||||
_assert_input_dict_equals(input_dict, target)
|
||||
|
||||
# forward pass with second input_dict
|
||||
action, _, extra = algo.compute_single_action(input_dict=input_dict)
|
||||
assert action.shape == (1,)
|
||||
assert SampleBatch.RETURNS_TO_GO in extra
|
||||
assert np.isclose(extra[SampleBatch.RETURNS_TO_GO], -110.0)
|
||||
|
||||
# third input_dict
|
||||
action = np.array([-0.2])
|
||||
obs = np.array([6.0, 7.0, 8.0])
|
||||
reward = -20.0
|
||||
|
||||
input_dict = algo.get_next_input_dict(
|
||||
input_dict,
|
||||
action,
|
||||
reward,
|
||||
obs,
|
||||
extra,
|
||||
)
|
||||
target = SampleBatch(
|
||||
{
|
||||
SampleBatch.OBS: np.array(
|
||||
[
|
||||
[0.0, 1.0, 2.0],
|
||||
[3.0, 4.0, 5.0],
|
||||
[6.0, 7.0, 8.0],
|
||||
],
|
||||
dtype=np.float32,
|
||||
),
|
||||
SampleBatch.ACTIONS: np.array([[0.5], [-0.2]], dtype=np.float32),
|
||||
SampleBatch.RETURNS_TO_GO: np.array([-120, -110.0], dtype=np.float32),
|
||||
SampleBatch.REWARDS: np.asarray(-20.0),
|
||||
SampleBatch.T: np.array([0, 1], dtype=np.int32),
|
||||
}
|
||||
)
|
||||
_assert_input_dict_equals(input_dict, target)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import pytest
|
||||
import sys
|
||||
|
||||
sys.exit(pytest.main(["-v", __file__]))
|
|
@ -108,6 +108,12 @@ def _import_dreamer():
|
|||
return dreamer.Dreamer, dreamer.DreamerConfig().to_dict()
|
||||
|
||||
|
||||
def _import_dt():
|
||||
import ray.rllib.algorithms.dt as dt
|
||||
|
||||
return dt.DT, dt.DTConfig().to_dict()
|
||||
|
||||
|
||||
def _import_es():
|
||||
import ray.rllib.algorithms.es as es
|
||||
|
||||
|
@ -215,6 +221,7 @@ ALGORITHMS = {
|
|||
"DDPPO": _import_ddppo,
|
||||
"DQN": _import_dqn,
|
||||
"Dreamer": _import_dreamer,
|
||||
"DT": _import_dt,
|
||||
"IMPALA": _import_impala,
|
||||
"APPO": _import_appo,
|
||||
"AlphaStar": _import_alpha_star,
|
||||
|
@ -309,6 +316,7 @@ POLICIES = {
|
|||
"DQNTFPolicy": "dqn.dqn_tf_policy",
|
||||
"DQNTorchPolicy": "dqn.dqn_torch_policy",
|
||||
"DreamerTorchPolicy": "dreamer.dreamer_torch_policy",
|
||||
"DTTorchPolicy": "dt.dt_torch_policy",
|
||||
"ESTFPolicy": "es.es_tf_policy",
|
||||
"ESTorchPolicy": "es.es_torch_policy",
|
||||
"ImpalaTF1Policy": "impala.impala_tf_policy",
|
||||
|
|
|
@ -0,0 +1,182 @@
|
|||
"""
|
||||
Example showing how you can use your trained Decision Transformer (DT) policy for
|
||||
inference (computing actions) in an environment.
|
||||
"""
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
import gym
|
||||
import os
|
||||
|
||||
import ray
|
||||
from ray import air, tune
|
||||
from ray.rllib.algorithms.dt import DTConfig
|
||||
from ray.rllib.algorithms.registry import get_algorithm_class
|
||||
from ray.tune.utils.log import Verbosity
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--num-cpus", type=int, default=0)
|
||||
parser.add_argument(
|
||||
"--input-files",
|
||||
nargs="+",
|
||||
default=[
|
||||
os.path.join(
|
||||
os.path.dirname(os.path.abspath(__file__)),
|
||||
"../../tests/data/cartpole/large.json",
|
||||
)
|
||||
],
|
||||
help="List of paths to offline json files/zips for training.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-episodes-during-inference",
|
||||
type=int,
|
||||
default=10,
|
||||
help="Number of episodes to do inference over after training.",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
ray.init(num_cpus=args.num_cpus or None)
|
||||
|
||||
# Bazel makes it hard to find files specified in `args` (and `data`).
|
||||
# Look for them here.
|
||||
input_files = []
|
||||
for input_file in args.input_files:
|
||||
if not os.path.exists(input_file):
|
||||
# This script runs in the ray/rllib/examples/inference_and_serving dir.
|
||||
rllib_dir = Path(__file__).parent.parent.parent
|
||||
input_dir = rllib_dir.absolute().joinpath(input_file)
|
||||
input_files.append(str(input_dir))
|
||||
else:
|
||||
input_files.append(input_file)
|
||||
|
||||
# Get max_ep_len
|
||||
env = gym.make("CartPole-v0")
|
||||
max_ep_len = env.spec.max_episode_steps
|
||||
env.close()
|
||||
|
||||
# Training config
|
||||
config = (
|
||||
DTConfig()
|
||||
.environment(
|
||||
env="CartPole-v0",
|
||||
clip_actions=False,
|
||||
normalize_actions=False,
|
||||
)
|
||||
.framework("torch")
|
||||
.offline_data(
|
||||
input_="dataset",
|
||||
input_config={
|
||||
"format": "json",
|
||||
"paths": input_files,
|
||||
},
|
||||
actions_in_input_normalized=True,
|
||||
)
|
||||
.training(
|
||||
lr=0.01,
|
||||
optimizer={
|
||||
"weight_decay": 0.1,
|
||||
"betas": [0.9, 0.999],
|
||||
},
|
||||
train_batch_size=512,
|
||||
replay_buffer_config={
|
||||
"capacity": 20,
|
||||
},
|
||||
model={
|
||||
"max_seq_len": 3,
|
||||
},
|
||||
num_layers=1,
|
||||
num_heads=1,
|
||||
embed_dim=64,
|
||||
)
|
||||
# Need to do evaluation rollouts for stopping condition.
|
||||
.evaluation(
|
||||
target_return=200.0,
|
||||
evaluation_interval=1,
|
||||
evaluation_num_workers=1,
|
||||
evaluation_duration=10,
|
||||
evaluation_duration_unit="episodes",
|
||||
evaluation_parallel_to_training=False,
|
||||
evaluation_config={"input": "sampler", "explore": False},
|
||||
)
|
||||
.rollouts(
|
||||
num_rollout_workers=0,
|
||||
# This needs to be specified
|
||||
horizon=max_ep_len,
|
||||
)
|
||||
.reporting(
|
||||
min_train_timesteps_per_iteration=5000,
|
||||
)
|
||||
.resources(
|
||||
# Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0.
|
||||
num_gpus=int(os.environ.get("RLLIB_NUM_GPUS", "0")),
|
||||
)
|
||||
)
|
||||
config = config.to_dict()
|
||||
|
||||
# Configure when to stop training
|
||||
# Note that for an offline RL algorithm, we don't do training rollouts,
|
||||
# instead we have to rely on evaluation rollouts.
|
||||
stop = {
|
||||
"evaluation/episode_reward_mean": 200.0,
|
||||
"training_iteration": 100,
|
||||
}
|
||||
|
||||
print("Training policy until desired reward/iterations. ...")
|
||||
tuner = tune.Tuner(
|
||||
"DT",
|
||||
param_space=config,
|
||||
run_config=air.RunConfig(
|
||||
stop=stop,
|
||||
verbose=Verbosity.V3_TRIAL_DETAILS,
|
||||
checkpoint_config=air.CheckpointConfig(
|
||||
checkpoint_frequency=1,
|
||||
checkpoint_at_end=True,
|
||||
),
|
||||
),
|
||||
)
|
||||
results = tuner.fit()
|
||||
|
||||
print("Training completed. Restoring new Algorithm for action inference.")
|
||||
# Get the last checkpoint from the above training run.
|
||||
checkpoint = results.get_best_result().checkpoint
|
||||
# Create new Algorithm and restore its state from the last checkpoint.
|
||||
algo = get_algorithm_class("DT")(config=config)
|
||||
algo.restore(checkpoint)
|
||||
|
||||
# Create the env to do inference in.
|
||||
env = gym.make("CartPole-v0")
|
||||
|
||||
obs = env.reset()
|
||||
input_dict = algo.get_initial_input_dict(obs)
|
||||
|
||||
num_episodes = 0
|
||||
total_rewards = 0.0
|
||||
|
||||
while num_episodes < args.num_episodes_during_inference:
|
||||
# Compute an action (`a`).
|
||||
a, _, extra = algo.compute_single_action(input_dict=input_dict)
|
||||
# Send the computed action `a` to the env.
|
||||
obs, reward, done, _ = env.step(a)
|
||||
# Add to total rewards.
|
||||
total_rewards += reward
|
||||
# Is the episode `done`? -> Reset.
|
||||
if done:
|
||||
print(f"Episode {num_episodes+1} - return: {total_rewards}")
|
||||
obs = env.reset()
|
||||
input_dict = algo.get_initial_input_dict(obs)
|
||||
num_episodes += 1
|
||||
total_rewards = 0.0
|
||||
# Episode is still ongoing -> Continue.
|
||||
else:
|
||||
input_dict = algo.get_next_input_dict(
|
||||
input_dict,
|
||||
a,
|
||||
reward,
|
||||
obs,
|
||||
extra,
|
||||
)
|
||||
|
||||
env.close()
|
||||
ray.shutdown()
|
BIN
rllib/tests/data/pendulum/pendulum_expert_sac_50eps.zip
Normal file
BIN
rllib/tests/data/pendulum/pendulum_expert_sac_50eps.zip
Normal file
Binary file not shown.
BIN
rllib/tests/data/pendulum/pendulum_medium_sac_50eps.zip
Normal file
BIN
rllib/tests/data/pendulum/pendulum_medium_sac_50eps.zip
Normal file
Binary file not shown.
42
rllib/tuned_examples/dt/cartpole-v0-dt.yaml
Normal file
42
rllib/tuned_examples/dt/cartpole-v0-dt.yaml
Normal file
|
@ -0,0 +1,42 @@
|
|||
cartpole_dt:
|
||||
env: 'CartPole-v0'
|
||||
run: DT
|
||||
stop:
|
||||
evaluation/episode_reward_mean: 200
|
||||
training_iteration: 100
|
||||
config:
|
||||
input: 'dataset'
|
||||
input_config:
|
||||
paths: 'tests/data/cartpole/large.json'
|
||||
format: 'json'
|
||||
num_workers: 3
|
||||
actions_in_input_normalized: True
|
||||
clip_actions: False
|
||||
# training
|
||||
framework: torch
|
||||
train_batch_size: 512
|
||||
min_train_timesteps_per_iteration: 5000
|
||||
target_return: 200
|
||||
lr: 0.01
|
||||
optimizer:
|
||||
weight_decay: 0.1
|
||||
betas: [0.9, 0.999]
|
||||
replay_buffer_config:
|
||||
capacity: 20
|
||||
# model
|
||||
model:
|
||||
max_seq_len: 3
|
||||
num_layers: 1
|
||||
num_heads: 1
|
||||
embed_dim: 64
|
||||
# rollout
|
||||
horizon: 200
|
||||
# evaluation
|
||||
evaluation_config:
|
||||
explore: False
|
||||
input: sampler
|
||||
evaluation_duration: 10
|
||||
evaluation_duration_unit: episodes
|
||||
evaluation_interval: 1
|
||||
evaluation_num_workers: 1
|
||||
evaluation_parallel_to_training: True
|
46
rllib/tuned_examples/dt/pendulum-v1-dt.yaml
Normal file
46
rllib/tuned_examples/dt/pendulum-v1-dt.yaml
Normal file
|
@ -0,0 +1,46 @@
|
|||
pendulum_dt:
|
||||
env: 'Pendulum-v1'
|
||||
run: DT
|
||||
stop:
|
||||
# We could make this higher, but given that we have 4 cpus for our tests, we will have to settle for -300.
|
||||
evaluation/episode_reward_mean: -300
|
||||
timesteps_total: 20000000
|
||||
config:
|
||||
input: 'dataset'
|
||||
input_config:
|
||||
paths: 'tests/data/pendulum/pendulum_expert_sac_50eps.zip'
|
||||
format: 'json'
|
||||
num_workers: 3
|
||||
actions_in_input_normalized: True
|
||||
clip_actions: True
|
||||
normalize_actions: True
|
||||
# training
|
||||
framework: torch
|
||||
train_batch_size: 512
|
||||
min_train_timesteps_per_iteration: 5000
|
||||
target_return: -120.0
|
||||
lr: 0.0
|
||||
lr_schedule: [[0, 0.0], [10000, 0.01]]
|
||||
grad_clip: 1.0
|
||||
optimizer:
|
||||
weight_decay: 0.1
|
||||
betas: [0.9, 0.999]
|
||||
replay_buffer_config:
|
||||
capacity: 20
|
||||
# model
|
||||
model:
|
||||
max_seq_len: 3
|
||||
num_layers: 1
|
||||
num_heads: 1
|
||||
embed_dim: 64
|
||||
# rollout
|
||||
horizon: 200
|
||||
# evaluation
|
||||
evaluation_config:
|
||||
explore: False
|
||||
input: sampler
|
||||
evaluation_duration: 10
|
||||
evaluation_duration_unit: episodes
|
||||
evaluation_interval: 1
|
||||
evaluation_num_workers: 1
|
||||
evaluation_parallel_to_training: True
|
49
rllib/tuned_examples/dt/pendulum-v1-medium-expert-dt.yaml
Normal file
49
rllib/tuned_examples/dt/pendulum-v1-medium-expert-dt.yaml
Normal file
|
@ -0,0 +1,49 @@
|
|||
pendulum_medium_expert_dt:
|
||||
env: 'Pendulum-v1'
|
||||
run: DT
|
||||
stop:
|
||||
# We could make this higher, but given that we have 4 cpus for our tests, we will have to settle for -350.
|
||||
evaluation/episode_reward_mean: -350
|
||||
timesteps_total: 20000000
|
||||
config:
|
||||
input: 'dataset'
|
||||
input_config:
|
||||
paths: [
|
||||
'tests/data/pendulum/pendulum_expert_sac_50eps.zip',
|
||||
'tests/data/pendulum/pendulum_medium_sac_50eps.zip',
|
||||
]
|
||||
format: 'json'
|
||||
num_workers: 3
|
||||
actions_in_input_normalized: True
|
||||
clip_actions: True
|
||||
normalize_actions: True
|
||||
# training
|
||||
framework: torch
|
||||
train_batch_size: 512
|
||||
min_train_timesteps_per_iteration: 5000
|
||||
target_return: -120.0
|
||||
lr: 0.0
|
||||
lr_schedule: [[0, 0.0], [100000, 0.01]]
|
||||
grad_clip: 1.0
|
||||
optimizer:
|
||||
weight_decay: 0.1
|
||||
betas: [0.9, 0.999]
|
||||
replay_buffer_config:
|
||||
capacity: 20
|
||||
# model
|
||||
model:
|
||||
max_seq_len: 3
|
||||
num_layers: 1
|
||||
num_heads: 1
|
||||
embed_dim: 64
|
||||
# rollout
|
||||
horizon: 200
|
||||
# evaluation
|
||||
evaluation_config:
|
||||
explore: False
|
||||
input: sampler
|
||||
evaluation_duration: 10
|
||||
evaluation_duration_unit: episodes
|
||||
evaluation_interval: 1
|
||||
evaluation_num_workers: 1
|
||||
evaluation_parallel_to_training: True
|
Loading…
Add table
Reference in a new issue