TypeError: load() missing 1 required positional argument: ‘sess’ when loading model from TF2 Object-Detection-API Tutorial SavedModel Example

原文标题TypeError: load() missing 1 required positional argument: ‘sess’ when loading model from TF2 Object-Detection-API Tutorial SavedModel Example

我正在阅读有关阅读文档 io 的 Tensor Flow 2 对象检测 API 教程,并正在使用 TF2 Saved Model 示例中的对象检测。https://tensorflow-object-detection-api-tutorial.readthedocs.io/en/latest/ auto_examples/plot_object_detection_saved_model.html

我使用教程中的代码成功下载了模型:

# Download and extract model
def download_model(model_name, model_date):
    base_url = 'http://download.tensorflow.org/models/object_detection/tf2/'
    model_file = model_name + '.tar.gz'
    model_dir = tf.keras.utils.get_file(fname=model_name,
                                        origin=base_url + model_date + '/' + model_file,
                                        untar=True)
    return str(model_dir)

MODEL_DATE = '20200711'
MODEL_NAME = 'centernet_hg104_1024x1024_coco17_tpu-32'
PATH_TO_MODEL_DIR = download_model(MODEL_NAME, MODEL_DATE)

但是当我从教程中运行加载脚本时:

import time
from object_detection.utils import label_map_util
from object_detection.utils import visualization_utils as viz_utils

PATH_TO_SAVED_MODEL = PATH_TO_MODEL_DIR + "/saved_model"

print('Loading model...', end='')
start_time = time.time()

# Load saved model and build the detection function
detect_fn = tf.saved_model.load(export_dir=PATH_TO_SAVED_MODEL, tags=None, options=None)

end_time = time.time()
elapsed_time = end_time - start_time
print('Done! Took {} seconds'.format(elapsed_time))

我收到 errorTypeError: load() missing 1 required positional argument: ‘sess’ 最初,我对这个函数的唯一参数是 ‘PATH_TO_SAVED_MODEL’,但之后添加了 ‘export_dir=’、’tags=None’ 和 ‘options=None’类似的错误信息提示我这样做。使用 sess,我尝试添加 ‘sess’、’sess=None’,并研究 tf 文档以获取有关 load() 函数中 sess 参数的详细信息,但没有任何运气。

另外,我想知道我下载模型的方式是否存在问题。 tf 文档在使用 load 函数之前总是使用 saved_model.save() 函数,我假设我的脚本下载模型做了同样的事情,但是我还需要在加载之前保存下载的模型吗?

如果有人对如何成功加载模型有任何建议,我将不胜感激!

原文链接:https://stackoverflow.com//questions/71516758/typeerror-load-missing-1-required-positional-argument-sess-when-loading-mo

回复

我来回复
  • Tillmann的头像
    Tillmann 评论

    尝试使用 `load_model() 函数:

    model = tf.keras.models.load_model('<path_to_model>')
    
    2年前 0条评论