mirror of
https://github.com/wassname/ray.git
synced 2026-06-27 18:06:25 +08:00
[RLlib] Fix JAX import bug. (#12621)
This commit is contained in:
@@ -281,7 +281,7 @@ def get_activation_fn(name: Optional[str] = None, framework: str = "tf"):
|
||||
elif framework == "jax":
|
||||
if name in ["linear", None]:
|
||||
return None
|
||||
jax = try_import_jax()
|
||||
jax, flax = try_import_jax()
|
||||
if name == "swish":
|
||||
return jax.nn.swish
|
||||
if name == "relu":
|
||||
|
||||
@@ -5,7 +5,7 @@ import numpy as np
|
||||
from ray.rllib.utils.framework import try_import_jax, try_import_tf, \
|
||||
try_import_torch
|
||||
|
||||
jax = try_import_jax()
|
||||
jax, flax = try_import_jax()
|
||||
tf1, tf, tfv = try_import_tf()
|
||||
if tf1:
|
||||
eager_mode = None
|
||||
|
||||
Reference in New Issue
Block a user