[sgd] Cleanup code from last PR (#9076)

This commit is contained in:
Richard Liaw
2020-06-22 15:17:07 -07:00
committed by GitHub
parent f76552d8db
commit e2330ffc35
3 changed files with 0 additions and 17 deletions
@@ -89,7 +89,6 @@ def test_save_and_restore(ray_start_2_cpus, num_replicas): # noqa: F811
model1.get_config()
model2.get_config()
# assert _compare(model1_config, model2_config, skip_keys=["name"])
model1_weights = model1.get_weights()
model2_weights = model2.get_weights()
@@ -97,7 +96,6 @@ def test_save_and_restore(ray_start_2_cpus, num_replicas): # noqa: F811
model1.optimizer.get_weights()
model2.optimizer.get_weights()
# assert _compare(model1_opt_weights, model2_opt_weights)
def _compare(d1, d2, skip_keys=None):
-7
View File
@@ -139,13 +139,6 @@ class TFRunner:
self.model = self.model_creator(self.config)
self.epoch = state["epoch"]
self.model.set_weights(state["weights"])
# This part is due to ray.get() changing scalar np.int64 object to int
# state["optimizer_weights"][0] = np.array(
# state["optimizer_weights"][0], dtype=np.int64)
# if self.model.optimizer.weights == []:
# self.model.make_train_function()
# self.model.optimizer.set_weights(state["optimizer_weights"])
def shutdown(self):
"""Attempts to shut down the worker."""
-8
View File
@@ -154,14 +154,6 @@ class TFTrainer:
model = self.model_creator(self.config)
model.set_weights(state["weights"])
# This part is due to ray.get() changing scalar np.int64 object to int
# state["optimizer_weights"][0] = np.array(
# state["optimizer_weights"][0], dtype=np.int64)
# model.make_train_function()
# model.optimizer.set_weights(state["optimizer_weights"])
return model