mirror of
https://github.com/wassname/ray.git
synced 2026-06-27 22:08:16 +08:00
[rllib] bug fix: merging --config params with params.pkl (#4336)
This commit is contained in:
committed by
Eric Liang
parent
87bfa1cf82
commit
8a6403c26e
+13
-11
@@ -12,6 +12,7 @@ import pickle
|
||||
import gym
|
||||
import ray
|
||||
from ray.rllib.agents.registry import get_agent_class
|
||||
from ray.tune.util import merge_dicts
|
||||
|
||||
EXAMPLE_USAGE = """
|
||||
Example Usage via RLlib CLI:
|
||||
@@ -69,22 +70,23 @@ def create_parser(parser_creator=None):
|
||||
|
||||
|
||||
def run(args, parser):
|
||||
config = args.config
|
||||
if not config:
|
||||
# Load configuration from file
|
||||
config_dir = os.path.dirname(args.checkpoint)
|
||||
config_path = os.path.join(config_dir, "params.pkl")
|
||||
if not os.path.exists(config_path):
|
||||
config_path = os.path.join(config_dir, "../params.pkl")
|
||||
if not os.path.exists(config_path):
|
||||
config = {}
|
||||
# Load configuration from file
|
||||
config_dir = os.path.dirname(args.checkpoint)
|
||||
config_path = os.path.join(config_dir, "params.pkl")
|
||||
if not os.path.exists(config_path):
|
||||
config_path = os.path.join(config_dir, "../params.pkl")
|
||||
if not os.path.exists(config_path):
|
||||
if not args.config:
|
||||
raise ValueError(
|
||||
"Could not find params.pkl in either the checkpoint dir or "
|
||||
"its parent directory.")
|
||||
else:
|
||||
with open(config_path, 'rb') as f:
|
||||
config = pickle.load(f)
|
||||
if "num_workers" in config:
|
||||
config["num_workers"] = min(2, config["num_workers"])
|
||||
|
||||
if "num_workers" in config:
|
||||
config["num_workers"] = min(2, config["num_workers"])
|
||||
config = merge_dicts(config, args.config)
|
||||
if not args.env:
|
||||
if not config.get("env"):
|
||||
parser.error("the following arguments are required: --env")
|
||||
|
||||
Reference in New Issue
Block a user