[rllib] Propagate model options correctly in ARS / ES, to action dist of PPO (#2974)

* fix

* fix

* fix it

* propagate conf to action dist

* move carla example too

* rr

* Update policies.py

* wip

* lint
This commit is contained in:
Eric Liang 2018-10-01 12:49:39 -07:00 committed by GitHub
parent e4bea8d10e
commit b45bed4bce
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
29 changed files with 322 additions and 377 deletions

View file

@ -13,7 +13,7 @@ View the `code for this example`_.
.. note::
For an overview of Ray's reinforcement learning library, see `Ray RLlib <http://ray.readthedocs.io/en/latest/rllib.html>`__.
For an overview of Ray's reinforcement learning library, see `RLlib <http://ray.readthedocs.io/en/latest/rllib.html>`__.
To run the application, first install **ray** and then some dependencies:

View file

@ -6,7 +6,7 @@ View the `code for this example`_.
.. note::
For an overview of Ray's reinforcement learning library, see `Ray RLlib <http://ray.readthedocs.io/en/latest/rllib.html>`__.
For an overview of Ray's reinforcement learning library, see `RLlib <http://ray.readthedocs.io/en/latest/rllib.html>`__.
To run this example, you will need to install `TensorFlow with GPU support`_ (at

View file

@ -77,7 +77,7 @@ Ray comes with libraries that accelerate deep learning and reinforcement learnin
.. toctree::
:maxdepth: 1
:caption: Ray RLlib
:caption: RLlib
rllib.rst
rllib-training.rst

View file

@ -50,7 +50,7 @@ In the above example, note that the ``env_creator`` function takes in an ``env_c
OpenAI Gym
----------
RLlib uses Gym as its environment interface for single-agent training. For more information on how to implement a custom Gym environment, see the `gym.Env class definition <https://github.com/openai/gym/blob/master/gym/core.py>`__. You may also find the `SimpleCorridor <https://github.com/ray-project/ray/blob/master/examples/custom_env/custom_env.py>`__ and `Carla simulator <https://github.com/ray-project/ray/blob/master/examples/carla/env.py>`__ example env implementations useful as a reference.
RLlib uses Gym as its environment interface for single-agent training. For more information on how to implement a custom Gym environment, see the `gym.Env class definition <https://github.com/openai/gym/blob/master/gym/core.py>`__. You may also find the `SimpleCorridor <https://github.com/ray-project/ray/blob/master/python/ray/rllib/examples/custom_env.py>`__ and `Carla simulator <https://github.com/ray-project/ray/blob/master/python/ray/rllib/examples/carla/env.py>`__ example env implementations useful as a reference.
Performance
~~~~~~~~~~~

View file

@ -46,7 +46,7 @@ Custom models should subclass the common RLlib `model class <https://github.com/
},
})
For a full example of a custom model in code, see the `Carla RLlib model <https://github.com/ray-project/ray/blob/master/examples/carla/models.py>`__ and associated `training scripts <https://github.com/ray-project/ray/tree/master/examples/carla>`__. The ``CarlaModel`` class defined there operates over a composite (Tuple) observation space including both images and scalar measurements.
For a full example of a custom model in code, see the `Carla RLlib model <https://github.com/ray-project/ray/blob/master/python/ray/rllib/examples/carla/models.py>`__ and associated `training scripts <https://github.com/ray-project/ray/tree/master/python/ray/rllib/examples/carla>`__. The ``CarlaModel`` class defined there operates over a composite (Tuple) observation space including both images and scalar measurements.
Custom Preprocessors
--------------------

View file

@ -10,7 +10,7 @@ Learn more about RLlib's design by reading the `ICML paper <https://arxiv.org/ab
Installation
------------
RLlib has extra dependencies on top of ``ray``. First, you'll need to install either `PyTorch <http://pytorch.org/>`__ or `TensorFlow <https://www.tensorflow.org>`__. Then, install the Ray RLlib module:
RLlib has extra dependencies on top of ``ray``. First, you'll need to install either `PyTorch <http://pytorch.org/>`__ or `TensorFlow <https://www.tensorflow.org>`__. Then, install the RLlib module:
.. code-block:: bash

View file

@ -1,119 +0,0 @@
"""Collection of Carla scenarios, including those from the CoRL 2017 paper."""
TEST_WEATHERS = [0, 2, 5, 7, 9, 10, 11, 12, 13]
TRAIN_WEATHERS = [1, 3, 4, 6, 8, 14]
def build_scenario(
city, start, end, vehicles, pedestrians, max_steps, weathers):
return {
"city": city,
"num_vehicles": vehicles,
"num_pedestrians": pedestrians,
"weather_distribution": weathers,
"start_pos_id": start,
"end_pos_id": end,
"max_steps": max_steps,
}
# Simple scenario for Town02 that involves driving down a road
DEFAULT_SCENARIO = build_scenario(
city="Town02", start=36, end=40, vehicles=20, pedestrians=40,
max_steps=200, weathers=[0])
# Simple scenario for Town02 that involves driving down a road
LANE_KEEP = build_scenario(
city="Town02", start=36, end=40, vehicles=0, pedestrians=0,
max_steps=2000, weathers=[0])
# Scenarios from the CoRL2017 paper
POSES_TOWN1_STRAIGHT = [
[36, 40], [39, 35], [110, 114], [7, 3], [0, 4],
[68, 50], [61, 59], [47, 64], [147, 90], [33, 87],
[26, 19], [80, 76], [45, 49], [55, 44], [29, 107],
[95, 104], [84, 34], [53, 67], [22, 17], [91, 148],
[20, 107], [78, 70], [95, 102], [68, 44], [45, 69]]
POSES_TOWN1_ONE_CURVE = [
[138, 17], [47, 16], [26, 9], [42, 49], [140, 124],
[85, 98], [65, 133], [137, 51], [76, 66], [46, 39],
[40, 60], [0, 29], [4, 129], [121, 140], [2, 129],
[78, 44], [68, 85], [41, 102], [95, 70], [68, 129],
[84, 69], [47, 79], [110, 15], [130, 17], [0, 17]]
POSES_TOWN1_NAV = [
[105, 29], [27, 130], [102, 87], [132, 27], [24, 44],
[96, 26], [34, 67], [28, 1], [140, 134], [105, 9],
[148, 129], [65, 18], [21, 16], [147, 97], [42, 51],
[30, 41], [18, 107], [69, 45], [102, 95], [18, 145],
[111, 64], [79, 45], [84, 69], [73, 31], [37, 81]]
POSES_TOWN2_STRAIGHT = [
[38, 34], [4, 2], [12, 10], [62, 55], [43, 47],
[64, 66], [78, 76], [59, 57], [61, 18], [35, 39],
[12, 8], [0, 18], [75, 68], [54, 60], [45, 49],
[46, 42], [53, 46], [80, 29], [65, 63], [0, 81],
[54, 63], [51, 42], [16, 19], [17, 26], [77, 68]]
POSES_TOWN2_ONE_CURVE = [
[37, 76], [8, 24], [60, 69], [38, 10], [21, 1],
[58, 71], [74, 32], [44, 0], [71, 16], [14, 24],
[34, 11], [43, 14], [75, 16], [80, 21], [3, 23],
[75, 59], [50, 47], [11, 19], [77, 34], [79, 25],
[40, 63], [58, 76], [79, 55], [16, 61], [27, 11]]
POSES_TOWN2_NAV = [
[19, 66], [79, 14], [19, 57], [23, 1],
[53, 76], [42, 13], [31, 71], [33, 5],
[54, 30], [10, 61], [66, 3], [27, 12],
[79, 19], [2, 29], [16, 14], [5, 57],
[70, 73], [46, 67], [57, 50], [61, 49], [21, 12],
[51, 81], [77, 68], [56, 65], [43, 54]]
TOWN1_STRAIGHT = [
build_scenario("Town01", start, end, 0, 0, 300, TEST_WEATHERS)
for (start, end) in POSES_TOWN1_STRAIGHT]
TOWN1_ONE_CURVE = [
build_scenario("Town01", start, end, 0, 0, 600, TEST_WEATHERS)
for (start, end) in POSES_TOWN1_ONE_CURVE]
TOWN1_NAVIGATION = [
build_scenario("Town01", start, end, 0, 0, 900, TEST_WEATHERS)
for (start, end) in POSES_TOWN1_NAV]
TOWN1_NAVIGATION_DYNAMIC = [
build_scenario("Town01", start, end, 20, 50, 900, TEST_WEATHERS)
for (start, end) in POSES_TOWN1_NAV]
TOWN2_STRAIGHT = [
build_scenario("Town02", start, end, 0, 0, 300, TRAIN_WEATHERS)
for (start, end) in POSES_TOWN2_STRAIGHT]
TOWN2_STRAIGHT_DYNAMIC = [
build_scenario("Town02", start, end, 20, 50, 300, TRAIN_WEATHERS)
for (start, end) in POSES_TOWN2_STRAIGHT]
TOWN2_ONE_CURVE = [
build_scenario("Town02", start, end, 0, 0, 600, TRAIN_WEATHERS)
for (start, end) in POSES_TOWN2_ONE_CURVE]
TOWN2_NAVIGATION = [
build_scenario("Town02", start, end, 0, 0, 900, TRAIN_WEATHERS)
for (start, end) in POSES_TOWN2_NAV]
TOWN2_NAVIGATION_DYNAMIC = [
build_scenario("Town02", start, end, 20, 50, 900, TRAIN_WEATHERS)
for (start, end) in POSES_TOWN2_NAV]
TOWN1_ALL = (
TOWN1_STRAIGHT + TOWN1_ONE_CURVE + TOWN1_NAVIGATION +
TOWN1_NAVIGATION_DYNAMIC)
TOWN2_ALL = (
TOWN2_STRAIGHT + TOWN2_ONE_CURVE + TOWN2_NAVIGATION +
TOWN2_NAVIGATION_DYNAMIC)

View file

@ -1 +0,0 @@
Example of using a custom gym env with RLlib.

View file

@ -25,19 +25,17 @@ Result = namedtuple("Result", [
])
DEFAULT_CONFIG = with_common_config({
'noise_stdev': 0.02, # std deviation of parameter noise
'num_rollouts': 32, # number of perturbs to try
'rollouts_used': 32, # number of perturbs to keep in gradient estimate
'num_workers': 2,
'sgd_stepsize': 0.01, # sgd step-size
'observation_filter': "MeanStdFilter",
'noise_size': 250000000,
'eval_prob': 0.03, # probability of evaluating the parameter rewards
'report_length': 10, # how many of the last rewards we average over
'env_config': {},
'offset': 0,
'policy_type': "LinearPolicy", # ["LinearPolicy", "MLPPolicy"]
"fcnet_hiddens": [32, 32], # fcnet structure of MLPPolicy
"noise_stdev": 0.02, # std deviation of parameter noise
"num_rollouts": 32, # number of perturbs to try
"rollouts_used": 32, # number of perturbs to keep in gradient estimate
"num_workers": 2,
"sgd_stepsize": 0.01, # sgd step-size
"observation_filter": "MeanStdFilter",
"noise_size": 250000000,
"eval_prob": 0.03, # probability of evaluating the parameter rewards
"report_length": 10, # how many of the last rewards we average over
"env_config": {},
"offset": 0,
})
@ -67,15 +65,9 @@ class SharedNoiseTable(object):
@ray.remote
class Worker(object):
def __init__(self,
config,
policy_params,
env_creator,
noise,
min_task_runtime=0.2):
def __init__(self, config, env_creator, noise, min_task_runtime=0.2):
self.min_task_runtime = min_task_runtime
self.config = config
self.policy_params = policy_params
self.noise = SharedNoiseTable(noise)
self.env = env_creator(config["env_config"])
@ -83,15 +75,9 @@ class Worker(object):
self.preprocessor = models.ModelCatalog.get_preprocessor(self.env)
self.sess = utils.make_session(single_threaded=True)
if config["policy_type"] == "LinearPolicy":
self.policy = policies.LinearPolicy(
self.sess, self.env.action_space, self.preprocessor,
config["observation_filter"], **policy_params)
else:
self.policy = policies.MLPPolicy(
self.sess, self.env.action_space, self.preprocessor,
config["observation_filter"], config["fcnet_hiddens"],
**policy_params)
self.policy = policies.GenericPolicy(
self.sess, self.env.action_space, self.preprocessor,
config["observation_filter"], config["model"])
def rollout(self, timestep_limit, add_noise=False):
rollout_rewards, rollout_length = policies.rollout(
@ -160,25 +146,14 @@ class ARSAgent(Agent):
return Resources(cpu=1, gpu=0, extra_cpu=cf["num_workers"])
def _init(self):
policy_params = {"action_noise_std": 0.0}
# register the linear network
utils.register_linear_network()
env = self.env_creator(self.config["env_config"])
from ray.rllib import models
preprocessor = models.ModelCatalog.get_preprocessor(env)
self.sess = utils.make_session(single_threaded=False)
if self.config["policy_type"] == "LinearPolicy":
self.policy = policies.LinearPolicy(
self.sess, env.action_space, preprocessor,
self.config["observation_filter"], **policy_params)
else:
self.policy = policies.MLPPolicy(
self.sess, env.action_space, preprocessor,
self.config["observation_filter"],
self.config["fcnet_hiddens"], **policy_params)
self.policy = policies.GenericPolicy(
self.sess, env.action_space, preprocessor,
self.config["observation_filter"], self.config["model"])
self.optimizer = optimizers.SGD(self.policy,
self.config["sgd_stepsize"])
@ -194,8 +169,8 @@ class ARSAgent(Agent):
# Create the actors.
print("Creating actors.")
self.workers = [
Worker.remote(self.config, policy_params, self.env_creator,
noise_id) for _ in range(self.config["num_workers"])
Worker.remote(self.config, self.env_creator, noise_id)
for _ in range(self.config["num_workers"])
]
self.episodes_so_far = 0

View file

@ -11,7 +11,6 @@ import tensorflow as tf
import ray
from ray.rllib.utils.filter import get_filter
from ray.rllib.utils.error import UnsupportedSpaceException
from ray.rllib.models import ModelCatalog
@ -59,14 +58,8 @@ class GenericPolicy(object):
action_space,
preprocessor,
observation_filter,
action_noise_std,
options={}):
if len(preprocessor.shape) > 1:
raise UnsupportedSpaceException(
"Observation space {} is not supported with ARS.".format(
preprocessor.shape))
model_config,
action_noise_std=0.0):
self.sess = sess
self.action_space = action_space
self.action_noise_std = action_noise_std
@ -78,9 +71,9 @@ class GenericPolicy(object):
# Policy network.
dist_class, dist_dim = ModelCatalog.get_action_dist(
action_space, dist_type="deterministic")
action_space, model_config, dist_type="deterministic")
model = ModelCatalog.get_model(self.inputs, dist_dim, options=options)
model = ModelCatalog.get_model(self.inputs, dist_dim, model_config)
dist = dist_class(model.outputs)
self.sampler = dist.sample()
@ -106,31 +99,3 @@ class GenericPolicy(object):
def get_weights(self):
return self.variables.get_flat()
class LinearPolicy(GenericPolicy):
def __init__(self, sess, action_space, preprocessor, observation_filter,
action_noise_std):
options = {"custom_model": "LinearNetwork"}
GenericPolicy.__init__(
self,
sess,
action_space,
preprocessor,
observation_filter,
action_noise_std,
options=options)
class MLPPolicy(GenericPolicy):
def __init__(self, sess, action_space, preprocessor, observation_filter,
fcnet_hiddens, action_noise_std):
options = {"fcnet_hiddens": fcnet_hiddens}
GenericPolicy.__init__(
self,
sess,
action_space,
preprocessor,
observation_filter,
action_noise_std,
options=options)

View file

@ -7,9 +7,6 @@ from __future__ import print_function
import numpy as np
import tensorflow as tf
from ray.rllib.models import ModelCatalog, Model
import tensorflow.contrib.slim as slim
from ray.rllib.models.misc import normc_initializer
def compute_ranks(x):
@ -62,21 +59,3 @@ def batched_weighted_sum(weights, vecs, batch_size):
np.asarray(batch_vecs, dtype=np.float32))
num_items_summed += len(batch_weights)
return total, num_items_summed
class LinearNetwork(Model):
"""Generic linear network."""
def _build_layers(self, inputs, num_outputs, _):
with tf.name_scope("linear"):
output = slim.fully_connected(
inputs,
num_outputs,
weights_initializer=normc_initializer(0.01),
activation_fn=None,
)
return output, inputs
def register_linear_network():
ModelCatalog.register_custom_model("LinearNetwork", LinearNetwork)

View file

@ -10,7 +10,7 @@ import numpy as np
import time
import ray
from ray.rllib.agents import Agent
from ray.rllib.agents import Agent, with_common_config
from ray.tune.trial import Resources
from ray.rllib.agents.es import optimizers
@ -24,7 +24,7 @@ Result = namedtuple("Result", [
"eval_returns", "eval_lengths"
])
DEFAULT_CONFIG = {
DEFAULT_CONFIG = with_common_config({
"l2_coeff": 0.005,
"noise_stdev": 0.02,
"episodes_per_batch": 1000,
@ -38,7 +38,8 @@ DEFAULT_CONFIG = {
"report_length": 10,
"env": None,
"env_config": {},
}
"model": {},
})
@ray.remote
@ -81,7 +82,7 @@ class Worker(object):
self.sess = utils.make_session(single_threaded=True)
self.policy = policies.GenericPolicy(
self.sess, self.env.action_space, self.preprocessor,
config["observation_filter"], **policy_params)
config["observation_filter"], config["model"], **policy_params)
def rollout(self, timestep_limit, add_noise=True):
rollout_rewards, rollout_length = policies.rollout(
@ -161,7 +162,8 @@ class ESAgent(Agent):
self.sess = utils.make_session(single_threaded=False)
self.policy = policies.GenericPolicy(
self.sess, env.action_space, preprocessor,
self.config["observation_filter"], **policy_params)
self.config["observation_filter"], self.config["model"],
**policy_params)
self.optimizer = optimizers.Adam(self.policy, self.config["stepsize"])
self.report_length = self.config["report_length"]

View file

@ -39,7 +39,7 @@ def rollout(policy, env, timestep_limit=None, add_noise=False):
class GenericPolicy(object):
def __init__(self, sess, action_space, preprocessor, observation_filter,
action_noise_std):
model_options, action_noise_std):
self.sess = sess
self.action_space = action_space
self.action_noise_std = action_noise_std
@ -51,8 +51,8 @@ class GenericPolicy(object):
# Policy network.
dist_class, dist_dim = ModelCatalog.get_action_dist(
self.action_space, dist_type="deterministic")
model = ModelCatalog.get_model(self.inputs, dist_dim)
self.action_space, model_options, dist_type="deterministic")
model = ModelCatalog.get_model(self.inputs, dist_dim, model_options)
dist = dist_class(model.outputs)
self.sampler = dist.sample()

View file

@ -24,8 +24,8 @@ class PGPolicyGraph(TFPolicyGraph):
obs = tf.placeholder(tf.float32, shape=[None] + list(obs_space.shape))
dist_class, self.logit_dim = ModelCatalog.get_action_dist(
action_space, self.config["model"])
self.model = ModelCatalog.get_model(
obs, self.logit_dim, options=self.config["model"])
self.model = ModelCatalog.get_model(obs, self.logit_dim,
self.config["model"])
action_dist = dist_class(self.model.outputs) # logit for each action
# Setup policy loss

