mirror of
https://github.com/vale981/ray
synced 2025-03-05 18:11:42 -05:00
[RLlib] Issue 8889: action clipping bug ppo not learning mujoco (#8898)
This commit is contained in:
parent
002e1e7c8d
commit
a69715e9aa
2 changed files with 95 additions and 4 deletions
|
@ -888,15 +888,18 @@ def _process_policy_eval_results(*, to_eval, eval_results, active_episodes,
|
|||
pi_info_cols["state_out_{}".format(f_i)] = column
|
||||
|
||||
policy = _get_or_raise(policies, policy_id)
|
||||
# Clip if necessary (while action components are still batched).
|
||||
if clip_actions:
|
||||
actions = clip_action(actions, policy.action_space_struct)
|
||||
# Split action-component batches into single action rows.
|
||||
actions = unbatch(actions)
|
||||
for i, action in enumerate(actions):
|
||||
env_id = eval_data[i].env_id
|
||||
agent_id = eval_data[i].agent_id
|
||||
actions_to_send[env_id][agent_id] = action
|
||||
# Clip if necessary.
|
||||
if clip_actions:
|
||||
clipped_action = clip_action(action,
|
||||
policy.action_space_struct)
|
||||
else:
|
||||
clipped_action = action
|
||||
actions_to_send[env_id][agent_id] = clipped_action
|
||||
episode = active_episodes[env_id]
|
||||
episode._set_rnn_state(agent_id, [c[i] for c in rnn_out_cols])
|
||||
episode._set_last_pi_info(
|
||||
|
|
88
rllib/tuned_examples/debug_learning_failure_git_bisect.py
Normal file
88
rllib/tuned_examples/debug_learning_failure_git_bisect.py
Normal file
|
@ -0,0 +1,88 @@
|
|||
"""Example of testing, whether RLlib can still learn with a certain config.
|
||||
|
||||
Can be used with git bisect to find the faulty commit responsible for a
|
||||
learning failure. Produces an error if the given reward is not reached within
|
||||
the stopping criteria (training iters or timesteps) allowing git bisect to
|
||||
properly analyze and find the faulty commit.
|
||||
|
||||
Run as follows using a simple command line config:
|
||||
$ python debug_learning_failure_git_bisect.py --config '{...}'
|
||||
--env CartPole-v0 --run PPO --stop-reward=180 --stop-iters=100
|
||||
|
||||
With a yaml file:
|
||||
$ python debug_learning_failure_git_bisect.py -f [yaml file] --stop-reward=180
|
||||
--stop-iters=100
|
||||
|
||||
Within git bisect:
|
||||
$ git bisect start
|
||||
$ git bisect bad
|
||||
$ git bisect good [some previous commit we know was good]
|
||||
$ git bisect run python debug_learning_failure_git_bisect.py [... options]
|
||||
"""
|
||||
import argparse
|
||||
import json
|
||||
import yaml
|
||||
|
||||
import ray
|
||||
from ray import tune
|
||||
from ray.rllib.utils import try_import_tf
|
||||
from ray.rllib.utils.test_utils import check_learning_achieved
|
||||
|
||||
tf = try_import_tf()
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--run", type=str, default=None)
|
||||
parser.add_argument("--torch", action="store_true")
|
||||
parser.add_argument("--stop-iters", type=int, default=None)
|
||||
parser.add_argument("--stop-timesteps", type=int, default=None)
|
||||
parser.add_argument("--stop-reward", type=float, default=None)
|
||||
parser.add_argument("-f", type=str, default=None)
|
||||
parser.add_argument("--config", type=str, default=None)
|
||||
parser.add_argument("--env", type=str, default=None)
|
||||
|
||||
if __name__ == "__main__":
|
||||
run = None
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Explicit yaml config file.
|
||||
if args.f:
|
||||
with open(args.f, "r") as fp:
|
||||
experiment_config = yaml.load(fp)
|
||||
experiment_config = experiment_config[next(
|
||||
iter(experiment_config))]
|
||||
config = experiment_config.get("config", {})
|
||||
config["env"] = experiment_config.get("env")
|
||||
run = experiment_config.pop("run")
|
||||
# JSON string on command line.
|
||||
else:
|
||||
config = json.loads(args.config)
|
||||
assert args.env
|
||||
config["env"] = args.env
|
||||
|
||||
# Explicit run.
|
||||
if args.run:
|
||||
run = args.run
|
||||
|
||||
# Explicit --torch framework.
|
||||
if args.torch:
|
||||
config["framework"] = "torch"
|
||||
# Framework not specified in config, try to infer it.
|
||||
if "framework" not in config:
|
||||
config["framework"] = "torch" if args.torch else "tf"
|
||||
|
||||
ray.init()
|
||||
|
||||
stop = {}
|
||||
if args.stop_iters:
|
||||
stop["training_iteration"] = args.stop_iters
|
||||
if args.stop_timesteps:
|
||||
stop["timesteps_total"] = args.stop_timesteps
|
||||
if args.stop_reward:
|
||||
stop["episode_reward_mean"] = args.stop_reward
|
||||
|
||||
results = tune.run(run, stop=stop, config=config)
|
||||
|
||||
check_learning_achieved(results, args.stop_reward)
|
||||
|
||||
ray.shutdown()
|
Loading…
Add table
Reference in a new issue