This commit is contained in:
wassname
2021-01-16 16:40:53 +08:00
parent cc6e0f2035
commit 90d207ca9b
3 changed files with 11 additions and 8 deletions
+2 -1
View File
@@ -7,10 +7,11 @@ from gym_recording_modified.playback import get_recordings
from tqdm.auto import tqdm
from replay_memory import ReplayMemory
from pathlib import Path
from loguru import logger
def load_demonstrations(mem: ReplayMemory, recordings: Path):
records = get_recordings(str(recordings))
print('picks in recordings', sum(records['reward']>10))
logger.info('picks in recordings', sum(records['reward']>10))
ends=records["episodes_end_point"]
for i in tqdm(range(len(ends)-1), desc='loading demonstrations'):
a = ends[i]
+6 -6
View File
@@ -1,8 +1,9 @@
import random
import numpy as np
import torch
import pickle
import hickle
import os
from loguru import logger
class ReplayMemory:
def __init__(self, capacity, seed):
@@ -26,12 +27,11 @@ class ReplayMemory:
return len(self.buffer)
def save(self, memory_path=None):
print('Saving memory to {}'.format(memory_path))
torch.save(self.buffer, memory_path)
logger.info(f'Saving memory to {memory_path}')
hickle.save(self.buffer, memory_path, compression='gzip', shuffle=True)
def load(self, memory_path):
print('Loading memory from {}'.format(memory_path))
logger.info('Loading memory from {memory_path}')
if memory_path is not None:
# print(self.buffer[0])
self.buffer = torch.load(memory_path)
self.buffer = hickle.load(memory_path)
self.position = len(self.buffer)
+3 -1
View File
@@ -4,6 +4,7 @@ import torch.nn.functional as F
from torch.optim import Adam
from utils import soft_update, hard_update
from model import GaussianPolicy, QNetwork, DeterministicPolicy
from loguru import logger
class SAC(object):
@@ -105,12 +106,13 @@ class SAC(object):
# Save model parameters
def save_model(self, actor_path=None, critic_path=None):
logger.debug(f'saving models to {actor_path} and {critic_path}'))
torch.save(self.policy.state_dict(), actor_path)
torch.save(self.critic.state_dict(), critic_path)
# Load model parameters
def load_model(self, actor_path, critic_path):
print('Loading models from {} and {}'.format(actor_path, critic_path))
logger.info(f'Loading models from {actor_path} and {critic_path}'))
if actor_path is not None:
self.policy.load_state_dict(torch.load(actor_path))
if critic_path is not None: