mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
[rllib] Fix DQN checkpoint/restore and enable test in jenkins (#1063)
* fix dqn restore and add test * Update .gitignore * Update test_checkpoint_restore.py * add checkpoint restore
This commit is contained in:
parent
a0d3fb1de1
commit
6ecc899cf2
4 changed files with 17 additions and 8 deletions
1
.gitignore
vendored
1
.gitignore
vendored
|
@ -7,6 +7,7 @@
|
||||||
/src/thirdparty/arrow
|
/src/thirdparty/arrow
|
||||||
/flatbuffers-1.7.1/
|
/flatbuffers-1.7.1/
|
||||||
/src/thirdparty/boost/
|
/src/thirdparty/boost/
|
||||||
|
/src/thirdparty/boost_1_65_1/
|
||||||
/src/thirdparty/boost_1_60_0/
|
/src/thirdparty/boost_1_60_0/
|
||||||
/src/thirdparty/catapult/
|
/src/thirdparty/catapult/
|
||||||
/src/thirdparty/flatbuffers/
|
/src/thirdparty/flatbuffers/
|
||||||
|
|
|
@ -224,7 +224,8 @@ class Actor(object):
|
||||||
self.episode_rewards,
|
self.episode_rewards,
|
||||||
self.episode_lengths,
|
self.episode_lengths,
|
||||||
self.saved_mean_reward,
|
self.saved_mean_reward,
|
||||||
self.obs]
|
self.obs,
|
||||||
|
self.replay_buffer]
|
||||||
|
|
||||||
def restore(self, data):
|
def restore(self, data):
|
||||||
self.beta_schedule = data[0]
|
self.beta_schedule = data[0]
|
||||||
|
@ -233,6 +234,7 @@ class Actor(object):
|
||||||
self.episode_lengths = data[3]
|
self.episode_lengths = data[3]
|
||||||
self.saved_mean_reward = data[4]
|
self.saved_mean_reward = data[4]
|
||||||
self.obs = data[5]
|
self.obs = data[5]
|
||||||
|
self.replay_buffer = data[6]
|
||||||
|
|
||||||
|
|
||||||
@ray.remote
|
@ray.remote
|
||||||
|
@ -367,7 +369,7 @@ class DQNAgent(Agent):
|
||||||
global_step=self.num_iterations)
|
global_step=self.num_iterations)
|
||||||
extra_data = [
|
extra_data = [
|
||||||
self.actor.save(),
|
self.actor.save(),
|
||||||
self.replay_buffer,
|
ray.get([w.save.remote() for w in self.workers]),
|
||||||
self.cur_timestep,
|
self.cur_timestep,
|
||||||
self.num_iterations,
|
self.num_iterations,
|
||||||
self.num_target_updates,
|
self.num_target_updates,
|
||||||
|
@ -376,10 +378,12 @@ class DQNAgent(Agent):
|
||||||
return checkpoint_path
|
return checkpoint_path
|
||||||
|
|
||||||
def _restore(self, checkpoint_path):
|
def _restore(self, checkpoint_path):
|
||||||
self.saver.restore(self.sess, checkpoint_path)
|
self.saver.restore(self.actor.sess, checkpoint_path)
|
||||||
extra_data = pickle.load(open(checkpoint_path + ".extra_data", "rb"))
|
extra_data = pickle.load(open(checkpoint_path + ".extra_data", "rb"))
|
||||||
self.actor.restore(extra_data[0])
|
self.actor.restore(extra_data[0])
|
||||||
self.replay_buffer = extra_data[1]
|
ray.get([
|
||||||
|
w.restore.remote(d) for (d, w)
|
||||||
|
in zip(extra_data[1], self.workers)])
|
||||||
self.cur_timestep = extra_data[2]
|
self.cur_timestep = extra_data[2]
|
||||||
self.num_iterations = extra_data[3]
|
self.num_iterations = extra_data[3]
|
||||||
self.num_target_updates = extra_data[4]
|
self.num_target_updates = extra_data[4]
|
||||||
|
|
|
@ -11,7 +11,6 @@ import random
|
||||||
from ray.rllib.dqn import (DQNAgent, DEFAULT_CONFIG as DQN_CONFIG)
|
from ray.rllib.dqn import (DQNAgent, DEFAULT_CONFIG as DQN_CONFIG)
|
||||||
from ray.rllib.ppo import (PPOAgent, DEFAULT_CONFIG as PG_CONFIG)
|
from ray.rllib.ppo import (PPOAgent, DEFAULT_CONFIG as PG_CONFIG)
|
||||||
from ray.rllib.a3c import (A3CAgent, DEFAULT_CONFIG as A3C_CONFIG)
|
from ray.rllib.a3c import (A3CAgent, DEFAULT_CONFIG as A3C_CONFIG)
|
||||||
|
|
||||||
# from ray.rllib.es import (ESAgent, DEFAULT_CONFIG as ES_CONFIG)
|
# from ray.rllib.es import (ESAgent, DEFAULT_CONFIG as ES_CONFIG)
|
||||||
|
|
||||||
|
|
||||||
|
@ -26,11 +25,13 @@ ray.init()
|
||||||
for (cls, default_config) in [
|
for (cls, default_config) in [
|
||||||
(DQNAgent, DQN_CONFIG),
|
(DQNAgent, DQN_CONFIG),
|
||||||
(PPOAgent, PG_CONFIG),
|
(PPOAgent, PG_CONFIG),
|
||||||
# TODO(ekl) this fails with multiple ES instances in a process
|
(A3CAgent, A3C_CONFIG),
|
||||||
|
# https://github.com/ray-project/ray/issues/1062
|
||||||
# (ESAgent, ES_CONFIG),
|
# (ESAgent, ES_CONFIG),
|
||||||
(A3CAgent, A3C_CONFIG)]:
|
]:
|
||||||
config = default_config.copy()
|
config = default_config.copy()
|
||||||
config["num_sgd_iter"] = 5
|
config["num_sgd_iter"] = 5
|
||||||
|
config["use_lstm"] = False # for a3c
|
||||||
config["episodes_per_batch"] = 100
|
config["episodes_per_batch"] = 100
|
||||||
config["timesteps_per_batch"] = 1000
|
config["timesteps_per_batch"] = 1000
|
||||||
alg1 = cls("CartPole-v0", config)
|
alg1 = cls("CartPole-v0", config)
|
||||||
|
@ -49,4 +50,4 @@ for (cls, default_config) in [
|
||||||
a1 = get_mean_action(alg1, obs)
|
a1 = get_mean_action(alg1, obs)
|
||||||
a2 = get_mean_action(alg2, obs)
|
a2 = get_mean_action(alg2, obs)
|
||||||
print("Checking computed actions", alg1, obs, a1, a2)
|
print("Checking computed actions", alg1, obs, a1, a2)
|
||||||
assert(abs(a1-a2) < .05)
|
assert abs(a1-a2) < .1, (a1, a2)
|
||||||
|
|
|
@ -111,3 +111,6 @@ docker run --shm-size=10G --memory=10G $DOCKER_SHA \
|
||||||
--alg PPO \
|
--alg PPO \
|
||||||
--num-iterations 2 \
|
--num-iterations 2 \
|
||||||
--config '{"kl_coeff": 1.0, "num_sgd_iter": 10, "sgd_stepsize": 1e-4, "sgd_batchsize": 64, "timesteps_per_batch": 2000, "num_workers": 1, "model": {"dim": 40, "conv_filters": [[16, [8, 8], 4], [32, [4, 4], 2], [512, [5, 5], 1]]}, "extra_frameskip": 4}'
|
--config '{"kl_coeff": 1.0, "num_sgd_iter": 10, "sgd_stepsize": 1e-4, "sgd_batchsize": 64, "timesteps_per_batch": 2000, "num_workers": 1, "model": {"dim": 40, "conv_filters": [[16, [8, 8], 4], [32, [4, 4], 2], [512, [5, 5], 1]]}, "extra_frameskip": 4}'
|
||||||
|
|
||||||
|
docker run --shm-size=10G --memory=10G $DOCKER_SHA \
|
||||||
|
python /ray/python/ray/rllib/test/test_checkpoint_restore.py
|
||||||
|
|
Loading…
Add table
Reference in a new issue