[rllib] Update concepts docs and add "Building Policies in Torch/TensorFlow" section (#4821)

* wip

* fix index

* fix bugs

* todo

* add imports

* note on get ph

* note on get ph

* rename to building custom algs

* add rnn state info
This commit is contained in:
Eric Liang
2019-05-27 14:17:32 -07:00
committed by GitHub
parent 574e1c7695
commit a45c61e19b
7 changed files with 433 additions and 51 deletions
+1 -1
View File
@@ -29,7 +29,7 @@ def get_policy_class(config):
PGTrainer = build_trainer(
name="PG",
name="PGTrainer",
default_config=DEFAULT_CONFIG,
default_policy=PGTFPolicy,
get_policy_class=get_policy_class)
+3 -3
View File
@@ -63,7 +63,7 @@ DEFAULT_CONFIG = with_common_config({
# yapf: enable
def make_optimizer(local_evaluator, remote_evaluators, config):
def choose_policy_optimizer(local_evaluator, remote_evaluators, config):
if config["simple_optimizer"]:
return SyncSamplesOptimizer(
local_evaluator,
@@ -155,10 +155,10 @@ def validate_config(config):
PPOTrainer = build_trainer(
name="PPO",
name="PPOTrainer",
default_config=DEFAULT_CONFIG,
default_policy=PPOTFPolicy,
make_policy_optimizer=make_optimizer,
make_policy_optimizer=choose_policy_optimizer,
validate_config=validate_config,
after_optimizer_step=update_kl,
before_train_step=warn_about_obs_filter,
+6 -7
View File
@@ -2,7 +2,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from ray.rllib.agents.trainer import Trainer
from ray.rllib.agents.trainer import Trainer, COMMON_CONFIG
from ray.rllib.optimizers import SyncSamplesOptimizer
from ray.rllib.utils.annotations import override, DeveloperAPI
@@ -44,13 +44,12 @@ def build_trainer(name,
a Trainer instance that uses the specified args.
"""
if name.endswith("Trainer"):
raise ValueError("Algorithm name should not include *Trainer suffix",
name)
if not name.endswith("Trainer"):
raise ValueError("Algorithm name should have *Trainer suffix", name)
class trainer_cls(Trainer):
_name = name
_default_config = default_config or Trainer.COMMON_CONFIG
_default_config = default_config or COMMON_CONFIG
_policy = default_policy
def _init(self, config, env_creator):
@@ -92,6 +91,6 @@ def build_trainer(name,
after_train_result(self, res)
return res
trainer_cls.__name__ = name + "Trainer"
trainer_cls.__qualname__ = name + "Trainer"
trainer_cls.__name__ = name
trainer_cls.__qualname__ = name
return trainer_cls