mirror of
https://github.com/wassname/pytorch-soft-actor-critic.git
synced 2026-06-27 16:46:28 +08:00
logging
This commit is contained in:
@@ -13,7 +13,14 @@ from load_demonstrations import load_demonstrations
|
||||
import apple_gym.env
|
||||
import pickle
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
|
||||
from loguru import logger
|
||||
from rich import logger.info
|
||||
from rich.logging import RichHandler
|
||||
logging.basicConfig(level=logging.INFO, handlers=[RichHandler(rich_tracebacks=True, markup=True)])
|
||||
logger.configure(handlers=[{"sink": RichHandler(markup=True),
|
||||
"format": "{message}"}])
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser(description='PyTorch Soft Actor-Critic Args')
|
||||
@@ -71,6 +78,10 @@ logger.info(f'args {args}')
|
||||
env = gym.make(args.env_name, render=args.render)
|
||||
env.seed(args.seed)
|
||||
env.action_space.seed(args.seed)
|
||||
keys_to_monitor = ['env_reward/apple_pick/tree/min_fruit_dist_reward',
|
||||
'env_reward/apple_pick/tree/gripping_fruit_reward',
|
||||
'env_reward/apple_pick/tree/force_tree_reward',
|
||||
'env_reward/apple_pick/tree/force_fruit_reward', 'env_obs/apple_pick/tree/picks']:
|
||||
|
||||
torch.manual_seed(args.seed)
|
||||
np.random.seed(args.seed)
|
||||
@@ -156,10 +167,7 @@ with tqdm(unit='steps', mininterval=5) as prog:
|
||||
if total_numsteps == 1:
|
||||
logger.info(f'info {info.keys()}')
|
||||
logger.debug(f'info {info}')
|
||||
for k in ['env_reward/apple_pick/tree/min_fruit_dist_reward',
|
||||
'env_reward/apple_pick/tree/gripping_fruit_reward',
|
||||
'env_reward/apple_pick/tree/force_tree_reward',
|
||||
'env_reward/apple_pick/tree/force_fruit_reward', 'env_obs/apple_pick/tree/picks']:
|
||||
for k in keys_to_monitor:
|
||||
writer.add_scalar(k, info[k], total_numsteps)
|
||||
|
||||
# Ignore the "done" signal if it comes from hitting the time horizon. (that is, when it's an artificial terminal signal that isn't based on the agent's state)
|
||||
@@ -186,10 +194,7 @@ with tqdm(unit='steps', mininterval=5) as prog:
|
||||
next_state, reward, done, _ = env.step(action)
|
||||
episode_reward += reward
|
||||
|
||||
for k in ['env_reward/apple_pick/tree/min_fruit_dist_reward',
|
||||
'env_reward/apple_pick/tree/gripping_fruit_reward',
|
||||
'env_reward/apple_pick/tree/force_tree_reward',
|
||||
'env_reward/apple_pick/tree/force_fruit_reward', 'env_obs/apple_pick/tree/picks']:
|
||||
for k in keys_to_monitor:
|
||||
writer.add_scalar('test/' + k, info[k], total_numsteps)
|
||||
|
||||
|
||||
@@ -200,8 +205,6 @@ with tqdm(unit='steps', mininterval=5) as prog:
|
||||
|
||||
writer.add_scalar('avg_reward/test', avg_reward, i_episode)
|
||||
|
||||
if args.train:
|
||||
save(save_dir)
|
||||
|
||||
logger.info("----------------------------------------")
|
||||
logger.info("Test Episodes: {}, Avg. Reward: {}".format(episodes, round(avg_reward, 2)))
|
||||
@@ -210,7 +213,7 @@ with tqdm(unit='steps', mininterval=5) as prog:
|
||||
if total_numsteps >= args.num_steps:
|
||||
break
|
||||
|
||||
if args.train:
|
||||
save(save_dir)
|
||||
|
||||
env.close()
|
||||
# if args.train:
|
||||
# save(save_dir)
|
||||
|
||||
Reference in New Issue
Block a user