mirror of
https://github.com/vale981/ray
synced 2025-03-04 17:41:43 -05:00
[RLlib] Add Decision Transformer (DT) (#27890)
This commit is contained in:
parent
6be4bf8be3
commit
edde905741
12 changed files with 1050 additions and 3 deletions
47
rllib/BUILD
47
rllib/BUILD
|
@ -381,6 +381,35 @@ py_test(
|
||||||
args = ["--yaml-dir=tuned_examples/dqn"]
|
args = ["--yaml-dir=tuned_examples/dqn"]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# DT
|
||||||
|
py_test(
|
||||||
|
name = "learning_tests_pendulum_dt",
|
||||||
|
main = "tests/run_regression_tests.py",
|
||||||
|
tags = ["team:rllib", "torch_only", "learning_tests", "learning_tests_pendulum", "learning_tests_continuous"],
|
||||||
|
size = "large",
|
||||||
|
srcs = ["tests/run_regression_tests.py"],
|
||||||
|
# Include an offline json data file as well.
|
||||||
|
data = [
|
||||||
|
"tuned_examples/dt/pendulum-v1-dt.yaml",
|
||||||
|
"tests/data/pendulum/pendulum_expert_sac_50eps.zip",
|
||||||
|
],
|
||||||
|
args = ["--yaml-dir=tuned_examples/dt"]
|
||||||
|
)
|
||||||
|
|
||||||
|
py_test(
|
||||||
|
name = "learning_tests_cartpole_dt",
|
||||||
|
main = "tests/run_regression_tests.py",
|
||||||
|
tags = ["team:rllib", "torch_only", "learning_tests", "learning_tests_cartpole", "learning_tests_discrete"],
|
||||||
|
size = "large",
|
||||||
|
srcs = ["tests/run_regression_tests.py"],
|
||||||
|
# Include an offline json data file as well.
|
||||||
|
data = [
|
||||||
|
"tuned_examples/dt/cartpole-v0-dt.yaml",
|
||||||
|
"tests/data/cartpole/large.json",
|
||||||
|
],
|
||||||
|
args = ["--yaml-dir=tuned_examples/dt"]
|
||||||
|
)
|
||||||
|
|
||||||
# Simple-Q
|
# Simple-Q
|
||||||
py_test(
|
py_test(
|
||||||
name = "learning_tests_cartpole_simpleq",
|
name = "learning_tests_cartpole_simpleq",
|
||||||
|
@ -928,6 +957,14 @@ py_test(
|
||||||
srcs = ["algorithms/dt/tests/test_dt_policy.py"]
|
srcs = ["algorithms/dt/tests/test_dt_policy.py"]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
py_test(
|
||||||
|
name = "test_dt",
|
||||||
|
tags = ["team:rllib", "algorithms_dir"],
|
||||||
|
size = "medium",
|
||||||
|
srcs = ["algorithms/dt/tests/test_dt.py"],
|
||||||
|
data = ["tests/data/pendulum/large.json"],
|
||||||
|
)
|
||||||
|
|
||||||
# ES
|
# ES
|
||||||
py_test(
|
py_test(
|
||||||
name = "test_es",
|
name = "test_es",
|
||||||
|
@ -3148,6 +3185,16 @@ py_test(
|
||||||
args = ["--stop-iters=2", "--framework=torch"]
|
args = ["--stop-iters=2", "--framework=torch"]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
py_test(
|
||||||
|
name = "examples/inference_and_serving/policy_inference_after_training_with_dt_torch",
|
||||||
|
main = "examples/inference_and_serving/policy_inference_after_training_with_dt.py",
|
||||||
|
tags = ["team:rllib", "exclusive", "examples", "examples_P"],
|
||||||
|
size = "medium",
|
||||||
|
srcs = ["examples/inference_and_serving/policy_inference_after_training_with_dt.py"],
|
||||||
|
data = ["tests/data/cartpole/large.json"],
|
||||||
|
args = ["--input-files=tests/data/cartpole/large.json"]
|
||||||
|
)
|
||||||
|
|
||||||
py_test(
|
py_test(
|
||||||
name = "examples/inference_and_serving/policy_inference_after_training_with_lstm_tf",
|
name = "examples/inference_and_serving/policy_inference_after_training_with_lstm_tf",
|
||||||
main = "examples/inference_and_serving/policy_inference_after_training_with_lstm.py",
|
main = "examples/inference_and_serving/policy_inference_after_training_with_lstm.py",
|
||||||
|
|
|
@ -0,0 +1,6 @@
|
||||||
|
from ray.rllib.algorithms.dt.dt import DT, DTConfig
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"DT",
|
||||||
|
"DTConfig",
|
||||||
|
]
|
401
rllib/algorithms/dt/dt.py
Normal file
401
rllib/algorithms/dt/dt.py
Normal file
|
@ -0,0 +1,401 @@
|
||||||
|
import logging
|
||||||
|
import math
|
||||||
|
from typing import List, Optional, Type, Tuple, Dict, Any, Union
|
||||||
|
|
||||||
|
from ray.rllib import SampleBatch
|
||||||
|
from ray.rllib.algorithms.algorithm import Algorithm, AlgorithmConfig
|
||||||
|
from ray.rllib.algorithms.dt.segmentation_buffer import MultiAgentSegmentationBuffer
|
||||||
|
from ray.rllib.execution import synchronous_parallel_sample
|
||||||
|
from ray.rllib.execution.train_ops import multi_gpu_train_one_step, train_one_step
|
||||||
|
from ray.rllib.policy import Policy
|
||||||
|
from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID
|
||||||
|
from ray.rllib.utils.annotations import override, PublicAPI
|
||||||
|
from ray.rllib.utils.metrics import (
|
||||||
|
NUM_AGENT_STEPS_SAMPLED,
|
||||||
|
NUM_ENV_STEPS_SAMPLED,
|
||||||
|
SAMPLE_TIMER,
|
||||||
|
NUM_AGENT_STEPS_TRAINED,
|
||||||
|
)
|
||||||
|
from ray.rllib.utils.typing import (
|
||||||
|
AlgorithmConfigDict,
|
||||||
|
ResultDict,
|
||||||
|
TensorStructType,
|
||||||
|
PolicyID,
|
||||||
|
TensorType,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class DTConfig(AlgorithmConfig):
|
||||||
|
def __init__(self, algo_class=None):
|
||||||
|
super().__init__(algo_class=algo_class or DT)
|
||||||
|
|
||||||
|
# fmt: off
|
||||||
|
# __sphinx_doc_begin__
|
||||||
|
# DT-specific settings.
|
||||||
|
# Required settings during training and evaluation:
|
||||||
|
# Initial return to go used as target during rollout.
|
||||||
|
self.target_return = None
|
||||||
|
# Rollout horizon/maximum episode length.
|
||||||
|
self.horizon = None
|
||||||
|
|
||||||
|
# Model settings:
|
||||||
|
self.model = {
|
||||||
|
# Transformer (GPT) context length.
|
||||||
|
"max_seq_len": 5,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Transformer (GPT) settings:
|
||||||
|
self.embed_dim = 128
|
||||||
|
self.num_layers = 2
|
||||||
|
self.num_heads = 1
|
||||||
|
self.embed_pdrop = 0.1
|
||||||
|
self.resid_pdrop = 0.1
|
||||||
|
self.attn_pdrop = 0.1
|
||||||
|
|
||||||
|
# Optimization settings:
|
||||||
|
self.lr = 1e-4
|
||||||
|
self.lr_schedule = None
|
||||||
|
self.optimizer = {
|
||||||
|
# Weight decay for Adam optimizer.
|
||||||
|
"weight_decay": 1e-4,
|
||||||
|
# Betas for Adam optimizer.
|
||||||
|
"betas": (0.9, 0.95),
|
||||||
|
}
|
||||||
|
self.grad_clip = None
|
||||||
|
# Coefficients on the loss for each of the heads.
|
||||||
|
# By default, only use the actions outputs for training.
|
||||||
|
self.loss_coef_actions = 1
|
||||||
|
self.loss_coef_obs = 0
|
||||||
|
self.loss_coef_returns_to_go = 0
|
||||||
|
|
||||||
|
self.replay_buffer_config = {
|
||||||
|
# How many trajectories/episodes does the segmentation buffer hold.
|
||||||
|
# Increase for more data shuffling but increased memory usage.
|
||||||
|
"capacity": 20,
|
||||||
|
# Do not change the type of replay buffer.
|
||||||
|
"type": MultiAgentSegmentationBuffer,
|
||||||
|
}
|
||||||
|
# __sphinx_doc_end__
|
||||||
|
# fmt: on
|
||||||
|
|
||||||
|
# Overwriting the trainer config default
|
||||||
|
# If data ingestion/sample_time is slow, increase this.
|
||||||
|
self.num_workers = 0
|
||||||
|
# Number of training_step calls between evaluation rollouts.
|
||||||
|
self.min_train_timesteps_per_iteration = 5000
|
||||||
|
|
||||||
|
# Don't change
|
||||||
|
self.offline_sampling = True
|
||||||
|
self.postprocess_inputs = True
|
||||||
|
self.discount = None
|
||||||
|
|
||||||
|
def training(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
replay_buffer_config: Optional[Dict[str, Any]],
|
||||||
|
embed_dim: Optional[int] = None,
|
||||||
|
num_layers: Optional[int] = None,
|
||||||
|
num_heads: Optional[int] = None,
|
||||||
|
embed_pdrop: Optional[float] = None,
|
||||||
|
resid_pdrop: Optional[float] = None,
|
||||||
|
attn_pdrop: Optional[float] = None,
|
||||||
|
grad_clip: Optional[float] = None,
|
||||||
|
loss_coef_actions: Optional[float] = None,
|
||||||
|
loss_coef_obs: Optional[float] = None,
|
||||||
|
loss_coef_returns_to_go: Optional[float] = None,
|
||||||
|
lr_schedule: Optional[List[List[Union[int, float]]]] = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> "DTConfig":
|
||||||
|
"""
|
||||||
|
=== DT configs
|
||||||
|
|
||||||
|
Args:
|
||||||
|
replay_buffer_config: Replay buffer config.
|
||||||
|
{
|
||||||
|
"capacity": How many trajectories/episodes does the buffer hold.
|
||||||
|
}
|
||||||
|
embed_dim: Dimension of the embeddings in the GPT model.
|
||||||
|
num_layers: Number of attention layers in the GPT model.
|
||||||
|
num_heads: Number of attention heads in the GPT model. Must divide
|
||||||
|
embed_dim evenly.
|
||||||
|
embed_pdrop: Dropout probability of the embedding layer of the GPT model.
|
||||||
|
resid_pdrop: Dropout probability of the residual layer of the GPT model.
|
||||||
|
attn_pdrop: Dropout probability of the attention layer of the GPT model.
|
||||||
|
grad_clip: If specified, clip the global norm of gradients by this amount.
|
||||||
|
lr_schedule: Learning rate schedule. In the format of
|
||||||
|
[[timestep, lr-value], [timestep, lr-value], ...]
|
||||||
|
Intermediary timesteps will be assigned to interpolated learning rate
|
||||||
|
values. A schedule should normally start from timestep 0.
|
||||||
|
loss_coef_actions: Coefficients on the loss for the actions output.
|
||||||
|
Default to 1.
|
||||||
|
loss_coef_obs: Coefficients on the loss for the obs output. Default to 0.
|
||||||
|
Set to a value greater than 0 to regress on the obs output.
|
||||||
|
loss_coef_returns_to_go: Coefficients on the loss for the returns_to_go
|
||||||
|
output. Default to 0. Set to a value greater than 0 to regress on the
|
||||||
|
returns_to_go output.
|
||||||
|
**kwargs: Forward compatibility kwargs
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
This updated DTConfig object.
|
||||||
|
"""
|
||||||
|
super().training(**kwargs)
|
||||||
|
if replay_buffer_config is not None:
|
||||||
|
self.replay_buffer_config = replay_buffer_config
|
||||||
|
if embed_dim is not None:
|
||||||
|
self.embed_dim = embed_dim
|
||||||
|
if num_layers is not None:
|
||||||
|
self.num_layers = num_layers
|
||||||
|
if num_heads is not None:
|
||||||
|
self.num_heads = num_heads
|
||||||
|
if embed_pdrop is not None:
|
||||||
|
self.embed_pdrop = embed_pdrop
|
||||||
|
if resid_pdrop is not None:
|
||||||
|
self.resid_pdrop = resid_pdrop
|
||||||
|
if attn_pdrop is not None:
|
||||||
|
self.attn_pdrop = attn_pdrop
|
||||||
|
if grad_clip is not None:
|
||||||
|
self.grad_clip = grad_clip
|
||||||
|
if lr_schedule is not None:
|
||||||
|
self.lr_schedule = lr_schedule
|
||||||
|
if loss_coef_actions is not None:
|
||||||
|
self.loss_coef_actions = loss_coef_actions
|
||||||
|
if loss_coef_obs is not None:
|
||||||
|
self.loss_coef_obs = loss_coef_obs
|
||||||
|
if loss_coef_returns_to_go is not None:
|
||||||
|
self.loss_coef_returns_to_go = loss_coef_returns_to_go
|
||||||
|
|
||||||
|
return self
|
||||||
|
|
||||||
|
def evaluation(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
target_return: Optional[float] = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> "DTConfig":
|
||||||
|
"""
|
||||||
|
=== DT configs
|
||||||
|
|
||||||
|
Args:
|
||||||
|
target_return: The target return-to-go for inference/evaluation.
|
||||||
|
**kwargs: Forward compatibility kwargs
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
This updated DTConfig object.
|
||||||
|
"""
|
||||||
|
super().evaluation(**kwargs)
|
||||||
|
if target_return is not None:
|
||||||
|
self.target_return = target_return
|
||||||
|
|
||||||
|
return self
|
||||||
|
|
||||||
|
|
||||||
|
class DT(Algorithm):
|
||||||
|
"""Implements Decision Transformer: https://arxiv.org/abs/2106.01345"""
|
||||||
|
|
||||||
|
# TODO: we have a circular dependency for get
|
||||||
|
# default config. config -> Trainer -> config
|
||||||
|
# defining Config class in the same file for now as a workaround.
|
||||||
|
|
||||||
|
@override(Algorithm)
|
||||||
|
def validate_config(self, config: AlgorithmConfigDict) -> None:
|
||||||
|
"""Validates the Trainer's config dict.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config: The Trainer's config to check.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: In case something is wrong with the config.
|
||||||
|
"""
|
||||||
|
# Call super's validation method.
|
||||||
|
super().validate_config(config)
|
||||||
|
|
||||||
|
# target_return must be specified
|
||||||
|
assert (
|
||||||
|
self.config.get("target_return") is not None
|
||||||
|
), "Must specify a target return (total sum of rewards)."
|
||||||
|
|
||||||
|
# horizon must be specified and >= 2
|
||||||
|
assert self.config.get("horizon") is not None, "Must specify rollout horizon."
|
||||||
|
assert self.config["horizon"] >= 2, "rollout horizon must be at least 2."
|
||||||
|
|
||||||
|
# replay_buffer's type must be MultiAgentSegmentationBuffer
|
||||||
|
assert (
|
||||||
|
self.config.get("replay_buffer_config") is not None
|
||||||
|
), "Must specify replay_buffer_config."
|
||||||
|
replay_buffer_type = self.config["replay_buffer_config"].get("type")
|
||||||
|
assert (
|
||||||
|
replay_buffer_type == MultiAgentSegmentationBuffer
|
||||||
|
), "replay_buffer's type must be MultiAgentSegmentationBuffer."
|
||||||
|
|
||||||
|
# max_seq_len must be specified in model
|
||||||
|
model_max_seq_len = self.config["model"].get("max_seq_len")
|
||||||
|
assert model_max_seq_len is not None, "Must specify model's max_seq_len."
|
||||||
|
|
||||||
|
# User shouldn't need to specify replay_buffer's max_seq_len.
|
||||||
|
# Autofill for replay buffer API. If they did specify, make sure it
|
||||||
|
# matches with model's max_seq_len
|
||||||
|
buffer_max_seq_len = self.config["replay_buffer_config"].get("max_seq_len")
|
||||||
|
if buffer_max_seq_len is None:
|
||||||
|
self.config["replay_buffer_config"]["max_seq_len"] = model_max_seq_len
|
||||||
|
else:
|
||||||
|
assert (
|
||||||
|
buffer_max_seq_len == model_max_seq_len
|
||||||
|
), "replay_buffer's max_seq_len must equal model's max_seq_len."
|
||||||
|
|
||||||
|
# Same thing for buffer's max_ep_len, which should be autofilled from
|
||||||
|
# rollout's horizon, or check that it matches if user specified.
|
||||||
|
buffer_max_ep_len = self.config["replay_buffer_config"].get("max_ep_len")
|
||||||
|
if buffer_max_ep_len is None:
|
||||||
|
self.config["replay_buffer_config"]["max_ep_len"] = self.config["horizon"]
|
||||||
|
else:
|
||||||
|
assert (
|
||||||
|
buffer_max_ep_len == self.config["horizon"]
|
||||||
|
), "replay_buffer's max_ep_len must equal rollout horizon."
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@override(Algorithm)
|
||||||
|
def get_default_config(cls) -> AlgorithmConfigDict:
|
||||||
|
return DTConfig().to_dict()
|
||||||
|
|
||||||
|
@override(Algorithm)
|
||||||
|
def get_default_policy_class(self, config: AlgorithmConfigDict) -> Type[Policy]:
|
||||||
|
if config["framework"] == "torch":
|
||||||
|
from ray.rllib.algorithms.dt.dt_torch_policy import DTTorchPolicy
|
||||||
|
|
||||||
|
return DTTorchPolicy
|
||||||
|
else:
|
||||||
|
raise ValueError("Non-torch frameworks are not supported yet!")
|
||||||
|
|
||||||
|
@override(Algorithm)
|
||||||
|
def training_step(self) -> ResultDict:
|
||||||
|
with self._timers[SAMPLE_TIMER]:
|
||||||
|
# TODO: Add ability to do obs_filter for offline sampling.
|
||||||
|
train_batch = synchronous_parallel_sample(worker_set=self.workers)
|
||||||
|
|
||||||
|
train_batch = train_batch.as_multi_agent()
|
||||||
|
self._counters[NUM_AGENT_STEPS_SAMPLED] += train_batch.agent_steps()
|
||||||
|
self._counters[NUM_ENV_STEPS_SAMPLED] += train_batch.env_steps()
|
||||||
|
|
||||||
|
# Because each sample is a segment of max_seq_len transitions, doing
|
||||||
|
# the division makes it so the total number of transitions per train
|
||||||
|
# step is consistent.
|
||||||
|
num_steps = train_batch.env_steps()
|
||||||
|
batch_size = int(math.ceil(num_steps / self.config["model"]["max_seq_len"]))
|
||||||
|
|
||||||
|
# Add the batch of episodes to the segmentation buffer.
|
||||||
|
self.local_replay_buffer.add(train_batch)
|
||||||
|
# Sample a batch of segments.
|
||||||
|
train_batch = self.local_replay_buffer.sample(batch_size)
|
||||||
|
|
||||||
|
# Postprocess batch before we learn on it.
|
||||||
|
post_fn = self.config.get("before_learn_on_batch") or (lambda b, *a: b)
|
||||||
|
train_batch = post_fn(train_batch, self.workers, self.config)
|
||||||
|
|
||||||
|
# Learn on training batch.
|
||||||
|
# Use simple optimizer (only for multi-agent or tf-eager; all other
|
||||||
|
# cases should use the multi-GPU optimizer, even if only using 1 GPU)
|
||||||
|
if self.config.get("simple_optimizer", False):
|
||||||
|
train_results = train_one_step(self, train_batch)
|
||||||
|
else:
|
||||||
|
train_results = multi_gpu_train_one_step(self, train_batch)
|
||||||
|
|
||||||
|
# Update learning rate scheduler.
|
||||||
|
global_vars = {
|
||||||
|
# Note: this counts the number of segments trained, not timesteps.
|
||||||
|
# i.e. NUM_AGENT_STEPS_TRAINED: B, NUM_AGENT_STEPS_SAMPLED: B*T
|
||||||
|
"timestep": self._counters[NUM_AGENT_STEPS_TRAINED],
|
||||||
|
}
|
||||||
|
self.workers.local_worker().set_global_vars(global_vars)
|
||||||
|
|
||||||
|
return train_results
|
||||||
|
|
||||||
|
@PublicAPI
|
||||||
|
@override(Algorithm)
|
||||||
|
def compute_single_action(
|
||||||
|
self,
|
||||||
|
*args,
|
||||||
|
input_dict: Optional[SampleBatch] = None,
|
||||||
|
full_fetch: bool = True,
|
||||||
|
**kwargs,
|
||||||
|
) -> Tuple[TensorStructType, List[TensorType], Dict[str, TensorType]]:
|
||||||
|
"""Computes an action for the specified policy on the local worker.
|
||||||
|
|
||||||
|
Note that you can also access the policy object through
|
||||||
|
self.get_policy(policy_id) and call compute_single_action() on it
|
||||||
|
directly.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_dict: A SampleBatch taken from get_initial_input_dict or
|
||||||
|
get_next_input_dict.
|
||||||
|
full_fetch: Whether to return extra action fetch results.
|
||||||
|
This is always True for DT.
|
||||||
|
kwargs: forward compatibility args.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A tuple containing: (
|
||||||
|
the computed action,
|
||||||
|
list of RNN states (empty for DT),
|
||||||
|
extra action output (pass to get_next_input_dict),
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
assert input_dict is not None, (
|
||||||
|
"DT must take in input_dict for inference. "
|
||||||
|
"See get_initial_input_dict() and get_next_input_dict()."
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
full_fetch
|
||||||
|
), "DT needs full_fetch=True. Pass extra into get_next_input_dict()."
|
||||||
|
|
||||||
|
return super().compute_single_action(
|
||||||
|
*args, input_dict=input_dict.copy(), full_fetch=full_fetch, **kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
@PublicAPI
|
||||||
|
def get_initial_input_dict(
|
||||||
|
self,
|
||||||
|
observation: TensorStructType,
|
||||||
|
policy_id: PolicyID = DEFAULT_POLICY_ID,
|
||||||
|
) -> SampleBatch:
|
||||||
|
"""Get the initial input_dict to be passed into compute_single_action.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
observation: first (unbatched) observation from env.reset()
|
||||||
|
policy_id: Policy to query (only applies to multi-agent).
|
||||||
|
Default: "default_policy".
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The input_dict for inference.
|
||||||
|
"""
|
||||||
|
policy = self.get_policy(policy_id)
|
||||||
|
return policy.get_initial_input_dict(observation)
|
||||||
|
|
||||||
|
@PublicAPI
|
||||||
|
def get_next_input_dict(
|
||||||
|
self,
|
||||||
|
input_dict: SampleBatch,
|
||||||
|
action: TensorStructType,
|
||||||
|
reward: TensorStructType,
|
||||||
|
next_obs: TensorStructType,
|
||||||
|
extra: Dict[str, TensorType],
|
||||||
|
policy_id: PolicyID = DEFAULT_POLICY_ID,
|
||||||
|
) -> SampleBatch:
|
||||||
|
"""Returns a new input_dict after stepping through the environment once.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_dict: the input dict passed into compute_single_action.
|
||||||
|
action: the (unbatched) action taken this step.
|
||||||
|
reward: the (unbatched) reward from env.step
|
||||||
|
next_obs: the (unbatached) next observation from env.step
|
||||||
|
extra: the extra action out from compute_single_action.
|
||||||
|
For DT this case contains current returns to go *before* the current
|
||||||
|
reward is subtracted from target_return.
|
||||||
|
policy_id: Policy to query (only applies to multi-agent).
|
||||||
|
Default: "default_policy".
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A new input_dict to be passed into compute_single_action.
|
||||||
|
"""
|
||||||
|
policy = self.get_policy(policy_id)
|
||||||
|
return policy.get_next_input_dict(input_dict, action, reward, next_obs, extra)
|
|
@ -58,9 +58,6 @@ class SegmentationBuffer:
|
||||||
self._add_single_episode(episode)
|
self._add_single_episode(episode)
|
||||||
|
|
||||||
def _add_single_episode(self, episode: SampleBatch):
|
def _add_single_episode(self, episode: SampleBatch):
|
||||||
# Truncate if episode too long.
|
|
||||||
# Note: sometimes this happens if the dataset shuffles such that the
|
|
||||||
# same episode is concatenated together twice (which is okay).
|
|
||||||
ep_len = episode.env_steps()
|
ep_len = episode.env_steps()
|
||||||
|
|
||||||
if ep_len > self.max_ep_len:
|
if ep_len > self.max_ep_len:
|
||||||
|
|
269
rllib/algorithms/dt/tests/test_dt.py
Normal file
269
rllib/algorithms/dt/tests/test_dt.py
Normal file
|
@ -0,0 +1,269 @@
|
||||||
|
from pathlib import Path
|
||||||
|
import os
|
||||||
|
import unittest
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
|
import gym
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
import ray
|
||||||
|
from ray.rllib import SampleBatch
|
||||||
|
from ray.rllib.algorithms.dt import DTConfig
|
||||||
|
from ray.rllib.utils.framework import try_import_tf, try_import_torch
|
||||||
|
from ray.rllib.utils.test_utils import check_train_results
|
||||||
|
|
||||||
|
tf1, tf, tfv = try_import_tf()
|
||||||
|
torch, _ = try_import_torch()
|
||||||
|
|
||||||
|
|
||||||
|
def _assert_input_dict_equals(d1: Dict[str, np.ndarray], d2: Dict[str, np.ndarray]):
|
||||||
|
for key in d1.keys():
|
||||||
|
assert key in d2.keys()
|
||||||
|
|
||||||
|
for key in d2.keys():
|
||||||
|
assert key in d1.keys()
|
||||||
|
|
||||||
|
for key in d1.keys():
|
||||||
|
assert isinstance(d1[key], np.ndarray)
|
||||||
|
assert isinstance(d2[key], np.ndarray)
|
||||||
|
assert d1[key].shape == d2[key].shape
|
||||||
|
assert np.allclose(d1[key], d2[key])
|
||||||
|
|
||||||
|
|
||||||
|
class TestDT(unittest.TestCase):
|
||||||
|
@classmethod
|
||||||
|
def setUpClass(cls):
|
||||||
|
ray.init()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def tearDownClass(cls):
|
||||||
|
ray.shutdown()
|
||||||
|
|
||||||
|
def test_dt_compilation(self):
|
||||||
|
"""Test whether a DT algorithm can be built with all supported frameworks."""
|
||||||
|
|
||||||
|
rllib_dir = Path(__file__).parent.parent.parent.parent
|
||||||
|
data_file = os.path.join(rllib_dir, "tests/data/pendulum/large.json")
|
||||||
|
|
||||||
|
input_config = {
|
||||||
|
"paths": data_file,
|
||||||
|
"format": "json",
|
||||||
|
}
|
||||||
|
|
||||||
|
config = (
|
||||||
|
DTConfig()
|
||||||
|
.environment(
|
||||||
|
env="Pendulum-v1",
|
||||||
|
clip_actions=True,
|
||||||
|
normalize_actions=True,
|
||||||
|
)
|
||||||
|
.framework("torch")
|
||||||
|
.offline_data(
|
||||||
|
input_="dataset",
|
||||||
|
input_config=input_config,
|
||||||
|
actions_in_input_normalized=True,
|
||||||
|
)
|
||||||
|
.training(
|
||||||
|
train_batch_size=200,
|
||||||
|
replay_buffer_config={
|
||||||
|
"capacity": 8,
|
||||||
|
},
|
||||||
|
model={
|
||||||
|
"max_seq_len": 4,
|
||||||
|
},
|
||||||
|
num_layers=1,
|
||||||
|
num_heads=1,
|
||||||
|
embed_dim=64,
|
||||||
|
)
|
||||||
|
.evaluation(
|
||||||
|
target_return=-120,
|
||||||
|
evaluation_interval=2,
|
||||||
|
evaluation_num_workers=0,
|
||||||
|
evaluation_duration=10,
|
||||||
|
evaluation_duration_unit="episodes",
|
||||||
|
evaluation_parallel_to_training=False,
|
||||||
|
evaluation_config={"input": "sampler", "explore": False},
|
||||||
|
)
|
||||||
|
.rollouts(
|
||||||
|
num_rollout_workers=0,
|
||||||
|
horizon=200,
|
||||||
|
)
|
||||||
|
.reporting(
|
||||||
|
min_train_timesteps_per_iteration=10,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
num_iterations = 4
|
||||||
|
|
||||||
|
for _ in ["torch"]:
|
||||||
|
algo = config.build()
|
||||||
|
# check if 4 iterations raises any errors
|
||||||
|
for i in range(num_iterations):
|
||||||
|
results = algo.train()
|
||||||
|
check_train_results(results)
|
||||||
|
print(results)
|
||||||
|
if (i + 1) % 2 == 0:
|
||||||
|
# evaluation happens every 2 iterations
|
||||||
|
eval_results = results["evaluation"]
|
||||||
|
print(
|
||||||
|
f"iter={algo.iteration} "
|
||||||
|
f"R={eval_results['episode_reward_mean']}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# do example inference rollout
|
||||||
|
env = gym.make("Pendulum-v1")
|
||||||
|
|
||||||
|
obs = env.reset()
|
||||||
|
input_dict = algo.get_initial_input_dict(obs)
|
||||||
|
|
||||||
|
for _ in range(200):
|
||||||
|
action, _, extra = algo.compute_single_action(input_dict=input_dict)
|
||||||
|
obs, reward, done, _ = env.step(action)
|
||||||
|
if done:
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
input_dict = algo.get_next_input_dict(
|
||||||
|
input_dict,
|
||||||
|
action,
|
||||||
|
reward,
|
||||||
|
obs,
|
||||||
|
extra,
|
||||||
|
)
|
||||||
|
|
||||||
|
env.close()
|
||||||
|
algo.stop()
|
||||||
|
|
||||||
|
def test_inference_methods(self):
|
||||||
|
"""Test inference methods."""
|
||||||
|
|
||||||
|
config = (
|
||||||
|
DTConfig()
|
||||||
|
.environment(
|
||||||
|
env="Pendulum-v1",
|
||||||
|
clip_actions=True,
|
||||||
|
normalize_actions=True,
|
||||||
|
)
|
||||||
|
.framework("torch")
|
||||||
|
.training(
|
||||||
|
train_batch_size=200,
|
||||||
|
replay_buffer_config={
|
||||||
|
"capacity": 8,
|
||||||
|
},
|
||||||
|
model={
|
||||||
|
"max_seq_len": 3,
|
||||||
|
},
|
||||||
|
num_layers=1,
|
||||||
|
num_heads=1,
|
||||||
|
embed_dim=64,
|
||||||
|
)
|
||||||
|
.evaluation(
|
||||||
|
target_return=-120,
|
||||||
|
)
|
||||||
|
.rollouts(
|
||||||
|
num_rollout_workers=0,
|
||||||
|
horizon=200,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
algo = config.build()
|
||||||
|
|
||||||
|
# Do a controlled fake rollout for 2 steps and check input_dict
|
||||||
|
# first input_dict
|
||||||
|
obs = np.array([0.0, 1.0, 2.0])
|
||||||
|
|
||||||
|
input_dict = algo.get_initial_input_dict(obs)
|
||||||
|
target = SampleBatch(
|
||||||
|
{
|
||||||
|
SampleBatch.OBS: np.array(
|
||||||
|
[
|
||||||
|
[0.0, 0.0, 0.0],
|
||||||
|
[0.0, 0.0, 0.0],
|
||||||
|
[0.0, 1.0, 2.0],
|
||||||
|
],
|
||||||
|
dtype=np.float32,
|
||||||
|
),
|
||||||
|
SampleBatch.ACTIONS: np.array([[0.0], [0.0]], dtype=np.float32),
|
||||||
|
SampleBatch.RETURNS_TO_GO: np.array([0.0, 0.0], dtype=np.float32),
|
||||||
|
SampleBatch.REWARDS: np.zeros((), dtype=np.float32),
|
||||||
|
SampleBatch.T: np.array([-1, -1], dtype=np.int32),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
_assert_input_dict_equals(input_dict, target)
|
||||||
|
|
||||||
|
# forward pass with first input_dict
|
||||||
|
action, _, extra = algo.compute_single_action(input_dict=input_dict)
|
||||||
|
assert action.shape == (1,)
|
||||||
|
assert SampleBatch.RETURNS_TO_GO in extra
|
||||||
|
assert np.isclose(extra[SampleBatch.RETURNS_TO_GO], -120.0)
|
||||||
|
|
||||||
|
# second input_dict
|
||||||
|
action = np.array([0.5])
|
||||||
|
obs = np.array([3.0, 4.0, 5.0])
|
||||||
|
reward = -10.0
|
||||||
|
|
||||||
|
input_dict = algo.get_next_input_dict(
|
||||||
|
input_dict,
|
||||||
|
action,
|
||||||
|
reward,
|
||||||
|
obs,
|
||||||
|
extra,
|
||||||
|
)
|
||||||
|
target = SampleBatch(
|
||||||
|
{
|
||||||
|
SampleBatch.OBS: np.array(
|
||||||
|
[
|
||||||
|
[0.0, 0.0, 0.0],
|
||||||
|
[0.0, 1.0, 2.0],
|
||||||
|
[3.0, 4.0, 5.0],
|
||||||
|
],
|
||||||
|
dtype=np.float32,
|
||||||
|
),
|
||||||
|
SampleBatch.ACTIONS: np.array([[0.0], [0.5]], dtype=np.float32),
|
||||||
|
SampleBatch.RETURNS_TO_GO: np.array([0.0, -120.0], dtype=np.float32),
|
||||||
|
SampleBatch.REWARDS: np.asarray(-10.0),
|
||||||
|
SampleBatch.T: np.array([-1, 0], dtype=np.int32),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
_assert_input_dict_equals(input_dict, target)
|
||||||
|
|
||||||
|
# forward pass with second input_dict
|
||||||
|
action, _, extra = algo.compute_single_action(input_dict=input_dict)
|
||||||
|
assert action.shape == (1,)
|
||||||
|
assert SampleBatch.RETURNS_TO_GO in extra
|
||||||
|
assert np.isclose(extra[SampleBatch.RETURNS_TO_GO], -110.0)
|
||||||
|
|
||||||
|
# third input_dict
|
||||||
|
action = np.array([-0.2])
|
||||||
|
obs = np.array([6.0, 7.0, 8.0])
|
||||||
|
reward = -20.0
|
||||||
|
|
||||||
|
input_dict = algo.get_next_input_dict(
|
||||||
|
input_dict,
|
||||||
|
action,
|
||||||
|
reward,
|
||||||
|
obs,
|
||||||
|
extra,
|
||||||
|
)
|
||||||
|
target = SampleBatch(
|
||||||
|
{
|
||||||
|
SampleBatch.OBS: np.array(
|
||||||
|
[
|
||||||
|
[0.0, 1.0, 2.0],
|
||||||
|
[3.0, 4.0, 5.0],
|
||||||
|
[6.0, 7.0, 8.0],
|
||||||
|
],
|
||||||
|
dtype=np.float32,
|
||||||
|
),
|
||||||
|
SampleBatch.ACTIONS: np.array([[0.5], [-0.2]], dtype=np.float32),
|
||||||
|
SampleBatch.RETURNS_TO_GO: np.array([-120, -110.0], dtype=np.float32),
|
||||||
|
SampleBatch.REWARDS: np.asarray(-20.0),
|
||||||
|
SampleBatch.T: np.array([0, 1], dtype=np.int32),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
_assert_input_dict_equals(input_dict, target)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import pytest
|
||||||
|
import sys
|
||||||
|
|
||||||
|
sys.exit(pytest.main(["-v", __file__]))
|
|
@ -108,6 +108,12 @@ def _import_dreamer():
|
||||||
return dreamer.Dreamer, dreamer.DreamerConfig().to_dict()
|
return dreamer.Dreamer, dreamer.DreamerConfig().to_dict()
|
||||||
|
|
||||||
|
|
||||||
|
def _import_dt():
|
||||||
|
import ray.rllib.algorithms.dt as dt
|
||||||
|
|
||||||
|
return dt.DT, dt.DTConfig().to_dict()
|
||||||
|
|
||||||
|
|
||||||
def _import_es():
|
def _import_es():
|
||||||
import ray.rllib.algorithms.es as es
|
import ray.rllib.algorithms.es as es
|
||||||
|
|
||||||
|
@ -215,6 +221,7 @@ ALGORITHMS = {
|
||||||
"DDPPO": _import_ddppo,
|
"DDPPO": _import_ddppo,
|
||||||
"DQN": _import_dqn,
|
"DQN": _import_dqn,
|
||||||
"Dreamer": _import_dreamer,
|
"Dreamer": _import_dreamer,
|
||||||
|
"DT": _import_dt,
|
||||||
"IMPALA": _import_impala,
|
"IMPALA": _import_impala,
|
||||||
"APPO": _import_appo,
|
"APPO": _import_appo,
|
||||||
"AlphaStar": _import_alpha_star,
|
"AlphaStar": _import_alpha_star,
|
||||||
|
@ -309,6 +316,7 @@ POLICIES = {
|
||||||
"DQNTFPolicy": "dqn.dqn_tf_policy",
|
"DQNTFPolicy": "dqn.dqn_tf_policy",
|
||||||
"DQNTorchPolicy": "dqn.dqn_torch_policy",
|
"DQNTorchPolicy": "dqn.dqn_torch_policy",
|
||||||
"DreamerTorchPolicy": "dreamer.dreamer_torch_policy",
|
"DreamerTorchPolicy": "dreamer.dreamer_torch_policy",
|
||||||
|
"DTTorchPolicy": "dt.dt_torch_policy",
|
||||||
"ESTFPolicy": "es.es_tf_policy",
|
"ESTFPolicy": "es.es_tf_policy",
|
||||||
"ESTorchPolicy": "es.es_torch_policy",
|
"ESTorchPolicy": "es.es_torch_policy",
|
||||||
"ImpalaTF1Policy": "impala.impala_tf_policy",
|
"ImpalaTF1Policy": "impala.impala_tf_policy",
|
||||||
|
|
|
@ -0,0 +1,182 @@
|
||||||
|
"""
|
||||||
|
Example showing how you can use your trained Decision Transformer (DT) policy for
|
||||||
|
inference (computing actions) in an environment.
|
||||||
|
"""
|
||||||
|
import argparse
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import gym
|
||||||
|
import os
|
||||||
|
|
||||||
|
import ray
|
||||||
|
from ray import air, tune
|
||||||
|
from ray.rllib.algorithms.dt import DTConfig
|
||||||
|
from ray.rllib.algorithms.registry import get_algorithm_class
|
||||||
|
from ray.tune.utils.log import Verbosity
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("--num-cpus", type=int, default=0)
|
||||||
|
parser.add_argument(
|
||||||
|
"--input-files",
|
||||||
|
nargs="+",
|
||||||
|
default=[
|
||||||
|
os.path.join(
|
||||||
|
os.path.dirname(os.path.abspath(__file__)),
|
||||||
|
"../../tests/data/cartpole/large.json",
|
||||||
|
)
|
||||||
|
],
|
||||||
|
help="List of paths to offline json files/zips for training.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--num-episodes-during-inference",
|
||||||
|
type=int,
|
||||||
|
default=10,
|
||||||
|
help="Number of episodes to do inference over after training.",
|
||||||
|
)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
ray.init(num_cpus=args.num_cpus or None)
|
||||||
|
|
||||||
|
# Bazel makes it hard to find files specified in `args` (and `data`).
|
||||||
|
# Look for them here.
|
||||||
|
input_files = []
|
||||||
|
for input_file in args.input_files:
|
||||||
|
if not os.path.exists(input_file):
|
||||||
|
# This script runs in the ray/rllib/examples/inference_and_serving dir.
|
||||||
|
rllib_dir = Path(__file__).parent.parent.parent
|
||||||
|
input_dir = rllib_dir.absolute().joinpath(input_file)
|
||||||
|
input_files.append(str(input_dir))
|
||||||
|
else:
|
||||||
|
input_files.append(input_file)
|
||||||
|
|
||||||
|
# Get max_ep_len
|
||||||
|
env = gym.make("CartPole-v0")
|
||||||
|
max_ep_len = env.spec.max_episode_steps
|
||||||
|
env.close()
|
||||||
|
|
||||||
|
# Training config
|
||||||
|
config = (
|
||||||
|
DTConfig()
|
||||||
|
.environment(
|
||||||
|
env="CartPole-v0",
|
||||||
|
clip_actions=False,
|
||||||
|
normalize_actions=False,
|
||||||
|
)
|
||||||
|
.framework("torch")
|
||||||
|
.offline_data(
|
||||||
|
input_="dataset",
|
||||||
|
input_config={
|
||||||
|
"format": "json",
|
||||||
|
"paths": input_files,
|
||||||
|
},
|
||||||
|
actions_in_input_normalized=True,
|
||||||
|
)
|
||||||
|
.training(
|
||||||
|
lr=0.01,
|
||||||
|
optimizer={
|
||||||
|
"weight_decay": 0.1,
|
||||||
|
"betas": [0.9, 0.999],
|
||||||
|
},
|
||||||
|
train_batch_size=512,
|
||||||
|
replay_buffer_config={
|
||||||
|
"capacity": 20,
|
||||||
|
},
|
||||||
|
model={
|
||||||
|
"max_seq_len": 3,
|
||||||
|
},
|
||||||
|
num_layers=1,
|
||||||
|
num_heads=1,
|
||||||
|
embed_dim=64,
|
||||||
|
)
|
||||||
|
# Need to do evaluation rollouts for stopping condition.
|
||||||
|
.evaluation(
|
||||||
|
target_return=200.0,
|
||||||
|
evaluation_interval=1,
|
||||||
|
evaluation_num_workers=1,
|
||||||
|
evaluation_duration=10,
|
||||||
|
evaluation_duration_unit="episodes",
|
||||||
|
evaluation_parallel_to_training=False,
|
||||||
|
evaluation_config={"input": "sampler", "explore": False},
|
||||||
|
)
|
||||||
|
.rollouts(
|
||||||
|
num_rollout_workers=0,
|
||||||
|
# This needs to be specified
|
||||||
|
horizon=max_ep_len,
|
||||||
|
)
|
||||||
|
.reporting(
|
||||||
|
min_train_timesteps_per_iteration=5000,
|
||||||
|
)
|
||||||
|
.resources(
|
||||||
|
# Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0.
|
||||||
|
num_gpus=int(os.environ.get("RLLIB_NUM_GPUS", "0")),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
config = config.to_dict()
|
||||||
|
|
||||||
|
# Configure when to stop training
|
||||||
|
# Note that for an offline RL algorithm, we don't do training rollouts,
|
||||||
|
# instead we have to rely on evaluation rollouts.
|
||||||
|
stop = {
|
||||||
|
"evaluation/episode_reward_mean": 200.0,
|
||||||
|
"training_iteration": 100,
|
||||||
|
}
|
||||||
|
|
||||||
|
print("Training policy until desired reward/iterations. ...")
|
||||||
|
tuner = tune.Tuner(
|
||||||
|
"DT",
|
||||||
|
param_space=config,
|
||||||
|
run_config=air.RunConfig(
|
||||||
|
stop=stop,
|
||||||
|
verbose=Verbosity.V3_TRIAL_DETAILS,
|
||||||
|
checkpoint_config=air.CheckpointConfig(
|
||||||
|
checkpoint_frequency=1,
|
||||||
|
checkpoint_at_end=True,
|
||||||
|
),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
results = tuner.fit()
|
||||||
|
|
||||||
|
print("Training completed. Restoring new Algorithm for action inference.")
|
||||||
|
# Get the last checkpoint from the above training run.
|
||||||
|
checkpoint = results.get_best_result().checkpoint
|
||||||
|
# Create new Algorithm and restore its state from the last checkpoint.
|
||||||
|
algo = get_algorithm_class("DT")(config=config)
|
||||||
|
algo.restore(checkpoint)
|
||||||
|
|
||||||
|
# Create the env to do inference in.
|
||||||
|
env = gym.make("CartPole-v0")
|
||||||
|
|
||||||
|
obs = env.reset()
|
||||||
|
input_dict = algo.get_initial_input_dict(obs)
|
||||||
|
|
||||||
|
num_episodes = 0
|
||||||
|
total_rewards = 0.0
|
||||||
|
|
||||||
|
while num_episodes < args.num_episodes_during_inference:
|
||||||
|
# Compute an action (`a`).
|
||||||
|
a, _, extra = algo.compute_single_action(input_dict=input_dict)
|
||||||
|
# Send the computed action `a` to the env.
|
||||||
|
obs, reward, done, _ = env.step(a)
|
||||||
|
# Add to total rewards.
|
||||||
|
total_rewards += reward
|
||||||
|
# Is the episode `done`? -> Reset.
|
||||||
|
if done:
|
||||||
|
print(f"Episode {num_episodes+1} - return: {total_rewards}")
|
||||||
|
obs = env.reset()
|
||||||
|
input_dict = algo.get_initial_input_dict(obs)
|
||||||
|
num_episodes += 1
|
||||||
|
total_rewards = 0.0
|
||||||
|
# Episode is still ongoing -> Continue.
|
||||||
|
else:
|
||||||
|
input_dict = algo.get_next_input_dict(
|
||||||
|
input_dict,
|
||||||
|
a,
|
||||||
|
reward,
|
||||||
|
obs,
|
||||||
|
extra,
|
||||||
|
)
|
||||||
|
|
||||||
|
env.close()
|
||||||
|
ray.shutdown()
|
BIN
rllib/tests/data/pendulum/pendulum_expert_sac_50eps.zip
Normal file
BIN
rllib/tests/data/pendulum/pendulum_expert_sac_50eps.zip
Normal file
Binary file not shown.
BIN
rllib/tests/data/pendulum/pendulum_medium_sac_50eps.zip
Normal file
BIN
rllib/tests/data/pendulum/pendulum_medium_sac_50eps.zip
Normal file
Binary file not shown.
42
rllib/tuned_examples/dt/cartpole-v0-dt.yaml
Normal file
42
rllib/tuned_examples/dt/cartpole-v0-dt.yaml
Normal file
|
@ -0,0 +1,42 @@
|
||||||
|
cartpole_dt:
|
||||||
|
env: 'CartPole-v0'
|
||||||
|
run: DT
|
||||||
|
stop:
|
||||||
|
evaluation/episode_reward_mean: 200
|
||||||
|
training_iteration: 100
|
||||||
|
config:
|
||||||
|
input: 'dataset'
|
||||||
|
input_config:
|
||||||
|
paths: 'tests/data/cartpole/large.json'
|
||||||
|
format: 'json'
|
||||||
|
num_workers: 3
|
||||||
|
actions_in_input_normalized: True
|
||||||
|
clip_actions: False
|
||||||
|
# training
|
||||||
|
framework: torch
|
||||||
|
train_batch_size: 512
|
||||||
|
min_train_timesteps_per_iteration: 5000
|
||||||
|
target_return: 200
|
||||||
|
lr: 0.01
|
||||||
|
optimizer:
|
||||||
|
weight_decay: 0.1
|
||||||
|
betas: [0.9, 0.999]
|
||||||
|
replay_buffer_config:
|
||||||
|
capacity: 20
|
||||||
|
# model
|
||||||
|
model:
|
||||||
|
max_seq_len: 3
|
||||||
|
num_layers: 1
|
||||||
|
num_heads: 1
|
||||||
|
embed_dim: 64
|
||||||
|
# rollout
|
||||||
|
horizon: 200
|
||||||
|
# evaluation
|
||||||
|
evaluation_config:
|
||||||
|
explore: False
|
||||||
|
input: sampler
|
||||||
|
evaluation_duration: 10
|
||||||
|
evaluation_duration_unit: episodes
|
||||||
|
evaluation_interval: 1
|
||||||
|
evaluation_num_workers: 1
|
||||||
|
evaluation_parallel_to_training: True
|
46
rllib/tuned_examples/dt/pendulum-v1-dt.yaml
Normal file
46
rllib/tuned_examples/dt/pendulum-v1-dt.yaml
Normal file
|
@ -0,0 +1,46 @@
|
||||||
|
pendulum_dt:
|
||||||
|
env: 'Pendulum-v1'
|
||||||
|
run: DT
|
||||||
|
stop:
|
||||||
|
# We could make this higher, but given that we have 4 cpus for our tests, we will have to settle for -300.
|
||||||
|
evaluation/episode_reward_mean: -300
|
||||||
|
timesteps_total: 20000000
|
||||||
|
config:
|
||||||
|
input: 'dataset'
|
||||||
|
input_config:
|
||||||
|
paths: 'tests/data/pendulum/pendulum_expert_sac_50eps.zip'
|
||||||
|
format: 'json'
|
||||||
|
num_workers: 3
|
||||||
|
actions_in_input_normalized: True
|
||||||
|
clip_actions: True
|
||||||
|
normalize_actions: True
|
||||||
|
# training
|
||||||
|
framework: torch
|
||||||
|
train_batch_size: 512
|
||||||
|
min_train_timesteps_per_iteration: 5000
|
||||||
|
target_return: -120.0
|
||||||
|
lr: 0.0
|
||||||
|
lr_schedule: [[0, 0.0], [10000, 0.01]]
|
||||||
|
grad_clip: 1.0
|
||||||
|
optimizer:
|
||||||
|
weight_decay: 0.1
|
||||||
|
betas: [0.9, 0.999]
|
||||||
|
replay_buffer_config:
|
||||||
|
capacity: 20
|
||||||
|
# model
|
||||||
|
model:
|
||||||
|
max_seq_len: 3
|
||||||
|
num_layers: 1
|
||||||
|
num_heads: 1
|
||||||
|
embed_dim: 64
|
||||||
|
# rollout
|
||||||
|
horizon: 200
|
||||||
|
# evaluation
|
||||||
|
evaluation_config:
|
||||||
|
explore: False
|
||||||
|
input: sampler
|
||||||
|
evaluation_duration: 10
|
||||||
|
evaluation_duration_unit: episodes
|
||||||
|
evaluation_interval: 1
|
||||||
|
evaluation_num_workers: 1
|
||||||
|
evaluation_parallel_to_training: True
|
49
rllib/tuned_examples/dt/pendulum-v1-medium-expert-dt.yaml
Normal file
49
rllib/tuned_examples/dt/pendulum-v1-medium-expert-dt.yaml
Normal file
|
@ -0,0 +1,49 @@
|
||||||
|
pendulum_medium_expert_dt:
|
||||||
|
env: 'Pendulum-v1'
|
||||||
|
run: DT
|
||||||
|
stop:
|
||||||
|
# We could make this higher, but given that we have 4 cpus for our tests, we will have to settle for -350.
|
||||||
|
evaluation/episode_reward_mean: -350
|
||||||
|
timesteps_total: 20000000
|
||||||
|
config:
|
||||||
|
input: 'dataset'
|
||||||
|
input_config:
|
||||||
|
paths: [
|
||||||
|
'tests/data/pendulum/pendulum_expert_sac_50eps.zip',
|
||||||
|
'tests/data/pendulum/pendulum_medium_sac_50eps.zip',
|
||||||
|
]
|
||||||
|
format: 'json'
|
||||||
|
num_workers: 3
|
||||||
|
actions_in_input_normalized: True
|
||||||
|
clip_actions: True
|
||||||
|
normalize_actions: True
|
||||||
|
# training
|
||||||
|
framework: torch
|
||||||
|
train_batch_size: 512
|
||||||
|
min_train_timesteps_per_iteration: 5000
|
||||||
|
target_return: -120.0
|
||||||
|
lr: 0.0
|
||||||
|
lr_schedule: [[0, 0.0], [100000, 0.01]]
|
||||||
|
grad_clip: 1.0
|
||||||
|
optimizer:
|
||||||
|
weight_decay: 0.1
|
||||||
|
betas: [0.9, 0.999]
|
||||||
|
replay_buffer_config:
|
||||||
|
capacity: 20
|
||||||
|
# model
|
||||||
|
model:
|
||||||
|
max_seq_len: 3
|
||||||
|
num_layers: 1
|
||||||
|
num_heads: 1
|
||||||
|
embed_dim: 64
|
||||||
|
# rollout
|
||||||
|
horizon: 200
|
||||||
|
# evaluation
|
||||||
|
evaluation_config:
|
||||||
|
explore: False
|
||||||
|
input: sampler
|
||||||
|
evaluation_duration: 10
|
||||||
|
evaluation_duration_unit: episodes
|
||||||
|
evaluation_interval: 1
|
||||||
|
evaluation_num_workers: 1
|
||||||
|
evaluation_parallel_to_training: True
|
Loading…
Add table
Reference in a new issue