diff --git a/load_demonstrations.py b/load_demonstrations.py index 63dcae7..c6e21f6 100644 --- a/load_demonstrations.py +++ b/load_demonstrations.py @@ -7,10 +7,11 @@ from gym_recording_modified.playback import get_recordings from tqdm.auto import tqdm from replay_memory import ReplayMemory from pathlib import Path +from loguru import logger def load_demonstrations(mem: ReplayMemory, recordings: Path): records = get_recordings(str(recordings)) - print('picks in recordings', sum(records['reward']>10)) + logger.info('picks in recordings', sum(records['reward']>10)) ends=records["episodes_end_point"] for i in tqdm(range(len(ends)-1), desc='loading demonstrations'): a = ends[i] diff --git a/replay_memory.py b/replay_memory.py index 1675be3..37d6580 100644 --- a/replay_memory.py +++ b/replay_memory.py @@ -1,8 +1,9 @@ import random import numpy as np import torch -import pickle +import hickle import os +from loguru import logger class ReplayMemory: def __init__(self, capacity, seed): @@ -26,12 +27,11 @@ class ReplayMemory: return len(self.buffer) def save(self, memory_path=None): - print('Saving memory to {}'.format(memory_path)) - torch.save(self.buffer, memory_path) + logger.info(f'Saving memory to {memory_path}') + hickle.save(self.buffer, memory_path, compression='gzip', shuffle=True) def load(self, memory_path): - print('Loading memory from {}'.format(memory_path)) + logger.info('Loading memory from {memory_path}') if memory_path is not None: - # print(self.buffer[0]) - self.buffer = torch.load(memory_path) + self.buffer = hickle.load(memory_path) self.position = len(self.buffer) diff --git a/sac.py b/sac.py index 7743026..216a1cf 100644 --- a/sac.py +++ b/sac.py @@ -4,6 +4,7 @@ import torch.nn.functional as F from torch.optim import Adam from utils import soft_update, hard_update from model import GaussianPolicy, QNetwork, DeterministicPolicy +from loguru import logger class SAC(object): @@ -105,12 +106,13 @@ class SAC(object): # Save model parameters def save_model(self, actor_path=None, critic_path=None): + logger.debug(f'saving models to {actor_path} and {critic_path}')) torch.save(self.policy.state_dict(), actor_path) torch.save(self.critic.state_dict(), critic_path) # Load model parameters def load_model(self, actor_path, critic_path): - print('Loading models from {} and {}'.format(actor_path, critic_path)) + logger.info(f'Loading models from {actor_path} and {critic_path}')) if actor_path is not None: self.policy.load_state_dict(torch.load(actor_path)) if critic_path is not None: