diff --git a/main.py b/main.py index aeb5e84..fb7ebb7 100644 --- a/main.py +++ b/main.py @@ -13,7 +13,14 @@ from load_demonstrations import load_demonstrations import apple_gym.env import pickle from tqdm.auto import tqdm + + from loguru import logger +from rich import logger.info +from rich.logging import RichHandler +logging.basicConfig(level=logging.INFO, handlers=[RichHandler(rich_tracebacks=True, markup=True)]) +logger.configure(handlers=[{"sink": RichHandler(markup=True), + "format": "{message}"}]) def get_args(): parser = argparse.ArgumentParser(description='PyTorch Soft Actor-Critic Args') @@ -71,6 +78,10 @@ logger.info(f'args {args}') env = gym.make(args.env_name, render=args.render) env.seed(args.seed) env.action_space.seed(args.seed) +keys_to_monitor = ['env_reward/apple_pick/tree/min_fruit_dist_reward', + 'env_reward/apple_pick/tree/gripping_fruit_reward', + 'env_reward/apple_pick/tree/force_tree_reward', + 'env_reward/apple_pick/tree/force_fruit_reward', 'env_obs/apple_pick/tree/picks']: torch.manual_seed(args.seed) np.random.seed(args.seed) @@ -156,10 +167,7 @@ with tqdm(unit='steps', mininterval=5) as prog: if total_numsteps == 1: logger.info(f'info {info.keys()}') logger.debug(f'info {info}') - for k in ['env_reward/apple_pick/tree/min_fruit_dist_reward', - 'env_reward/apple_pick/tree/gripping_fruit_reward', - 'env_reward/apple_pick/tree/force_tree_reward', - 'env_reward/apple_pick/tree/force_fruit_reward', 'env_obs/apple_pick/tree/picks']: + for k in keys_to_monitor: writer.add_scalar(k, info[k], total_numsteps) # Ignore the "done" signal if it comes from hitting the time horizon. (that is, when it's an artificial terminal signal that isn't based on the agent's state) @@ -186,10 +194,7 @@ with tqdm(unit='steps', mininterval=5) as prog: next_state, reward, done, _ = env.step(action) episode_reward += reward - for k in ['env_reward/apple_pick/tree/min_fruit_dist_reward', - 'env_reward/apple_pick/tree/gripping_fruit_reward', - 'env_reward/apple_pick/tree/force_tree_reward', - 'env_reward/apple_pick/tree/force_fruit_reward', 'env_obs/apple_pick/tree/picks']: + for k in keys_to_monitor: writer.add_scalar('test/' + k, info[k], total_numsteps) @@ -200,8 +205,6 @@ with tqdm(unit='steps', mininterval=5) as prog: writer.add_scalar('avg_reward/test', avg_reward, i_episode) - if args.train: - save(save_dir) logger.info("----------------------------------------") logger.info("Test Episodes: {}, Avg. Reward: {}".format(episodes, round(avg_reward, 2))) @@ -210,7 +213,7 @@ with tqdm(unit='steps', mininterval=5) as prog: if total_numsteps >= args.num_steps: break + if args.train: + save(save_dir) env.close() -# if args.train: -# save(save_dir)