[tune] save error msg, cleanup after object checkpoints

This commit is contained in:
Eric Liang 2018-01-29 18:48:45 -08:00 committed by GitHub
parent 0b022c0973
commit 35b1d6189b
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
9 changed files with 49 additions and 29 deletions

View file

@ -114,9 +114,9 @@ class A3CAgent(Agent):
return result
def _save(self):
def _save(self, checkpoint_dir):
checkpoint_path = os.path.join(
self.logdir, "checkpoint-{}".format(self.iteration))
checkpoint_dir, "checkpoint-{}".format(self.iteration))
agent_state = ray.get(
[a.save.remote() for a in self.remote_evaluators])
extra_data = {

View file

@ -147,8 +147,8 @@ class _MockAgent(Agent):
episode_reward_mean=10, episode_len_mean=10,
timesteps_this_iter=10, info={})
def _save(self):
path = os.path.join(self.logdir, "mock_agent.pkl")
def _save(self, checkpoint_dir):
path = os.path.join(checkpoint_dir, "mock_agent.pkl")
with open(path, 'wb') as f:
pickle.dump(self.info, f)
return path

View file

@ -218,10 +218,10 @@ class DQNAgent(Agent):
else:
self.local_evaluator.sample(no_replay=True)
def _save(self):
def _save(self, checkpoint_dir):
checkpoint_path = self.saver.save(
self.local_evaluator.sess,
os.path.join(self.logdir, "checkpoint"),
os.path.join(checkpoint_dir, "checkpoint"),
global_step=self.iteration)
extra_data = [
self.local_evaluator.save(),

View file

@ -300,9 +300,9 @@ class ESAgent(Agent):
return result
def _save(self):
def _save(self, checkpoint_dir):
checkpoint_path = os.path.join(
self.logdir, "checkpoint-{}".format(self.iteration))
checkpoint_dir, "checkpoint-{}".format(self.iteration))
weights = self.policy.get_weights()
objects = [
weights,

View file

@ -244,10 +244,10 @@ class PPOAgent(Agent):
return result
def _save(self):
def _save(self, checkpoint_dir):
checkpoint_path = self.saver.save(
self.local_evaluator.sess,
os.path.join(self.logdir, "checkpoint"),
os.path.join(checkpoint_dir, "checkpoint"),
global_step=self.iteration)
agent_state = ray.get(
[a.save.remote() for a in self.remote_evaluators])

View file

@ -35,8 +35,8 @@ class MyTrainableClass(Trainable):
# objectives such as loss or accuracy (see tune/result.py).
return TrainingResult(episode_reward_mean=v, timesteps_this_iter=1)
def _save(self):
path = os.path.join(self.logdir, "checkpoint")
def _save(self, checkpoint_dir):
path = os.path.join(checkpoint_dir, "checkpoint")
with open(path, "w") as f:
f.write(json.dumps({"timestep": self.timestep}))
return path

View file

@ -141,17 +141,20 @@ class Trainable(object):
return result
def save(self):
def save(self, checkpoint_dir=None):
"""Saves the current model state to a checkpoint.
Subclasses should override ``_save()`` instead to save state.
This method dumps additional metadata alongside the saved path.
Args:
checkpoint_dir (str): Optional dir to place the checkpoint.
Returns:
Checkpoint path that may be passed to restore().
"""
checkpoint_path = self._save()
checkpoint_path = self._save(checkpoint_dir or self.logdir)
pickle.dump(
[self._experiment_id, self._iteration, self._timesteps_total,
self._time_total],
@ -166,7 +169,8 @@ class Trainable(object):
Object holding checkpoint data.
"""
checkpoint_prefix = self.save()
tmpdir = tempfile.mkdtemp("save_to_object", dir=self.logdir)
checkpoint_prefix = self.save(tmpdir)
data = {}
base_dir = os.path.dirname(checkpoint_prefix)
@ -185,6 +189,7 @@ class Trainable(object):
len(compressed)))
f.write(compressed)
shutil.rmtree(tmpdir)
return out.getvalue()
def restore(self, checkpoint_path):
@ -234,7 +239,7 @@ class Trainable(object):
raise NotImplementedError
def _save(self):
def _save(self, checkpoint_dir):
"""Subclasses should override this to implement save()."""
raise NotImplementedError

View file

@ -20,6 +20,10 @@ DEBUG_PRINT_INTERVAL = 5
MAX_LEN_IDENTIFIER = 130
def date_str():
return datetime.today().strftime("%Y-%m-%d_%H-%M-%S")
class Resources(
namedtuple("Resources", [
"cpu", "gpu", "driver_cpu_limit", "driver_gpu_limit"])):
@ -126,7 +130,7 @@ class Trial(object):
elif self._checkpoint_obj:
self.restore_from_obj(self._checkpoint_obj)
def stop(self, error=False, stop_logger=True):
def stop(self, error=False, error_msg=None, stop_logger=True):
"""Stops this trial.
Stops this trial, releasing all allocating resources. If stopping the
@ -135,6 +139,8 @@ class Trial(object):
Args:
error (bool): Whether to mark this trial as terminated in error.
error_msg (str): Optional error message.
stop_logger (bool): Whether to shut down the trial logger.
"""
if error:
@ -143,6 +149,11 @@ class Trial(object):
self.status = Trial.TERMINATED
try:
if error_msg and self.logdir:
error_file = os.path.join(
self.logdir, "error_{}.txt".format(date_str()))
with open(error_file, "w") as f:
f.write(error_msg)
if self.runner:
stop_tasks = []
stop_tasks.append(self.runner.stop.remote())
@ -317,8 +328,7 @@ class Trial(object):
os.makedirs(self.local_dir)
self.logdir = tempfile.mkdtemp(
prefix="{}_{}".format(
self,
datetime.today().strftime("%Y-%m-%d_%H-%M-%S")),
str(self)[:MAX_LEN_IDENTIFIER], date_str()),
dir=self.local_dir)
self.result_logger = UnifiedLogger(
self.config, self.logdir, self.upload_dir)

View file

@ -195,15 +195,17 @@ class TrialRunner(object):
trial.start()
self._running[trial.train_remote()] = trial
except Exception:
print("Error starting runner, retrying:", traceback.format_exc())
error_msg = traceback.format_exc()
print("Error starting runner, retrying:", error_msg)
time.sleep(2)
trial.stop(error=True)
trial.stop(error=True, error_msg=error_msg)
try:
trial.start()
self._running[trial.train_remote()] = trial
except Exception:
print("Error starting runner, abort:", traceback.format_exc())
trial.stop(error=True)
error_msg = traceback.format_exc()
print("Error starting runner, abort:", error_msg)
trial.stop(error=True, error_msg=error_msg)
# note that we don't return the resources, since they may
# have been lost
@ -236,10 +238,11 @@ class TrialRunner(object):
assert False, "Invalid scheduling decision: {}".format(
decision)
except Exception:
print("Error processing event:", traceback.format_exc())
error_msg = traceback.format_exc()
print("Error processing event:", error_msg)
if trial.status == Trial.RUNNING:
self._scheduler_alg.on_trial_error(self, trial)
self._stop_trial(trial, error=True)
self._stop_trial(trial, error=True, error_msg=error_msg)
def _get_runnable(self):
return self._scheduler_alg.choose_trial_to_run(self)
@ -272,6 +275,7 @@ class TrialRunner(object):
result for the trial and calls `scheduler.on_trial_complete`
if RUNNING."""
error = False
error_msg = None
if trial.status in [Trial.ERROR, Trial.TERMINATED]:
return
@ -287,16 +291,17 @@ class TrialRunner(object):
trial.update_last_result(result, terminate=True)
self._scheduler_alg.on_trial_complete(self, trial, result)
except Exception:
print("Error processing event:", traceback.format_exc())
error_msg = traceback.format_exc()
print("Error processing event:", error_msg)
self._scheduler_alg.on_trial_error(self, trial)
error = True
self._stop_trial(trial, error=error)
self._stop_trial(trial, error=error, error_msg=error_msg)
def _stop_trial(self, trial, error=False):
def _stop_trial(self, trial, error=False, error_msg=None):
"""Only returns resources if resources allocated."""
prior_status = trial.status
trial.stop(error=error)
trial.stop(error=error, error_msg=error_msg)
if prior_status == Trial.RUNNING:
self._return_resources(trial.resources)