From 8d9e2024c1cbe9b785fee5b6f73843889903827d Mon Sep 17 00:00:00 2001 From: wassname Date: Fri, 19 Jan 2018 14:09:23 +0800 Subject: [PATCH] add debug arg --- ddpg/train.py | 81 ++++++++++++++++++++++++++++----------------------- 1 file changed, 45 insertions(+), 36 deletions(-) diff --git a/ddpg/train.py b/ddpg/train.py index 538836c..a4f0965 100644 --- a/ddpg/train.py +++ b/ddpg/train.py @@ -30,6 +30,7 @@ def parse_args(): parser.add_argument('--fail-reward', type=float, default=0.0) parser.add_argument('--reward-scale', type=float, default=1.) boolean_flag(parser, "flip-state-action", default=False) + boolean_flag(parser, 'debug', default=False, help="Run in single threaded mode for debugging") for agent in ["actor", "critic", "dynamics"]: parser.add_argument('--{}-layers'.format(agent), type=str, default="64-64") @@ -196,45 +197,53 @@ def train(args, model_fn, act_update_fns, multi_thread, train_single, play_singl try: - if args.num_threads == args.num_train_threads == 1: - # # run a single thread in the foreground so we can debug easier - # args.thread = 1 - # multi_thread(actor, critic, dynamics, target_actor, target_critic, target_dynamics, args, act_update_fns, best_reward) - - # or debug the single thread funcs - args.thread = 1 - global_episode = Value("i", 0) - global_update_step = Value("i", 0) - episodes_queue = mp.Queue() - train_single(actor, critic, dynamics, target_actor, target_critic, target_dynamics, args, act_update_fns, global_episode, global_update_step, episodes_queue) - elif args.num_threads == args.num_train_threads: - for rank in range(args.num_threads): - args.thread = rank - p = mp.Process( - target=multi_thread, - args=(actor, critic, dynamics, target_actor, target_critic, target_dynamics, args, act_update_fns, - best_reward)) - p.start() - processes.append(p) - else: - global_episode = Value("i", 0) - global_update_step = Value("i", 0) - episodes_queue = mp.Queue() - for rank in range(args.num_threads): - args.thread = rank - if rank < args.num_train_threads: + if args.num_threads == args.num_train_threads: + print("training with train_multi_thread") + if args.debug: + # # run a single thread in the foreground so we can debug easier + args.thread = 1 + multi_thread(actor, critic, dynamics, target_actor, target_critic, target_dynamics, args, act_update_fns, best_reward) + else: + for rank in range(args.num_threads): + args.thread = rank p = mp.Process( - target=train_single, + target=multi_thread, args=(actor, critic, dynamics, target_actor, target_critic, target_dynamics, args, act_update_fns, - global_episode, global_update_step, episodes_queue)) - else: - p = mp.Process( - target=play_single, - args=(actor, critic, dynamics, target_actor, target_critic, target_dynamics, args, act_update_fns, - global_episode, global_update_step, episodes_queue, best_reward)) - p.start() - processes.append(p) + p.start() + processes.append(p) + else: + print("training with train_single_thread") + if args.debug: + # or debug the single thread funcs + global_episode = Value("i", 0) + global_update_step = Value("i", 0) + episodes_queue = mp.Queue() + + args.thread = 1 + train_single(actor, critic, dynamics, target_actor, target_critic, target_dynamics, args, act_update_fns, global_episode, global_update_step, episodes_queue) + + args.thread = 2 + play_single(actor, critic, dynamics, target_actor, target_critic, target_dynamics, args, act_update_fns, global_episode, global_update_step, episodes_queue, best_reward) + else: + global_episode = Value("i", 0) + global_update_step = Value("i", 0) + episodes_queue = mp.Queue() + for rank in range(args.num_threads): + args.thread = rank + if rank < args.num_train_threads: + p = mp.Process( + target=train_single, + args=(actor, critic, dynamics, target_actor, target_critic, target_dynamics, args, act_update_fns, + global_episode, global_update_step, episodes_queue)) + else: + p = mp.Process( + target=play_single, + args=(actor, critic, dynamics, target_actor, target_critic, target_dynamics, args, act_update_fns, + global_episode, global_update_step, episodes_queue, + best_reward)) + p.start() + processes.append(p) for p in processes: p.join()