mirror of
https://github.com/vale981/ray
synced 2025-03-06 18:41:40 -05:00
[rllib] Added missing action clipping for rollout example script (#4413)
* Added action clipping for rollout example script * Used action_clipping from sampler * Fixed and improved naming
This commit is contained in:
parent
59d74d5e92
commit
b21c20c9a6
2 changed files with 32 additions and 27 deletions
|
@ -200,6 +200,31 @@ class AsyncSampler(threading.Thread, SamplerInput):
|
||||||
return extra
|
return extra
|
||||||
|
|
||||||
|
|
||||||
|
def clip_action(action, space):
|
||||||
|
"""Called to clip actions to the specified range of this policy.
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
action: Single action.
|
||||||
|
space: Action space the actions should be present in.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Clipped batch of actions.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if isinstance(space, gym.spaces.Box):
|
||||||
|
return np.clip(action, space.low, space.high)
|
||||||
|
elif isinstance(space, gym.spaces.Tuple):
|
||||||
|
if type(action) not in (tuple, list):
|
||||||
|
raise ValueError("Expected tuple space for actions {}: {}".format(
|
||||||
|
action, space))
|
||||||
|
out = []
|
||||||
|
for a, s in zip(action, space.spaces):
|
||||||
|
out.append(clip_action(a, s))
|
||||||
|
return out
|
||||||
|
else:
|
||||||
|
return action
|
||||||
|
|
||||||
|
|
||||||
def _env_runner(base_env,
|
def _env_runner(base_env,
|
||||||
extra_batch_callback,
|
extra_batch_callback,
|
||||||
policies,
|
policies,
|
||||||
|
@ -526,7 +551,7 @@ def _process_policy_eval_results(to_eval, eval_results, active_episodes,
|
||||||
env_id = eval_data[i].env_id
|
env_id = eval_data[i].env_id
|
||||||
agent_id = eval_data[i].agent_id
|
agent_id = eval_data[i].agent_id
|
||||||
if clip_actions:
|
if clip_actions:
|
||||||
actions_to_send[env_id][agent_id] = _clip_actions(
|
actions_to_send[env_id][agent_id] = clip_action(
|
||||||
action, policy.action_space)
|
action, policy.action_space)
|
||||||
else:
|
else:
|
||||||
actions_to_send[env_id][agent_id] = action
|
actions_to_send[env_id][agent_id] = action
|
||||||
|
@ -563,31 +588,6 @@ def _fetch_atari_metrics(base_env):
|
||||||
return atari_out
|
return atari_out
|
||||||
|
|
||||||
|
|
||||||
def _clip_actions(actions, space):
|
|
||||||
"""Called to clip actions to the specified range of this policy.
|
|
||||||
|
|
||||||
Arguments:
|
|
||||||
actions: Single action.
|
|
||||||
space: Action space the actions should be present in.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Clipped batch of actions.
|
|
||||||
"""
|
|
||||||
|
|
||||||
if isinstance(space, gym.spaces.Box):
|
|
||||||
return np.clip(actions, space.low, space.high)
|
|
||||||
elif isinstance(space, gym.spaces.Tuple):
|
|
||||||
if type(actions) not in (tuple, list):
|
|
||||||
raise ValueError("Expected tuple space for actions {}: {}".format(
|
|
||||||
actions, space))
|
|
||||||
out = []
|
|
||||||
for a, s in zip(actions, space.spaces):
|
|
||||||
out.append(_clip_actions(a, s))
|
|
||||||
return out
|
|
||||||
else:
|
|
||||||
return actions
|
|
||||||
|
|
||||||
|
|
||||||
def _unbatch_tuple_actions(action_batch):
|
def _unbatch_tuple_actions(action_batch):
|
||||||
# convert list of batches -> batch of lists
|
# convert list of batches -> batch of lists
|
||||||
if isinstance(action_batch, TupleActions):
|
if isinstance(action_batch, TupleActions):
|
||||||
|
|
|
@ -12,6 +12,7 @@ import pickle
|
||||||
import gym
|
import gym
|
||||||
import ray
|
import ray
|
||||||
from ray.rllib.agents.registry import get_agent_class
|
from ray.rllib.agents.registry import get_agent_class
|
||||||
|
from ray.rllib.evaluation.sampler import clip_action
|
||||||
from ray.tune.util import merge_dicts
|
from ray.tune.util import merge_dicts
|
||||||
|
|
||||||
EXAMPLE_USAGE = """
|
EXAMPLE_USAGE = """
|
||||||
|
@ -153,7 +154,11 @@ def rollout(agent, env_name, num_steps, out=None, no_render=True):
|
||||||
else:
|
else:
|
||||||
action = agent.compute_action(state)
|
action = agent.compute_action(state)
|
||||||
|
|
||||||
next_state, reward, done, _ = env.step(action)
|
if agent.config["clip_actions"]:
|
||||||
|
clipped_action = clip_action(action, env.action_space)
|
||||||
|
next_state, reward, done, _ = env.step(clipped_action)
|
||||||
|
else:
|
||||||
|
next_state, reward, done, _ = env.step(action)
|
||||||
|
|
||||||
if multiagent:
|
if multiagent:
|
||||||
done = done["__all__"]
|
done = done["__all__"]
|
||||||
|
|
Loading…
Add table
Reference in a new issue