mirror of
https://github.com/wassname/pytorch-soft-actor-critic.git
synced 2026-06-27 16:46:28 +08:00
tune tau etc
This commit is contained in:
@@ -1,5 +1,8 @@
|
||||
python=/home/wassname/anaconda/envs/diy-gym2/bin/python
|
||||
date=2021-01-03_13-30-07
|
||||
run:
|
||||
python main.py --demonstrations data/demonstrations --tau 1 --target_update_interval 100
|
||||
${python} main.py --demonstrations data/demonstrations --cuda --updates_per_step 2
|
||||
|
||||
play:
|
||||
python play.py --load-actor models/actor_2021-01-02_10-26-23_SAC_ApplePick-v0_Gaussian_autotune.pkl --load-critic models/critic_2021-01-02_10-26-23_SAC_ApplePick-v0_Gaussian_autotune.pkl --render
|
||||
${python} play.py --load-actor models/actor_${date}_SAC_ApplePick-v0_Gaussian_autotune.pkl --load-critic models/critic_${date}_SAC_ApplePick-v0_Gaussian_autotune.pkl --render
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@ import datetime
|
||||
import gym
|
||||
import numpy as np
|
||||
import itertools
|
||||
from pathlib import Path
|
||||
import torch
|
||||
from sac import SAC
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
@@ -73,34 +74,39 @@ log_name = '{}_SAC_{}_{}_{}'.format(datetime.datetime.now().strftime("%Y-%m-%d_%
|
||||
args.policy, "autotune" if args.automatic_entropy_tuning else "")
|
||||
writer = SummaryWriter('runs/' + log_name)
|
||||
|
||||
save_dir=Path("models") / log_name
|
||||
|
||||
# Memory
|
||||
memory=ReplayMemory(args.replay_size, args.seed)
|
||||
if args.demonstrations:
|
||||
load_demonstrations(memory, args.demonstrations)
|
||||
|
||||
def save():
|
||||
agent.save_model(args.env_name, "", "models/actor_" + log_name+'.pkl', "models/critic_"+log_name+'.pkl')
|
||||
memory.save(args.env_name, "", "models/memory_" + log_name +'.pkl')
|
||||
|
||||
def load(log_name):
|
||||
agent.load_model("models/actor_" + log_name + '.pkl', "models/critic_" + log_name + '.pkl')
|
||||
memory.load("models/memory_" + log_name +'.pkl')
|
||||
def save(save_dir):
|
||||
save_dir.mkdir(exist_ok=True)
|
||||
agent.save_model(save_dir/'actor.pkl', save_dir/'critic.pkl')
|
||||
memory.save(save_dir/'memory.pkl')
|
||||
|
||||
def load(save_dir):
|
||||
agent.load_model(save_dir/'actor.pkl', save_dir/'critic.pkl')
|
||||
memory.load(save_dir/'memory.pkl')
|
||||
|
||||
if args.load:
|
||||
load(args.load)
|
||||
|
||||
if args.demonstrations:
|
||||
load_demonstrations(memory, args.demonstrations)
|
||||
|
||||
# Training Loop
|
||||
total_numsteps = 0
|
||||
updates = 0
|
||||
|
||||
with tqdm(unit='frames') as prog:
|
||||
with tqdm(unit='steps', mininterval=5) as prog:
|
||||
for i_episode in itertools.count(1):
|
||||
episode_reward = 0
|
||||
episode_steps = 0
|
||||
done = False
|
||||
state = env.reset()
|
||||
|
||||
while not done:
|
||||
for i_step in itertools.count(1):
|
||||
if args.start_steps > total_numsteps:
|
||||
action = env.action_space.sample() # Sample random action
|
||||
else:
|
||||
@@ -117,21 +123,24 @@ with tqdm(unit='frames') as prog:
|
||||
writer.add_scalar('loss/policy', policy_loss, updates)
|
||||
writer.add_scalar('loss/entropy_loss', ent_loss, updates)
|
||||
writer.add_scalar('entropy_temperature/alpha', alpha, updates)
|
||||
|
||||
updates += 1
|
||||
|
||||
next_state, reward, done, info = env.step(action) # Step
|
||||
next_state, reward, done, info = env.step(action) # Step
|
||||
episode_steps += 1
|
||||
total_numsteps += 1
|
||||
prog.update(1)
|
||||
episode_reward += reward
|
||||
|
||||
prog.update(1)
|
||||
prog.desc = f'er={episode_reward/episode_steps:2.2f}'
|
||||
# for k, v in info.items():
|
||||
# if len(v) == 1:
|
||||
# writer.add_scalar('env/'+k, v, episode_steps)
|
||||
# log env stuff
|
||||
for k in ['env_reward/apple_pick/tree/min_fruit_dist_reward',
|
||||
'env_reward/apple_pick/tree/gripping_fruit_reward',
|
||||
'env_reward/apple_pick/tree/force_tree_reward',
|
||||
'env_reward/apple_pick/tree/force_fruit_reward']:
|
||||
writer.add_scalar(k, info[k], episode_steps)
|
||||
|
||||
# Ignore the "done" signal if it comes from hitting the time horizon.
|
||||
# (https://github.com/openai/spinningup/blob/master/spinup/algos/sac/sac.py)
|
||||
# Ignore the "done" signal if it comes from hitting the time horizon. (that is, when it's an artificial terminal signal that isn't based on the agent's state)
|
||||
# (https://github.com/openai/spinningup/blob/master/spinup/algos/pytorch/sac/sac.py)
|
||||
mask = 1 if episode_steps == env._max_episode_steps else float(not done)
|
||||
|
||||
memory.push(state, action, reward, next_state, mask) # Append transition to memory
|
||||
@@ -165,11 +174,11 @@ with tqdm(unit='frames') as prog:
|
||||
|
||||
writer.add_scalar('avg_reward/test', avg_reward, i_episode)
|
||||
|
||||
save()
|
||||
save(save_dir)
|
||||
|
||||
print("----------------------------------------")
|
||||
print("Test Episodes: {}, Avg. Reward: {}".format(episodes, round(avg_reward, 2)))
|
||||
print("----------------------------------------")
|
||||
|
||||
env.close()
|
||||
save()
|
||||
save(save_dir)
|
||||
|
||||
+4
-8
@@ -1,5 +1,6 @@
|
||||
import random
|
||||
import numpy as np
|
||||
import torch
|
||||
import pickle
|
||||
import os
|
||||
|
||||
@@ -24,16 +25,11 @@ class ReplayMemory:
|
||||
def __len__(self):
|
||||
return len(self.buffer)
|
||||
|
||||
def save(self, env_name, suffix="", memory_path=None):
|
||||
if not os.path.exists('models/'):
|
||||
os.makedirs('models/')
|
||||
|
||||
if memory_path is None:
|
||||
memory_path = "models/memory_buffer_{}_{}".format(env_name, suffix)
|
||||
def save(self, memory_path=None):
|
||||
print('Saving memory to {}'.format(memory_path))
|
||||
pickle.dump(self.buffer, open(memory_path, 'wb'))
|
||||
torch.save(self.buffer, memory_path)
|
||||
|
||||
def load(self, memory_path):
|
||||
print('Loading memory from {}'.format(memory_path))
|
||||
if memory_path is not None:
|
||||
self.buffer = pickle.load(open(memory_path, 'rb'))
|
||||
self.buffer = torch.load(memory_path)
|
||||
|
||||
@@ -104,14 +104,7 @@ class SAC(object):
|
||||
return qf1_loss.item(), qf2_loss.item(), policy_loss.item(), alpha_loss.item(), alpha_tlogs.item()
|
||||
|
||||
# Save model parameters
|
||||
def save_model(self, env_name, suffix="", actor_path=None, critic_path=None):
|
||||
if not os.path.exists('models/'):
|
||||
os.makedirs('models/')
|
||||
|
||||
if actor_path is None:
|
||||
actor_path = "models/sac_actor_{}_{}".format(env_name, suffix)
|
||||
if critic_path is None:
|
||||
critic_path = "models/sac_critic_{}_{}".format(env_name, suffix)
|
||||
def save_model(self, actor_path=None, critic_path=None):
|
||||
print('Saving models to {} and {}'.format(actor_path, critic_path))
|
||||
torch.save(self.policy.state_dict(), actor_path)
|
||||
torch.save(self.critic.state_dict(), critic_path)
|
||||
|
||||
Reference in New Issue
Block a user