mirror of
https://github.com/wassname/pytorch-soft-actor-critic.git
synced 2026-06-27 18:06:10 +08:00
logging
This commit is contained in:
@@ -7,10 +7,11 @@ from gym_recording_modified.playback import get_recordings
|
||||
from tqdm.auto import tqdm
|
||||
from replay_memory import ReplayMemory
|
||||
from pathlib import Path
|
||||
from loguru import logger
|
||||
|
||||
def load_demonstrations(mem: ReplayMemory, recordings: Path):
|
||||
records = get_recordings(str(recordings))
|
||||
print('picks in recordings', sum(records['reward']>10))
|
||||
logger.info('picks in recordings', sum(records['reward']>10))
|
||||
ends=records["episodes_end_point"]
|
||||
for i in tqdm(range(len(ends)-1), desc='loading demonstrations'):
|
||||
a = ends[i]
|
||||
|
||||
+6
-6
@@ -1,8 +1,9 @@
|
||||
import random
|
||||
import numpy as np
|
||||
import torch
|
||||
import pickle
|
||||
import hickle
|
||||
import os
|
||||
from loguru import logger
|
||||
|
||||
class ReplayMemory:
|
||||
def __init__(self, capacity, seed):
|
||||
@@ -26,12 +27,11 @@ class ReplayMemory:
|
||||
return len(self.buffer)
|
||||
|
||||
def save(self, memory_path=None):
|
||||
print('Saving memory to {}'.format(memory_path))
|
||||
torch.save(self.buffer, memory_path)
|
||||
logger.info(f'Saving memory to {memory_path}')
|
||||
hickle.save(self.buffer, memory_path, compression='gzip', shuffle=True)
|
||||
|
||||
def load(self, memory_path):
|
||||
print('Loading memory from {}'.format(memory_path))
|
||||
logger.info('Loading memory from {memory_path}')
|
||||
if memory_path is not None:
|
||||
# print(self.buffer[0])
|
||||
self.buffer = torch.load(memory_path)
|
||||
self.buffer = hickle.load(memory_path)
|
||||
self.position = len(self.buffer)
|
||||
|
||||
@@ -4,6 +4,7 @@ import torch.nn.functional as F
|
||||
from torch.optim import Adam
|
||||
from utils import soft_update, hard_update
|
||||
from model import GaussianPolicy, QNetwork, DeterministicPolicy
|
||||
from loguru import logger
|
||||
|
||||
|
||||
class SAC(object):
|
||||
@@ -105,12 +106,13 @@ class SAC(object):
|
||||
|
||||
# Save model parameters
|
||||
def save_model(self, actor_path=None, critic_path=None):
|
||||
logger.debug(f'saving models to {actor_path} and {critic_path}'))
|
||||
torch.save(self.policy.state_dict(), actor_path)
|
||||
torch.save(self.critic.state_dict(), critic_path)
|
||||
|
||||
# Load model parameters
|
||||
def load_model(self, actor_path, critic_path):
|
||||
print('Loading models from {} and {}'.format(actor_path, critic_path))
|
||||
logger.info(f'Loading models from {actor_path} and {critic_path}'))
|
||||
if actor_path is not None:
|
||||
self.policy.load_state_dict(torch.load(actor_path))
|
||||
if critic_path is not None:
|
||||
|
||||
Reference in New Issue
Block a user