[rllib] Added missing action clipping for rollout example script (#4413)

* Added action clipping for rollout example script

* Used action_clipping from sampler

* Fixed and improved naming
This commit is contained in:
Leon Sievers 2019-03-22 08:51:27 +01:00 committed by Eric Liang
parent 59d74d5e92
commit b21c20c9a6
2 changed files with 32 additions and 27 deletions

View file

@ -200,6 +200,31 @@ class AsyncSampler(threading.Thread, SamplerInput):
return extra
def clip_action(action, space):
"""Called to clip actions to the specified range of this policy.
Arguments:
action: Single action.
space: Action space the actions should be present in.
Returns:
Clipped batch of actions.
"""
if isinstance(space, gym.spaces.Box):
return np.clip(action, space.low, space.high)
elif isinstance(space, gym.spaces.Tuple):
if type(action) not in (tuple, list):
raise ValueError("Expected tuple space for actions {}: {}".format(
action, space))
out = []
for a, s in zip(action, space.spaces):
out.append(clip_action(a, s))
return out
else:
return action
def _env_runner(base_env,
extra_batch_callback,
policies,
@ -526,7 +551,7 @@ def _process_policy_eval_results(to_eval, eval_results, active_episodes,
env_id = eval_data[i].env_id
agent_id = eval_data[i].agent_id
if clip_actions:
actions_to_send[env_id][agent_id] = _clip_actions(
actions_to_send[env_id][agent_id] = clip_action(
action, policy.action_space)
else:
actions_to_send[env_id][agent_id] = action
@ -563,31 +588,6 @@ def _fetch_atari_metrics(base_env):
return atari_out
def _clip_actions(actions, space):
"""Called to clip actions to the specified range of this policy.
Arguments:
actions: Single action.
space: Action space the actions should be present in.
Returns:
Clipped batch of actions.
"""
if isinstance(space, gym.spaces.Box):
return np.clip(actions, space.low, space.high)
elif isinstance(space, gym.spaces.Tuple):
if type(actions) not in (tuple, list):
raise ValueError("Expected tuple space for actions {}: {}".format(
actions, space))
out = []
for a, s in zip(actions, space.spaces):
out.append(_clip_actions(a, s))
return out
else:
return actions
def _unbatch_tuple_actions(action_batch):
# convert list of batches -> batch of lists
if isinstance(action_batch, TupleActions):

View file

@ -12,6 +12,7 @@ import pickle
import gym
import ray
from ray.rllib.agents.registry import get_agent_class
from ray.rllib.evaluation.sampler import clip_action
from ray.tune.util import merge_dicts
EXAMPLE_USAGE = """
@ -153,7 +154,11 @@ def rollout(agent, env_name, num_steps, out=None, no_render=True):
else:
action = agent.compute_action(state)
next_state, reward, done, _ = env.step(action)
if agent.config["clip_actions"]:
clipped_action = clip_action(action, env.action_space)
next_state, reward, done, _ = env.step(clipped_action)
else:
next_state, reward, done, _ = env.step(action)
if multiagent:
done = done["__all__"]