View file

@ -54,7 +54,7 @@ class PPOLoss(object):
vf_loss_coeff (float): Coefficient of the value function loss
use_gae (bool): If true, use the Generalized Advantage Estimator.
"""
dist_cls, _ = ModelCatalog.get_action_dist(action_space)
dist_cls, _ = ModelCatalog.get_action_dist(action_space, {})
prev_dist = dist_cls(logits)
# Make loss functions.
logp_ratio = tf.exp(
@ -108,7 +108,8 @@ class PPOPolicyGraph(LearningRateSchedule, TFPolicyGraph):
self.config = config
self.kl_coeff_val = self.config["kl_coeff"]
self.kl_target = self.config["kl_target"]
dist_cls, logit_dim = ModelCatalog.get_action_dist(action_space)
dist_cls, logit_dim = ModelCatalog.get_action_dist(
action_space, self.config["model"])
if existing_inputs:
obs_ph, value_targets_ph, adv_ph, act_ph, \

View file

@ -31,7 +31,6 @@ run_experiments({
"carla-a3c": {
"run": "A3C",
"env": "carla_env",
"trial_resources": {"cpu": 4, "gpu": 1},
"config": {
"env_config": env_config,
"model": {

View file

@ -31,7 +31,6 @@ run_experiments({
"carla-dqn": {
"run": "DQN",
"env": "carla_env",
"trial_resources": {"cpu": 4, "gpu": 1},
"config": {
"env_config": env_config,
"model": {
@ -49,9 +48,6 @@ run_experiments({
"learning_starts": 1000,
"schedule_max_timesteps": 100000,
"gamma": 0.8,
"tf_session_args": {
"gpu_options": {"allow_growth": True},
},
},
},
})

View file

@ -33,8 +33,8 @@ if CARLA_OUT_PATH and not os.path.exists(CARLA_OUT_PATH):
os.makedirs(CARLA_OUT_PATH)
# Set this to the path of your Carla binary
SERVER_BINARY = os.environ.get(
"CARLA_SERVER", os.path.expanduser("~/CARLA_0.7.0/CarlaUE4.sh"))
SERVER_BINARY = os.environ.get("CARLA_SERVER",
os.path.expanduser("~/CARLA_0.7.0/CarlaUE4.sh"))
assert os.path.exists(SERVER_BINARY)
if "CARLA_PY_PATH" in os.environ:
@ -97,7 +97,6 @@ ENV_CONFIG = {
"squash_action_logits": False,
}
DISCRETE_ACTIONS = {
# coast
0: [0.0, 0.0],
@ -119,7 +118,6 @@ DISCRETE_ACTIONS = {
8: [-0.5, 0.5],
}
live_carla_processes = set()
@ -133,7 +131,6 @@ atexit.register(cleanup)
class CarlaEnv(gym.Env):
def __init__(self, config=ENV_CONFIG):
self.config = config
self.city = self.config["server_map"].split("/")[-1]
@ -143,21 +140,27 @@ class CarlaEnv(gym.Env):
if config["discrete_actions"]:
self.action_space = Discrete(len(DISCRETE_ACTIONS))
else:
self.action_space = Box(-1.0, 1.0, shape=(2,), dtype=np.float32)
self.action_space = Box(-1.0, 1.0, shape=(2, ), dtype=np.float32)
if config["use_depth_camera"]:
image_space = Box(
-1.0, 1.0, shape=(
config["y_res"], config["x_res"],
1 * config["framestack"]), dtype=np.float32)
-1.0,
1.0,
shape=(config["y_res"], config["x_res"],
1 * config["framestack"]),
dtype=np.float32)
else:
image_space = Box(
0, 255, shape=(
config["y_res"], config["x_res"],
3 * config["framestack"]), dtype=np.uint8)
0,
255,
shape=(config["y_res"], config["x_res"],
3 * config["framestack"]),
dtype=np.uint8)
self.observation_space = Tuple( # forward_speed, dist to goal
[image_space,
Discrete(len(COMMANDS_ENUM)), # next_command
Box(-128.0, 128.0, shape=(2,), dtype=np.float32)])
[
image_space,
Discrete(len(COMMANDS_ENUM)), # next_command
Box(-128.0, 128.0, shape=(2, ), dtype=np.float32)
])
# TODO(ekl) this isn't really a proper gym spec
self._spec = lambda: None
@ -185,11 +188,13 @@ class CarlaEnv(gym.Env):
# Create a new server process and start the client.
self.server_port = random.randint(10000, 60000)
self.server_process = subprocess.Popen(
[SERVER_BINARY, self.config["server_map"],
"-windowed", "-ResX=400", "-ResY=300",
"-carla-server",
"-carla-world-port={}".format(self.server_port)],
preexec_fn=os.setsid, stdout=open(os.devnull, "w"))
[
SERVER_BINARY, self.config["server_map"], "-windowed",
"-ResX=400", "-ResY=300", "-carla-server",
"-carla-world-port={}".format(self.server_port)
],
preexec_fn=os.setsid,
stdout=open(os.devnull, "w"))
live_carla_processes.add(os.getpgid(self.server_process.pid))
for i in range(RETRIES_ON_ERROR):
@ -257,14 +262,14 @@ class CarlaEnv(gym.Env):
if self.config["use_depth_camera"]:
camera1 = Camera("CameraDepth", PostProcessing="Depth")
camera1.set_image_size(
self.config["render_x_res"], self.config["render_y_res"])
camera1.set_image_size(self.config["render_x_res"],
self.config["render_y_res"])
camera1.set_position(30, 0, 130)
settings.add_sensor(camera1)
camera2 = Camera("CameraRGB")
camera2.set_image_size(
self.config["render_x_res"], self.config["render_y_res"])
camera2.set_image_size(self.config["render_x_res"],
self.config["render_y_res"])
camera2.set_position(30, 0, 130)
settings.add_sensor(camera2)
@ -274,13 +279,14 @@ class CarlaEnv(gym.Env):
self.start_pos = positions[self.scenario["start_pos_id"]]
self.end_pos = positions[self.scenario["end_pos_id"]]
self.start_coord = [
self.start_pos.location.x // 100, self.start_pos.location.y // 100]
self.start_pos.location.x // 100, self.start_pos.location.y // 100
]
self.end_coord = [
self.end_pos.location.x // 100, self.end_pos.location.y // 100]
print(
"Start pos {} ({}), end {} ({})".format(
self.scenario["start_pos_id"], self.start_coord,
self.scenario["end_pos_id"], self.end_coord))
self.end_pos.location.x // 100, self.end_pos.location.y // 100
]
print("Start pos {} ({}), end {} ({})".format(
self.scenario["start_pos_id"], self.start_coord,
self.scenario["end_pos_id"], self.end_coord))
# Notify the server that we want to start the episode at the
# player_start index. This function blocks until the server is ready
@ -300,11 +306,10 @@ class CarlaEnv(gym.Env):
prev_image = image
if self.config["framestack"] == 2:
image = np.concatenate([prev_image, image], axis=2)
obs = (
image,
COMMAND_ORDINAL[py_measurements["next_command"]],
[py_measurements["forward_speed"],
py_measurements["distance_to_goal"]])
obs = (image, COMMAND_ORDINAL[py_measurements["next_command"]], [
py_measurements["forward_speed"],
py_measurements["distance_to_goal"]
])
self.last_obs = obs
return obs
@ -313,9 +318,8 @@ class CarlaEnv(gym.Env):
obs = self._step(action)
return obs
except Exception:
print(
"Error during step, terminating episode early",
traceback.format_exc())
print("Error during step, terminating episode early",
traceback.format_exc())
self.clear_server_state()
return (self.last_obs, 0.0, True, {})
@ -336,12 +340,14 @@ class CarlaEnv(gym.Env):
hand_brake = False
if self.config["verbose"]:
print(
"steer", steer, "throttle", throttle, "brake", brake,
"reverse", reverse)
print("steer", steer, "throttle", throttle, "brake", brake,
"reverse", reverse)
self.client.send_control(
steer=steer, throttle=throttle, brake=brake, hand_brake=hand_brake,
steer=steer,
throttle=throttle,
brake=brake,
hand_brake=hand_brake,
reverse=reverse)
# Process observations
@ -359,15 +365,14 @@ class CarlaEnv(gym.Env):
"reverse": reverse,
"hand_brake": hand_brake,
}
reward = compute_reward(
self, self.prev_measurement, py_measurements)
reward = compute_reward(self, self.prev_measurement, py_measurements)
self.total_reward += reward
py_measurements["reward"] = reward
py_measurements["total_reward"] = self.total_reward
done = (self.num_steps > self.scenario["max_steps"] or
py_measurements["next_command"] == "REACH_GOAL" or
(self.config["early_terminate_on_collision"] and
collided_done(py_measurements)))
done = (self.num_steps > self.scenario["max_steps"]
or py_measurements["next_command"] == "REACH_GOAL"
or (self.config["early_terminate_on_collision"]
and collided_done(py_measurements)))
py_measurements["done"] = done
self.prev_measurement = py_measurements
@ -377,8 +382,7 @@ class CarlaEnv(gym.Env):
self.measurements_file = open(
os.path.join(
CARLA_OUT_PATH,
"measurements_{}.json".format(self.episode_id)),
"w")
"measurements_{}.json".format(self.episode_id)), "w")
self.measurements_file.write(json.dumps(py_measurements))
self.measurements_file.write("\n")
if done:
@ -389,9 +393,8 @@ class CarlaEnv(gym.Env):
self.num_steps += 1
image = self.preprocess_image(image)
return (
self.encode_obs(image, py_measurements), reward, done,
py_measurements)
return (self.encode_obs(image, py_measurements), reward, done,
py_measurements)
def images_to_video(self):
videos_dir = os.path.join(CARLA_OUT_PATH, "Videos")
@ -413,15 +416,15 @@ class CarlaEnv(gym.Env):
if self.config["use_depth_camera"]:
assert self.config["use_depth_camera"]
data = (image.data - 0.5) * 2
data = data.reshape(
self.config["render_y_res"], self.config["render_x_res"], 1)
data = data.reshape(self.config["render_y_res"],
self.config["render_x_res"], 1)
data = cv2.resize(
data, (self.config["x_res"], self.config["y_res"]),
interpolation=cv2.INTER_AREA)
data = np.expand_dims(data, 2)
else:
data = image.data.reshape(
self.config["render_y_res"], self.config["render_x_res"], 3)
data = image.data.reshape(self.config["render_y_res"],
self.config["render_x_res"], 3)
data = cv2.resize(
data, (self.config["x_res"], self.config["y_res"]),
interpolation=cv2.INTER_AREA)
@ -448,36 +451,39 @@ class CarlaEnv(gym.Env):
cur = measurements.player_measurements
if self.config["enable_planner"]:
next_command = COMMANDS_ENUM[
self.planner.get_next_command(
[cur.transform.location.x, cur.transform.location.y,
GROUND_Z],
[cur.transform.orientation.x, cur.transform.orientation.y,
GROUND_Z],
[self.end_pos.location.x, self.end_pos.location.y,
GROUND_Z],
[self.end_pos.orientation.x, self.end_pos.orientation.y,
GROUND_Z])
]
next_command = COMMANDS_ENUM[self.planner.get_next_command(
[cur.transform.location.x, cur.transform.location.y, GROUND_Z],
[
cur.transform.orientation.x, cur.transform.orientation.y,
GROUND_Z
],
[self.end_pos.location.x, self.end_pos.location.y, GROUND_Z], [
self.end_pos.orientation.x, self.end_pos.orientation.y,
GROUND_Z
])]
else:
next_command = "LANE_FOLLOW"
if next_command == "REACH_GOAL":
distance_to_goal = 0.0 # avoids crash in planner
elif self.config["enable_planner"]:
distance_to_goal = self.planner.get_shortest_path_distance(
[cur.transform.location.x, cur.transform.location.y, GROUND_Z],
[cur.transform.orientation.x, cur.transform.orientation.y,
GROUND_Z],
[self.end_pos.location.x, self.end_pos.location.y, GROUND_Z],
[self.end_pos.orientation.x, self.end_pos.orientation.y,
GROUND_Z]) / 100
distance_to_goal = self.planner.get_shortest_path_distance([
cur.transform.location.x, cur.transform.location.y, GROUND_Z
], [
cur.transform.orientation.x, cur.transform.orientation.y,
GROUND_Z
], [self.end_pos.location.x, self.end_pos.location.y, GROUND_Z], [
self.end_pos.orientation.x, self.end_pos.orientation.y,
GROUND_Z
]) / 100
else:
distance_to_goal = -1
distance_to_goal_euclidean = float(np.linalg.norm(
[cur.transform.location.x - self.end_pos.location.x,
cur.transform.location.y - self.end_pos.location.y]) / 100)
distance_to_goal_euclidean = float(
np.linalg.norm([
cur.transform.location.x - self.end_pos.location.x,
cur.transform.location.y - self.end_pos.location.y
]) / 100)
py_measurements = {
"episode_id": self.episode_id,
@ -513,8 +519,8 @@ class CarlaEnv(gym.Env):
if not os.path.exists(out_dir):
os.makedirs(out_dir)
out_file = os.path.join(
out_dir,
"{}_{:>04}.jpg".format(self.episode_id, self.num_steps))
out_dir, "{}_{:>04}.jpg".format(self.episode_id,
self.num_steps))
scipy.misc.imsave(out_file, image.data)
assert observation is not None, sensor_data
@ -621,8 +627,7 @@ REWARD_FUNCTIONS = {
def compute_reward(env, prev, current):
return REWARD_FUNCTIONS[env.config["reward_function"]](
env, prev, current)
return REWARD_FUNCTIONS[env.config["reward_function"]](env, prev, current)
def print_measurements(measurements):
@ -654,9 +659,8 @@ def sigmoid(x):
def collided_done(py_measurements):
m = py_measurements
collided = (
m["collision_vehicles"] > 0 or m["collision_pedestrians"] > 0 or
m["collision_other"] > 0)
collided = (m["collision_vehicles"] > 0 or m["collision_pedestrians"] > 0
or m["collision_other"] > 0)
return bool(collided or m["total_reward"] < -100)

View file

@ -43,8 +43,8 @@ class CarlaModel(Model):
(inputs.shape.as_list()[1:], expected_shape)
# Reshape the input vector back into its components
vision_in = tf.reshape(
inputs[:, :image_size], [tf.shape(inputs)[0]] + image_shape)
vision_in = tf.reshape(inputs[:, :image_size],
[tf.shape(inputs)[0]] + image_shape)
metrics_in = inputs[:, image_size:]
print("Vision in shape", vision_in)
print("Metrics in shape", metrics_in)
@ -53,18 +53,26 @@ class CarlaModel(Model):
with tf.name_scope("carla_vision"):
for i, (out_size, kernel, stride) in enumerate(convs[:-1], 1):
vision_in = slim.conv2d(
vision_in, out_size, kernel, stride,
vision_in,
out_size,
kernel,
stride,
scope="conv{}".format(i))
out_size, kernel, stride = convs[-1]
vision_in = slim.conv2d(
vision_in, out_size, kernel, stride,
padding="VALID", scope="conv_out")
vision_in,
out_size,
kernel,
stride,
padding="VALID",
scope="conv_out")
vision_in = tf.squeeze(vision_in, [1, 2])
# Setup metrics layer
with tf.name_scope("carla_metrics"):
metrics_in = slim.fully_connected(
metrics_in, 64,
metrics_in,
64,
weights_initializer=xavier_initializer(),
activation_fn=activation,
scope="metrics_out")
@ -79,15 +87,18 @@ class CarlaModel(Model):
print("Shape of concatenated out is", last_layer.shape)
for size in hiddens:
last_layer = slim.fully_connected(
last_layer, size,
last_layer,
size,
weights_initializer=xavier_initializer(),
activation_fn=activation,
scope="fc{}".format(i))
i += 1
output = slim.fully_connected(
last_layer, num_outputs,
last_layer,
num_outputs,
weights_initializer=normc_initializer(0.01),
activation_fn=None, scope="fc_out")
activation_fn=None,
scope="fc_out")
return output, last_layer

View file

@ -31,7 +31,6 @@ run_experiments({
"carla-ppo": {
"run": "PPO",
"env": "carla_env",
"trial_resources": {"cpu": 4, "gpu": 1},
"config": {
"env_config": env_config,
"model": {
@ -55,7 +54,9 @@ run_experiments({
"sgd_batchsize": 32,
"devices": ["/gpu:0"],
"tf_session_args": {
"gpu_options": {"allow_growth": True}
"gpu_options": {
"allow_growth": True
}
}
},
},

View file

@ -0,0 +1,131 @@
"""Collection of Carla scenarios, including those from the CoRL 2017 paper."""
TEST_WEATHERS = [0, 2, 5, 7, 9, 10, 11, 12, 13]
TRAIN_WEATHERS = [1, 3, 4, 6, 8, 14]
def build_scenario(city, start, end, vehicles, pedestrians, max_steps,
weathers):
return {
"city": city,
"num_vehicles": vehicles,
"num_pedestrians": pedestrians,
"weather_distribution": weathers,
"start_pos_id": start,
"end_pos_id": end,
"max_steps": max_steps,
}
# Simple scenario for Town02 that involves driving down a road
DEFAULT_SCENARIO = build_scenario(
city="Town02",
start=36,
end=40,
vehicles=20,
pedestrians=40,
max_steps=200,
weathers=[0])
# Simple scenario for Town02 that involves driving down a road
LANE_KEEP = build_scenario(
city="Town02",
start=36,
end=40,
vehicles=0,
pedestrians=0,
max_steps=2000,
weathers=[0])
# Scenarios from the CoRL2017 paper
POSES_TOWN1_STRAIGHT = [[36, 40], [39, 35], [110, 114], [7, 3], [0, 4], [
68, 50
], [61, 59], [47, 64], [147, 90], [33, 87], [26, 19], [80, 76], [45, 49], [
55, 44
], [29, 107], [95, 104], [84, 34], [53, 67], [22, 17], [91, 148], [20, 107],
[78, 70], [95, 102], [68, 44], [45, 69]]
POSES_TOWN1_ONE_CURVE = [[138, 17], [47, 16], [26, 9], [42, 49], [140, 124], [
85, 98
], [65, 133], [137, 51], [76, 66], [46, 39], [40, 60], [0, 29], [4, 129], [
121, 140
], [2, 129], [78, 44], [68, 85], [41, 102], [95, 70], [68, 129], [84, 69],
[47, 79], [110, 15], [130, 17], [0, 17]]
POSES_TOWN1_NAV = [[105, 29], [27, 130], [102, 87], [132, 27], [24, 44], [
96, 26
], [34, 67], [28, 1], [140, 134], [105, 9], [148, 129], [65, 18], [21, 16], [
147, 97
], [42, 51], [30, 41], [18, 107], [69, 45], [102, 95], [18, 145], [111, 64],
[79, 45], [84, 69], [73, 31], [37, 81]]
POSES_TOWN2_STRAIGHT = [[38, 34], [4, 2], [12, 10], [62, 55], [43, 47], [
64, 66
], [78, 76], [59, 57], [61, 18], [35, 39], [12, 8], [0, 18], [75, 68], [
54, 60
], [45, 49], [46, 42], [53, 46], [80, 29], [65, 63], [0, 81], [54, 63],
[51, 42], [16, 19], [17, 26], [77, 68]]
POSES_TOWN2_ONE_CURVE = [[37, 76], [8, 24], [60, 69], [38, 10], [21, 1], [
58, 71
], [74, 32], [44, 0], [71, 16], [14, 24], [34, 11], [43, 14], [75, 16], [
80, 21
], [3, 23], [75, 59], [50, 47], [11, 19], [77, 34], [79, 25], [40, 63],
[58, 76], [79, 55], [16, 61], [27, 11]]
POSES_TOWN2_NAV = [[19, 66], [79, 14], [19, 57], [23, 1], [53, 76], [42, 13], [
31, 71
], [33, 5], [54, 30], [10, 61], [66, 3], [27, 12], [79, 19], [2, 29], [16, 14],
[5, 57], [70, 73], [46, 67], [57, 50], [61, 49], [21, 12],
[51, 81], [77, 68], [56, 65], [43, 54]]
TOWN1_STRAIGHT = [
build_scenario("Town01", start, end, 0, 0, 300, TEST_WEATHERS)
for (start, end) in POSES_TOWN1_STRAIGHT
]
TOWN1_ONE_CURVE = [
build_scenario("Town01", start, end, 0, 0, 600, TEST_WEATHERS)
for (start, end) in POSES_TOWN1_ONE_CURVE
]
TOWN1_NAVIGATION = [
build_scenario("Town01", start, end, 0, 0, 900, TEST_WEATHERS)
for (start, end) in POSES_TOWN1_NAV
]
TOWN1_NAVIGATION_DYNAMIC = [
build_scenario("Town01", start, end, 20, 50, 900, TEST_WEATHERS)
for (start, end) in POSES_TOWN1_NAV
]
TOWN2_STRAIGHT = [
build_scenario("Town02", start, end, 0, 0, 300, TRAIN_WEATHERS)
for (start, end) in POSES_TOWN2_STRAIGHT
]
TOWN2_STRAIGHT_DYNAMIC = [
build_scenario("Town02", start, end, 20, 50, 300, TRAIN_WEATHERS)
for (start, end) in POSES_TOWN2_STRAIGHT
]
TOWN2_ONE_CURVE = [
build_scenario("Town02", start, end, 0, 0, 600, TRAIN_WEATHERS)
for (start, end) in POSES_TOWN2_ONE_CURVE
]
TOWN2_NAVIGATION = [
build_scenario("Town02", start, end, 0, 0, 900, TRAIN_WEATHERS)
for (start, end) in POSES_TOWN2_NAV
]
TOWN2_NAVIGATION_DYNAMIC = [
build_scenario("Town02", start, end, 20, 50, 900, TRAIN_WEATHERS)
for (start, end) in POSES_TOWN2_NAV
]
TOWN1_ALL = (TOWN1_STRAIGHT + TOWN1_ONE_CURVE + TOWN1_NAVIGATION +
TOWN1_NAVIGATION_DYNAMIC)
TOWN2_ALL = (TOWN2_STRAIGHT + TOWN2_ONE_CURVE + TOWN2_NAVIGATION +
TOWN2_NAVIGATION_DYNAMIC)

View file

@ -32,7 +32,6 @@ run_experiments({
"carla-a3c": {
"run": "A3C",
"env": "carla_env",
"trial_resources": {"cpu": 5, "extra_gpu": 2},
"config": {
"env_config": env_config,
"use_gpu_for_workers": True,

View file

@ -25,21 +25,26 @@ register_env(env_name, lambda env_config: CarlaEnv(env_config))
register_carla_model()
ray.init()
def shape_out(spec):
return (spec.config.env_config.framestack *
(spec.config.env_config.use_depth_camera and 1 or 3))
run_experiments({
"carla-dqn": {
"run": "DQN",
"env": "carla_env",
"trial_resources": {"cpu": 4, "gpu": 1},
"config": {
"env_config": env_config,
"model": {
"custom_model": "carla",
"custom_options": {
"image_shape": [
80, 80,
lambda spec: spec.config.env_config.framestack * (
spec.config.env_config.use_depth_camera and 1 or 3
),
80,
80,
shape_out,
],
},
"conv_filters": [
@ -53,7 +58,9 @@ run_experiments({
"schedule_max_timesteps": 100000,
"gamma": 0.8,
"tf_session_args": {
"gpu_options": {"allow_growth": True},
"gpu_options": {
"allow_growth": True
},
},
},
},

View file

@ -28,14 +28,14 @@ run_experiments({
"carla": {
"run": "PPO",
"env": "carla_env",
"trial_resources": {"cpu": 4, "gpu": 1},
"config": {
"env_config": env_config,
"model": {
"custom_model": "carla",
"custom_options": {
"image_shape": [
env_config["x_res"], env_config["y_res"], 6],
env_config["x_res"], env_config["y_res"], 6
],
},
"conv_filters": [
[16, [8, 8], 4],
@ -44,17 +44,14 @@ run_experiments({
],
},
"num_workers": 1,
"timesteps_per_batch": 2000,
"min_steps_per_task": 100,
"train_batch_size": 2000,
"sample_batch_size": 100,
"lambda": 0.95,
"clip_param": 0.2,
"num_sgd_iter": 20,
"sgd_stepsize": 0.0001,
"sgd_batchsize": 32,
"devices": ["/gpu:0"],
"tf_session_args": {
"gpu_options": {"allow_growth": True}
}
"lr": 0.0001,
"sgd_minibatch_size": 32,
"num_gpus": 1,
},
},
})

View file

@ -24,7 +24,7 @@ class SimpleCorridor(gym.Env):
self.cur_pos = 0
self.action_space = Discrete(2)
self.observation_space = Box(
0.0, self.end_pos, shape=(1,), dtype=np.float32)
0.0, self.end_pos, shape=(1, ), dtype=np.float32)
self._spec = EnvSpec("SimpleCorridor-{}-v0".format(self.end_pos))
def reset(self):
@ -32,7 +32,7 @@ class SimpleCorridor(gym.Env):
return [self.cur_pos]
def step(self, action):
assert action in [0, 1]
assert action in [0, 1], action
if action == 0 and self.cur_pos > 0:
self.cur_pos -= 1
elif action == 1:

View file

@ -51,14 +51,15 @@ class ModelCatalog(object):
>>> prep = ModelCatalog.get_preprocessor(env)
>>> observation = prep.transform(raw_observation)
>>> dist_cls, dist_dim = ModelCatalog.get_action_dist(env.action_space)
>>> model = ModelCatalog.get_model(inputs, dist_dim)
>>> dist_cls, dist_dim = ModelCatalog.get_action_dist(
env.action_space, {})
>>> model = ModelCatalog.get_model(inputs, dist_dim, options)
>>> dist = dist_cls(model.outputs)
>>> action = dist.sample()
"""
@staticmethod
def get_action_dist(action_space, config=None, dist_type=None):
def get_action_dist(action_space, config, dist_type=None):
"""Returns action distribution class and size for the given action space.
Args:
@ -90,7 +91,8 @@ class ModelCatalog(object):
child_dist = []
input_lens = []
for action in action_space.spaces:
dist, action_size = ModelCatalog.get_action_dist(action)
dist, action_size = ModelCatalog.get_action_dist(
action, config)
child_dist.append(dist)
input_lens.append(action_size)
return partial(
@ -139,11 +141,7 @@ class ModelCatalog(object):
" not supported".format(action_space))
@staticmethod
def get_model(inputs,
num_outputs,
options=None,
state_in=None,
seq_lens=None):
def get_model(inputs, num_outputs, options, state_in=None, seq_lens=None):
"""Returns a suitable model conforming to given input and output specs.
Args:
@ -157,7 +155,6 @@ class ModelCatalog(object):
model (Model): Neural network model.
"""
options = options or {}
model = ModelCatalog._get_model(inputs, num_outputs, options, state_in,
seq_lens)

View file

@ -69,12 +69,13 @@ class ModelCatalogTest(unittest.TestCase):
ray.init()
with tf.variable_scope("test1"):
p1 = ModelCatalog.get_model(np.zeros((10, 3), dtype=np.float32), 5)
p1 = ModelCatalog.get_model(
np.zeros((10, 3), dtype=np.float32), 5, {})
self.assertEqual(type(p1), FullyConnectedNetwork)
with tf.variable_scope("test2"):
p2 = ModelCatalog.get_model(
np.zeros((10, 84, 84, 3), dtype=np.float32), 5)
np.zeros((10, 84, 84, 3), dtype=np.float32), 5, {})
self.assertEqual(type(p2), VisionNetwork)
def testCustomModel(self):

View file

@ -1,4 +1,3 @@
# can expect improvement to -140 reward in ~300-500k timesteps
swimmer-ars:
env: Swimmer-v2
run: ARS
@ -9,8 +8,9 @@ swimmer-ars:
num_workers: 1
sgd_stepsize: 0.02
noise_size: 250000000
policy_type: LinearPolicy
eval_prob: 0.2
offset: 0
observation_filter: NoFilter
report_length: 3
model:
fcnet_hiddens: [] # a linear policy