prcoess_obs works, with training by both obs and critic

This commit is contained in:
wassname
2021-01-16 18:14:57 +08:00
parent 093876e414
commit 5534d4b078
5 changed files with 99 additions and 26 deletions
+1 -1
View File
@@ -2,7 +2,7 @@ python=/home/wassname/anaconda/envs/diygym3/bin/python
date=2021-01-03_13-30-07
LOGURU_LEVEL=INFO
run:
LOGURU_LEVEL=INFO ${python} main.py --demonstrations data/demonstrations --cuda --updates_per_step 4 --automatic_entropy_tuning
LOGURU_LEVEL=INFO ${python} main.py --demonstrations data/demonstrations --cuda --updates_per_step 4 --automatic_entropy_tuning true
# LOGURU_LEVEL=INFO ${python} main.py --demonstrations data/demonstrations --cuda --updates_per_step 2 --load auto --alpha 0.1 --tau 1 --target_update_interval 1000
# LOGURU_LEVEL=INFO ${python} main.py --demonstrations data/demonstrations --cuda --updates_per_step 2 --load auto --tau 1 --target_update_interval 1000 --policy Deterministic
+1 -1
View File
@@ -13,7 +13,7 @@ def load_demonstrations(mem: ReplayMemory, recordings: Path):
records = get_recordings(str(recordings))
logger.info('picks in recordings', sum(records['reward']>10))
ends=records["episodes_end_point"]
for i in tqdm(range(len(ends)-1), desc='loading demonstrations'):
for i in range(len(ends)-1):
a = ends[i]
b = ends[i+1]
for s in range(a+1, b):
+15 -3
View File
@@ -12,11 +12,13 @@ from replay_memory import ReplayMemory
from load_demonstrations import load_demonstrations
import apple_gym.env
import pickle
from process_obs import ProcessObservation
# from torchinfo import summary
from tqdm.auto import tqdm
from loguru import logger
from rich import logger.info
from rich import print
from rich.logging import RichHandler
logging.basicConfig(level=logging.INFO, handlers=[RichHandler(rich_tracebacks=True, markup=True)])
logger.configure(handlers=[{"sink": RichHandler(markup=True),
@@ -81,13 +83,22 @@ env.action_space.seed(args.seed)
keys_to_monitor = ['env_reward/apple_pick/tree/min_fruit_dist_reward',
'env_reward/apple_pick/tree/gripping_fruit_reward',
'env_reward/apple_pick/tree/force_tree_reward',
'env_reward/apple_pick/tree/force_fruit_reward', 'env_obs/apple_pick/tree/picks']:
'env_reward/apple_pick/tree/force_fruit_reward', 'env_obs/apple_pick/tree/picks']
torch.manual_seed(args.seed)
np.random.seed(args.seed)
# A visual network
observation_space=env.observation_space.shape[0]
process_obs=ProcessObservation()
observation_space=observation_space - process_obs.reduce_action_space
logger.info(f"process_obs reduces obs_space {env.observation_space.shape[0]}-{process_obs.reduce_action_space}={observation_space}")
# Agent
agent = SAC(env.observation_space.shape[0], env.action_space, args)
agent = SAC(observation_space, env.action_space, args, process_obs)
# TODO
# summary(model, input_size=(batch_size, 1, 28, 28))
#Tensorboard
log_name = '{}_SAC_{}_{}_{}'.format(datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S"), args.env_name,
@@ -132,6 +143,7 @@ updates = 0
with tqdm(unit='steps', mininterval=5) as prog:
for i_episode in itertools.count(0):
print('1')
episode_reward = 0
episode_steps = 0
done = False
+59 -2
View File
@@ -1,4 +1,6 @@
import os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
@@ -42,7 +44,7 @@ class GenerativeResnet3Headless(nn.Module):
self.res2 = ResidualBlock(channel_size * 4, channel_size * 4)
self.res3 = ResidualBlock(channel_size * 4, channel_size * 4)
self.res4 = ResidualBlock(channel_size * 4, channel_size * 4)
self.res5 = ResidualBlock(channel_size * 4, channel_size * 4)
self.conv4 = nn.ConvTranspose2d(channel_size * 4, channel_size * 2, kernel_size=4, stride=2, padding=1,
output_padding=1)
@@ -65,9 +67,17 @@ class GenerativeResnet3Headless(nn.Module):
self.dropout_sin = nn.Dropout(p=prob)
self.dropout_wid = nn.Dropout(p=prob)
# freeze above params
for param in self.parameters():
param.requires_grad = False
self.res5 = ResidualBlock(channel_size * 4, channel_size * 4)
self.head = nn.Conv2d(64, 4, 1, bias=False)
def forward(self, x_in):
# Freeze these layers
with torch.no_grad():
x = F.relu(self.bn1(self.conv1(x_in)))
x = F.relu(self.bn2(self.conv2(x)))
x = F.relu(self.bn3(self.conv3(x)))
@@ -75,11 +85,15 @@ class GenerativeResnet3Headless(nn.Module):
x = self.res2(x)
x = self.res3(x)
x = self.res4(x)
x = self.res5(x)
# 1x1 conv to reduce feature state, with random weights
# 1x1 conv to reduce feature state, init with random weights
x = self.head(x)
x = F.max_pool2d(x, kernel_size=3, stride=2)
x = F.max_pool2d(x, kernel_size=3, stride=2)
# ignore the old head which made it larger
# x = F.relu(self.bn4(self.conv4(x)))
# x = F.relu(self.bn5(self.conv5(x)))
# x = self.conv6(x)
@@ -96,3 +110,46 @@ class GenerativeResnet3Headless(nn.Module):
# width_output = self.width_output(x)
return x
class ProcessObservation(nn.Module):
def __init__(self, res=(224, 224)):
super().__init__()
self.res = res
# Load visual model
grconvnet3_path = os.path.join(
os.path.dirname(os.path.abspath(__file__)),
'data/nets/cornell-randsplit-rgbd-grconvnet3-drop1-ch16/epoch_30_iou_0.97.pt'
)
self.feature_extractor = GenerativeResnet3Headless().eval()
self.feature_extractor.load_state_dict(state_dict=torch.load(grconvnet3_path))
old_img_size = (res[0], res[1], 8)
new_img_size = (res[0]//16-1, res[1]//16-1, 8)
self.reduce_action_space = int(np.prod(old_img_size) - np.prod(new_img_size))
def __call__(self, obs):
"""
Takes in a torch array of observations, processes the images into features.
This assumes the observations ends in 2 rgbd images with shape (224, 244, 4)
"""
# import pdb; pdb.set_trace()
h, w = self.res
px = h * w
base_rgbd = obs[:, -px * 4:].reshape((-1, h, w, 4))
arm_rgbd = obs[:, -px * 8: - px * 4].reshape((-1, h, w, 4))
others = obs[:,: - px * 8]
bs = obs.shape[0]
# make a batch
x = torch.cat([base_rgbd, arm_rgbd], 0)
x = x.permute((0, 3, 1, 2)) # to ((-1, 4, x, y))
h = self.feature_extractor(x)
# undo fake batch
base_h, arm_h = h[:bs].reshape((bs, -1)), h[bs:].reshape((bs, -1))
# add features together
y = torch.cat([others, base_h, arm_h], 1)
return y
+16 -12
View File
@@ -13,18 +13,19 @@ class SAC(object):
self.gamma = args.gamma
self.tau = args.tau
self.alpha = args.alpha
self.process_obs = process_obs
self.device = torch.device("cuda" if args.cuda else "cpu")
self.policy_type = args.policy
self.target_update_interval = args.target_update_interval
self.automatic_entropy_tuning = args.automatic_entropy_tuning
self.device = torch.device("cuda" if args.cuda else "cpu")
self.process_obs = process_obs.to(self.device)
self.critic = QNetwork(num_inputs, action_space.shape[0], args.hidden_size).to(device=self.device)
self.critic_optim = Adam(
list(self.critic.parameters()) + list(process_obs.parameters())
, lr=args.lr)
self.critic = QNetwork(num_inputs, action_space.shape[0], args.hidden_size, process_obs).to(device=self.device)
self.critic_optim = Adam(self.critic.parameters(), lr=args.lr)
self.critic_target = QNetwork(num_inputs, action_space.shape[0], args.hidden_size, process_obs).to(self.device)
self.critic_target = QNetwork(num_inputs, action_space.shape[0], args.hidden_size).to(self.device)
hard_update(self.critic_target, self.critic)
if self.policy_type == "Gaussian":
@@ -34,13 +35,15 @@ class SAC(object):
self.log_alpha = torch.zeros(1, requires_grad=True, device=self.device)
self.alpha_optim = Adam([self.log_alpha], lr=args.lr)
self.policy = GaussianPolicy(num_inputs, action_space.shape[0], args.hidden_size, action_space, process_obs).to(self.device)
self.policy_optim = Adam(self.policy.parameters(), lr=args.lr)
self.policy = GaussianPolicy(num_inputs, action_space.shape[0], args.hidden_size, action_space).to(self.device)
self.policy_optim = Adam(
list(self.policy.parameters()) + list(process_obs.parameters()),
lr=args.lr)
else:
self.alpha = 0
self.automatic_entropy_tuning = False
self.policy = DeterministicPolicy(num_inputs, action_space.shape[0], args.hidden_size, action_space, process_obs).to(self.device)
self.policy = DeterministicPolicy(num_inputs, action_space.shape[0], args.hidden_size, action_space).to(self.device)
self.policy_optim = Adam(self.policy.parameters(), lr=args.lr)
def select_action(self, obs, evaluate=False):
@@ -64,8 +67,8 @@ class SAC(object):
state_batch = self.process_obs(obs_batch)
next_state_batch = self.process_obs(next_obs_batch)
with torch.no_grad():
next_state_batch = self.process_obs(next_obs_batch)
next_state_action, next_state_log_pi, _ = self.policy.sample(next_state_batch)
qf1_next_target, qf2_next_target = self.critic_target(next_state_batch, next_state_action)
min_qf_next_target = torch.min(qf1_next_target, qf2_next_target) - self.alpha * next_state_log_pi
@@ -79,6 +82,7 @@ class SAC(object):
qf_loss.backward()
self.critic_optim.step()
state_batch = self.process_obs(obs_batch)
pi, log_pi, _ = self.policy.sample(state_batch)
qf1_pi, qf2_pi = self.critic(state_batch, pi)
@@ -111,13 +115,13 @@ class SAC(object):
# Save model parameters
def save_model(self, actor_path=None, critic_path=None):
logger.debug(f'saving models to {actor_path} and {critic_path}'))
logger.debug(f'saving models to {actor_path} and {critic_path}')
torch.save(self.policy.state_dict(), actor_path)
torch.save(self.critic.state_dict(), critic_path)
# Load model parameters
def load_model(self, actor_path, critic_path):
logger.info(f'Loading models from {actor_path} and {critic_path}'))
logger.info(f'Loading models from {actor_path} and {critic_path}')
if actor_path is not None:
self.policy.load_state_dict(torch.load(actor_path))
if critic_path is not None: