diff --git a/release/air_tests/air_benchmarks/workloads/tensorflow_benchmark.py b/release/air_tests/air_benchmarks/workloads/tensorflow_benchmark.py index b000a9889..724c53ffd 100644 --- a/release/air_tests/air_benchmarks/workloads/tensorflow_benchmark.py +++ b/release/air_tests/air_benchmarks/workloads/tensorflow_benchmark.py @@ -219,6 +219,7 @@ def cli(): @click.option("--cpus-per-worker", type=int, default=8) @click.option("--use-gpu", is_flag=True, default=False) @click.option("--batch-size", type=int, default=64) +@click.option("--smoke-test", is_flag=True, default=False) def run( num_runs: int = 1, num_epochs: int = 4, @@ -228,6 +229,8 @@ def run( batch_size: int = 64, smoke_test: bool = False, ): + # Note: smoke_test is ignored as we just adjust the batch size. + # The parameter is passed by the release test pipeline. import ray from benchmark_util import upload_file_to_all_nodes, run_command_on_all_nodes @@ -330,9 +333,9 @@ def run( with open(test_output_json, "wt") as f: json.dump(result, f) - target_ratio = 1.15 + target_ratio = 1.2 ratio = (times_ray_mean / times_vanilla_mean) if times_vanilla_mean != 0.0 else 1.0 - if ratio > 1.15: + if ratio > target_ratio: raise RuntimeError( f"Training on Ray took an average of {times_ray_mean:.2f} seconds, " f"which is more than {target_ratio:.2f}x of the average vanilla training "