From d356dd3ec4827ae3b8dd9d8e740fb1066f27aa1e Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Wed, 19 Jul 2017 23:45:05 +0000 Subject: [PATCH] [rllib] Expose algorithm parameters and tune policy gradient parameters for humanoid (#753) * parameters for humanoid * fix --- python/ray/rllib/test.sh | 1 + python/ray/rllib/train.py | 17 +++++++++++++---- 2 files changed, 14 insertions(+), 4 deletions(-) diff --git a/python/ray/rllib/test.sh b/python/ray/rllib/test.sh index 77361c3a9..2e11229cd 100755 --- a/python/ray/rllib/test.sh +++ b/python/ray/rllib/test.sh @@ -1,6 +1,7 @@ #!/bin/bash python train.py --env Walker2d-v1 --alg PolicyGradient --upload-dir s3://bucketname/ +python train.py --env Humanoid-v1 --alg PolicyGradient --config '{"kl_coeff": 1.0, "num_sgd_iter": 20, "sgd_stepsize": 1e-4, "sgd_batchsize": 32768, "devices": ["/gpu:0", "/gpu:1", "/gpu:2", "/gpu:3"], "tf_session_args": {"device_count": {"GPU": 4}, "log_device_placement": false, "allow_soft_placement": true}, "timesteps_per_batch": 320000, "num_agents": 64}' --upload-dir s3://bucketname/ python train.py --env PongNoFrameskip-v0 --alg DQN --upload-dir s3://bucketname/ python train.py --env PongDeterministic-v0 --alg A3C --upload-dir s3://bucketname/ python train.py --env Humanoid-v1 --alg EvolutionStrategies --upload-dir s3://bucketname/ diff --git a/python/ray/rllib/train.py b/python/ray/rllib/train.py index 88daf9582..e2de3ff4f 100755 --- a/python/ray/rllib/train.py +++ b/python/ray/rllib/train.py @@ -18,6 +18,7 @@ parser = argparse.ArgumentParser( description=("Train a reinforcement learning agent.")) parser.add_argument("--env", required=True, type=str) parser.add_argument("--alg", required=True, type=str) +parser.add_argument("--config", default="{}", type=str) parser.add_argument("--upload-dir", default="file:///tmp/ray", type=str) @@ -28,17 +29,25 @@ if __name__ == "__main__": env_name = args.env if args.alg == "PolicyGradient": + config = pg.DEFAULT_CONFIG.copy() + config.update(json.loads(args.config)) alg = pg.PolicyGradient( - env_name, pg.DEFAULT_CONFIG, upload_dir=args.upload_dir) + env_name, config, upload_dir=args.upload_dir) elif args.alg == "EvolutionStrategies": + config = es.DEFAULT_CONFIG.copy() + config.update(json.loads(args.config)) alg = es.EvolutionStrategies( - env_name, es.DEFAULT_CONFIG, upload_dir=args.upload_dir) + env_name, config, upload_dir=args.upload_dir) elif args.alg == "DQN": + config = dqn.DEFAULT_CONFIG.copy() + config.update(json.loads(args.config)) alg = dqn.DQN( - env_name, dqn.DEFAULT_CONFIG, upload_dir=args.upload_dir) + env_name, config, upload_dir=args.upload_dir) elif args.alg == "A3C": + config = a3c.DEFAULT_CONFIG.copy() + config.update(json.loads(args.config)) alg = a3c.A3C( - env_name, a3c.DEFAULT_CONFIG, upload_dir=args.upload_dir) + env_name, config, upload_dir=args.upload_dir) else: assert False, ("Unknown algorithm, check --alg argument. Valid " "choices are PolicyGradientPolicyGradient, "