mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
[rllib] bug fix: merging --config params with params.pkl (#4336)
This commit is contained in:
parent
87bfa1cf82
commit
8a6403c26e
1 changed files with 13 additions and 11 deletions
|
@ -12,6 +12,7 @@ import pickle
|
|||
import gym
|
||||
import ray
|
||||
from ray.rllib.agents.registry import get_agent_class
|
||||
from ray.tune.util import merge_dicts
|
||||
|
||||
EXAMPLE_USAGE = """
|
||||
Example Usage via RLlib CLI:
|
||||
|
@ -69,22 +70,23 @@ def create_parser(parser_creator=None):
|
|||
|
||||
|
||||
def run(args, parser):
|
||||
config = args.config
|
||||
if not config:
|
||||
# Load configuration from file
|
||||
config_dir = os.path.dirname(args.checkpoint)
|
||||
config_path = os.path.join(config_dir, "params.pkl")
|
||||
if not os.path.exists(config_path):
|
||||
config_path = os.path.join(config_dir, "../params.pkl")
|
||||
if not os.path.exists(config_path):
|
||||
config = {}
|
||||
# Load configuration from file
|
||||
config_dir = os.path.dirname(args.checkpoint)
|
||||
config_path = os.path.join(config_dir, "params.pkl")
|
||||
if not os.path.exists(config_path):
|
||||
config_path = os.path.join(config_dir, "../params.pkl")
|
||||
if not os.path.exists(config_path):
|
||||
if not args.config:
|
||||
raise ValueError(
|
||||
"Could not find params.pkl in either the checkpoint dir or "
|
||||
"its parent directory.")
|
||||
else:
|
||||
with open(config_path, 'rb') as f:
|
||||
config = pickle.load(f)
|
||||
if "num_workers" in config:
|
||||
config["num_workers"] = min(2, config["num_workers"])
|
||||
|
||||
if "num_workers" in config:
|
||||
config["num_workers"] = min(2, config["num_workers"])
|
||||
config = merge_dicts(config, args.config)
|
||||
if not args.env:
|
||||
if not config.get("env"):
|
||||
parser.error("the following arguments are required: --env")
|
||||
|
|
Loading…
Add table
Reference in a new issue