mirror of
https://github.com/wassname/pytorch-soft-actor-critic.git
synced 2026-06-27 17:01:47 +08:00
prcoess_obs works, with training by both obs and critic
This commit is contained in:
@@ -2,7 +2,7 @@ python=/home/wassname/anaconda/envs/diygym3/bin/python
|
||||
date=2021-01-03_13-30-07
|
||||
LOGURU_LEVEL=INFO
|
||||
run:
|
||||
LOGURU_LEVEL=INFO ${python} main.py --demonstrations data/demonstrations --cuda --updates_per_step 4 --automatic_entropy_tuning
|
||||
LOGURU_LEVEL=INFO ${python} main.py --demonstrations data/demonstrations --cuda --updates_per_step 4 --automatic_entropy_tuning true
|
||||
# LOGURU_LEVEL=INFO ${python} main.py --demonstrations data/demonstrations --cuda --updates_per_step 2 --load auto --alpha 0.1 --tau 1 --target_update_interval 1000
|
||||
# LOGURU_LEVEL=INFO ${python} main.py --demonstrations data/demonstrations --cuda --updates_per_step 2 --load auto --tau 1 --target_update_interval 1000 --policy Deterministic
|
||||
|
||||
|
||||
@@ -13,7 +13,7 @@ def load_demonstrations(mem: ReplayMemory, recordings: Path):
|
||||
records = get_recordings(str(recordings))
|
||||
logger.info('picks in recordings', sum(records['reward']>10))
|
||||
ends=records["episodes_end_point"]
|
||||
for i in tqdm(range(len(ends)-1), desc='loading demonstrations'):
|
||||
for i in range(len(ends)-1):
|
||||
a = ends[i]
|
||||
b = ends[i+1]
|
||||
for s in range(a+1, b):
|
||||
|
||||
@@ -12,11 +12,13 @@ from replay_memory import ReplayMemory
|
||||
from load_demonstrations import load_demonstrations
|
||||
import apple_gym.env
|
||||
import pickle
|
||||
from process_obs import ProcessObservation
|
||||
# from torchinfo import summary
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
|
||||
from loguru import logger
|
||||
from rich import logger.info
|
||||
from rich import print
|
||||
from rich.logging import RichHandler
|
||||
logging.basicConfig(level=logging.INFO, handlers=[RichHandler(rich_tracebacks=True, markup=True)])
|
||||
logger.configure(handlers=[{"sink": RichHandler(markup=True),
|
||||
@@ -81,13 +83,22 @@ env.action_space.seed(args.seed)
|
||||
keys_to_monitor = ['env_reward/apple_pick/tree/min_fruit_dist_reward',
|
||||
'env_reward/apple_pick/tree/gripping_fruit_reward',
|
||||
'env_reward/apple_pick/tree/force_tree_reward',
|
||||
'env_reward/apple_pick/tree/force_fruit_reward', 'env_obs/apple_pick/tree/picks']:
|
||||
'env_reward/apple_pick/tree/force_fruit_reward', 'env_obs/apple_pick/tree/picks']
|
||||
|
||||
torch.manual_seed(args.seed)
|
||||
np.random.seed(args.seed)
|
||||
|
||||
# A visual network
|
||||
observation_space=env.observation_space.shape[0]
|
||||
process_obs=ProcessObservation()
|
||||
observation_space=observation_space - process_obs.reduce_action_space
|
||||
logger.info(f"process_obs reduces obs_space {env.observation_space.shape[0]}-{process_obs.reduce_action_space}={observation_space}")
|
||||
|
||||
# Agent
|
||||
agent = SAC(env.observation_space.shape[0], env.action_space, args)
|
||||
agent = SAC(observation_space, env.action_space, args, process_obs)
|
||||
|
||||
# TODO
|
||||
# summary(model, input_size=(batch_size, 1, 28, 28))
|
||||
|
||||
#Tensorboard
|
||||
log_name = '{}_SAC_{}_{}_{}'.format(datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S"), args.env_name,
|
||||
@@ -132,6 +143,7 @@ updates = 0
|
||||
|
||||
with tqdm(unit='steps', mininterval=5) as prog:
|
||||
for i_episode in itertools.count(0):
|
||||
print('1')
|
||||
episode_reward = 0
|
||||
episode_steps = 0
|
||||
done = False
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
|
||||
import os
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
@@ -42,7 +44,7 @@ class GenerativeResnet3Headless(nn.Module):
|
||||
self.res2 = ResidualBlock(channel_size * 4, channel_size * 4)
|
||||
self.res3 = ResidualBlock(channel_size * 4, channel_size * 4)
|
||||
self.res4 = ResidualBlock(channel_size * 4, channel_size * 4)
|
||||
self.res5 = ResidualBlock(channel_size * 4, channel_size * 4)
|
||||
|
||||
|
||||
self.conv4 = nn.ConvTranspose2d(channel_size * 4, channel_size * 2, kernel_size=4, stride=2, padding=1,
|
||||
output_padding=1)
|
||||
@@ -65,9 +67,17 @@ class GenerativeResnet3Headless(nn.Module):
|
||||
self.dropout_sin = nn.Dropout(p=prob)
|
||||
self.dropout_wid = nn.Dropout(p=prob)
|
||||
|
||||
# freeze above params
|
||||
for param in self.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
self.res5 = ResidualBlock(channel_size * 4, channel_size * 4)
|
||||
|
||||
self.head = nn.Conv2d(64, 4, 1, bias=False)
|
||||
|
||||
def forward(self, x_in):
|
||||
# Freeze these layers
|
||||
with torch.no_grad():
|
||||
x = F.relu(self.bn1(self.conv1(x_in)))
|
||||
x = F.relu(self.bn2(self.conv2(x)))
|
||||
x = F.relu(self.bn3(self.conv3(x)))
|
||||
@@ -75,11 +85,15 @@ class GenerativeResnet3Headless(nn.Module):
|
||||
x = self.res2(x)
|
||||
x = self.res3(x)
|
||||
x = self.res4(x)
|
||||
|
||||
x = self.res5(x)
|
||||
|
||||
# 1x1 conv to reduce feature state, with random weights
|
||||
# 1x1 conv to reduce feature state, init with random weights
|
||||
x = self.head(x)
|
||||
x = F.max_pool2d(x, kernel_size=3, stride=2)
|
||||
x = F.max_pool2d(x, kernel_size=3, stride=2)
|
||||
|
||||
# ignore the old head which made it larger
|
||||
# x = F.relu(self.bn4(self.conv4(x)))
|
||||
# x = F.relu(self.bn5(self.conv5(x)))
|
||||
# x = self.conv6(x)
|
||||
@@ -96,3 +110,46 @@ class GenerativeResnet3Headless(nn.Module):
|
||||
# width_output = self.width_output(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class ProcessObservation(nn.Module):
|
||||
def __init__(self, res=(224, 224)):
|
||||
super().__init__()
|
||||
self.res = res
|
||||
|
||||
# Load visual model
|
||||
grconvnet3_path = os.path.join(
|
||||
os.path.dirname(os.path.abspath(__file__)),
|
||||
'data/nets/cornell-randsplit-rgbd-grconvnet3-drop1-ch16/epoch_30_iou_0.97.pt'
|
||||
)
|
||||
self.feature_extractor = GenerativeResnet3Headless().eval()
|
||||
self.feature_extractor.load_state_dict(state_dict=torch.load(grconvnet3_path))
|
||||
|
||||
old_img_size = (res[0], res[1], 8)
|
||||
new_img_size = (res[0]//16-1, res[1]//16-1, 8)
|
||||
self.reduce_action_space = int(np.prod(old_img_size) - np.prod(new_img_size))
|
||||
|
||||
def __call__(self, obs):
|
||||
"""
|
||||
Takes in a torch array of observations, processes the images into features.
|
||||
|
||||
This assumes the observations ends in 2 rgbd images with shape (224, 244, 4)
|
||||
"""
|
||||
# import pdb; pdb.set_trace()
|
||||
h, w = self.res
|
||||
px = h * w
|
||||
base_rgbd = obs[:, -px * 4:].reshape((-1, h, w, 4))
|
||||
arm_rgbd = obs[:, -px * 8: - px * 4].reshape((-1, h, w, 4))
|
||||
others = obs[:,: - px * 8]
|
||||
bs = obs.shape[0]
|
||||
|
||||
# make a batch
|
||||
x = torch.cat([base_rgbd, arm_rgbd], 0)
|
||||
x = x.permute((0, 3, 1, 2)) # to ((-1, 4, x, y))
|
||||
h = self.feature_extractor(x)
|
||||
|
||||
# undo fake batch
|
||||
base_h, arm_h = h[:bs].reshape((bs, -1)), h[bs:].reshape((bs, -1))
|
||||
# add features together
|
||||
y = torch.cat([others, base_h, arm_h], 1)
|
||||
return y
|
||||
@@ -13,18 +13,19 @@ class SAC(object):
|
||||
self.gamma = args.gamma
|
||||
self.tau = args.tau
|
||||
self.alpha = args.alpha
|
||||
self.process_obs = process_obs
|
||||
self.device = torch.device("cuda" if args.cuda else "cpu")
|
||||
|
||||
self.policy_type = args.policy
|
||||
self.target_update_interval = args.target_update_interval
|
||||
self.automatic_entropy_tuning = args.automatic_entropy_tuning
|
||||
|
||||
self.device = torch.device("cuda" if args.cuda else "cpu")
|
||||
self.process_obs = process_obs.to(self.device)
|
||||
self.critic = QNetwork(num_inputs, action_space.shape[0], args.hidden_size).to(device=self.device)
|
||||
self.critic_optim = Adam(
|
||||
list(self.critic.parameters()) + list(process_obs.parameters())
|
||||
, lr=args.lr)
|
||||
|
||||
self.critic = QNetwork(num_inputs, action_space.shape[0], args.hidden_size, process_obs).to(device=self.device)
|
||||
self.critic_optim = Adam(self.critic.parameters(), lr=args.lr)
|
||||
|
||||
self.critic_target = QNetwork(num_inputs, action_space.shape[0], args.hidden_size, process_obs).to(self.device)
|
||||
self.critic_target = QNetwork(num_inputs, action_space.shape[0], args.hidden_size).to(self.device)
|
||||
hard_update(self.critic_target, self.critic)
|
||||
|
||||
if self.policy_type == "Gaussian":
|
||||
@@ -34,13 +35,15 @@ class SAC(object):
|
||||
self.log_alpha = torch.zeros(1, requires_grad=True, device=self.device)
|
||||
self.alpha_optim = Adam([self.log_alpha], lr=args.lr)
|
||||
|
||||
self.policy = GaussianPolicy(num_inputs, action_space.shape[0], args.hidden_size, action_space, process_obs).to(self.device)
|
||||
self.policy_optim = Adam(self.policy.parameters(), lr=args.lr)
|
||||
self.policy = GaussianPolicy(num_inputs, action_space.shape[0], args.hidden_size, action_space).to(self.device)
|
||||
self.policy_optim = Adam(
|
||||
list(self.policy.parameters()) + list(process_obs.parameters()),
|
||||
lr=args.lr)
|
||||
|
||||
else:
|
||||
self.alpha = 0
|
||||
self.automatic_entropy_tuning = False
|
||||
self.policy = DeterministicPolicy(num_inputs, action_space.shape[0], args.hidden_size, action_space, process_obs).to(self.device)
|
||||
self.policy = DeterministicPolicy(num_inputs, action_space.shape[0], args.hidden_size, action_space).to(self.device)
|
||||
self.policy_optim = Adam(self.policy.parameters(), lr=args.lr)
|
||||
|
||||
def select_action(self, obs, evaluate=False):
|
||||
@@ -64,8 +67,8 @@ class SAC(object):
|
||||
|
||||
|
||||
state_batch = self.process_obs(obs_batch)
|
||||
next_state_batch = self.process_obs(next_obs_batch)
|
||||
with torch.no_grad():
|
||||
next_state_batch = self.process_obs(next_obs_batch)
|
||||
next_state_action, next_state_log_pi, _ = self.policy.sample(next_state_batch)
|
||||
qf1_next_target, qf2_next_target = self.critic_target(next_state_batch, next_state_action)
|
||||
min_qf_next_target = torch.min(qf1_next_target, qf2_next_target) - self.alpha * next_state_log_pi
|
||||
@@ -79,6 +82,7 @@ class SAC(object):
|
||||
qf_loss.backward()
|
||||
self.critic_optim.step()
|
||||
|
||||
state_batch = self.process_obs(obs_batch)
|
||||
pi, log_pi, _ = self.policy.sample(state_batch)
|
||||
|
||||
qf1_pi, qf2_pi = self.critic(state_batch, pi)
|
||||
@@ -111,13 +115,13 @@ class SAC(object):
|
||||
|
||||
# Save model parameters
|
||||
def save_model(self, actor_path=None, critic_path=None):
|
||||
logger.debug(f'saving models to {actor_path} and {critic_path}'))
|
||||
logger.debug(f'saving models to {actor_path} and {critic_path}')
|
||||
torch.save(self.policy.state_dict(), actor_path)
|
||||
torch.save(self.critic.state_dict(), critic_path)
|
||||
|
||||
# Load model parameters
|
||||
def load_model(self, actor_path, critic_path):
|
||||
logger.info(f'Loading models from {actor_path} and {critic_path}'))
|
||||
logger.info(f'Loading models from {actor_path} and {critic_path}')
|
||||
if actor_path is not None:
|
||||
self.policy.load_state_dict(torch.load(actor_path))
|
||||
if critic_path is not None:
|
||||
|
||||
Reference in New Issue
Block a user