pytorch version

This commit is contained in:
Kolesnikov Sergey
2017-11-15 22:18:46 +03:00
parent 34993abdf7
commit 7401266fe7
49 changed files with 5435 additions and 1 deletions
+237
View File
@@ -0,0 +1,237 @@
import argparse
import os
import json
import copy
import torch
import torch.multiprocessing as mp
from multiprocessing import Value
from common.misc_util import boolean_flag, str2params, create_if_need
from common.env_wrappers import create_env
from common.torch_util import activations, hard_update
from ddpg.model import create_model, create_act_update_fns, train_multi_thread, \
train_single_thread, play_single_thread
def parse_args():
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--seed', type=int, default=42)
parser.add_argument('--difficulty', type=int, default=2)
parser.add_argument('--max-obstacles', type=int, default=3)
parser.add_argument('--logdir', type=str, default="./logs")
parser.add_argument('--num-threads', type=int, default=1)
parser.add_argument('--num-train-threads', type=int, default=1)
boolean_flag(parser, "ddpg-wrapper", default=False)
parser.add_argument('--skip-frames', type=int, default=1)
parser.add_argument('--fail-reward', type=float, default=0.0)
parser.add_argument('--reward-scale', type=float, default=1.)
boolean_flag(parser, "flip-state-action", default=False)
for agent in ["actor", "critic"]:
parser.add_argument('--{}-layers'.format(agent), type=str, default="64-64")
parser.add_argument('--{}-activation'.format(agent), type=str, default="relu")
boolean_flag(parser, "{}-layer-norm".format(agent), default=False)
boolean_flag(parser, "{}-parameters-noise".format(agent), default=False)
boolean_flag(parser, "{}-parameters-noise-factorised".format(agent), default=False)
parser.add_argument('--{}-lr'.format(agent), type=float, default=1e-3)
parser.add_argument('--{}-lr-end'.format(agent), type=float, default=5e-5)
parser.add_argument('--restore-{}-from'.format(agent), type=str, default=None)
parser.add_argument('--gamma', type=float, default=0.96)
parser.add_argument('--loss-type', type=str, default="quadric-linear")
parser.add_argument('--grad-clip', type=float, default=10.)
parser.add_argument('--tau', default=0.01, type=float)
parser.add_argument('--train-steps', type=int, default=int(1e4))
parser.add_argument('--batch-size', type=int, default=256) # per worker
parser.add_argument('--buffer-size', type=int, default=int(1e6))
boolean_flag(parser, "prioritized-replay", default=False)
parser.add_argument('--prioritized-replay-alpha', default=0.6, type=float)
parser.add_argument('--prioritized-replay-beta0', default=0.4, type=float)
parser.add_argument('--initial-epsilon', default=1., type=float)
parser.add_argument('--final-epsilon', default=0.01, type=float)
parser.add_argument('--max-episodes', default=int(1e4), type=int)
parser.add_argument('--max-update-steps', default=int(5e6), type=int)
parser.add_argument('--epsilon-cycle-len', default=int(2e2), type=int)
parser.add_argument('--max-train-days', default=int(1e1), type=int)
parser.add_argument('--rp-type', default="ornstein-uhlenbeck", type=str)
parser.add_argument('--rp-theta', default=0.15, type=float)
parser.add_argument('--rp-sigma', default=0.2, type=float)
parser.add_argument('--rp-sigma-min', default=0.15, type=float)
parser.add_argument('--rp-mu', default=0.0, type=float)
parser.add_argument('--clip-delta', type=int, default=10)
parser.add_argument('--save-step', type=int, default=int(1e4))
parser.add_argument('--restore-args-from', type=str, default=None)
return parser.parse_args()
def restore_args(args):
with open(args.restore_args_from, "r") as fin:
params = json.load(fin)
del params["seed"]
del params["difficulty"]
del params["max_obstacles"]
del params["logdir"]
del params["num_threads"]
del params["num_train_threads"]
del params["skip_frames"]
for agent in ["actor", "critic"]:
del params["{}_lr".format(agent)]
del params["{}_lr_end".format(agent)]
del params["restore_{}_from".format(agent)]
del params["grad_clip"]
del params["tau"]
del params["train_steps"]
del params["batch_size"]
del params["buffer_size"]
del params["prioritized_replay"]
del params["prioritized_replay_alpha"]
del params["prioritized_replay_beta0"]
del params["initial_epsilon"]
del params["final_epsilon"]
del params["max_episodes"]
del params["max_update_steps"]
del params["epsilon_cycle_len"]
del params["max_train_days"]
del params["rp_type"]
del params["rp_theta"]
del params["rp_sigma"]
del params["rp_sigma_min"]
del params["rp_mu"]
del params["clip_delta"]
del params["save_step"]
del params["restore_args_from"]
for key, value in params.items():
setattr(args, key, value)
return args
def train(args, model_fn, act_update_fns, multi_thread, train_single, play_single):
create_if_need(args.logdir)
if args.restore_args_from is not None:
args = restore_args(args)
with open("{}/args.json".format(args.logdir), "w") as fout:
json.dump(vars(args), fout, indent=4, ensure_ascii=False, sort_keys=True)
env = create_env(args)
if args.flip_state_action and hasattr(env, "state_transform"):
args.flip_states = env.state_transform.flip_states
args.batch_size = args.batch_size // 2
args.n_action = env.action_space.shape[0]
args.n_observation = env.observation_space.shape[0]
args.actor_layers = str2params(args.actor_layers)
args.critic_layers = str2params(args.critic_layers)
args.actor_activation = activations[args.actor_activation]
args.critic_activation = activations[args.critic_activation]
actor, critic = model_fn(args)
if args.restore_actor_from is not None:
actor.load_state_dict(torch.load(args.restore_actor_from))
if args.restore_critic_from is not None:
critic.load_state_dict(torch.load(args.restore_critic_from))
actor.train()
critic.train()
actor.share_memory()
critic.share_memory()
target_actor = copy.deepcopy(actor)
target_critic = copy.deepcopy(critic)
hard_update(target_actor, actor)
hard_update(target_critic, critic)
target_actor.train()
target_critic.train()
target_actor.share_memory()
target_critic.share_memory()
_, _, save_fn = act_update_fns(actor, critic, target_actor, target_critic, args)
processes = []
best_reward = Value("f", 0.0)
try:
if args.num_threads == args.num_train_threads:
for rank in range(args.num_threads):
args.thread = rank
p = mp.Process(
target=multi_thread,
args=(actor, critic, target_actor, target_critic, args, act_update_fns,
best_reward))
p.start()
processes.append(p)
else:
global_episode = Value("i", 0)
global_update_step = Value("i", 0)
episodes_queue = mp.Queue()
for rank in range(args.num_threads):
args.thread = rank
if rank < args.num_train_threads:
p = mp.Process(
target=train_single,
args=(actor, critic, target_actor, target_critic, args, act_update_fns,
global_episode, global_update_step, episodes_queue))
else:
p = mp.Process(
target=play_single,
args=(actor, critic, target_actor, target_critic, args, act_update_fns,
global_episode, global_update_step, episodes_queue,
best_reward))
p.start()
processes.append(p)
for p in processes:
p.join()
except KeyboardInterrupt:
pass
save_fn()
if __name__ == '__main__':
os.environ['OMP_NUM_THREADS'] = '1'
torch.set_num_threads(1)
args = parse_args()
train(args,
create_model,
create_act_update_fns,
train_multi_thread,
train_single_thread,
play_single_thread)