mirror of
https://github.com/vale981/ray
synced 2025-03-05 10:01:43 -05:00
[rllib] PyTorch Models for A3C (#1187)
* fixing policy * Compute Action is singular, fixed weird issue with arrays * remove vestige * extraneous ipdb * Can Drop in Pytorch Model * lint * introducing models * fix base policy * Missed this from last time * lint * removedolds * getting vision working * LINT * trying to fix test dependencies * requiremnets * try * tryconda * yes * shutup * flake_passes * changes * removing weight initializer for lstm for now * unused * adam * clip * zero * properscaling * weight * try * fix up pytorch visionnet * bias correction * fix model * same visionnet * matching_bad_things * test * try locking * fixing_linear * naming * lint * FORJENKINS * clouds * lint * Lint + removed dependencies * removed dependencies * format
This commit is contained in:
parent
9a6a056609
commit
afdc87323f
16 changed files with 462 additions and 16 deletions
|
@ -4,3 +4,4 @@ FROM ray-project/deploy
|
|||
RUN conda install -y -c conda-forge tensorflow
|
||||
RUN apt-get install -y zlib1g-dev
|
||||
RUN pip install gym[atari] opencv-python==3.2.0.8 smart_open
|
||||
RUN conda install -y -q pytorch torchvision -c soumith
|
||||
|
|
|
@ -20,10 +20,11 @@ DEFAULT_CONFIG = {
|
|||
"num_batches_per_iteration": 100,
|
||||
"batch_size": 10,
|
||||
"use_lstm": True,
|
||||
"use_pytorch": False,
|
||||
"model": {"grayscale": True,
|
||||
"zero_mean": False,
|
||||
"dim": 42,
|
||||
"channel_major": True}
|
||||
"channel_major": False}
|
||||
}
|
||||
|
||||
|
||||
|
@ -35,6 +36,9 @@ class A3CAgent(Agent):
|
|||
self.env = create_and_wrap(self.env_creator, self.config["model"])
|
||||
if self.config["use_lstm"]:
|
||||
policy_cls = SharedModelLSTM
|
||||
elif self.config["use_pytorch"]:
|
||||
from ray.rllib.a3c.shared_torch_policy import SharedTorchPolicy
|
||||
policy_cls = SharedTorchPolicy
|
||||
else:
|
||||
policy_cls = SharedModel
|
||||
self.policy = policy_cls(
|
||||
|
|
|
@ -20,9 +20,6 @@ class Policy(object):
|
|||
def compute_gradients(self, batch):
|
||||
raise NotImplementedError
|
||||
|
||||
def get_vf_loss(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def compute_action(self, observations):
|
||||
"""Compute action for a _single_ observation"""
|
||||
raise NotImplementedError
|
||||
|
|
|
@ -12,7 +12,7 @@ class SharedModel(TFPolicy):
|
|||
def __init__(self, ob_space, ac_space, **kwargs):
|
||||
super(SharedModel, self).__init__(ob_space, ac_space, **kwargs)
|
||||
|
||||
def setup_graph(self, ob_space, ac_space):
|
||||
def _setup_graph(self, ob_space, ac_space):
|
||||
self.x = tf.placeholder(tf.float32, [None] + list(ob_space))
|
||||
dist_class, self.logit_dim = ModelCatalog.get_action_dist(ac_space)
|
||||
self._model = ModelCatalog.get_model(self.x, self.logit_dim)
|
||||
|
|
|
@ -14,7 +14,7 @@ class SharedModelLSTM(TFPolicy):
|
|||
def __init__(self, ob_space, ac_space, **kwargs):
|
||||
super(SharedModelLSTM, self).__init__(ob_space, ac_space, **kwargs)
|
||||
|
||||
def setup_graph(self, ob_space, ac_space):
|
||||
def _setup_graph(self, ob_space, ac_space):
|
||||
self.x = tf.placeholder(tf.float32, [None] + list(ob_space))
|
||||
dist_class, self.logit_dim = ModelCatalog.get_action_dist(ac_space)
|
||||
self._model = LSTM(self.x, self.logit_dim, {})
|
||||
|
|
73
python/ray/rllib/a3c/shared_torch_policy.py
Normal file
73
python/ray/rllib/a3c/shared_torch_policy.py
Normal file
|
@ -0,0 +1,73 @@
|
|||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import torch
|
||||
from torch.autograd import Variable
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ray.rllib.a3c.torchpolicy import TorchPolicy
|
||||
from ray.rllib.models.pytorch.misc import var_to_np, convert_batch
|
||||
from ray.rllib.models.catalog import ModelCatalog
|
||||
|
||||
|
||||
class SharedTorchPolicy(TorchPolicy):
|
||||
"""Assumes nonrecurrent."""
|
||||
|
||||
def __init__(self, ob_space, ac_space, **kwargs):
|
||||
super(SharedTorchPolicy, self).__init__(
|
||||
ob_space, ac_space, **kwargs)
|
||||
|
||||
def _setup_graph(self, ob_space, ac_space):
|
||||
_, self.logit_dim = ModelCatalog.get_action_dist(ac_space)
|
||||
self._model = ModelCatalog.get_torch_model(ob_space, self.logit_dim)
|
||||
self.optimizer = torch.optim.Adam(self._model.parameters(), lr=0.0001)
|
||||
|
||||
def compute_action(self, ob, *args):
|
||||
"""Should take in a SINGLE ob"""
|
||||
with self.lock:
|
||||
ob = Variable(torch.from_numpy(ob).float().unsqueeze(0))
|
||||
logits, values = self._model(ob)
|
||||
samples = self._model.probs(logits).multinomial().squeeze()
|
||||
values = values.squeeze(0)
|
||||
return var_to_np(samples), var_to_np(values)
|
||||
|
||||
def compute_logits(self, ob, *args):
|
||||
with self.lock:
|
||||
ob = Variable(torch.from_numpy(ob).float().unsqueeze(0))
|
||||
res = self._model.hidden_layers(ob)
|
||||
return var_to_np(self._model.logits(res))
|
||||
|
||||
def value(self, ob, *args):
|
||||
with self.lock:
|
||||
ob = Variable(torch.from_numpy(ob).float().unsqueeze(0))
|
||||
res = self._model.hidden_layers(ob)
|
||||
res = self._model.value_branch(res)
|
||||
res = res.squeeze(0)
|
||||
return var_to_np(res)
|
||||
|
||||
def _evaluate(self, obs, actions):
|
||||
"""Passes in multiple obs."""
|
||||
logits, values = self._model(obs)
|
||||
log_probs = F.log_softmax(logits)
|
||||
probs = self._model.probs(logits)
|
||||
action_log_probs = log_probs.gather(1, actions.view(-1, 1))
|
||||
entropy = -(log_probs * probs).sum(-1).sum()
|
||||
return values, action_log_probs, entropy
|
||||
|
||||
def _backward(self, batch):
|
||||
"""Loss is encoded in here. Defining a new loss function
|
||||
would start by rewriting this function"""
|
||||
|
||||
states, acs, advs, rs, _ = convert_batch(batch)
|
||||
values, ac_logprobs, entropy = self._evaluate(states, acs)
|
||||
pi_err = -(advs * ac_logprobs).sum()
|
||||
value_err = 0.5 * (values - rs).pow(2).sum()
|
||||
|
||||
self.optimizer.zero_grad()
|
||||
overall_err = 0.5 * value_err + pi_err - entropy * 0.01
|
||||
overall_err.backward()
|
||||
torch.nn.utils.clip_grad_norm(self._model.parameters(), 40)
|
||||
|
||||
def get_initial_features(self):
|
||||
return [None]
|
|
@ -17,7 +17,7 @@ class TFPolicy(Policy):
|
|||
self.g = tf.Graph()
|
||||
with self.g.as_default(), tf.device(worker_device):
|
||||
with tf.variable_scope(name):
|
||||
self.setup_graph(ob_space, action_space)
|
||||
self._setup_graph(ob_space, action_space)
|
||||
assert all([hasattr(self, attr)
|
||||
for attr in ["vf", "logits", "x", "var_list"]])
|
||||
print("Setting up loss")
|
||||
|
@ -25,7 +25,7 @@ class TFPolicy(Policy):
|
|||
self.setup_gradients()
|
||||
self.initialize()
|
||||
|
||||
def setup_graph(self):
|
||||
def _setup_graph(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def setup_loss(self, action_space):
|
||||
|
@ -92,9 +92,6 @@ class TFPolicy(Policy):
|
|||
def compute_gradients(self, batch):
|
||||
raise NotImplementedError
|
||||
|
||||
def get_vf_loss(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def compute_action(self, observations):
|
||||
raise NotImplementedError
|
||||
|
||||
|
|
|
@ -0,0 +1,78 @@
|
|||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import torch
|
||||
from torch.autograd import Variable
|
||||
|
||||
from ray.rllib.a3c.policy import Policy
|
||||
from threading import Lock
|
||||
|
||||
|
||||
class TorchPolicy(Policy):
|
||||
"""The policy base class for Torch.
|
||||
|
||||
The model is a separate object than the policy. This could be changed
|
||||
in the future."""
|
||||
|
||||
def __init__(self, ob_space, action_space, name="local", summarize=True):
|
||||
self.local_steps = 0
|
||||
self.summarize = summarize
|
||||
self._setup_graph(ob_space, action_space)
|
||||
torch.set_num_threads(2)
|
||||
self.lock = Lock()
|
||||
|
||||
def apply_gradients(self, grads):
|
||||
self.optimizer.zero_grad()
|
||||
for g, p in zip(grads, self._model.parameters()):
|
||||
p.grad = Variable(torch.from_numpy(g))
|
||||
self.optimizer.step()
|
||||
|
||||
def get_weights(self):
|
||||
# !! This only returns references to the data.
|
||||
return self._model.state_dict()
|
||||
|
||||
def set_weights(self, weights):
|
||||
with self.lock:
|
||||
self._model.load_state_dict(weights)
|
||||
|
||||
def compute_gradients(self, batch):
|
||||
"""_backward generates the gradient in each model parameter.
|
||||
This is taken out.
|
||||
|
||||
Args:
|
||||
batch: Batch of data needed for gradient calculation.
|
||||
|
||||
Return:
|
||||
gradients (list of np arrays): List of gradients
|
||||
info (dict): Extra information (user-defined)"""
|
||||
with self.lock:
|
||||
self._backward(batch)
|
||||
# Note that return values are just references;
|
||||
# calling zero_grad will modify the values
|
||||
return [p.grad.data.numpy() for p in self._model.parameters()], {}
|
||||
|
||||
def model_update(self, batch):
|
||||
"""Implements compute + apply
|
||||
|
||||
TODO(rliaw): Pytorch has nice caching property that doesn't require
|
||||
full batch to be passed in. Can exploit that later"""
|
||||
with self.lock:
|
||||
self._backward(batch)
|
||||
self.optimizer.step()
|
||||
|
||||
def _setup_graph(ob_space, action_space):
|
||||
raise NotImplementedError
|
||||
|
||||
def _backward(self, batch):
|
||||
"""Implements the loss function and calculates the gradient.
|
||||
Pytorch automatically generates a backward trace for each variable.
|
||||
Assumption right now is that variables are moved, so the backward
|
||||
trace is lost.
|
||||
|
||||
This function regenerates the backward trace and
|
||||
caluclates the gradient."""
|
||||
raise NotImplementedError
|
||||
|
||||
def get_initial_features(self):
|
||||
return []
|
|
@ -85,6 +85,30 @@ class ModelCatalog(object):
|
|||
|
||||
return FullyConnectedNetwork(inputs, num_outputs, options)
|
||||
|
||||
@staticmethod
|
||||
def get_torch_model(input_shape, num_outputs, options=dict()):
|
||||
"""Returns a PyTorch suitable model.
|
||||
|
||||
Args:
|
||||
input_shape (tup): The input shape to the model.
|
||||
num_outputs (int): The size of the output vector of the model.
|
||||
options (dict): Optional args to pass to the model constructor.
|
||||
|
||||
Returns:
|
||||
model (Model): Neural network model.
|
||||
"""
|
||||
from ray.rllib.models.pytorch.fcnet import (
|
||||
FullyConnectedNetwork as PyTorchFCNet)
|
||||
from ray.rllib.models.pytorch.visionnet import (
|
||||
VisionNetwork as PyTorchVisionNet)
|
||||
|
||||
obs_rank = len(input_shape) - 1
|
||||
|
||||
if obs_rank > 1:
|
||||
return PyTorchVisionNet(input_shape, num_outputs, options)
|
||||
|
||||
return PyTorchFCNet(input_shape[0], num_outputs, options)
|
||||
|
||||
@classmethod
|
||||
def get_preprocessor(cls, env, options=dict()):
|
||||
"""Returns a suitable processor for the given environment.
|
||||
|
|
|
@ -30,15 +30,15 @@ class AtariPixelPreprocessor(Preprocessor):
|
|||
self._grayscale = self._options.get("grayscale", False)
|
||||
self._zero_mean = self._options.get("zero_mean", True)
|
||||
self._dim = self._options.get("dim", 80)
|
||||
self._pytorch = self._options.get("pytorch", False)
|
||||
self._channel_major = self._options.get("channel_major", False)
|
||||
if self._grayscale:
|
||||
self.shape = (self._dim, self._dim, 1)
|
||||
else:
|
||||
self.shape = (self._dim, self._dim, 3)
|
||||
|
||||
# pytorch requires (# in-channels, row dim, col dim)
|
||||
if self._pytorch:
|
||||
self.shape = self.shape[::-1]
|
||||
# channel_major requires (# in-channels, row dim, col dim)
|
||||
if self._channel_major:
|
||||
self.shape = self.shape[-1:] + self.shape[:-1]
|
||||
|
||||
# TODO(ekl) why does this need to return an extra size-1 dim (the [None])
|
||||
def transform(self, observation):
|
||||
|
@ -59,7 +59,7 @@ class AtariPixelPreprocessor(Preprocessor):
|
|||
scaled = (scaled - 128) / 128
|
||||
else:
|
||||
scaled *= 1.0 / 255.0
|
||||
if self._pytorch:
|
||||
if self._channel_major:
|
||||
scaled = np.reshape(scaled, self.shape)
|
||||
return scaled
|
||||
|
||||
|
|
0
python/ray/rllib/models/pytorch/__init__.py
Normal file
0
python/ray/rllib/models/pytorch/__init__.py
Normal file
56
python/ray/rllib/models/pytorch/fcnet.py
Normal file
56
python/ray/rllib/models/pytorch/fcnet.py
Normal file
|
@ -0,0 +1,56 @@
|
|||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from ray.rllib.models.pytorch.model import Model, SlimFC
|
||||
from ray.rllib.models.pytorch.misc import normc_initializer
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class FullyConnectedNetwork(Model):
|
||||
"""TODO(rliaw): Logits, Value should both be contained here"""
|
||||
def _init(self, inputs, num_outputs, options):
|
||||
assert type(inputs) is int
|
||||
hiddens = options.get("fcnet_hiddens", [256, 256])
|
||||
fcnet_activation = options.get("fcnet_activation", "tanh")
|
||||
activation = None
|
||||
if fcnet_activation == "tanh":
|
||||
activation = nn.Tanh
|
||||
elif fcnet_activation == "relu":
|
||||
activation = nn.ReLU
|
||||
print("Constructing fcnet {} {}".format(hiddens, activation))
|
||||
|
||||
layers = []
|
||||
last_layer_size = inputs
|
||||
for size in hiddens:
|
||||
layers.append(SlimFC(
|
||||
last_layer_size, size,
|
||||
initializer=normc_initializer(1.0),
|
||||
activation_fn=activation))
|
||||
last_layer_size = size
|
||||
|
||||
self.hidden_layers = nn.Sequential(*layers)
|
||||
|
||||
self.logits = SlimFC(
|
||||
last_layer_size, num_outputs,
|
||||
initializer=normc_initializer(0.01),
|
||||
activation_fn=None)
|
||||
self.probs = nn.Softmax()
|
||||
self.value_branch = SlimFC(
|
||||
last_layer_size, 1,
|
||||
initializer=normc_initializer(1.0),
|
||||
activation_fn=None)
|
||||
|
||||
def forward(self, obs):
|
||||
""" Internal method - pass in Variables, not numpy arrays
|
||||
|
||||
Args:
|
||||
obs: observations and features
|
||||
|
||||
Return:
|
||||
logits: logits to be sampled from for each state
|
||||
value: value function for each state"""
|
||||
res = self.hidden_layers(obs)
|
||||
logits = self.logits(res)
|
||||
value = self.value_branch(res)
|
||||
return logits, value
|
69
python/ray/rllib/models/pytorch/misc.py
Normal file
69
python/ray/rllib/models/pytorch/misc.py
Normal file
|
@ -0,0 +1,69 @@
|
|||
""" Code adapted from https://github.com/ikostrikov/pytorch-a3c"""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.autograd import Variable
|
||||
|
||||
|
||||
def convert_batch(batch, has_features=False):
|
||||
"""Convert batch from numpy to PT variable"""
|
||||
states = Variable(torch.from_numpy(batch.si).float())
|
||||
acs = Variable(torch.from_numpy(batch.a))
|
||||
advs = Variable(torch.from_numpy(batch.adv.copy()).float())
|
||||
advs = advs.view(-1, 1)
|
||||
rs = Variable(torch.from_numpy(batch.r.copy()).float())
|
||||
rs = rs.view(-1, 1)
|
||||
if has_features:
|
||||
features = [Variable(torch.from_numpy(f))
|
||||
for f in batch.features]
|
||||
else:
|
||||
features = batch.features
|
||||
return states, acs, advs, rs, features
|
||||
|
||||
|
||||
def var_to_np(var):
|
||||
return var.data.numpy()[0]
|
||||
|
||||
|
||||
def normc_initializer(std=1.0):
|
||||
def initializer(tensor):
|
||||
tensor.data.normal_(0, 1)
|
||||
tensor.data *= std / torch.sqrt(
|
||||
tensor.data.pow(2).sum(1, keepdim=True))
|
||||
return initializer
|
||||
|
||||
|
||||
def valid_padding(in_size, filter_size, stride_size):
|
||||
"""Note: Padding is added to match TF conv2d `same` padding. See
|
||||
www.tensorflow.org/versions/r0.12/api_docs/python/nn/convolution
|
||||
|
||||
Params:
|
||||
in_size (tuple): Rows (Height), Column (Width) for input
|
||||
stride_size (tuple): Rows (Height), Column (Width) for stride
|
||||
filter_size (tuple): Rows (Height), Column (Width) for filter
|
||||
|
||||
Output:
|
||||
padding (tuple): For input into torch.nn.ZeroPad2d
|
||||
output (tuple): Output shape after padding and convolution
|
||||
"""
|
||||
in_height, in_width = in_size
|
||||
filter_height, filter_width = filter_size
|
||||
stride_height, stride_width = stride_size
|
||||
|
||||
out_height = np.ceil(float(in_height) / float(stride_height))
|
||||
out_width = np.ceil(float(in_width) / float(stride_width))
|
||||
|
||||
pad_along_height = int(
|
||||
((out_height - 1) * stride_height + filter_height - in_height))
|
||||
pad_along_width = int(
|
||||
((out_width - 1) * stride_width + filter_width - in_width))
|
||||
pad_top = pad_along_height // 2
|
||||
pad_bottom = pad_along_height - pad_top
|
||||
pad_left = pad_along_width // 2
|
||||
pad_right = pad_along_width - pad_left
|
||||
padding = (pad_left, pad_right, pad_top, pad_bottom)
|
||||
output = (out_height, out_width)
|
||||
return padding, output
|
70
python/ray/rllib/models/pytorch/model.py
Normal file
70
python/ray/rllib/models/pytorch/model.py
Normal file
|
@ -0,0 +1,70 @@
|
|||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self, obs_space, ac_space, options):
|
||||
super(Model, self).__init__()
|
||||
self._init(obs_space, ac_space, options)
|
||||
|
||||
def _init(self, inputs, num_outputs, options):
|
||||
raise NotImplementedError
|
||||
|
||||
def forward(self, obs):
|
||||
"""Forward pass for the model. Internal method - should only
|
||||
be passed PyTorch Tensors.
|
||||
|
||||
PyTorch automatically overloads the given model
|
||||
with this function. Recommended that model(obs)
|
||||
is used instead of model.forward(obs). See
|
||||
https://discuss.pytorch.org/t/any-different-between-model
|
||||
-input-and-model-forward-input/3690
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class SlimConv2d(nn.Module):
|
||||
"""Simple mock of tf.slim Conv2d"""
|
||||
|
||||
def __init__(self, in_channels, out_channels, kernel, stride, padding,
|
||||
initializer=nn.init.xavier_uniform,
|
||||
activation_fn=nn.ReLU, bias_init=0):
|
||||
super(SlimConv2d, self).__init__()
|
||||
layers = []
|
||||
if padding:
|
||||
layers.append(nn.ZeroPad2d(padding))
|
||||
conv = nn.Conv2d(in_channels, out_channels, kernel, stride)
|
||||
if initializer:
|
||||
initializer(conv.weight)
|
||||
nn.init.constant(conv.bias, bias_init)
|
||||
|
||||
layers.append(conv)
|
||||
if activation_fn:
|
||||
layers.append(activation_fn())
|
||||
self._model = nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
return self._model(x)
|
||||
|
||||
|
||||
class SlimFC(nn.Module):
|
||||
"""Simple PyTorch of `linear` function"""
|
||||
|
||||
def __init__(self, in_size, size, initializer=None,
|
||||
activation_fn=None, bias_init=0):
|
||||
super(SlimFC, self).__init__()
|
||||
layers = []
|
||||
linear = nn.Linear(in_size, size)
|
||||
if initializer:
|
||||
initializer(linear.weight)
|
||||
nn.init.constant(linear.bias, bias_init)
|
||||
layers.append(linear)
|
||||
if activation_fn:
|
||||
layers.append(activation_fn())
|
||||
self._model = nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
return self._model(x)
|
70
python/ray/rllib/models/pytorch/visionnet.py
Normal file
70
python/ray/rllib/models/pytorch/visionnet.py
Normal file
|
@ -0,0 +1,70 @@
|
|||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import torch.nn as nn
|
||||
|
||||
from ray.rllib.models.pytorch.model import Model, SlimConv2d, SlimFC
|
||||
from ray.rllib.models.pytorch.misc import normc_initializer, valid_padding
|
||||
|
||||
|
||||
class VisionNetwork(Model):
|
||||
"""Generic vision network"""
|
||||
|
||||
def _init(self, inputs, num_outputs, options):
|
||||
"""TF visionnet in PyTorch.
|
||||
|
||||
Params:
|
||||
inputs (tuple): (channels, rows/height, cols/width)
|
||||
num_outputs (int): logits size
|
||||
"""
|
||||
filters = options.get("conv_filters", [
|
||||
[16, [8, 8], 4],
|
||||
[32, [4, 4], 2],
|
||||
[512, [10, 10], 1]
|
||||
])
|
||||
layers = []
|
||||
in_channels, in_size = inputs[0], inputs[1:]
|
||||
|
||||
for out_channels, kernel, stride in filters[:-1]:
|
||||
padding, out_size = valid_padding(
|
||||
in_size, kernel, [stride, stride])
|
||||
layers.append(SlimConv2d(
|
||||
in_channels, out_channels, kernel, stride, padding))
|
||||
in_channels = out_channels
|
||||
in_size = out_size
|
||||
|
||||
out_channels, kernel, stride = filters[-1]
|
||||
layers.append(SlimConv2d(
|
||||
in_channels, out_channels, kernel, stride, None))
|
||||
self._convs = nn.Sequential(*layers)
|
||||
|
||||
self.logits = SlimFC(
|
||||
out_channels, num_outputs, initializer=nn.init.xavier_uniform)
|
||||
self.probs = nn.Softmax()
|
||||
self.value_branch = SlimFC(
|
||||
out_channels, 1, initializer=normc_initializer())
|
||||
|
||||
def hidden_layers(self, obs):
|
||||
""" Internal method - pass in Variables, not numpy arrays
|
||||
|
||||
args:
|
||||
obs: observations and features"""
|
||||
res = self._convs(obs)
|
||||
res = res.squeeze(3)
|
||||
res = res.squeeze(2)
|
||||
return res
|
||||
|
||||
def forward(self, obs):
|
||||
"""Internal method. Implements the
|
||||
|
||||
Args:
|
||||
obs (PyTorch): observations and features
|
||||
|
||||
Return:
|
||||
logits (PyTorch): logits to be sampled from for each state
|
||||
value (PyTorch): value function for each state"""
|
||||
res = self.hidden_layers(obs)
|
||||
logits = self.logits(res)
|
||||
value = self.value_branch(res)
|
||||
return logits, value
|
|
@ -125,6 +125,13 @@ docker run --shm-size=10G --memory=10G $DOCKER_SHA \
|
|||
--stop '{"training_iteration": 2}' \
|
||||
--config '{"kl_coeff": 1.0, "num_sgd_iter": 10, "sgd_stepsize": 1e-4, "sgd_batchsize": 64, "timesteps_per_batch": 2000, "num_workers": 1, "model": {"dim": 40, "conv_filters": [[16, [8, 8], 4], [32, [4, 4], 2], [512, [5, 5], 1]]}, "extra_frameskip": 4}'
|
||||
|
||||
docker run --shm-size=10G --memory=10G $DOCKER_SHA \
|
||||
python /ray/python/ray/rllib/train.py \
|
||||
--env PongDeterministic-v4 \
|
||||
--alg A3C \
|
||||
--stop '{"training_iteration": 2}' \
|
||||
--config '{"num_workers": 2, "use_lstm": false, "use_pytorch": true, "model": {"grayscale": true, "zero_mean": false, "dim": 80, "channel_major": true}}'
|
||||
|
||||
docker run --shm-size=10G --memory=10G $DOCKER_SHA \
|
||||
python /ray/python/ray/rllib/test/test_checkpoint_restore.py
|
||||
|
||||
|
|
Loading…
Add table
Reference in a new issue