Young87

SmartCat's Blog

So happy to code my life!

游戏开发交流QQ群号60398951

当前位置:首页 >跨站数据

tf.keras: 保存与加载模型

tf.keras.save 和 tf.keras.models.load_model 问题

使用 tf.keras.models.load_model 加载模型后会出现准确率极低的问题, 就好像没有进行训练过一样. 这是因为我们在模型中使用了 v1.x 优化器 (来自 tf.compat.v1.train), 而此类优化器由于与检查点不兼容, 所以在载入模型时会丢失优化器的状态值. 我们只能通过重新编译模型来恢复优化器的状态.

本文如何讲述

本文用到的库
构建模型
训练模型
保存模型
加载模型
模型评估

本文用到的库

# 这是本文用到的库
import tensorflow as tf
import pathlib

当下有一名为 model 的已训练完的采用了 v1.x 优化器 ‘adam’ 的模型.

构建模型

model = tf.keras.Sequential([
    layers.experimental.preprocessing.Rescaling(1 / 255),
    layers.Conv2D(32, 3, activation='relu'),
    layers.MaxPooling2D(),
    layers.Conv2D(32, 3, activation='relu'),
    layers.MaxPooling2D(),
    layers.Flatten(),
    layers.Dense(128, activation='relu'),
    layers.Dense(10)
])
model.compile(
	optimizer='adam',  # 这里的优化器 'adam' 就是一个 v1.x 优化器
	loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
	metrics=['accuracy']
)

训练模型

# 假设我们已经有了一个名为 train_set 的训练集, 使用该训练集对模型进行训练(5个轮次).
model.fit(train_set, epoch=5)

保存模型

我们想把 model 保存到当前工作目录下, 文件名为 model.h5 (HDF5 格式), 步骤如下:

首先, 设定模型保存的路径:

# 使用 pathlib.Path 类构造路径, 可以免受平台差异性的困扰
model_path = pathlib.Path(r'assets/model.h5') 

然后, 将模型保存成文件.

model = model.save(model_path)

加载模型

# 导入模型
new_model = tf.keras.models.load_model(model_path)

# 需再次编译该模型, 编译条件须与训练时相同
model.compile(
	optimizer='adam',
	loss=tf.losses.SparseCategoricalCrossentropy(from_logits=True),
	metrics=['accuracy']	   
)

评估模型

假设我们已经有了一个测试集 test_set, 我们可以使用此测试集对模型进行评估:

model.evaluate(test_set)

总结

由于 tf.compat.v1.train 里的优化器不兼容检查点, 所以我们后续加载模型时会丢失优化器的状态值, 导致后续评估结果过差. 我们可通过重新编译的方式解决此问题.

除特别声明,本站所有文章均为原创,如需转载请以超级链接形式注明出处:SmartCat's Blog

上一篇: mysql配置

下一篇: 【UVM源码学习】uvm_registry

精华推荐