From 6ecc899cf2a846ce91d1208cb4f715b96d61d404 Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Tue, 3 Oct 2017 23:17:54 -0700 Subject: [PATCH] [rllib] Fix DQN checkpoint/restore and enable test in jenkins (#1063) * fix dqn restore and add test * Update .gitignore * Update test_checkpoint_restore.py * add checkpoint restore --- .gitignore | 1 + python/ray/rllib/dqn/dqn.py | 12 ++++++++---- python/ray/rllib/test/test_checkpoint_restore.py | 9 +++++---- test/jenkins_tests/run_multi_node_tests.sh | 3 +++ 4 files changed, 17 insertions(+), 8 deletions(-) diff --git a/.gitignore b/.gitignore index 91c4789a5..98eb3766a 100644 --- a/.gitignore +++ b/.gitignore @@ -7,6 +7,7 @@ /src/thirdparty/arrow /flatbuffers-1.7.1/ /src/thirdparty/boost/ +/src/thirdparty/boost_1_65_1/ /src/thirdparty/boost_1_60_0/ /src/thirdparty/catapult/ /src/thirdparty/flatbuffers/ diff --git a/python/ray/rllib/dqn/dqn.py b/python/ray/rllib/dqn/dqn.py index 5b9726b94..60f4a5647 100644 --- a/python/ray/rllib/dqn/dqn.py +++ b/python/ray/rllib/dqn/dqn.py @@ -224,7 +224,8 @@ class Actor(object): self.episode_rewards, self.episode_lengths, self.saved_mean_reward, - self.obs] + self.obs, + self.replay_buffer] def restore(self, data): self.beta_schedule = data[0] @@ -233,6 +234,7 @@ class Actor(object): self.episode_lengths = data[3] self.saved_mean_reward = data[4] self.obs = data[5] + self.replay_buffer = data[6] @ray.remote @@ -367,7 +369,7 @@ class DQNAgent(Agent): global_step=self.num_iterations) extra_data = [ self.actor.save(), - self.replay_buffer, + ray.get([w.save.remote() for w in self.workers]), self.cur_timestep, self.num_iterations, self.num_target_updates, @@ -376,10 +378,12 @@ class DQNAgent(Agent): return checkpoint_path def _restore(self, checkpoint_path): - self.saver.restore(self.sess, checkpoint_path) + self.saver.restore(self.actor.sess, checkpoint_path) extra_data = pickle.load(open(checkpoint_path + ".extra_data", "rb")) self.actor.restore(extra_data[0]) - self.replay_buffer = extra_data[1] + ray.get([ + w.restore.remote(d) for (d, w) + in zip(extra_data[1], self.workers)]) self.cur_timestep = extra_data[2] self.num_iterations = extra_data[3] self.num_target_updates = extra_data[4] diff --git a/python/ray/rllib/test/test_checkpoint_restore.py b/python/ray/rllib/test/test_checkpoint_restore.py index 29e5c14af..83a07e3d9 100755 --- a/python/ray/rllib/test/test_checkpoint_restore.py +++ b/python/ray/rllib/test/test_checkpoint_restore.py @@ -11,7 +11,6 @@ import random from ray.rllib.dqn import (DQNAgent, DEFAULT_CONFIG as DQN_CONFIG) from ray.rllib.ppo import (PPOAgent, DEFAULT_CONFIG as PG_CONFIG) from ray.rllib.a3c import (A3CAgent, DEFAULT_CONFIG as A3C_CONFIG) - # from ray.rllib.es import (ESAgent, DEFAULT_CONFIG as ES_CONFIG) @@ -26,11 +25,13 @@ ray.init() for (cls, default_config) in [ (DQNAgent, DQN_CONFIG), (PPOAgent, PG_CONFIG), - # TODO(ekl) this fails with multiple ES instances in a process + (A3CAgent, A3C_CONFIG), + # https://github.com/ray-project/ray/issues/1062 # (ESAgent, ES_CONFIG), - (A3CAgent, A3C_CONFIG)]: + ]: config = default_config.copy() config["num_sgd_iter"] = 5 + config["use_lstm"] = False # for a3c config["episodes_per_batch"] = 100 config["timesteps_per_batch"] = 1000 alg1 = cls("CartPole-v0", config) @@ -49,4 +50,4 @@ for (cls, default_config) in [ a1 = get_mean_action(alg1, obs) a2 = get_mean_action(alg2, obs) print("Checking computed actions", alg1, obs, a1, a2) - assert(abs(a1-a2) < .05) + assert abs(a1-a2) < .1, (a1, a2) diff --git a/test/jenkins_tests/run_multi_node_tests.sh b/test/jenkins_tests/run_multi_node_tests.sh index 852148091..5efed3860 100755 --- a/test/jenkins_tests/run_multi_node_tests.sh +++ b/test/jenkins_tests/run_multi_node_tests.sh @@ -111,3 +111,6 @@ docker run --shm-size=10G --memory=10G $DOCKER_SHA \ --alg PPO \ --num-iterations 2 \ --config '{"kl_coeff": 1.0, "num_sgd_iter": 10, "sgd_stepsize": 1e-4, "sgd_batchsize": 64, "timesteps_per_batch": 2000, "num_workers": 1, "model": {"dim": 40, "conv_filters": [[16, [8, 8], 4], [32, [4, 4], 2], [512, [5, 5], 1]]}, "extra_frameskip": 4}' + +docker run --shm-size=10G --memory=10G $DOCKER_SHA \ + python /ray/python/ray/rllib/test/test_checkpoint_restore.py