mirror of
https://github.com/wassname/ray.git
synced 2026-06-27 22:23:17 +08:00
[rllib] [hotfix] Remove assert that trips on pytorch multiagent (#8241)
This commit is contained in:
@@ -5,7 +5,6 @@ import ray
|
||||
from ray.rllib.utils.annotations import DeveloperAPI
|
||||
from ray.rllib.evaluation.rollout_worker import RolloutWorker, \
|
||||
_validate_multiagent_config
|
||||
from ray.rllib.policy import Policy, TorchPolicy
|
||||
from ray.rllib.offline import NoopOutput, JsonReader, MixedInput, JsonWriter, \
|
||||
ShuffledInput
|
||||
from ray.rllib.utils import merge_dicts, try_import_tf
|
||||
@@ -279,25 +278,4 @@ class WorkerSet:
|
||||
_fake_sampler=config.get("_fake_sampler", False),
|
||||
extra_python_environs=extra_python_environs)
|
||||
|
||||
# Check for correct policy class (only locally, remote Workers should
|
||||
# create the exact same Policy types).
|
||||
if type(worker) is RolloutWorker:
|
||||
actual_class = type(worker.get_policy())
|
||||
|
||||
# Pytorch case: Policy must be a TorchPolicy.
|
||||
if config["use_pytorch"]:
|
||||
assert issubclass(actual_class, TorchPolicy), \
|
||||
"Worker policy must be subclass of `TorchPolicy`, " \
|
||||
"but is {}!".format(actual_class.__name__)
|
||||
# non-Pytorch case:
|
||||
# Policy may be None AND must not be a TorchPolicy.
|
||||
else:
|
||||
assert issubclass(actual_class, type(None)) or \
|
||||
(issubclass(actual_class, Policy) and
|
||||
not issubclass(actual_class, TorchPolicy)), "Worker " \
|
||||
"policy must be subclass of `Policy`, but NOT " \
|
||||
"`TorchPolicy` (your class={})! If you have a torch " \
|
||||
"Trainer, make sure to set `use_pytorch=True` in " \
|
||||
"your Trainer's config)!".format(actual_class.__name__)
|
||||
|
||||
return worker
|
||||
|
||||
Reference in New Issue
Block a user