[RLlib] Fix JAX import bug. (#12621)

This commit is contained in:
Sven Mika
2020-12-07 20:05:08 +01:00
committed by GitHub
parent 7e1422e925
commit 340b1e99fc
2 changed files with 2 additions and 2 deletions
+1 -1
View File
@@ -281,7 +281,7 @@ def get_activation_fn(name: Optional[str] = None, framework: str = "tf"):
elif framework == "jax":
if name in ["linear", None]:
return None
jax = try_import_jax()
jax, flax = try_import_jax()
if name == "swish":
return jax.nn.swish
if name == "relu":
+1 -1
View File
@@ -5,7 +5,7 @@ import numpy as np
from ray.rllib.utils.framework import try_import_jax, try_import_tf, \
try_import_torch
jax = try_import_jax()
jax, flax = try_import_jax()
tf1, tf, tfv = try_import_tf()
if tf1:
eager_mode = None