Custom action distributions (#5164)

* custom action dist wip

* Test case for custom action dist

* ActionDistribution.get_parameter_shape_for_action_space pattern

* Edit exception message to also suggest using a custom action distribution

* Clean up ModelCatalog.get_action_dist

* Pass model config to ActionDistribution constructors

* Update custom action distribution test case

* Name fix

* Autoformatter

* parameter shape static methods for torch distributions

* Fix docstring

* Generalize fake array for graph initialization

* Fix action dist constructors

* Correct parameter shape static methods for multicategorical and gaussian

* Make suggested changes to custom action dist's

* Correct instances of not passing model config to action dist

* Autoformatter

* fix tuple distribution constructor

* bugfix
This commit is contained in:
Matthew A. Wright
2019-08-06 18:13:16 +00:00
committed by Eric Liang
parent 94bff244e4
commit e3c9f7e83a
22 changed files with 252 additions and 73 deletions
+3 -1
View File
@@ -14,8 +14,10 @@ TRAINABLE_CLASS = "trainable_class"
ENV_CREATOR = "env_creator"
RLLIB_MODEL = "rllib_model"
RLLIB_PREPROCESSOR = "rllib_preprocessor"
RLLIB_ACTION_DIST = "rllib_action_dist"
KNOWN_CATEGORIES = [
TRAINABLE_CLASS, ENV_CREATOR, RLLIB_MODEL, RLLIB_PREPROCESSOR
TRAINABLE_CLASS, ENV_CREATOR, RLLIB_MODEL, RLLIB_PREPROCESSOR,
RLLIB_ACTION_DIST
]
logger = logging.getLogger(__name__)