[rllib] Fix DQN checkpoint/restore and enable test in jenkins (#1063)

* fix dqn restore and add test

* Update .gitignore

* Update test_checkpoint_restore.py

* add checkpoint restore
This commit is contained in:
Eric Liang 2017-10-03 23:17:54 -07:00 committed by Richard Liaw
parent a0d3fb1de1
commit 6ecc899cf2
4 changed files with 17 additions and 8 deletions

1
.gitignore vendored
View file

@ -7,6 +7,7 @@
/src/thirdparty/arrow
/flatbuffers-1.7.1/
/src/thirdparty/boost/
/src/thirdparty/boost_1_65_1/
/src/thirdparty/boost_1_60_0/
/src/thirdparty/catapult/
/src/thirdparty/flatbuffers/

View file

@ -224,7 +224,8 @@ class Actor(object):
self.episode_rewards,
self.episode_lengths,
self.saved_mean_reward,
self.obs]
self.obs,
self.replay_buffer]
def restore(self, data):
self.beta_schedule = data[0]
@ -233,6 +234,7 @@ class Actor(object):
self.episode_lengths = data[3]
self.saved_mean_reward = data[4]
self.obs = data[5]
self.replay_buffer = data[6]
@ray.remote
@ -367,7 +369,7 @@ class DQNAgent(Agent):
global_step=self.num_iterations)
extra_data = [
self.actor.save(),
self.replay_buffer,
ray.get([w.save.remote() for w in self.workers]),
self.cur_timestep,
self.num_iterations,
self.num_target_updates,
@ -376,10 +378,12 @@ class DQNAgent(Agent):
return checkpoint_path
def _restore(self, checkpoint_path):
self.saver.restore(self.sess, checkpoint_path)
self.saver.restore(self.actor.sess, checkpoint_path)
extra_data = pickle.load(open(checkpoint_path + ".extra_data", "rb"))
self.actor.restore(extra_data[0])
self.replay_buffer = extra_data[1]
ray.get([
w.restore.remote(d) for (d, w)
in zip(extra_data[1], self.workers)])
self.cur_timestep = extra_data[2]
self.num_iterations = extra_data[3]
self.num_target_updates = extra_data[4]

View file

@ -11,7 +11,6 @@ import random
from ray.rllib.dqn import (DQNAgent, DEFAULT_CONFIG as DQN_CONFIG)
from ray.rllib.ppo import (PPOAgent, DEFAULT_CONFIG as PG_CONFIG)
from ray.rllib.a3c import (A3CAgent, DEFAULT_CONFIG as A3C_CONFIG)
# from ray.rllib.es import (ESAgent, DEFAULT_CONFIG as ES_CONFIG)
@ -26,11 +25,13 @@ ray.init()
for (cls, default_config) in [
(DQNAgent, DQN_CONFIG),
(PPOAgent, PG_CONFIG),
# TODO(ekl) this fails with multiple ES instances in a process
(A3CAgent, A3C_CONFIG),
# https://github.com/ray-project/ray/issues/1062
# (ESAgent, ES_CONFIG),
(A3CAgent, A3C_CONFIG)]:
]:
config = default_config.copy()
config["num_sgd_iter"] = 5
config["use_lstm"] = False # for a3c
config["episodes_per_batch"] = 100
config["timesteps_per_batch"] = 1000
alg1 = cls("CartPole-v0", config)
@ -49,4 +50,4 @@ for (cls, default_config) in [
a1 = get_mean_action(alg1, obs)
a2 = get_mean_action(alg2, obs)
print("Checking computed actions", alg1, obs, a1, a2)
assert(abs(a1-a2) < .05)
assert abs(a1-a2) < .1, (a1, a2)

View file

@ -111,3 +111,6 @@ docker run --shm-size=10G --memory=10G $DOCKER_SHA \
--alg PPO \
--num-iterations 2 \
--config '{"kl_coeff": 1.0, "num_sgd_iter": 10, "sgd_stepsize": 1e-4, "sgd_batchsize": 64, "timesteps_per_batch": 2000, "num_workers": 1, "model": {"dim": 40, "conv_filters": [[16, [8, 8], 4], [32, [4, 4], 2], [512, [5, 5], 1]]}, "extra_frameskip": 4}'
docker run --shm-size=10G --memory=10G $DOCKER_SHA \
python /ray/python/ray/rllib/test/test_checkpoint_restore.py