Files
ray/python/ray/rllib/agents/pg/pg.py
T
Robert Nishihara e49839c73f Fix linting. (#3155)
2018-10-28 20:43:29 -07:00

57 lines
1.9 KiB
Python

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from ray.rllib.agents.agent import Agent, with_common_config
from ray.rllib.agents.pg.pg_policy_graph import PGPolicyGraph
from ray.rllib.optimizers import SyncSamplesOptimizer
from ray.rllib.utils import merge_dicts
from ray.tune.trial import Resources
# yapf: disable
# __sphinx_doc_begin__
DEFAULT_CONFIG = with_common_config({
# No remote workers by default
"num_workers": 0,
# Learning rate
"lr": 0.0004,
})
# __sphinx_doc_end__
# yapf: enable
class PGAgent(Agent):
"""Simple policy gradient agent.
This is an example agent to show how to implement algorithms in RLlib.
In most cases, you will probably want to use the PPO agent instead.
"""
_agent_name = "PG"
_default_config = DEFAULT_CONFIG
_policy_graph = PGPolicyGraph
@classmethod
def default_resource_request(cls, config):
cf = merge_dicts(cls._default_config, config)
return Resources(cpu=1, gpu=0, extra_cpu=cf["num_workers"])
def _init(self):
self.local_evaluator = self.make_local_evaluator(
self.env_creator, self._policy_graph)
self.remote_evaluators = self.make_remote_evaluators(
self.env_creator, self._policy_graph, self.config["num_workers"],
{})
self.optimizer = SyncSamplesOptimizer(self.local_evaluator,
self.remote_evaluators,
self.config["optimizer"])
def _train(self):
prev_steps = self.optimizer.num_steps_sampled
self.optimizer.step()
result = self.optimizer.collect_metrics(
self.config["collect_metrics_timeout"])
result.update(timesteps_this_iter=self.optimizer.num_steps_sampled -
prev_steps)
return result