add debug arg

This commit is contained in:
wassname
2018-01-19 14:09:23 +08:00
parent 66d6a74093
commit 8d9e2024c1
+45 -36
View File
@@ -30,6 +30,7 @@ def parse_args():
parser.add_argument('--fail-reward', type=float, default=0.0)
parser.add_argument('--reward-scale', type=float, default=1.)
boolean_flag(parser, "flip-state-action", default=False)
boolean_flag(parser, 'debug', default=False, help="Run in single threaded mode for debugging")
for agent in ["actor", "critic", "dynamics"]:
parser.add_argument('--{}-layers'.format(agent), type=str, default="64-64")
@@ -196,45 +197,53 @@ def train(args, model_fn, act_update_fns, multi_thread, train_single, play_singl
try:
if args.num_threads == args.num_train_threads == 1:
# # run a single thread in the foreground so we can debug easier
# args.thread = 1
# multi_thread(actor, critic, dynamics, target_actor, target_critic, target_dynamics, args, act_update_fns, best_reward)
# or debug the single thread funcs
args.thread = 1
global_episode = Value("i", 0)
global_update_step = Value("i", 0)
episodes_queue = mp.Queue()
train_single(actor, critic, dynamics, target_actor, target_critic, target_dynamics, args, act_update_fns, global_episode, global_update_step, episodes_queue)
elif args.num_threads == args.num_train_threads:
for rank in range(args.num_threads):
args.thread = rank
p = mp.Process(
target=multi_thread,
args=(actor, critic, dynamics, target_actor, target_critic, target_dynamics, args, act_update_fns,
best_reward))
p.start()
processes.append(p)
else:
global_episode = Value("i", 0)
global_update_step = Value("i", 0)
episodes_queue = mp.Queue()
for rank in range(args.num_threads):
args.thread = rank
if rank < args.num_train_threads:
if args.num_threads == args.num_train_threads:
print("training with train_multi_thread")
if args.debug:
# # run a single thread in the foreground so we can debug easier
args.thread = 1
multi_thread(actor, critic, dynamics, target_actor, target_critic, target_dynamics, args, act_update_fns, best_reward)
else:
for rank in range(args.num_threads):
args.thread = rank
p = mp.Process(
target=train_single,
target=multi_thread,
args=(actor, critic, dynamics, target_actor, target_critic, target_dynamics, args, act_update_fns,
global_episode, global_update_step, episodes_queue))
else:
p = mp.Process(
target=play_single,
args=(actor, critic, dynamics, target_actor, target_critic, target_dynamics, args, act_update_fns,
global_episode, global_update_step, episodes_queue,
best_reward))
p.start()
processes.append(p)
p.start()
processes.append(p)
else:
print("training with train_single_thread")
if args.debug:
# or debug the single thread funcs
global_episode = Value("i", 0)
global_update_step = Value("i", 0)
episodes_queue = mp.Queue()
args.thread = 1
train_single(actor, critic, dynamics, target_actor, target_critic, target_dynamics, args, act_update_fns, global_episode, global_update_step, episodes_queue)
args.thread = 2
play_single(actor, critic, dynamics, target_actor, target_critic, target_dynamics, args, act_update_fns, global_episode, global_update_step, episodes_queue, best_reward)
else:
global_episode = Value("i", 0)
global_update_step = Value("i", 0)
episodes_queue = mp.Queue()
for rank in range(args.num_threads):
args.thread = rank
if rank < args.num_train_threads:
p = mp.Process(
target=train_single,
args=(actor, critic, dynamics, target_actor, target_critic, target_dynamics, args, act_update_fns,
global_episode, global_update_step, episodes_queue))
else:
p = mp.Process(
target=play_single,
args=(actor, critic, dynamics, target_actor, target_critic, target_dynamics, args, act_update_fns,
global_episode, global_update_step, episodes_queue,
best_reward))
p.start()
processes.append(p)
for p in processes:
p.join()