mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
[rllib] Added missing action clipping for rollout example script (#4413)
* Added action clipping for rollout example script * Used action_clipping from sampler * Fixed and improved naming
This commit is contained in:
parent
59d74d5e92
commit
b21c20c9a6
2 changed files with 32 additions and 27 deletions
|
@ -200,6 +200,31 @@ class AsyncSampler(threading.Thread, SamplerInput):
|
|||
return extra
|
||||
|
||||
|
||||
def clip_action(action, space):
|
||||
"""Called to clip actions to the specified range of this policy.
|
||||
|
||||
Arguments:
|
||||
action: Single action.
|
||||
space: Action space the actions should be present in.
|
||||
|
||||
Returns:
|
||||
Clipped batch of actions.
|
||||
"""
|
||||
|
||||
if isinstance(space, gym.spaces.Box):
|
||||
return np.clip(action, space.low, space.high)
|
||||
elif isinstance(space, gym.spaces.Tuple):
|
||||
if type(action) not in (tuple, list):
|
||||
raise ValueError("Expected tuple space for actions {}: {}".format(
|
||||
action, space))
|
||||
out = []
|
||||
for a, s in zip(action, space.spaces):
|
||||
out.append(clip_action(a, s))
|
||||
return out
|
||||
else:
|
||||
return action
|
||||
|
||||
|
||||
def _env_runner(base_env,
|
||||
extra_batch_callback,
|
||||
policies,
|
||||
|
@ -526,7 +551,7 @@ def _process_policy_eval_results(to_eval, eval_results, active_episodes,
|
|||
env_id = eval_data[i].env_id
|
||||
agent_id = eval_data[i].agent_id
|
||||
if clip_actions:
|
||||
actions_to_send[env_id][agent_id] = _clip_actions(
|
||||
actions_to_send[env_id][agent_id] = clip_action(
|
||||
action, policy.action_space)
|
||||
else:
|
||||
actions_to_send[env_id][agent_id] = action
|
||||
|
@ -563,31 +588,6 @@ def _fetch_atari_metrics(base_env):
|
|||
return atari_out
|
||||
|
||||
|
||||
def _clip_actions(actions, space):
|
||||
"""Called to clip actions to the specified range of this policy.
|
||||
|
||||
Arguments:
|
||||
actions: Single action.
|
||||
space: Action space the actions should be present in.
|
||||
|
||||
Returns:
|
||||
Clipped batch of actions.
|
||||
"""
|
||||
|
||||
if isinstance(space, gym.spaces.Box):
|
||||
return np.clip(actions, space.low, space.high)
|
||||
elif isinstance(space, gym.spaces.Tuple):
|
||||
if type(actions) not in (tuple, list):
|
||||
raise ValueError("Expected tuple space for actions {}: {}".format(
|
||||
actions, space))
|
||||
out = []
|
||||
for a, s in zip(actions, space.spaces):
|
||||
out.append(_clip_actions(a, s))
|
||||
return out
|
||||
else:
|
||||
return actions
|
||||
|
||||
|
||||
def _unbatch_tuple_actions(action_batch):
|
||||
# convert list of batches -> batch of lists
|
||||
if isinstance(action_batch, TupleActions):
|
||||
|
|
|
@ -12,6 +12,7 @@ import pickle
|
|||
import gym
|
||||
import ray
|
||||
from ray.rllib.agents.registry import get_agent_class
|
||||
from ray.rllib.evaluation.sampler import clip_action
|
||||
from ray.tune.util import merge_dicts
|
||||
|
||||
EXAMPLE_USAGE = """
|
||||
|
@ -153,7 +154,11 @@ def rollout(agent, env_name, num_steps, out=None, no_render=True):
|
|||
else:
|
||||
action = agent.compute_action(state)
|
||||
|
||||
next_state, reward, done, _ = env.step(action)
|
||||
if agent.config["clip_actions"]:
|
||||
clipped_action = clip_action(action, env.action_space)
|
||||
next_state, reward, done, _ = env.step(clipped_action)
|
||||
else:
|
||||
next_state, reward, done, _ = env.step(action)
|
||||
|
||||
if multiagent:
|
||||
done = done["__all__"]
|
||||
|
|
Loading…
Add table
Reference in a new issue