[rllib] Moved clip_action into policy_graph; Clip actions in compute_single_action (#4459)

* Moved clip_action into policy_graph; Clip actions in compute_single_action

* Update policy_graph.py

* Changed formatting

* Updated codebase for convencience
This commit is contained in:
Leon Sievers 2019-03-29 21:26:07 +01:00 committed by Eric Liang
parent 5133b10700
commit f4b313eaad
4 changed files with 46 additions and 34 deletions

View file

@ -448,9 +448,19 @@ class Agent(Trainable):
preprocessed, update=False)
if state:
return self.get_policy(policy_id).compute_single_action(
filtered_obs, state, prev_action, prev_reward, info)
filtered_obs,
state,
prev_action,
prev_reward,
info,
clip_actions=self.config["clip_actions"])
return self.get_policy(policy_id).compute_single_action(
filtered_obs, state, prev_action, prev_reward, info)[0]
filtered_obs,
state,
prev_action,
prev_reward,
info,
clip_actions=self.config["clip_actions"])[0]
@property
def iteration(self):

View file

@ -2,6 +2,9 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import gym
from ray.rllib.utils.annotations import DeveloperAPI
@ -81,6 +84,7 @@ class PolicyGraph(object):
prev_reward=None,
info=None,
episode=None,
clip_actions=False,
**kwargs):
"""Unbatched version of compute_actions.
@ -93,6 +97,7 @@ class PolicyGraph(object):
episode (MultiAgentEpisode): this provides access to all of the
internal episode state, which may be useful for model-based or
multi-agent algorithms.
clip_actions (bool): should the action be clipped
kwargs: forward compatibility placeholder
Returns:
@ -119,6 +124,8 @@ class PolicyGraph(object):
prev_reward_batch=prev_reward_batch,
info_batch=info_batch,
episodes=episodes)
if clip_actions:
action = clip_action(action, self.action_space)
return action, [s[0] for s in state_out], \
{k: v[0] for k, v in info.items()}
@ -263,3 +270,28 @@ class PolicyGraph(object):
export_dir (str): Local writable directory.
"""
raise NotImplementedError
def clip_action(action, space):
"""Called to clip actions to the specified range of this policy.
Arguments:
action: Single action.
space: Action space the actions should be present in.
Returns:
Clipped batch of actions.
"""
if isinstance(space, gym.spaces.Box):
return np.clip(action, space.low, space.high)
elif isinstance(space, gym.spaces.Tuple):
if type(action) not in (tuple, list):
raise ValueError("Expected tuple space for actions {}: {}".format(
action, space))
out = []
for a, s in zip(action, space.spaces):
out.append(clip_action(a, s))
return out
else:
return action

View file

@ -2,7 +2,6 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import gym
from collections import defaultdict, namedtuple
import logging
import numpy as np
@ -21,6 +20,7 @@ from ray.rllib.offline import InputReader
from ray.rllib.utils.annotations import override
from ray.rllib.utils.debug import log_once, summarize
from ray.rllib.utils.tf_run_builder import TFRunBuilder
from ray.rllib.evaluation.policy_graph import clip_action
logger = logging.getLogger(__name__)
@ -224,31 +224,6 @@ class AsyncSampler(threading.Thread, SamplerInput):
return extra
def clip_action(action, space):
"""Called to clip actions to the specified range of this policy.
Arguments:
action: Single action.
space: Action space the actions should be present in.
Returns:
Clipped batch of actions.
"""
if isinstance(space, gym.spaces.Box):
return np.clip(action, space.low, space.high)
elif isinstance(space, gym.spaces.Tuple):
if type(action) not in (tuple, list):
raise ValueError("Expected tuple space for actions {}: {}".format(
action, space))
out = []
for a, s in zip(action, space.spaces):
out.append(clip_action(a, s))
return out
else:
return action
def _env_runner(base_env, extra_batch_callback, policies, policy_mapping_fn,
unroll_length, horizon, preprocessors, obs_filters,
clip_rewards, clip_actions, pack, callbacks, tf_sess,

View file

@ -13,7 +13,6 @@ import gym
import ray
from ray.rllib.agents.registry import get_agent_class
from ray.rllib.evaluation.sample_batch import DEFAULT_POLICY_ID
from ray.rllib.evaluation.sampler import clip_action
from ray.tune.util import merge_dicts
EXAMPLE_USAGE = """
@ -155,11 +154,7 @@ def rollout(agent, env_name, num_steps, out=None, no_render=True):
else:
action = agent.compute_action(state)
if agent.config["clip_actions"]:
clipped_action = clip_action(action, env.action_space)
next_state, reward, done, _ = env.step(clipped_action)
else:
next_state, reward, done, _ = env.step(action)
next_state, reward, done, _ = env.step(action)
if multiagent:
done = done["__all__"]