mirror of
https://github.com/wassname/ray.git
synced 2026-06-29 11:51:09 +08:00
b6a18cb39b
* wip * works with cartpole * lint * fix pg * comment * action dist rename * preprocessor * fix test * typo * fix the action[0] nonsense * revert * satisfy the lint * wip * works with cartpole * lint * fix pg * comment * action dist rename * preprocessor * fix test * typo * fix the action[0] nonsense * revert * satisfy the lint * Minor indentation changes. * fix merge * add humanoid * initial dqn refactor * remove tfutil * fix calls * fix tf errors 1 * closer * runs now * lint * tensorboard graph * fix linting * more 4 space * fix * fix linT * more lint * oops * es parity * remove example.py * fix training bug * add cartpole demo * try fixing cartpole * allow model options, configure cartpole * debug * simplify * no dueling * avoid out of file handles * Test dqn in jenkins. * Minor formatting. * fix issue * fix another * Fix problem in which we log to a directory that hasn't been created.
30 lines
1.0 KiB
Python
30 lines
1.0 KiB
Python
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
|
|
|
|
class Model(object):
|
|
"""Defines an abstract network model for use with RLlib.
|
|
|
|
Models convert input tensors to a number of output features. These features
|
|
can then be interpreted by ActionDistribution classes to determine
|
|
e.g. agent action values.
|
|
|
|
The last layer of the network can also be retrieved if the algorithm
|
|
needs to further post-processing (e.g. Actor and Critic networks in A3C).
|
|
|
|
Attributes:
|
|
inputs (Tensor): The input placeholder for this model.
|
|
outputs (Tensor): The output vector of this model.
|
|
last_layer (Tensor): The network layer right before the model output.
|
|
"""
|
|
|
|
def __init__(self, inputs, num_outputs, options):
|
|
self.inputs = inputs
|
|
self.outputs, self.last_layer = self._init(
|
|
inputs, num_outputs, options)
|
|
|
|
def _init(self):
|
|
"""Builds and returns the output and last layer of the network."""
|
|
raise NotImplementedError
|