mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
[rllib] A3C Configurations (#1370)
* initial introduction of a3c configs * fix sample batch * flake but need to check save * save,resotre * fix * pickles * entropy * fix * moving ppo * results * jenkins
This commit is contained in:
parent
b217a5ef14
commit
4bb5b6bd5b
15 changed files with 164 additions and 113 deletions
|
@ -8,85 +8,83 @@ import os
|
|||
|
||||
import ray
|
||||
from ray.rllib.agent import Agent
|
||||
from ray.rllib.envs import create_and_wrap
|
||||
from ray.rllib.a3c.runner import RemoteA3CEvaluator
|
||||
from ray.rllib.a3c.common import get_policy_cls
|
||||
from ray.rllib.utils.filter import get_filter
|
||||
from ray.rllib.optimizers import AsyncOptimizer
|
||||
from ray.rllib.a3c.base_evaluator import A3CEvaluator, RemoteA3CEvaluator
|
||||
from ray.tune.result import TrainingResult
|
||||
|
||||
|
||||
DEFAULT_CONFIG = {
|
||||
# Number of workers (excluding master)
|
||||
"num_workers": 4,
|
||||
"num_batches_per_iteration": 100,
|
||||
|
||||
# Size of rollout batch
|
||||
"batch_size": 10,
|
||||
"use_lstm": True,
|
||||
# Use LSTM model - only applicable for image states
|
||||
"use_lstm": False,
|
||||
# Use PyTorch as backend - no LSTM support
|
||||
"use_pytorch": False,
|
||||
# Which observation filter to apply to the observation
|
||||
"observation_filter": "NoFilter",
|
||||
# Which reward filter to apply to the reward
|
||||
"reward_filter": "NoFilter",
|
||||
|
||||
"model": {"grayscale": True,
|
||||
"zero_mean": False,
|
||||
"dim": 42,
|
||||
"channel_major": False}
|
||||
# Discount factor of MDP
|
||||
"gamma": 0.99,
|
||||
# GAE(gamma) parameter
|
||||
"lambda": 1.0,
|
||||
# Max global norm for each gradient calculated by worker
|
||||
"grad_clip": 40.0,
|
||||
# Learning rate
|
||||
"lr": 0.0001,
|
||||
# Value Function Loss coefficient
|
||||
"vf_loss_coeff": 0.5,
|
||||
# Entropy coefficient
|
||||
"entropy_coeff": -0.01,
|
||||
# Preprocessing for environment
|
||||
"preprocessing": {
|
||||
# (Image statespace) - Converts image to Channels = 1
|
||||
"grayscale": True,
|
||||
# (Image statespace) - Each pixel
|
||||
"zero_mean": False,
|
||||
# (Image statespace) - Converts image to (dim, dim, C)
|
||||
"dim": 80,
|
||||
# (Image statespace) - Converts image shape to (C, dim, dim)
|
||||
"channel_major": False
|
||||
},
|
||||
# Configuration for model specification
|
||||
"model": {},
|
||||
# Arguments to pass to the rllib optimizer
|
||||
"optimizer": {
|
||||
# Number of gradients applied for each `train` step
|
||||
"grads_per_step": 100,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class A3CAgent(Agent):
|
||||
_agent_name = "A3C"
|
||||
_default_config = DEFAULT_CONFIG
|
||||
_allow_unknown_subkeys = ["model", "optimizer"]
|
||||
|
||||
def _init(self):
|
||||
self.env = create_and_wrap(self.env_creator, self.config["model"])
|
||||
policy_cls = get_policy_cls(self.config)
|
||||
self.policy = policy_cls(
|
||||
self.env.observation_space.shape, self.env.action_space)
|
||||
self.obs_filter = get_filter(
|
||||
self.config["observation_filter"],
|
||||
self.env.observation_space.shape)
|
||||
self.rew_filter = get_filter(self.config["reward_filter"], ())
|
||||
self.agents = [
|
||||
self.local_evaluator = A3CEvaluator(
|
||||
self.env_creator, self.config, self.logdir, start_sampler=False)
|
||||
self.remote_evaluators = [
|
||||
RemoteA3CEvaluator.remote(
|
||||
self.env_creator, self.config, self.logdir)
|
||||
for i in range(self.config["num_workers"])]
|
||||
self.parameters = self.policy.get_weights()
|
||||
self.optimizer = AsyncOptimizer(
|
||||
self.config["optimizer"], self.local_evaluator,
|
||||
self.remote_evaluators)
|
||||
|
||||
def _train(self):
|
||||
remote_params = ray.put(self.parameters)
|
||||
ray.get([agent.set_weights.remote(remote_params)
|
||||
for agent in self.agents])
|
||||
|
||||
gradient_list = {agent.compute_gradient.remote(): agent
|
||||
for agent in self.agents}
|
||||
max_batches = self.config["num_batches_per_iteration"]
|
||||
batches_so_far = len(gradient_list)
|
||||
while gradient_list:
|
||||
[done_id], _ = ray.wait(list(gradient_list))
|
||||
gradient, info = ray.get(done_id)
|
||||
agent = gradient_list.pop(done_id)
|
||||
self.obs_filter.update(info["obs_filter"])
|
||||
self.rew_filter.update(info["rew_filter"])
|
||||
self.policy.apply_gradients(gradient)
|
||||
self.parameters = self.policy.get_weights()
|
||||
|
||||
if batches_so_far < max_batches:
|
||||
batches_so_far += 1
|
||||
agent.update_filters.remote(
|
||||
obs_filter=self.obs_filter,
|
||||
rew_filter=self.rew_filter)
|
||||
agent.set_weights.remote(self.parameters)
|
||||
gradient_list[agent.compute_gradient.remote()] = agent
|
||||
res = self._fetch_metrics_from_workers()
|
||||
self.optimizer.step()
|
||||
res = self._fetch_metrics_from_remote_evaluators()
|
||||
return res
|
||||
|
||||
def _fetch_metrics_from_workers(self):
|
||||
def _fetch_metrics_from_remote_evaluators(self):
|
||||
episode_rewards = []
|
||||
episode_lengths = []
|
||||
metric_lists = [
|
||||
a.get_completed_rollout_metrics.remote() for a in self.agents]
|
||||
metric_lists = [a.get_completed_rollout_metrics.remote()
|
||||
for a in self.remote_evaluators]
|
||||
for metrics in metric_lists:
|
||||
for episode in ray.get(metrics):
|
||||
episode_lengths.append(episode.episode_length)
|
||||
|
@ -106,22 +104,25 @@ class A3CAgent(Agent):
|
|||
return result
|
||||
|
||||
def _save(self):
|
||||
# TODO(rliaw): extend to also support saving worker state?
|
||||
checkpoint_path = os.path.join(
|
||||
self.logdir, "checkpoint-{}".format(self.iteration))
|
||||
objects = [self.parameters, self.obs_filter, self.rew_filter]
|
||||
pickle.dump(objects, open(checkpoint_path, "wb"))
|
||||
# self.saver.save
|
||||
agent_state = ray.get(
|
||||
[a.save.remote() for a in self.remote_evaluators])
|
||||
extra_data = {
|
||||
"remote_state": agent_state,
|
||||
"local_state": self.local_evaluator.save()}
|
||||
pickle.dump(extra_data, open(checkpoint_path + ".extra_data", "wb"))
|
||||
return checkpoint_path
|
||||
|
||||
def _restore(self, checkpoint_path):
|
||||
objects = pickle.load(open(checkpoint_path, "rb"))
|
||||
self.parameters = objects[0]
|
||||
self.obs_filter = objects[1]
|
||||
self.rew_filter = objects[2]
|
||||
self.policy.set_weights(self.parameters)
|
||||
extra_data = pickle.load(open(checkpoint_path + ".extra_data", "rb"))
|
||||
ray.get(
|
||||
[a.restore.remote(o) for a, o in zip(
|
||||
self.remote_evaluators, extra_data["remote_state"])])
|
||||
self.local_evaluator.restore(extra_data["local_state"])
|
||||
|
||||
# TODO(rliaw): augment to support LSTM
|
||||
def compute_action(self, observation):
|
||||
obs = self.obs_filter(observation, update=False)
|
||||
action, info = self.policy.compute(obs)
|
||||
obs = self.local_evaluator.obs_filter(observation, update=False)
|
||||
action, info = self.local_evaluator.policy.compute(obs)
|
||||
return action
|
||||
|
|
|
@ -2,6 +2,8 @@ from __future__ import absolute_import
|
|||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import pickle
|
||||
|
||||
import ray
|
||||
from ray.rllib.envs import create_and_wrap
|
||||
from ray.rllib.optimizers import Evaluator
|
||||
|
@ -23,16 +25,22 @@ class A3CEvaluator(Evaluator):
|
|||
rollouts.
|
||||
logdir: Directory for logging.
|
||||
"""
|
||||
def __init__(self, env_creator, config, logdir):
|
||||
self.env = env = create_and_wrap(env_creator, config["model"])
|
||||
def __init__(self, env_creator, config, logdir, start_sampler=True):
|
||||
self.env = env = create_and_wrap(env_creator, config["preprocessing"])
|
||||
policy_cls = get_policy_cls(config)
|
||||
# TODO(rliaw): should change this to be just env.observation_space
|
||||
self.policy = policy_cls(env.observation_space.shape, env.action_space)
|
||||
obs_filter = get_filter(
|
||||
self.policy = policy_cls(
|
||||
env.observation_space.shape, env.action_space, config)
|
||||
self.config = config
|
||||
|
||||
# Technically not needed when not remote
|
||||
self.obs_filter = get_filter(
|
||||
config["observation_filter"], env.observation_space.shape)
|
||||
self.rew_filter = get_filter(config["reward_filter"], ())
|
||||
self.sampler = AsyncSampler(env, self.policy, obs_filter,
|
||||
self.sampler = AsyncSampler(env, self.policy, self.obs_filter,
|
||||
config["batch_size"])
|
||||
if start_sampler and self.sampler.async:
|
||||
self.sampler.start()
|
||||
self.logdir = logdir
|
||||
|
||||
def sample(self):
|
||||
|
@ -40,7 +48,10 @@ class A3CEvaluator(Evaluator):
|
|||
Returns:
|
||||
trajectory (PartialRollout): Experience Samples from evaluator"""
|
||||
rollout = self.sampler.get_data()
|
||||
return rollout
|
||||
samples = process_rollout(
|
||||
rollout, self.rew_filter, gamma=self.config["gamma"],
|
||||
lambda_=self.config["lambda"], use_gae=True)
|
||||
return samples
|
||||
|
||||
def get_completed_rollout_metrics(self):
|
||||
"""Returns metrics on previously completed rollouts.
|
||||
|
@ -49,20 +60,16 @@ class A3CEvaluator(Evaluator):
|
|||
"""
|
||||
return self.sampler.get_metrics()
|
||||
|
||||
def compute_gradient(self):
|
||||
rollout = self.sampler.get_data()
|
||||
obs_filter = self.sampler.get_obs_filter(flush=True)
|
||||
def compute_gradients(self, samples):
|
||||
gradient, info = self.policy.compute_gradients(samples)
|
||||
return gradient
|
||||
|
||||
traj = process_rollout(
|
||||
rollout, self.rew_filter, gamma=0.99, lambda_=1.0, use_gae=True)
|
||||
gradient, info = self.policy.compute_gradients(traj)
|
||||
info["obs_filter"] = obs_filter
|
||||
info["rew_filter"] = self.rew_filter
|
||||
return gradient, info
|
||||
|
||||
def apply_gradient(self, grads):
|
||||
def apply_gradients(self, grads):
|
||||
self.policy.apply_gradients(grads)
|
||||
|
||||
def get_weights(self):
|
||||
return self.policy.get_weights()
|
||||
|
||||
def set_weights(self, params):
|
||||
self.policy.set_weights(params)
|
||||
|
||||
|
@ -73,5 +80,13 @@ class A3CEvaluator(Evaluator):
|
|||
if obs_filter:
|
||||
self.sampler.update_obs_filter(obs_filter)
|
||||
|
||||
def save(self):
|
||||
weights = self.get_weights()
|
||||
return pickle.dumps({"weights": weights})
|
||||
|
||||
def restore(self, objs):
|
||||
objs = pickle.loads(objs)
|
||||
self.set_weights(objs["weights"])
|
||||
|
||||
|
||||
RemoteA3CEvaluator = ray.remote(A3CEvaluator)
|
|
@ -13,13 +13,14 @@ class SharedModel(TFPolicy):
|
|||
other_output = ["vf_preds"]
|
||||
is_recurrent = False
|
||||
|
||||
def __init__(self, ob_space, ac_space, **kwargs):
|
||||
super(SharedModel, self).__init__(ob_space, ac_space, **kwargs)
|
||||
def __init__(self, ob_space, ac_space, config, **kwargs):
|
||||
super(SharedModel, self).__init__(ob_space, ac_space, config, **kwargs)
|
||||
|
||||
def _setup_graph(self, ob_space, ac_space):
|
||||
self.x = tf.placeholder(tf.float32, [None] + list(ob_space))
|
||||
dist_class, self.logit_dim = ModelCatalog.get_action_dist(ac_space)
|
||||
self._model = ModelCatalog.get_model(self.x, self.logit_dim)
|
||||
self._model = ModelCatalog.get_model(
|
||||
self.x, self.logit_dim, self.config["model"])
|
||||
self.logits = self._model.outputs
|
||||
self.curr_dist = dist_class(self.logits)
|
||||
# with tf.variable_scope("vf"):
|
||||
|
|
|
@ -13,7 +13,7 @@ class SharedModelLSTM(TFPolicy):
|
|||
"""
|
||||
Attributes:
|
||||
other_output (list): Other than `action`, the other return values from
|
||||
`compute_gradient`.
|
||||
`compute_gradients`.
|
||||
is_recurrent (bool): True if is a recurrent network (requires features
|
||||
to be tracked).
|
||||
"""
|
||||
|
@ -21,8 +21,9 @@ class SharedModelLSTM(TFPolicy):
|
|||
other_output = ["vf_preds", "features"]
|
||||
is_recurrent = True
|
||||
|
||||
def __init__(self, ob_space, ac_space, **kwargs):
|
||||
super(SharedModelLSTM, self).__init__(ob_space, ac_space, **kwargs)
|
||||
def __init__(self, ob_space, ac_space, config, **kwargs):
|
||||
super(SharedModelLSTM, self).__init__(
|
||||
ob_space, ac_space, config, **kwargs)
|
||||
|
||||
def _setup_graph(self, ob_space, ac_space):
|
||||
self.x = tf.placeholder(tf.float32, [None] + list(ob_space))
|
||||
|
|
|
@ -17,14 +17,16 @@ class SharedTorchPolicy(TorchPolicy):
|
|||
other_output = ["vf_preds"]
|
||||
is_recurrent = False
|
||||
|
||||
def __init__(self, ob_space, ac_space, **kwargs):
|
||||
def __init__(self, ob_space, ac_space, config, **kwargs):
|
||||
super(SharedTorchPolicy, self).__init__(
|
||||
ob_space, ac_space, **kwargs)
|
||||
ob_space, ac_space, config, **kwargs)
|
||||
|
||||
def _setup_graph(self, ob_space, ac_space):
|
||||
_, self.logit_dim = ModelCatalog.get_action_dist(ac_space)
|
||||
self._model = ModelCatalog.get_torch_model(ob_space, self.logit_dim)
|
||||
self.optimizer = torch.optim.Adam(self._model.parameters(), lr=0.0001)
|
||||
self._model = ModelCatalog.get_torch_model(
|
||||
ob_space, self.logit_dim, self.config["model"])
|
||||
self.optimizer = torch.optim.Adam(
|
||||
self._model.parameters(), lr=self.config["lr"])
|
||||
|
||||
def compute(self, ob, *args):
|
||||
"""Should take in a SINGLE ob"""
|
||||
|
@ -68,6 +70,9 @@ class SharedTorchPolicy(TorchPolicy):
|
|||
value_err = 0.5 * (values - rs).pow(2).sum()
|
||||
|
||||
self.optimizer.zero_grad()
|
||||
overall_err = 0.5 * value_err + pi_err - entropy * 0.01
|
||||
overall_err = (pi_err +
|
||||
value_err * self.config["vf_loss_coeff"] +
|
||||
entropy * self.config["entropy_coeff"])
|
||||
overall_err.backward()
|
||||
torch.nn.utils.clip_grad_norm(self._model.parameters(), 40)
|
||||
torch.nn.utils.clip_grad_norm(
|
||||
self._model.parameters(), self.config["grad_clip"])
|
||||
|
|
|
@ -10,8 +10,10 @@ from ray.rllib.a3c.policy import Policy
|
|||
|
||||
class TFPolicy(Policy):
|
||||
"""The policy base class."""
|
||||
def __init__(self, ob_space, action_space, name="local", summarize=True):
|
||||
def __init__(self, ob_space, action_space, config,
|
||||
name="local", summarize=True):
|
||||
self.local_steps = 0
|
||||
self.config = config
|
||||
self.summarize = summarize
|
||||
worker_device = "/job:localhost/replica:0/task:0/cpu:0"
|
||||
self.g = tf.Graph()
|
||||
|
@ -52,13 +54,15 @@ class TFPolicy(Policy):
|
|||
delta = self.vf - self.r
|
||||
self.vf_loss = 0.5 * tf.reduce_sum(tf.square(delta))
|
||||
self.entropy = tf.reduce_sum(self.curr_dist.entropy())
|
||||
self.loss = self.pi_loss + 0.5 * self.vf_loss - self.entropy * 0.01
|
||||
self.loss = (self.pi_loss +
|
||||
self.vf_loss * self.config["vf_loss_coeff"] +
|
||||
self.entropy * self.config["entropy_coeff"])
|
||||
|
||||
def setup_gradients(self):
|
||||
grads = tf.gradients(self.loss, self.var_list)
|
||||
self.grads, _ = tf.clip_by_global_norm(grads, 40.0)
|
||||
self.grads, _ = tf.clip_by_global_norm(grads, self.config["grad_clip"])
|
||||
grads_and_vars = list(zip(self.grads, self.var_list))
|
||||
opt = tf.train.AdamOptimizer(1e-4)
|
||||
opt = tf.train.AdamOptimizer(self.config["lr"])
|
||||
self._apply_gradients = opt.apply_gradients(grads_and_vars)
|
||||
|
||||
def initialize(self):
|
||||
|
@ -71,6 +75,7 @@ class TFPolicy(Policy):
|
|||
tf.summary.scalar("model/var_gnorm", tf.global_norm(self.var_list))
|
||||
self.summary_op = tf.summary.merge_all()
|
||||
|
||||
# TODO(rliaw): Can consider exposing these parameters
|
||||
self.sess = tf.Session(graph=self.g, config=tf.ConfigProto(
|
||||
intra_op_parallelism_threads=1, inter_op_parallelism_threads=2))
|
||||
self.variables = ray.experimental.TensorFlowVariables(self.loss,
|
||||
|
|
|
@ -15,8 +15,10 @@ class TorchPolicy(Policy):
|
|||
The model is a separate object than the policy. This could be changed
|
||||
in the future."""
|
||||
|
||||
def __init__(self, ob_space, action_space, name="local", summarize=True):
|
||||
def __init__(self, ob_space, action_space, config,
|
||||
name="local", summarize=True):
|
||||
self.local_steps = 0
|
||||
self.config = config
|
||||
self.summarize = summarize
|
||||
self._setup_graph(ob_space, action_space)
|
||||
torch.set_num_threads(2)
|
||||
|
|
|
@ -35,8 +35,7 @@ class AsyncOptimizer(Optimizer):
|
|||
# Note: can't use wait: https://github.com/ray-project/ray/issues/1128
|
||||
while gradient_queue:
|
||||
with self.wait_timer:
|
||||
fut, e = gradient_queue[0]
|
||||
gradient_queue = gradient_queue[1:]
|
||||
fut, e = gradient_queue.pop(0)
|
||||
gradient = ray.get(fut)
|
||||
|
||||
if gradient is not None:
|
||||
|
|
|
@ -78,7 +78,10 @@ DEFAULT_CONFIG = {
|
|||
# is detected
|
||||
"tf_debug_inf_or_nan": False,
|
||||
# If True, we write tensorflow logs and checkpoints
|
||||
"write_logs": True
|
||||
"write_logs": True,
|
||||
# Preprocessing for environment
|
||||
# TODO(rliaw): Convert to function similar to A#c
|
||||
"preprocessing": {}
|
||||
}
|
||||
|
||||
|
||||
|
@ -139,7 +142,7 @@ class PPOAgent(Agent):
|
|||
# to guard against the case where all values are equal
|
||||
return (value - value.mean()) / max(1e-4, value.std())
|
||||
|
||||
trajectory["advantages"] = standardized(trajectory["advantages"])
|
||||
trajectory.data["advantages"] = standardized(trajectory["advantages"])
|
||||
|
||||
rollouts_end = time.time()
|
||||
print("Computing policy (iterations=" + str(config["num_sgd_iter"]) +
|
||||
|
@ -147,7 +150,7 @@ class PPOAgent(Agent):
|
|||
names = [
|
||||
"iter", "total loss", "policy loss", "vf loss", "kl", "entropy"]
|
||||
print(("{:>15}" * len(names)).format(*names))
|
||||
trajectory = shuffle(trajectory)
|
||||
trajectory.data = shuffle(trajectory.data)
|
||||
shuffle_end = time.time()
|
||||
tuples_per_device = model.load_data(
|
||||
trajectory, self.iteration == 0 and config["full_trace_data_load"])
|
||||
|
|
|
@ -5,7 +5,7 @@ from __future__ import print_function
|
|||
import numpy as np
|
||||
import ray
|
||||
|
||||
from ray.rllib.ppo.utils import concatenate
|
||||
from ray.rllib.optimizers import SampleBatch
|
||||
|
||||
|
||||
def collect_samples(agents,
|
||||
|
@ -37,5 +37,5 @@ def collect_samples(agents,
|
|||
trajectories.append(trajectory)
|
||||
observation_filter.update(obs_f)
|
||||
reward_filter.update(rew_f)
|
||||
return (concatenate(trajectories), np.mean(total_rewards),
|
||||
return (SampleBatch.concat_samples(trajectories), np.mean(total_rewards),
|
||||
np.mean(trajectory_lengths))
|
||||
|
|
|
@ -19,7 +19,7 @@ from ray.rllib.utils.sampler import SyncSampler
|
|||
from ray.rllib.utils.filter import get_filter, MeanStdFilter
|
||||
from ray.rllib.utils.process_rollout import process_rollout
|
||||
from ray.rllib.ppo.loss import ProximalPolicyLoss
|
||||
from ray.rllib.ppo.utils import concatenate
|
||||
from ray.rllib.optimizers import SampleBatch
|
||||
|
||||
|
||||
# TODO(pcm): Make sure that both observation_filter and reward_filter
|
||||
|
@ -227,7 +227,7 @@ class Runner(object):
|
|||
(c.episode_reward, c.episode_length) for c in metrics])
|
||||
updated_obs_filter = self.sampler.get_obs_filter(flush=True)
|
||||
return (
|
||||
concatenate(trajectories),
|
||||
SampleBatch.concat_samples(trajectories),
|
||||
total_rewards,
|
||||
trajectory_lengths,
|
||||
updated_obs_filter,
|
||||
|
|
|
@ -2,16 +2,17 @@ pong-a3c-pytorch-cnn:
|
|||
env: PongDeterministic-v4
|
||||
run: A3C
|
||||
resources:
|
||||
cpu: 16
|
||||
cpu: 17
|
||||
driver_cpu_limit: 1
|
||||
config:
|
||||
num_workers: 16
|
||||
num_batches_per_iteration: 1000
|
||||
batch_size: 20
|
||||
use_lstm: false
|
||||
use_pytorch: true
|
||||
model:
|
||||
preprocessing:
|
||||
grayscale: true
|
||||
zero_mean: false
|
||||
dim: 80
|
||||
channel_major: true
|
||||
optimizer:
|
||||
grads_per_step: 1000
|
||||
|
|
|
@ -2,9 +2,15 @@ pong-a3c:
|
|||
env: PongDeterministic-v4
|
||||
run: A3C
|
||||
resources:
|
||||
cpu: 16
|
||||
cpu: 17
|
||||
driver_cpu_limit: 1
|
||||
config:
|
||||
num_workers: 16
|
||||
num_batches_per_iteration: 1000
|
||||
batch_size: 20
|
||||
use_lstm: true
|
||||
use_pytorch: false
|
||||
optimizer:
|
||||
grads_per_step: 1000
|
||||
preprocessing:
|
||||
dim: 42
|
||||
channel_major: false
|
||||
|
|
|
@ -4,6 +4,7 @@ from __future__ import print_function
|
|||
|
||||
import numpy as np
|
||||
import scipy.signal
|
||||
from ray.rllib.optimizers import SampleBatch
|
||||
|
||||
|
||||
def discount(x, gamma):
|
||||
|
@ -11,7 +12,15 @@ def discount(x, gamma):
|
|||
|
||||
|
||||
def process_rollout(rollout, reward_filter, gamma, lambda_=1.0, use_gae=True):
|
||||
"""Given a rollout, compute its value targets and the advantage."""
|
||||
"""Given a rollout, compute its value targets and the advantage.
|
||||
|
||||
Args:
|
||||
rollout (PartialRollout): Partial Rollout Object
|
||||
reward_filter (Filter): # TODO(rliaw)
|
||||
|
||||
Returns:
|
||||
SampleBatch (SampleBatch): Object with experience from rollout and
|
||||
processed rewards."""
|
||||
|
||||
traj = {}
|
||||
trajsize = len(rollout.data["actions"])
|
||||
|
@ -35,6 +44,8 @@ def process_rollout(rollout, reward_filter, gamma, lambda_=1.0, use_gae=True):
|
|||
for i in range(traj["advantages"].shape[0]):
|
||||
traj["advantages"][i] = reward_filter(traj["advantages"][i])
|
||||
|
||||
traj["advantages"] = traj["advantages"].copy()
|
||||
|
||||
assert all(val.shape[0] == trajsize for val in traj.values()), \
|
||||
"Rollout stacked incorrectly!"
|
||||
return traj
|
||||
return SampleBatch(traj)
|
||||
|
|
|
@ -148,9 +148,10 @@ class AsyncSampler(threading.Thread):
|
|||
self.policy = policy
|
||||
self._obs_filter = obs_filter
|
||||
self._obs_f_lock = threading.Lock()
|
||||
self.start()
|
||||
self.started = False
|
||||
|
||||
def run(self):
|
||||
self.started = True
|
||||
try:
|
||||
self._run()
|
||||
except BaseException as e:
|
||||
|
@ -213,7 +214,7 @@ class AsyncSampler(threading.Thread):
|
|||
Returns:
|
||||
rollout (PartialRollout): trajectory data (unprocessed)
|
||||
"""
|
||||
|
||||
assert self.started, "Sampler never started running!"
|
||||
rollout = self._pull_batch_from_queue()
|
||||
return rollout
|
||||
|
||||
|
|
Loading…
Add table
Reference in a new issue