diff --git a/rllib/utils/framework.py b/rllib/utils/framework.py index 9025b8736..5615d0fe0 100644 --- a/rllib/utils/framework.py +++ b/rllib/utils/framework.py @@ -281,7 +281,7 @@ def get_activation_fn(name: Optional[str] = None, framework: str = "tf"): elif framework == "jax": if name in ["linear", None]: return None - jax = try_import_jax() + jax, flax = try_import_jax() if name == "swish": return jax.nn.swish if name == "relu": diff --git a/rllib/utils/test_utils.py b/rllib/utils/test_utils.py index 05713c7cb..5460d9c27 100644 --- a/rllib/utils/test_utils.py +++ b/rllib/utils/test_utils.py @@ -5,7 +5,7 @@ import numpy as np from ray.rllib.utils.framework import try_import_jax, try_import_tf, \ try_import_torch -jax = try_import_jax() +jax, flax = try_import_jax() tf1, tf, tfv = try_import_tf() if tf1: eager_mode = None