mirror of
https://github.com/wassname/ray.git
synced 2026-06-28 12:10:40 +08:00
[rllib] Add n-step Q learning for DQN (#1439)
* n-step * add sample adjustm * Oops * fix nstep * metric adjustment * Sat Jan 20 23:30:34 PST 2018 * Sun Jan 21 16:40:46 PST 2018 * Mon Jan 22 22:24:57 PST 2018
This commit is contained in:
@@ -2,9 +2,10 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
import pickle
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
import ray
|
||||
@@ -24,6 +25,8 @@ DEFAULT_CONFIG = dict(
|
||||
double_q=True,
|
||||
# Hidden layer sizes of the state and action value networks
|
||||
hiddens=[256],
|
||||
# N-step Q learning
|
||||
n_step=1,
|
||||
# Config options to pass to the model constructor
|
||||
model={},
|
||||
# Discount factor for the MDP
|
||||
|
||||
@@ -12,6 +12,34 @@ from ray.rllib.dqn.common.schedules import LinearSchedule
|
||||
from ray.rllib.optimizers import SampleBatch, TFMultiGPUSupport
|
||||
|
||||
|
||||
def adjust_nstep(n_step, gamma, obs, actions, rewards, new_obs, dones):
|
||||
"""Rewrites the given trajectory fragments to encode n-step rewards.
|
||||
|
||||
reward[i] = (
|
||||
reward[i] * gamma**0 +
|
||||
reward[i+1] * gamma**1 +
|
||||
... +
|
||||
reward[i+n_step-1] * gamma**(n_step-1))
|
||||
|
||||
The ith new_obs is also adjusted to point to the (i+n_step-1)'th new obs.
|
||||
|
||||
If the episode finishes, the reward will be truncated. After this rewrite,
|
||||
all the arrays will be shortened by (n_step - 1).
|
||||
"""
|
||||
for i in range(len(rewards) - n_step + 1):
|
||||
if dones[i]:
|
||||
continue # episode end
|
||||
for j in range(1, n_step):
|
||||
new_obs[i] = new_obs[i + j]
|
||||
rewards[i] += gamma ** j * rewards[i + j]
|
||||
if dones[i + j]:
|
||||
break # episode end
|
||||
# truncate ends of the trajectory
|
||||
new_len = len(obs) - n_step + 1
|
||||
for arr in [obs, actions, rewards, new_obs, dones]:
|
||||
del arr[new_len:]
|
||||
|
||||
|
||||
class DQNEvaluator(TFMultiGPUSupport):
|
||||
"""The base DQN Evaluator that does not include the replay buffer.
|
||||
|
||||
@@ -59,17 +87,29 @@ class DQNEvaluator(TFMultiGPUSupport):
|
||||
|
||||
def sample(self):
|
||||
obs, actions, rewards, new_obs, dones = [], [], [], [], []
|
||||
for _ in range(self.config["sample_batch_size"]):
|
||||
for _ in range(
|
||||
self.config["sample_batch_size"] + self.config["n_step"] - 1):
|
||||
ob, act, rew, ob1, done = self._step(self.global_timestep)
|
||||
obs.append(ob)
|
||||
actions.append(act)
|
||||
rewards.append(rew)
|
||||
new_obs.append(ob1)
|
||||
dones.append(done)
|
||||
return SampleBatch({
|
||||
|
||||
# N-step Q adjustments
|
||||
if self.config["n_step"] > 1:
|
||||
# Adjust for steps lost from truncation
|
||||
self.local_timestep -= (self.config["n_step"] - 1)
|
||||
adjust_nstep(
|
||||
self.config["n_step"], self.config["gamma"],
|
||||
obs, actions, rewards, new_obs, dones)
|
||||
|
||||
batch = SampleBatch({
|
||||
"obs": obs, "actions": actions, "rewards": rewards,
|
||||
"new_obs": new_obs, "dones": dones,
|
||||
"weights": np.ones_like(rewards)})
|
||||
assert batch.count == self.config["sample_batch_size"]
|
||||
return batch
|
||||
|
||||
def compute_gradients(self, samples):
|
||||
_, grad = self.dqn_graph.compute_gradients(
|
||||
|
||||
@@ -136,7 +136,8 @@ class ModelAndLoss(object):
|
||||
q_tp1_best_masked = (1.0 - done_mask) * q_tp1_best
|
||||
|
||||
# compute RHS of bellman equation
|
||||
q_t_selected_target = rew_t + config["gamma"] * q_tp1_best_masked
|
||||
q_t_selected_target = (
|
||||
rew_t + config["gamma"] ** config["n_step"] * q_tp1_best_masked)
|
||||
|
||||
# compute the error (potentially clipped)
|
||||
self.td_error = q_t_selected - tf.stop_gradient(q_t_selected_target)
|
||||
|
||||
@@ -10,9 +10,25 @@ import tempfile
|
||||
import ray
|
||||
from ray.rllib.a3c import DEFAULT_CONFIG
|
||||
from ray.rllib.a3c.a3c_evaluator import A3CEvaluator
|
||||
from ray.rllib.dqn.dqn_evaluator import adjust_nstep
|
||||
from ray.tune.registry import get_registry
|
||||
|
||||
|
||||
class DQNEvaluatorTest(unittest.TestCase):
|
||||
def testNStep(self):
|
||||
obs = [1, 2, 3, 4, 5, 6, 7]
|
||||
actions = ["a", "b", "a", "a", "a", "b", "a"]
|
||||
rewards = [10.0, 0.0, 100.0, 100.0, 100.0, 100.0, 100000.0]
|
||||
new_obs = [2, 3, 4, 5, 6, 7, 8]
|
||||
dones = [1, 0, 0, 0, 0, 1, 0]
|
||||
adjust_nstep(3, 0.9, obs, actions, rewards, new_obs, dones)
|
||||
self.assertEqual(obs, [1, 2, 3, 4, 5])
|
||||
self.assertEqual(actions, ["a", "b", "a", "a", "a"])
|
||||
self.assertEqual(rewards, [10.0, 171.0, 271.0, 271.0, 190.0])
|
||||
self.assertEqual(new_obs, [2, 5, 6, 7, 7])
|
||||
self.assertEqual(dones, [1, 0, 0, 0, 0])
|
||||
|
||||
|
||||
class A3CEvaluatorTest(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
@@ -25,7 +41,7 @@ class A3CEvaluatorTest(unittest.TestCase):
|
||||
self._temp_dir = tempfile.mkdtemp("a3c_evaluator_test")
|
||||
self.e = A3CEvaluator(
|
||||
get_registry(),
|
||||
lambda: gym.make("Pong-v0"),
|
||||
lambda config: gym.make("CartPole-v0"),
|
||||
config,
|
||||
logdir=self._temp_dir)
|
||||
|
||||
|
||||
@@ -121,6 +121,7 @@ class AsyncSampler(threading.Thread):
|
||||
self.policy = policy
|
||||
self._obs_filter = obs_filter
|
||||
self.started = False
|
||||
self.daemon = True
|
||||
|
||||
def run(self):
|
||||
self.started = True
|
||||
|
||||
@@ -42,6 +42,9 @@ class Resources(
|
||||
return super(Resources, cls).__new__(
|
||||
cls, cpu, gpu, driver_cpu_limit, driver_gpu_limit)
|
||||
|
||||
def summary_string(self):
|
||||
return "{} CPUs, {} GPUs".format(self.cpu, self.gpu)
|
||||
|
||||
|
||||
class Trial(object):
|
||||
"""A trial object holds the state for one model training run.
|
||||
|
||||
@@ -79,9 +79,12 @@ class TrialRunner(object):
|
||||
for trial in self._trials:
|
||||
if trial.status == Trial.PENDING:
|
||||
if not self.has_resources(trial.resources):
|
||||
raise TuneError(
|
||||
"Insufficient cluster resources to launch trial",
|
||||
(trial.resources, self._avail_resources))
|
||||
raise TuneError((
|
||||
"Insufficient cluster resources to launch trial: "
|
||||
"trial requested {} but the cluster only has {} "
|
||||
"available.").format(
|
||||
trial.resources.summary_string(),
|
||||
self._avail_resources.summary_string()))
|
||||
elif trial.status == Trial.PAUSED:
|
||||
raise TuneError(
|
||||
"There are paused trials, but no more pending "
|
||||
|
||||
Reference in New Issue
Block a user