mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
[rllib] bug fix: merging --config params with params.pkl (#4336)
This commit is contained in:
parent
87bfa1cf82
commit
8a6403c26e
1 changed files with 13 additions and 11 deletions
|
@ -12,6 +12,7 @@ import pickle
|
||||||
import gym
|
import gym
|
||||||
import ray
|
import ray
|
||||||
from ray.rllib.agents.registry import get_agent_class
|
from ray.rllib.agents.registry import get_agent_class
|
||||||
|
from ray.tune.util import merge_dicts
|
||||||
|
|
||||||
EXAMPLE_USAGE = """
|
EXAMPLE_USAGE = """
|
||||||
Example Usage via RLlib CLI:
|
Example Usage via RLlib CLI:
|
||||||
|
@ -69,22 +70,23 @@ def create_parser(parser_creator=None):
|
||||||
|
|
||||||
|
|
||||||
def run(args, parser):
|
def run(args, parser):
|
||||||
config = args.config
|
config = {}
|
||||||
if not config:
|
# Load configuration from file
|
||||||
# Load configuration from file
|
config_dir = os.path.dirname(args.checkpoint)
|
||||||
config_dir = os.path.dirname(args.checkpoint)
|
config_path = os.path.join(config_dir, "params.pkl")
|
||||||
config_path = os.path.join(config_dir, "params.pkl")
|
if not os.path.exists(config_path):
|
||||||
if not os.path.exists(config_path):
|
config_path = os.path.join(config_dir, "../params.pkl")
|
||||||
config_path = os.path.join(config_dir, "../params.pkl")
|
if not os.path.exists(config_path):
|
||||||
if not os.path.exists(config_path):
|
if not args.config:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Could not find params.pkl in either the checkpoint dir or "
|
"Could not find params.pkl in either the checkpoint dir or "
|
||||||
"its parent directory.")
|
"its parent directory.")
|
||||||
|
else:
|
||||||
with open(config_path, 'rb') as f:
|
with open(config_path, 'rb') as f:
|
||||||
config = pickle.load(f)
|
config = pickle.load(f)
|
||||||
if "num_workers" in config:
|
if "num_workers" in config:
|
||||||
config["num_workers"] = min(2, config["num_workers"])
|
config["num_workers"] = min(2, config["num_workers"])
|
||||||
|
config = merge_dicts(config, args.config)
|
||||||
if not args.env:
|
if not args.env:
|
||||||
if not config.get("env"):
|
if not config.get("env"):
|
||||||
parser.error("the following arguments are required: --env")
|
parser.error("the following arguments are required: --env")
|
||||||
|
|
Loading…
Add table
Reference in a new issue