ray/python/ray/rllib/es/es.py

336 lines
12 KiB
Python
Raw Normal View History

# Code in this file is copied and adapted from
# https://github.com/openai/evolution-strategies-starter.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from collections import namedtuple
import numpy as np
import os
import pickle
import time
import ray
from ray.rllib import agent
from ray.tune.trial import Resources
from ray.rllib.es import optimizers
from ray.rllib.es import policies
from ray.rllib.es import tabular_logger as tlogger
from ray.rllib.es import utils
Result = namedtuple("Result", [
"noise_indices", "noisy_returns", "sign_noisy_returns", "noisy_lengths",
"eval_returns", "eval_lengths"
])
DEFAULT_CONFIG = {
'l2_coeff': 0.005,
'noise_stdev': 0.02,
'episodes_per_batch': 1000,
'timesteps_per_batch': 10000,
'eval_prob': 0.003,
'return_proc_mode': "centered_rank",
'num_workers': 10,
'stepsize': 0.01,
'observation_filter': "MeanStdFilter",
'noise_size': 250000000,
'env_config': {},
}
@ray.remote
def create_shared_noise(count):
"""Create a large array of noise to be shared by all workers."""
seed = 123
noise = np.random.RandomState(seed).randn(count).astype(np.float32)
return noise
class SharedNoiseTable(object):
def __init__(self, noise):
self.noise = noise
assert self.noise.dtype == np.float32
def get(self, i, dim):
return self.noise[i:i + dim]
def sample_index(self, dim):
return np.random.randint(0, len(self.noise) - dim + 1)
@ray.remote
class Worker(object):
def __init__(self, registry, config, policy_params, env_creator, noise,
min_task_runtime=0.2):
self.min_task_runtime = min_task_runtime
self.config = config
self.policy_params = policy_params
self.noise = SharedNoiseTable(noise)
[carla] [rllib] Add support for carla nav planner and scenarios from paper (#1382) * wip * Sat Dec 30 15:07:28 PST 2017 * log video * video doesn't work well * scenario integration * Sat Dec 30 17:30:22 PST 2017 * Sat Dec 30 17:31:05 PST 2017 * Sat Dec 30 17:31:32 PST 2017 * Sat Dec 30 17:32:16 PST 2017 * Sat Dec 30 17:34:11 PST 2017 * Sat Dec 30 17:34:50 PST 2017 * Sat Dec 30 17:35:34 PST 2017 * Sat Dec 30 17:38:49 PST 2017 * Sat Dec 30 17:40:39 PST 2017 * Sat Dec 30 17:43:00 PST 2017 * Sat Dec 30 17:43:04 PST 2017 * Sat Dec 30 17:45:56 PST 2017 * Sat Dec 30 17:46:26 PST 2017 * Sat Dec 30 17:47:02 PST 2017 * Sat Dec 30 17:51:53 PST 2017 * Sat Dec 30 17:52:54 PST 2017 * Sat Dec 30 17:56:43 PST 2017 * Sat Dec 30 18:27:07 PST 2017 * Sat Dec 30 18:27:52 PST 2017 * fix train * Sat Dec 30 18:41:51 PST 2017 * Sat Dec 30 18:54:11 PST 2017 * Sat Dec 30 18:56:22 PST 2017 * Sat Dec 30 19:05:04 PST 2017 * Sat Dec 30 19:05:23 PST 2017 * Sat Dec 30 19:11:53 PST 2017 * Sat Dec 30 19:14:31 PST 2017 * Sat Dec 30 19:16:20 PST 2017 * Sat Dec 30 19:18:05 PST 2017 * Sat Dec 30 19:18:45 PST 2017 * Sat Dec 30 19:22:44 PST 2017 * Sat Dec 30 19:24:41 PST 2017 * Sat Dec 30 19:26:57 PST 2017 * Sat Dec 30 19:40:37 PST 2017 * wip models * reward bonus * test prep * Sun Dec 31 18:45:25 PST 2017 * Sun Dec 31 18:58:28 PST 2017 * Sun Dec 31 18:59:34 PST 2017 * Sun Dec 31 19:03:33 PST 2017 * Sun Dec 31 19:05:05 PST 2017 * Sun Dec 31 19:09:25 PST 2017 * fix train * kill * add tuple preprocessor * Sun Dec 31 20:38:33 PST 2017 * Sun Dec 31 22:51:24 PST 2017 * Sun Dec 31 23:14:13 PST 2017 * Sun Dec 31 23:16:04 PST 2017 * Mon Jan 1 00:08:35 PST 2018 * Mon Jan 1 00:10:48 PST 2018 * Mon Jan 1 01:08:31 PST 2018 * Mon Jan 1 14:45:44 PST 2018 * Mon Jan 1 14:54:56 PST 2018 * Mon Jan 1 17:29:29 PST 2018 * switch to euclidean dists * Mon Jan 1 17:39:27 PST 2018 * Mon Jan 1 17:41:47 PST 2018 * Mon Jan 1 17:44:18 PST 2018 * Mon Jan 1 17:47:09 PST 2018 * Mon Jan 1 20:31:02 PST 2018 * Mon Jan 1 20:39:33 PST 2018 * Mon Jan 1 20:40:55 PST 2018 * Mon Jan 1 20:55:06 PST 2018 * Mon Jan 1 21:05:52 PST 2018 * fix env path * merge richards fix * fix hash * Mon Jan 1 22:04:00 PST 2018 * Mon Jan 1 22:25:29 PST 2018 * Mon Jan 1 22:30:42 PST 2018 * simplified reward function * add framestack * add env configs * simplify speed reward * Tue Jan 2 17:36:15 PST 2018 * Tue Jan 2 17:49:16 PST 2018 * Tue Jan 2 18:10:38 PST 2018 * add lane keeping simple mode * Tue Jan 2 20:25:26 PST 2018 * Tue Jan 2 20:30:30 PST 2018 * Tue Jan 2 20:33:26 PST 2018 * Tue Jan 2 20:41:42 PST 2018 * ppo lane keep * simplify discrete actions * Tue Jan 2 21:41:05 PST 2018 * Tue Jan 2 21:49:03 PST 2018 * Tue Jan 2 22:12:23 PST 2018 * Tue Jan 2 22:14:42 PST 2018 * Tue Jan 2 22:20:59 PST 2018 * Tue Jan 2 22:23:43 PST 2018 * Tue Jan 2 22:26:27 PST 2018 * Tue Jan 2 22:27:20 PST 2018 * Tue Jan 2 22:44:00 PST 2018 * Tue Jan 2 22:57:58 PST 2018 * Tue Jan 2 23:08:51 PST 2018 * Tue Jan 2 23:11:32 PST 2018 * update dqn reward * Thu Jan 4 12:29:40 PST 2018 * Thu Jan 4 12:30:26 PST 2018 * Update train_dqn.py * fix
2018-01-05 21:32:41 -08:00
self.env = env_creator(config["env_config"])
from ray.rllib import models
self.preprocessor = models.ModelCatalog.get_preprocessor(
registry, self.env)
self.sess = utils.make_session(single_threaded=True)
self.policy = policies.GenericPolicy(
registry, self.sess, self.env.action_space, self.preprocessor,
config["observation_filter"], **policy_params)
def rollout(self, timestep_limit, add_noise=True):
rollout_rewards, rollout_length = policies.rollout(
self.policy, self.env, timestep_limit=timestep_limit,
add_noise=add_noise)
return rollout_rewards, rollout_length
def do_rollouts(self, params, timestep_limit=None):
# Set the network weights.
self.policy.set_weights(params)
noise_indices, returns, sign_returns, lengths = [], [], [], []
eval_returns, eval_lengths = [], []
# Perform some rollouts with noise.
task_tstart = time.time()
while (len(noise_indices) == 0 or
time.time() - task_tstart < self.min_task_runtime):
if np.random.uniform() < self.config["eval_prob"]:
# Do an evaluation run with no perturbation.
self.policy.set_weights(params)
rewards, length = self.rollout(timestep_limit, add_noise=False)
eval_returns.append(rewards.sum())
eval_lengths.append(length)
else:
# Do a regular run with parameter perturbations.
noise_index = self.noise.sample_index(self.policy.num_params)
perturbation = self.config["noise_stdev"] * self.noise.get(
noise_index, self.policy.num_params)
# These two sampling steps could be done in parallel on
# different actors letting us update twice as frequently.
self.policy.set_weights(params + perturbation)
rewards_pos, lengths_pos = self.rollout(timestep_limit)
self.policy.set_weights(params - perturbation)
rewards_neg, lengths_neg = self.rollout(timestep_limit)
noise_indices.append(noise_index)
returns.append([rewards_pos.sum(), rewards_neg.sum()])
sign_returns.append(
[np.sign(rewards_pos).sum(), np.sign(rewards_neg).sum()])
lengths.append([lengths_pos, lengths_neg])
2018-01-31 17:22:39 -08:00
return Result(
noise_indices=noise_indices,
noisy_returns=returns,
sign_noisy_returns=sign_returns,
noisy_lengths=lengths,
eval_returns=eval_returns,
eval_lengths=eval_lengths)
class ESAgent(agent.Agent):
_agent_name = "ES"
_default_config = DEFAULT_CONFIG
[carla] [rllib] Add support for carla nav planner and scenarios from paper (#1382) * wip * Sat Dec 30 15:07:28 PST 2017 * log video * video doesn't work well * scenario integration * Sat Dec 30 17:30:22 PST 2017 * Sat Dec 30 17:31:05 PST 2017 * Sat Dec 30 17:31:32 PST 2017 * Sat Dec 30 17:32:16 PST 2017 * Sat Dec 30 17:34:11 PST 2017 * Sat Dec 30 17:34:50 PST 2017 * Sat Dec 30 17:35:34 PST 2017 * Sat Dec 30 17:38:49 PST 2017 * Sat Dec 30 17:40:39 PST 2017 * Sat Dec 30 17:43:00 PST 2017 * Sat Dec 30 17:43:04 PST 2017 * Sat Dec 30 17:45:56 PST 2017 * Sat Dec 30 17:46:26 PST 2017 * Sat Dec 30 17:47:02 PST 2017 * Sat Dec 30 17:51:53 PST 2017 * Sat Dec 30 17:52:54 PST 2017 * Sat Dec 30 17:56:43 PST 2017 * Sat Dec 30 18:27:07 PST 2017 * Sat Dec 30 18:27:52 PST 2017 * fix train * Sat Dec 30 18:41:51 PST 2017 * Sat Dec 30 18:54:11 PST 2017 * Sat Dec 30 18:56:22 PST 2017 * Sat Dec 30 19:05:04 PST 2017 * Sat Dec 30 19:05:23 PST 2017 * Sat Dec 30 19:11:53 PST 2017 * Sat Dec 30 19:14:31 PST 2017 * Sat Dec 30 19:16:20 PST 2017 * Sat Dec 30 19:18:05 PST 2017 * Sat Dec 30 19:18:45 PST 2017 * Sat Dec 30 19:22:44 PST 2017 * Sat Dec 30 19:24:41 PST 2017 * Sat Dec 30 19:26:57 PST 2017 * Sat Dec 30 19:40:37 PST 2017 * wip models * reward bonus * test prep * Sun Dec 31 18:45:25 PST 2017 * Sun Dec 31 18:58:28 PST 2017 * Sun Dec 31 18:59:34 PST 2017 * Sun Dec 31 19:03:33 PST 2017 * Sun Dec 31 19:05:05 PST 2017 * Sun Dec 31 19:09:25 PST 2017 * fix train * kill * add tuple preprocessor * Sun Dec 31 20:38:33 PST 2017 * Sun Dec 31 22:51:24 PST 2017 * Sun Dec 31 23:14:13 PST 2017 * Sun Dec 31 23:16:04 PST 2017 * Mon Jan 1 00:08:35 PST 2018 * Mon Jan 1 00:10:48 PST 2018 * Mon Jan 1 01:08:31 PST 2018 * Mon Jan 1 14:45:44 PST 2018 * Mon Jan 1 14:54:56 PST 2018 * Mon Jan 1 17:29:29 PST 2018 * switch to euclidean dists * Mon Jan 1 17:39:27 PST 2018 * Mon Jan 1 17:41:47 PST 2018 * Mon Jan 1 17:44:18 PST 2018 * Mon Jan 1 17:47:09 PST 2018 * Mon Jan 1 20:31:02 PST 2018 * Mon Jan 1 20:39:33 PST 2018 * Mon Jan 1 20:40:55 PST 2018 * Mon Jan 1 20:55:06 PST 2018 * Mon Jan 1 21:05:52 PST 2018 * fix env path * merge richards fix * fix hash * Mon Jan 1 22:04:00 PST 2018 * Mon Jan 1 22:25:29 PST 2018 * Mon Jan 1 22:30:42 PST 2018 * simplified reward function * add framestack * add env configs * simplify speed reward * Tue Jan 2 17:36:15 PST 2018 * Tue Jan 2 17:49:16 PST 2018 * Tue Jan 2 18:10:38 PST 2018 * add lane keeping simple mode * Tue Jan 2 20:25:26 PST 2018 * Tue Jan 2 20:30:30 PST 2018 * Tue Jan 2 20:33:26 PST 2018 * Tue Jan 2 20:41:42 PST 2018 * ppo lane keep * simplify discrete actions * Tue Jan 2 21:41:05 PST 2018 * Tue Jan 2 21:49:03 PST 2018 * Tue Jan 2 22:12:23 PST 2018 * Tue Jan 2 22:14:42 PST 2018 * Tue Jan 2 22:20:59 PST 2018 * Tue Jan 2 22:23:43 PST 2018 * Tue Jan 2 22:26:27 PST 2018 * Tue Jan 2 22:27:20 PST 2018 * Tue Jan 2 22:44:00 PST 2018 * Tue Jan 2 22:57:58 PST 2018 * Tue Jan 2 23:08:51 PST 2018 * Tue Jan 2 23:11:32 PST 2018 * update dqn reward * Thu Jan 4 12:29:40 PST 2018 * Thu Jan 4 12:30:26 PST 2018 * Update train_dqn.py * fix
2018-01-05 21:32:41 -08:00
_allow_unknown_subkeys = ["env_config"]
@classmethod
def default_resource_request(cls, config):
cf = dict(cls._default_config, **config)
return Resources(cpu=1, gpu=0, extra_cpu=cf["num_workers"])
def _init(self):
policy_params = {
"action_noise_std": 0.01
}
[carla] [rllib] Add support for carla nav planner and scenarios from paper (#1382) * wip * Sat Dec 30 15:07:28 PST 2017 * log video * video doesn't work well * scenario integration * Sat Dec 30 17:30:22 PST 2017 * Sat Dec 30 17:31:05 PST 2017 * Sat Dec 30 17:31:32 PST 2017 * Sat Dec 30 17:32:16 PST 2017 * Sat Dec 30 17:34:11 PST 2017 * Sat Dec 30 17:34:50 PST 2017 * Sat Dec 30 17:35:34 PST 2017 * Sat Dec 30 17:38:49 PST 2017 * Sat Dec 30 17:40:39 PST 2017 * Sat Dec 30 17:43:00 PST 2017 * Sat Dec 30 17:43:04 PST 2017 * Sat Dec 30 17:45:56 PST 2017 * Sat Dec 30 17:46:26 PST 2017 * Sat Dec 30 17:47:02 PST 2017 * Sat Dec 30 17:51:53 PST 2017 * Sat Dec 30 17:52:54 PST 2017 * Sat Dec 30 17:56:43 PST 2017 * Sat Dec 30 18:27:07 PST 2017 * Sat Dec 30 18:27:52 PST 2017 * fix train * Sat Dec 30 18:41:51 PST 2017 * Sat Dec 30 18:54:11 PST 2017 * Sat Dec 30 18:56:22 PST 2017 * Sat Dec 30 19:05:04 PST 2017 * Sat Dec 30 19:05:23 PST 2017 * Sat Dec 30 19:11:53 PST 2017 * Sat Dec 30 19:14:31 PST 2017 * Sat Dec 30 19:16:20 PST 2017 * Sat Dec 30 19:18:05 PST 2017 * Sat Dec 30 19:18:45 PST 2017 * Sat Dec 30 19:22:44 PST 2017 * Sat Dec 30 19:24:41 PST 2017 * Sat Dec 30 19:26:57 PST 2017 * Sat Dec 30 19:40:37 PST 2017 * wip models * reward bonus * test prep * Sun Dec 31 18:45:25 PST 2017 * Sun Dec 31 18:58:28 PST 2017 * Sun Dec 31 18:59:34 PST 2017 * Sun Dec 31 19:03:33 PST 2017 * Sun Dec 31 19:05:05 PST 2017 * Sun Dec 31 19:09:25 PST 2017 * fix train * kill * add tuple preprocessor * Sun Dec 31 20:38:33 PST 2017 * Sun Dec 31 22:51:24 PST 2017 * Sun Dec 31 23:14:13 PST 2017 * Sun Dec 31 23:16:04 PST 2017 * Mon Jan 1 00:08:35 PST 2018 * Mon Jan 1 00:10:48 PST 2018 * Mon Jan 1 01:08:31 PST 2018 * Mon Jan 1 14:45:44 PST 2018 * Mon Jan 1 14:54:56 PST 2018 * Mon Jan 1 17:29:29 PST 2018 * switch to euclidean dists * Mon Jan 1 17:39:27 PST 2018 * Mon Jan 1 17:41:47 PST 2018 * Mon Jan 1 17:44:18 PST 2018 * Mon Jan 1 17:47:09 PST 2018 * Mon Jan 1 20:31:02 PST 2018 * Mon Jan 1 20:39:33 PST 2018 * Mon Jan 1 20:40:55 PST 2018 * Mon Jan 1 20:55:06 PST 2018 * Mon Jan 1 21:05:52 PST 2018 * fix env path * merge richards fix * fix hash * Mon Jan 1 22:04:00 PST 2018 * Mon Jan 1 22:25:29 PST 2018 * Mon Jan 1 22:30:42 PST 2018 * simplified reward function * add framestack * add env configs * simplify speed reward * Tue Jan 2 17:36:15 PST 2018 * Tue Jan 2 17:49:16 PST 2018 * Tue Jan 2 18:10:38 PST 2018 * add lane keeping simple mode * Tue Jan 2 20:25:26 PST 2018 * Tue Jan 2 20:30:30 PST 2018 * Tue Jan 2 20:33:26 PST 2018 * Tue Jan 2 20:41:42 PST 2018 * ppo lane keep * simplify discrete actions * Tue Jan 2 21:41:05 PST 2018 * Tue Jan 2 21:49:03 PST 2018 * Tue Jan 2 22:12:23 PST 2018 * Tue Jan 2 22:14:42 PST 2018 * Tue Jan 2 22:20:59 PST 2018 * Tue Jan 2 22:23:43 PST 2018 * Tue Jan 2 22:26:27 PST 2018 * Tue Jan 2 22:27:20 PST 2018 * Tue Jan 2 22:44:00 PST 2018 * Tue Jan 2 22:57:58 PST 2018 * Tue Jan 2 23:08:51 PST 2018 * Tue Jan 2 23:11:32 PST 2018 * update dqn reward * Thu Jan 4 12:29:40 PST 2018 * Thu Jan 4 12:30:26 PST 2018 * Update train_dqn.py * fix
2018-01-05 21:32:41 -08:00
env = self.env_creator(self.config["env_config"])
from ray.rllib import models
preprocessor = models.ModelCatalog.get_preprocessor(
self.registry, env)
self.sess = utils.make_session(single_threaded=False)
self.policy = policies.GenericPolicy(
self.registry, self.sess, env.action_space, preprocessor,
self.config["observation_filter"], **policy_params)
self.optimizer = optimizers.Adam(self.policy, self.config["stepsize"])
# Create the shared noise table.
print("Creating shared noise table.")
noise_id = create_shared_noise.remote(self.config["noise_size"])
self.noise = SharedNoiseTable(ray.get(noise_id))
# Create the actors.
print("Creating actors.")
self.workers = [
Worker.remote(
self.registry, self.config, policy_params, self.env_creator,
noise_id)
for _ in range(self.config["num_workers"])]
self.episodes_so_far = 0
self.timesteps_so_far = 0
self.tstart = time.time()
def _collect_results(self, theta_id, min_episodes, min_timesteps):
num_episodes, num_timesteps = 0, 0
results = []
while num_episodes < min_episodes or num_timesteps < min_timesteps:
print(
"Collected {} episodes {} timesteps so far this iter".format(
num_episodes, num_timesteps))
rollout_ids = [worker.do_rollouts.remote(theta_id)
for worker in self.workers]
# Get the results of the rollouts.
for result in ray.get(rollout_ids):
results.append(result)
# Update the number of episodes and the number of timesteps
# keeping in mind that result.noisy_lengths is a list of lists,
# where the inner lists have length 2.
num_episodes += sum(len(pair) for pair
in result.noisy_lengths)
num_timesteps += sum(sum(pair) for pair
in result.noisy_lengths)
return results, num_episodes, num_timesteps
def _train(self):
config = self.config
step_tstart = time.time()
theta = self.policy.get_weights()
assert theta.dtype == np.float32
# Put the current policy weights in the object store.
theta_id = ray.put(theta)
# Use the actors to do rollouts, note that we pass in the ID of the
# policy weights.
results, num_episodes, num_timesteps = self._collect_results(
theta_id,
config["episodes_per_batch"],
config["timesteps_per_batch"])
all_noise_indices = []
all_training_returns = []
all_training_lengths = []
all_eval_returns = []
all_eval_lengths = []
# Loop over the results.
for result in results:
all_eval_returns += result.eval_returns
all_eval_lengths += result.eval_lengths
all_noise_indices += result.noise_indices
all_training_returns += result.noisy_returns
all_training_lengths += result.noisy_lengths
assert len(all_eval_returns) == len(all_eval_lengths)
assert (len(all_noise_indices) == len(all_training_returns) ==
len(all_training_lengths))
self.episodes_so_far += num_episodes
self.timesteps_so_far += num_timesteps
# Assemble the results.
eval_returns = np.array(all_eval_returns)
eval_lengths = np.array(all_eval_lengths)
noise_indices = np.array(all_noise_indices)
noisy_returns = np.array(all_training_returns)
noisy_lengths = np.array(all_training_lengths)
# Process the returns.
if config["return_proc_mode"] == "centered_rank":
proc_noisy_returns = utils.compute_centered_ranks(noisy_returns)
else:
raise NotImplementedError(config["return_proc_mode"])
# Compute and take a step.
g, count = utils.batched_weighted_sum(
proc_noisy_returns[:, 0] - proc_noisy_returns[:, 1],
(self.noise.get(index, self.policy.num_params)
for index in noise_indices),
batch_size=500)
g /= noisy_returns.size
assert (
g.shape == (self.policy.num_params,) and
g.dtype == np.float32 and
count == len(noise_indices))
# Compute the new weights theta.
theta, update_ratio = self.optimizer.update(
-g + config["l2_coeff"] * theta)
# Set the new weights in the local copy of the policy.
self.policy.set_weights(theta)
step_tend = time.time()
tlogger.record_tabular("EvalEpRewMean", eval_returns.mean())
tlogger.record_tabular("EvalEpRewStd", eval_returns.std())
tlogger.record_tabular("EvalEpLenMean", eval_lengths.mean())
tlogger.record_tabular("EpRewMean", noisy_returns.mean())
tlogger.record_tabular("EpRewStd", noisy_returns.std())
tlogger.record_tabular("EpLenMean", noisy_lengths.mean())
tlogger.record_tabular("Norm", float(np.square(theta).sum()))
tlogger.record_tabular("GradNorm", float(np.square(g).sum()))
tlogger.record_tabular("UpdateRatio", float(update_ratio))
tlogger.record_tabular("EpisodesThisIter", noisy_lengths.size)
tlogger.record_tabular("EpisodesSoFar", self.episodes_so_far)
tlogger.record_tabular("TimestepsThisIter", noisy_lengths.sum())
tlogger.record_tabular("TimestepsSoFar", self.timesteps_so_far)
tlogger.record_tabular("TimeElapsedThisIter", step_tend - step_tstart)
tlogger.record_tabular("TimeElapsed", step_tend - self.tstart)
tlogger.dump_tabular()
info = {
"weights_norm": np.square(theta).sum(),
"grad_norm": np.square(g).sum(),
"update_ratio": update_ratio,
"episodes_this_iter": noisy_lengths.size,
"episodes_so_far": self.episodes_so_far,
"timesteps_this_iter": noisy_lengths.sum(),
"timesteps_so_far": self.timesteps_so_far,
"time_elapsed_this_iter": step_tend - step_tstart,
"time_elapsed": step_tend - self.tstart
}
result = ray.tune.result.TrainingResult(
episode_reward_mean=eval_returns.mean(),
episode_len_mean=eval_lengths.mean(),
timesteps_this_iter=noisy_lengths.sum(),
info=info)
return result
def _stop(self):
# workaround for https://github.com/ray-project/ray/issues/1516
for w in self.workers:
w.__ray_terminate__.remote()
def _save(self, checkpoint_dir):
checkpoint_path = os.path.join(
checkpoint_dir, "checkpoint-{}".format(self.iteration))
weights = self.policy.get_weights()
objects = [
weights,
self.episodes_so_far,
self.timesteps_so_far]
pickle.dump(objects, open(checkpoint_path, "wb"))
return checkpoint_path
def _restore(self, checkpoint_path):
objects = pickle.load(open(checkpoint_path, "rb"))
self.policy.set_weights(objects[0])
self.episodes_so_far = objects[1]
self.timesteps_so_far = objects[2]
def compute_action(self, observation):
return self.policy.compute(observation, update=False)[0]