mirror of
https://github.com/wassname/Run-Skeleton-Run.git
synced 2026-06-27 17:14:10 +08:00
77 lines
2.7 KiB
Python
77 lines
2.7 KiB
Python
import numpy as np
|
|
|
|
|
|
def traj_segment_generator(pi, env, args, horizon, stochastic):
|
|
# Initialize state variables
|
|
t = 0
|
|
ac = env.action_space.sample() # not used, just so we have the datatype
|
|
new = True # marks if we're on first timestep of an episode
|
|
ob = env.reset(difficulty=args.difficulty)
|
|
|
|
cur_ep_ret = 0 # return in current episode
|
|
cur_ep_len = 0 # len of current episode
|
|
ep_rets = [] # returns of completed episodes in this segment
|
|
ep_lens = [] # lengths of ...
|
|
|
|
# Initialize history arrays
|
|
obs = np.array([ob for _ in range(horizon)])
|
|
rews = np.zeros(horizon, 'float32')
|
|
vpreds = np.zeros(horizon, 'float32')
|
|
news = np.zeros(horizon, 'int32')
|
|
acs = np.array([ac for _ in range(horizon)])
|
|
prevacs = acs.copy()
|
|
|
|
while True:
|
|
prevac = ac
|
|
ac, vpred = pi.act(stochastic, ob)
|
|
# Slight weirdness here because we need value function at time T
|
|
# before returning segment [0, T-1] so we get the correct
|
|
# terminal value
|
|
if t > 0 and t % horizon == 0:
|
|
yield {"ob": obs, "rew": rews, "vpred": vpreds, "new": news,
|
|
"ac": acs, "prevac": prevacs, "nextvpred": vpred * (1 - new),
|
|
"ep_rets": ep_rets, "ep_lens": ep_lens}
|
|
# @TODO: TRPO & PPO implementation diff
|
|
# _, vpred = pi.act(stochastic, ob) # @TODO: uncomment??? IMPORTANT!!
|
|
# Be careful!!! if you change the downstream algorithm to aggregate
|
|
# several of these batches, then be sure to do a deepcopy
|
|
ep_rets = []
|
|
ep_lens = []
|
|
i = t % horizon
|
|
obs[i] = ob
|
|
vpreds[i] = vpred
|
|
news[i] = new
|
|
acs[i] = ac
|
|
prevacs[i] = prevac
|
|
|
|
ob, rew, new, _ = env.step(ac)
|
|
rews[i] = rew
|
|
|
|
cur_ep_ret += rew
|
|
cur_ep_len += 1
|
|
if new:
|
|
ep_rets.append(cur_ep_ret)
|
|
ep_lens.append(cur_ep_len)
|
|
cur_ep_ret = 0
|
|
cur_ep_len = 0
|
|
ob = env.reset(difficulty=args.difficulty)
|
|
t += 1
|
|
|
|
|
|
def add_vtarg_and_adv(seg, gamma, lam):
|
|
"""
|
|
Compute target value using TD(lambda) estimator, and advantage with GAE(lambda)
|
|
"""
|
|
# last element is only used for last vtarg, but we already zeroed it if last new = 1
|
|
new = np.append(seg["new"], 0)
|
|
vpred = np.append(seg["vpred"], seg["nextvpred"])
|
|
T = len(seg["rew"])
|
|
seg["adv"] = gaelam = np.empty(T, 'float32')
|
|
rew = seg["rew"]
|
|
lastgaelam = 0
|
|
for t in reversed(range(T)):
|
|
nonterminal = 1 - new[t + 1]
|
|
delta = rew[t] + gamma * vpred[t + 1] * nonterminal - vpred[t]
|
|
gaelam[t] = lastgaelam = delta + gamma * lam * nonterminal * lastgaelam
|
|
seg["tdlamret"] = seg["adv"] + seg["vpred"]
|