mirror of
https://github.com/wassname/ray.git
synced 2026-07-02 11:20:09 +08:00
[rllib] Parallelize sample collection and gradient computation in DQN (#746)
* wip * works with cartpole * lint * fix pg * comment * action dist rename * preprocessor * fix test * typo * fix the action[0] nonsense * revert * satisfy the lint * wip * wip * works with cartpole * lint * fix pg * comment * action dist rename * preprocessor * fix test * typo * fix the action[0] nonsense * revert * satisfy the lint * Minor indentation changes. * fix merge * add humanoid * initial dqn refactor * remove tfutil * fix calls * fix tf errors 1 * closer * runs now * lint * tensorboard graph * fix linting * more 4 space * fix * fix linT * more lint * oops * es parity * remove example.py * fix training bug * add cartpole demo * try fixing cartpole * allow model options, configure cartpole * debug * simplify * no dueling * avoid out of file handles * Test dqn in jenkins. * Minor formatting. * lint * fix py3 * fix issue * remove chekcpoint * revert * Fixit * sanity check configs * update cuda * fix * parallel gradient computation * update * upd * bug * upd * always record training stats * fix * comments * revert assert * add gpu mask * fofset * a tie * Merge * fix * fix * fix examples * A3C -> DQN * fix dqn test * remove submodule * fix linting
This commit is contained in:
committed by
Philipp Moritz
parent
10027974b1
commit
9f3a4fce50
+215
-113
@@ -10,6 +10,7 @@ import pickle
|
||||
import os
|
||||
import tensorflow as tf
|
||||
|
||||
import ray
|
||||
from ray.rllib.common import Agent, TrainingResult
|
||||
from ray.rllib.dqn import logger, models
|
||||
from ray.rllib.dqn.common.atari_wrappers_deprecated \
|
||||
@@ -41,17 +42,15 @@ from ray.rllib.dqn.replay_buffer import ReplayBuffer, PrioritizedReplayBuffer
|
||||
annealed
|
||||
exploration_final_eps: float
|
||||
final value of random action probability
|
||||
train_freq: int
|
||||
update the model every `train_freq` steps.
|
||||
batch_size: int
|
||||
sample_batch_size: int
|
||||
update the replay buffer with this many samples at once
|
||||
num_workers: int
|
||||
the number of workers to use for parallel batch sample collection
|
||||
train_batch_size: int
|
||||
size of a batched sampled from replay buffer for training
|
||||
print_freq: int
|
||||
how often to print out training progress
|
||||
set to None to disable printing
|
||||
checkpoint_freq: int
|
||||
how often to save the model. This is so that the best version is
|
||||
restored at the end of the training. If you do not wish to restore
|
||||
the best version at the end of the training set this variable to None.
|
||||
learning_starts: int
|
||||
how many steps of the model to collect transitions for before learning
|
||||
starts
|
||||
@@ -80,16 +79,17 @@ DEFAULT_CONFIG = dict(
|
||||
double_q=True,
|
||||
hiddens=[256],
|
||||
model={},
|
||||
gpu_offset=0,
|
||||
lr=5e-4,
|
||||
schedule_max_timesteps=100000,
|
||||
timesteps_per_iteration=1000,
|
||||
buffer_size=50000,
|
||||
exploration_fraction=0.1,
|
||||
exploration_final_eps=0.02,
|
||||
train_freq=1,
|
||||
batch_size=32,
|
||||
sample_batch_size=1,
|
||||
num_workers=1,
|
||||
train_batch_size=32,
|
||||
print_freq=1,
|
||||
checkpoint_freq=10000,
|
||||
learning_starts=1000,
|
||||
gamma=1.0,
|
||||
grad_norm_clipping=10,
|
||||
@@ -102,22 +102,14 @@ DEFAULT_CONFIG = dict(
|
||||
num_cpu=16)
|
||||
|
||||
|
||||
class DQNAgent(Agent):
|
||||
def __init__(self, env_name, config, upload_dir=None):
|
||||
config.update({"alg": "DQN"})
|
||||
|
||||
Agent.__init__(self, env_name, config, upload_dir=upload_dir)
|
||||
|
||||
with tf.Graph().as_default():
|
||||
self._init()
|
||||
|
||||
def _init(self):
|
||||
config = self.config
|
||||
env = gym.make(self.env_name)
|
||||
class Actor(object):
|
||||
def __init__(self, env_name, config, logdir):
|
||||
env = gym.make(env_name)
|
||||
# TODO(ekl): replace this with RLlib preprocessors
|
||||
if "NoFrameskip" in self.env_name:
|
||||
if "NoFrameskip" in env_name:
|
||||
env = ScaledFloatFrame(wrap_dqn(env))
|
||||
self.env = env
|
||||
self.config = config
|
||||
|
||||
num_cpu = config["num_cpu"]
|
||||
tf_config = tf.ConfigProto(
|
||||
@@ -131,11 +123,11 @@ class DQNAgent(Agent):
|
||||
self.replay_buffer = PrioritizedReplayBuffer(
|
||||
config["buffer_size"],
|
||||
alpha=config["prioritized_replay_alpha"])
|
||||
prioritized_replay_beta_iters = (
|
||||
config["prioritized_replay_beta_iters"])
|
||||
prioritized_replay_beta_iters = \
|
||||
config["prioritized_replay_beta_iters"]
|
||||
if prioritized_replay_beta_iters is None:
|
||||
prioritized_replay_beta_iters = (
|
||||
config["schedule_max_timesteps"])
|
||||
prioritized_replay_beta_iters = \
|
||||
config["schedule_max_timesteps"]
|
||||
self.beta_schedule = LinearSchedule(
|
||||
prioritized_replay_beta_iters,
|
||||
initial_p=config["prioritized_replay_beta0"],
|
||||
@@ -154,135 +146,245 @@ class DQNAgent(Agent):
|
||||
# Initialize the parameters and copy them to the target network.
|
||||
self.sess.run(tf.global_variables_initializer())
|
||||
self.dqn_graph.update_target(self.sess)
|
||||
self.variables = ray.experimental.TensorFlowVariables(
|
||||
tf.group(self.dqn_graph.q_tp1, self.dqn_graph.q_t), self.sess)
|
||||
|
||||
self.episode_rewards = [0.0]
|
||||
self.episode_lengths = [0.0]
|
||||
self.saved_mean_reward = None
|
||||
self.obs = self.env.reset()
|
||||
self.num_timesteps = 0
|
||||
self.file_writer = tf.summary.FileWriter(logdir, self.sess.graph)
|
||||
|
||||
def step(self, cur_timestep):
|
||||
# Take action and update exploration to the newest value
|
||||
action = self.dqn_graph.act(
|
||||
self.sess, np.array(self.obs)[None],
|
||||
self.exploration.value(cur_timestep))[0]
|
||||
new_obs, rew, done, _ = self.env.step(action)
|
||||
ret = (self.obs, action, rew, new_obs, float(done))
|
||||
self.obs = new_obs
|
||||
self.episode_rewards[-1] += rew
|
||||
self.episode_lengths[-1] += 1
|
||||
if done:
|
||||
self.obs = self.env.reset()
|
||||
self.episode_rewards.append(0.0)
|
||||
self.episode_lengths.append(0.0)
|
||||
return ret
|
||||
|
||||
def do_steps(self, num_steps, cur_timestep):
|
||||
for _ in range(num_steps):
|
||||
obs, action, rew, new_obs, done = self.step(cur_timestep)
|
||||
self.replay_buffer.add(obs, action, rew, new_obs, done)
|
||||
|
||||
def get_gradient(self, cur_timestep):
|
||||
if self.config["prioritized_replay"]:
|
||||
experience = self.replay_buffer.sample(
|
||||
self.config["train_batch_size"],
|
||||
beta=self.beta_schedule.value(cur_timestep))
|
||||
(obses_t, actions, rewards, obses_tp1,
|
||||
dones, _, batch_idxes) = experience
|
||||
else:
|
||||
obses_t, actions, rewards, obses_tp1, dones = \
|
||||
self.replay_buffer.sample(self.config["train_batch_size"])
|
||||
batch_idxes = None
|
||||
td_errors, grad = self.dqn_graph.compute_gradients(
|
||||
self.sess, obses_t, actions, rewards, obses_tp1, dones,
|
||||
np.ones_like(rewards))
|
||||
if self.config["prioritized_replay"]:
|
||||
new_priorities = (
|
||||
np.abs(td_errors) + self.config["prioritized_replay_eps"])
|
||||
self.replay_buffer.update_priorities(
|
||||
batch_idxes, new_priorities)
|
||||
return grad
|
||||
|
||||
def apply_gradients(self, grad):
|
||||
self.dqn_graph.apply_gradients(self.sess, grad)
|
||||
|
||||
def stats(self, num_timesteps):
|
||||
mean_100ep_reward = round(np.mean(self.episode_rewards[-101:-1]), 1)
|
||||
mean_100ep_length = round(np.mean(self.episode_lengths[-101:-1]), 1)
|
||||
exploration = self.exploration.value(num_timesteps)
|
||||
return (
|
||||
mean_100ep_reward,
|
||||
mean_100ep_length,
|
||||
len(self.episode_rewards),
|
||||
exploration,
|
||||
len(self.replay_buffer))
|
||||
|
||||
def get_weights(self):
|
||||
return self.variables.get_weights()
|
||||
|
||||
def set_weights(self, weights):
|
||||
self.variables.set_weights(weights)
|
||||
|
||||
def save(self):
|
||||
return [
|
||||
self.beta_schedule,
|
||||
self.exploration,
|
||||
self.episode_rewards,
|
||||
self.episode_lengths,
|
||||
self.saved_mean_reward,
|
||||
self.obs]
|
||||
|
||||
def restore(self, data):
|
||||
self.beta_schedule = data[0]
|
||||
self.exploration = data[1]
|
||||
self.episode_rewards = data[2]
|
||||
self.episode_lengths = data[3]
|
||||
self.saved_mean_reward = data[4]
|
||||
self.obs = data[5]
|
||||
|
||||
|
||||
@ray.remote
|
||||
class RemoteActor(Actor):
|
||||
def __init__(self, env_name, config, logdir, gpu_mask):
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = gpu_mask
|
||||
Actor.__init__(self, env_name, config, logdir)
|
||||
|
||||
|
||||
class DQNAgent(Agent):
|
||||
def __init__(self, env_name, config, upload_dir=None):
|
||||
config.update({"alg": "DQN"})
|
||||
|
||||
Agent.__init__(self, env_name, config, upload_dir=upload_dir)
|
||||
|
||||
with tf.Graph().as_default():
|
||||
self._init(config, env_name)
|
||||
|
||||
def _init(self, config, env_name):
|
||||
self.actor = Actor(env_name, config, self.logdir)
|
||||
self.workers = [
|
||||
RemoteActor.remote(
|
||||
env_name, config, self.logdir,
|
||||
"{}".format(i + config["gpu_offset"]))
|
||||
for i in range(config["num_workers"])]
|
||||
|
||||
self.cur_timestep = 0
|
||||
self.num_iterations = 0
|
||||
self.file_writer = tf.summary.FileWriter(self.logdir, self.sess.graph)
|
||||
self.num_target_updates = 0
|
||||
self.steps_since_update = 0
|
||||
self.file_writer = tf.summary.FileWriter(
|
||||
self.logdir, self.actor.sess.graph)
|
||||
self.saver = tf.train.Saver(max_to_keep=None)
|
||||
|
||||
def _update_worker_weights(self):
|
||||
w = self.actor.get_weights()
|
||||
weights = ray.put(self.actor.get_weights())
|
||||
for w in self.workers:
|
||||
w.set_weights.remote(weights)
|
||||
|
||||
def _train(self):
|
||||
config = self.config
|
||||
sample_time, learn_time = 0, 0
|
||||
iter_init_timesteps = self.num_timesteps
|
||||
sample_time, sync_time, learn_time, apply_time = 0, 0, 0, 0
|
||||
iter_init_timesteps = self.cur_timestep
|
||||
|
||||
for _ in range(config["timesteps_per_iteration"]):
|
||||
self.num_timesteps += 1
|
||||
num_loop_iters = 0
|
||||
steps_per_iter = config["sample_batch_size"] * len(self.workers)
|
||||
while (self.cur_timestep - iter_init_timesteps <
|
||||
config["timesteps_per_iteration"]):
|
||||
dt = time.time()
|
||||
# Take action and update exploration to the newest value
|
||||
action = self.dqn_graph.act(
|
||||
self.sess, np.array(self.obs)[None],
|
||||
self.exploration.value(self.num_timesteps))[0]
|
||||
new_obs, rew, done, _ = self.env.step(action)
|
||||
# Store transition in the replay buffer.
|
||||
self.replay_buffer.add(self.obs, action, rew, new_obs, float(done))
|
||||
self.obs = new_obs
|
||||
|
||||
self.episode_rewards[-1] += rew
|
||||
self.episode_lengths[-1] += 1
|
||||
if done:
|
||||
self.obs = self.env.reset()
|
||||
self.episode_rewards.append(0.0)
|
||||
self.episode_lengths.append(0.0)
|
||||
ray.get([
|
||||
w.do_steps.remote(
|
||||
config["sample_batch_size"], self.cur_timestep)
|
||||
for w in self.workers])
|
||||
num_loop_iters += 1
|
||||
self.cur_timestep += steps_per_iter
|
||||
self.steps_since_update += steps_per_iter
|
||||
sample_time += time.time() - dt
|
||||
|
||||
if self.num_timesteps > config["learning_starts"] and \
|
||||
self.num_timesteps % config["train_freq"] == 0:
|
||||
if self.cur_timestep > config["learning_starts"]:
|
||||
dt = time.time()
|
||||
# Minimize the error in Bellman's equation on a batch sampled
|
||||
# from replay buffer.
|
||||
if config["prioritized_replay"]:
|
||||
experience = self.replay_buffer.sample(
|
||||
config["batch_size"],
|
||||
beta=self.beta_schedule.value(self.num_timesteps))
|
||||
(obses_t, actions, rewards, obses_tp1,
|
||||
dones, _, batch_idxes) = experience
|
||||
else:
|
||||
obses_t, actions, rewards, obses_tp1, dones = (
|
||||
self.replay_buffer.sample(config["batch_size"]))
|
||||
batch_idxes = None
|
||||
td_errors = self.dqn_graph.train(
|
||||
self.sess, obses_t, actions, rewards, obses_tp1, dones,
|
||||
np.ones_like(rewards))
|
||||
if config["prioritized_replay"]:
|
||||
new_priorities = np.abs(td_errors) + (
|
||||
config["prioritized_replay_eps"])
|
||||
self.replay_buffer.update_priorities(
|
||||
batch_idxes, new_priorities)
|
||||
self._update_worker_weights()
|
||||
sync_time += (time.time() - dt)
|
||||
dt = time.time()
|
||||
gradients = ray.get(
|
||||
[w.get_gradient.remote(self.cur_timestep)
|
||||
for w in self.workers])
|
||||
learn_time += (time.time() - dt)
|
||||
dt = time.time()
|
||||
for grad in gradients:
|
||||
self.actor.apply_gradients(grad)
|
||||
apply_time += (time.time() - dt)
|
||||
|
||||
if self.num_timesteps > config["learning_starts"] and (
|
||||
self.num_timesteps %
|
||||
config["target_network_update_freq"] == 0):
|
||||
if (self.cur_timestep > config["learning_starts"] and
|
||||
self.steps_since_update >
|
||||
config["target_network_update_freq"]):
|
||||
self.actor.dqn_graph.update_target(self.actor.sess)
|
||||
# Update target network periodically.
|
||||
self.dqn_graph.update_target(self.sess)
|
||||
self._update_worker_weights()
|
||||
self.steps_since_update -= config["target_network_update_freq"]
|
||||
self.num_target_updates += 1
|
||||
|
||||
mean_100ep_reward = round(np.mean(self.episode_rewards[-101:-1]), 1)
|
||||
mean_100ep_length = round(np.mean(self.episode_lengths[-101:-1]), 1)
|
||||
num_episodes = len(self.episode_rewards)
|
||||
mean_100ep_reward = 0.0
|
||||
mean_100ep_length = 0.0
|
||||
num_episodes = 0
|
||||
buffer_size_sum = 0
|
||||
for mean_rew, mean_len, episodes, exploration, buf_sz in ray.get(
|
||||
[w.stats.remote(self.cur_timestep) for w in self.workers]):
|
||||
mean_100ep_reward += mean_rew
|
||||
mean_100ep_length += mean_len
|
||||
num_episodes += episodes
|
||||
buffer_size_sum += buf_sz
|
||||
mean_100ep_reward /= len(self.workers)
|
||||
mean_100ep_length /= len(self.workers)
|
||||
|
||||
info = {
|
||||
"sample_time": sample_time,
|
||||
"learn_time": learn_time,
|
||||
"steps": self.num_timesteps,
|
||||
"episodes": num_episodes,
|
||||
"exploration": int(
|
||||
100 * self.exploration.value(self.num_timesteps))
|
||||
}
|
||||
info = [
|
||||
("mean_100ep_reward", mean_100ep_reward),
|
||||
("exploration_frac", exploration),
|
||||
("steps", self.cur_timestep),
|
||||
("episodes", num_episodes),
|
||||
("buffer_sizes_sum", buffer_size_sum),
|
||||
("target_updates", self.num_target_updates),
|
||||
("sample_time", sample_time),
|
||||
("weight_sync_time", sync_time),
|
||||
("apply_time", apply_time),
|
||||
("learn_time", learn_time),
|
||||
("samples_per_s",
|
||||
num_loop_iters * np.float64(steps_per_iter) / sample_time),
|
||||
("learn_samples_per_s",
|
||||
num_loop_iters * np.float64(config["train_batch_size"]) *
|
||||
np.float64(config["num_workers"]) / learn_time),
|
||||
]
|
||||
|
||||
logger.record_tabular("sample_time", sample_time)
|
||||
logger.record_tabular("learn_time", learn_time)
|
||||
logger.record_tabular("steps", self.num_timesteps)
|
||||
logger.record_tabular("buffer_size", len(self.replay_buffer))
|
||||
logger.record_tabular("episodes", num_episodes)
|
||||
logger.record_tabular("mean 100 episode reward", mean_100ep_reward)
|
||||
logger.record_tabular(
|
||||
"% time spent exploring",
|
||||
int(100 * self.exploration.value(self.num_timesteps)))
|
||||
for k, v in info:
|
||||
logger.record_tabular(k, v)
|
||||
logger.dump_tabular()
|
||||
|
||||
result = TrainingResult(
|
||||
episode_reward_mean=mean_100ep_reward,
|
||||
episode_len_mean=mean_100ep_length,
|
||||
timesteps_this_iter=self.num_timesteps - iter_init_timesteps,
|
||||
timesteps_this_iter=self.cur_timestep - iter_init_timesteps,
|
||||
info=info)
|
||||
|
||||
return result
|
||||
|
||||
def _save(self):
|
||||
checkpoint_path = self.saver.save(
|
||||
self.sess,
|
||||
self.actor.sess,
|
||||
os.path.join(self.logdir, "checkpoint"),
|
||||
global_step=self.num_iterations)
|
||||
extra_data = [
|
||||
self.actor.save(),
|
||||
self.replay_buffer,
|
||||
self.beta_schedule,
|
||||
self.exploration,
|
||||
self.episode_rewards,
|
||||
self.episode_lengths,
|
||||
self.saved_mean_reward,
|
||||
self.obs,
|
||||
self.num_timesteps,
|
||||
self.num_iterations]
|
||||
self.cur_timestep,
|
||||
self.num_iterations,
|
||||
self.num_target_updates,
|
||||
self.steps_since_update]
|
||||
pickle.dump(extra_data, open(checkpoint_path + ".extra_data", "wb"))
|
||||
return checkpoint_path
|
||||
|
||||
def _restore(self, checkpoint_path):
|
||||
self.saver.restore(self.sess, checkpoint_path)
|
||||
extra_data = pickle.load(open(checkpoint_path + ".extra_data", "rb"))
|
||||
self.replay_buffer = extra_data[0]
|
||||
self.beta_schedule = extra_data[1]
|
||||
self.exploration = extra_data[2]
|
||||
self.episode_rewards = extra_data[3]
|
||||
self.episode_lengths = extra_data[4]
|
||||
self.saved_mean_reward = extra_data[5]
|
||||
self.obs = extra_data[6]
|
||||
self.num_timesteps = extra_data[7]
|
||||
self.num_iterations = extra_data[8]
|
||||
self.actor.restore(extra_data[0])
|
||||
self.replay_buffer = extra_data[1]
|
||||
self.cur_timestep = extra_data[2]
|
||||
self.num_iterations = extra_data[3]
|
||||
self.num_target_updates = extra_data[4]
|
||||
self.steps_since_update = extra_data[5]
|
||||
|
||||
def compute_action(self, observation):
|
||||
return self.dqn_graph.act(
|
||||
self.sess, np.array(observation)[None], 0.0)[0]
|
||||
return self.actor.dqn_graph.act(
|
||||
self.actor.sess, np.array(observation)[None], 0.0)[0]
|
||||
|
||||
@@ -70,7 +70,7 @@ def _minimize_and_clip(optimizer, objective, var_list, clip_val=10):
|
||||
for i, (grad, var) in enumerate(gradients):
|
||||
if grad is not None:
|
||||
gradients[i] = (tf.clip_by_norm(grad, clip_val), var)
|
||||
return optimizer.apply_gradients(gradients)
|
||||
return gradients
|
||||
|
||||
|
||||
def _scope_vars(scope, trainable_only=False):
|
||||
@@ -169,12 +169,16 @@ class DQNGraph(object):
|
||||
weighted_error = tf.reduce_mean(self.importance_weights * errors)
|
||||
# compute optimization op (potentially with gradient clipping)
|
||||
if config["grad_norm_clipping"] is not None:
|
||||
self.optimize_expr = _minimize_and_clip(
|
||||
self.grads_and_vars = _minimize_and_clip(
|
||||
optimizer, weighted_error, var_list=q_func_vars,
|
||||
clip_val=config["grad_norm_clipping"])
|
||||
else:
|
||||
self.optimize_expr = optimizer.minimize(
|
||||
self.grads_and_vars = optimizer.compute_gradients(
|
||||
weighted_error, var_list=q_func_vars)
|
||||
self.grads_and_vars = [
|
||||
(g, v) for (g, v) in self.grads_and_vars if g is not None]
|
||||
self.grads = [g for (g, v) in self.grads_and_vars]
|
||||
self.train_expr = optimizer.apply_gradients(self.grads_and_vars)
|
||||
|
||||
# update_target_fn will be called periodically to copy Q network to
|
||||
# target Q network
|
||||
@@ -197,11 +201,11 @@ class DQNGraph(object):
|
||||
self.eps: eps,
|
||||
})
|
||||
|
||||
def train(
|
||||
def compute_gradients(
|
||||
self, sess, obs_t, act_t, rew_t, obs_tp1, done_mask,
|
||||
importance_weights):
|
||||
td_err, _ = sess.run(
|
||||
[self.td_error, self.optimize_expr],
|
||||
td_err, grads = sess.run(
|
||||
[self.td_error, self.grads],
|
||||
feed_dict={
|
||||
self.obs_t: obs_t,
|
||||
self.act_t: act_t,
|
||||
@@ -210,4 +214,9 @@ class DQNGraph(object):
|
||||
self.done_mask: done_mask,
|
||||
self.importance_weights: importance_weights
|
||||
})
|
||||
return td_err
|
||||
return td_err, grads
|
||||
|
||||
def apply_gradients(self, sess, grads):
|
||||
assert len(grads) == len(self.grads_and_vars)
|
||||
feed_dict = {ph: g for (g, ph) in zip(grads, self.grads)}
|
||||
sess.run(self.train_expr, feed_dict=feed_dict)
|
||||
|
||||
+44
@@ -0,0 +1,44 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import argparse
|
||||
|
||||
import ray
|
||||
from ray.rllib.dqn import DQNAgent, DEFAULT_CONFIG
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Run the DQN algorithm.")
|
||||
parser.add_argument("--iterations", default=-1, type=int,
|
||||
help="The number of training iterations to run.")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
config = DEFAULT_CONFIG.copy()
|
||||
config.update(dict(
|
||||
lr=1e-3,
|
||||
schedule_max_timesteps=100000,
|
||||
exploration_fraction=0.1,
|
||||
exploration_final_eps=0.02,
|
||||
dueling=False,
|
||||
hiddens=[],
|
||||
model_config=dict(
|
||||
fcnet_hiddens=[64],
|
||||
fcnet_activation='relu',
|
||||
)))
|
||||
|
||||
ray.init()
|
||||
dqn = DQNAgent("CartPole-v0", config)
|
||||
|
||||
iteration = 0
|
||||
while iteration != args.iterations:
|
||||
iteration += 1
|
||||
res = dqn.train()
|
||||
print("current status: {}".format(res))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Executable
+43
@@ -0,0 +1,43 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import argparse
|
||||
|
||||
import ray
|
||||
from ray.rllib.dqn import DQNAgent, DEFAULT_CONFIG
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Run the DQN algorithm.")
|
||||
parser.add_argument("--iterations", default=-1, type=int,
|
||||
help="The number of training iterations to run.")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
config = DEFAULT_CONFIG.copy()
|
||||
config.update(dict(
|
||||
lr=1e-4,
|
||||
schedule_max_timesteps=2000000,
|
||||
exploration_fraction=0.1,
|
||||
exploration_final_eps=0.01,
|
||||
train_freq=4,
|
||||
learning_starts=10000,
|
||||
target_network_update_freq=1000,
|
||||
gamma=0.99,
|
||||
prioritized_replay=True))
|
||||
|
||||
ray.init()
|
||||
dqn = DQNAgent("PongNoFrameskip-v4", config)
|
||||
|
||||
iteration = 0
|
||||
while iteration != args.iterations:
|
||||
iteration += 1
|
||||
res = dqn.train()
|
||||
print("current status: {}".format(res))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user