mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
[rllib] Fix support for mixed discrete and continuous action spaces, add to regression test (#2655)
* fix * lint * fix
This commit is contained in:
parent
98fed67b45
commit
53f9755594
4 changed files with 22 additions and 5 deletions
|
@ -98,6 +98,7 @@ Here is an example of the basic usage:
|
|||
|
||||
import ray
|
||||
import ray.rllib.agents.ppo as ppo
|
||||
from ray.tune.logger import pretty_print
|
||||
|
||||
ray.init()
|
||||
config = ppo.DEFAULT_CONFIG.copy()
|
||||
|
@ -108,7 +109,7 @@ Here is an example of the basic usage:
|
|||
for i in range(1000):
|
||||
# Perform one iteration of training the policy with PPO
|
||||
result = agent.train()
|
||||
print("result: {}".format(result))
|
||||
print(pretty_print(result))
|
||||
|
||||
if i % 100 == 0:
|
||||
checkpoint = agent.save()
|
||||
|
|
|
@ -404,7 +404,13 @@ class _MultiAgentEpisode(object):
|
|||
action = self._agent_to_last_action[agent_id]
|
||||
# Concatenate tuple actions
|
||||
if isinstance(action, list):
|
||||
action = np.concatenate(action, axis=0).flatten()
|
||||
expanded = []
|
||||
for a in action:
|
||||
if len(a.shape) == 1:
|
||||
expanded.append(np.expand_dims(a, 1))
|
||||
else:
|
||||
expanded.append(a)
|
||||
action = np.concatenate(expanded, axis=1).flatten()
|
||||
return action
|
||||
|
||||
def last_pi_info_for(self, agent_id):
|
||||
|
|
|
@ -182,7 +182,6 @@ class MultiActionDistribution(ActionDistribution):
|
|||
|
||||
def __init__(self, inputs, action_space, child_distributions, input_lens):
|
||||
self.input_lens = input_lens
|
||||
inputs = tf.reshape(inputs, [-1, sum(input_lens)])
|
||||
split_inputs = tf.split(inputs, self.input_lens, axis=1)
|
||||
child_list = []
|
||||
for i, distribution in enumerate(child_distributions):
|
||||
|
@ -191,11 +190,18 @@ class MultiActionDistribution(ActionDistribution):
|
|||
|
||||
def logp(self, x):
|
||||
"""The log-likelihood of the action distribution."""
|
||||
split_list = tf.split(x, len(self.input_lens), axis=1)
|
||||
split_indices = []
|
||||
for dist in self.child_distributions:
|
||||
if isinstance(dist, Categorical):
|
||||
split_indices.append(1)
|
||||
else:
|
||||
split_indices.append(tf.shape(dist.sample())[1])
|
||||
split_list = tf.split(x, split_indices, axis=1)
|
||||
for i, distribution in enumerate(self.child_distributions):
|
||||
# Remove extra categorical dimension
|
||||
if isinstance(distribution, Categorical):
|
||||
split_list[i] = tf.squeeze(split_list[i], axis=-1)
|
||||
split_list[i] = tf.cast(
|
||||
tf.squeeze(split_list[i], axis=-1), tf.int32)
|
||||
log_list = np.asarray([
|
||||
distribution.logp(split_x) for distribution, split_x in zip(
|
||||
self.child_distributions, split_list)
|
||||
|
|
|
@ -23,6 +23,10 @@ ACTION_SPACES_TO_TEST = {
|
|||
Box(0.0, 1.0, (5, ), dtype=np.float32),
|
||||
Box(0.0, 1.0, (5, ), dtype=np.float32)
|
||||
],
|
||||
"mixed_tuple": Tuple(
|
||||
[Discrete(2),
|
||||
Discrete(3),
|
||||
Box(0.0, 1.0, (5, ), dtype=np.float32)]),
|
||||
}
|
||||
|
||||
OBSERVATION_SPACES_TO_TEST = {
|
||||
|
|
Loading…
Add table
Reference in a new issue