[rllib] Fix A3C PyTorch implementation (#2036)

* Use F.softmax instead of a pointless network layer

Stateless functions should not be network layers.

* Use correct pytorch functions

* Rename argument name to out_size

Matches in_size and makes more sense.

* Fix shapes of tensors

Advantages and rewards both should be scalars, and therefore a list of them
should be 1D.

* Fmt

* replace deprecated function

* rm unnecessary Variable wrapper

* rm all use of torch Variables

Torch does this for us now.

* Ensure that values are flat list

* Fix shape error in conv nets

* fmt

* Fix shape errors

Reshaping the action before stepping in the env fixes a few errors.

* Add TODO

* Use correct filter size

Works when `self.config['model']['channel_major'] = True`.

* Add missing channel major

* Revert reshape of action

This should be handled by the agent or at least in a cleaner way that doesn't
break existing envs.

* Squeeze action

* Squeeze actions along first dimension

This should deal with some cases such as cartpole where actions are scalars
while leaving alone cases where actions are arrays (some robotics tasks).

* try adding pytorch tests

* typo

* fixup docker messages

* Fix A3C for some envs

Pendulum doesn't work since it's an edge case (expects singleton arrays, which
`.squeeze()` collapses to scalars).

* fmt

* nit flake

* small lint
This commit is contained in:
Alok Singh 2018-05-30 10:48:11 -07:00 committed by Richard Liaw
parent ac1e5a7d15
commit fd234e3171
10 changed files with 138 additions and 100 deletions

View file

@ -1,8 +1,10 @@
# The examples Docker image adds dependencies needed to run the examples
FROM ray-project/deploy
RUN conda install -y -c conda-forge tensorflow
# This updates numpy to 1.14 and mutes errors from other libraries
RUN conda install -y numpy
RUN apt-get install -y zlib1g-dev
RUN pip install gym[atari] opencv-python==3.2.0.8
RUN pip install gym[atari] opencv-python==3.2.0.8 tensorflow
RUN pip install --upgrade git+git://github.com/hyperopt/hyperopt.git
# RUN conda install -y -q pytorch torchvision -c soumith
RUN conda install pytorch-cpu torchvision-cpu -c pytorch

View file

@ -15,7 +15,6 @@ from ray.rllib.a3c.a3c_evaluator import A3CEvaluator, RemoteA3CEvaluator, \
from ray.tune.result import TrainingResult
from ray.tune.trial import Resources
DEFAULT_CONFIG = {
# Number of workers (excluding master)
"num_workers": 4,
@ -52,7 +51,7 @@ DEFAULT_CONFIG = {
# (Image statespace) - Converts image to (dim, dim, C)
"dim": 80,
# (Image statespace) - Converts image shape to (C, dim, dim)
"channel_major": False
"channel_major": False,
},
# Arguments to pass to the rllib optimizer
"optimizer": {
@ -73,46 +72,53 @@ class A3CAgent(Agent):
def default_resource_request(cls, config):
cf = dict(cls._default_config, **config)
return Resources(
cpu=1, gpu=0,
cpu=1,
gpu=0,
extra_cpu=cf["num_workers"],
extra_gpu=cf["use_gpu_for_workers"] and cf["num_workers"] or 0)
def _init(self):
self.local_evaluator = A3CEvaluator(
self.registry, self.env_creator, self.config, self.logdir,
self.registry,
self.env_creator,
self.config,
self.logdir,
start_sampler=False)
if self.config["use_gpu_for_workers"]:
remote_cls = GPURemoteA3CEvaluator
else:
remote_cls = RemoteA3CEvaluator
self.remote_evaluators = [
remote_cls.remote(
self.registry, self.env_creator, self.config, self.logdir)
for i in range(self.config["num_workers"])]
self.optimizer = AsyncOptimizer(
self.config["optimizer"], self.local_evaluator,
self.remote_evaluators)
remote_cls.remote(self.registry, self.env_creator, self.config,
self.logdir)
for i in range(self.config["num_workers"])
]
self.optimizer = AsyncOptimizer(self.config["optimizer"],
self.local_evaluator,
self.remote_evaluators)
def _train(self):
self.optimizer.step()
FilterManager.synchronize(
self.local_evaluator.filters, self.remote_evaluators)
FilterManager.synchronize(self.local_evaluator.filters,
self.remote_evaluators)
res = self._fetch_metrics_from_remote_evaluators()
return res
def _fetch_metrics_from_remote_evaluators(self):
episode_rewards = []
episode_lengths = []
metric_lists = [a.get_completed_rollout_metrics.remote()
for a in self.remote_evaluators]
metric_lists = [
a.get_completed_rollout_metrics.remote()
for a in self.remote_evaluators
]
for metrics in metric_lists:
for episode in ray.get(metrics):
episode_lengths.append(episode.episode_length)
episode_rewards.append(episode.episode_reward)
avg_reward = (
np.mean(episode_rewards) if episode_rewards else float('nan'))
avg_length = (
np.mean(episode_lengths) if episode_lengths else float('nan'))
avg_reward = (np.mean(episode_rewards)
if episode_rewards else float('nan'))
avg_length = (np.mean(episode_lengths)
if episode_lengths else float('nan'))
timesteps = np.sum(episode_lengths) if episode_lengths else 0
result = TrainingResult(
@ -129,21 +135,23 @@ class A3CAgent(Agent):
ev.__ray_terminate__.remote()
def _save(self, checkpoint_dir):
checkpoint_path = os.path.join(
checkpoint_dir, "checkpoint-{}".format(self.iteration))
checkpoint_path = os.path.join(checkpoint_dir,
"checkpoint-{}".format(self.iteration))
agent_state = ray.get(
[a.save.remote() for a in self.remote_evaluators])
extra_data = {
"remote_state": agent_state,
"local_state": self.local_evaluator.save()}
"local_state": self.local_evaluator.save()
}
pickle.dump(extra_data, open(checkpoint_path + ".extra_data", "wb"))
return checkpoint_path
def _restore(self, checkpoint_path):
extra_data = pickle.load(open(checkpoint_path + ".extra_data", "rb"))
ray.get(
[a.restore.remote(o) for a, o in zip(
self.remote_evaluators, extra_data["remote_state"])])
ray.get([
a.restore.remote(o)
for a, o in zip(self.remote_evaluators, extra_data["remote_state"])
])
self.local_evaluator.restore(extra_data["local_state"])
def compute_action(self, observation):

View file

@ -3,7 +3,6 @@ from __future__ import division
from __future__ import print_function
import torch
from torch.autograd import Variable
import torch.nn.functional as F
from ray.rllib.a3c.torchpolicy import TorchPolicy
@ -18,8 +17,8 @@ class SharedTorchPolicy(TorchPolicy):
is_recurrent = False
def __init__(self, registry, ob_space, ac_space, config, **kwargs):
super(SharedTorchPolicy, self).__init__(
registry, ob_space, ac_space, config, **kwargs)
super(SharedTorchPolicy, self).__init__(registry, ob_space, ac_space,
config, **kwargs)
def _setup_graph(self, ob_space, ac_space):
_, self.logit_dim = ModelCatalog.get_action_dist(ac_space)
@ -31,32 +30,36 @@ class SharedTorchPolicy(TorchPolicy):
def compute(self, ob, *args):
"""Should take in a SINGLE ob"""
with self.lock:
ob = Variable(torch.from_numpy(ob).float().unsqueeze(0))
ob = torch.from_numpy(ob).float().unsqueeze(0)
logits, values = self._model(ob)
samples = self._model.probs(logits).multinomial().squeeze()
values = values.squeeze(0)
return var_to_np(samples), {"vf_preds": var_to_np(values)}
# TODO(alok): Support non-categorical distributions. Multinomial
# is only for categorical.
sampled_actions = F.softmax(logits, dim=1).multinomial(1).squeeze()
values = values.squeeze()
return var_to_np(sampled_actions), {"vf_preds": var_to_np(values)}
def compute_logits(self, ob, *args):
with self.lock:
ob = Variable(torch.from_numpy(ob).float().unsqueeze(0))
ob = torch.from_numpy(ob).float().unsqueeze(0)
res = self._model.hidden_layers(ob)
return var_to_np(self._model.logits(res))
def value(self, ob, *args):
with self.lock:
ob = Variable(torch.from_numpy(ob).float().unsqueeze(0))
ob = torch.from_numpy(ob).float().unsqueeze(0)
res = self._model.hidden_layers(ob)
res = self._model.value_branch(res)
res = res.squeeze(0)
res = res.squeeze()
return var_to_np(res)
def _evaluate(self, obs, actions):
"""Passes in multiple obs."""
logits, values = self._model(obs)
log_probs = F.log_softmax(logits)
probs = self._model.probs(logits)
log_probs = F.log_softmax(logits, dim=1)
probs = F.softmax(logits, dim=1)
action_log_probs = log_probs.gather(1, actions.view(-1, 1))
# TODO(alok): set distribution based on action space and use its
# `.entropy()` method to calculate automatically
entropy = -(log_probs * probs).sum(-1).sum()
return values, action_log_probs, entropy
@ -64,15 +67,19 @@ class SharedTorchPolicy(TorchPolicy):
"""Loss is encoded in here. Defining a new loss function
would start by rewriting this function"""
states, acs, advs, rs, _ = convert_batch(batch)
values, ac_logprobs, entropy = self._evaluate(states, acs)
pi_err = -(advs * ac_logprobs).sum()
value_err = 0.5 * (values - rs).pow(2).sum()
states, actions, advs, rs, _ = convert_batch(batch)
values, action_log_probs, entropy = self._evaluate(states, actions)
pi_err = -advs.dot(action_log_probs.reshape(-1))
value_err = F.mse_loss(values.reshape(-1), rs)
self.optimizer.zero_grad()
overall_err = (pi_err +
value_err * self.config["vf_loss_coeff"] +
entropy * self.config["entropy_coeff"])
overall_err = sum([
pi_err,
self.config["vf_loss_coeff"] * value_err,
self.config["entropy_coeff"] * entropy,
])
overall_err.backward()
torch.nn.utils.clip_grad_norm(
self._model.parameters(), self.config["grad_clip"])
torch.nn.utils.clip_grad_norm_(self._model.parameters(),
self.config["grad_clip"])

View file

@ -3,7 +3,6 @@ from __future__ import division
from __future__ import print_function
import torch
from torch.autograd import Variable
from ray.rllib.a3c.policy import Policy
from threading import Lock
@ -15,8 +14,13 @@ class TorchPolicy(Policy):
The model is a separate object than the policy. This could be changed
in the future."""
def __init__(self, registry, ob_space, action_space, config,
name="local", summarize=True):
def __init__(self,
registry,
ob_space,
action_space,
config,
name="local",
summarize=True):
self.registry = registry
self.local_steps = 0
self.config = config
@ -28,7 +32,7 @@ class TorchPolicy(Policy):
def apply_gradients(self, grads):
self.optimizer.zero_grad()
for g, p in zip(grads, self._model.parameters()):
p.grad = Variable(torch.from_numpy(g))
p.grad = torch.from_numpy(g)
self.optimizer.step()
def get_weights(self):
@ -69,7 +73,7 @@ class TorchPolicy(Policy):
def _backward(self, batch):
"""Implements the loss function and calculates the gradient.
Pytorch automatically generates a backward trace for each variable.
Pytorch automatically generates a backward trace for each tensor.
Assumption right now is that variables are moved, so the backward
trace is lost.

View file

@ -180,11 +180,14 @@ class ModelCatalog(object):
return registry.get(RLLIB_MODEL, model)(
input_shape, num_outputs, options)
# TODO(alok): fix to handle Discrete(n) state spaces
obs_rank = len(input_shape) - 1
if obs_rank > 1:
return PyTorchVisionNet(input_shape, num_outputs, options)
# TODO(alok): overhaul PyTorchFCNet so it can just
# take input shape directly
return PyTorchFCNet(input_shape[0], num_outputs, options)
@staticmethod

View file

@ -9,6 +9,7 @@ import torch.nn as nn
class FullyConnectedNetwork(Model):
"""TODO(rliaw): Logits, Value should both be contained here"""
def _init(self, inputs, num_outputs, options):
assert type(inputs) is int
hiddens = options.get("fcnet_hiddens", [256, 256])
@ -23,26 +24,29 @@ class FullyConnectedNetwork(Model):
layers = []
last_layer_size = inputs
for size in hiddens:
layers.append(SlimFC(
last_layer_size, size,
initializer=normc_initializer(1.0),
activation_fn=activation))
layers.append(
SlimFC(
in_size=last_layer_size,
out_size=size,
initializer=normc_initializer(1.0),
activation_fn=activation))
last_layer_size = size
self.hidden_layers = nn.Sequential(*layers)
self.logits = SlimFC(
last_layer_size, num_outputs,
in_size=last_layer_size,
out_size=num_outputs,
initializer=normc_initializer(0.01),
activation_fn=None)
self.probs = nn.Softmax()
self.value_branch = SlimFC(
last_layer_size, 1,
in_size=last_layer_size,
out_size=1,
initializer=normc_initializer(1.0),
activation_fn=None)
def forward(self, obs):
""" Internal method - pass in Variables, not numpy arrays
""" Internal method - pass in torch tensors, not numpy arrays
Args:
obs: observations and features
@ -52,5 +56,5 @@ class FullyConnectedNetwork(Model):
value: value function for each state"""
res = self.hidden_layers(obs)
logits = self.logits(res)
value = self.value_branch(res)
value = self.value_branch(res).reshape(-1)
return logits, value

View file

@ -5,31 +5,24 @@ from __future__ import print_function
import numpy as np
import torch
from torch.autograd import Variable
def convert_batch(trajectory, has_features=False):
"""Convert trajectory from numpy to PT variable"""
states = Variable(torch.from_numpy(
trajectory["observations"]).float())
acs = Variable(torch.from_numpy(
trajectory["actions"]))
advs = Variable(torch.from_numpy(
trajectory["advantages"].copy()).float())
advs = advs.view(-1, 1)
rs = Variable(torch.from_numpy(
trajectory["value_targets"]).float())
rs = rs.view(-1, 1)
states = torch.from_numpy(trajectory["obs"]).float()
acs = torch.from_numpy(trajectory["actions"])
advs = torch.from_numpy(
trajectory["advantages"].copy()).float().reshape(-1)
rs = torch.from_numpy(trajectory["rewards"]).float().reshape(-1)
if has_features:
features = [Variable(torch.from_numpy(f))
for f in trajectory["features"]]
features = [torch.from_numpy(f) for f in trajectory["features"]]
else:
features = trajectory["features"]
return states, acs, advs, rs, features
def var_to_np(var):
return var.data.numpy()[0]
return var.detach().numpy()
def normc_initializer(std=1.0):
@ -37,6 +30,7 @@ def normc_initializer(std=1.0):
tensor.data.normal_(0, 1)
tensor.data *= std / torch.sqrt(
tensor.data.pow(2).sum(1, keepdim=True))
return initializer

View file

@ -29,9 +29,15 @@ class Model(nn.Module):
class SlimConv2d(nn.Module):
"""Simple mock of tf.slim Conv2d"""
def __init__(self, in_channels, out_channels, kernel, stride, padding,
initializer=nn.init.xavier_uniform,
activation_fn=nn.ReLU, bias_init=0):
def __init__(self,
in_channels,
out_channels,
kernel,
stride,
padding,
initializer=nn.init.xavier_uniform_,
activation_fn=nn.ReLU,
bias_init=0):
super(SlimConv2d, self).__init__()
layers = []
if padding:
@ -39,7 +45,7 @@ class SlimConv2d(nn.Module):
conv = nn.Conv2d(in_channels, out_channels, kernel, stride)
if initializer:
initializer(conv.weight)
nn.init.constant(conv.bias, bias_init)
nn.init.constant_(conv.bias, bias_init)
layers.append(conv)
if activation_fn:
@ -53,14 +59,18 @@ class SlimConv2d(nn.Module):
class SlimFC(nn.Module):
"""Simple PyTorch of `linear` function"""
def __init__(self, in_size, size, initializer=None,
activation_fn=None, bias_init=0):
def __init__(self,
in_size,
out_size,
initializer=None,
activation_fn=None,
bias_init=0):
super(SlimFC, self).__init__()
layers = []
linear = nn.Linear(in_size, size)
linear = nn.Linear(in_size, out_size)
if initializer:
initializer(linear.weight)
nn.init.constant(linear.bias, bias_init)
nn.init.constant_(linear.bias, bias_init)
layers.append(linear)
if activation_fn:
layers.append(activation_fn())

View file

@ -21,32 +21,31 @@ class VisionNetwork(Model):
filters = options.get("conv_filters", [
[16, [8, 8], 4],
[32, [4, 4], 2],
[512, [10, 10], 1]
[512, [10, 10], 1],
])
layers = []
in_channels, in_size = inputs[0], inputs[1:]
for out_channels, kernel, stride in filters[:-1]:
padding, out_size = valid_padding(
in_size, kernel, [stride, stride])
layers.append(SlimConv2d(
in_channels, out_channels, kernel, stride, padding))
padding, out_size = valid_padding(in_size, kernel,
[stride, stride])
layers.append(
SlimConv2d(in_channels, out_channels, kernel, stride, padding))
in_channels = out_channels
in_size = out_size
out_channels, kernel, stride = filters[-1]
layers.append(SlimConv2d(
in_channels, out_channels, kernel, stride, None))
layers.append(
SlimConv2d(in_channels, out_channels, kernel, stride, None))
self._convs = nn.Sequential(*layers)
self.logits = SlimFC(
out_channels, num_outputs, initializer=nn.init.xavier_uniform)
self.probs = nn.Softmax()
out_channels, num_outputs, initializer=nn.init.xavier_uniform_)
self.value_branch = SlimFC(
out_channels, 1, initializer=normc_initializer())
def hidden_layers(self, obs):
""" Internal method - pass in Variables, not numpy arrays
""" Internal method - pass in torch tensors, not numpy arrays
args:
obs: observations and features"""

View file

@ -146,12 +146,19 @@ docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \
--stop '{"training_iteration": 2}' \
--config '{"kl_coeff": 1.0, "num_sgd_iter": 10, "sgd_stepsize": 1e-4, "sgd_batchsize": 64, "timesteps_per_batch": 2000, "num_workers": 1, "model": {"dim": 40, "conv_filters": [[16, [8, 8], 4], [32, [4, 4], 2], [512, [5, 5], 1]]}, "extra_frameskip": 4}'
# docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \
# python /ray/python/ray/rllib/train.py \
# --env PongDeterministic-v4 \
# --run A3C \
# --stop '{"training_iteration": 2}' \
# --config '{"num_workers": 2, "use_lstm": false, "use_pytorch": true, "model": {"grayscale": true, "zero_mean": false, "dim": 80, "channel_major": true}}'
docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \
python /ray/python/ray/rllib/train.py \
--env PongDeterministic-v4 \
--run A3C \
--stop '{"training_iteration": 2}' \
--config '{"num_workers": 2, "use_lstm": false, "use_pytorch": true, "model": {"grayscale": true, "zero_mean": false, "dim": 80, "channel_major": true}}'
docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \
python /ray/python/ray/rllib/train.py \
--env CartPole-v1 \
--run A3C \
--stop '{"training_iteration": 2}' \
--config '{"num_workers": 2, "use_lstm": false, "use_pytorch": true}'
docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \
python /ray/python/ray/rllib/train.py \