mirror of
https://github.com/wassname/pytorch-soft-actor-critic.git
synced 2026-06-27 16:46:28 +08:00
misc
This commit is contained in:
@@ -5,7 +5,14 @@ LOGURU_LEVEL=INFO
|
|||||||
run:
|
run:
|
||||||
ulimit -S -m 65000000
|
ulimit -S -m 65000000
|
||||||
ulimit -S -v 65000000
|
ulimit -S -v 65000000
|
||||||
LOGURU_LEVEL=INFO ${python} main.py --cuda --automatic_entropy_tuning true --replay_size 50000 --load auto
|
LOGURU_LEVEL=INFO ${python} \
|
||||||
|
-m pdb -c continue \
|
||||||
|
main.py \
|
||||||
|
--cuda \
|
||||||
|
--automatic_entropy_tuning true \
|
||||||
|
--replay_size 10000 \
|
||||||
|
--demonstrations data/demonstrations \
|
||||||
|
# --load auto \
|
||||||
# ${python} -m pdb main.py --cuda --automatic_entropy_tuning true --replay_size 10000 --load auto --start_steps 200
|
# ${python} -m pdb main.py --cuda --automatic_entropy_tuning true --replay_size 10000 --load auto --start_steps 200
|
||||||
# LOGURU_LEVEL=INFO ${python} main.py --demonstrations data/demonstrations --cuda --automatic_entropy_tuning true --replay_size 20000 --load auto
|
# LOGURU_LEVEL=INFO ${python} main.py --demonstrations data/demonstrations --cuda --automatic_entropy_tuning true --replay_size 20000 --load auto
|
||||||
# LOGURU_LEVEL=INFO ${python} main.py --demonstrations data/demonstrations --cuda --updates_per_step 2 --load auto --alpha 0.1 --tau 1 --target_update_interval 1000
|
# LOGURU_LEVEL=INFO ${python} main.py --demonstrations data/demonstrations --cuda --updates_per_step 2 --load auto --alpha 0.1 --tau 1 --target_update_interval 1000
|
||||||
|
|||||||
@@ -12,7 +12,6 @@ from load_demonstrations import load_demonstrations
|
|||||||
import apple_gym.env
|
import apple_gym.env
|
||||||
import pickle
|
import pickle
|
||||||
from process_obs import ProcessObservation
|
from process_obs import ProcessObservation
|
||||||
# from torchinfo import summary
|
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
from progress import RichTQDM
|
from progress import RichTQDM
|
||||||
@@ -98,8 +97,14 @@ logger.info(f"process_obs reduces obs_space {env.observation_space.shape[0]}-{pr
|
|||||||
# Agent
|
# Agent
|
||||||
agent = SAC(observation_dim, env.action_space, args, process_obs)
|
agent = SAC(observation_dim, env.action_space, args, process_obs)
|
||||||
|
|
||||||
# TODO
|
# from torchinfo import summary
|
||||||
# summary(model, input_size=(batch_size, 1, 28, 28))
|
# print('process_obs')
|
||||||
|
# summary(process_obs, input_size=(2, *env.observation_space.shape), depth=2)
|
||||||
|
# print('critic')
|
||||||
|
# summary(agent.critic, input_size=((2, observation_dim), (2, action_dim)))
|
||||||
|
# print('policy')
|
||||||
|
# summary(agent.policy, input_size=(2, observation_dim))
|
||||||
|
# # print(process_obs, agent.critic, agent.policy)
|
||||||
|
|
||||||
#Tensorboard
|
#Tensorboard
|
||||||
log_name = '{}_SAC_{}_{}_{}'.format(datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S"), args.env_name,
|
log_name = '{}_SAC_{}_{}_{}'.format(datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S"), args.env_name,
|
||||||
@@ -117,13 +122,21 @@ def save(save_dir):
|
|||||||
try:
|
try:
|
||||||
save_dir.mkdir(exist_ok=True)
|
save_dir.mkdir(exist_ok=True)
|
||||||
logger.info(f'Saving to {save_dir}')
|
logger.info(f'Saving to {save_dir}')
|
||||||
agent.save_model(save_dir/'actor.pkl', save_dir/'critic.pkl')
|
agent.save_model(
|
||||||
|
save_dir / 'actor.pkl',
|
||||||
|
save_dir / 'critic.pkl',
|
||||||
|
save_dir / 'process_obs.pkl'
|
||||||
|
)
|
||||||
# memory.save(save_dir / 'memory.pkl') # crashes at over 200k
|
# memory.save(save_dir / 'memory.pkl') # crashes at over 200k
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.exception("failed to save")
|
logging.exception("failed to save")
|
||||||
|
|
||||||
def load(save_dir):
|
def load(save_dir):
|
||||||
agent.load_model(save_dir / 'actor.pkl', save_dir / 'critic.pkl')
|
agent.load_model(
|
||||||
|
save_dir / 'actor.pkl',
|
||||||
|
save_dir / 'critic.pkl',
|
||||||
|
save_dir / 'process_obs.pkl'
|
||||||
|
)
|
||||||
# if args.train:
|
# if args.train:
|
||||||
# memory.load(save_dir/'memory.pkl')
|
# memory.load(save_dir/'memory.pkl')
|
||||||
|
|
||||||
@@ -145,10 +158,9 @@ updates = 0
|
|||||||
|
|
||||||
with RichTQDM() as prog:
|
with RichTQDM() as prog:
|
||||||
task1 = prog.add_task("[red]steps", total=args.num_steps)
|
task1 = prog.add_task("[red]steps", total=args.num_steps)
|
||||||
task2 = prog.add_task("[red]updates", total=args.num_steps)
|
task2 = prog.add_task("[blue]updates", total=args.num_steps)
|
||||||
task3 = prog.add_task("[red]test", total=args.num_steps)
|
task3 = prog.add_task("[green]test", total=args.num_steps)
|
||||||
for i_episode in itertools.count(0):
|
for i_episode in itertools.count(0):
|
||||||
print('1')
|
|
||||||
episode_reward = 0
|
episode_reward = 0
|
||||||
episode_steps = 0
|
episode_steps = 0
|
||||||
done = False
|
done = False
|
||||||
@@ -160,7 +172,7 @@ with RichTQDM() as prog:
|
|||||||
else:
|
else:
|
||||||
action = agent.select_action(state) # Sample action from policy
|
action = agent.select_action(state) # Sample action from policy
|
||||||
|
|
||||||
if len(memory) > args.batch_size:
|
if len(memory) > args.batch_size and (total_numsteps%20==0):
|
||||||
# Number of updates per step in environment
|
# Number of updates per step in environment
|
||||||
for i in range(args.updates_per_step):
|
for i in range(args.updates_per_step):
|
||||||
# Update parameters of all the networks
|
# Update parameters of all the networks
|
||||||
|
|||||||
+2
-2
@@ -122,7 +122,7 @@ class ProcessObservation(nn.Module):
|
|||||||
os.path.dirname(os.path.abspath(__file__)),
|
os.path.dirname(os.path.abspath(__file__)),
|
||||||
'data/nets/cornell-randsplit-rgbd-grconvnet3-drop1-ch16/epoch_30_iou_0.97.pt'
|
'data/nets/cornell-randsplit-rgbd-grconvnet3-drop1-ch16/epoch_30_iou_0.97.pt'
|
||||||
)
|
)
|
||||||
self.feature_extractor = GenerativeResnet3Headless().half()
|
self.feature_extractor = GenerativeResnet3Headless().train().half()
|
||||||
self.feature_extractor.load_state_dict(state_dict=torch.load(grconvnet3_path), strict=False)
|
self.feature_extractor.load_state_dict(state_dict=torch.load(grconvnet3_path), strict=False)
|
||||||
|
|
||||||
old_img_size = (res[0], res[1], 8)
|
old_img_size = (res[0], res[1], 8)
|
||||||
@@ -146,11 +146,11 @@ class ProcessObservation(nn.Module):
|
|||||||
# make a batch
|
# make a batch
|
||||||
x = torch.cat([base_rgbd, arm_rgbd], 0)
|
x = torch.cat([base_rgbd, arm_rgbd], 0)
|
||||||
x = x.permute((0, 3, 1, 2)) # to ((-1, 4, x, y))
|
x = x.permute((0, 3, 1, 2)) # to ((-1, 4, x, y))
|
||||||
x = x.half()
|
|
||||||
h = self.feature_extractor(x)
|
h = self.feature_extractor(x)
|
||||||
|
|
||||||
# undo fake batch
|
# undo fake batch
|
||||||
base_h, arm_h = h[:bs].reshape((bs, -1)), h[bs:].reshape((bs, -1))
|
base_h, arm_h = h[:bs].reshape((bs, -1)), h[bs:].reshape((bs, -1))
|
||||||
# add features together
|
# add features together
|
||||||
y = torch.cat([others, base_h, arm_h], 1)
|
y = torch.cat([others, base_h, arm_h], 1)
|
||||||
|
assert torch.isfinite(y).all()
|
||||||
return y
|
return y
|
||||||
|
|||||||
+32
-32
@@ -4,7 +4,7 @@ import torch
|
|||||||
import hickle
|
import hickle
|
||||||
import os
|
import os
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
# import bcolz
|
||||||
import lz4.frame
|
import lz4.frame
|
||||||
import cloudpickle as pickle
|
import cloudpickle as pickle
|
||||||
|
|
||||||
@@ -56,40 +56,40 @@ class ReplayMemory:
|
|||||||
self.position = len(self.buffer)
|
self.position = len(self.buffer)
|
||||||
|
|
||||||
|
|
||||||
class ReplayMemory2:
|
# class ReplayMemory:
|
||||||
def __init__(self, capacity, seed, observation_dim, action_dim):
|
# def __init__(self, capacity, seed, observation_dim, action_dim):
|
||||||
random.seed(seed)
|
# random.seed(seed)
|
||||||
self.capacity = capacity
|
# self.capacity = capacity
|
||||||
self._observations = np.zeros((capacity, observation_dim), dtype='float16')
|
# self._observations = (bcolz.zeros((capacity, observation_dim), dtype='float16'))
|
||||||
self._actions = np.zeros((capacity, action_dim))
|
# self._actions = (bcolz.zeros((capacity, action_dim)))
|
||||||
self._rewards = np.zeros((capacity, 1))
|
# self._rewards = (bcolz.zeros((capacity, 1)))
|
||||||
self._next_obs = np.zeros((capacity, observation_dim), dtype='float16')
|
# self._next_obs = (bcolz.zeros((capacity, observation_dim), dtype='float16'))
|
||||||
self._terminals = np.zeros((capacity, 1), dtype='uint8')
|
# self._terminals = (bcolz.zeros((capacity, 1), dtype='uint8'))
|
||||||
self.position = 0
|
# self.position = 0
|
||||||
self._size = 0
|
# self._size = 0
|
||||||
|
|
||||||
def push(self, state, action, reward, next_state, done):
|
# def push(self, state, action, reward, next_state, done):
|
||||||
self._observations[self.position] = state
|
# self._observations[self.position] = state
|
||||||
self._actions[self.position] = action
|
# self._actions[self.position] = action
|
||||||
self._rewards[self.position] = reward
|
# self._rewards[self.position] = reward
|
||||||
self._next_obs[self.position] = next_state
|
# self._next_obs[self.position] = next_state
|
||||||
self._terminals[self.position] = done
|
# self._terminals[self.position] = done
|
||||||
self.position = (self.position + 1) % self.capacity
|
# self.position = (self.position + 1) % self.capacity
|
||||||
if self._size<self.capacity:
|
# if self._size<self.capacity:
|
||||||
self._size += 1
|
# self._size += 1
|
||||||
|
|
||||||
def sample(self, batch_size):
|
# def sample(self, batch_size):
|
||||||
n = min(self.position, self.capacity)
|
# n = min(self.position, self.capacity)
|
||||||
indices = np.random.choice(n, size=batch_size)
|
# indices = np.random.choice(n, size=batch_size)
|
||||||
state = self._observations[indices]
|
# state = self._observations[indices]
|
||||||
action = self._actions[indices]
|
# action = self._actions[indices]
|
||||||
reward = self._rewards[indices]
|
# reward = self._rewards[indices]
|
||||||
next_state = self._next_obs[indices]
|
# next_state = self._next_obs[indices]
|
||||||
done = self._terminals[indices]
|
# done = self._terminals[indices]
|
||||||
return state, action, reward, next_state, done
|
# return state, action, reward, next_state, done
|
||||||
|
|
||||||
def __len__(self):
|
# def __len__(self):
|
||||||
return self._size
|
# return self._size
|
||||||
|
|
||||||
|
|
||||||
# class BatchedReplayMemory:
|
# class BatchedReplayMemory:
|
||||||
|
|||||||
@@ -14,28 +14,29 @@ class SAC(object):
|
|||||||
self.tau = args.tau
|
self.tau = args.tau
|
||||||
self.alpha = args.alpha
|
self.alpha = args.alpha
|
||||||
self.device = torch.device("cuda" if args.cuda else "cpu")
|
self.device = torch.device("cuda" if args.cuda else "cpu")
|
||||||
|
self.dtype = torch.float
|
||||||
|
|
||||||
self.policy_type = args.policy
|
self.policy_type = args.policy
|
||||||
self.target_update_interval = args.target_update_interval
|
self.target_update_interval = args.target_update_interval
|
||||||
self.automatic_entropy_tuning = args.automatic_entropy_tuning
|
self.automatic_entropy_tuning = args.automatic_entropy_tuning
|
||||||
|
|
||||||
self.process_obs = process_obs.to(self.device)
|
self.process_obs = process_obs.to(self.device).to(self.dtype)
|
||||||
self.critic = QNetwork(num_inputs, action_space.shape[0], args.hidden_size).to(device=self.device)
|
self.critic = QNetwork(num_inputs, action_space.shape[0], args.hidden_size).to(device=self.device).to(self.dtype)
|
||||||
self.critic_optim = Adam(
|
self.critic_optim = Adam(
|
||||||
list(self.critic.parameters()) + list(process_obs.parameters())
|
list(self.critic.parameters()) + list(process_obs.parameters())
|
||||||
, lr=args.lr)
|
, lr=args.lr)
|
||||||
|
|
||||||
self.critic_target = QNetwork(num_inputs, action_space.shape[0], args.hidden_size).to(self.device)
|
self.critic_target = QNetwork(num_inputs, action_space.shape[0], args.hidden_size).to(self.device).to(self.dtype)
|
||||||
hard_update(self.critic_target, self.critic)
|
hard_update(self.critic_target, self.critic)
|
||||||
|
|
||||||
if self.policy_type == "Gaussian":
|
if self.policy_type == "Gaussian":
|
||||||
# Target Entropy = −dim(A) (e.g. , -6 for HalfCheetah-v2) as given in the paper
|
# Target Entropy = −dim(A) (e.g. , -6 for HalfCheetah-v2) as given in the paper
|
||||||
if self.automatic_entropy_tuning is True:
|
if self.automatic_entropy_tuning is True:
|
||||||
self.target_entropy = -torch.prod(torch.Tensor(action_space.shape).to(self.device)).item()
|
self.target_entropy = -torch.prod(torch.Tensor(action_space.shape).to(self.device)).item()
|
||||||
self.log_alpha = torch.zeros(1, requires_grad=True, device=self.device)
|
self.log_alpha = torch.zeros(1, requires_grad=True, device=self.device, dtype=self.dtype)
|
||||||
self.alpha_optim = Adam([self.log_alpha], lr=args.lr)
|
self.alpha_optim = Adam([self.log_alpha], lr=args.lr)
|
||||||
|
|
||||||
self.policy = GaussianPolicy(num_inputs, action_space.shape[0], args.hidden_size, action_space).to(self.device)
|
self.policy = GaussianPolicy(num_inputs, action_space.shape[0], args.hidden_size, action_space).to(self.device).to(self.dtype)
|
||||||
self.policy_optim = Adam(
|
self.policy_optim = Adam(
|
||||||
list(self.policy.parameters()) + list(process_obs.parameters()),
|
list(self.policy.parameters()) + list(process_obs.parameters()),
|
||||||
lr=args.lr)
|
lr=args.lr)
|
||||||
@@ -43,14 +44,14 @@ class SAC(object):
|
|||||||
else:
|
else:
|
||||||
self.alpha = 0
|
self.alpha = 0
|
||||||
self.automatic_entropy_tuning = False
|
self.automatic_entropy_tuning = False
|
||||||
self.policy = DeterministicPolicy(num_inputs, action_space.shape[0], args.hidden_size, action_space).to(self.device)
|
self.policy = DeterministicPolicy(num_inputs, action_space.shape[0], args.hidden_size, action_space).to(self.device).to(self.dtype)
|
||||||
self.policy_optim = Adam(
|
self.policy_optim = Adam(
|
||||||
list(self.policy.parameters()) + list(process_obs.parameters()),
|
list(self.policy.parameters()) + list(process_obs.parameters()),
|
||||||
lr=args.lr)
|
lr=args.lr)
|
||||||
|
|
||||||
def select_action(self, obs, evaluate=False):
|
def select_action(self, obs, evaluate=False):
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
obs = torch.FloatTensor(obs).to(self.device).unsqueeze(0)
|
obs = torch.FloatTensor(obs).to(self.device).unsqueeze(0).to(self.dtype)
|
||||||
state = self.process_obs(obs)
|
state = self.process_obs(obs)
|
||||||
if evaluate is False:
|
if evaluate is False:
|
||||||
action, _, _ = self.policy.sample(state)
|
action, _, _ = self.policy.sample(state)
|
||||||
@@ -63,11 +64,11 @@ class SAC(object):
|
|||||||
# Sample a batch from memory
|
# Sample a batch from memory
|
||||||
obs_batch, action_batch, reward_batch, next_obs_batch, mask_batch = memory.sample(batch_size=batch_size)
|
obs_batch, action_batch, reward_batch, next_obs_batch, mask_batch = memory.sample(batch_size=batch_size)
|
||||||
|
|
||||||
obs_batch = torch.FloatTensor(obs_batch).to(self.device)
|
obs_batch = torch.FloatTensor(obs_batch).to(self.device).to(self.dtype)
|
||||||
next_obs_batch= torch.FloatTensor(next_obs_batch).to(self.device)
|
next_obs_batch= torch.FloatTensor(next_obs_batch).to(self.device).to(self.dtype)
|
||||||
action_batch = torch.FloatTensor(action_batch).to(self.device)
|
action_batch = torch.FloatTensor(action_batch).to(self.device).to(self.dtype)
|
||||||
reward_batch = torch.FloatTensor(reward_batch).to(self.device).unsqueeze(1)
|
reward_batch = torch.FloatTensor(reward_batch).to(self.device).unsqueeze(1).to(self.dtype)
|
||||||
mask_batch = torch.FloatTensor(mask_batch).to(self.device).unsqueeze(1)
|
mask_batch = torch.FloatTensor(mask_batch).to(self.device).unsqueeze(1).to(self.dtype)
|
||||||
|
|
||||||
|
|
||||||
state_batch = self.process_obs(obs_batch)
|
state_batch = self.process_obs(obs_batch)
|
||||||
@@ -83,6 +84,7 @@ class SAC(object):
|
|||||||
qf_loss = qf1_loss + qf2_loss
|
qf_loss = qf1_loss + qf2_loss
|
||||||
|
|
||||||
self.critic_optim.zero_grad()
|
self.critic_optim.zero_grad()
|
||||||
|
assert torch.isfinite(qf_loss).all()
|
||||||
qf_loss.backward()
|
qf_loss.backward()
|
||||||
self.critic_optim.step()
|
self.critic_optim.step()
|
||||||
|
|
||||||
@@ -95,6 +97,7 @@ class SAC(object):
|
|||||||
policy_loss = ((self.alpha * log_pi) - min_qf_pi).mean() # Jπ = 𝔼st∼D,εt∼N[α * logπ(f(εt;st)|st) − Q(st,f(εt;st))]
|
policy_loss = ((self.alpha * log_pi) - min_qf_pi).mean() # Jπ = 𝔼st∼D,εt∼N[α * logπ(f(εt;st)|st) − Q(st,f(εt;st))]
|
||||||
|
|
||||||
self.policy_optim.zero_grad()
|
self.policy_optim.zero_grad()
|
||||||
|
assert torch.isfinite(policy_loss).all()
|
||||||
policy_loss.backward()
|
policy_loss.backward()
|
||||||
self.policy_optim.step()
|
self.policy_optim.step()
|
||||||
|
|
||||||
@@ -108,7 +111,7 @@ class SAC(object):
|
|||||||
self.alpha = self.log_alpha.exp()
|
self.alpha = self.log_alpha.exp()
|
||||||
alpha_tlogs = self.alpha.clone() # For TensorboardX logs
|
alpha_tlogs = self.alpha.clone() # For TensorboardX logs
|
||||||
else:
|
else:
|
||||||
alpha_loss = torch.tensor(0.).to(self.device)
|
alpha_loss = torch.tensor(0.).to(self.device).to(self.dtype)
|
||||||
alpha_tlogs = torch.tensor(self.alpha) # For TensorboardX logs
|
alpha_tlogs = torch.tensor(self.alpha) # For TensorboardX logs
|
||||||
|
|
||||||
|
|
||||||
@@ -118,16 +121,19 @@ class SAC(object):
|
|||||||
return qf1_loss.item(), qf2_loss.item(), policy_loss.item(), alpha_loss.item(), alpha_tlogs.item()
|
return qf1_loss.item(), qf2_loss.item(), policy_loss.item(), alpha_loss.item(), alpha_tlogs.item()
|
||||||
|
|
||||||
# Save model parameters
|
# Save model parameters
|
||||||
def save_model(self, actor_path=None, critic_path=None):
|
def save_model(self, actor_path=None, critic_path=None, process_obs_path=None):
|
||||||
logger.debug(f'saving models to {actor_path} and {critic_path}')
|
logger.debug(f'saving models to {actor_path} and {critic_path} and {process_obs_path}')
|
||||||
torch.save(self.policy.state_dict(), actor_path)
|
torch.save(self.policy.state_dict(), actor_path)
|
||||||
torch.save(self.critic.state_dict(), critic_path)
|
torch.save(self.critic.state_dict(), critic_path)
|
||||||
|
torch.save(self.process_obs.state_dict(), process_obs_path)
|
||||||
|
|
||||||
# Load model parameters
|
# Load model parameters
|
||||||
def load_model(self, actor_path, critic_path):
|
def load_model(self, actor_path=None, critic_path=None, process_obs_path=None):
|
||||||
logger.info(f'Loading models from {actor_path} and {critic_path}')
|
logger.info(f'Loading models from {actor_path} and {critic_path} and {process_obs_path}')
|
||||||
if actor_path is not None:
|
if actor_path is not None:
|
||||||
self.policy.load_state_dict(torch.load(actor_path))
|
self.policy.load_state_dict(torch.load(actor_path))
|
||||||
if critic_path is not None:
|
if critic_path is not None:
|
||||||
self.critic.load_state_dict(torch.load(critic_path))
|
self.critic.load_state_dict(torch.load(critic_path))
|
||||||
|
if process_obs_path is not None:
|
||||||
|
self.process_obs.load_state_dict(torch.load(process_obs_path))
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user