mirror of
https://github.com/vale981/ray
synced 2025-04-23 06:25:52 -04:00
[tune] improve s3 log sync (#1511)
This commit is contained in:
parent
89db7841d2
commit
41007722f9
6 changed files with 92 additions and 23 deletions
|
@ -53,6 +53,7 @@ Getting Started
|
|||
"alpha": grid_search([0.2, 0.4, 0.6]),
|
||||
"beta": grid_search([1, 2]),
|
||||
},
|
||||
"upload_dir": "s3://your_bucket/path",
|
||||
}
|
||||
})
|
||||
|
||||
|
@ -72,7 +73,7 @@ This script runs a small grid search over the ``my_func`` function using Ray Tun
|
|||
- my_func_4_alpha=0.4,beta=2: RUNNING [pid=6800], 209 s, 41204 ts, 70.1 acc
|
||||
- my_func_5_alpha=0.6,beta=2: TERMINATED [pid=6809], 10 s, 2164 ts, 100 acc
|
||||
|
||||
In order to report incremental progress, ``my_func`` periodically calls the ``reporter`` function passed in by Ray Tune to return the current timestep and other metrics as defined in `ray.tune.result.TrainingResult <https://github.com/ray-project/ray/blob/master/python/ray/tune/result.py>`__.
|
||||
In order to report incremental progress, ``my_func`` periodically calls the ``reporter`` function passed in by Ray Tune to return the current timestep and other metrics as defined in `ray.tune.result.TrainingResult <https://github.com/ray-project/ray/blob/master/python/ray/tune/result.py>`__. Incremental results will be saved to local disk and optionally uploaded to the specified ``upload_dir`` (e.g. S3 path).
|
||||
|
||||
Visualizing Results
|
||||
-------------------
|
||||
|
|
|
@ -14,8 +14,8 @@ from botocore.config import Config
|
|||
from ray.ray_constants import BOTO_MAX_RETRIES
|
||||
|
||||
RAY = "ray-autoscaler"
|
||||
DEFAULT_RAY_INSTANCE_PROFILE = RAY
|
||||
DEFAULT_RAY_IAM_ROLE = RAY
|
||||
DEFAULT_RAY_INSTANCE_PROFILE = RAY + "-v1"
|
||||
DEFAULT_RAY_IAM_ROLE = RAY + "-v1"
|
||||
SECURITY_GROUP_TEMPLATE = RAY + "-{}"
|
||||
|
||||
assert StrictVersion(boto3.__version__) >= StrictVersion("1.4.8"), \
|
||||
|
@ -92,6 +92,8 @@ def _configure_iam_role(config):
|
|||
assert role is not None, "Failed to create role"
|
||||
role.attach_policy(
|
||||
PolicyArn="arn:aws:iam::aws:policy/AmazonEC2FullAccess")
|
||||
role.attach_policy(
|
||||
PolicyArn="arn:aws:iam::aws:policy/AmazonS3FullAccess")
|
||||
profile.add_role(RoleName=role.name)
|
||||
time.sleep(15) # wait for propagation
|
||||
|
||||
|
|
|
@ -69,7 +69,7 @@ def make_parser(**kwargs):
|
|||
DEFAULT_RESULTS_DIR))
|
||||
parser.add_argument(
|
||||
"--upload-dir", default="", type=str,
|
||||
help="Optional URI to upload training results to.")
|
||||
help="Optional URI to sync training results to (e.g. s3://bucket).")
|
||||
parser.add_argument(
|
||||
"--checkpoint-freq", default=0, type=int,
|
||||
help="How many training iterations between checkpoints. "
|
||||
|
|
71
python/ray/tune/log_sync.py
Normal file
71
python/ray/tune/log_sync.py
Normal file
|
@ -0,0 +1,71 @@
|
|||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import distutils.spawn
|
||||
import os
|
||||
import subprocess
|
||||
import time
|
||||
|
||||
from ray.tune.error import TuneError
|
||||
from ray.tune.result import DEFAULT_RESULTS_DIR
|
||||
|
||||
|
||||
# Map from (logdir, remote_dir) -> syncer
|
||||
_syncers = {}
|
||||
|
||||
|
||||
def get_syncer(local_dir, remote_dir):
|
||||
if not remote_dir.startswith("s3://"):
|
||||
raise TuneError("Upload uri must start with s3://")
|
||||
|
||||
if not distutils.spawn.find_executable("aws"):
|
||||
raise TuneError("Upload uri requires awscli tool to be installed")
|
||||
|
||||
if local_dir.startswith(DEFAULT_RESULTS_DIR + "/"):
|
||||
rel_path = os.path.relpath(local_dir, DEFAULT_RESULTS_DIR)
|
||||
remote_dir = os.path.join(remote_dir, rel_path)
|
||||
|
||||
key = (local_dir, remote_dir)
|
||||
if key not in _syncers:
|
||||
_syncers[key] = _S3LogSyncer(local_dir, remote_dir)
|
||||
|
||||
return _syncers[key]
|
||||
|
||||
|
||||
def wait_for_log_sync():
|
||||
for syncer in _syncers.values():
|
||||
syncer.wait()
|
||||
|
||||
|
||||
class _S3LogSyncer(object):
|
||||
def __init__(self, local_dir, remote_dir):
|
||||
self.local_dir = local_dir
|
||||
self.remote_dir = remote_dir
|
||||
self.last_sync_time = 0
|
||||
self.sync_process = None
|
||||
print("Created S3LogSyncer for {} -> {}".format(local_dir, remote_dir))
|
||||
|
||||
def sync_if_needed(self):
|
||||
if time.time() - self.last_sync_time > 300:
|
||||
self.sync_now()
|
||||
|
||||
def sync_now(self, force=False):
|
||||
print(
|
||||
"Syncing files from {} -> {}".format(
|
||||
self.local_dir, self.remote_dir))
|
||||
self.last_sync_time = time.time()
|
||||
if self.sync_process:
|
||||
self.sync_process.poll()
|
||||
if self.sync_process.returncode is None:
|
||||
if force:
|
||||
self.sync_process.kill()
|
||||
else:
|
||||
print("Warning: last sync is still in progress, skipping")
|
||||
return
|
||||
self.sync_process = subprocess.Popen(
|
||||
["aws", "s3", "sync", self.local_dir, self.remote_dir])
|
||||
|
||||
def wait(self):
|
||||
if self.sync_process:
|
||||
self.sync_process.wait()
|
|
@ -6,14 +6,9 @@ import csv
|
|||
import json
|
||||
import numpy as np
|
||||
import os
|
||||
import sys
|
||||
|
||||
from ray.tune.result import TrainingResult
|
||||
|
||||
if sys.version_info[0] == 2:
|
||||
import cStringIO as StringIO
|
||||
elif sys.version_info[0] == 3:
|
||||
import io as StringIO
|
||||
from ray.tune.log_sync import get_syncer
|
||||
|
||||
try:
|
||||
import tensorflow as tf
|
||||
|
@ -54,7 +49,9 @@ class Logger(object):
|
|||
|
||||
|
||||
class UnifiedLogger(Logger):
|
||||
"""Unified result logger for TensorBoard, rllab/viskit, plain json."""
|
||||
"""Unified result logger for TensorBoard, rllab/viskit, plain json.
|
||||
|
||||
This class also periodically syncs output to the given upload uri."""
|
||||
|
||||
def _init(self):
|
||||
self._loggers = []
|
||||
|
@ -63,14 +60,22 @@ class UnifiedLogger(Logger):
|
|||
print("TF not installed - cannot log with {}...".format(cls))
|
||||
continue
|
||||
self._loggers.append(cls(self.config, self.logdir, self.uri))
|
||||
if self.uri:
|
||||
self._log_syncer = get_syncer(self.logdir, self.uri)
|
||||
else:
|
||||
self._log_syncer = None
|
||||
|
||||
def on_result(self, result):
|
||||
for logger in self._loggers:
|
||||
logger.on_result(result)
|
||||
if self._log_syncer:
|
||||
self._log_syncer.sync_if_needed()
|
||||
|
||||
def close(self):
|
||||
for logger in self._loggers:
|
||||
logger.close()
|
||||
if self._log_syncer:
|
||||
self._log_syncer.sync_now(force=True)
|
||||
|
||||
|
||||
class NoopLogger(Logger):
|
||||
|
@ -85,10 +90,6 @@ class _JsonLogger(Logger):
|
|||
json.dump(self.config, f, sort_keys=True, cls=_CustomEncoder)
|
||||
local_file = os.path.join(self.logdir, "result.json")
|
||||
self.local_out = open(local_file, "w")
|
||||
if self.uri:
|
||||
self.result_buffer = StringIO.StringIO()
|
||||
import smart_open
|
||||
self.smart_open = smart_open.smart_open
|
||||
|
||||
def on_result(self, result):
|
||||
json.dump(result._asdict(), self, cls=_CustomEncoder)
|
||||
|
@ -97,14 +98,6 @@ class _JsonLogger(Logger):
|
|||
def write(self, b):
|
||||
self.local_out.write(b)
|
||||
self.local_out.flush()
|
||||
# TODO(pcm): At the moment we are writing the whole results output from
|
||||
# the beginning in each iteration. This will write O(n^2) bytes where n
|
||||
# is the number of bytes printed so far. Fix this! This should at least
|
||||
# only write the last 5MBs (S3 chunksize).
|
||||
if self.uri:
|
||||
with self.smart_open(self.uri, "w") as f:
|
||||
self.result_buffer.write(b)
|
||||
f.write(self.result_buffer.getvalue())
|
||||
|
||||
def close(self):
|
||||
self.local_out.close()
|
||||
|
|
|
@ -8,6 +8,7 @@ from ray.tune import TuneError
|
|||
from ray.tune.hyperband import HyperBandScheduler
|
||||
from ray.tune.median_stopping_rule import MedianStoppingRule
|
||||
from ray.tune.trial import Trial, DEBUG_PRINT_INTERVAL
|
||||
from ray.tune.log_sync import wait_for_log_sync
|
||||
from ray.tune.trial_runner import TrialRunner
|
||||
from ray.tune.trial_scheduler import FIFOScheduler
|
||||
from ray.tune.web_server import TuneServer
|
||||
|
@ -62,4 +63,5 @@ def run_experiments(experiments, scheduler=None, with_server=False,
|
|||
if trial.status != Trial.TERMINATED:
|
||||
raise TuneError("Trial did not complete", trial)
|
||||
|
||||
wait_for_log_sync()
|
||||
return runner.get_trials()
|
||||
|
|
Loading…
Add table
Reference in a new issue