mirror of
https://github.com/vale981/ray
synced 2025-03-05 18:11:42 -05:00
[RLlib Testing] Lower --smoke-test
"time_total_s" to make sure it doesn't time out. (#18670)
This commit is contained in:
parent
e7ea1f9a82
commit
ba1c489b79
9 changed files with 97 additions and 73 deletions
|
@ -12,9 +12,12 @@ python:
|
|||
|
||||
post_build_cmds:
|
||||
# Create a couple of soft links so tf 2.4.3 works with cuda 11.2.
|
||||
# TODO(jungong) : remove them once product ray-ml docker gets upgraded to use tf 2.5.0.
|
||||
# TODO(jungong): remove these once product ray-ml docker gets upgraded to use tf 2.5.0.
|
||||
- sudo ln -s /usr/local/cuda /usr/local/nvidia
|
||||
- sudo ln -s /usr/local/cuda/lib64/libcusolver.so.11 /usr/local/cuda/lib64/libcusolver.so.10
|
||||
- pip install tensorflow==2.5.0
|
||||
# END: TO-DO
|
||||
|
||||
- pip uninstall -y ray || true
|
||||
- pip3 install -U {{ env["RAY_WHEELS"] | default("ray") }}
|
||||
- {{ env["RAY_WHEELS_SANITY_CHECK"] | default("echo No Ray wheels sanity check") }}
|
||||
|
|
|
@ -3,29 +3,28 @@
|
|||
Runs Atari/PyBullet benchmarks for all major algorithms.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
from ray.rllib.utils.test_utils import run_learning_tests_from_yaml
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--smoke-test",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Finish quickly for training.")
|
||||
args = parser.parse_args()
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--smoke-test",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Finish quickly for training.")
|
||||
args = parser.parse_args()
|
||||
|
||||
# Get path of this very script to look for yaml files.
|
||||
abs_yaml_path = Path(__file__).parent
|
||||
print("abs_yaml_path={}".format(abs_yaml_path))
|
||||
|
||||
# This pattern match is kind of hacky. Avoids cluster.yaml to get sucked
|
||||
# into this.
|
||||
yaml_files = abs_yaml_path.rglob("*test*.yaml")
|
||||
yaml_files = abs_yaml_path.rglob("*.yaml")
|
||||
yaml_files = sorted(
|
||||
map(lambda path: str(path.absolute()), yaml_files), reverse=True)
|
||||
|
||||
|
@ -40,4 +39,4 @@ if __name__ == "__main__":
|
|||
with open(test_output_json, "wt") as f:
|
||||
json.dump(results, f)
|
||||
|
||||
print("PASSED.")
|
||||
print("Ok.")
|
||||
|
|
|
@ -1,29 +1,4 @@
|
|||
|
||||
a2c-stateless-cartpole:
|
||||
env: ray.rllib.examples.env.stateless_cartpole.StatelessCartPole
|
||||
run: A2C
|
||||
# Minimum reward and total ts (in given time_total_s) to pass this test.
|
||||
pass_criteria:
|
||||
episode_reward_mean: 150.0
|
||||
timesteps_total: 1000000
|
||||
stop:
|
||||
time_total_s: 1200
|
||||
config:
|
||||
num_gpus: 2
|
||||
num_workers: 23
|
||||
lr: 0.0005
|
||||
# Test w/ GTrXL net.
|
||||
model:
|
||||
use_attention: true
|
||||
max_seq_len: 10
|
||||
attention_num_transformer_units: 1
|
||||
attention_dim: 32
|
||||
attention_memory_inference: 10
|
||||
attention_memory_training: 10
|
||||
attention_num_heads: 1
|
||||
attention_head_dim: 32
|
||||
attention_position_wise_mlp_dim: 32
|
||||
|
||||
appo-stateless-cartpole-no-vtrace:
|
||||
env: ray.rllib.examples.env.stateless_cartpole.StatelessCartPole
|
||||
run: APPO
|
||||
|
|
|
@ -5,14 +5,12 @@
|
|||
compute_template: 8gpus_64cpus.yaml
|
||||
|
||||
run:
|
||||
timeout: 21600
|
||||
timeout: 14400
|
||||
script: python learning_tests/run.py
|
||||
|
||||
smoke_test:
|
||||
run:
|
||||
timeout: 7200
|
||||
cluster:
|
||||
compute_template: 4gpus_64cpus.yaml
|
||||
timeout: 1200
|
||||
|
||||
# 2-GPU learning tests (CartPole and RepeatAfterMeEnv) for major algos.
|
||||
- name: multi_gpu_learning_tests
|
||||
|
@ -56,7 +54,7 @@
|
|||
# timeout: 7200
|
||||
# script: bash unit_gpu_tests/run.sh
|
||||
|
||||
# IMPALA large machine stress tests (Atari).
|
||||
# IMPALA large machine stress tests (4x Atari).
|
||||
- name: stress_tests
|
||||
cluster:
|
||||
app_config: app_config.yaml
|
||||
|
@ -69,4 +67,4 @@
|
|||
|
||||
smoke_test:
|
||||
run:
|
||||
timeout: 1800
|
||||
timeout: 1200
|
||||
|
|
|
@ -3,22 +3,23 @@
|
|||
Runs IMPALA on 4 GPUs and 100s of CPUs.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
from ray.rllib.utils.test_utils import run_learning_tests_from_yaml
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--smoke-test",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Finish quickly for training.")
|
||||
args = parser.parse_args()
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--smoke-test",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Finish quickly for training.")
|
||||
args = parser.parse_args()
|
||||
|
||||
# Get path of this very script to look for yaml files.
|
||||
abs_yaml_path = Path(__file__).parent
|
||||
print("abs_yaml_path={}".format(abs_yaml_path))
|
||||
|
@ -27,6 +28,7 @@ if __name__ == "__main__":
|
|||
yaml_files = sorted(
|
||||
map(lambda path: str(path.absolute()), yaml_files), reverse=True)
|
||||
|
||||
# Run all tests in the found yaml files.
|
||||
results = run_learning_tests_from_yaml(
|
||||
yaml_files=yaml_files,
|
||||
max_num_repeats=1,
|
||||
|
|
|
@ -1,10 +1,49 @@
|
|||
import ray
|
||||
import argparse
|
||||
import time
|
||||
|
||||
import ray
|
||||
|
||||
ray.init(address="auto")
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"num_nodes",
|
||||
type=int,
|
||||
help="Wait for this number of nodes (includes head)")
|
||||
|
||||
parser.add_argument(
|
||||
"max_time_s", type=int, help="Wait for this number of seconds")
|
||||
|
||||
parser.add_argument(
|
||||
"--feedback_interval_s",
|
||||
type=int,
|
||||
default=10,
|
||||
help="Wait for this number of seconds")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
curr_nodes = 0
|
||||
while not curr_nodes > 8:
|
||||
print("Waiting for more nodes to come up: {}/{}".format(curr_nodes, 8))
|
||||
start = time.time()
|
||||
next_feedback = start
|
||||
max_time = start + args.max_time_s
|
||||
while not curr_nodes >= args.num_nodes:
|
||||
now = time.time()
|
||||
|
||||
if now >= max_time:
|
||||
raise RuntimeError(
|
||||
f"Maximum wait time reached, but only "
|
||||
f"{curr_nodes}/{args.num_nodes} nodes came up. Aborting.")
|
||||
|
||||
if now >= next_feedback:
|
||||
passed = now - start
|
||||
print(f"Waiting for more nodes to come up: "
|
||||
f"{curr_nodes}/{args.num_nodes} "
|
||||
f"({passed:.0f} seconds passed)")
|
||||
next_feedback = now + args.feedback_interval_s
|
||||
|
||||
time.sleep(5)
|
||||
curr_nodes = len(ray.nodes())
|
||||
time.sleep(10)
|
||||
|
||||
passed = time.time() - start
|
||||
print(f"Cluster is up: {curr_nodes}/{args.num_nodes} nodes online after "
|
||||
f"{passed:.0f} seconds")
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
"""CQL (derived from SAC).
|
||||
"""
|
||||
import logging
|
||||
import numpy as np
|
||||
from typing import Optional, Type
|
||||
|
||||
|
@ -16,10 +17,12 @@ from ray.rllib.offline.shuffled_input import ShuffledInput
|
|||
from ray.rllib.policy.policy import LEARNER_STATS_KEY, Policy
|
||||
from ray.rllib.policy.sample_batch import SampleBatch
|
||||
from ray.rllib.utils import merge_dicts
|
||||
from ray.rllib.utils.framework import try_import_tfp
|
||||
from ray.rllib.utils.framework import try_import_tf, try_import_tfp
|
||||
from ray.rllib.utils.typing import TrainerConfigDict
|
||||
|
||||
tf1, tf, tfv = try_import_tf()
|
||||
tfp = try_import_tfp()
|
||||
logger = logging.getLogger(__name__)
|
||||
replay_buffer = None
|
||||
|
||||
# yapf: disable
|
||||
|
@ -62,9 +65,12 @@ def validate_config(config: TrainerConfigDict):
|
|||
config["simple_optimizer"] = True
|
||||
|
||||
if config["framework"] in ["tf", "tf2", "tfe"] and tfp is None:
|
||||
raise ModuleNotFoundError(
|
||||
"You need `tensorflow_probability` in order to run CQL with tf! "
|
||||
"Install it via `pip install tensorflow_probability`.")
|
||||
logger.warning(
|
||||
"You need `tensorflow_probability` in order to run CQL! "
|
||||
"Install it via `pip install tensorflow_probability`. Your "
|
||||
f"tf.__version__={tf.__version__ if tf else None}."
|
||||
"Trying to import tfp results in the following error:")
|
||||
try_import_tfp(error=True)
|
||||
|
||||
|
||||
def execution_plan(workers, config):
|
||||
|
|
|
@ -17,9 +17,10 @@ from ray.rllib.agents.dqn.dqn import GenericOffPolicyTrainer
|
|||
from ray.rllib.agents.sac.sac_tf_policy import SACTFPolicy
|
||||
from ray.rllib.policy.policy import Policy
|
||||
from ray.rllib.utils.deprecation import DEPRECATED_VALUE, deprecation_warning
|
||||
from ray.rllib.utils.framework import try_import_tfp
|
||||
from ray.rllib.utils.framework import try_import_tf, try_import_tfp
|
||||
from ray.rllib.utils.typing import TrainerConfigDict
|
||||
|
||||
tf1, tf, tfv = try_import_tf()
|
||||
tfp = try_import_tfp()
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
@ -194,9 +195,12 @@ def validate_config(config: TrainerConfigDict) -> None:
|
|||
raise ValueError("`grad_clip` value must be > 0.0!")
|
||||
|
||||
if config["framework"] in ["tf", "tf2", "tfe"] and tfp is None:
|
||||
raise ModuleNotFoundError(
|
||||
logger.warning(
|
||||
"You need `tensorflow_probability` in order to run SAC! "
|
||||
"Install it via `pip install tensorflow_probability`.")
|
||||
"Install it via `pip install tensorflow_probability`. Your "
|
||||
f"tf.__version__={tf.__version__ if tf else None}."
|
||||
"Trying to import tfp results in the following error:")
|
||||
try_import_tfp(error=True)
|
||||
|
||||
|
||||
def get_policy_class(config: TrainerConfigDict) -> Optional[Type[Policy]]:
|
||||
|
|
|
@ -11,7 +11,6 @@ import yaml
|
|||
import ray
|
||||
from ray.rllib.utils.framework import try_import_jax, try_import_tf, \
|
||||
try_import_torch
|
||||
from ray.rllib.utils.numpy import LARGE_INTEGER
|
||||
from ray.tune import run_experiments
|
||||
|
||||
jax, _ = try_import_jax()
|
||||
|
@ -419,12 +418,11 @@ def run_learning_tests_from_yaml(
|
|||
|
||||
# For smoke-tests, we just run for n min.
|
||||
if smoke_test:
|
||||
# 4min hardcoded for now.
|
||||
e["stop"]["time_total_s"] = 240
|
||||
# Don't stop smoke tests b/c of any reward received.
|
||||
e["pass_criteria"]["episode_reward_mean"] = float("inf")
|
||||
# Same for timesteps.
|
||||
e["pass_criteria"]["timesteps_total"] = LARGE_INTEGER
|
||||
# 0sec for each(!) experiment/trial.
|
||||
# This is such that if there are many experiments/trials
|
||||
# in a test (e.g. rllib_learning_test), each one can at least
|
||||
# create its trainer and run a first iteration.
|
||||
e["stop"]["time_total_s"] = 0
|
||||
else:
|
||||
# We also stop early, once we reach the desired reward.
|
||||
e["stop"]["episode_reward_mean"] = \
|
||||
|
@ -459,7 +457,7 @@ def run_learning_tests_from_yaml(
|
|||
"passed": False,
|
||||
}
|
||||
# This key would break tune.
|
||||
del e["pass_criteria"]
|
||||
e.pop("pass_criteria", None)
|
||||
|
||||
# Print out the actual config.
|
||||
print("== Test config ==")
|
||||
|
|
Loading…
Add table
Reference in a new issue