This commit is contained in:
wassname
2018-01-18 17:23:35 +08:00
parent d9b67e5f9a
commit a87a3ad7bb
3 changed files with 14 additions and 158 deletions
-131
View File
@@ -2,142 +2,11 @@ import numpy as np
import gym
from gym.spaces import Box
import sys
# from osim.env import RunEnv
from common.state_transform import StateVelCentr
class DdpgWrapper(gym.Wrapper):
def __init__(self, env, args):
gym.Wrapper.__init__(self, env)
self.state_transform = StateVelCentr(
obstacles_mode='standard',
exclude_centr=True,
vel_states=[])
self.observation_space = Box(-1000, 1000, self.state_transform.state_size)
self.skip_frames = args.skip_frames
self.reward_scale = args.reward_scale
self.fail_reward = args.fail_reward
# [-1, 1] <-> [0, 1]
action_mean = .5
action_std = .5
self.normalize_action = lambda x: (x - action_mean) / action_std
self.denormalise_action = lambda x: x * action_std + action_mean
def reset(self, **kwargs):
return self._reset(**kwargs)
def _reset(self, **kwargs):
observation = self.env.reset(**kwargs)
self.env_step = 0
self.state_transform.reset()
observation, _ = self.state_transform.process(observation)
observation = self.observation(observation)
return observation
def _step(self, action):
action = self.denormalise_action(action)
total_reward = 0.
for _ in range(self.skip_frames):
observation, reward, done, _ = self.env.step(action)
observation, obst_rew = self.state_transform.process(observation)
total_reward += reward + obst_rew
self.env_step += 1
if done:
if self.env_step < 1000: # hardcoded
total_reward += self.fail_reward
break
observation = self.observation(observation)
total_reward *= self.reward_scale
return observation, total_reward, done, None
def observation(self, observation):
return self._observation(observation)
def _observation(self, observation):
observation = np.array(observation, dtype=np.float32)
return observation
# def create_env_old(args):
# env = RunEnv(visualize=False, max_obstacles=args.max_obstacles)
#
# if hasattr(args, "baseline_wrapper") or hasattr(args, "ddpg_wrapper"):
# env = DdpgWrapper(env, args)
#
# return env
# class BasicTask:
# def __init__(self):
# self.normalized_state = True
#
# def normalize_state(self, state):
# return state
#
# def reset(self):
# state = self.env.reset()
# if self.normalized_state:
# return self.normalize_state(state)
# return state
#
# def step(self, action):
# next_state, reward, done, info = self.env.step(action)
# if self.normalized_state:
# next_state = self.normalize_state(next_state)
# return next_state, np.sign(reward), done, info
#
# def random_action(self):
# return self.env.action_space.sample()
#
#
# class Pendulum(BasicTask):
# name = 'Pendulum-v0'
# success_threshold = -10
#
# def __init__(self):
# BasicTask.__init__(self)
# self.env = gym.make(self.name)
# self.max_episode_steps = self.env._max_episode_steps
# self.env._max_episode_steps = sys.maxsize
# self.action_dim = self.env.action_space.shape[0]
# self.state_dim = self.env.observation_space.shape[0]
#
# def step(self, action):
# action = np.clip(action, -2, 2)
# next_state, reward, done, info = self.env.step(action)
# return next_state, reward, done, info
def create_env(args):
# env = Pendulum()
env = gym.make('Pendulum-v0')
return env
def create_observation_handler(args):
if hasattr(args, "baseline_wrapper") or hasattr(args, "ddpg_wrapper"):
state_transform = StateVelCentr(
obstacles_mode='standard',
exclude_centr=True,
vel_states=[])
def observation_handler(observation, previous_action=None):
observation = np.array(observation, dtype=np.float32)
observation, _ = state_transform.process(observation)
return observation
else:
def observation_handler(observation, previous_action=None):
observation = np.array(observation, dtype=np.float32)
return observation
return observation_handler
def create_action_handler(args):
action_mean = .5
action_std = .5
action_handler = lambda x: x * action_std + action_mean
return action_handler
+14 -26
View File
@@ -16,6 +16,7 @@ from common.env_wrappers import create_env
from common.random_process import create_random_process
def create_model(args):
# TODO still using actor layers etc
base = Base(
args.n_observation, args.n_action, args.actor_layers,
activation=args.actor_activation,
@@ -53,26 +54,6 @@ def create_model(args):
pprint(dynamics)
return actor, critic, dynamics
# def create_model_old(args):
# actor = Actor(
# args.n_observation, args.n_action, args.actor_layers,
# activation=args.actor_activation,
# layer_norm=args.actor_layer_norm,
# parameters_noise=args.actor_parameters_noise,
# parameters_noise_factorised=args.actor_parameters_noise_factorised,
# last_activation=nn.Tanh)
# critic = Critic(
# args.n_observation, args.n_action, args.critic_layers,
# activation=args.critic_activation,
# layer_norm=args.critic_layer_norm,
# parameters_noise=args.critic_parameters_noise,
# parameters_noise_factorised=args.critic_parameters_noise_factorised)
#
# pprint(actor)
# pprint(critic)
#
# return actor, critic
def create_act_update_fns(actor, critic, dynamics, target_actor, target_critic, target_dynamics, args):
actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr)
@@ -118,7 +99,10 @@ def create_act_update_fns(actor, critic, dynamics, target_actor, target_critic,
weights = to_tensor(weights, requires_grad=False)
# Dynamics update
next_observations_pred = dynamics(observations, actions)
next_observations_pred = dynamics(
to_tensor(observations),
to_tensor(actions)
)
dynamics_loss = criterion(
next_observations_pred,
to_tensor(next_observations),
@@ -149,16 +133,18 @@ def create_act_update_fns(actor, critic, dynamics, target_actor, target_critic,
td_target = rewards + reward_predicted
critic.zero_grad()
# v_values = critic(to_tensor(observations), to_tensor(actions))
v_values = critic(dynamics(to_tensor(observations), to_tensor(actions)))
v_values = critic(
dynamics(
to_tensor(observations),
to_tensor(actions)
)
)
value_loss = criterion(v_values, td_target, weights=weights)
value_loss.backward()
torch.nn.utils.clip_grad_norm(critic.parameters(), args.grad_clip)
for param_group in critic_optim.param_groups:
param_group["lr"] = critic_lr
critic_optim.step()
# Actor update
@@ -280,7 +266,8 @@ def train_multi_thread(actor, critic, dynamics, target_actor, target_critic, tar
"epsilon": epsilon
}
observation = env.reset()#seed=seed, difficulty=args.difficulty)
env.seed(seed)
observation = env.reset() #seed=seed, difficulty=args.difficulty)
random_process.reset_states()
done = False
@@ -497,6 +484,7 @@ def play_single_thread(
"epsilon": epsilon
}
env.seed(seed)
observation = env.reset()#seed=seed, difficulty=args.difficulty)
random_process.reset_states()
done = False
-1
View File
@@ -150,7 +150,6 @@ class CriticHead(nn.Module):
def forward(self, observation):
x = self.base.forward(observation)
# x = torch.cat((x, action), dim=1)
x = self.value_net.forward(x)
return x