From 8a6403c26e1bbc66994ffb27ecd925918c8203e1 Mon Sep 17 00:00:00 2001 From: Ameer Haj Ali Date: Wed, 13 Mar 2019 11:26:55 -0700 Subject: [PATCH] [rllib] bug fix: merging --config params with params.pkl (#4336) --- python/ray/rllib/rollout.py | 24 +++++++++++++----------- 1 file changed, 13 insertions(+), 11 deletions(-) diff --git a/python/ray/rllib/rollout.py b/python/ray/rllib/rollout.py index 68f2d6456..70b00eb63 100755 --- a/python/ray/rllib/rollout.py +++ b/python/ray/rllib/rollout.py @@ -12,6 +12,7 @@ import pickle import gym import ray from ray.rllib.agents.registry import get_agent_class +from ray.tune.util import merge_dicts EXAMPLE_USAGE = """ Example Usage via RLlib CLI: @@ -69,22 +70,23 @@ def create_parser(parser_creator=None): def run(args, parser): - config = args.config - if not config: - # Load configuration from file - config_dir = os.path.dirname(args.checkpoint) - config_path = os.path.join(config_dir, "params.pkl") - if not os.path.exists(config_path): - config_path = os.path.join(config_dir, "../params.pkl") - if not os.path.exists(config_path): + config = {} + # Load configuration from file + config_dir = os.path.dirname(args.checkpoint) + config_path = os.path.join(config_dir, "params.pkl") + if not os.path.exists(config_path): + config_path = os.path.join(config_dir, "../params.pkl") + if not os.path.exists(config_path): + if not args.config: raise ValueError( "Could not find params.pkl in either the checkpoint dir or " "its parent directory.") + else: with open(config_path, 'rb') as f: config = pickle.load(f) - if "num_workers" in config: - config["num_workers"] = min(2, config["num_workers"]) - + if "num_workers" in config: + config["num_workers"] = min(2, config["num_workers"]) + config = merge_dicts(config, args.config) if not args.env: if not config.get("env"): parser.error("the following arguments are required: --env")