mirror of
https://github.com/vale981/ray
synced 2025-03-05 10:01:43 -05:00
[rllib] Fix A3C PyTorch implementation (#2036)
* Use F.softmax instead of a pointless network layer Stateless functions should not be network layers. * Use correct pytorch functions * Rename argument name to out_size Matches in_size and makes more sense. * Fix shapes of tensors Advantages and rewards both should be scalars, and therefore a list of them should be 1D. * Fmt * replace deprecated function * rm unnecessary Variable wrapper * rm all use of torch Variables Torch does this for us now. * Ensure that values are flat list * Fix shape error in conv nets * fmt * Fix shape errors Reshaping the action before stepping in the env fixes a few errors. * Add TODO * Use correct filter size Works when `self.config['model']['channel_major'] = True`. * Add missing channel major * Revert reshape of action This should be handled by the agent or at least in a cleaner way that doesn't break existing envs. * Squeeze action * Squeeze actions along first dimension This should deal with some cases such as cartpole where actions are scalars while leaving alone cases where actions are arrays (some robotics tasks). * try adding pytorch tests * typo * fixup docker messages * Fix A3C for some envs Pendulum doesn't work since it's an edge case (expects singleton arrays, which `.squeeze()` collapses to scalars). * fmt * nit flake * small lint
This commit is contained in:
parent
ac1e5a7d15
commit
fd234e3171
10 changed files with 138 additions and 100 deletions
|
@ -1,8 +1,10 @@
|
|||
# The examples Docker image adds dependencies needed to run the examples
|
||||
|
||||
FROM ray-project/deploy
|
||||
RUN conda install -y -c conda-forge tensorflow
|
||||
|
||||
# This updates numpy to 1.14 and mutes errors from other libraries
|
||||
RUN conda install -y numpy
|
||||
RUN apt-get install -y zlib1g-dev
|
||||
RUN pip install gym[atari] opencv-python==3.2.0.8
|
||||
RUN pip install gym[atari] opencv-python==3.2.0.8 tensorflow
|
||||
RUN pip install --upgrade git+git://github.com/hyperopt/hyperopt.git
|
||||
# RUN conda install -y -q pytorch torchvision -c soumith
|
||||
RUN conda install pytorch-cpu torchvision-cpu -c pytorch
|
||||
|
|
|
@ -15,7 +15,6 @@ from ray.rllib.a3c.a3c_evaluator import A3CEvaluator, RemoteA3CEvaluator, \
|
|||
from ray.tune.result import TrainingResult
|
||||
from ray.tune.trial import Resources
|
||||
|
||||
|
||||
DEFAULT_CONFIG = {
|
||||
# Number of workers (excluding master)
|
||||
"num_workers": 4,
|
||||
|
@ -52,7 +51,7 @@ DEFAULT_CONFIG = {
|
|||
# (Image statespace) - Converts image to (dim, dim, C)
|
||||
"dim": 80,
|
||||
# (Image statespace) - Converts image shape to (C, dim, dim)
|
||||
"channel_major": False
|
||||
"channel_major": False,
|
||||
},
|
||||
# Arguments to pass to the rllib optimizer
|
||||
"optimizer": {
|
||||
|
@ -73,46 +72,53 @@ class A3CAgent(Agent):
|
|||
def default_resource_request(cls, config):
|
||||
cf = dict(cls._default_config, **config)
|
||||
return Resources(
|
||||
cpu=1, gpu=0,
|
||||
cpu=1,
|
||||
gpu=0,
|
||||
extra_cpu=cf["num_workers"],
|
||||
extra_gpu=cf["use_gpu_for_workers"] and cf["num_workers"] or 0)
|
||||
|
||||
def _init(self):
|
||||
self.local_evaluator = A3CEvaluator(
|
||||
self.registry, self.env_creator, self.config, self.logdir,
|
||||
self.registry,
|
||||
self.env_creator,
|
||||
self.config,
|
||||
self.logdir,
|
||||
start_sampler=False)
|
||||
if self.config["use_gpu_for_workers"]:
|
||||
remote_cls = GPURemoteA3CEvaluator
|
||||
else:
|
||||
remote_cls = RemoteA3CEvaluator
|
||||
self.remote_evaluators = [
|
||||
remote_cls.remote(
|
||||
self.registry, self.env_creator, self.config, self.logdir)
|
||||
for i in range(self.config["num_workers"])]
|
||||
self.optimizer = AsyncOptimizer(
|
||||
self.config["optimizer"], self.local_evaluator,
|
||||
self.remote_evaluators)
|
||||
remote_cls.remote(self.registry, self.env_creator, self.config,
|
||||
self.logdir)
|
||||
for i in range(self.config["num_workers"])
|
||||
]
|
||||
self.optimizer = AsyncOptimizer(self.config["optimizer"],
|
||||
self.local_evaluator,
|
||||
self.remote_evaluators)
|
||||
|
||||
def _train(self):
|
||||
self.optimizer.step()
|
||||
FilterManager.synchronize(
|
||||
self.local_evaluator.filters, self.remote_evaluators)
|
||||
FilterManager.synchronize(self.local_evaluator.filters,
|
||||
self.remote_evaluators)
|
||||
res = self._fetch_metrics_from_remote_evaluators()
|
||||
return res
|
||||
|
||||
def _fetch_metrics_from_remote_evaluators(self):
|
||||
episode_rewards = []
|
||||
episode_lengths = []
|
||||
metric_lists = [a.get_completed_rollout_metrics.remote()
|
||||
for a in self.remote_evaluators]
|
||||
metric_lists = [
|
||||
a.get_completed_rollout_metrics.remote()
|
||||
for a in self.remote_evaluators
|
||||
]
|
||||
for metrics in metric_lists:
|
||||
for episode in ray.get(metrics):
|
||||
episode_lengths.append(episode.episode_length)
|
||||
episode_rewards.append(episode.episode_reward)
|
||||
avg_reward = (
|
||||
np.mean(episode_rewards) if episode_rewards else float('nan'))
|
||||
avg_length = (
|
||||
np.mean(episode_lengths) if episode_lengths else float('nan'))
|
||||
avg_reward = (np.mean(episode_rewards)
|
||||
if episode_rewards else float('nan'))
|
||||
avg_length = (np.mean(episode_lengths)
|
||||
if episode_lengths else float('nan'))
|
||||
timesteps = np.sum(episode_lengths) if episode_lengths else 0
|
||||
|
||||
result = TrainingResult(
|
||||
|
@ -129,21 +135,23 @@ class A3CAgent(Agent):
|
|||
ev.__ray_terminate__.remote()
|
||||
|
||||
def _save(self, checkpoint_dir):
|
||||
checkpoint_path = os.path.join(
|
||||
checkpoint_dir, "checkpoint-{}".format(self.iteration))
|
||||
checkpoint_path = os.path.join(checkpoint_dir,
|
||||
"checkpoint-{}".format(self.iteration))
|
||||
agent_state = ray.get(
|
||||
[a.save.remote() for a in self.remote_evaluators])
|
||||
extra_data = {
|
||||
"remote_state": agent_state,
|
||||
"local_state": self.local_evaluator.save()}
|
||||
"local_state": self.local_evaluator.save()
|
||||
}
|
||||
pickle.dump(extra_data, open(checkpoint_path + ".extra_data", "wb"))
|
||||
return checkpoint_path
|
||||
|
||||
def _restore(self, checkpoint_path):
|
||||
extra_data = pickle.load(open(checkpoint_path + ".extra_data", "rb"))
|
||||
ray.get(
|
||||
[a.restore.remote(o) for a, o in zip(
|
||||
self.remote_evaluators, extra_data["remote_state"])])
|
||||
ray.get([
|
||||
a.restore.remote(o)
|
||||
for a, o in zip(self.remote_evaluators, extra_data["remote_state"])
|
||||
])
|
||||
self.local_evaluator.restore(extra_data["local_state"])
|
||||
|
||||
def compute_action(self, observation):
|
||||
|
|
|
@ -3,7 +3,6 @@ from __future__ import division
|
|||
from __future__ import print_function
|
||||
|
||||
import torch
|
||||
from torch.autograd import Variable
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ray.rllib.a3c.torchpolicy import TorchPolicy
|
||||
|
@ -18,8 +17,8 @@ class SharedTorchPolicy(TorchPolicy):
|
|||
is_recurrent = False
|
||||
|
||||
def __init__(self, registry, ob_space, ac_space, config, **kwargs):
|
||||
super(SharedTorchPolicy, self).__init__(
|
||||
registry, ob_space, ac_space, config, **kwargs)
|
||||
super(SharedTorchPolicy, self).__init__(registry, ob_space, ac_space,
|
||||
config, **kwargs)
|
||||
|
||||
def _setup_graph(self, ob_space, ac_space):
|
||||
_, self.logit_dim = ModelCatalog.get_action_dist(ac_space)
|
||||
|
@ -31,32 +30,36 @@ class SharedTorchPolicy(TorchPolicy):
|
|||
def compute(self, ob, *args):
|
||||
"""Should take in a SINGLE ob"""
|
||||
with self.lock:
|
||||
ob = Variable(torch.from_numpy(ob).float().unsqueeze(0))
|
||||
ob = torch.from_numpy(ob).float().unsqueeze(0)
|
||||
logits, values = self._model(ob)
|
||||
samples = self._model.probs(logits).multinomial().squeeze()
|
||||
values = values.squeeze(0)
|
||||
return var_to_np(samples), {"vf_preds": var_to_np(values)}
|
||||
# TODO(alok): Support non-categorical distributions. Multinomial
|
||||
# is only for categorical.
|
||||
sampled_actions = F.softmax(logits, dim=1).multinomial(1).squeeze()
|
||||
values = values.squeeze()
|
||||
return var_to_np(sampled_actions), {"vf_preds": var_to_np(values)}
|
||||
|
||||
def compute_logits(self, ob, *args):
|
||||
with self.lock:
|
||||
ob = Variable(torch.from_numpy(ob).float().unsqueeze(0))
|
||||
ob = torch.from_numpy(ob).float().unsqueeze(0)
|
||||
res = self._model.hidden_layers(ob)
|
||||
return var_to_np(self._model.logits(res))
|
||||
|
||||
def value(self, ob, *args):
|
||||
with self.lock:
|
||||
ob = Variable(torch.from_numpy(ob).float().unsqueeze(0))
|
||||
ob = torch.from_numpy(ob).float().unsqueeze(0)
|
||||
res = self._model.hidden_layers(ob)
|
||||
res = self._model.value_branch(res)
|
||||
res = res.squeeze(0)
|
||||
res = res.squeeze()
|
||||
return var_to_np(res)
|
||||
|
||||
def _evaluate(self, obs, actions):
|
||||
"""Passes in multiple obs."""
|
||||
logits, values = self._model(obs)
|
||||
log_probs = F.log_softmax(logits)
|
||||
probs = self._model.probs(logits)
|
||||
log_probs = F.log_softmax(logits, dim=1)
|
||||
probs = F.softmax(logits, dim=1)
|
||||
action_log_probs = log_probs.gather(1, actions.view(-1, 1))
|
||||
# TODO(alok): set distribution based on action space and use its
|
||||
# `.entropy()` method to calculate automatically
|
||||
entropy = -(log_probs * probs).sum(-1).sum()
|
||||
return values, action_log_probs, entropy
|
||||
|
||||
|
@ -64,15 +67,19 @@ class SharedTorchPolicy(TorchPolicy):
|
|||
"""Loss is encoded in here. Defining a new loss function
|
||||
would start by rewriting this function"""
|
||||
|
||||
states, acs, advs, rs, _ = convert_batch(batch)
|
||||
values, ac_logprobs, entropy = self._evaluate(states, acs)
|
||||
pi_err = -(advs * ac_logprobs).sum()
|
||||
value_err = 0.5 * (values - rs).pow(2).sum()
|
||||
states, actions, advs, rs, _ = convert_batch(batch)
|
||||
values, action_log_probs, entropy = self._evaluate(states, actions)
|
||||
pi_err = -advs.dot(action_log_probs.reshape(-1))
|
||||
value_err = F.mse_loss(values.reshape(-1), rs)
|
||||
|
||||
self.optimizer.zero_grad()
|
||||
overall_err = (pi_err +
|
||||
value_err * self.config["vf_loss_coeff"] +
|
||||
entropy * self.config["entropy_coeff"])
|
||||
|
||||
overall_err = sum([
|
||||
pi_err,
|
||||
self.config["vf_loss_coeff"] * value_err,
|
||||
self.config["entropy_coeff"] * entropy,
|
||||
])
|
||||
|
||||
overall_err.backward()
|
||||
torch.nn.utils.clip_grad_norm(
|
||||
self._model.parameters(), self.config["grad_clip"])
|
||||
torch.nn.utils.clip_grad_norm_(self._model.parameters(),
|
||||
self.config["grad_clip"])
|
||||
|
|
|
@ -3,7 +3,6 @@ from __future__ import division
|
|||
from __future__ import print_function
|
||||
|
||||
import torch
|
||||
from torch.autograd import Variable
|
||||
|
||||
from ray.rllib.a3c.policy import Policy
|
||||
from threading import Lock
|
||||
|
@ -15,8 +14,13 @@ class TorchPolicy(Policy):
|
|||
The model is a separate object than the policy. This could be changed
|
||||
in the future."""
|
||||
|
||||
def __init__(self, registry, ob_space, action_space, config,
|
||||
name="local", summarize=True):
|
||||
def __init__(self,
|
||||
registry,
|
||||
ob_space,
|
||||
action_space,
|
||||
config,
|
||||
name="local",
|
||||
summarize=True):
|
||||
self.registry = registry
|
||||
self.local_steps = 0
|
||||
self.config = config
|
||||
|
@ -28,7 +32,7 @@ class TorchPolicy(Policy):
|
|||
def apply_gradients(self, grads):
|
||||
self.optimizer.zero_grad()
|
||||
for g, p in zip(grads, self._model.parameters()):
|
||||
p.grad = Variable(torch.from_numpy(g))
|
||||
p.grad = torch.from_numpy(g)
|
||||
self.optimizer.step()
|
||||
|
||||
def get_weights(self):
|
||||
|
@ -69,7 +73,7 @@ class TorchPolicy(Policy):
|
|||
|
||||
def _backward(self, batch):
|
||||
"""Implements the loss function and calculates the gradient.
|
||||
Pytorch automatically generates a backward trace for each variable.
|
||||
Pytorch automatically generates a backward trace for each tensor.
|
||||
Assumption right now is that variables are moved, so the backward
|
||||
trace is lost.
|
||||
|
||||
|
|
|
@ -180,11 +180,14 @@ class ModelCatalog(object):
|
|||
return registry.get(RLLIB_MODEL, model)(
|
||||
input_shape, num_outputs, options)
|
||||
|
||||
# TODO(alok): fix to handle Discrete(n) state spaces
|
||||
obs_rank = len(input_shape) - 1
|
||||
|
||||
if obs_rank > 1:
|
||||
return PyTorchVisionNet(input_shape, num_outputs, options)
|
||||
|
||||
# TODO(alok): overhaul PyTorchFCNet so it can just
|
||||
# take input shape directly
|
||||
return PyTorchFCNet(input_shape[0], num_outputs, options)
|
||||
|
||||
@staticmethod
|
||||
|
|
|
@ -9,6 +9,7 @@ import torch.nn as nn
|
|||
|
||||
class FullyConnectedNetwork(Model):
|
||||
"""TODO(rliaw): Logits, Value should both be contained here"""
|
||||
|
||||
def _init(self, inputs, num_outputs, options):
|
||||
assert type(inputs) is int
|
||||
hiddens = options.get("fcnet_hiddens", [256, 256])
|
||||
|
@ -23,26 +24,29 @@ class FullyConnectedNetwork(Model):
|
|||
layers = []
|
||||
last_layer_size = inputs
|
||||
for size in hiddens:
|
||||
layers.append(SlimFC(
|
||||
last_layer_size, size,
|
||||
initializer=normc_initializer(1.0),
|
||||
activation_fn=activation))
|
||||
layers.append(
|
||||
SlimFC(
|
||||
in_size=last_layer_size,
|
||||
out_size=size,
|
||||
initializer=normc_initializer(1.0),
|
||||
activation_fn=activation))
|
||||
last_layer_size = size
|
||||
|
||||
self.hidden_layers = nn.Sequential(*layers)
|
||||
|
||||
self.logits = SlimFC(
|
||||
last_layer_size, num_outputs,
|
||||
in_size=last_layer_size,
|
||||
out_size=num_outputs,
|
||||
initializer=normc_initializer(0.01),
|
||||
activation_fn=None)
|
||||
self.probs = nn.Softmax()
|
||||
self.value_branch = SlimFC(
|
||||
last_layer_size, 1,
|
||||
in_size=last_layer_size,
|
||||
out_size=1,
|
||||
initializer=normc_initializer(1.0),
|
||||
activation_fn=None)
|
||||
|
||||
def forward(self, obs):
|
||||
""" Internal method - pass in Variables, not numpy arrays
|
||||
""" Internal method - pass in torch tensors, not numpy arrays
|
||||
|
||||
Args:
|
||||
obs: observations and features
|
||||
|
@ -52,5 +56,5 @@ class FullyConnectedNetwork(Model):
|
|||
value: value function for each state"""
|
||||
res = self.hidden_layers(obs)
|
||||
logits = self.logits(res)
|
||||
value = self.value_branch(res)
|
||||
value = self.value_branch(res).reshape(-1)
|
||||
return logits, value
|
||||
|
|
|
@ -5,31 +5,24 @@ from __future__ import print_function
|
|||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.autograd import Variable
|
||||
|
||||
|
||||
def convert_batch(trajectory, has_features=False):
|
||||
"""Convert trajectory from numpy to PT variable"""
|
||||
states = Variable(torch.from_numpy(
|
||||
trajectory["observations"]).float())
|
||||
acs = Variable(torch.from_numpy(
|
||||
trajectory["actions"]))
|
||||
advs = Variable(torch.from_numpy(
|
||||
trajectory["advantages"].copy()).float())
|
||||
advs = advs.view(-1, 1)
|
||||
rs = Variable(torch.from_numpy(
|
||||
trajectory["value_targets"]).float())
|
||||
rs = rs.view(-1, 1)
|
||||
states = torch.from_numpy(trajectory["obs"]).float()
|
||||
acs = torch.from_numpy(trajectory["actions"])
|
||||
advs = torch.from_numpy(
|
||||
trajectory["advantages"].copy()).float().reshape(-1)
|
||||
rs = torch.from_numpy(trajectory["rewards"]).float().reshape(-1)
|
||||
if has_features:
|
||||
features = [Variable(torch.from_numpy(f))
|
||||
for f in trajectory["features"]]
|
||||
features = [torch.from_numpy(f) for f in trajectory["features"]]
|
||||
else:
|
||||
features = trajectory["features"]
|
||||
return states, acs, advs, rs, features
|
||||
|
||||
|
||||
def var_to_np(var):
|
||||
return var.data.numpy()[0]
|
||||
return var.detach().numpy()
|
||||
|
||||
|
||||
def normc_initializer(std=1.0):
|
||||
|
@ -37,6 +30,7 @@ def normc_initializer(std=1.0):
|
|||
tensor.data.normal_(0, 1)
|
||||
tensor.data *= std / torch.sqrt(
|
||||
tensor.data.pow(2).sum(1, keepdim=True))
|
||||
|
||||
return initializer
|
||||
|
||||
|
||||
|
|
|
@ -29,9 +29,15 @@ class Model(nn.Module):
|
|||
class SlimConv2d(nn.Module):
|
||||
"""Simple mock of tf.slim Conv2d"""
|
||||
|
||||
def __init__(self, in_channels, out_channels, kernel, stride, padding,
|
||||
initializer=nn.init.xavier_uniform,
|
||||
activation_fn=nn.ReLU, bias_init=0):
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel,
|
||||
stride,
|
||||
padding,
|
||||
initializer=nn.init.xavier_uniform_,
|
||||
activation_fn=nn.ReLU,
|
||||
bias_init=0):
|
||||
super(SlimConv2d, self).__init__()
|
||||
layers = []
|
||||
if padding:
|
||||
|
@ -39,7 +45,7 @@ class SlimConv2d(nn.Module):
|
|||
conv = nn.Conv2d(in_channels, out_channels, kernel, stride)
|
||||
if initializer:
|
||||
initializer(conv.weight)
|
||||
nn.init.constant(conv.bias, bias_init)
|
||||
nn.init.constant_(conv.bias, bias_init)
|
||||
|
||||
layers.append(conv)
|
||||
if activation_fn:
|
||||
|
@ -53,14 +59,18 @@ class SlimConv2d(nn.Module):
|
|||
class SlimFC(nn.Module):
|
||||
"""Simple PyTorch of `linear` function"""
|
||||
|
||||
def __init__(self, in_size, size, initializer=None,
|
||||
activation_fn=None, bias_init=0):
|
||||
def __init__(self,
|
||||
in_size,
|
||||
out_size,
|
||||
initializer=None,
|
||||
activation_fn=None,
|
||||
bias_init=0):
|
||||
super(SlimFC, self).__init__()
|
||||
layers = []
|
||||
linear = nn.Linear(in_size, size)
|
||||
linear = nn.Linear(in_size, out_size)
|
||||
if initializer:
|
||||
initializer(linear.weight)
|
||||
nn.init.constant(linear.bias, bias_init)
|
||||
nn.init.constant_(linear.bias, bias_init)
|
||||
layers.append(linear)
|
||||
if activation_fn:
|
||||
layers.append(activation_fn())
|
||||
|
|
|
@ -21,32 +21,31 @@ class VisionNetwork(Model):
|
|||
filters = options.get("conv_filters", [
|
||||
[16, [8, 8], 4],
|
||||
[32, [4, 4], 2],
|
||||
[512, [10, 10], 1]
|
||||
[512, [10, 10], 1],
|
||||
])
|
||||
layers = []
|
||||
in_channels, in_size = inputs[0], inputs[1:]
|
||||
|
||||
for out_channels, kernel, stride in filters[:-1]:
|
||||
padding, out_size = valid_padding(
|
||||
in_size, kernel, [stride, stride])
|
||||
layers.append(SlimConv2d(
|
||||
in_channels, out_channels, kernel, stride, padding))
|
||||
padding, out_size = valid_padding(in_size, kernel,
|
||||
[stride, stride])
|
||||
layers.append(
|
||||
SlimConv2d(in_channels, out_channels, kernel, stride, padding))
|
||||
in_channels = out_channels
|
||||
in_size = out_size
|
||||
|
||||
out_channels, kernel, stride = filters[-1]
|
||||
layers.append(SlimConv2d(
|
||||
in_channels, out_channels, kernel, stride, None))
|
||||
layers.append(
|
||||
SlimConv2d(in_channels, out_channels, kernel, stride, None))
|
||||
self._convs = nn.Sequential(*layers)
|
||||
|
||||
self.logits = SlimFC(
|
||||
out_channels, num_outputs, initializer=nn.init.xavier_uniform)
|
||||
self.probs = nn.Softmax()
|
||||
out_channels, num_outputs, initializer=nn.init.xavier_uniform_)
|
||||
self.value_branch = SlimFC(
|
||||
out_channels, 1, initializer=normc_initializer())
|
||||
|
||||
def hidden_layers(self, obs):
|
||||
""" Internal method - pass in Variables, not numpy arrays
|
||||
""" Internal method - pass in torch tensors, not numpy arrays
|
||||
|
||||
args:
|
||||
obs: observations and features"""
|
||||
|
|
|
@ -146,12 +146,19 @@ docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \
|
|||
--stop '{"training_iteration": 2}' \
|
||||
--config '{"kl_coeff": 1.0, "num_sgd_iter": 10, "sgd_stepsize": 1e-4, "sgd_batchsize": 64, "timesteps_per_batch": 2000, "num_workers": 1, "model": {"dim": 40, "conv_filters": [[16, [8, 8], 4], [32, [4, 4], 2], [512, [5, 5], 1]]}, "extra_frameskip": 4}'
|
||||
|
||||
# docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \
|
||||
# python /ray/python/ray/rllib/train.py \
|
||||
# --env PongDeterministic-v4 \
|
||||
# --run A3C \
|
||||
# --stop '{"training_iteration": 2}' \
|
||||
# --config '{"num_workers": 2, "use_lstm": false, "use_pytorch": true, "model": {"grayscale": true, "zero_mean": false, "dim": 80, "channel_major": true}}'
|
||||
docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \
|
||||
python /ray/python/ray/rllib/train.py \
|
||||
--env PongDeterministic-v4 \
|
||||
--run A3C \
|
||||
--stop '{"training_iteration": 2}' \
|
||||
--config '{"num_workers": 2, "use_lstm": false, "use_pytorch": true, "model": {"grayscale": true, "zero_mean": false, "dim": 80, "channel_major": true}}'
|
||||
|
||||
docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \
|
||||
python /ray/python/ray/rllib/train.py \
|
||||
--env CartPole-v1 \
|
||||
--run A3C \
|
||||
--stop '{"training_iteration": 2}' \
|
||||
--config '{"num_workers": 2, "use_lstm": false, "use_pytorch": true}'
|
||||
|
||||
docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \
|
||||
python /ray/python/ray/rllib/train.py \
|
||||
|
|
Loading…
Add table
Reference in a new issue