diff --git a/docker/examples/Dockerfile b/docker/examples/Dockerfile index 382da881b..6316b4eeb 100644 --- a/docker/examples/Dockerfile +++ b/docker/examples/Dockerfile @@ -1,8 +1,10 @@ # The examples Docker image adds dependencies needed to run the examples FROM ray-project/deploy -RUN conda install -y -c conda-forge tensorflow + +# This updates numpy to 1.14 and mutes errors from other libraries +RUN conda install -y numpy RUN apt-get install -y zlib1g-dev -RUN pip install gym[atari] opencv-python==3.2.0.8 +RUN pip install gym[atari] opencv-python==3.2.0.8 tensorflow RUN pip install --upgrade git+git://github.com/hyperopt/hyperopt.git -# RUN conda install -y -q pytorch torchvision -c soumith +RUN conda install pytorch-cpu torchvision-cpu -c pytorch diff --git a/python/ray/rllib/a3c/a3c.py b/python/ray/rllib/a3c/a3c.py index 15f5aa187..569b50c44 100644 --- a/python/ray/rllib/a3c/a3c.py +++ b/python/ray/rllib/a3c/a3c.py @@ -15,7 +15,6 @@ from ray.rllib.a3c.a3c_evaluator import A3CEvaluator, RemoteA3CEvaluator, \ from ray.tune.result import TrainingResult from ray.tune.trial import Resources - DEFAULT_CONFIG = { # Number of workers (excluding master) "num_workers": 4, @@ -52,7 +51,7 @@ DEFAULT_CONFIG = { # (Image statespace) - Converts image to (dim, dim, C) "dim": 80, # (Image statespace) - Converts image shape to (C, dim, dim) - "channel_major": False + "channel_major": False, }, # Arguments to pass to the rllib optimizer "optimizer": { @@ -73,46 +72,53 @@ class A3CAgent(Agent): def default_resource_request(cls, config): cf = dict(cls._default_config, **config) return Resources( - cpu=1, gpu=0, + cpu=1, + gpu=0, extra_cpu=cf["num_workers"], extra_gpu=cf["use_gpu_for_workers"] and cf["num_workers"] or 0) def _init(self): self.local_evaluator = A3CEvaluator( - self.registry, self.env_creator, self.config, self.logdir, + self.registry, + self.env_creator, + self.config, + self.logdir, start_sampler=False) if self.config["use_gpu_for_workers"]: remote_cls = GPURemoteA3CEvaluator else: remote_cls = RemoteA3CEvaluator self.remote_evaluators = [ - remote_cls.remote( - self.registry, self.env_creator, self.config, self.logdir) - for i in range(self.config["num_workers"])] - self.optimizer = AsyncOptimizer( - self.config["optimizer"], self.local_evaluator, - self.remote_evaluators) + remote_cls.remote(self.registry, self.env_creator, self.config, + self.logdir) + for i in range(self.config["num_workers"]) + ] + self.optimizer = AsyncOptimizer(self.config["optimizer"], + self.local_evaluator, + self.remote_evaluators) def _train(self): self.optimizer.step() - FilterManager.synchronize( - self.local_evaluator.filters, self.remote_evaluators) + FilterManager.synchronize(self.local_evaluator.filters, + self.remote_evaluators) res = self._fetch_metrics_from_remote_evaluators() return res def _fetch_metrics_from_remote_evaluators(self): episode_rewards = [] episode_lengths = [] - metric_lists = [a.get_completed_rollout_metrics.remote() - for a in self.remote_evaluators] + metric_lists = [ + a.get_completed_rollout_metrics.remote() + for a in self.remote_evaluators + ] for metrics in metric_lists: for episode in ray.get(metrics): episode_lengths.append(episode.episode_length) episode_rewards.append(episode.episode_reward) - avg_reward = ( - np.mean(episode_rewards) if episode_rewards else float('nan')) - avg_length = ( - np.mean(episode_lengths) if episode_lengths else float('nan')) + avg_reward = (np.mean(episode_rewards) + if episode_rewards else float('nan')) + avg_length = (np.mean(episode_lengths) + if episode_lengths else float('nan')) timesteps = np.sum(episode_lengths) if episode_lengths else 0 result = TrainingResult( @@ -129,21 +135,23 @@ class A3CAgent(Agent): ev.__ray_terminate__.remote() def _save(self, checkpoint_dir): - checkpoint_path = os.path.join( - checkpoint_dir, "checkpoint-{}".format(self.iteration)) + checkpoint_path = os.path.join(checkpoint_dir, + "checkpoint-{}".format(self.iteration)) agent_state = ray.get( [a.save.remote() for a in self.remote_evaluators]) extra_data = { "remote_state": agent_state, - "local_state": self.local_evaluator.save()} + "local_state": self.local_evaluator.save() + } pickle.dump(extra_data, open(checkpoint_path + ".extra_data", "wb")) return checkpoint_path def _restore(self, checkpoint_path): extra_data = pickle.load(open(checkpoint_path + ".extra_data", "rb")) - ray.get( - [a.restore.remote(o) for a, o in zip( - self.remote_evaluators, extra_data["remote_state"])]) + ray.get([ + a.restore.remote(o) + for a, o in zip(self.remote_evaluators, extra_data["remote_state"]) + ]) self.local_evaluator.restore(extra_data["local_state"]) def compute_action(self, observation): diff --git a/python/ray/rllib/a3c/shared_torch_policy.py b/python/ray/rllib/a3c/shared_torch_policy.py index 36b39dcfc..d98a2f6dc 100644 --- a/python/ray/rllib/a3c/shared_torch_policy.py +++ b/python/ray/rllib/a3c/shared_torch_policy.py @@ -3,7 +3,6 @@ from __future__ import division from __future__ import print_function import torch -from torch.autograd import Variable import torch.nn.functional as F from ray.rllib.a3c.torchpolicy import TorchPolicy @@ -18,8 +17,8 @@ class SharedTorchPolicy(TorchPolicy): is_recurrent = False def __init__(self, registry, ob_space, ac_space, config, **kwargs): - super(SharedTorchPolicy, self).__init__( - registry, ob_space, ac_space, config, **kwargs) + super(SharedTorchPolicy, self).__init__(registry, ob_space, ac_space, + config, **kwargs) def _setup_graph(self, ob_space, ac_space): _, self.logit_dim = ModelCatalog.get_action_dist(ac_space) @@ -31,32 +30,36 @@ class SharedTorchPolicy(TorchPolicy): def compute(self, ob, *args): """Should take in a SINGLE ob""" with self.lock: - ob = Variable(torch.from_numpy(ob).float().unsqueeze(0)) + ob = torch.from_numpy(ob).float().unsqueeze(0) logits, values = self._model(ob) - samples = self._model.probs(logits).multinomial().squeeze() - values = values.squeeze(0) - return var_to_np(samples), {"vf_preds": var_to_np(values)} + # TODO(alok): Support non-categorical distributions. Multinomial + # is only for categorical. + sampled_actions = F.softmax(logits, dim=1).multinomial(1).squeeze() + values = values.squeeze() + return var_to_np(sampled_actions), {"vf_preds": var_to_np(values)} def compute_logits(self, ob, *args): with self.lock: - ob = Variable(torch.from_numpy(ob).float().unsqueeze(0)) + ob = torch.from_numpy(ob).float().unsqueeze(0) res = self._model.hidden_layers(ob) return var_to_np(self._model.logits(res)) def value(self, ob, *args): with self.lock: - ob = Variable(torch.from_numpy(ob).float().unsqueeze(0)) + ob = torch.from_numpy(ob).float().unsqueeze(0) res = self._model.hidden_layers(ob) res = self._model.value_branch(res) - res = res.squeeze(0) + res = res.squeeze() return var_to_np(res) def _evaluate(self, obs, actions): """Passes in multiple obs.""" logits, values = self._model(obs) - log_probs = F.log_softmax(logits) - probs = self._model.probs(logits) + log_probs = F.log_softmax(logits, dim=1) + probs = F.softmax(logits, dim=1) action_log_probs = log_probs.gather(1, actions.view(-1, 1)) + # TODO(alok): set distribution based on action space and use its + # `.entropy()` method to calculate automatically entropy = -(log_probs * probs).sum(-1).sum() return values, action_log_probs, entropy @@ -64,15 +67,19 @@ class SharedTorchPolicy(TorchPolicy): """Loss is encoded in here. Defining a new loss function would start by rewriting this function""" - states, acs, advs, rs, _ = convert_batch(batch) - values, ac_logprobs, entropy = self._evaluate(states, acs) - pi_err = -(advs * ac_logprobs).sum() - value_err = 0.5 * (values - rs).pow(2).sum() + states, actions, advs, rs, _ = convert_batch(batch) + values, action_log_probs, entropy = self._evaluate(states, actions) + pi_err = -advs.dot(action_log_probs.reshape(-1)) + value_err = F.mse_loss(values.reshape(-1), rs) self.optimizer.zero_grad() - overall_err = (pi_err + - value_err * self.config["vf_loss_coeff"] + - entropy * self.config["entropy_coeff"]) + + overall_err = sum([ + pi_err, + self.config["vf_loss_coeff"] * value_err, + self.config["entropy_coeff"] * entropy, + ]) + overall_err.backward() - torch.nn.utils.clip_grad_norm( - self._model.parameters(), self.config["grad_clip"]) + torch.nn.utils.clip_grad_norm_(self._model.parameters(), + self.config["grad_clip"]) diff --git a/python/ray/rllib/a3c/torchpolicy.py b/python/ray/rllib/a3c/torchpolicy.py index 8c7d86a08..8c6a28256 100644 --- a/python/ray/rllib/a3c/torchpolicy.py +++ b/python/ray/rllib/a3c/torchpolicy.py @@ -3,7 +3,6 @@ from __future__ import division from __future__ import print_function import torch -from torch.autograd import Variable from ray.rllib.a3c.policy import Policy from threading import Lock @@ -15,8 +14,13 @@ class TorchPolicy(Policy): The model is a separate object than the policy. This could be changed in the future.""" - def __init__(self, registry, ob_space, action_space, config, - name="local", summarize=True): + def __init__(self, + registry, + ob_space, + action_space, + config, + name="local", + summarize=True): self.registry = registry self.local_steps = 0 self.config = config @@ -28,7 +32,7 @@ class TorchPolicy(Policy): def apply_gradients(self, grads): self.optimizer.zero_grad() for g, p in zip(grads, self._model.parameters()): - p.grad = Variable(torch.from_numpy(g)) + p.grad = torch.from_numpy(g) self.optimizer.step() def get_weights(self): @@ -69,7 +73,7 @@ class TorchPolicy(Policy): def _backward(self, batch): """Implements the loss function and calculates the gradient. - Pytorch automatically generates a backward trace for each variable. + Pytorch automatically generates a backward trace for each tensor. Assumption right now is that variables are moved, so the backward trace is lost. diff --git a/python/ray/rllib/models/catalog.py b/python/ray/rllib/models/catalog.py index 532c138ba..603073dbf 100644 --- a/python/ray/rllib/models/catalog.py +++ b/python/ray/rllib/models/catalog.py @@ -180,11 +180,14 @@ class ModelCatalog(object): return registry.get(RLLIB_MODEL, model)( input_shape, num_outputs, options) + # TODO(alok): fix to handle Discrete(n) state spaces obs_rank = len(input_shape) - 1 if obs_rank > 1: return PyTorchVisionNet(input_shape, num_outputs, options) + # TODO(alok): overhaul PyTorchFCNet so it can just + # take input shape directly return PyTorchFCNet(input_shape[0], num_outputs, options) @staticmethod diff --git a/python/ray/rllib/models/pytorch/fcnet.py b/python/ray/rllib/models/pytorch/fcnet.py index b67f1365b..4c8163001 100644 --- a/python/ray/rllib/models/pytorch/fcnet.py +++ b/python/ray/rllib/models/pytorch/fcnet.py @@ -9,6 +9,7 @@ import torch.nn as nn class FullyConnectedNetwork(Model): """TODO(rliaw): Logits, Value should both be contained here""" + def _init(self, inputs, num_outputs, options): assert type(inputs) is int hiddens = options.get("fcnet_hiddens", [256, 256]) @@ -23,26 +24,29 @@ class FullyConnectedNetwork(Model): layers = [] last_layer_size = inputs for size in hiddens: - layers.append(SlimFC( - last_layer_size, size, - initializer=normc_initializer(1.0), - activation_fn=activation)) + layers.append( + SlimFC( + in_size=last_layer_size, + out_size=size, + initializer=normc_initializer(1.0), + activation_fn=activation)) last_layer_size = size self.hidden_layers = nn.Sequential(*layers) self.logits = SlimFC( - last_layer_size, num_outputs, + in_size=last_layer_size, + out_size=num_outputs, initializer=normc_initializer(0.01), activation_fn=None) - self.probs = nn.Softmax() self.value_branch = SlimFC( - last_layer_size, 1, + in_size=last_layer_size, + out_size=1, initializer=normc_initializer(1.0), activation_fn=None) def forward(self, obs): - """ Internal method - pass in Variables, not numpy arrays + """ Internal method - pass in torch tensors, not numpy arrays Args: obs: observations and features @@ -52,5 +56,5 @@ class FullyConnectedNetwork(Model): value: value function for each state""" res = self.hidden_layers(obs) logits = self.logits(res) - value = self.value_branch(res) + value = self.value_branch(res).reshape(-1) return logits, value diff --git a/python/ray/rllib/models/pytorch/misc.py b/python/ray/rllib/models/pytorch/misc.py index 5cb5a4718..dc725265c 100644 --- a/python/ray/rllib/models/pytorch/misc.py +++ b/python/ray/rllib/models/pytorch/misc.py @@ -5,31 +5,24 @@ from __future__ import print_function import numpy as np import torch -from torch.autograd import Variable def convert_batch(trajectory, has_features=False): """Convert trajectory from numpy to PT variable""" - states = Variable(torch.from_numpy( - trajectory["observations"]).float()) - acs = Variable(torch.from_numpy( - trajectory["actions"])) - advs = Variable(torch.from_numpy( - trajectory["advantages"].copy()).float()) - advs = advs.view(-1, 1) - rs = Variable(torch.from_numpy( - trajectory["value_targets"]).float()) - rs = rs.view(-1, 1) + states = torch.from_numpy(trajectory["obs"]).float() + acs = torch.from_numpy(trajectory["actions"]) + advs = torch.from_numpy( + trajectory["advantages"].copy()).float().reshape(-1) + rs = torch.from_numpy(trajectory["rewards"]).float().reshape(-1) if has_features: - features = [Variable(torch.from_numpy(f)) - for f in trajectory["features"]] + features = [torch.from_numpy(f) for f in trajectory["features"]] else: features = trajectory["features"] return states, acs, advs, rs, features def var_to_np(var): - return var.data.numpy()[0] + return var.detach().numpy() def normc_initializer(std=1.0): @@ -37,6 +30,7 @@ def normc_initializer(std=1.0): tensor.data.normal_(0, 1) tensor.data *= std / torch.sqrt( tensor.data.pow(2).sum(1, keepdim=True)) + return initializer diff --git a/python/ray/rllib/models/pytorch/model.py b/python/ray/rllib/models/pytorch/model.py index fd1577f33..876196741 100644 --- a/python/ray/rllib/models/pytorch/model.py +++ b/python/ray/rllib/models/pytorch/model.py @@ -29,9 +29,15 @@ class Model(nn.Module): class SlimConv2d(nn.Module): """Simple mock of tf.slim Conv2d""" - def __init__(self, in_channels, out_channels, kernel, stride, padding, - initializer=nn.init.xavier_uniform, - activation_fn=nn.ReLU, bias_init=0): + def __init__(self, + in_channels, + out_channels, + kernel, + stride, + padding, + initializer=nn.init.xavier_uniform_, + activation_fn=nn.ReLU, + bias_init=0): super(SlimConv2d, self).__init__() layers = [] if padding: @@ -39,7 +45,7 @@ class SlimConv2d(nn.Module): conv = nn.Conv2d(in_channels, out_channels, kernel, stride) if initializer: initializer(conv.weight) - nn.init.constant(conv.bias, bias_init) + nn.init.constant_(conv.bias, bias_init) layers.append(conv) if activation_fn: @@ -53,14 +59,18 @@ class SlimConv2d(nn.Module): class SlimFC(nn.Module): """Simple PyTorch of `linear` function""" - def __init__(self, in_size, size, initializer=None, - activation_fn=None, bias_init=0): + def __init__(self, + in_size, + out_size, + initializer=None, + activation_fn=None, + bias_init=0): super(SlimFC, self).__init__() layers = [] - linear = nn.Linear(in_size, size) + linear = nn.Linear(in_size, out_size) if initializer: initializer(linear.weight) - nn.init.constant(linear.bias, bias_init) + nn.init.constant_(linear.bias, bias_init) layers.append(linear) if activation_fn: layers.append(activation_fn()) diff --git a/python/ray/rllib/models/pytorch/visionnet.py b/python/ray/rllib/models/pytorch/visionnet.py index 99786a8d4..0fc862069 100644 --- a/python/ray/rllib/models/pytorch/visionnet.py +++ b/python/ray/rllib/models/pytorch/visionnet.py @@ -21,32 +21,31 @@ class VisionNetwork(Model): filters = options.get("conv_filters", [ [16, [8, 8], 4], [32, [4, 4], 2], - [512, [10, 10], 1] + [512, [10, 10], 1], ]) layers = [] in_channels, in_size = inputs[0], inputs[1:] for out_channels, kernel, stride in filters[:-1]: - padding, out_size = valid_padding( - in_size, kernel, [stride, stride]) - layers.append(SlimConv2d( - in_channels, out_channels, kernel, stride, padding)) + padding, out_size = valid_padding(in_size, kernel, + [stride, stride]) + layers.append( + SlimConv2d(in_channels, out_channels, kernel, stride, padding)) in_channels = out_channels in_size = out_size out_channels, kernel, stride = filters[-1] - layers.append(SlimConv2d( - in_channels, out_channels, kernel, stride, None)) + layers.append( + SlimConv2d(in_channels, out_channels, kernel, stride, None)) self._convs = nn.Sequential(*layers) self.logits = SlimFC( - out_channels, num_outputs, initializer=nn.init.xavier_uniform) - self.probs = nn.Softmax() + out_channels, num_outputs, initializer=nn.init.xavier_uniform_) self.value_branch = SlimFC( out_channels, 1, initializer=normc_initializer()) def hidden_layers(self, obs): - """ Internal method - pass in Variables, not numpy arrays + """ Internal method - pass in torch tensors, not numpy arrays args: obs: observations and features""" diff --git a/test/jenkins_tests/run_multi_node_tests.sh b/test/jenkins_tests/run_multi_node_tests.sh index c732be535..8bd010c3c 100755 --- a/test/jenkins_tests/run_multi_node_tests.sh +++ b/test/jenkins_tests/run_multi_node_tests.sh @@ -146,12 +146,19 @@ docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \ --stop '{"training_iteration": 2}' \ --config '{"kl_coeff": 1.0, "num_sgd_iter": 10, "sgd_stepsize": 1e-4, "sgd_batchsize": 64, "timesteps_per_batch": 2000, "num_workers": 1, "model": {"dim": 40, "conv_filters": [[16, [8, 8], 4], [32, [4, 4], 2], [512, [5, 5], 1]]}, "extra_frameskip": 4}' -# docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \ -# python /ray/python/ray/rllib/train.py \ -# --env PongDeterministic-v4 \ -# --run A3C \ -# --stop '{"training_iteration": 2}' \ -# --config '{"num_workers": 2, "use_lstm": false, "use_pytorch": true, "model": {"grayscale": true, "zero_mean": false, "dim": 80, "channel_major": true}}' +docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \ + python /ray/python/ray/rllib/train.py \ + --env PongDeterministic-v4 \ + --run A3C \ + --stop '{"training_iteration": 2}' \ + --config '{"num_workers": 2, "use_lstm": false, "use_pytorch": true, "model": {"grayscale": true, "zero_mean": false, "dim": 80, "channel_major": true}}' + +docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \ + python /ray/python/ray/rllib/train.py \ + --env CartPole-v1 \ + --run A3C \ + --stop '{"training_iteration": 2}' \ + --config '{"num_workers": 2, "use_lstm": false, "use_pytorch": true}' docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \ python /ray/python/ray/rllib/train.py \