mirror of
https://github.com/vale981/ray
synced 2025-03-05 10:01:43 -05:00
[RLlib] Make Dataset reader default reader and enable CRR to use dataset (#26304)
Co-authored-by: avnish <avnish@avnishs-MBP.local.meter>
This commit is contained in:
parent
61c9e761f3
commit
1243ed62bf
18 changed files with 277 additions and 216 deletions
22
rllib/BUILD
22
rllib/BUILD
|
@ -254,17 +254,17 @@ py_test(
|
|||
|
||||
# CRR
|
||||
py_test(
|
||||
name = "learning_tests_pendulum_crr",
|
||||
main = "tests/run_regression_tests.py",
|
||||
tags = ["team:rllib", "torch_only", "learning_tests", "learning_tests_pendulum", "learning_tests_continuous"],
|
||||
size = "large",
|
||||
srcs = ["tests/run_regression_tests.py"],
|
||||
# Include an offline json data file as well.
|
||||
data = [
|
||||
"tuned_examples/crr/pendulum-v1-crr.yaml",
|
||||
"tests/data/pendulum/pendulum_replay_v1.1.0.zip",
|
||||
],
|
||||
args = ["--yaml-dir=tuned_examples/crr"]
|
||||
name = "learning_tests_pendulum_crr",
|
||||
main = "tests/run_regression_tests.py",
|
||||
tags = ["team:rllib", "torch_only", "learning_tests", "learning_tests_pendulum", "learning_tests_continuous"],
|
||||
size = "large",
|
||||
srcs = ["tests/run_regression_tests.py"],
|
||||
# Include an offline json data file as well.
|
||||
data = [
|
||||
"tuned_examples/crr/pendulum-v1-crr.yaml",
|
||||
"tests/data/pendulum/pendulum_replay_v1.1.0.zip",
|
||||
],
|
||||
args = ["--yaml-dir=tuned_examples/crr"]
|
||||
)
|
||||
|
||||
py_test(
|
||||
|
|
|
@ -14,7 +14,6 @@ from ray.rllib.execution.train_ops import (
|
|||
multi_gpu_train_one_step,
|
||||
train_one_step,
|
||||
)
|
||||
from ray.rllib.utils.replay_buffers.utils import sample_min_n_steps_from_buffer
|
||||
from ray.rllib.policy.policy import Policy
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils.deprecation import (
|
||||
|
@ -32,8 +31,8 @@ from ray.rllib.utils.metrics import (
|
|||
NUM_TARGET_UPDATES,
|
||||
TARGET_NET_UPDATE_TIMER,
|
||||
SYNCH_WORKER_WEIGHTS_TIMER,
|
||||
SAMPLE_TIMER,
|
||||
)
|
||||
from ray.rllib.utils.replay_buffers.utils import update_priorities_in_replay_buffer
|
||||
from ray.rllib.utils.typing import ResultDict, AlgorithmConfigDict
|
||||
|
||||
tf1, tf, tfv = try_import_tf()
|
||||
|
@ -177,23 +176,11 @@ class CQL(SAC):
|
|||
@override(SAC)
|
||||
def training_step(self) -> ResultDict:
|
||||
# Collect SampleBatches from sample workers.
|
||||
batch = synchronous_parallel_sample(worker_set=self.workers)
|
||||
batch = batch.as_multi_agent()
|
||||
self._counters[NUM_AGENT_STEPS_SAMPLED] += batch.agent_steps()
|
||||
self._counters[NUM_ENV_STEPS_SAMPLED] += batch.env_steps()
|
||||
# Add batch to replay buffer.
|
||||
self.local_replay_buffer.add(batch)
|
||||
|
||||
# Sample training batch from replay buffer.
|
||||
train_batch = sample_min_n_steps_from_buffer(
|
||||
self.local_replay_buffer,
|
||||
self.config["train_batch_size"],
|
||||
count_by_agent_steps=self._by_agent_steps,
|
||||
)
|
||||
|
||||
# Old-style replay buffers return None if learning has not started
|
||||
if not train_batch:
|
||||
return {}
|
||||
with self._timers[SAMPLE_TIMER]:
|
||||
train_batch = synchronous_parallel_sample(worker_set=self.workers)
|
||||
train_batch = train_batch.as_multi_agent()
|
||||
self._counters[NUM_AGENT_STEPS_SAMPLED] += train_batch.agent_steps()
|
||||
self._counters[NUM_ENV_STEPS_SAMPLED] += train_batch.env_steps()
|
||||
|
||||
# Postprocess batch before we learn on it.
|
||||
post_fn = self.config.get("before_learn_on_batch") or (lambda b, *a: b)
|
||||
|
@ -207,14 +194,6 @@ class CQL(SAC):
|
|||
else:
|
||||
train_results = multi_gpu_train_one_step(self, train_batch)
|
||||
|
||||
# Update replay buffer priorities.
|
||||
update_priorities_in_replay_buffer(
|
||||
self.local_replay_buffer,
|
||||
self.config,
|
||||
train_batch,
|
||||
train_results,
|
||||
)
|
||||
|
||||
# Update target network every `target_network_update_freq` training steps.
|
||||
cur_ts = self._counters[
|
||||
NUM_AGENT_STEPS_TRAINED if self._by_agent_steps else NUM_ENV_STEPS_TRAINED
|
||||
|
|
|
@ -69,7 +69,8 @@ class TestCQL(unittest.TestCase):
|
|||
evaluation_parallel_to_training=False,
|
||||
evaluation_num_workers=2,
|
||||
)
|
||||
.rollouts(rollout_fragment_length=1)
|
||||
.rollouts(num_rollout_workers=0)
|
||||
.reporting(min_time_s_per_iteration=0.0)
|
||||
)
|
||||
num_iterations = 4
|
||||
|
||||
|
@ -85,7 +86,6 @@ class TestCQL(unittest.TestCase):
|
|||
f"iter={trainer.iteration} "
|
||||
f"R={eval_results['episode_reward_mean']}"
|
||||
)
|
||||
|
||||
check_compute_single_action(trainer)
|
||||
|
||||
# Get policy and model.
|
||||
|
@ -97,9 +97,9 @@ class TestCQL(unittest.TestCase):
|
|||
# Example on how to do evaluation on the trained Trainer
|
||||
# using the data from CQL's global replay buffer.
|
||||
# Get a sample (MultiAgentBatch).
|
||||
multi_agent_batch = trainer.local_replay_buffer.sample(
|
||||
num_items=config.train_batch_size
|
||||
)
|
||||
|
||||
batch = trainer.workers.local_worker().input_reader.next()
|
||||
multi_agent_batch = batch.as_multi_agent()
|
||||
# All experiences have been buffered for `default_policy`
|
||||
batch = multi_agent_batch.policy_batches["default_policy"]
|
||||
|
||||
|
|
|
@ -1,21 +1,21 @@
|
|||
import logging
|
||||
from typing import List, Optional, Type
|
||||
|
||||
import numpy as np
|
||||
import tree
|
||||
|
||||
from ray.rllib.algorithms.algorithm import Algorithm, AlgorithmConfig
|
||||
from ray.rllib.execution import synchronous_parallel_sample
|
||||
from ray.rllib.execution.train_ops import multi_gpu_train_one_step, train_one_step
|
||||
from ray.rllib.offline.shuffled_input import ShuffledInput
|
||||
from ray.rllib.policy import Policy
|
||||
from ray.rllib.policy.sample_batch import SampleBatch
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils.metrics import (
|
||||
LAST_TARGET_UPDATE_TS,
|
||||
NUM_AGENT_STEPS_TRAINED,
|
||||
NUM_ENV_STEPS_TRAINED,
|
||||
NUM_TARGET_UPDATES,
|
||||
TARGET_NET_UPDATE_TIMER,
|
||||
NUM_AGENT_STEPS_SAMPLED,
|
||||
NUM_ENV_STEPS_SAMPLED,
|
||||
SAMPLE_TIMER,
|
||||
)
|
||||
from ray.rllib.utils.replay_buffers import MultiAgentReplayBuffer
|
||||
from ray.rllib.utils.typing import (
|
||||
AlgorithmConfigDict,
|
||||
PartialAlgorithmConfigDict,
|
||||
|
@ -38,19 +38,13 @@ class CRRConfig(AlgorithmConfig):
|
|||
self.advantage_type = "mean"
|
||||
self.n_action_sample = 4
|
||||
self.twin_q = True
|
||||
self.target_update_grad_intervals = 100
|
||||
self.train_batch_size = 128
|
||||
|
||||
# target_network_update_freq by default is 100 * train_batch_size
|
||||
# if target_network_update_freq is not set. See self.setup for code.
|
||||
self.target_network_update_freq = None
|
||||
# __sphinx_doc_end__
|
||||
# fmt: on
|
||||
self.replay_buffer_config = {
|
||||
"type": MultiAgentReplayBuffer,
|
||||
"capacity": 50000,
|
||||
# How many steps of the model to sample before learning starts.
|
||||
"learning_starts": 1000,
|
||||
"replay_batch_size": 32,
|
||||
# The number of contiguous environment steps to replay at once. This
|
||||
# may be set to greater than 1 to support recurrent models.
|
||||
"replay_sequence_length": 1,
|
||||
}
|
||||
self.actor_hiddens = [256, 256]
|
||||
self.actor_hidden_activation = "relu"
|
||||
self.critic_hiddens = [256, 256]
|
||||
|
@ -60,7 +54,10 @@ class CRRConfig(AlgorithmConfig):
|
|||
self.tau = 5e-3
|
||||
|
||||
# overriding the trainer config default
|
||||
self.num_workers = 0 # offline RL does not need rollout workers
|
||||
# If data ingestion/sample_time is slow, increase this
|
||||
self.num_workers = 4
|
||||
self.offline_sampling = True
|
||||
self.min_iter_time_s = 10.0
|
||||
|
||||
def training(
|
||||
self,
|
||||
|
@ -71,8 +68,7 @@ class CRRConfig(AlgorithmConfig):
|
|||
advantage_type: Optional[str] = None,
|
||||
n_action_sample: Optional[int] = None,
|
||||
twin_q: Optional[bool] = None,
|
||||
target_update_grad_intervals: Optional[int] = None,
|
||||
replay_buffer_config: Optional[dict] = None,
|
||||
target_network_update_freq: Optional[int] = None,
|
||||
actor_hiddens: Optional[List[int]] = None,
|
||||
actor_hidden_activation: Optional[str] = None,
|
||||
critic_hiddens: Optional[List[int]] = None,
|
||||
|
@ -110,10 +106,9 @@ class CRRConfig(AlgorithmConfig):
|
|||
a^j)]
|
||||
n_action_sample: the number of actions to sample for v_t estimation.
|
||||
twin_q: if True, uses pessimistic q estimation.
|
||||
target_update_grad_intervals: The frequency at which we update the
|
||||
target_network_update_freq: The frequency at which we update the
|
||||
target copy of the model in terms of the number of gradient updates
|
||||
applied to the main model.
|
||||
replay_buffer_config: The config dictionary for replay buffer.
|
||||
actor_hiddens: The number of hidden units in the actor's fc network.
|
||||
actor_hidden_activation: The activation used in the actor's fc network.
|
||||
critic_hiddens: The number of hidden units in the critic's fc network.
|
||||
|
@ -139,10 +134,8 @@ class CRRConfig(AlgorithmConfig):
|
|||
self.n_action_sample = n_action_sample
|
||||
if twin_q is not None:
|
||||
self.twin_q = twin_q
|
||||
if target_update_grad_intervals is not None:
|
||||
self.target_update_grad_intervals = target_update_grad_intervals
|
||||
if replay_buffer_config is not None:
|
||||
self.replay_buffer_config = replay_buffer_config
|
||||
if target_network_update_freq is not None:
|
||||
self.target_network_update_freq = target_network_update_freq
|
||||
if actor_hiddens is not None:
|
||||
self.actor_hiddens = actor_hiddens
|
||||
if actor_hidden_activation is not None:
|
||||
|
@ -168,44 +161,10 @@ class CRR(Algorithm):
|
|||
|
||||
def setup(self, config: PartialAlgorithmConfigDict):
|
||||
super().setup(config)
|
||||
# initial setup for handling the offline data in form of a replay buffer
|
||||
# Add the entire dataset to Replay Buffer (global variable)
|
||||
reader = self.workers.local_worker().input_reader
|
||||
|
||||
# For d4rl, add the D4RLReaders' dataset to the buffer.
|
||||
if isinstance(self.config["input"], str) and "d4rl" in self.config["input"]:
|
||||
dataset = reader.dataset
|
||||
self.local_replay_buffer.add(dataset)
|
||||
# For a list of files, add each file's entire content to the buffer.
|
||||
elif isinstance(reader, ShuffledInput):
|
||||
num_batches = 0
|
||||
total_timesteps = 0
|
||||
for batch in reader.child.read_all_files():
|
||||
num_batches += 1
|
||||
total_timesteps += len(batch)
|
||||
# Add NEXT_OBS if not available. This is slightly hacked
|
||||
# as for the very last time step, we will use next-obs=zeros
|
||||
# and therefore force-set DONE=True to avoid this missing
|
||||
# next-obs to cause learning problems.
|
||||
if SampleBatch.NEXT_OBS not in batch:
|
||||
obs = batch[SampleBatch.OBS]
|
||||
batch[SampleBatch.NEXT_OBS] = np.concatenate(
|
||||
[obs[1:], np.zeros_like(obs[0:1])]
|
||||
)
|
||||
batch[SampleBatch.DONES][-1] = True
|
||||
self.local_replay_buffer.add(batch)
|
||||
print(
|
||||
f"Loaded {num_batches} batches ({total_timesteps} ts) into the"
|
||||
" replay buffer, which has capacity "
|
||||
f"{self.local_replay_buffer.capacity}."
|
||||
if self.config.get("target_network_update_freq", None) is None:
|
||||
self.config["target_network_update_freq"] = (
|
||||
self.config["train_batch_size"] * 100
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Unknown offline input! config['input'] must either be list of"
|
||||
" offline files (json) or a D4RL-specific InputReader "
|
||||
"specifier (e.g. 'd4rl.hopper-medium-v0')."
|
||||
)
|
||||
|
||||
# added a counter key for keeping track of number of gradient updates
|
||||
self._counters[NUM_GRADIENT_UPDATES] = 0
|
||||
# if I don't set this here to zero I won't see zero in the logs (defaultdict)
|
||||
|
@ -227,47 +186,39 @@ class CRR(Algorithm):
|
|||
|
||||
@override(Algorithm)
|
||||
def training_step(self) -> ResultDict:
|
||||
with self._timers[SAMPLE_TIMER]:
|
||||
train_batch = synchronous_parallel_sample(worker_set=self.workers)
|
||||
train_batch = train_batch.as_multi_agent()
|
||||
self._counters[NUM_AGENT_STEPS_SAMPLED] += train_batch.agent_steps()
|
||||
self._counters[NUM_ENV_STEPS_SAMPLED] += train_batch.env_steps()
|
||||
|
||||
total_transitions = len(self.local_replay_buffer)
|
||||
bsize = self.config["train_batch_size"]
|
||||
n_batches_per_epoch = total_transitions // bsize
|
||||
# Postprocess batch before we learn on it.
|
||||
post_fn = self.config.get("before_learn_on_batch") or (lambda b, *a: b)
|
||||
train_batch = post_fn(train_batch, self.workers, self.config)
|
||||
|
||||
results = []
|
||||
for batch_iter in range(n_batches_per_epoch):
|
||||
# Sample training batch from replay buffer.
|
||||
train_batch = self.local_replay_buffer.sample(bsize)
|
||||
# Learn on training batch.
|
||||
# Use simple optimizer (only for multi-agent or tf-eager; all other
|
||||
# cases should use the multi-GPU optimizer, even if only using 1 GPU)
|
||||
if self.config.get("simple_optimizer", False):
|
||||
train_results = train_one_step(self, train_batch)
|
||||
else:
|
||||
train_results = multi_gpu_train_one_step(self, train_batch)
|
||||
|
||||
# Postprocess batch before we learn on it.
|
||||
post_fn = self.config.get("before_learn_on_batch") or (lambda b, *a: b)
|
||||
train_batch = post_fn(train_batch, self.workers, self.config)
|
||||
# update target every few gradient updates
|
||||
# Update target network every `target_network_update_freq` training steps.
|
||||
cur_ts = self._counters[
|
||||
NUM_AGENT_STEPS_TRAINED if self._by_agent_steps else NUM_ENV_STEPS_TRAINED
|
||||
]
|
||||
last_update = self._counters[LAST_TARGET_UPDATE_TS]
|
||||
|
||||
# Learn on training batch.
|
||||
# Use simple optimizer (only for multi-agent or tf-eager; all other
|
||||
# cases should use the multi-GPU optimizer, even if only using 1 GPU)
|
||||
if self.config.get("simple_optimizer", False):
|
||||
train_results = train_one_step(self, train_batch)
|
||||
else:
|
||||
train_results = multi_gpu_train_one_step(self, train_batch)
|
||||
if cur_ts - last_update >= self.config["target_network_update_freq"]:
|
||||
with self._timers[TARGET_NET_UPDATE_TIMER]:
|
||||
to_update = self.workers.local_worker().get_policies_to_train()
|
||||
self.workers.local_worker().foreach_policy_to_train(
|
||||
lambda p, pid: pid in to_update and p.update_target()
|
||||
)
|
||||
self._counters[NUM_TARGET_UPDATES] += 1
|
||||
self._counters[LAST_TARGET_UPDATE_TS] = cur_ts
|
||||
|
||||
# update target every few gradient updates
|
||||
cur_ts = self._counters[NUM_GRADIENT_UPDATES]
|
||||
last_update = self._counters[LAST_TARGET_UPDATE_TS]
|
||||
|
||||
if cur_ts - last_update >= self.config["target_update_grad_intervals"]:
|
||||
with self._timers[TARGET_NET_UPDATE_TIMER]:
|
||||
to_update = self.workers.local_worker().get_policies_to_train()
|
||||
self.workers.local_worker().foreach_policy_to_train(
|
||||
lambda p, pid: pid in to_update and p.update_target()
|
||||
)
|
||||
self._counters[NUM_TARGET_UPDATES] += 1
|
||||
self._counters[LAST_TARGET_UPDATE_TS] = cur_ts
|
||||
|
||||
self._counters[NUM_GRADIENT_UPDATES] += 1
|
||||
|
||||
results.append(train_results)
|
||||
|
||||
summary = tree.map_structure_with_path(
|
||||
lambda path, *v: float(np.mean(v)), *results
|
||||
)
|
||||
|
||||
return summary
|
||||
self._counters[NUM_GRADIENT_UPDATES] += 1
|
||||
return train_results
|
||||
|
|
|
@ -4,8 +4,8 @@ import unittest
|
|||
|
||||
import ray
|
||||
from ray.rllib.algorithms.crr import CRRConfig
|
||||
from ray.rllib.offline.json_reader import JsonReader
|
||||
from ray.rllib.utils.framework import try_import_tf, try_import_torch
|
||||
from ray.rllib.utils.replay_buffers import MultiAgentReplayBuffer
|
||||
from ray.rllib.utils.test_utils import (
|
||||
check_compute_single_action,
|
||||
check_train_results,
|
||||
|
@ -32,24 +32,31 @@ class TestCRR(unittest.TestCase):
|
|||
print("rllib dir={}".format(rllib_dir))
|
||||
data_file = os.path.join(rllib_dir, "tests/data/pendulum/large.json")
|
||||
print("data_file={} exists={}".format(data_file, os.path.isfile(data_file)))
|
||||
# Will use the Json Reader in this example until we convert over the example
|
||||
# files over to Parquet, since the dataset json reader cannot handle large
|
||||
# block sizes.
|
||||
|
||||
def input_reading_fn(ioctx):
|
||||
return JsonReader(ioctx.config["input_config"]["paths"], ioctx)
|
||||
|
||||
input_config = {"paths": data_file}
|
||||
|
||||
config = (
|
||||
CRRConfig()
|
||||
.environment(env="Pendulum-v1", clip_actions=True)
|
||||
.framework("torch")
|
||||
.offline_data(input_=[data_file], actions_in_input_normalized=True)
|
||||
.offline_data(
|
||||
input_=input_reading_fn,
|
||||
input_config=input_config,
|
||||
actions_in_input_normalized=True,
|
||||
)
|
||||
.training(
|
||||
twin_q=True,
|
||||
train_batch_size=256,
|
||||
replay_buffer_config={
|
||||
"type": MultiAgentReplayBuffer,
|
||||
"learning_starts": 0,
|
||||
"capacity": 100000,
|
||||
},
|
||||
weight_type="bin",
|
||||
advantage_type="mean",
|
||||
n_action_sample=4,
|
||||
target_update_grad_intervals=10000,
|
||||
target_network_update_freq=10000,
|
||||
tau=1.0,
|
||||
)
|
||||
.evaluation(
|
||||
|
|
|
@ -18,6 +18,7 @@ from ray.rllib.utils.metrics import (
|
|||
NUM_AGENT_STEPS_SAMPLED,
|
||||
NUM_ENV_STEPS_SAMPLED,
|
||||
SYNCH_WORKER_WEIGHTS_TIMER,
|
||||
SAMPLE_TIMER,
|
||||
)
|
||||
from ray.rllib.utils.typing import (
|
||||
ResultDict,
|
||||
|
@ -253,7 +254,8 @@ class MARWIL(Algorithm):
|
|||
@override(Algorithm)
|
||||
def training_step(self) -> ResultDict:
|
||||
# Collect SampleBatches from sample workers.
|
||||
batch = synchronous_parallel_sample(worker_set=self.workers)
|
||||
with self._timers[SAMPLE_TIMER]:
|
||||
batch = synchronous_parallel_sample(worker_set=self.workers)
|
||||
batch = batch.as_multi_agent()
|
||||
self._counters[NUM_AGENT_STEPS_SAMPLED] += batch.agent_steps()
|
||||
self._counters[NUM_ENV_STEPS_SAMPLED] += batch.env_steps()
|
||||
|
|
|
@ -863,7 +863,6 @@ class RolloutWorker(ParallelIteratorWorker):
|
|||
>>> print(worker.sample()) # doctest: +SKIP
|
||||
SampleBatch({"obs": [...], "action": [...], ...})
|
||||
"""
|
||||
|
||||
if self.fake_sampler and self.last_batch is not None:
|
||||
return self.last_batch
|
||||
elif self.input_reader is None:
|
||||
|
@ -893,9 +892,8 @@ class RolloutWorker(ParallelIteratorWorker):
|
|||
max_batches = self.num_envs
|
||||
else:
|
||||
max_batches = float("inf")
|
||||
|
||||
while (
|
||||
steps_so_far < self.rollout_fragment_length and len(batches) < max_batches
|
||||
while steps_so_far < self.rollout_fragment_length and (
|
||||
len(batches) < max_batches or self.policy_config.get("offline_sampling")
|
||||
):
|
||||
batch = self.input_reader.next()
|
||||
steps_so_far += (
|
||||
|
|
|
@ -1,3 +1,6 @@
|
|||
from pathlib import Path
|
||||
import re
|
||||
|
||||
import gym
|
||||
import logging
|
||||
import importlib.util
|
||||
|
@ -99,14 +102,59 @@ class WorkerSet:
|
|||
}
|
||||
self._cls = RolloutWorker.as_remote(**self._remote_args).remote
|
||||
self._logdir = logdir
|
||||
|
||||
if _setup:
|
||||
# Force a local worker if num_workers == 0 (no remote workers).
|
||||
# Otherwise, this WorkerSet would be empty.
|
||||
self._local_worker = None
|
||||
if num_workers == 0:
|
||||
local_worker = True
|
||||
|
||||
if (
|
||||
(
|
||||
isinstance(trainer_config["input"], str)
|
||||
or isinstance(trainer_config["input"], list)
|
||||
)
|
||||
and ("d4rl" not in trainer_config["input"])
|
||||
and (not "sampler" == trainer_config["input"])
|
||||
and (not "dataset" == trainer_config["input"])
|
||||
and (
|
||||
not (
|
||||
isinstance(trainer_config["input"], str)
|
||||
and registry_contains_input(trainer_config["input"])
|
||||
)
|
||||
)
|
||||
and (
|
||||
not (
|
||||
isinstance(trainer_config["input"], str)
|
||||
and self._valid_module(trainer_config["input"])
|
||||
)
|
||||
)
|
||||
):
|
||||
paths = trainer_config["input"]
|
||||
if isinstance(paths, str):
|
||||
inputs = Path(paths).absolute()
|
||||
if inputs.is_dir():
|
||||
paths = list(inputs.glob("*.json")) + list(inputs.glob("*.zip"))
|
||||
paths = [str(path) for path in paths]
|
||||
else:
|
||||
paths = [paths]
|
||||
ends_with_zip_or_json = all(
|
||||
re.search("\\.zip$", path) or re.search("\\.json$", path)
|
||||
for path in paths
|
||||
)
|
||||
ends_with_parquet = all(
|
||||
re.search("\\.parquet$", path) for path in paths
|
||||
)
|
||||
trainer_config["input"] = "dataset"
|
||||
input_config = {"paths": paths}
|
||||
if ends_with_zip_or_json:
|
||||
input_config["format"] = "json"
|
||||
elif ends_with_parquet:
|
||||
input_config["format"] = "parquet"
|
||||
else:
|
||||
raise ValueError(
|
||||
"Input path must end with .zip, .parquet, or .json"
|
||||
)
|
||||
trainer_config["input_config"] = input_config
|
||||
self._local_config = merge_dicts(
|
||||
trainer_config,
|
||||
{"tf_session_args": trainer_config["local_tf_session_args"]},
|
||||
|
@ -737,6 +785,26 @@ class WorkerSet:
|
|||
|
||||
return faulty_worker_indices
|
||||
|
||||
@classmethod
|
||||
def _valid_module(cls, class_path):
|
||||
del cls
|
||||
if (
|
||||
isinstance(class_path, str)
|
||||
and not os.path.isfile(class_path)
|
||||
and "." in class_path
|
||||
):
|
||||
module_path, class_name = class_path.rsplit(".", 1)
|
||||
try:
|
||||
spec = importlib.util.find_spec(module_path)
|
||||
if spec is not None:
|
||||
return True
|
||||
except (ModuleNotFoundError, ValueError):
|
||||
print(
|
||||
f"module {module_path} not found while trying to get "
|
||||
f"input {class_path}"
|
||||
)
|
||||
return False
|
||||
|
||||
@Deprecated(new="WorkerSet.foreach_policy_to_train", error=False)
|
||||
def foreach_trainable_policy(self, func):
|
||||
return self.foreach_policy_to_train(func)
|
||||
|
|
|
@ -1,15 +1,17 @@
|
|||
import logging
|
||||
import math
|
||||
import re
|
||||
import zipfile
|
||||
from pathlib import Path
|
||||
|
||||
import ray.data
|
||||
from ray.rllib.offline.input_reader import InputReader
|
||||
from ray.rllib.offline.io_context import IOContext
|
||||
from ray.rllib.offline.json_reader import from_json_data
|
||||
from ray.rllib.policy.sample_batch import concat_samples
|
||||
from ray.rllib.policy.sample_batch import concat_samples, SampleBatch, DEFAULT_POLICY_ID
|
||||
from ray.rllib.utils.annotations import override, PublicAPI
|
||||
from ray.rllib.utils.typing import SampleBatchType, AlgorithmConfigDict
|
||||
from typing import List
|
||||
|
||||
from typing import List, Tuple
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -25,11 +27,31 @@ def _get_resource_bundles(config: AlgorithmConfigDict):
|
|||
return [{"CPU": math.ceil(parallelism * cpus_per_task)}]
|
||||
|
||||
|
||||
def _unzip_if_needed(paths: List[str], format: str):
|
||||
"""If a path in paths is a zip file, unzip it and use path of the unzipped file"""
|
||||
ret = []
|
||||
for path in paths:
|
||||
fpath = Path(path).absolute()
|
||||
if not fpath.exists():
|
||||
fpath = Path(__file__).parent.parent / path
|
||||
if not fpath.exists():
|
||||
raise FileNotFoundError(f"File not found: {path}")
|
||||
if re.search("\\.zip$", str(fpath)):
|
||||
with zipfile.ZipFile(str(fpath), "r") as zip_ref:
|
||||
zip_ref.extractall(str(fpath.parent))
|
||||
fpath = re.sub("\\.zip$", f".{format}", str(fpath))
|
||||
fpath = str(fpath)
|
||||
ret.append(fpath)
|
||||
return ret
|
||||
|
||||
|
||||
@PublicAPI
|
||||
def get_dataset_and_shards(
|
||||
config: AlgorithmConfigDict, num_workers: int, local_worker: bool
|
||||
) -> (ray.data.dataset.Dataset, List[ray.data.dataset.Dataset]):
|
||||
assert config["input"] == "dataset"
|
||||
) -> Tuple[ray.data.dataset.Dataset, List[ray.data.dataset.Dataset]]:
|
||||
assert config["input"] == "dataset", (
|
||||
"Must specify input as dataset if" " calling `get_dataset_and_shards`"
|
||||
)
|
||||
assert (
|
||||
"input_config" in config
|
||||
), "Must specify input_config dict if using Dataset input."
|
||||
|
@ -37,36 +59,47 @@ def get_dataset_and_shards(
|
|||
input_config = config["input_config"]
|
||||
|
||||
format = input_config.get("format")
|
||||
path = input_config.get("path")
|
||||
assert format in ("json", "parquet"), (
|
||||
"Offline input data format must be " "parquet " "or json"
|
||||
)
|
||||
paths = input_config.get("paths")
|
||||
loader_fn = input_config.get("loader_fn")
|
||||
|
||||
if loader_fn and (format or path):
|
||||
if loader_fn and (format or paths):
|
||||
raise ValueError(
|
||||
"When using a `loader_fn`, you cannot specify a `format` or `path`."
|
||||
)
|
||||
|
||||
if not (format and path) and not loader_fn:
|
||||
if not (format and paths) and not loader_fn:
|
||||
raise ValueError(
|
||||
"Must specify format and path, or a loader_fn via input_config key"
|
||||
" when using Ray dataset input."
|
||||
)
|
||||
|
||||
if not isinstance(paths, (list, str)):
|
||||
raise ValueError("Paths must be a list of path strings or a path string")
|
||||
if isinstance(paths, str):
|
||||
paths = [paths]
|
||||
paths = _unzip_if_needed(paths, format)
|
||||
|
||||
parallelism = input_config.get("parallelism", num_workers or 1)
|
||||
cpus_per_task = input_config.get(
|
||||
"num_cpus_per_read_task", DEFAULT_NUM_CPUS_PER_TASK
|
||||
)
|
||||
|
||||
assert loader_fn or (format and path)
|
||||
|
||||
assert loader_fn or (format and paths), (
|
||||
f"If using a loader_fn: {loader_fn} that constructs a dataset, "
|
||||
"format: {format} and paths: {paths} must be specified. If format and "
|
||||
"paths are specified, a loader_fn must not be specified."
|
||||
)
|
||||
if loader_fn:
|
||||
dataset = loader_fn()
|
||||
elif format == "json":
|
||||
dataset = ray.data.read_json(
|
||||
path, parallelism=parallelism, ray_remote_args={"num_cpus": cpus_per_task}
|
||||
paths, parallelism=parallelism, ray_remote_args={"num_cpus": cpus_per_task}
|
||||
)
|
||||
elif format == "parquet":
|
||||
dataset = ray.data.read_parquet(
|
||||
path, parallelism=parallelism, ray_remote_args={"num_cpus": cpus_per_task}
|
||||
paths, parallelism=parallelism, ray_remote_args={"num_cpus": cpus_per_task}
|
||||
)
|
||||
else:
|
||||
raise ValueError("Un-supported Ray dataset format: ", format)
|
||||
|
@ -96,7 +129,7 @@ class DatasetReader(InputReader):
|
|||
"format": "json",
|
||||
# A single data file, a directory, or anything
|
||||
# that ray.data.dataset recognizes.
|
||||
"path": "/tmp/sample_batches/",
|
||||
"paths": "/tmp/sample_batches/",
|
||||
# By default, parallelism=num_workers.
|
||||
"parallelism": 3,
|
||||
# Dataset allocates 0.5 CPU for each reader by default.
|
||||
|
@ -114,20 +147,33 @@ class DatasetReader(InputReader):
|
|||
ds: Ray dataset to sample from.
|
||||
"""
|
||||
self._ioctx = ioctx
|
||||
self._default_policy = self.policy_map = None
|
||||
self._dataset = ds
|
||||
self.count = None if not self._dataset else self._dataset.count()
|
||||
# do this to disable the ray data stdout logging
|
||||
ray.data.set_progress_bars(enabled=False)
|
||||
|
||||
# the number of rows to return per call to next()
|
||||
if self._ioctx:
|
||||
self.batch_size = ioctx.config.get("train_batch_size", 1)
|
||||
num_workers = ioctx.config.get("num_workers", 0)
|
||||
self.batch_size = self._ioctx.config.get("train_batch_size", 1)
|
||||
num_workers = self._ioctx.config.get("num_workers", 0)
|
||||
seed = self._ioctx.config.get("seed", None)
|
||||
if num_workers:
|
||||
self.batch_size = max(math.ceil(self.batch_size / num_workers), 1)
|
||||
# We allow the creation of a non-functioning None DatasetReader.
|
||||
# It's useful for example for a non-rollout local worker.
|
||||
if ds:
|
||||
if self._ioctx.worker is not None:
|
||||
self._policy_map = self._ioctx.worker.policy_map
|
||||
self._default_policy = self._policy_map.get(DEFAULT_POLICY_ID)
|
||||
self._dataset.random_shuffle(seed=seed)
|
||||
print(
|
||||
"DatasetReader ", ioctx.worker_index, " has ", ds.count(), " samples."
|
||||
f"DatasetReader {self._ioctx.worker_index} has {ds.count()}, samples."
|
||||
)
|
||||
self._iter = self._dataset.repeat().iter_rows()
|
||||
# TODO: @avnishn make this call seeded.
|
||||
# calling random_shuffle_each_window shuffles the dataset after
|
||||
# each time the whole dataset has been read.
|
||||
self._iter = self._dataset.repeat().random_shuffle_each_window().iter_rows()
|
||||
else:
|
||||
self._iter = None
|
||||
|
||||
|
@ -142,6 +188,22 @@ class DatasetReader(InputReader):
|
|||
# Columns like obs are compressed when written by DatasetWriter.
|
||||
d = from_json_data(d, self._ioctx.worker)
|
||||
count += d.count
|
||||
ret.append(d)
|
||||
ret.append(self._postprocess_if_needed(d))
|
||||
ret = concat_samples(ret)
|
||||
return ret
|
||||
|
||||
def _postprocess_if_needed(self, batch: SampleBatchType) -> SampleBatchType:
|
||||
if not self._ioctx or not self._ioctx.config.get("postprocess_inputs"):
|
||||
return batch
|
||||
|
||||
if isinstance(batch, SampleBatch):
|
||||
out = []
|
||||
for sub_batch in batch.split_by_episode():
|
||||
out.append(self._default_policy.postprocess_trajectory(sub_batch))
|
||||
return SampleBatch.concat_samples(out)
|
||||
else:
|
||||
# TODO(ekl) this is trickier since the alignments between agent
|
||||
# trajectories in the episode are not available any more.
|
||||
raise NotImplementedError(
|
||||
"Postprocessing of multi-agent data not implemented yet."
|
||||
)
|
||||
|
|
|
@ -22,7 +22,7 @@ class TestDatasetReader(unittest.TestCase):
|
|||
print("rllib dir={}".format(rllib_dir))
|
||||
data_file = os.path.join(rllib_dir, "rllib/tests/data/pendulum/large.json")
|
||||
print("data_file={} exists={}".format(data_file, os.path.isfile(data_file)))
|
||||
input_config = {"format": "json", "path": data_file}
|
||||
input_config = {"format": "json", "paths": data_file}
|
||||
dataset, _ = get_dataset_and_shards(
|
||||
{"input": "dataset", "input_config": input_config}, 0, True
|
||||
)
|
||||
|
|
Binary file not shown.
File diff suppressed because one or more lines are too long
Binary file not shown.
|
@ -106,7 +106,13 @@ class NestedActionSpacesTest(unittest.TestCase):
|
|||
|
||||
# Test, whether offline data can be properly read by
|
||||
# BC, configured accordingly.
|
||||
config["input"] = config["output"]
|
||||
|
||||
# doing this for backwards compatibility until we move to parquet
|
||||
# as default output
|
||||
config["input"] = lambda ioctx: JsonReader(
|
||||
ioctx.config["input_config"]["paths"], ioctx
|
||||
)
|
||||
config["input_config"] = {"paths": config["output"]}
|
||||
del config["output"]
|
||||
bc = BC(config=config)
|
||||
bc.train()
|
||||
|
|
|
@ -7,7 +7,7 @@ pendulum-cql:
|
|||
run: CQL
|
||||
stop:
|
||||
evaluation/episode_reward_mean: -700
|
||||
timesteps_total: 200000
|
||||
timesteps_total: 800000
|
||||
config:
|
||||
# Works for both torch and tf.
|
||||
framework: tf
|
||||
|
@ -21,10 +21,9 @@ pendulum-cql:
|
|||
|
||||
twin_q: true
|
||||
train_batch_size: 2000
|
||||
replay_buffer_config:
|
||||
type: MultiAgentReplayBuffer
|
||||
learning_starts: 0
|
||||
bc_iters: 100
|
||||
num_workers: 2
|
||||
min_time_s_per_iteration: 10
|
||||
|
||||
metrics_smoothing_episodes: 5
|
||||
|
||||
|
|
|
@ -7,6 +7,7 @@ cartpole_crr:
|
|||
config:
|
||||
input:
|
||||
- 'tests/data/cartpole/large.json'
|
||||
num_workers: 3
|
||||
framework: torch
|
||||
gamma: 0.99
|
||||
train_batch_size: 2048
|
||||
|
@ -20,7 +21,7 @@ cartpole_crr:
|
|||
clip_actions: True
|
||||
# Q function update setting
|
||||
twin_q: True
|
||||
target_update_grad_intervals: 1
|
||||
target_network_update_freq: 1
|
||||
tau: 0.0005
|
||||
# evaluation
|
||||
evaluation_config:
|
||||
|
@ -31,11 +32,6 @@ cartpole_crr:
|
|||
evaluation_interval: 1
|
||||
evaluation_num_workers: 1
|
||||
evaluation_parallel_to_training: True
|
||||
# replay buffer
|
||||
replay_buffer_config:
|
||||
type: ray.rllib.utils.replay_buffers.MultiAgentReplayBuffer
|
||||
learning_starts: 0
|
||||
capacity: 100000
|
||||
# specific to CRR
|
||||
temperature: 1.0
|
||||
weight_type: bin
|
||||
|
|
|
@ -8,6 +8,7 @@ cartpole_crr:
|
|||
input:
|
||||
- 'tests/data/cartpole/large.json'
|
||||
framework: torch
|
||||
num_workers: 3
|
||||
gamma: 0.99
|
||||
train_batch_size: 2048
|
||||
critic_hidden_activation: 'tanh'
|
||||
|
@ -20,7 +21,7 @@ cartpole_crr:
|
|||
clip_actions: True
|
||||
# Q function update setting
|
||||
twin_q: True
|
||||
target_update_grad_intervals: 1
|
||||
target_network_update_freq: 1
|
||||
tau: 0.0005
|
||||
# evaluation
|
||||
evaluation_config:
|
||||
|
@ -31,11 +32,6 @@ cartpole_crr:
|
|||
evaluation_interval: 1
|
||||
evaluation_num_workers: 1
|
||||
evaluation_parallel_to_training: True
|
||||
# replay buffer
|
||||
replay_buffer_config:
|
||||
type: ray.rllib.utils.replay_buffers.MultiAgentReplayBuffer
|
||||
learning_starts: 0
|
||||
capacity: 100000
|
||||
# specific to CRR
|
||||
temperature: 1.0
|
||||
weight_type: bin
|
||||
|
|
|
@ -2,12 +2,13 @@ pendulum_crr:
|
|||
env: 'Pendulum-v1'
|
||||
run: CRR
|
||||
stop:
|
||||
evaluation/episode_reward_mean: -200
|
||||
training_iteration: 500
|
||||
# We could make this -200, but given that we have 4 cpus for our tests, we will have to settle for -300.
|
||||
evaluation/episode_reward_mean: -300
|
||||
timesteps_total: 2000000
|
||||
config:
|
||||
input:
|
||||
- 'tests/data/pendulum/pendulum_replay_v1.1.0.zip'
|
||||
input: 'tests/data/pendulum/pendulum_replay_v1.1.0.zip'
|
||||
framework: torch
|
||||
num_workers: 3
|
||||
gamma: 0.99
|
||||
train_batch_size: 1024
|
||||
critic_hidden_activation: 'relu'
|
||||
|
@ -20,7 +21,7 @@ pendulum_crr:
|
|||
clip_actions: True
|
||||
# Q function update setting
|
||||
twin_q: True
|
||||
target_update_grad_intervals: 1
|
||||
target_network_update_freq: 1
|
||||
tau: 0.0001
|
||||
# evaluation
|
||||
evaluation_config:
|
||||
|
@ -32,10 +33,6 @@ pendulum_crr:
|
|||
evaluation_num_workers: 1
|
||||
evaluation_parallel_to_training: True
|
||||
# replay buffer
|
||||
replay_buffer_config:
|
||||
type: ray.rllib.utils.replay_buffers.MultiAgentReplayBuffer
|
||||
learning_starts: 0
|
||||
capacity: 100000
|
||||
# specific to CRR
|
||||
temperature: 1.0
|
||||
weight_type: exp
|
||||
|
|
Loading…
Add table
Reference in a new issue