mirror of
https://github.com/wassname/ray.git
synced 2026-06-28 00:29:38 +08:00
[sgd] Cleanup code from last PR (#9076)
This commit is contained in:
@@ -89,7 +89,6 @@ def test_save_and_restore(ray_start_2_cpus, num_replicas): # noqa: F811
|
||||
|
||||
model1.get_config()
|
||||
model2.get_config()
|
||||
# assert _compare(model1_config, model2_config, skip_keys=["name"])
|
||||
|
||||
model1_weights = model1.get_weights()
|
||||
model2_weights = model2.get_weights()
|
||||
@@ -97,7 +96,6 @@ def test_save_and_restore(ray_start_2_cpus, num_replicas): # noqa: F811
|
||||
|
||||
model1.optimizer.get_weights()
|
||||
model2.optimizer.get_weights()
|
||||
# assert _compare(model1_opt_weights, model2_opt_weights)
|
||||
|
||||
|
||||
def _compare(d1, d2, skip_keys=None):
|
||||
|
||||
@@ -139,13 +139,6 @@ class TFRunner:
|
||||
self.model = self.model_creator(self.config)
|
||||
self.epoch = state["epoch"]
|
||||
self.model.set_weights(state["weights"])
|
||||
# This part is due to ray.get() changing scalar np.int64 object to int
|
||||
# state["optimizer_weights"][0] = np.array(
|
||||
# state["optimizer_weights"][0], dtype=np.int64)
|
||||
|
||||
# if self.model.optimizer.weights == []:
|
||||
# self.model.make_train_function()
|
||||
# self.model.optimizer.set_weights(state["optimizer_weights"])
|
||||
|
||||
def shutdown(self):
|
||||
"""Attempts to shut down the worker."""
|
||||
|
||||
@@ -154,14 +154,6 @@ class TFTrainer:
|
||||
|
||||
model = self.model_creator(self.config)
|
||||
model.set_weights(state["weights"])
|
||||
|
||||
# This part is due to ray.get() changing scalar np.int64 object to int
|
||||
# state["optimizer_weights"][0] = np.array(
|
||||
# state["optimizer_weights"][0], dtype=np.int64)
|
||||
|
||||
# model.make_train_function()
|
||||
# model.optimizer.set_weights(state["optimizer_weights"])
|
||||
|
||||
return model
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user