[rllib] Fix support for mixed discrete and continuous action spaces, add to regression test (#2655)

* fix

* lint

* fix
This commit is contained in:
Eric Liang 2018-08-15 10:19:41 -07:00 committed by GitHub
parent 98fed67b45
commit 53f9755594
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 22 additions and 5 deletions

View file

@ -98,6 +98,7 @@ Here is an example of the basic usage:
import ray
import ray.rllib.agents.ppo as ppo
from ray.tune.logger import pretty_print
ray.init()
config = ppo.DEFAULT_CONFIG.copy()
@ -108,7 +109,7 @@ Here is an example of the basic usage:
for i in range(1000):
# Perform one iteration of training the policy with PPO
result = agent.train()
print("result: {}".format(result))
print(pretty_print(result))
if i % 100 == 0:
checkpoint = agent.save()

View file

@ -404,7 +404,13 @@ class _MultiAgentEpisode(object):
action = self._agent_to_last_action[agent_id]
# Concatenate tuple actions
if isinstance(action, list):
action = np.concatenate(action, axis=0).flatten()
expanded = []
for a in action:
if len(a.shape) == 1:
expanded.append(np.expand_dims(a, 1))
else:
expanded.append(a)
action = np.concatenate(expanded, axis=1).flatten()
return action
def last_pi_info_for(self, agent_id):

View file

@ -182,7 +182,6 @@ class MultiActionDistribution(ActionDistribution):
def __init__(self, inputs, action_space, child_distributions, input_lens):
self.input_lens = input_lens
inputs = tf.reshape(inputs, [-1, sum(input_lens)])
split_inputs = tf.split(inputs, self.input_lens, axis=1)
child_list = []
for i, distribution in enumerate(child_distributions):
@ -191,11 +190,18 @@ class MultiActionDistribution(ActionDistribution):
def logp(self, x):
"""The log-likelihood of the action distribution."""
split_list = tf.split(x, len(self.input_lens), axis=1)
split_indices = []
for dist in self.child_distributions:
if isinstance(dist, Categorical):
split_indices.append(1)
else:
split_indices.append(tf.shape(dist.sample())[1])
split_list = tf.split(x, split_indices, axis=1)
for i, distribution in enumerate(self.child_distributions):
# Remove extra categorical dimension
if isinstance(distribution, Categorical):
split_list[i] = tf.squeeze(split_list[i], axis=-1)
split_list[i] = tf.cast(
tf.squeeze(split_list[i], axis=-1), tf.int32)
log_list = np.asarray([
distribution.logp(split_x) for distribution, split_x in zip(
self.child_distributions, split_list)

View file

@ -23,6 +23,10 @@ ACTION_SPACES_TO_TEST = {
Box(0.0, 1.0, (5, ), dtype=np.float32),
Box(0.0, 1.0, (5, ), dtype=np.float32)
],
"mixed_tuple": Tuple(
[Discrete(2),
Discrete(3),
Box(0.0, 1.0, (5, ), dtype=np.float32)]),
}
OBSERVATION_SPACES_TO_TEST = {