smaller camera, detach critic inputs

This commit is contained in:
wassname
2021-01-17 18:27:36 +08:00
parent 16ca1a351b
commit f28ca774ed
4 changed files with 33 additions and 21 deletions
+3 -2
View File
@@ -9,9 +9,10 @@ run:
-m pdb -c continue \
main.py \
--cuda \
--opt_level O1 \
--automatic_entropy_tuning true \
--replay_size 10000 \
--demonstrations data/demonstrations \
--replay_size 100000 \
# --demonstrations data/demonstrations \
# --load auto \
# ${python} -m pdb main.py --cuda --automatic_entropy_tuning true --replay_size 10000 --load auto --start_steps 200
# LOGURU_LEVEL=INFO ${python} main.py --demonstrations data/demonstrations --cuda --automatic_entropy_tuning true --replay_size 20000 --load auto
+15 -13
View File
@@ -66,6 +66,8 @@ def get_args():
help='Load models')
parser.add_argument('-r', '--render', action="store_true",
help='show')
parser.add_argument('-O', '--opt_level', default='O0',
help='Apex Amp Optimisation level')
args = parser.parse_args()
return args
@@ -95,16 +97,16 @@ observation_dim=observation_dim - process_obs.reduce_obs_space
logger.info(f"process_obs reduces obs_space {env.observation_space.shape[0]}-{process_obs.reduce_obs_space}={observation_dim}")
# Agent
agent = SAC(observation_dim, env.action_space, args, process_obs)
agent = SAC(observation_dim, env.action_space, args, process_obs, args.opt_level)
# from torchinfo import summary
# print('process_obs')
# summary(process_obs, input_size=(2, *env.observation_space.shape), depth=2)
# print('critic')
# summary(agent.critic, input_size=((2, observation_dim), (2, action_dim)))
# print('policy')
# summary(agent.policy, input_size=(2, observation_dim))
# # print(process_obs, agent.critic, agent.policy)
from torchinfo import summary
print('process_obs')
summary(process_obs, input_size=(2, *env.observation_space.shape), depth=2)
print('critic')
summary(agent.critic, input_size=((2, observation_dim), (2, action_dim)))
print('policy')
summary(agent.policy, input_size=(2, observation_dim))
# print(process_obs, agent.critic, agent.policy)
#Tensorboard
log_name = '{}_SAC_{}_{}_{}'.format(datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S"), args.env_name,
@@ -172,7 +174,7 @@ with RichTQDM() as prog:
else:
action = agent.select_action(state) # Sample action from policy
if len(memory) > args.batch_size and (total_numsteps%20==0):
if len(memory) > args.batch_size and (total_numsteps%1==0):
# Number of updates per step in environment
for i in range(args.updates_per_step):
# Update parameters of all the networks
@@ -211,7 +213,7 @@ with RichTQDM() as prog:
logger.info("\nEpisode: {}, total numsteps: {}, episode steps: {}, reward: {}, updates: {}".format(i_episode, total_numsteps, episode_steps, round(episode_reward, 2), updates))
prog.desc = "e: {}, r: {}, u: {}, m: {}".format(i_episode, round(episode_reward, 2), updates, len(memory))
if (i_episode % 100 == 0) and (args.eval is True) and i_episode>0:
if (i_episode % 10 == 0) and (args.eval is True) and i_episode>0:
avg_reward = 0.
episodes = 10
for _ in range(episodes):
@@ -244,7 +246,7 @@ with RichTQDM() as prog:
if total_numsteps >= args.num_steps:
break
if args.train:
save(save_dir)
if args.train:
save(save_dir)
env.close()
+2 -2
View File
@@ -113,7 +113,7 @@ class GenerativeResnet3Headless(nn.Module):
class ProcessObservation(nn.Module):
def __init__(self, res=(224, 224)):
def __init__(self, res=(112, 112)):
super().__init__()
self.res = res
@@ -122,7 +122,7 @@ class ProcessObservation(nn.Module):
os.path.dirname(os.path.abspath(__file__)),
'data/nets/cornell-randsplit-rgbd-grconvnet3-drop1-ch16/epoch_30_iou_0.97.pt'
)
self.feature_extractor = GenerativeResnet3Headless().train().half()
self.feature_extractor = GenerativeResnet3Headless().train()
self.feature_extractor.load_state_dict(state_dict=torch.load(grconvnet3_path), strict=False)
old_img_size = (res[0], res[1], 8)
+13 -4
View File
@@ -5,10 +5,11 @@ from torch.optim import Adam
from utils import soft_update, hard_update
from model import GaussianPolicy, QNetwork, DeterministicPolicy
from loguru import logger
from apex import amp
class SAC(object):
def __init__(self, num_inputs, action_space, args, process_obs=None):
def __init__(self, num_inputs, action_space, args, process_obs=None, opt_level='O1'):
self.gamma = args.gamma
self.tau = args.tau
@@ -49,6 +50,12 @@ class SAC(object):
list(self.policy.parameters()) + list(process_obs.parameters()),
lr=args.lr)
if opt_level is not None:
model, optimizer = amp.initialize(
[self.policy, self.process_obs, self.critic, self.critic_target],
[self.policy_optim, self.critic_optim],
opt_level=opt_level)
def select_action(self, obs, evaluate=False):
with torch.no_grad():
obs = torch.FloatTensor(obs).to(self.device).unsqueeze(0).to(self.dtype)
@@ -85,20 +92,22 @@ class SAC(object):
self.critic_optim.zero_grad()
assert torch.isfinite(qf_loss).all()
qf_loss.backward()
with amp.scale_loss(qf_loss, self.critic_optim) as qf_loss:
qf_loss.backward()
self.critic_optim.step()
state_batch = self.process_obs(obs_batch)
pi, log_pi, _ = self.policy.sample(state_batch)
qf1_pi, qf2_pi = self.critic(state_batch, pi)
qf1_pi, qf2_pi = self.critic(state_batch.detach(), pi)
min_qf_pi = torch.min(qf1_pi, qf2_pi)
policy_loss = ((self.alpha * log_pi) - min_qf_pi).mean() # Jπ = 𝔼stD,εtN[α * logπ(f(εt;st)|st) Q(st,f(εt;st))]
self.policy_optim.zero_grad()
assert torch.isfinite(policy_loss).all()
policy_loss.backward()
with amp.scale_loss(policy_loss, self.policy_optim) as policy_loss:
policy_loss.backward()
self.policy_optim.step()
if self.automatic_entropy_tuning: