mirror of
https://github.com/wassname/keras-contrib.git
synced 2026-06-27 16:10:11 +08:00
Merge pull request #161 from farizrahman4u/rm_test_file
save_load_utils_test.py removes 'model.h5'
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
import pytest
|
||||
import os
|
||||
from keras import backend as K
|
||||
from keras.layers import Input, Dense
|
||||
from keras.models import Model
|
||||
@@ -33,15 +34,16 @@ def test_save_and_load_all_weights():
|
||||
ow1value[0, 0:3] = [4, 2, 0]
|
||||
K.set_value(ow1, ow1value)
|
||||
# save all weights
|
||||
save_all_weights(m1, "model.h5")
|
||||
save_all_weights(m1, 'model.h5')
|
||||
# new model
|
||||
m2 = make_model()
|
||||
# load all weights
|
||||
load_all_weights(m2, "model.h5")
|
||||
load_all_weights(m2, 'model.h5')
|
||||
# check weights
|
||||
assert_allclose(K.get_value(m2.layers[1].kernel)[0, 0:4], [1, 3, 3, 7])
|
||||
# check optimizer weights
|
||||
assert_allclose(K.get_value(m2.optimizer.weights[3])[0, 0:3], [4, 2, 0])
|
||||
os.remove('model.h5')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
Reference in New Issue
Block a user