mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
[rllib] Moved clip_action into policy_graph; Clip actions in compute_single_action (#4459)
* Moved clip_action into policy_graph; Clip actions in compute_single_action * Update policy_graph.py * Changed formatting * Updated codebase for convencience
This commit is contained in:
parent
5133b10700
commit
f4b313eaad
4 changed files with 46 additions and 34 deletions
|
@ -448,9 +448,19 @@ class Agent(Trainable):
|
|||
preprocessed, update=False)
|
||||
if state:
|
||||
return self.get_policy(policy_id).compute_single_action(
|
||||
filtered_obs, state, prev_action, prev_reward, info)
|
||||
filtered_obs,
|
||||
state,
|
||||
prev_action,
|
||||
prev_reward,
|
||||
info,
|
||||
clip_actions=self.config["clip_actions"])
|
||||
return self.get_policy(policy_id).compute_single_action(
|
||||
filtered_obs, state, prev_action, prev_reward, info)[0]
|
||||
filtered_obs,
|
||||
state,
|
||||
prev_action,
|
||||
prev_reward,
|
||||
info,
|
||||
clip_actions=self.config["clip_actions"])[0]
|
||||
|
||||
@property
|
||||
def iteration(self):
|
||||
|
|
|
@ -2,6 +2,9 @@ from __future__ import absolute_import
|
|||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
import gym
|
||||
|
||||
from ray.rllib.utils.annotations import DeveloperAPI
|
||||
|
||||
|
||||
|
@ -81,6 +84,7 @@ class PolicyGraph(object):
|
|||
prev_reward=None,
|
||||
info=None,
|
||||
episode=None,
|
||||
clip_actions=False,
|
||||
**kwargs):
|
||||
"""Unbatched version of compute_actions.
|
||||
|
||||
|
@ -93,6 +97,7 @@ class PolicyGraph(object):
|
|||
episode (MultiAgentEpisode): this provides access to all of the
|
||||
internal episode state, which may be useful for model-based or
|
||||
multi-agent algorithms.
|
||||
clip_actions (bool): should the action be clipped
|
||||
kwargs: forward compatibility placeholder
|
||||
|
||||
Returns:
|
||||
|
@ -119,6 +124,8 @@ class PolicyGraph(object):
|
|||
prev_reward_batch=prev_reward_batch,
|
||||
info_batch=info_batch,
|
||||
episodes=episodes)
|
||||
if clip_actions:
|
||||
action = clip_action(action, self.action_space)
|
||||
return action, [s[0] for s in state_out], \
|
||||
{k: v[0] for k, v in info.items()}
|
||||
|
||||
|
@ -263,3 +270,28 @@ class PolicyGraph(object):
|
|||
export_dir (str): Local writable directory.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def clip_action(action, space):
|
||||
"""Called to clip actions to the specified range of this policy.
|
||||
|
||||
Arguments:
|
||||
action: Single action.
|
||||
space: Action space the actions should be present in.
|
||||
|
||||
Returns:
|
||||
Clipped batch of actions.
|
||||
"""
|
||||
|
||||
if isinstance(space, gym.spaces.Box):
|
||||
return np.clip(action, space.low, space.high)
|
||||
elif isinstance(space, gym.spaces.Tuple):
|
||||
if type(action) not in (tuple, list):
|
||||
raise ValueError("Expected tuple space for actions {}: {}".format(
|
||||
action, space))
|
||||
out = []
|
||||
for a, s in zip(action, space.spaces):
|
||||
out.append(clip_action(a, s))
|
||||
return out
|
||||
else:
|
||||
return action
|
||||
|
|
|
@ -2,7 +2,6 @@ from __future__ import absolute_import
|
|||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import gym
|
||||
from collections import defaultdict, namedtuple
|
||||
import logging
|
||||
import numpy as np
|
||||
|
@ -21,6 +20,7 @@ from ray.rllib.offline import InputReader
|
|||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils.debug import log_once, summarize
|
||||
from ray.rllib.utils.tf_run_builder import TFRunBuilder
|
||||
from ray.rllib.evaluation.policy_graph import clip_action
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -224,31 +224,6 @@ class AsyncSampler(threading.Thread, SamplerInput):
|
|||
return extra
|
||||
|
||||
|
||||
def clip_action(action, space):
|
||||
"""Called to clip actions to the specified range of this policy.
|
||||
|
||||
Arguments:
|
||||
action: Single action.
|
||||
space: Action space the actions should be present in.
|
||||
|
||||
Returns:
|
||||
Clipped batch of actions.
|
||||
"""
|
||||
|
||||
if isinstance(space, gym.spaces.Box):
|
||||
return np.clip(action, space.low, space.high)
|
||||
elif isinstance(space, gym.spaces.Tuple):
|
||||
if type(action) not in (tuple, list):
|
||||
raise ValueError("Expected tuple space for actions {}: {}".format(
|
||||
action, space))
|
||||
out = []
|
||||
for a, s in zip(action, space.spaces):
|
||||
out.append(clip_action(a, s))
|
||||
return out
|
||||
else:
|
||||
return action
|
||||
|
||||
|
||||
def _env_runner(base_env, extra_batch_callback, policies, policy_mapping_fn,
|
||||
unroll_length, horizon, preprocessors, obs_filters,
|
||||
clip_rewards, clip_actions, pack, callbacks, tf_sess,
|
||||
|
|
|
@ -13,7 +13,6 @@ import gym
|
|||
import ray
|
||||
from ray.rllib.agents.registry import get_agent_class
|
||||
from ray.rllib.evaluation.sample_batch import DEFAULT_POLICY_ID
|
||||
from ray.rllib.evaluation.sampler import clip_action
|
||||
from ray.tune.util import merge_dicts
|
||||
|
||||
EXAMPLE_USAGE = """
|
||||
|
@ -155,11 +154,7 @@ def rollout(agent, env_name, num_steps, out=None, no_render=True):
|
|||
else:
|
||||
action = agent.compute_action(state)
|
||||
|
||||
if agent.config["clip_actions"]:
|
||||
clipped_action = clip_action(action, env.action_space)
|
||||
next_state, reward, done, _ = env.step(clipped_action)
|
||||
else:
|
||||
next_state, reward, done, _ = env.step(action)
|
||||
next_state, reward, done, _ = env.step(action)
|
||||
|
||||
if multiagent:
|
||||
done = done["__all__"]
|
||||
|
|
Loading…
Add table
Reference in a new issue