mirror of
https://github.com/wassname/Run-Skeleton-Run.git
synced 2026-06-27 17:14:10 +08:00
tidy
This commit is contained in:
@@ -2,142 +2,11 @@ import numpy as np
|
||||
import gym
|
||||
from gym.spaces import Box
|
||||
import sys
|
||||
# from osim.env import RunEnv
|
||||
|
||||
from common.state_transform import StateVelCentr
|
||||
|
||||
|
||||
class DdpgWrapper(gym.Wrapper):
|
||||
def __init__(self, env, args):
|
||||
gym.Wrapper.__init__(self, env)
|
||||
self.state_transform = StateVelCentr(
|
||||
obstacles_mode='standard',
|
||||
exclude_centr=True,
|
||||
vel_states=[])
|
||||
self.observation_space = Box(-1000, 1000, self.state_transform.state_size)
|
||||
self.skip_frames = args.skip_frames
|
||||
self.reward_scale = args.reward_scale
|
||||
self.fail_reward = args.fail_reward
|
||||
# [-1, 1] <-> [0, 1]
|
||||
action_mean = .5
|
||||
action_std = .5
|
||||
self.normalize_action = lambda x: (x - action_mean) / action_std
|
||||
self.denormalise_action = lambda x: x * action_std + action_mean
|
||||
|
||||
def reset(self, **kwargs):
|
||||
return self._reset(**kwargs)
|
||||
|
||||
def _reset(self, **kwargs):
|
||||
observation = self.env.reset(**kwargs)
|
||||
self.env_step = 0
|
||||
self.state_transform.reset()
|
||||
observation, _ = self.state_transform.process(observation)
|
||||
observation = self.observation(observation)
|
||||
return observation
|
||||
|
||||
def _step(self, action):
|
||||
action = self.denormalise_action(action)
|
||||
total_reward = 0.
|
||||
for _ in range(self.skip_frames):
|
||||
observation, reward, done, _ = self.env.step(action)
|
||||
observation, obst_rew = self.state_transform.process(observation)
|
||||
total_reward += reward + obst_rew
|
||||
self.env_step += 1
|
||||
if done:
|
||||
if self.env_step < 1000: # hardcoded
|
||||
total_reward += self.fail_reward
|
||||
break
|
||||
|
||||
observation = self.observation(observation)
|
||||
total_reward *= self.reward_scale
|
||||
return observation, total_reward, done, None
|
||||
|
||||
def observation(self, observation):
|
||||
return self._observation(observation)
|
||||
|
||||
def _observation(self, observation):
|
||||
observation = np.array(observation, dtype=np.float32)
|
||||
return observation
|
||||
|
||||
|
||||
# def create_env_old(args):
|
||||
# env = RunEnv(visualize=False, max_obstacles=args.max_obstacles)
|
||||
#
|
||||
# if hasattr(args, "baseline_wrapper") or hasattr(args, "ddpg_wrapper"):
|
||||
# env = DdpgWrapper(env, args)
|
||||
#
|
||||
# return env
|
||||
|
||||
|
||||
# class BasicTask:
|
||||
# def __init__(self):
|
||||
# self.normalized_state = True
|
||||
#
|
||||
# def normalize_state(self, state):
|
||||
# return state
|
||||
#
|
||||
# def reset(self):
|
||||
# state = self.env.reset()
|
||||
# if self.normalized_state:
|
||||
# return self.normalize_state(state)
|
||||
# return state
|
||||
#
|
||||
# def step(self, action):
|
||||
# next_state, reward, done, info = self.env.step(action)
|
||||
# if self.normalized_state:
|
||||
# next_state = self.normalize_state(next_state)
|
||||
# return next_state, np.sign(reward), done, info
|
||||
#
|
||||
# def random_action(self):
|
||||
# return self.env.action_space.sample()
|
||||
#
|
||||
#
|
||||
# class Pendulum(BasicTask):
|
||||
# name = 'Pendulum-v0'
|
||||
# success_threshold = -10
|
||||
#
|
||||
# def __init__(self):
|
||||
# BasicTask.__init__(self)
|
||||
# self.env = gym.make(self.name)
|
||||
# self.max_episode_steps = self.env._max_episode_steps
|
||||
# self.env._max_episode_steps = sys.maxsize
|
||||
# self.action_dim = self.env.action_space.shape[0]
|
||||
# self.state_dim = self.env.observation_space.shape[0]
|
||||
#
|
||||
# def step(self, action):
|
||||
# action = np.clip(action, -2, 2)
|
||||
# next_state, reward, done, info = self.env.step(action)
|
||||
# return next_state, reward, done, info
|
||||
|
||||
|
||||
def create_env(args):
|
||||
# env = Pendulum()
|
||||
env = gym.make('Pendulum-v0')
|
||||
return env
|
||||
|
||||
|
||||
def create_observation_handler(args):
|
||||
|
||||
if hasattr(args, "baseline_wrapper") or hasattr(args, "ddpg_wrapper"):
|
||||
state_transform = StateVelCentr(
|
||||
obstacles_mode='standard',
|
||||
exclude_centr=True,
|
||||
vel_states=[])
|
||||
|
||||
def observation_handler(observation, previous_action=None):
|
||||
observation = np.array(observation, dtype=np.float32)
|
||||
observation, _ = state_transform.process(observation)
|
||||
return observation
|
||||
else:
|
||||
def observation_handler(observation, previous_action=None):
|
||||
observation = np.array(observation, dtype=np.float32)
|
||||
return observation
|
||||
|
||||
return observation_handler
|
||||
|
||||
|
||||
def create_action_handler(args):
|
||||
action_mean = .5
|
||||
action_std = .5
|
||||
action_handler = lambda x: x * action_std + action_mean
|
||||
return action_handler
|
||||
|
||||
+14
-26
@@ -16,6 +16,7 @@ from common.env_wrappers import create_env
|
||||
from common.random_process import create_random_process
|
||||
|
||||
def create_model(args):
|
||||
# TODO still using actor layers etc
|
||||
base = Base(
|
||||
args.n_observation, args.n_action, args.actor_layers,
|
||||
activation=args.actor_activation,
|
||||
@@ -53,26 +54,6 @@ def create_model(args):
|
||||
pprint(dynamics)
|
||||
return actor, critic, dynamics
|
||||
|
||||
# def create_model_old(args):
|
||||
# actor = Actor(
|
||||
# args.n_observation, args.n_action, args.actor_layers,
|
||||
# activation=args.actor_activation,
|
||||
# layer_norm=args.actor_layer_norm,
|
||||
# parameters_noise=args.actor_parameters_noise,
|
||||
# parameters_noise_factorised=args.actor_parameters_noise_factorised,
|
||||
# last_activation=nn.Tanh)
|
||||
# critic = Critic(
|
||||
# args.n_observation, args.n_action, args.critic_layers,
|
||||
# activation=args.critic_activation,
|
||||
# layer_norm=args.critic_layer_norm,
|
||||
# parameters_noise=args.critic_parameters_noise,
|
||||
# parameters_noise_factorised=args.critic_parameters_noise_factorised)
|
||||
#
|
||||
# pprint(actor)
|
||||
# pprint(critic)
|
||||
#
|
||||
# return actor, critic
|
||||
|
||||
|
||||
def create_act_update_fns(actor, critic, dynamics, target_actor, target_critic, target_dynamics, args):
|
||||
actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr)
|
||||
@@ -118,7 +99,10 @@ def create_act_update_fns(actor, critic, dynamics, target_actor, target_critic,
|
||||
weights = to_tensor(weights, requires_grad=False)
|
||||
|
||||
# Dynamics update
|
||||
next_observations_pred = dynamics(observations, actions)
|
||||
next_observations_pred = dynamics(
|
||||
to_tensor(observations),
|
||||
to_tensor(actions)
|
||||
)
|
||||
dynamics_loss = criterion(
|
||||
next_observations_pred,
|
||||
to_tensor(next_observations),
|
||||
@@ -149,16 +133,18 @@ def create_act_update_fns(actor, critic, dynamics, target_actor, target_critic,
|
||||
td_target = rewards + reward_predicted
|
||||
|
||||
critic.zero_grad()
|
||||
|
||||
# v_values = critic(to_tensor(observations), to_tensor(actions))
|
||||
v_values = critic(dynamics(to_tensor(observations), to_tensor(actions)))
|
||||
v_values = critic(
|
||||
dynamics(
|
||||
to_tensor(observations),
|
||||
to_tensor(actions)
|
||||
)
|
||||
)
|
||||
value_loss = criterion(v_values, td_target, weights=weights)
|
||||
value_loss.backward()
|
||||
|
||||
torch.nn.utils.clip_grad_norm(critic.parameters(), args.grad_clip)
|
||||
for param_group in critic_optim.param_groups:
|
||||
param_group["lr"] = critic_lr
|
||||
|
||||
critic_optim.step()
|
||||
|
||||
# Actor update
|
||||
@@ -280,7 +266,8 @@ def train_multi_thread(actor, critic, dynamics, target_actor, target_critic, tar
|
||||
"epsilon": epsilon
|
||||
}
|
||||
|
||||
observation = env.reset()#seed=seed, difficulty=args.difficulty)
|
||||
env.seed(seed)
|
||||
observation = env.reset() #seed=seed, difficulty=args.difficulty)
|
||||
random_process.reset_states()
|
||||
done = False
|
||||
|
||||
@@ -497,6 +484,7 @@ def play_single_thread(
|
||||
"epsilon": epsilon
|
||||
}
|
||||
|
||||
env.seed(seed)
|
||||
observation = env.reset()#seed=seed, difficulty=args.difficulty)
|
||||
random_process.reset_states()
|
||||
done = False
|
||||
|
||||
@@ -150,7 +150,6 @@ class CriticHead(nn.Module):
|
||||
|
||||
def forward(self, observation):
|
||||
x = self.base.forward(observation)
|
||||
# x = torch.cat((x, action), dim=1)
|
||||
x = self.value_net.forward(x)
|
||||
return x
|
||||
|
||||
|
||||
Reference in New Issue
Block a user