diff --git a/python/ray/rllib/evaluation/sampler.py b/python/ray/rllib/evaluation/sampler.py index 0c3032178..cd6935f8a 100644 --- a/python/ray/rllib/evaluation/sampler.py +++ b/python/ray/rllib/evaluation/sampler.py @@ -200,6 +200,31 @@ class AsyncSampler(threading.Thread, SamplerInput): return extra +def clip_action(action, space): + """Called to clip actions to the specified range of this policy. + + Arguments: + action: Single action. + space: Action space the actions should be present in. + + Returns: + Clipped batch of actions. + """ + + if isinstance(space, gym.spaces.Box): + return np.clip(action, space.low, space.high) + elif isinstance(space, gym.spaces.Tuple): + if type(action) not in (tuple, list): + raise ValueError("Expected tuple space for actions {}: {}".format( + action, space)) + out = [] + for a, s in zip(action, space.spaces): + out.append(clip_action(a, s)) + return out + else: + return action + + def _env_runner(base_env, extra_batch_callback, policies, @@ -526,7 +551,7 @@ def _process_policy_eval_results(to_eval, eval_results, active_episodes, env_id = eval_data[i].env_id agent_id = eval_data[i].agent_id if clip_actions: - actions_to_send[env_id][agent_id] = _clip_actions( + actions_to_send[env_id][agent_id] = clip_action( action, policy.action_space) else: actions_to_send[env_id][agent_id] = action @@ -563,31 +588,6 @@ def _fetch_atari_metrics(base_env): return atari_out -def _clip_actions(actions, space): - """Called to clip actions to the specified range of this policy. - - Arguments: - actions: Single action. - space: Action space the actions should be present in. - - Returns: - Clipped batch of actions. - """ - - if isinstance(space, gym.spaces.Box): - return np.clip(actions, space.low, space.high) - elif isinstance(space, gym.spaces.Tuple): - if type(actions) not in (tuple, list): - raise ValueError("Expected tuple space for actions {}: {}".format( - actions, space)) - out = [] - for a, s in zip(actions, space.spaces): - out.append(_clip_actions(a, s)) - return out - else: - return actions - - def _unbatch_tuple_actions(action_batch): # convert list of batches -> batch of lists if isinstance(action_batch, TupleActions): diff --git a/python/ray/rllib/rollout.py b/python/ray/rllib/rollout.py index 70b00eb63..0bd364583 100755 --- a/python/ray/rllib/rollout.py +++ b/python/ray/rllib/rollout.py @@ -12,6 +12,7 @@ import pickle import gym import ray from ray.rllib.agents.registry import get_agent_class +from ray.rllib.evaluation.sampler import clip_action from ray.tune.util import merge_dicts EXAMPLE_USAGE = """ @@ -153,7 +154,11 @@ def rollout(agent, env_name, num_steps, out=None, no_render=True): else: action = agent.compute_action(state) - next_state, reward, done, _ = env.step(action) + if agent.config["clip_actions"]: + clipped_action = clip_action(action, env.action_space) + next_state, reward, done, _ = env.step(clipped_action) + else: + next_state, reward, done, _ = env.step(action) if multiagent: done = done["__all__"]