mirror of
https://github.com/wassname/ray.git
synced 2026-06-28 12:45:44 +08:00
[rllib] Update concepts docs and add "Building Policies in Torch/TensorFlow" section (#4821)
* wip * fix index * fix bugs * todo * add imports * note on get ph * note on get ph * rename to building custom algs * add rnn state info
This commit is contained in:
@@ -29,7 +29,7 @@ def get_policy_class(config):
|
||||
|
||||
|
||||
PGTrainer = build_trainer(
|
||||
name="PG",
|
||||
name="PGTrainer",
|
||||
default_config=DEFAULT_CONFIG,
|
||||
default_policy=PGTFPolicy,
|
||||
get_policy_class=get_policy_class)
|
||||
|
||||
@@ -63,7 +63,7 @@ DEFAULT_CONFIG = with_common_config({
|
||||
# yapf: enable
|
||||
|
||||
|
||||
def make_optimizer(local_evaluator, remote_evaluators, config):
|
||||
def choose_policy_optimizer(local_evaluator, remote_evaluators, config):
|
||||
if config["simple_optimizer"]:
|
||||
return SyncSamplesOptimizer(
|
||||
local_evaluator,
|
||||
@@ -155,10 +155,10 @@ def validate_config(config):
|
||||
|
||||
|
||||
PPOTrainer = build_trainer(
|
||||
name="PPO",
|
||||
name="PPOTrainer",
|
||||
default_config=DEFAULT_CONFIG,
|
||||
default_policy=PPOTFPolicy,
|
||||
make_policy_optimizer=make_optimizer,
|
||||
make_policy_optimizer=choose_policy_optimizer,
|
||||
validate_config=validate_config,
|
||||
after_optimizer_step=update_kl,
|
||||
before_train_step=warn_about_obs_filter,
|
||||
|
||||
@@ -2,7 +2,7 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from ray.rllib.agents.trainer import Trainer
|
||||
from ray.rllib.agents.trainer import Trainer, COMMON_CONFIG
|
||||
from ray.rllib.optimizers import SyncSamplesOptimizer
|
||||
from ray.rllib.utils.annotations import override, DeveloperAPI
|
||||
|
||||
@@ -44,13 +44,12 @@ def build_trainer(name,
|
||||
a Trainer instance that uses the specified args.
|
||||
"""
|
||||
|
||||
if name.endswith("Trainer"):
|
||||
raise ValueError("Algorithm name should not include *Trainer suffix",
|
||||
name)
|
||||
if not name.endswith("Trainer"):
|
||||
raise ValueError("Algorithm name should have *Trainer suffix", name)
|
||||
|
||||
class trainer_cls(Trainer):
|
||||
_name = name
|
||||
_default_config = default_config or Trainer.COMMON_CONFIG
|
||||
_default_config = default_config or COMMON_CONFIG
|
||||
_policy = default_policy
|
||||
|
||||
def _init(self, config, env_creator):
|
||||
@@ -92,6 +91,6 @@ def build_trainer(name,
|
||||
after_train_result(self, res)
|
||||
return res
|
||||
|
||||
trainer_cls.__name__ = name + "Trainer"
|
||||
trainer_cls.__qualname__ = name + "Trainer"
|
||||
trainer_cls.__name__ = name
|
||||
trainer_cls.__qualname__ = name
|
||||
return trainer_cls
|
||||
|
||||
Reference in New Issue
Block a user