[rllib] PyTorch Models for A3C (#1187)

* fixing policy

* Compute Action is singular, fixed weird issue with arrays

* remove vestige

* extraneous ipdb

* Can Drop in Pytorch Model

* lint

* introducing models

* fix base policy

* Missed this from last time

* lint

* removedolds

* getting vision working

* LINT

* trying to fix test dependencies

* requiremnets

* try

* tryconda

* yes

* shutup

* flake_passes

* changes

* removing weight initializer for lstm for now

* unused

* adam

* clip

* zero

* properscaling

* weight

* try

* fix up pytorch visionnet

* bias correction

* fix model

* same visionnet

* matching_bad_things

* test

* try locking

* fixing_linear

* naming

* lint

* FORJENKINS

* clouds

* lint

* Lint + removed dependencies

* removed dependencies

* format
This commit is contained in:
Richard Liaw 2017-11-12 00:20:33 -08:00 committed by GitHub
parent 9a6a056609
commit afdc87323f
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
16 changed files with 462 additions and 16 deletions

View file

@ -4,3 +4,4 @@ FROM ray-project/deploy
RUN conda install -y -c conda-forge tensorflow
RUN apt-get install -y zlib1g-dev
RUN pip install gym[atari] opencv-python==3.2.0.8 smart_open
RUN conda install -y -q pytorch torchvision -c soumith

View file

@ -20,10 +20,11 @@ DEFAULT_CONFIG = {
"num_batches_per_iteration": 100,
"batch_size": 10,
"use_lstm": True,
"use_pytorch": False,
"model": {"grayscale": True,
"zero_mean": False,
"dim": 42,
"channel_major": True}
"channel_major": False}
}
@ -35,6 +36,9 @@ class A3CAgent(Agent):
self.env = create_and_wrap(self.env_creator, self.config["model"])
if self.config["use_lstm"]:
policy_cls = SharedModelLSTM
elif self.config["use_pytorch"]:
from ray.rllib.a3c.shared_torch_policy import SharedTorchPolicy
policy_cls = SharedTorchPolicy
else:
policy_cls = SharedModel
self.policy = policy_cls(

View file

@ -20,9 +20,6 @@ class Policy(object):
def compute_gradients(self, batch):
raise NotImplementedError
def get_vf_loss(self):
raise NotImplementedError
def compute_action(self, observations):
"""Compute action for a _single_ observation"""
raise NotImplementedError

View file

@ -12,7 +12,7 @@ class SharedModel(TFPolicy):
def __init__(self, ob_space, ac_space, **kwargs):
super(SharedModel, self).__init__(ob_space, ac_space, **kwargs)
def setup_graph(self, ob_space, ac_space):
def _setup_graph(self, ob_space, ac_space):
self.x = tf.placeholder(tf.float32, [None] + list(ob_space))
dist_class, self.logit_dim = ModelCatalog.get_action_dist(ac_space)
self._model = ModelCatalog.get_model(self.x, self.logit_dim)

View file

@ -14,7 +14,7 @@ class SharedModelLSTM(TFPolicy):
def __init__(self, ob_space, ac_space, **kwargs):
super(SharedModelLSTM, self).__init__(ob_space, ac_space, **kwargs)
def setup_graph(self, ob_space, ac_space):
def _setup_graph(self, ob_space, ac_space):
self.x = tf.placeholder(tf.float32, [None] + list(ob_space))
dist_class, self.logit_dim = ModelCatalog.get_action_dist(ac_space)
self._model = LSTM(self.x, self.logit_dim, {})

View file

@ -0,0 +1,73 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import torch
from torch.autograd import Variable
import torch.nn.functional as F
from ray.rllib.a3c.torchpolicy import TorchPolicy
from ray.rllib.models.pytorch.misc import var_to_np, convert_batch
from ray.rllib.models.catalog import ModelCatalog
class SharedTorchPolicy(TorchPolicy):
"""Assumes nonrecurrent."""
def __init__(self, ob_space, ac_space, **kwargs):
super(SharedTorchPolicy, self).__init__(
ob_space, ac_space, **kwargs)
def _setup_graph(self, ob_space, ac_space):
_, self.logit_dim = ModelCatalog.get_action_dist(ac_space)
self._model = ModelCatalog.get_torch_model(ob_space, self.logit_dim)
self.optimizer = torch.optim.Adam(self._model.parameters(), lr=0.0001)
def compute_action(self, ob, *args):
"""Should take in a SINGLE ob"""
with self.lock:
ob = Variable(torch.from_numpy(ob).float().unsqueeze(0))
logits, values = self._model(ob)
samples = self._model.probs(logits).multinomial().squeeze()
values = values.squeeze(0)
return var_to_np(samples), var_to_np(values)
def compute_logits(self, ob, *args):
with self.lock:
ob = Variable(torch.from_numpy(ob).float().unsqueeze(0))
res = self._model.hidden_layers(ob)
return var_to_np(self._model.logits(res))
def value(self, ob, *args):
with self.lock:
ob = Variable(torch.from_numpy(ob).float().unsqueeze(0))
res = self._model.hidden_layers(ob)
res = self._model.value_branch(res)
res = res.squeeze(0)
return var_to_np(res)
def _evaluate(self, obs, actions):
"""Passes in multiple obs."""
logits, values = self._model(obs)
log_probs = F.log_softmax(logits)
probs = self._model.probs(logits)
action_log_probs = log_probs.gather(1, actions.view(-1, 1))
entropy = -(log_probs * probs).sum(-1).sum()
return values, action_log_probs, entropy
def _backward(self, batch):
"""Loss is encoded in here. Defining a new loss function
would start by rewriting this function"""
states, acs, advs, rs, _ = convert_batch(batch)
values, ac_logprobs, entropy = self._evaluate(states, acs)
pi_err = -(advs * ac_logprobs).sum()
value_err = 0.5 * (values - rs).pow(2).sum()
self.optimizer.zero_grad()
overall_err = 0.5 * value_err + pi_err - entropy * 0.01
overall_err.backward()
torch.nn.utils.clip_grad_norm(self._model.parameters(), 40)
def get_initial_features(self):
return [None]

View file

@ -17,7 +17,7 @@ class TFPolicy(Policy):
self.g = tf.Graph()
with self.g.as_default(), tf.device(worker_device):
with tf.variable_scope(name):
self.setup_graph(ob_space, action_space)
self._setup_graph(ob_space, action_space)
assert all([hasattr(self, attr)
for attr in ["vf", "logits", "x", "var_list"]])
print("Setting up loss")
@ -25,7 +25,7 @@ class TFPolicy(Policy):
self.setup_gradients()
self.initialize()
def setup_graph(self):
def _setup_graph(self):
raise NotImplementedError
def setup_loss(self, action_space):
@ -92,9 +92,6 @@ class TFPolicy(Policy):
def compute_gradients(self, batch):
raise NotImplementedError
def get_vf_loss(self):
raise NotImplementedError
def compute_action(self, observations):
raise NotImplementedError

View file

@ -0,0 +1,78 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import torch
from torch.autograd import Variable
from ray.rllib.a3c.policy import Policy
from threading import Lock
class TorchPolicy(Policy):
"""The policy base class for Torch.
The model is a separate object than the policy. This could be changed
in the future."""
def __init__(self, ob_space, action_space, name="local", summarize=True):
self.local_steps = 0
self.summarize = summarize
self._setup_graph(ob_space, action_space)
torch.set_num_threads(2)
self.lock = Lock()
def apply_gradients(self, grads):
self.optimizer.zero_grad()
for g, p in zip(grads, self._model.parameters()):
p.grad = Variable(torch.from_numpy(g))
self.optimizer.step()
def get_weights(self):
# !! This only returns references to the data.
return self._model.state_dict()
def set_weights(self, weights):
with self.lock:
self._model.load_state_dict(weights)
def compute_gradients(self, batch):
"""_backward generates the gradient in each model parameter.
This is taken out.
Args:
batch: Batch of data needed for gradient calculation.
Return:
gradients (list of np arrays): List of gradients
info (dict): Extra information (user-defined)"""
with self.lock:
self._backward(batch)
# Note that return values are just references;
# calling zero_grad will modify the values
return [p.grad.data.numpy() for p in self._model.parameters()], {}
def model_update(self, batch):
"""Implements compute + apply
TODO(rliaw): Pytorch has nice caching property that doesn't require
full batch to be passed in. Can exploit that later"""
with self.lock:
self._backward(batch)
self.optimizer.step()
def _setup_graph(ob_space, action_space):
raise NotImplementedError
def _backward(self, batch):
"""Implements the loss function and calculates the gradient.
Pytorch automatically generates a backward trace for each variable.
Assumption right now is that variables are moved, so the backward
trace is lost.
This function regenerates the backward trace and
caluclates the gradient."""
raise NotImplementedError
def get_initial_features(self):
return []

View file

@ -85,6 +85,30 @@ class ModelCatalog(object):
return FullyConnectedNetwork(inputs, num_outputs, options)
@staticmethod
def get_torch_model(input_shape, num_outputs, options=dict()):
"""Returns a PyTorch suitable model.
Args:
input_shape (tup): The input shape to the model.
num_outputs (int): The size of the output vector of the model.
options (dict): Optional args to pass to the model constructor.
Returns:
model (Model): Neural network model.
"""
from ray.rllib.models.pytorch.fcnet import (
FullyConnectedNetwork as PyTorchFCNet)
from ray.rllib.models.pytorch.visionnet import (
VisionNetwork as PyTorchVisionNet)
obs_rank = len(input_shape) - 1
if obs_rank > 1:
return PyTorchVisionNet(input_shape, num_outputs, options)
return PyTorchFCNet(input_shape[0], num_outputs, options)
@classmethod
def get_preprocessor(cls, env, options=dict()):
"""Returns a suitable processor for the given environment.

View file

@ -30,15 +30,15 @@ class AtariPixelPreprocessor(Preprocessor):
self._grayscale = self._options.get("grayscale", False)
self._zero_mean = self._options.get("zero_mean", True)
self._dim = self._options.get("dim", 80)
self._pytorch = self._options.get("pytorch", False)
self._channel_major = self._options.get("channel_major", False)
if self._grayscale:
self.shape = (self._dim, self._dim, 1)
else:
self.shape = (self._dim, self._dim, 3)
# pytorch requires (# in-channels, row dim, col dim)
if self._pytorch:
self.shape = self.shape[::-1]
# channel_major requires (# in-channels, row dim, col dim)
if self._channel_major:
self.shape = self.shape[-1:] + self.shape[:-1]
# TODO(ekl) why does this need to return an extra size-1 dim (the [None])
def transform(self, observation):
@ -59,7 +59,7 @@ class AtariPixelPreprocessor(Preprocessor):
scaled = (scaled - 128) / 128
else:
scaled *= 1.0 / 255.0
if self._pytorch:
if self._channel_major:
scaled = np.reshape(scaled, self.shape)
return scaled

View file

@ -0,0 +1,56 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from ray.rllib.models.pytorch.model import Model, SlimFC
from ray.rllib.models.pytorch.misc import normc_initializer
import torch.nn as nn
class FullyConnectedNetwork(Model):
"""TODO(rliaw): Logits, Value should both be contained here"""
def _init(self, inputs, num_outputs, options):
assert type(inputs) is int
hiddens = options.get("fcnet_hiddens", [256, 256])
fcnet_activation = options.get("fcnet_activation", "tanh")
activation = None
if fcnet_activation == "tanh":
activation = nn.Tanh
elif fcnet_activation == "relu":
activation = nn.ReLU
print("Constructing fcnet {} {}".format(hiddens, activation))
layers = []
last_layer_size = inputs
for size in hiddens:
layers.append(SlimFC(
last_layer_size, size,
initializer=normc_initializer(1.0),
activation_fn=activation))
last_layer_size = size
self.hidden_layers = nn.Sequential(*layers)
self.logits = SlimFC(
last_layer_size, num_outputs,
initializer=normc_initializer(0.01),
activation_fn=None)
self.probs = nn.Softmax()
self.value_branch = SlimFC(
last_layer_size, 1,
initializer=normc_initializer(1.0),
activation_fn=None)
def forward(self, obs):
""" Internal method - pass in Variables, not numpy arrays
Args:
obs: observations and features
Return:
logits: logits to be sampled from for each state
value: value function for each state"""
res = self.hidden_layers(obs)
logits = self.logits(res)
value = self.value_branch(res)
return logits, value

View file

@ -0,0 +1,69 @@
""" Code adapted from https://github.com/ikostrikov/pytorch-a3c"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import torch
from torch.autograd import Variable
def convert_batch(batch, has_features=False):
"""Convert batch from numpy to PT variable"""
states = Variable(torch.from_numpy(batch.si).float())
acs = Variable(torch.from_numpy(batch.a))
advs = Variable(torch.from_numpy(batch.adv.copy()).float())
advs = advs.view(-1, 1)
rs = Variable(torch.from_numpy(batch.r.copy()).float())
rs = rs.view(-1, 1)
if has_features:
features = [Variable(torch.from_numpy(f))
for f in batch.features]
else:
features = batch.features
return states, acs, advs, rs, features
def var_to_np(var):
return var.data.numpy()[0]
def normc_initializer(std=1.0):
def initializer(tensor):
tensor.data.normal_(0, 1)
tensor.data *= std / torch.sqrt(
tensor.data.pow(2).sum(1, keepdim=True))
return initializer
def valid_padding(in_size, filter_size, stride_size):
"""Note: Padding is added to match TF conv2d `same` padding. See
www.tensorflow.org/versions/r0.12/api_docs/python/nn/convolution
Params:
in_size (tuple): Rows (Height), Column (Width) for input
stride_size (tuple): Rows (Height), Column (Width) for stride
filter_size (tuple): Rows (Height), Column (Width) for filter
Output:
padding (tuple): For input into torch.nn.ZeroPad2d
output (tuple): Output shape after padding and convolution
"""
in_height, in_width = in_size
filter_height, filter_width = filter_size
stride_height, stride_width = stride_size
out_height = np.ceil(float(in_height) / float(stride_height))
out_width = np.ceil(float(in_width) / float(stride_width))
pad_along_height = int(
((out_height - 1) * stride_height + filter_height - in_height))
pad_along_width = int(
((out_width - 1) * stride_width + filter_width - in_width))
pad_top = pad_along_height // 2
pad_bottom = pad_along_height - pad_top
pad_left = pad_along_width // 2
pad_right = pad_along_width - pad_left
padding = (pad_left, pad_right, pad_top, pad_bottom)
output = (out_height, out_width)
return padding, output

View file

@ -0,0 +1,70 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import torch.nn as nn
class Model(nn.Module):
def __init__(self, obs_space, ac_space, options):
super(Model, self).__init__()
self._init(obs_space, ac_space, options)
def _init(self, inputs, num_outputs, options):
raise NotImplementedError
def forward(self, obs):
"""Forward pass for the model. Internal method - should only
be passed PyTorch Tensors.
PyTorch automatically overloads the given model
with this function. Recommended that model(obs)
is used instead of model.forward(obs). See
https://discuss.pytorch.org/t/any-different-between-model
-input-and-model-forward-input/3690
"""
raise NotImplementedError
class SlimConv2d(nn.Module):
"""Simple mock of tf.slim Conv2d"""
def __init__(self, in_channels, out_channels, kernel, stride, padding,
initializer=nn.init.xavier_uniform,
activation_fn=nn.ReLU, bias_init=0):
super(SlimConv2d, self).__init__()
layers = []
if padding:
layers.append(nn.ZeroPad2d(padding))
conv = nn.Conv2d(in_channels, out_channels, kernel, stride)
if initializer:
initializer(conv.weight)
nn.init.constant(conv.bias, bias_init)
layers.append(conv)
if activation_fn:
layers.append(activation_fn())
self._model = nn.Sequential(*layers)
def forward(self, x):
return self._model(x)
class SlimFC(nn.Module):
"""Simple PyTorch of `linear` function"""
def __init__(self, in_size, size, initializer=None,
activation_fn=None, bias_init=0):
super(SlimFC, self).__init__()
layers = []
linear = nn.Linear(in_size, size)
if initializer:
initializer(linear.weight)
nn.init.constant(linear.bias, bias_init)
layers.append(linear)
if activation_fn:
layers.append(activation_fn())
self._model = nn.Sequential(*layers)
def forward(self, x):
return self._model(x)

View file

@ -0,0 +1,70 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import torch.nn as nn
from ray.rllib.models.pytorch.model import Model, SlimConv2d, SlimFC
from ray.rllib.models.pytorch.misc import normc_initializer, valid_padding
class VisionNetwork(Model):
"""Generic vision network"""
def _init(self, inputs, num_outputs, options):
"""TF visionnet in PyTorch.
Params:
inputs (tuple): (channels, rows/height, cols/width)
num_outputs (int): logits size
"""
filters = options.get("conv_filters", [
[16, [8, 8], 4],
[32, [4, 4], 2],
[512, [10, 10], 1]
])
layers = []
in_channels, in_size = inputs[0], inputs[1:]
for out_channels, kernel, stride in filters[:-1]:
padding, out_size = valid_padding(
in_size, kernel, [stride, stride])
layers.append(SlimConv2d(
in_channels, out_channels, kernel, stride, padding))
in_channels = out_channels
in_size = out_size
out_channels, kernel, stride = filters[-1]
layers.append(SlimConv2d(
in_channels, out_channels, kernel, stride, None))
self._convs = nn.Sequential(*layers)
self.logits = SlimFC(
out_channels, num_outputs, initializer=nn.init.xavier_uniform)
self.probs = nn.Softmax()
self.value_branch = SlimFC(
out_channels, 1, initializer=normc_initializer())
def hidden_layers(self, obs):
""" Internal method - pass in Variables, not numpy arrays
args:
obs: observations and features"""
res = self._convs(obs)
res = res.squeeze(3)
res = res.squeeze(2)
return res
def forward(self, obs):
"""Internal method. Implements the
Args:
obs (PyTorch): observations and features
Return:
logits (PyTorch): logits to be sampled from for each state
value (PyTorch): value function for each state"""
res = self.hidden_layers(obs)
logits = self.logits(res)
value = self.value_branch(res)
return logits, value

View file

@ -125,6 +125,13 @@ docker run --shm-size=10G --memory=10G $DOCKER_SHA \
--stop '{"training_iteration": 2}' \
--config '{"kl_coeff": 1.0, "num_sgd_iter": 10, "sgd_stepsize": 1e-4, "sgd_batchsize": 64, "timesteps_per_batch": 2000, "num_workers": 1, "model": {"dim": 40, "conv_filters": [[16, [8, 8], 4], [32, [4, 4], 2], [512, [5, 5], 1]]}, "extra_frameskip": 4}'
docker run --shm-size=10G --memory=10G $DOCKER_SHA \
python /ray/python/ray/rllib/train.py \
--env PongDeterministic-v4 \
--alg A3C \
--stop '{"training_iteration": 2}' \
--config '{"num_workers": 2, "use_lstm": false, "use_pytorch": true, "model": {"grayscale": true, "zero_mean": false, "dim": 80, "channel_major": true}}'
docker run --shm-size=10G --memory=10G $DOCKER_SHA \
python /ray/python/ray/rllib/test/test_checkpoint_restore.py