This commit is contained in:
wassname
2021-01-16 16:41:14 +08:00
parent 90d207ca9b
commit 0805bfa98f
+15 -12
View File
@@ -13,7 +13,14 @@ from load_demonstrations import load_demonstrations
import apple_gym.env
import pickle
from tqdm.auto import tqdm
from loguru import logger
from rich import logger.info
from rich.logging import RichHandler
logging.basicConfig(level=logging.INFO, handlers=[RichHandler(rich_tracebacks=True, markup=True)])
logger.configure(handlers=[{"sink": RichHandler(markup=True),
"format": "{message}"}])
def get_args():
parser = argparse.ArgumentParser(description='PyTorch Soft Actor-Critic Args')
@@ -71,6 +78,10 @@ logger.info(f'args {args}')
env = gym.make(args.env_name, render=args.render)
env.seed(args.seed)
env.action_space.seed(args.seed)
keys_to_monitor = ['env_reward/apple_pick/tree/min_fruit_dist_reward',
'env_reward/apple_pick/tree/gripping_fruit_reward',
'env_reward/apple_pick/tree/force_tree_reward',
'env_reward/apple_pick/tree/force_fruit_reward', 'env_obs/apple_pick/tree/picks']:
torch.manual_seed(args.seed)
np.random.seed(args.seed)
@@ -156,10 +167,7 @@ with tqdm(unit='steps', mininterval=5) as prog:
if total_numsteps == 1:
logger.info(f'info {info.keys()}')
logger.debug(f'info {info}')
for k in ['env_reward/apple_pick/tree/min_fruit_dist_reward',
'env_reward/apple_pick/tree/gripping_fruit_reward',
'env_reward/apple_pick/tree/force_tree_reward',
'env_reward/apple_pick/tree/force_fruit_reward', 'env_obs/apple_pick/tree/picks']:
for k in keys_to_monitor:
writer.add_scalar(k, info[k], total_numsteps)
# Ignore the "done" signal if it comes from hitting the time horizon. (that is, when it's an artificial terminal signal that isn't based on the agent's state)
@@ -186,10 +194,7 @@ with tqdm(unit='steps', mininterval=5) as prog:
next_state, reward, done, _ = env.step(action)
episode_reward += reward
for k in ['env_reward/apple_pick/tree/min_fruit_dist_reward',
'env_reward/apple_pick/tree/gripping_fruit_reward',
'env_reward/apple_pick/tree/force_tree_reward',
'env_reward/apple_pick/tree/force_fruit_reward', 'env_obs/apple_pick/tree/picks']:
for k in keys_to_monitor:
writer.add_scalar('test/' + k, info[k], total_numsteps)
@@ -200,8 +205,6 @@ with tqdm(unit='steps', mininterval=5) as prog:
writer.add_scalar('avg_reward/test', avg_reward, i_episode)
if args.train:
save(save_dir)
logger.info("----------------------------------------")
logger.info("Test Episodes: {}, Avg. Reward: {}".format(episodes, round(avg_reward, 2)))
@@ -210,7 +213,7 @@ with tqdm(unit='steps', mininterval=5) as prog:
if total_numsteps >= args.num_steps:
break
if args.train:
save(save_dir)
env.close()
# if args.train:
# save(save_dir)