[RLlib] Make JSONReader default, users will have to use the DatasetReader for any speedups. (#26541)

This commit is contained in:
Avnish Narayan 2022-07-14 08:19:38 -07:00 committed by GitHub
parent c168c09281
commit a322ac463c
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 17 additions and 57 deletions

View file

@ -1,6 +1,3 @@
from pathlib import Path
import re
import gym
import logging
import importlib.util
@ -108,53 +105,6 @@ class WorkerSet:
self._local_worker = None
if num_workers == 0:
local_worker = True
if (
(
isinstance(trainer_config["input"], str)
or isinstance(trainer_config["input"], list)
)
and ("d4rl" not in trainer_config["input"])
and (not "sampler" == trainer_config["input"])
and (not "dataset" == trainer_config["input"])
and (
not (
isinstance(trainer_config["input"], str)
and registry_contains_input(trainer_config["input"])
)
)
and (
not (
isinstance(trainer_config["input"], str)
and self._valid_module(trainer_config["input"])
)
)
):
paths = trainer_config["input"]
if isinstance(paths, str):
inputs = Path(paths).absolute()
if inputs.is_dir():
paths = list(inputs.glob("*.json")) + list(inputs.glob("*.zip"))
paths = [str(path) for path in paths]
else:
paths = [paths]
ends_with_zip_or_json = all(
re.search("\\.zip$", path) or re.search("\\.json$", path)
for path in paths
)
ends_with_parquet = all(
re.search("\\.parquet$", path) for path in paths
)
trainer_config["input"] = "dataset"
input_config = {"paths": paths}
if ends_with_zip_or_json:
input_config["format"] = "json"
elif ends_with_parquet:
input_config["format"] = "parquet"
else:
raise ValueError(
"Input path must end with .zip, .parquet, or .json"
)
trainer_config["input_config"] = input_config
self._local_config = merge_dicts(
trainer_config,
{"tf_session_args": trainer_config["local_tf_session_args"]},

View file

@ -39,7 +39,7 @@ parser.add_argument(
required=True,
help="The directory in which to find all yamls to test.",
)
parser.add_argument("--num-cpus", type=int, default=6)
parser.add_argument("--num-cpus", type=int, default=8)
parser.add_argument(
"--local-mode",
action="store_true",

View file

@ -13,7 +13,10 @@ pendulum-cql:
framework: tf
# Use one or more offline files or "input: sampler" for online learning.
input: ["tests/data/pendulum/enormous.zip"]
input: 'dataset'
input_config:
paths: ["tests/data/pendulum/enormous.zip"]
format: 'json'
# Our input file above comes from an SAC run. Actions in there
# are already normalized (produced by SquashedGaussian).
actions_in_input_normalized: true

View file

@ -5,8 +5,10 @@ cartpole_crr:
evaluation/episode_reward_mean: 200
training_iteration: 100
config:
input:
- 'tests/data/cartpole/large.json'
input: 'dataset'
input_config:
paths: 'tests/data/cartpole/large.json'
format: 'json'
num_workers: 3
framework: torch
gamma: 0.99

View file

@ -5,8 +5,10 @@ cartpole_crr:
evaluation/episode_reward_mean: 200
training_iteration: 100
config:
input:
- 'tests/data/cartpole/large.json'
input: 'dataset'
input_config:
paths: 'tests/data/cartpole/large.json'
format: 'json'
framework: torch
num_workers: 3
gamma: 0.99

View file

@ -6,7 +6,10 @@ pendulum_crr:
evaluation/episode_reward_mean: -300
timesteps_total: 2000000
config:
input: 'tests/data/pendulum/pendulum_replay_v1.1.0.zip'
input: 'dataset'
input_config:
paths: 'tests/data/pendulum/pendulum_replay_v1.1.0.zip'
format: 'json'
framework: torch
num_workers: 3
gamma: 0.99