mirror of
https://github.com/wassname/ray.git
synced 2026-06-30 11:09:15 +08:00
[rllib] Expose algorithm parameters and tune policy gradient parameters for humanoid (#753)
* parameters for humanoid * fix
This commit is contained in:
committed by
Robert Nishihara
parent
ade6d80820
commit
d356dd3ec4
@@ -1,6 +1,7 @@
|
||||
#!/bin/bash
|
||||
|
||||
python train.py --env Walker2d-v1 --alg PolicyGradient --upload-dir s3://bucketname/
|
||||
python train.py --env Humanoid-v1 --alg PolicyGradient --config '{"kl_coeff": 1.0, "num_sgd_iter": 20, "sgd_stepsize": 1e-4, "sgd_batchsize": 32768, "devices": ["/gpu:0", "/gpu:1", "/gpu:2", "/gpu:3"], "tf_session_args": {"device_count": {"GPU": 4}, "log_device_placement": false, "allow_soft_placement": true}, "timesteps_per_batch": 320000, "num_agents": 64}' --upload-dir s3://bucketname/
|
||||
python train.py --env PongNoFrameskip-v0 --alg DQN --upload-dir s3://bucketname/
|
||||
python train.py --env PongDeterministic-v0 --alg A3C --upload-dir s3://bucketname/
|
||||
python train.py --env Humanoid-v1 --alg EvolutionStrategies --upload-dir s3://bucketname/
|
||||
|
||||
@@ -18,6 +18,7 @@ parser = argparse.ArgumentParser(
|
||||
description=("Train a reinforcement learning agent."))
|
||||
parser.add_argument("--env", required=True, type=str)
|
||||
parser.add_argument("--alg", required=True, type=str)
|
||||
parser.add_argument("--config", default="{}", type=str)
|
||||
parser.add_argument("--upload-dir", default="file:///tmp/ray", type=str)
|
||||
|
||||
|
||||
@@ -28,17 +29,25 @@ if __name__ == "__main__":
|
||||
|
||||
env_name = args.env
|
||||
if args.alg == "PolicyGradient":
|
||||
config = pg.DEFAULT_CONFIG.copy()
|
||||
config.update(json.loads(args.config))
|
||||
alg = pg.PolicyGradient(
|
||||
env_name, pg.DEFAULT_CONFIG, upload_dir=args.upload_dir)
|
||||
env_name, config, upload_dir=args.upload_dir)
|
||||
elif args.alg == "EvolutionStrategies":
|
||||
config = es.DEFAULT_CONFIG.copy()
|
||||
config.update(json.loads(args.config))
|
||||
alg = es.EvolutionStrategies(
|
||||
env_name, es.DEFAULT_CONFIG, upload_dir=args.upload_dir)
|
||||
env_name, config, upload_dir=args.upload_dir)
|
||||
elif args.alg == "DQN":
|
||||
config = dqn.DEFAULT_CONFIG.copy()
|
||||
config.update(json.loads(args.config))
|
||||
alg = dqn.DQN(
|
||||
env_name, dqn.DEFAULT_CONFIG, upload_dir=args.upload_dir)
|
||||
env_name, config, upload_dir=args.upload_dir)
|
||||
elif args.alg == "A3C":
|
||||
config = a3c.DEFAULT_CONFIG.copy()
|
||||
config.update(json.loads(args.config))
|
||||
alg = a3c.A3C(
|
||||
env_name, a3c.DEFAULT_CONFIG, upload_dir=args.upload_dir)
|
||||
env_name, config, upload_dir=args.upload_dir)
|
||||
else:
|
||||
assert False, ("Unknown algorithm, check --alg argument. Valid "
|
||||
"choices are PolicyGradientPolicyGradient, "
|
||||
|
||||
Reference in New Issue
Block a user