mirror of
https://github.com/vale981/ray
synced 2025-03-05 18:11:42 -05:00
[rllib] Fix DQN checkpoint/restore and enable test in jenkins (#1063)
* fix dqn restore and add test * Update .gitignore * Update test_checkpoint_restore.py * add checkpoint restore
This commit is contained in:
parent
a0d3fb1de1
commit
6ecc899cf2
4 changed files with 17 additions and 8 deletions
1
.gitignore
vendored
1
.gitignore
vendored
|
@ -7,6 +7,7 @@
|
|||
/src/thirdparty/arrow
|
||||
/flatbuffers-1.7.1/
|
||||
/src/thirdparty/boost/
|
||||
/src/thirdparty/boost_1_65_1/
|
||||
/src/thirdparty/boost_1_60_0/
|
||||
/src/thirdparty/catapult/
|
||||
/src/thirdparty/flatbuffers/
|
||||
|
|
|
@ -224,7 +224,8 @@ class Actor(object):
|
|||
self.episode_rewards,
|
||||
self.episode_lengths,
|
||||
self.saved_mean_reward,
|
||||
self.obs]
|
||||
self.obs,
|
||||
self.replay_buffer]
|
||||
|
||||
def restore(self, data):
|
||||
self.beta_schedule = data[0]
|
||||
|
@ -233,6 +234,7 @@ class Actor(object):
|
|||
self.episode_lengths = data[3]
|
||||
self.saved_mean_reward = data[4]
|
||||
self.obs = data[5]
|
||||
self.replay_buffer = data[6]
|
||||
|
||||
|
||||
@ray.remote
|
||||
|
@ -367,7 +369,7 @@ class DQNAgent(Agent):
|
|||
global_step=self.num_iterations)
|
||||
extra_data = [
|
||||
self.actor.save(),
|
||||
self.replay_buffer,
|
||||
ray.get([w.save.remote() for w in self.workers]),
|
||||
self.cur_timestep,
|
||||
self.num_iterations,
|
||||
self.num_target_updates,
|
||||
|
@ -376,10 +378,12 @@ class DQNAgent(Agent):
|
|||
return checkpoint_path
|
||||
|
||||
def _restore(self, checkpoint_path):
|
||||
self.saver.restore(self.sess, checkpoint_path)
|
||||
self.saver.restore(self.actor.sess, checkpoint_path)
|
||||
extra_data = pickle.load(open(checkpoint_path + ".extra_data", "rb"))
|
||||
self.actor.restore(extra_data[0])
|
||||
self.replay_buffer = extra_data[1]
|
||||
ray.get([
|
||||
w.restore.remote(d) for (d, w)
|
||||
in zip(extra_data[1], self.workers)])
|
||||
self.cur_timestep = extra_data[2]
|
||||
self.num_iterations = extra_data[3]
|
||||
self.num_target_updates = extra_data[4]
|
||||
|
|
|
@ -11,7 +11,6 @@ import random
|
|||
from ray.rllib.dqn import (DQNAgent, DEFAULT_CONFIG as DQN_CONFIG)
|
||||
from ray.rllib.ppo import (PPOAgent, DEFAULT_CONFIG as PG_CONFIG)
|
||||
from ray.rllib.a3c import (A3CAgent, DEFAULT_CONFIG as A3C_CONFIG)
|
||||
|
||||
# from ray.rllib.es import (ESAgent, DEFAULT_CONFIG as ES_CONFIG)
|
||||
|
||||
|
||||
|
@ -26,11 +25,13 @@ ray.init()
|
|||
for (cls, default_config) in [
|
||||
(DQNAgent, DQN_CONFIG),
|
||||
(PPOAgent, PG_CONFIG),
|
||||
# TODO(ekl) this fails with multiple ES instances in a process
|
||||
(A3CAgent, A3C_CONFIG),
|
||||
# https://github.com/ray-project/ray/issues/1062
|
||||
# (ESAgent, ES_CONFIG),
|
||||
(A3CAgent, A3C_CONFIG)]:
|
||||
]:
|
||||
config = default_config.copy()
|
||||
config["num_sgd_iter"] = 5
|
||||
config["use_lstm"] = False # for a3c
|
||||
config["episodes_per_batch"] = 100
|
||||
config["timesteps_per_batch"] = 1000
|
||||
alg1 = cls("CartPole-v0", config)
|
||||
|
@ -49,4 +50,4 @@ for (cls, default_config) in [
|
|||
a1 = get_mean_action(alg1, obs)
|
||||
a2 = get_mean_action(alg2, obs)
|
||||
print("Checking computed actions", alg1, obs, a1, a2)
|
||||
assert(abs(a1-a2) < .05)
|
||||
assert abs(a1-a2) < .1, (a1, a2)
|
||||
|
|
|
@ -111,3 +111,6 @@ docker run --shm-size=10G --memory=10G $DOCKER_SHA \
|
|||
--alg PPO \
|
||||
--num-iterations 2 \
|
||||
--config '{"kl_coeff": 1.0, "num_sgd_iter": 10, "sgd_stepsize": 1e-4, "sgd_batchsize": 64, "timesteps_per_batch": 2000, "num_workers": 1, "model": {"dim": 40, "conv_filters": [[16, [8, 8], 4], [32, [4, 4], 2], [512, [5, 5], 1]]}, "extra_frameskip": 4}'
|
||||
|
||||
docker run --shm-size=10G --memory=10G $DOCKER_SHA \
|
||||
python /ray/python/ray/rllib/test/test_checkpoint_restore.py
|
||||
|
|
Loading…
Add table
Reference in a new issue