mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
[tune] Add PyTorch MNIST Example + Misc. Tweaks (#2708)
This commit is contained in:
parent
224d38cbb2
commit
0347e6418b
8 changed files with 492 additions and 17 deletions
|
@ -114,8 +114,7 @@ def make_parser(parser_creator=None, **kwargs):
|
|||
"A value of 0 (default) disables checkpointing.")
|
||||
parser.add_argument(
|
||||
"--checkpoint-at-end",
|
||||
default=False,
|
||||
type=bool,
|
||||
action="store_true",
|
||||
help="Whether to checkpoint at the end of the experiment. "
|
||||
"Default is False.")
|
||||
parser.add_argument(
|
||||
|
@ -152,11 +151,12 @@ def to_argv(config):
|
|||
for k, v in config.items():
|
||||
if "-" in k:
|
||||
raise ValueError("Use '_' instead of '-' in `{}`".format(k))
|
||||
argv.append("--{}".format(k.replace("_", "-")))
|
||||
if not isinstance(v, bool) or v: # for argparse flags
|
||||
argv.append("--{}".format(k.replace("_", "-")))
|
||||
if isinstance(v, string_types):
|
||||
argv.append(v)
|
||||
elif isinstance(v, bool):
|
||||
argv.append(v)
|
||||
pass
|
||||
else:
|
||||
argv.append(json.dumps(v, cls=_SafeFallbackEncoder))
|
||||
return argv
|
||||
|
|
191
python/ray/tune/examples/mnist_pytorch.py
Normal file
191
python/ray/tune/examples/mnist_pytorch.py
Normal file
|
@ -0,0 +1,191 @@
|
|||
# Original Code here:
|
||||
# https://github.com/pytorch/examples/blob/master/mnist/main.py
|
||||
from __future__ import print_function
|
||||
|
||||
import argparse
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch.optim as optim
|
||||
from torchvision import datasets, transforms
|
||||
from torch.autograd import Variable
|
||||
|
||||
# Training settings
|
||||
parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
|
||||
parser.add_argument(
|
||||
'--batch-size',
|
||||
type=int,
|
||||
default=64,
|
||||
metavar='N',
|
||||
help='input batch size for training (default: 64)')
|
||||
parser.add_argument(
|
||||
'--test-batch-size',
|
||||
type=int,
|
||||
default=1000,
|
||||
metavar='N',
|
||||
help='input batch size for testing (default: 1000)')
|
||||
parser.add_argument(
|
||||
'--epochs',
|
||||
type=int,
|
||||
default=10,
|
||||
metavar='N',
|
||||
help='number of epochs to train (default: 10)')
|
||||
parser.add_argument(
|
||||
'--lr',
|
||||
type=float,
|
||||
default=0.01,
|
||||
metavar='LR',
|
||||
help='learning rate (default: 0.01)')
|
||||
parser.add_argument(
|
||||
'--momentum',
|
||||
type=float,
|
||||
default=0.5,
|
||||
metavar='M',
|
||||
help='SGD momentum (default: 0.5)')
|
||||
parser.add_argument(
|
||||
'--no-cuda',
|
||||
action='store_true',
|
||||
default=False,
|
||||
help='disables CUDA training')
|
||||
parser.add_argument(
|
||||
'--seed',
|
||||
type=int,
|
||||
default=1,
|
||||
metavar='S',
|
||||
help='random seed (default: 1)')
|
||||
parser.add_argument(
|
||||
'--smoke-test', action="store_true", help="Finish quickly for testing")
|
||||
|
||||
|
||||
def train_mnist(args, config, reporter):
|
||||
vars(args).update(config)
|
||||
args.cuda = not args.no_cuda and torch.cuda.is_available()
|
||||
|
||||
torch.manual_seed(args.seed)
|
||||
if args.cuda:
|
||||
torch.cuda.manual_seed(args.seed)
|
||||
|
||||
kwargs = {'num_workers': 1, 'pin_memory': True} if args.cuda else {}
|
||||
train_loader = torch.utils.data.DataLoader(
|
||||
datasets.MNIST(
|
||||
'~/data',
|
||||
train=True,
|
||||
download=False,
|
||||
transform=transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize((0.1307, ), (0.3081, ))
|
||||
])),
|
||||
batch_size=args.batch_size,
|
||||
shuffle=True,
|
||||
**kwargs)
|
||||
test_loader = torch.utils.data.DataLoader(
|
||||
datasets.MNIST(
|
||||
'~/data',
|
||||
train=False,
|
||||
transform=transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize((0.1307, ), (0.3081, ))
|
||||
])),
|
||||
batch_size=args.test_batch_size,
|
||||
shuffle=True,
|
||||
**kwargs)
|
||||
|
||||
class Net(nn.Module):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
|
||||
self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
|
||||
self.conv2_drop = nn.Dropout2d()
|
||||
self.fc1 = nn.Linear(320, 50)
|
||||
self.fc2 = nn.Linear(50, 10)
|
||||
|
||||
def forward(self, x):
|
||||
x = F.relu(F.max_pool2d(self.conv1(x), 2))
|
||||
x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
|
||||
x = x.view(-1, 320)
|
||||
x = F.relu(self.fc1(x))
|
||||
x = F.dropout(x, training=self.training)
|
||||
x = self.fc2(x)
|
||||
return F.log_softmax(x, dim=1)
|
||||
|
||||
model = Net()
|
||||
if args.cuda:
|
||||
model.cuda()
|
||||
|
||||
optimizer = optim.SGD(
|
||||
model.parameters(), lr=args.lr, momentum=args.momentum)
|
||||
|
||||
def train(epoch):
|
||||
model.train()
|
||||
for batch_idx, (data, target) in enumerate(train_loader):
|
||||
if args.cuda:
|
||||
data, target = data.cuda(), target.cuda()
|
||||
data, target = Variable(data), Variable(target)
|
||||
optimizer.zero_grad()
|
||||
output = model(data)
|
||||
loss = F.nll_loss(output, target)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
def test():
|
||||
model.eval()
|
||||
test_loss = 0
|
||||
correct = 0
|
||||
for data, target in test_loader:
|
||||
if args.cuda:
|
||||
data, target = data.cuda(), target.cuda()
|
||||
data, target = Variable(data, volatile=True), Variable(target)
|
||||
output = model(data)
|
||||
test_loss += F.nll_loss(
|
||||
output, target,
|
||||
size_average=False).data[0] # sum up batch loss
|
||||
pred = output.data.max(
|
||||
1, keepdim=True)[1] # get the index of the max log-probability
|
||||
correct += pred.eq(target.data.view_as(pred)).long().cpu().sum()
|
||||
|
||||
test_loss = test_loss.item() / len(test_loader.dataset)
|
||||
accuracy = correct.item() / len(test_loader.dataset)
|
||||
reporter(mean_loss=test_loss, mean_accuracy=accuracy)
|
||||
|
||||
for epoch in range(1, args.epochs + 1):
|
||||
train(epoch)
|
||||
test()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
datasets.MNIST('~/data', train=True, download=True)
|
||||
args = parser.parse_args()
|
||||
|
||||
import numpy as np
|
||||
import ray
|
||||
from ray import tune
|
||||
from ray.tune.schedulers import AsyncHyperBandScheduler
|
||||
|
||||
ray.init()
|
||||
sched = AsyncHyperBandScheduler(
|
||||
time_attr="training_iteration",
|
||||
reward_attr="neg_mean_loss",
|
||||
max_t=400,
|
||||
grace_period=20)
|
||||
tune.register_trainable("train_mnist",
|
||||
lambda cfg, rprtr: train_mnist(args, cfg, rprtr))
|
||||
tune.run_experiments(
|
||||
{
|
||||
"exp": {
|
||||
"stop": {
|
||||
"mean_accuracy": 0.98,
|
||||
"training_iteration": 1 if args.smoke_test else 20
|
||||
},
|
||||
"trial_resources": {
|
||||
"cpu": 3
|
||||
},
|
||||
"run": "train_mnist",
|
||||
"num_samples": 1 if args.smoke_test else 10,
|
||||
"config": {
|
||||
"lr": lambda spec: np.random.uniform(0.001, 0.1),
|
||||
"momentum": lambda spec: np.random.uniform(0.1, 0.9),
|
||||
}
|
||||
}
|
||||
},
|
||||
verbose=0,
|
||||
scheduler=sched)
|
203
python/ray/tune/examples/mnist_pytorch_trainable.py
Normal file
203
python/ray/tune/examples/mnist_pytorch_trainable.py
Normal file
|
@ -0,0 +1,203 @@
|
|||
# Original Code here:
|
||||
# https://github.com/pytorch/examples/blob/master/mnist/main.py
|
||||
from __future__ import print_function
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch.optim as optim
|
||||
from torchvision import datasets, transforms
|
||||
from torch.autograd import Variable
|
||||
|
||||
from ray.tune import Trainable
|
||||
|
||||
# Training settings
|
||||
parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
|
||||
parser.add_argument(
|
||||
'--batch-size',
|
||||
type=int,
|
||||
default=64,
|
||||
metavar='N',
|
||||
help='input batch size for training (default: 64)')
|
||||
parser.add_argument(
|
||||
'--test-batch-size',
|
||||
type=int,
|
||||
default=1000,
|
||||
metavar='N',
|
||||
help='input batch size for testing (default: 1000)')
|
||||
parser.add_argument(
|
||||
'--epochs',
|
||||
type=int,
|
||||
default=10,
|
||||
metavar='N',
|
||||
help='number of epochs to train (default: 10)')
|
||||
parser.add_argument(
|
||||
'--lr',
|
||||
type=float,
|
||||
default=0.01,
|
||||
metavar='LR',
|
||||
help='learning rate (default: 0.01)')
|
||||
parser.add_argument(
|
||||
'--momentum',
|
||||
type=float,
|
||||
default=0.5,
|
||||
metavar='M',
|
||||
help='SGD momentum (default: 0.5)')
|
||||
parser.add_argument(
|
||||
'--no-cuda',
|
||||
action='store_true',
|
||||
default=False,
|
||||
help='disables CUDA training')
|
||||
parser.add_argument(
|
||||
'--seed',
|
||||
type=int,
|
||||
default=1,
|
||||
metavar='S',
|
||||
help='random seed (default: 1)')
|
||||
parser.add_argument(
|
||||
'--smoke-test', action="store_true", help="Finish quickly for testing")
|
||||
|
||||
|
||||
class Net(nn.Module):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
|
||||
self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
|
||||
self.conv2_drop = nn.Dropout2d()
|
||||
self.fc1 = nn.Linear(320, 50)
|
||||
self.fc2 = nn.Linear(50, 10)
|
||||
|
||||
def forward(self, x):
|
||||
x = F.relu(F.max_pool2d(self.conv1(x), 2))
|
||||
x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
|
||||
x = x.view(-1, 320)
|
||||
x = F.relu(self.fc1(x))
|
||||
x = F.dropout(x, training=self.training)
|
||||
x = self.fc2(x)
|
||||
return F.log_softmax(x, dim=1)
|
||||
|
||||
|
||||
class TrainMNIST(Trainable):
|
||||
def _setup(self):
|
||||
args = self.config.pop("args")
|
||||
vars(args).update(self.config)
|
||||
args.cuda = not args.no_cuda and torch.cuda.is_available()
|
||||
|
||||
torch.manual_seed(args.seed)
|
||||
if args.cuda:
|
||||
torch.cuda.manual_seed(args.seed)
|
||||
|
||||
kwargs = {'num_workers': 1, 'pin_memory': True} if args.cuda else {}
|
||||
self.train_loader = torch.utils.data.DataLoader(
|
||||
datasets.MNIST(
|
||||
'~/data',
|
||||
train=True,
|
||||
download=False,
|
||||
transform=transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize((0.1307, ), (0.3081, ))
|
||||
])),
|
||||
batch_size=args.batch_size,
|
||||
shuffle=True,
|
||||
**kwargs)
|
||||
self.test_loader = torch.utils.data.DataLoader(
|
||||
datasets.MNIST(
|
||||
'~/data',
|
||||
train=False,
|
||||
transform=transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize((0.1307, ), (0.3081, ))
|
||||
])),
|
||||
batch_size=args.test_batch_size,
|
||||
shuffle=True,
|
||||
**kwargs)
|
||||
|
||||
self.model = Net()
|
||||
if args.cuda:
|
||||
self.model.cuda()
|
||||
|
||||
self.optimizer = optim.SGD(
|
||||
self.model.parameters(), lr=args.lr, momentum=args.momentum)
|
||||
self.args = args
|
||||
|
||||
def _train_iteration(self):
|
||||
self.model.train()
|
||||
for batch_idx, (data, target) in enumerate(self.train_loader):
|
||||
if self.args.cuda:
|
||||
data, target = data.cuda(), target.cuda()
|
||||
data, target = Variable(data), Variable(target)
|
||||
self.optimizer.zero_grad()
|
||||
output = self.model(data)
|
||||
loss = F.nll_loss(output, target)
|
||||
loss.backward()
|
||||
self.optimizer.step()
|
||||
|
||||
def _test(self):
|
||||
self.model.eval()
|
||||
test_loss = 0
|
||||
correct = 0
|
||||
for data, target in self.test_loader:
|
||||
if self.args.cuda:
|
||||
data, target = data.cuda(), target.cuda()
|
||||
data, target = Variable(data, volatile=True), Variable(target)
|
||||
output = self.model(data)
|
||||
|
||||
# sum up batch loss
|
||||
test_loss += F.nll_loss(output, target, size_average=False).data[0]
|
||||
|
||||
# get the index of the max log-probability
|
||||
pred = output.data.max(1, keepdim=True)[1]
|
||||
correct += pred.eq(target.data.view_as(pred)).long().cpu().sum()
|
||||
|
||||
test_loss = test_loss.item() / len(self.test_loader.dataset)
|
||||
accuracy = correct.item() / len(self.test_loader.dataset)
|
||||
return {"mean_loss": test_loss, "mean_accuracy": accuracy}
|
||||
|
||||
def _train(self):
|
||||
self._train_iteration()
|
||||
return self._test()
|
||||
|
||||
def _save(self, path):
|
||||
torch.save(self.model.state_dict(), os.path.join(path, "model.pth"))
|
||||
return path
|
||||
|
||||
def _restore(self, path):
|
||||
self.model.load_state_dict(os.path.join(path, "model.pth"))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
datasets.MNIST('~/data', train=True, download=True)
|
||||
args = parser.parse_args()
|
||||
|
||||
import numpy as np
|
||||
import ray
|
||||
from ray import tune
|
||||
from ray.tune.schedulers import HyperBandScheduler
|
||||
|
||||
ray.init()
|
||||
sched = HyperBandScheduler(
|
||||
time_attr="training_iteration", reward_attr="neg_mean_loss")
|
||||
tune.run_experiments(
|
||||
{
|
||||
"exp": {
|
||||
"stop": {
|
||||
"mean_accuracy": 0.95,
|
||||
"training_iteration": 1 if args.smoke_test else 20,
|
||||
},
|
||||
"trial_resources": {
|
||||
"cpu": 3
|
||||
},
|
||||
"run": TrainMNIST,
|
||||
"num_samples": 1 if args.smoke_test else 20,
|
||||
"checkpoint_at_end": True,
|
||||
"config": {
|
||||
"args": args,
|
||||
"lr": lambda spec: np.random.uniform(0.001, 0.1),
|
||||
"momentum": lambda spec: np.random.uniform(0.1, 0.9),
|
||||
}
|
||||
}
|
||||
},
|
||||
verbose=0,
|
||||
scheduler=sched)
|
|
@ -115,11 +115,12 @@ class FunctionRunner(Trainable):
|
|||
time.sleep(1)
|
||||
result = self._status_reporter._get_and_clear_status()
|
||||
|
||||
curr_ts_total = result.get(TIMESTEPS_TOTAL,
|
||||
self._last_reported_timestep)
|
||||
result.update(
|
||||
timesteps_this_iter=(curr_ts_total - self._last_reported_timestep))
|
||||
self._last_reported_timestep = curr_ts_total
|
||||
curr_ts_total = result.get(TIMESTEPS_TOTAL)
|
||||
if curr_ts_total is not None:
|
||||
result.update(
|
||||
timesteps_this_iter=(
|
||||
curr_ts_total - self._last_reported_timestep))
|
||||
self._last_reported_timestep = curr_ts_total
|
||||
|
||||
return result
|
||||
|
||||
|
|
|
@ -383,6 +383,36 @@ class TrainableFunctionApiTest(unittest.TestCase):
|
|||
self.assertEqual(trial.status, Trial.TERMINATED)
|
||||
self.assertEqual(trial.last_result['mean_accuracy'], float('inf'))
|
||||
|
||||
def testReportTimeStep(self):
|
||||
def train(config, reporter):
|
||||
for i in range(100):
|
||||
reporter(mean_accuracy=5)
|
||||
|
||||
[trial] = run_experiments({
|
||||
"foo": {
|
||||
"run": train,
|
||||
"config": {
|
||||
"script_min_iter_time_s": 0,
|
||||
},
|
||||
}
|
||||
})
|
||||
self.assertIsNone(trial.last_result[TIMESTEPS_TOTAL])
|
||||
|
||||
def train3(config, reporter):
|
||||
for i in range(10):
|
||||
reporter(timesteps_total=5)
|
||||
|
||||
[trial3] = run_experiments({
|
||||
"foo": {
|
||||
"run": train3,
|
||||
"config": {
|
||||
"script_min_iter_time_s": 0,
|
||||
},
|
||||
}
|
||||
})
|
||||
self.assertEqual(trial3.last_result[TIMESTEPS_TOTAL], 5)
|
||||
self.assertEqual(trial3.last_result["timesteps_this_iter"], 0)
|
||||
|
||||
|
||||
class RunExperimentTest(unittest.TestCase):
|
||||
def setUp(self):
|
||||
|
@ -505,6 +535,24 @@ class RunExperimentTest(unittest.TestCase):
|
|||
for trial in trials:
|
||||
self.assertEqual(trial.status, Trial.TERMINATED)
|
||||
|
||||
def testCheckpointAtEnd(self):
|
||||
class train(Trainable):
|
||||
def _train(self):
|
||||
return dict(timesteps_this_iter=1, done=True)
|
||||
|
||||
def _save(self, path):
|
||||
return path
|
||||
|
||||
trials = run_experiments({
|
||||
"foo": {
|
||||
"run": train,
|
||||
"checkpoint_at_end": True
|
||||
}
|
||||
})
|
||||
for trial in trials:
|
||||
self.assertEqual(trial.status, Trial.TERMINATED)
|
||||
self.assertTrue(trial.has_checkpoint())
|
||||
|
||||
|
||||
class VariantGeneratorTest(unittest.TestCase):
|
||||
def setUp(self):
|
||||
|
|
|
@ -74,7 +74,7 @@ class Trainable(object):
|
|||
|
||||
self._iteration = 0
|
||||
self._time_total = 0.0
|
||||
self._timesteps_total = 0
|
||||
self._timesteps_total = None
|
||||
self._setup()
|
||||
self._initialize_ok = True
|
||||
self._local_ip = ray.services.get_node_ip_address()
|
||||
|
@ -150,9 +150,15 @@ class Trainable(object):
|
|||
time_this_iter = time.time() - start
|
||||
self._time_total += time_this_iter
|
||||
|
||||
self._timesteps_total += result.get(TIMESTEPS_THIS_ITER, 0)
|
||||
|
||||
result.setdefault(DONE, False)
|
||||
|
||||
# self._timesteps_total should only be tracked if increments provided
|
||||
if result.get(TIMESTEPS_THIS_ITER):
|
||||
if self._timesteps_total is None:
|
||||
self._timesteps_total = 0
|
||||
self._timesteps_total += result[TIMESTEPS_THIS_ITER]
|
||||
|
||||
# self._timesteps_total should not override user-provided total
|
||||
result.setdefault(TIMESTEPS_TOTAL, self._timesteps_total)
|
||||
|
||||
# Provides auto-filled neg_mean_loss for avoiding regressions
|
||||
|
@ -278,12 +284,26 @@ class Trainable(object):
|
|||
raise NotImplementedError
|
||||
|
||||
def _save(self, checkpoint_dir):
|
||||
"""Subclasses should override this to implement save()."""
|
||||
"""Subclasses should override this to implement save().
|
||||
|
||||
Args:
|
||||
checkpoint_dir (str): The directory where the checkpoint
|
||||
can be stored.
|
||||
|
||||
Returns:
|
||||
Checkpoint path that may be passed to restore(). Typically
|
||||
would default to `checkpoint_dir`.
|
||||
"""
|
||||
|
||||
raise NotImplementedError
|
||||
|
||||
def _restore(self, checkpoint_path):
|
||||
"""Subclasses should override this to implement restore()."""
|
||||
"""Subclasses should override this to implement restore().
|
||||
|
||||
Args:
|
||||
checkpoint_path (str): The directory where the checkpoint
|
||||
is stored.
|
||||
"""
|
||||
|
||||
raise NotImplementedError
|
||||
|
||||
|
|
|
@ -16,7 +16,7 @@ from ray.tune.logger import pretty_print, UnifiedLogger
|
|||
# have been defined yet. See https://github.com/ray-project/ray/issues/1716.
|
||||
import ray.tune.registry
|
||||
from ray.tune.result import (DEFAULT_RESULTS_DIR, DONE, HOSTNAME, PID,
|
||||
TIME_TOTAL_S, TRAINING_ITERATION)
|
||||
TIME_TOTAL_S, TRAINING_ITERATION, TIMESTEPS_TOTAL)
|
||||
from ray.utils import random_string, binary_to_hex
|
||||
|
||||
DEBUG_PRINT_INTERVAL = 5
|
||||
|
@ -237,8 +237,12 @@ class Trial(object):
|
|||
int(self.last_result.get(TIME_TOTAL_S)))
|
||||
]
|
||||
|
||||
if self.last_result.get("timesteps_total") is not None:
|
||||
pieces.append('{} ts'.format(self.last_result["timesteps_total"]))
|
||||
if self.last_result.get(TRAINING_ITERATION) is not None:
|
||||
pieces.append('{} iter'.format(
|
||||
self.last_result[TRAINING_ITERATION]))
|
||||
|
||||
if self.last_result.get(TIMESTEPS_TOTAL) is not None:
|
||||
pieces.append('{} ts'.format(self.last_result[TIMESTEPS_TOTAL]))
|
||||
|
||||
if self.last_result.get("episode_reward_mean") is not None:
|
||||
pieces.append('{} rew'.format(
|
||||
|
|
|
@ -266,6 +266,14 @@ docker run -e "RAY_USE_XRAY=1" --rm --shm-size=10G --memory=10G $DOCKER_SHA \
|
|||
python /ray/python/ray/tune/examples/tune_mnist_keras.py \
|
||||
--smoke-test
|
||||
|
||||
docker run -e "RAY_USE_XRAY=1" --rm --shm-size=10G --memory=10G $DOCKER_SHA \
|
||||
python /ray/python/ray/tune/examples/mnist_pytorch.py \
|
||||
--smoke-test
|
||||
|
||||
docker run -e "RAY_USE_XRAY=1" --rm --shm-size=10G --memory=10G $DOCKER_SHA \
|
||||
python /ray/python/ray/tune/examples/mnist_pytorch_trainable.py \
|
||||
--smoke-test
|
||||
|
||||
docker run -e "RAY_USE_XRAY=1" --rm --shm-size=10G --memory=10G $DOCKER_SHA \
|
||||
python /ray/python/ray/rllib/examples/legacy_multiagent/multiagent_mountaincar.py
|
||||
|
||||
|
|
Loading…
Add table
Reference in a new issue