mirror of
https://github.com/wassname/ray.git
synced 2026-07-05 22:01:12 +08:00
[RLlib] JAXPolicy prep PR #2 (move get_activation_fn (backward-compatibly), minor fixes and preparations). (#13091)
This commit is contained in:
@@ -2,7 +2,8 @@ import numpy as np
|
||||
|
||||
from ray.rllib.models.torch.misc import SlimFC
|
||||
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
|
||||
from ray.rllib.utils.framework import get_activation_fn, try_import_torch
|
||||
from ray.rllib.models.utils import get_activation_fn
|
||||
from ray.rllib.utils.framework import try_import_torch
|
||||
|
||||
torch, nn = try_import_torch()
|
||||
|
||||
|
||||
@@ -5,10 +5,10 @@ from ray.rllib.agents.ppo.ppo_tf_policy import postprocess_ppo_gae, \
|
||||
vf_preds_fetches, compute_and_clip_gradients, setup_config, \
|
||||
ValueNetworkMixin
|
||||
from ray.rllib.evaluation.postprocessing import Postprocessing
|
||||
from ray.rllib.models.utils import get_activation_fn
|
||||
from ray.rllib.policy.sample_batch import SampleBatch
|
||||
from ray.rllib.policy.tf_policy_template import build_tf_policy
|
||||
from ray.rllib.utils import try_import_tf
|
||||
from ray.rllib.utils.framework import get_activation_fn
|
||||
|
||||
tf1, tf, tfv = try_import_tf()
|
||||
|
||||
|
||||
@@ -5,7 +5,8 @@ from typing import Optional, Tuple
|
||||
|
||||
from ray.rllib.models.torch.misc import SlimFC
|
||||
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
|
||||
from ray.rllib.utils.framework import get_activation_fn, try_import_torch
|
||||
from ray.rllib.models.utils import get_activation_fn
|
||||
from ray.rllib.utils.framework import try_import_torch
|
||||
from ray.rllib.utils.spaces.simplex import Simplex
|
||||
from ray.rllib.utils.typing import ModelConfigDict, TensorType
|
||||
|
||||
|
||||
@@ -3,7 +3,8 @@ import gym
|
||||
|
||||
from ray.rllib.models.tf.misc import normc_initializer
|
||||
from ray.rllib.models.tf.tf_modelv2 import TFModelV2
|
||||
from ray.rllib.utils.framework import get_activation_fn, try_import_tf
|
||||
from ray.rllib.models.utils import get_activation_fn
|
||||
from ray.rllib.utils.framework import try_import_tf
|
||||
from ray.rllib.utils.typing import Dict, TensorType, List, ModelConfigDict
|
||||
|
||||
tf1, tf, tfv = try_import_tf()
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
import numpy as np
|
||||
|
||||
from ray.rllib.utils.framework import get_activation_fn, get_variable, \
|
||||
try_import_tf
|
||||
from ray.rllib.utils.framework import TensorType, TensorShape
|
||||
from ray.rllib.models.utils import get_activation_fn
|
||||
from ray.rllib.utils.framework import get_variable, try_import_tf, \
|
||||
TensorType, TensorShape
|
||||
|
||||
tf1, tf, tfv = try_import_tf()
|
||||
|
||||
|
||||
@@ -3,8 +3,8 @@ import gym
|
||||
|
||||
from ray.rllib.models.tf.tf_modelv2 import TFModelV2
|
||||
from ray.rllib.models.tf.misc import normc_initializer
|
||||
from ray.rllib.models.utils import get_filter_config
|
||||
from ray.rllib.utils.framework import get_activation_fn, try_import_tf
|
||||
from ray.rllib.models.utils import get_activation_fn, get_filter_config
|
||||
from ray.rllib.utils.framework import try_import_tf
|
||||
from ray.rllib.utils.typing import ModelConfigDict, TensorType
|
||||
|
||||
tf1, tf, tfv = try_import_tf()
|
||||
|
||||
@@ -2,7 +2,8 @@
|
||||
import numpy as np
|
||||
from typing import Union, Tuple, Any, List
|
||||
|
||||
from ray.rllib.utils.framework import get_activation_fn, try_import_torch
|
||||
from ray.rllib.models.utils import get_activation_fn
|
||||
from ray.rllib.utils.framework import try_import_torch
|
||||
from ray.rllib.utils.typing import TensorType
|
||||
|
||||
torch, nn = try_import_torch()
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
from typing import Tuple
|
||||
|
||||
from ray.rllib.models.torch.misc import Reshape
|
||||
from ray.rllib.models.utils import get_initializer
|
||||
from ray.rllib.utils.framework import get_activation_fn, try_import_torch
|
||||
from ray.rllib.models.utils import get_activation_fn, get_initializer
|
||||
from ray.rllib.utils.framework import try_import_torch
|
||||
|
||||
torch, nn = try_import_torch()
|
||||
if torch:
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import numpy as np
|
||||
|
||||
from ray.rllib.utils.framework import get_activation_fn, try_import_torch
|
||||
from ray.rllib.utils.framework import TensorType
|
||||
from ray.rllib.models.utils import get_activation_fn
|
||||
from ray.rllib.utils.framework import try_import_torch, TensorType
|
||||
|
||||
torch, nn = try_import_torch()
|
||||
|
||||
|
||||
+79
-2
@@ -1,4 +1,62 @@
|
||||
from ray.rllib.utils.framework import try_import_tf, try_import_torch
|
||||
from typing import Optional
|
||||
|
||||
from ray.rllib.utils.framework import try_import_jax, try_import_tf, \
|
||||
try_import_torch
|
||||
|
||||
|
||||
def get_activation_fn(name: Optional[str] = None, framework: str = "tf"):
|
||||
"""Returns a framework specific activation function, given a name string.
|
||||
|
||||
Args:
|
||||
name (Optional[str]): One of "relu" (default), "tanh", "swish", or
|
||||
"linear" or None.
|
||||
framework (str): One of "jax", "tf|tfe|tf2" or "torch".
|
||||
|
||||
Returns:
|
||||
A framework-specific activtion function. e.g. tf.nn.tanh or
|
||||
torch.nn.ReLU. None if name in ["linear", None].
|
||||
|
||||
Raises:
|
||||
ValueError: If name is an unknown activation function.
|
||||
"""
|
||||
# Already a callable, return as-is.
|
||||
if callable(name):
|
||||
return name
|
||||
|
||||
# Infer the correct activation function from the string specifier.
|
||||
if framework == "torch":
|
||||
if name in ["linear", None]:
|
||||
return None
|
||||
if name == "swish":
|
||||
from ray.rllib.utils.torch_ops import Swish
|
||||
return Swish
|
||||
_, nn = try_import_torch()
|
||||
if name == "relu":
|
||||
return nn.ReLU
|
||||
elif name == "tanh":
|
||||
return nn.Tanh
|
||||
elif framework == "jax":
|
||||
if name in ["linear", None]:
|
||||
return None
|
||||
jax, _ = try_import_jax()
|
||||
if name == "swish":
|
||||
return jax.nn.swish
|
||||
if name == "relu":
|
||||
return jax.nn.relu
|
||||
elif name == "tanh":
|
||||
return jax.nn.hard_tanh
|
||||
else:
|
||||
assert framework in ["tf", "tfe", "tf2"],\
|
||||
"Unsupported framework `{}`!".format(framework)
|
||||
if name in ["linear", None]:
|
||||
return None
|
||||
tf1, tf, tfv = try_import_tf()
|
||||
fn = getattr(tf.nn, name, None)
|
||||
if fn is not None:
|
||||
return fn
|
||||
|
||||
raise ValueError("Unknown activation ({}) for framework={}!".format(
|
||||
name, framework))
|
||||
|
||||
|
||||
def get_filter_config(shape):
|
||||
@@ -40,7 +98,7 @@ def get_initializer(name, framework="tf"):
|
||||
|
||||
Args:
|
||||
name (str): One of "xavier_uniform" (default), "xavier_normal".
|
||||
framework (str): One of "tf" or "torch".
|
||||
framework (str): One of "jax", "tf|tfe|tf2" or "torch".
|
||||
|
||||
Returns:
|
||||
A framework-specific initializer function, e.g.
|
||||
@@ -50,14 +108,33 @@ def get_initializer(name, framework="tf"):
|
||||
Raises:
|
||||
ValueError: If name is an unknown initializer.
|
||||
"""
|
||||
# Already a callable, return as-is.
|
||||
if callable(name):
|
||||
return name
|
||||
|
||||
if framework == "jax":
|
||||
_, flax = try_import_jax()
|
||||
assert flax is not None,\
|
||||
"`flax` not installed. Try `pip install jax flax`."
|
||||
import flax.linen as nn
|
||||
if name in [None, "default", "xavier_uniform"]:
|
||||
return nn.initializers.xavier_uniform()
|
||||
elif name == "xavier_normal":
|
||||
return nn.initializers.xavier_normal()
|
||||
if framework == "torch":
|
||||
_, nn = try_import_torch()
|
||||
assert nn is not None,\
|
||||
"`torch` not installed. Try `pip install torch`."
|
||||
if name in [None, "default", "xavier_uniform"]:
|
||||
return nn.init.xavier_uniform_
|
||||
elif name == "xavier_normal":
|
||||
return nn.init.xavier_normal_
|
||||
else:
|
||||
assert framework in ["tf", "tfe", "tf2"],\
|
||||
"Unsupported framework `{}`!".format(framework)
|
||||
tf1, tf, tfv = try_import_tf()
|
||||
assert tf is not None,\
|
||||
"`tensorflow` not installed. Try `pip install tensorflow`."
|
||||
if name in [None, "default", "xavier_uniform"]:
|
||||
return tf.keras.initializers.GlorotUniform
|
||||
elif name == "xavier_normal":
|
||||
|
||||
@@ -15,8 +15,8 @@ def deprecation_warning(old, new=None, error=None):
|
||||
Args:
|
||||
old (str): A description of the "thing" that is to be deprecated.
|
||||
new (Optional[str]): A description of the new "thing" that replaces it.
|
||||
error (Optional[bool,Exception]): Whether or which exception to throw.
|
||||
If True, throw ValueError.
|
||||
error (Optional[Union[bool,Exception]]): Whether or which exception to
|
||||
throw. If True, throw ValueError.
|
||||
"""
|
||||
msg = "`{}` has been deprecated.{}".format(
|
||||
old, (" Use `{}` instead.".format(new) if new else ""))
|
||||
|
||||
@@ -9,11 +9,12 @@ from ray.rllib.models.tf.tf_action_dist import Categorical, MultiCategorical
|
||||
from ray.rllib.models.torch.misc import SlimFC
|
||||
from ray.rllib.models.torch.torch_action_dist import TorchCategorical, \
|
||||
TorchMultiCategorical
|
||||
from ray.rllib.models.utils import get_activation_fn
|
||||
from ray.rllib.policy.sample_batch import SampleBatch
|
||||
from ray.rllib.utils import NullContextManager
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils.exploration.exploration import Exploration
|
||||
from ray.rllib.utils.framework import get_activation_fn, try_import_tf, \
|
||||
from ray.rllib.utils.framework import try_import_tf, \
|
||||
try_import_torch
|
||||
from ray.rllib.utils.from_config import from_config
|
||||
from ray.rllib.utils.tf_ops import get_placeholder, one_hot as tf_one_hot
|
||||
|
||||
@@ -4,6 +4,7 @@ import os
|
||||
import sys
|
||||
from typing import Any, Optional
|
||||
|
||||
from ray.rllib.utils.deprecation import deprecation_warning
|
||||
from ray.rllib.utils.typing import TensorStructType, TensorShape, TensorType
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -252,7 +253,7 @@ def get_variable(value,
|
||||
return value
|
||||
|
||||
|
||||
# TODO: (sven) move to models/utils.py
|
||||
# Deprecated: Use rllib.models.utils::get_activation_fn instead.
|
||||
def get_activation_fn(name: Optional[str] = None, framework: str = "tf"):
|
||||
"""Returns a framework specific activation function, given a name string.
|
||||
|
||||
@@ -268,6 +269,10 @@ def get_activation_fn(name: Optional[str] = None, framework: str = "tf"):
|
||||
Raises:
|
||||
ValueError: If name is an unknown activation function.
|
||||
"""
|
||||
deprecation_warning(
|
||||
"rllib/utils/framework.py::get_activation_fn",
|
||||
"rllib/models/utils.py::get_activation_fn",
|
||||
error=False)
|
||||
if framework == "torch":
|
||||
if name in ["linear", None]:
|
||||
return None
|
||||
|
||||
@@ -5,7 +5,7 @@ import numpy as np
|
||||
from ray.rllib.utils.framework import try_import_jax, try_import_tf, \
|
||||
try_import_torch
|
||||
|
||||
jax, flax = try_import_jax()
|
||||
jax, _ = try_import_jax()
|
||||
tf1, tf, tfv = try_import_tf()
|
||||
if tf1:
|
||||
eager_mode = None
|
||||
|
||||
Reference in New Issue
Block a user