mirror of
https://github.com/wassname/ray.git
synced 2026-06-27 21:23:10 +08:00
[rllib] Initial RLLib documentation (#969)
* initial documentation for RLLib * more RL documentation * fix linting * fix comments * update * fix
This commit is contained in:
committed by
Robert Nishihara
parent
9ec3608eca
commit
1eb8c83314
@@ -1,3 +1,12 @@
|
||||
from ray.rllib.models.catalog import ModelCatalog
|
||||
from ray.rllib.models.action_dist import (ActionDistribution, Categorical,
|
||||
DiagGaussian, Deterministic)
|
||||
from ray.rllib.models.model import Model
|
||||
from ray.rllib.models.fcnet import FullyConnectedNetwork
|
||||
from ray.rllib.models.convnet import ConvolutionalNetwork
|
||||
from ray.rllib.models.lstm import LSTM
|
||||
|
||||
__all__ = ["ModelCatalog"]
|
||||
|
||||
__all__ = ["ActionDistribution", "ActionDistribution", "Categorical",
|
||||
"DiagGaussian", "Deterministic", "ModelCatalog", "Model",
|
||||
"FullyConnectedNetwork", "ConvolutionalNetwork", "LSTM"]
|
||||
|
||||
@@ -17,15 +17,19 @@ class ActionDistribution(object):
|
||||
self.inputs = inputs
|
||||
|
||||
def logp(self, x):
|
||||
"""The log-likelihood of the action distribution."""
|
||||
raise NotImplementedError
|
||||
|
||||
def kl(self, other):
|
||||
"""The KL-divergene between two action distributions."""
|
||||
raise NotImplementedError
|
||||
|
||||
def entropy(self):
|
||||
"""The entroy of the action distribution."""
|
||||
raise NotImplementedError
|
||||
|
||||
def sample(self):
|
||||
"""Draw a sample from the action distribution."""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
|
||||
@@ -11,13 +11,13 @@ from ray.rllib.models.misc import (conv2d, linear, flatten,
|
||||
normc_initializer)
|
||||
from ray.rllib.models.model import Model
|
||||
|
||||
use_tf100_api = (distutils.version.LooseVersion(tf.VERSION) >=
|
||||
distutils.version.LooseVersion("1.0.0"))
|
||||
|
||||
|
||||
class LSTM(Model):
|
||||
# TODO(rliaw): Add LSTM code for other algorithms
|
||||
def _init(self, inputs, num_outputs, options):
|
||||
use_tf100_api = (distutils.version.LooseVersion(tf.VERSION) >=
|
||||
distutils.version.LooseVersion("1.0.0"))
|
||||
|
||||
self.x = x = inputs
|
||||
for i in range(4):
|
||||
x = tf.nn.elu(conv2d(x, 32, "l{}".format(i + 1), [3, 3], [2, 2]))
|
||||
|
||||
Reference in New Issue
Block a user