diff --git a/python/ray/util/sgd/tests/test_tensorflow.py b/python/ray/util/sgd/tests/test_tensorflow.py index b16f826fb..b6b33690b 100644 --- a/python/ray/util/sgd/tests/test_tensorflow.py +++ b/python/ray/util/sgd/tests/test_tensorflow.py @@ -89,7 +89,6 @@ def test_save_and_restore(ray_start_2_cpus, num_replicas): # noqa: F811 model1.get_config() model2.get_config() - # assert _compare(model1_config, model2_config, skip_keys=["name"]) model1_weights = model1.get_weights() model2_weights = model2.get_weights() @@ -97,7 +96,6 @@ def test_save_and_restore(ray_start_2_cpus, num_replicas): # noqa: F811 model1.optimizer.get_weights() model2.optimizer.get_weights() - # assert _compare(model1_opt_weights, model2_opt_weights) def _compare(d1, d2, skip_keys=None): diff --git a/python/ray/util/sgd/tf/tf_runner.py b/python/ray/util/sgd/tf/tf_runner.py index 3d0b0b5e1..53ae600cc 100644 --- a/python/ray/util/sgd/tf/tf_runner.py +++ b/python/ray/util/sgd/tf/tf_runner.py @@ -139,13 +139,6 @@ class TFRunner: self.model = self.model_creator(self.config) self.epoch = state["epoch"] self.model.set_weights(state["weights"]) - # This part is due to ray.get() changing scalar np.int64 object to int - # state["optimizer_weights"][0] = np.array( - # state["optimizer_weights"][0], dtype=np.int64) - - # if self.model.optimizer.weights == []: - # self.model.make_train_function() - # self.model.optimizer.set_weights(state["optimizer_weights"]) def shutdown(self): """Attempts to shut down the worker.""" diff --git a/python/ray/util/sgd/tf/tf_trainer.py b/python/ray/util/sgd/tf/tf_trainer.py index e31032002..27857d20d 100644 --- a/python/ray/util/sgd/tf/tf_trainer.py +++ b/python/ray/util/sgd/tf/tf_trainer.py @@ -154,14 +154,6 @@ class TFTrainer: model = self.model_creator(self.config) model.set_weights(state["weights"]) - - # This part is due to ray.get() changing scalar np.int64 object to int - # state["optimizer_weights"][0] = np.array( - # state["optimizer_weights"][0], dtype=np.int64) - - # model.make_train_function() - # model.optimizer.set_weights(state["optimizer_weights"]) - return model