tf.keras: 保存与加载模型
日期: 2020-12-14 分类: 跨站数据 703次阅读
tf.keras.save 和 tf.keras.models.load_model 问题
使用 tf.keras.models.load_model 加载模型后会出现准确率极低的问题, 就好像没有进行训练过一样. 这是因为我们在模型中使用了 v1.x 优化器 (来自 tf.compat.v1.train), 而此类优化器由于与检查点不兼容, 所以在载入模型时会丢失优化器的状态值. 我们只能通过重新编译模型来恢复优化器的状态.
本文如何讲述
本文用到的库
# 这是本文用到的库
import tensorflow as tf
import pathlib
设当下有一名为 model
的已训练完的采用了 v1.x 优化器 ‘adam’ 的模型.
构建模型
model = tf.keras.Sequential([
layers.experimental.preprocessing.Rescaling(1 / 255),
layers.Conv2D(32, 3, activation='relu'),
layers.MaxPooling2D(),
layers.Conv2D(32, 3, activation='relu'),
layers.MaxPooling2D(),
layers.Flatten(),
layers.Dense(128, activation='relu'),
layers.Dense(10)
])
model.compile(
optimizer='adam', # 这里的优化器 'adam' 就是一个 v1.x 优化器
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
metrics=['accuracy']
)
训练模型
# 假设我们已经有了一个名为 train_set 的训练集, 使用该训练集对模型进行训练(5个轮次).
model.fit(train_set, epoch=5)
保存模型
我们想把 model 保存到当前工作目录下, 文件名为 model.h5 (HDF5 格式), 步骤如下:
首先, 设定模型保存的路径:
# 使用 pathlib.Path 类构造路径, 可以免受平台差异性的困扰
model_path = pathlib.Path(r'assets/model.h5')
然后, 将模型保存成文件.
model = model.save(model_path)
加载模型
# 导入模型
new_model = tf.keras.models.load_model(model_path)
# 需再次编译该模型, 编译条件须与训练时相同
model.compile(
optimizer='adam',
loss=tf.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['accuracy']
)
评估模型
假设我们已经有了一个测试集 test_set
, 我们可以使用此测试集对模型进行评估:
model.evaluate(test_set)
总结
由于 tf.compat.v1.train 里的优化器不兼容检查点, 所以我们后续加载模型时会丢失优化器的状态值, 导致后续评估结果过差. 我们可通过重新编译的方式解决此问题.
除特别声明,本站所有文章均为原创,如需转载请以超级链接形式注明出处:SmartCat's Blog
上一篇: mysql配置
精华推荐