From 2e6f9bedf28edf4e4e962cd4c9dd0822f20b4f02 Mon Sep 17 00:00:00 2001 From: Si-Yuan Date: Wed, 5 Dec 2018 13:09:08 -0800 Subject: [PATCH] Add the extra fallback for serialization (#3468) * Add the extra fallback for serialization. * Better comments & warnings. quotes. * Update test/runtest.py Co-Authored-By: suquark * Update test/runtest.py Co-Authored-By: suquark * linting * Don't hijack too much errors. * simplify the test * Update runtest.py * simplify --- python/ray/worker.py | 11 +++++++++++ test/runtest.py | 17 +++++++++++++++++ 2 files changed, 28 insertions(+) diff --git a/python/ray/worker.py b/python/ray/worker.py index c3c01f485..b09afe824 100644 --- a/python/ray/worker.py +++ b/python/ray/worker.py @@ -418,6 +418,17 @@ class Worker(object): logger.info( "The object with ID {} already exists in the object store." .format(object_id)) + except TypeError: + # This error can happen because one of the members of the object + # may not be serializable for cloudpickle. So we need these extra + # fallbacks here to start from the beginning. Hopefully the object + # could have a `__reduce__` method. + register_custom_serializer(type(value), use_pickle=True) + warning_message = ("WARNING: Serializing the class {} failed, " + "so are are falling back to cloudpickle." + .format(type(value))) + logger.warning(warning_message) + self.store_and_register(object_id, value) def retrieve_and_deserialize(self, object_ids, timeout, error_timeout=10): start_time = time.time() diff --git a/test/runtest.py b/test/runtest.py index 767960a16..22251efd7 100644 --- a/test/runtest.py +++ b/test/runtest.py @@ -378,6 +378,23 @@ def test_custom_serializers(shutdown_only): assert ray.get(f.remote()) == ((3, "string1", Bar.__name__), "string2") +def test_serialization_final_fallback(ray_start): + pytest.importorskip("catboost") + # This test will only run when "catboost" is installed. + from catboost import CatBoostClassifier + + model = CatBoostClassifier( + iterations=2, + depth=2, + learning_rate=1, + loss_function="Logloss", + logging_level="Verbose") + + reconstructed_model = ray.get(ray.put(model)) + assert set(model.get_params().items()) == set( + reconstructed_model.get_params().items()) + + def test_register_class(shutdown_only): ray.init(num_cpus=2)