[rllib] bug fix: merging --config params with params.pkl (#4336)

This commit is contained in:
Ameer Haj Ali 2019-03-13 11:26:55 -07:00 committed by Eric Liang
parent 87bfa1cf82
commit 8a6403c26e

View file

@ -12,6 +12,7 @@ import pickle
import gym import gym
import ray import ray
from ray.rllib.agents.registry import get_agent_class from ray.rllib.agents.registry import get_agent_class
from ray.tune.util import merge_dicts
EXAMPLE_USAGE = """ EXAMPLE_USAGE = """
Example Usage via RLlib CLI: Example Usage via RLlib CLI:
@ -69,22 +70,23 @@ def create_parser(parser_creator=None):
def run(args, parser): def run(args, parser):
config = args.config config = {}
if not config: # Load configuration from file
# Load configuration from file config_dir = os.path.dirname(args.checkpoint)
config_dir = os.path.dirname(args.checkpoint) config_path = os.path.join(config_dir, "params.pkl")
config_path = os.path.join(config_dir, "params.pkl") if not os.path.exists(config_path):
if not os.path.exists(config_path): config_path = os.path.join(config_dir, "../params.pkl")
config_path = os.path.join(config_dir, "../params.pkl") if not os.path.exists(config_path):
if not os.path.exists(config_path): if not args.config:
raise ValueError( raise ValueError(
"Could not find params.pkl in either the checkpoint dir or " "Could not find params.pkl in either the checkpoint dir or "
"its parent directory.") "its parent directory.")
else:
with open(config_path, 'rb') as f: with open(config_path, 'rb') as f:
config = pickle.load(f) config = pickle.load(f)
if "num_workers" in config: if "num_workers" in config:
config["num_workers"] = min(2, config["num_workers"]) config["num_workers"] = min(2, config["num_workers"])
config = merge_dicts(config, args.config)
if not args.env: if not args.env:
if not config.get("env"): if not config.get("env"):
parser.error("the following arguments are required: --env") parser.error("the following arguments are required: --env")