mirror of
https://github.com/vale981/ray
synced 2025-04-23 06:25:52 -04:00
[tune] save error msg, cleanup after object checkpoints
This commit is contained in:
parent
0b022c0973
commit
35b1d6189b
9 changed files with 49 additions and 29 deletions
|
@ -114,9 +114,9 @@ class A3CAgent(Agent):
|
|||
|
||||
return result
|
||||
|
||||
def _save(self):
|
||||
def _save(self, checkpoint_dir):
|
||||
checkpoint_path = os.path.join(
|
||||
self.logdir, "checkpoint-{}".format(self.iteration))
|
||||
checkpoint_dir, "checkpoint-{}".format(self.iteration))
|
||||
agent_state = ray.get(
|
||||
[a.save.remote() for a in self.remote_evaluators])
|
||||
extra_data = {
|
||||
|
|
|
@ -147,8 +147,8 @@ class _MockAgent(Agent):
|
|||
episode_reward_mean=10, episode_len_mean=10,
|
||||
timesteps_this_iter=10, info={})
|
||||
|
||||
def _save(self):
|
||||
path = os.path.join(self.logdir, "mock_agent.pkl")
|
||||
def _save(self, checkpoint_dir):
|
||||
path = os.path.join(checkpoint_dir, "mock_agent.pkl")
|
||||
with open(path, 'wb') as f:
|
||||
pickle.dump(self.info, f)
|
||||
return path
|
||||
|
|
|
@ -218,10 +218,10 @@ class DQNAgent(Agent):
|
|||
else:
|
||||
self.local_evaluator.sample(no_replay=True)
|
||||
|
||||
def _save(self):
|
||||
def _save(self, checkpoint_dir):
|
||||
checkpoint_path = self.saver.save(
|
||||
self.local_evaluator.sess,
|
||||
os.path.join(self.logdir, "checkpoint"),
|
||||
os.path.join(checkpoint_dir, "checkpoint"),
|
||||
global_step=self.iteration)
|
||||
extra_data = [
|
||||
self.local_evaluator.save(),
|
||||
|
|
|
@ -300,9 +300,9 @@ class ESAgent(Agent):
|
|||
|
||||
return result
|
||||
|
||||
def _save(self):
|
||||
def _save(self, checkpoint_dir):
|
||||
checkpoint_path = os.path.join(
|
||||
self.logdir, "checkpoint-{}".format(self.iteration))
|
||||
checkpoint_dir, "checkpoint-{}".format(self.iteration))
|
||||
weights = self.policy.get_weights()
|
||||
objects = [
|
||||
weights,
|
||||
|
|
|
@ -244,10 +244,10 @@ class PPOAgent(Agent):
|
|||
|
||||
return result
|
||||
|
||||
def _save(self):
|
||||
def _save(self, checkpoint_dir):
|
||||
checkpoint_path = self.saver.save(
|
||||
self.local_evaluator.sess,
|
||||
os.path.join(self.logdir, "checkpoint"),
|
||||
os.path.join(checkpoint_dir, "checkpoint"),
|
||||
global_step=self.iteration)
|
||||
agent_state = ray.get(
|
||||
[a.save.remote() for a in self.remote_evaluators])
|
||||
|
|
|
@ -35,8 +35,8 @@ class MyTrainableClass(Trainable):
|
|||
# objectives such as loss or accuracy (see tune/result.py).
|
||||
return TrainingResult(episode_reward_mean=v, timesteps_this_iter=1)
|
||||
|
||||
def _save(self):
|
||||
path = os.path.join(self.logdir, "checkpoint")
|
||||
def _save(self, checkpoint_dir):
|
||||
path = os.path.join(checkpoint_dir, "checkpoint")
|
||||
with open(path, "w") as f:
|
||||
f.write(json.dumps({"timestep": self.timestep}))
|
||||
return path
|
||||
|
|
|
@ -141,17 +141,20 @@ class Trainable(object):
|
|||
|
||||
return result
|
||||
|
||||
def save(self):
|
||||
def save(self, checkpoint_dir=None):
|
||||
"""Saves the current model state to a checkpoint.
|
||||
|
||||
Subclasses should override ``_save()`` instead to save state.
|
||||
This method dumps additional metadata alongside the saved path.
|
||||
|
||||
Args:
|
||||
checkpoint_dir (str): Optional dir to place the checkpoint.
|
||||
|
||||
Returns:
|
||||
Checkpoint path that may be passed to restore().
|
||||
"""
|
||||
|
||||
checkpoint_path = self._save()
|
||||
checkpoint_path = self._save(checkpoint_dir or self.logdir)
|
||||
pickle.dump(
|
||||
[self._experiment_id, self._iteration, self._timesteps_total,
|
||||
self._time_total],
|
||||
|
@ -166,7 +169,8 @@ class Trainable(object):
|
|||
Object holding checkpoint data.
|
||||
"""
|
||||
|
||||
checkpoint_prefix = self.save()
|
||||
tmpdir = tempfile.mkdtemp("save_to_object", dir=self.logdir)
|
||||
checkpoint_prefix = self.save(tmpdir)
|
||||
|
||||
data = {}
|
||||
base_dir = os.path.dirname(checkpoint_prefix)
|
||||
|
@ -185,6 +189,7 @@ class Trainable(object):
|
|||
len(compressed)))
|
||||
f.write(compressed)
|
||||
|
||||
shutil.rmtree(tmpdir)
|
||||
return out.getvalue()
|
||||
|
||||
def restore(self, checkpoint_path):
|
||||
|
@ -234,7 +239,7 @@ class Trainable(object):
|
|||
|
||||
raise NotImplementedError
|
||||
|
||||
def _save(self):
|
||||
def _save(self, checkpoint_dir):
|
||||
"""Subclasses should override this to implement save()."""
|
||||
|
||||
raise NotImplementedError
|
||||
|
|
|
@ -20,6 +20,10 @@ DEBUG_PRINT_INTERVAL = 5
|
|||
MAX_LEN_IDENTIFIER = 130
|
||||
|
||||
|
||||
def date_str():
|
||||
return datetime.today().strftime("%Y-%m-%d_%H-%M-%S")
|
||||
|
||||
|
||||
class Resources(
|
||||
namedtuple("Resources", [
|
||||
"cpu", "gpu", "driver_cpu_limit", "driver_gpu_limit"])):
|
||||
|
@ -126,7 +130,7 @@ class Trial(object):
|
|||
elif self._checkpoint_obj:
|
||||
self.restore_from_obj(self._checkpoint_obj)
|
||||
|
||||
def stop(self, error=False, stop_logger=True):
|
||||
def stop(self, error=False, error_msg=None, stop_logger=True):
|
||||
"""Stops this trial.
|
||||
|
||||
Stops this trial, releasing all allocating resources. If stopping the
|
||||
|
@ -135,6 +139,8 @@ class Trial(object):
|
|||
|
||||
Args:
|
||||
error (bool): Whether to mark this trial as terminated in error.
|
||||
error_msg (str): Optional error message.
|
||||
stop_logger (bool): Whether to shut down the trial logger.
|
||||
"""
|
||||
|
||||
if error:
|
||||
|
@ -143,6 +149,11 @@ class Trial(object):
|
|||
self.status = Trial.TERMINATED
|
||||
|
||||
try:
|
||||
if error_msg and self.logdir:
|
||||
error_file = os.path.join(
|
||||
self.logdir, "error_{}.txt".format(date_str()))
|
||||
with open(error_file, "w") as f:
|
||||
f.write(error_msg)
|
||||
if self.runner:
|
||||
stop_tasks = []
|
||||
stop_tasks.append(self.runner.stop.remote())
|
||||
|
@ -317,8 +328,7 @@ class Trial(object):
|
|||
os.makedirs(self.local_dir)
|
||||
self.logdir = tempfile.mkdtemp(
|
||||
prefix="{}_{}".format(
|
||||
self,
|
||||
datetime.today().strftime("%Y-%m-%d_%H-%M-%S")),
|
||||
str(self)[:MAX_LEN_IDENTIFIER], date_str()),
|
||||
dir=self.local_dir)
|
||||
self.result_logger = UnifiedLogger(
|
||||
self.config, self.logdir, self.upload_dir)
|
||||
|
|
|
@ -195,15 +195,17 @@ class TrialRunner(object):
|
|||
trial.start()
|
||||
self._running[trial.train_remote()] = trial
|
||||
except Exception:
|
||||
print("Error starting runner, retrying:", traceback.format_exc())
|
||||
error_msg = traceback.format_exc()
|
||||
print("Error starting runner, retrying:", error_msg)
|
||||
time.sleep(2)
|
||||
trial.stop(error=True)
|
||||
trial.stop(error=True, error_msg=error_msg)
|
||||
try:
|
||||
trial.start()
|
||||
self._running[trial.train_remote()] = trial
|
||||
except Exception:
|
||||
print("Error starting runner, abort:", traceback.format_exc())
|
||||
trial.stop(error=True)
|
||||
error_msg = traceback.format_exc()
|
||||
print("Error starting runner, abort:", error_msg)
|
||||
trial.stop(error=True, error_msg=error_msg)
|
||||
# note that we don't return the resources, since they may
|
||||
# have been lost
|
||||
|
||||
|
@ -236,10 +238,11 @@ class TrialRunner(object):
|
|||
assert False, "Invalid scheduling decision: {}".format(
|
||||
decision)
|
||||
except Exception:
|
||||
print("Error processing event:", traceback.format_exc())
|
||||
error_msg = traceback.format_exc()
|
||||
print("Error processing event:", error_msg)
|
||||
if trial.status == Trial.RUNNING:
|
||||
self._scheduler_alg.on_trial_error(self, trial)
|
||||
self._stop_trial(trial, error=True)
|
||||
self._stop_trial(trial, error=True, error_msg=error_msg)
|
||||
|
||||
def _get_runnable(self):
|
||||
return self._scheduler_alg.choose_trial_to_run(self)
|
||||
|
@ -272,6 +275,7 @@ class TrialRunner(object):
|
|||
result for the trial and calls `scheduler.on_trial_complete`
|
||||
if RUNNING."""
|
||||
error = False
|
||||
error_msg = None
|
||||
|
||||
if trial.status in [Trial.ERROR, Trial.TERMINATED]:
|
||||
return
|
||||
|
@ -287,16 +291,17 @@ class TrialRunner(object):
|
|||
trial.update_last_result(result, terminate=True)
|
||||
self._scheduler_alg.on_trial_complete(self, trial, result)
|
||||
except Exception:
|
||||
print("Error processing event:", traceback.format_exc())
|
||||
error_msg = traceback.format_exc()
|
||||
print("Error processing event:", error_msg)
|
||||
self._scheduler_alg.on_trial_error(self, trial)
|
||||
error = True
|
||||
|
||||
self._stop_trial(trial, error=error)
|
||||
self._stop_trial(trial, error=error, error_msg=error_msg)
|
||||
|
||||
def _stop_trial(self, trial, error=False):
|
||||
def _stop_trial(self, trial, error=False, error_msg=None):
|
||||
"""Only returns resources if resources allocated."""
|
||||
prior_status = trial.status
|
||||
trial.stop(error=error)
|
||||
trial.stop(error=error, error_msg=error_msg)
|
||||
if prior_status == Trial.RUNNING:
|
||||
self._return_resources(trial.resources)
|
||||
|
||||
|
|
Loading…
Add table
Reference in a new issue