mirror of
https://github.com/wassname/Run-Skeleton-Run.git
synced 2026-06-27 17:14:10 +08:00
add debug arg
This commit is contained in:
+45
-36
@@ -30,6 +30,7 @@ def parse_args():
|
||||
parser.add_argument('--fail-reward', type=float, default=0.0)
|
||||
parser.add_argument('--reward-scale', type=float, default=1.)
|
||||
boolean_flag(parser, "flip-state-action", default=False)
|
||||
boolean_flag(parser, 'debug', default=False, help="Run in single threaded mode for debugging")
|
||||
|
||||
for agent in ["actor", "critic", "dynamics"]:
|
||||
parser.add_argument('--{}-layers'.format(agent), type=str, default="64-64")
|
||||
@@ -196,45 +197,53 @@ def train(args, model_fn, act_update_fns, multi_thread, train_single, play_singl
|
||||
|
||||
try:
|
||||
|
||||
if args.num_threads == args.num_train_threads == 1:
|
||||
# # run a single thread in the foreground so we can debug easier
|
||||
# args.thread = 1
|
||||
# multi_thread(actor, critic, dynamics, target_actor, target_critic, target_dynamics, args, act_update_fns, best_reward)
|
||||
|
||||
# or debug the single thread funcs
|
||||
args.thread = 1
|
||||
global_episode = Value("i", 0)
|
||||
global_update_step = Value("i", 0)
|
||||
episodes_queue = mp.Queue()
|
||||
train_single(actor, critic, dynamics, target_actor, target_critic, target_dynamics, args, act_update_fns, global_episode, global_update_step, episodes_queue)
|
||||
elif args.num_threads == args.num_train_threads:
|
||||
for rank in range(args.num_threads):
|
||||
args.thread = rank
|
||||
p = mp.Process(
|
||||
target=multi_thread,
|
||||
args=(actor, critic, dynamics, target_actor, target_critic, target_dynamics, args, act_update_fns,
|
||||
best_reward))
|
||||
p.start()
|
||||
processes.append(p)
|
||||
else:
|
||||
global_episode = Value("i", 0)
|
||||
global_update_step = Value("i", 0)
|
||||
episodes_queue = mp.Queue()
|
||||
for rank in range(args.num_threads):
|
||||
args.thread = rank
|
||||
if rank < args.num_train_threads:
|
||||
if args.num_threads == args.num_train_threads:
|
||||
print("training with train_multi_thread")
|
||||
if args.debug:
|
||||
# # run a single thread in the foreground so we can debug easier
|
||||
args.thread = 1
|
||||
multi_thread(actor, critic, dynamics, target_actor, target_critic, target_dynamics, args, act_update_fns, best_reward)
|
||||
else:
|
||||
for rank in range(args.num_threads):
|
||||
args.thread = rank
|
||||
p = mp.Process(
|
||||
target=train_single,
|
||||
target=multi_thread,
|
||||
args=(actor, critic, dynamics, target_actor, target_critic, target_dynamics, args, act_update_fns,
|
||||
global_episode, global_update_step, episodes_queue))
|
||||
else:
|
||||
p = mp.Process(
|
||||
target=play_single,
|
||||
args=(actor, critic, dynamics, target_actor, target_critic, target_dynamics, args, act_update_fns,
|
||||
global_episode, global_update_step, episodes_queue,
|
||||
best_reward))
|
||||
p.start()
|
||||
processes.append(p)
|
||||
p.start()
|
||||
processes.append(p)
|
||||
else:
|
||||
print("training with train_single_thread")
|
||||
if args.debug:
|
||||
# or debug the single thread funcs
|
||||
global_episode = Value("i", 0)
|
||||
global_update_step = Value("i", 0)
|
||||
episodes_queue = mp.Queue()
|
||||
|
||||
args.thread = 1
|
||||
train_single(actor, critic, dynamics, target_actor, target_critic, target_dynamics, args, act_update_fns, global_episode, global_update_step, episodes_queue)
|
||||
|
||||
args.thread = 2
|
||||
play_single(actor, critic, dynamics, target_actor, target_critic, target_dynamics, args, act_update_fns, global_episode, global_update_step, episodes_queue, best_reward)
|
||||
else:
|
||||
global_episode = Value("i", 0)
|
||||
global_update_step = Value("i", 0)
|
||||
episodes_queue = mp.Queue()
|
||||
for rank in range(args.num_threads):
|
||||
args.thread = rank
|
||||
if rank < args.num_train_threads:
|
||||
p = mp.Process(
|
||||
target=train_single,
|
||||
args=(actor, critic, dynamics, target_actor, target_critic, target_dynamics, args, act_update_fns,
|
||||
global_episode, global_update_step, episodes_queue))
|
||||
else:
|
||||
p = mp.Process(
|
||||
target=play_single,
|
||||
args=(actor, critic, dynamics, target_actor, target_critic, target_dynamics, args, act_update_fns,
|
||||
global_episode, global_update_step, episodes_queue,
|
||||
best_reward))
|
||||
p.start()
|
||||
processes.append(p)
|
||||
|
||||
for p in processes:
|
||||
p.join()
|
||||
|
||||
Reference in New Issue
Block a user