mirror of
https://github.com/vale981/ray
synced 2025-03-05 10:01:43 -05:00
[RLlib] Make JSONReader default, users will have to use the DatasetReader for any speedups. (#26541)
This commit is contained in:
parent
c168c09281
commit
a322ac463c
6 changed files with 17 additions and 57 deletions
|
@ -1,6 +1,3 @@
|
|||
from pathlib import Path
|
||||
import re
|
||||
|
||||
import gym
|
||||
import logging
|
||||
import importlib.util
|
||||
|
@ -108,53 +105,6 @@ class WorkerSet:
|
|||
self._local_worker = None
|
||||
if num_workers == 0:
|
||||
local_worker = True
|
||||
if (
|
||||
(
|
||||
isinstance(trainer_config["input"], str)
|
||||
or isinstance(trainer_config["input"], list)
|
||||
)
|
||||
and ("d4rl" not in trainer_config["input"])
|
||||
and (not "sampler" == trainer_config["input"])
|
||||
and (not "dataset" == trainer_config["input"])
|
||||
and (
|
||||
not (
|
||||
isinstance(trainer_config["input"], str)
|
||||
and registry_contains_input(trainer_config["input"])
|
||||
)
|
||||
)
|
||||
and (
|
||||
not (
|
||||
isinstance(trainer_config["input"], str)
|
||||
and self._valid_module(trainer_config["input"])
|
||||
)
|
||||
)
|
||||
):
|
||||
paths = trainer_config["input"]
|
||||
if isinstance(paths, str):
|
||||
inputs = Path(paths).absolute()
|
||||
if inputs.is_dir():
|
||||
paths = list(inputs.glob("*.json")) + list(inputs.glob("*.zip"))
|
||||
paths = [str(path) for path in paths]
|
||||
else:
|
||||
paths = [paths]
|
||||
ends_with_zip_or_json = all(
|
||||
re.search("\\.zip$", path) or re.search("\\.json$", path)
|
||||
for path in paths
|
||||
)
|
||||
ends_with_parquet = all(
|
||||
re.search("\\.parquet$", path) for path in paths
|
||||
)
|
||||
trainer_config["input"] = "dataset"
|
||||
input_config = {"paths": paths}
|
||||
if ends_with_zip_or_json:
|
||||
input_config["format"] = "json"
|
||||
elif ends_with_parquet:
|
||||
input_config["format"] = "parquet"
|
||||
else:
|
||||
raise ValueError(
|
||||
"Input path must end with .zip, .parquet, or .json"
|
||||
)
|
||||
trainer_config["input_config"] = input_config
|
||||
self._local_config = merge_dicts(
|
||||
trainer_config,
|
||||
{"tf_session_args": trainer_config["local_tf_session_args"]},
|
||||
|
|
|
@ -39,7 +39,7 @@ parser.add_argument(
|
|||
required=True,
|
||||
help="The directory in which to find all yamls to test.",
|
||||
)
|
||||
parser.add_argument("--num-cpus", type=int, default=6)
|
||||
parser.add_argument("--num-cpus", type=int, default=8)
|
||||
parser.add_argument(
|
||||
"--local-mode",
|
||||
action="store_true",
|
||||
|
|
|
@ -13,7 +13,10 @@ pendulum-cql:
|
|||
framework: tf
|
||||
|
||||
# Use one or more offline files or "input: sampler" for online learning.
|
||||
input: ["tests/data/pendulum/enormous.zip"]
|
||||
input: 'dataset'
|
||||
input_config:
|
||||
paths: ["tests/data/pendulum/enormous.zip"]
|
||||
format: 'json'
|
||||
# Our input file above comes from an SAC run. Actions in there
|
||||
# are already normalized (produced by SquashedGaussian).
|
||||
actions_in_input_normalized: true
|
||||
|
|
|
@ -5,8 +5,10 @@ cartpole_crr:
|
|||
evaluation/episode_reward_mean: 200
|
||||
training_iteration: 100
|
||||
config:
|
||||
input:
|
||||
- 'tests/data/cartpole/large.json'
|
||||
input: 'dataset'
|
||||
input_config:
|
||||
paths: 'tests/data/cartpole/large.json'
|
||||
format: 'json'
|
||||
num_workers: 3
|
||||
framework: torch
|
||||
gamma: 0.99
|
||||
|
|
|
@ -5,8 +5,10 @@ cartpole_crr:
|
|||
evaluation/episode_reward_mean: 200
|
||||
training_iteration: 100
|
||||
config:
|
||||
input:
|
||||
- 'tests/data/cartpole/large.json'
|
||||
input: 'dataset'
|
||||
input_config:
|
||||
paths: 'tests/data/cartpole/large.json'
|
||||
format: 'json'
|
||||
framework: torch
|
||||
num_workers: 3
|
||||
gamma: 0.99
|
||||
|
|
|
@ -6,7 +6,10 @@ pendulum_crr:
|
|||
evaluation/episode_reward_mean: -300
|
||||
timesteps_total: 2000000
|
||||
config:
|
||||
input: 'tests/data/pendulum/pendulum_replay_v1.1.0.zip'
|
||||
input: 'dataset'
|
||||
input_config:
|
||||
paths: 'tests/data/pendulum/pendulum_replay_v1.1.0.zip'
|
||||
format: 'json'
|
||||
framework: torch
|
||||
num_workers: 3
|
||||
gamma: 0.99
|
||||
|
|
Loading…
Add table
Reference in a new issue