diff --git a/tests/keras_contrib/utils/save_load_utils_test.py b/tests/keras_contrib/utils/save_load_utils_test.py index 14da651..67f55fc 100644 --- a/tests/keras_contrib/utils/save_load_utils_test.py +++ b/tests/keras_contrib/utils/save_load_utils_test.py @@ -28,7 +28,7 @@ def test_save_and_load_all_weights(): w1value[0, 0:4] = [1, 3, 3, 7] K.set_value(w1, w1value) # set optimizer weights - ow1 = m1.optimizer.weights[4] # momentum weights + ow1 = m1.optimizer.weights[3] # momentum weights ow1value = K.get_value(ow1) ow1value[0, 0:3] = [4, 2, 0] K.set_value(ow1, ow1value) @@ -41,7 +41,7 @@ def test_save_and_load_all_weights(): # check weights assert_allclose(K.get_value(m2.layers[1].kernel)[0, 0:4], [1, 3, 3, 7]) # check optimizer weights - assert_allclose(K.get_value(m2.optimizer.weights[4])[0, 0:3], [4, 2, 0]) + assert_allclose(K.get_value(m2.optimizer.weights[3])[0, 0:3], [4, 2, 0]) if __name__ == '__main__':