mirror of
https://github.com/wassname/ray.git
synced 2026-06-27 23:08:32 +08:00
[rllib] Default to truncate_episodes and add some more config validators (#2967)
* update * link it * warn about truncation * fix * Update rllib-training.rst * deprecate tests failing
This commit is contained in:
@@ -48,7 +48,7 @@ DEFAULT_CONFIG = with_common_config({
|
||||
# Whether to allocate CPUs for workers (if > 0).
|
||||
"num_cpus_per_worker": 1,
|
||||
# Whether to rollout "complete_episodes" or "truncate_episodes"
|
||||
"batch_mode": "complete_episodes",
|
||||
"batch_mode": "truncate_episodes",
|
||||
# Which observation filter to apply to the observation
|
||||
"observation_filter": "MeanStdFilter",
|
||||
# Use the sync samples optimizer instead of the multi-gpu one
|
||||
@@ -80,17 +80,7 @@ class PPOAgent(Agent):
|
||||
extra_gpu=cf["num_gpus_per_worker"] * cf["num_workers"])
|
||||
|
||||
def _init(self):
|
||||
waste_ratio = (
|
||||
self.config["sample_batch_size"] * self.config["num_workers"] /
|
||||
self.config["train_batch_size"])
|
||||
if waste_ratio > 1:
|
||||
msg = ("sample_batch_size * num_workers >> train_batch_size. "
|
||||
"This means that many steps will be discarded. Consider "
|
||||
"reducing sample_batch_size, or increase train_batch_size.")
|
||||
if waste_ratio > 1.5:
|
||||
raise ValueError(msg)
|
||||
else:
|
||||
print("Warning: " + msg)
|
||||
self._validate_config()
|
||||
self.local_evaluator = self.make_local_evaluator(
|
||||
self.env_creator, self._policy_graph)
|
||||
self.remote_evaluators = self.make_remote_evaluators(
|
||||
@@ -114,6 +104,28 @@ class PPOAgent(Agent):
|
||||
"standardize_fields": ["advantages"],
|
||||
})
|
||||
|
||||
def _validate_config(self):
|
||||
waste_ratio = (
|
||||
self.config["sample_batch_size"] * self.config["num_workers"] /
|
||||
self.config["train_batch_size"])
|
||||
if waste_ratio > 1:
|
||||
msg = ("sample_batch_size * num_workers >> train_batch_size. "
|
||||
"This means that many steps will be discarded. Consider "
|
||||
"reducing sample_batch_size, or increase train_batch_size.")
|
||||
if waste_ratio > 1.5:
|
||||
raise ValueError(msg)
|
||||
else:
|
||||
print("Warning: " + msg)
|
||||
if self.config["sgd_minibatch_size"] > self.config["train_batch_size"]:
|
||||
raise ValueError(
|
||||
"Minibatch size {} must be <= train batch size {}.".format(
|
||||
self.config["sgd_minibatch_size"],
|
||||
self.config["train_batch_size"]))
|
||||
if (self.config["batch_mode"] == "truncate_episodes"
|
||||
and not self.config["use_gae"]):
|
||||
raise ValueError(
|
||||
"Episode truncation is not supported without a value function")
|
||||
|
||||
def _train(self):
|
||||
prev_steps = self.optimizer.num_steps_sampled
|
||||
fetches = self.optimizer.step()
|
||||
|
||||
@@ -10,3 +10,4 @@ hopper-ppo:
|
||||
train_batch_size: 160000
|
||||
num_workers: 64
|
||||
num_gpus: 4
|
||||
batch_mode: complete_episodes
|
||||
|
||||
@@ -17,3 +17,4 @@ humanoid-ppo-gae:
|
||||
free_log_std: true
|
||||
num_workers: 64
|
||||
num_gpus: 4
|
||||
batch_mode: complete_episodes
|
||||
|
||||
@@ -15,3 +15,4 @@ humanoid-ppo:
|
||||
use_gae: false
|
||||
num_workers: 64
|
||||
num_gpus: 4
|
||||
batch_mode: complete_episodes
|
||||
|
||||
@@ -13,4 +13,4 @@ pendulum-ppo:
|
||||
num_sgd_iter: 10
|
||||
model:
|
||||
fcnet_hiddens: [64, 64]
|
||||
squash_to_range: True
|
||||
batch_mode: complete_episodes
|
||||
|
||||
@@ -6,3 +6,4 @@ cartpole-ppo:
|
||||
time_total_s: 300
|
||||
config:
|
||||
num_workers: 1
|
||||
batch_mode: complete_episodes
|
||||
|
||||
@@ -15,3 +15,4 @@ pendulum-ppo:
|
||||
num_sgd_iter: 10
|
||||
model:
|
||||
fcnet_hiddens: [64, 64]
|
||||
batch_mode: complete_episodes
|
||||
|
||||
@@ -9,3 +9,4 @@ walker2d-v1-ppo:
|
||||
train_batch_size: 320000
|
||||
num_workers: 64
|
||||
num_gpus: 4
|
||||
batch_mode: complete_episodes
|
||||
|
||||
Reference in New Issue
Block a user