这是一篇由 Adrian Rosebrock 撰写的客座文章。Adrian 是 PyImageSearch.com 的作者,这是一个关于计算机视觉和深度学习的博客。Adrian 最近完成了新书 Deep Learning for Computer Vision with Python 的撰写,这本书介绍了如何使用 Keras 进行计算机视觉和图像识别的深度学习。
在本教程中,我们将介绍一种简单的方法,将 Keras 模型部署为 REST API。
这篇文章中涵盖的示例将作为构建您自己的深度学习 API 的模板/起点,您可以根据 API 端点的可扩展性和健壮性需求来扩展和自定义代码。
具体来说,我们将学习
- 如何(以及如何不)将 Keras 模型加载到内存中,以便有效地用于推理
- 如何使用 Flask Web 框架为我们的 API 创建端点
- 如何使用我们的模型进行预测,将它们转换为 JSON 格式,并将结果返回给客户端
- 如何使用 cURL 和 Python 调用我们的 Keras REST API
在本教程结束时,您将很好地理解创建 Keras REST API 所需的组件(以其最简单的形式)。
您可以随意使用本指南中提供的代码作为您自己深度学习 REST API 的起点。
注意:此处介绍的方法仅供参考。它并非 用于生产环境,也不能在高负载下扩展。如果您对利用消息队列和批处理的更高级的 Keras REST API 感兴趣,请参阅本教程。
配置您的开发环境
我们假设您的机器上已经配置并安装了 Keras。如果没有,请确保使用 官方安装说明 安装 Keras。
接下来,我们需要安装 Flask(及其相关依赖项),这是一个 Python Web 框架,以便我们可以构建 API 端点。我们还需要 requests,以便我们也可以使用我们的 API。
相关的 pip
安装命令如下所示
$ pip install flask gevent requests pillow
构建您的 Keras REST API
我们的 Keras REST API 自包含在一个名为 run_keras_server.py
的文件中。为了简单起见,我们将安装保存在一个文件中,实现也可以轻松模块化。
在 run_keras_server.py
中,您将找到三个函数,分别是
load_model
:用于加载我们训练好的 Keras 模型并为推理做好准备。
prepare_image
:此函数在将输入图像传递给网络进行预测之前对其进行预处理。如果您不使用图像数据,则可以考虑将其名称更改为更通用的 prepare_datapoint
,并应用您可能需要的任何缩放/归一化。
predict
:我们 API 的实际端点,它将对来自请求的传入数据进行分类,并将结果返回给客户端。
本教程的完整代码可以在 此处 找到。
# import the necessary packages
from keras.applications import ResNet50
from keras.preprocessing.image import img_to_array
from keras.applications import imagenet_utils
from PIL import Image
import numpy as np
import flask
import io
# initialize our Flask application and the Keras model
app = flask.Flask(__name__)
model = None
我们的第一个代码片段处理导入所需的包以及初始化 Flask 应用程序和我们的 model
。
接下来,我们定义 load_model
函数
def load_model():
# load the pre-trained Keras model (here we are using a model
# pre-trained on ImageNet and provided by Keras, but you can
# substitute in your own networks just as easily)
global model
model = ResNet50(weights="imagenet")
顾名思义,此方法负责实例化我们的架构并从磁盘加载权重。
为了简单起见,我们将使用已经在 ImageNet 数据集上预先训练过的 ResNet50 架构。
如果您使用自己的自定义模型,则需要修改此函数以从磁盘加载您的架构+权重。
在我们对来自客户端的任何数据执行预测之前,我们首先需要准备和预处理数据
def prepare_image(image, target):
# if the image mode is not RGB, convert it
if image.mode != "RGB":
image = image.convert("RGB")
# resize the input image and preprocess it
image = image.resize(target)
image = img_to_array(image)
image = np.expand_dims(image, axis=0)
image = imagenet_utils.preprocess_input(image)
# return the processed image
return image
此函数
- 接受输入图像
- 将模式转换为 RGB(如果需要)
- 将其调整为 224x224 像素(ResNet 的输入空间维度)
- 通过均值减法和缩放对数组进行预处理
同样,您应该根据在将输入数据传递给模型之前需要进行的任何预处理、缩放和/或归一化来修改此函数。
我们现在可以定义 predict
函数了,此方法处理对 /predict
端点的任何请求
@app.route("/predict", methods=["POST"])
def predict():
# initialize the data dictionary that will be returned from the
# view
data = {"success": False}
# ensure an image was properly uploaded to our endpoint
if flask.request.method == "POST":
if flask.request.files.get("image"):
# read the image in PIL format
image = flask.request.files["image"].read()
image = Image.open(io.BytesIO(image))
# preprocess the image and prepare it for classification
image = prepare_image(image, target=(224, 224))
# classify the input image and then initialize the list
# of predictions to return to the client
preds = model.predict(image)
results = imagenet_utils.decode_predictions(preds)
data["predictions"] = []
# loop over the results and add them to the list of
# returned predictions
for (imagenetID, label, prob) in results[0]:
r = {"label": label, "probability": float(prob)}
data["predictions"].append(r)
# indicate that the request was a success
data["success"] = True
# return the data dictionary as a JSON response
return flask.jsonify(data)
data
字典用于存储我们要返回给客户端的任何数据。目前,这包括一个用于指示预测是否成功的布尔值,我们还将使用此字典来存储我们对传入数据进行的任何预测的结果。
为了接受传入数据,我们检查是否
- 请求方法是 POST(使我们能够向端点发送任意数据,包括图像、JSON、编码数据等)
- 在 POST 期间,已将
image
传递到 files
属性中
然后,我们获取传入数据并
- 以 PIL 格式读取它
- 对其进行预处理
- 将其传递给我们的网络
- 遍历结果并将它们分别添加到
data["predictions"]
列表中
- 以 JSON 格式将响应返回给客户端
如果您使用的是非图像数据,则应删除 request.files
代码,并自行解析原始输入数据,或使用 request.get_json()
自动将输入数据解析为 Python 字典/对象。此外,请考虑阅读以下教程,其中讨论了 Flask request 对象
的基础知识。
现在剩下的就是启动我们的服务了
# if this is the main thread of execution first load the model and
# then start the server
if __name__ == "__main__":
print(("* Loading Keras model and Flask starting server..."
"please wait until server has fully started"))
load_model()
app.run()
首先,我们调用 load_model
,它从磁盘加载我们的 Keras 模型。
对 load_model
的调用是一个阻塞操作,它会阻止 Web 服务启动,直到模型完全加载为止。如果我们没有确保在启动 Web 服务之前将模型完全加载到内存中并准备好进行推理,则可能会遇到以下情况
- 向服务器发送了一个请求。
- 服务器接受请求,对数据进行预处理,然后尝试将其传递给模型
- ...但由于模型尚未完全加载,因此我们的脚本将出错!
在构建您自己的 Keras REST API 时,请确保插入逻辑以确保在接受请求之前加载模型并准备好进行推理。
如何在 REST API 中不 加载 Keras 模型
您可能会试图在 predict
函数内部加载模型,如下所示
...
# ensure an image was properly uploaded to our endpoint
if request.method == "POST":
if request.files.get("image"):
# read the image in PIL format
image = request.files["image"].read()
image = Image.open(io.BytesIO(image))
# preprocess the image and prepare it for classification
image = prepare_image(image, target=(224, 224))
# load the model
model = ResNet50(weights="imagenet")
# classify the input image and then initialize the list
# of predictions to return to the client
preds = model.predict(image)
results = imagenet_utils.decode_predictions(preds)
data["predictions"] = []
...
这段代码意味着每次有新请求传入时都会加载 model
。这是非常低效的,甚至会导致您的系统内存不足。
如果您尝试运行上面的代码,您会注意到您的 API 将运行得非常慢(特别是如果您的模型很大),这是因为为每个新请求加载模型在 I/O 和 CPU 操作中都有很大的开销。
要了解这如何轻易耗尽服务器的内存,假设我们同时有 N 个传入服务器的请求。这意味着内存中将加载 N 个模型...同样是同时加载。如果您的模型很大,例如 ResNet,则在 RAM 中存储 N 个模型副本很容易耗尽系统内存。
为此,除非您有非常具体、合理的理由,否则请尽量避免为每个新传入请求加载新的模型实例。
注意:我们假设您使用的是默认的单线程 Flask 服务器。如果您部署到多线程服务器,即使使用本文前面讨论的“更正确”方法,您仍然可能会遇到在内存中加载多个模型的情况。如果您打算使用专用服务器(如 Apache 或 nginx),则应考虑使您的管道更具可扩展性,如此处所述。
启动您的 Keras Rest API
启动 Keras REST API 服务很容易。
打开终端并执行
$ python run_keras_server.py
Using TensorFlow backend.
* Loading Keras model and Flask starting server...please wait until server has fully started
...
* Running on http://127.0.0.1:5000
从输出中可以看到,我们的模型是首先加载的,然后我们才能启动 Flask 服务器。
您现在可以通过 http://127.0.0.1:5000
访问服务器。
但是,如果您将 IP 地址+端口复制并粘贴到浏览器中,您将看到以下图像
出现这种情况的原因是在 Flask URL 路由中没有设置索引/主页。
相反,请尝试通过浏览器访问 /predict
端点
您将看到“不允许使用的方法”错误。出现此错误的原因是您的浏览器正在执行 GET 请求,但 /predict
仅接受 POST(我们将在下一节中演示如何执行)。
使用 cURL 测试 Keras REST API
在测试和调试您的 Keras REST API 时,请考虑使用 cURL(无论如何,这都是一个值得学习的好工具)。
下面您可以看到我们想要分类的图像,一只狗,更具体地说是一只比格犬
我们可以使用 curl
将此图像传递给我们的 API,并了解 ResNet 认为图像包含的内容
$ curl -X POST -F image=@dog.jpg 'https://127.0.0.1:5000/predict'
{
"predictions": [
{
"label": "beagle",
"probability": 0.9901360869407654
},
{
"label": "Walker_hound",
"probability": 0.002396771451458335
},
{
"label": "pot",
"probability": 0.0013951235450804234
},
{
"label": "Brittany_spaniel",
"probability": 0.001283277408219874
},
{
"label": "bluetick",
"probability": 0.0010894243605434895
}
],
"success": true
}
-X
标志和 POST
值表示我们正在执行 POST 请求。
我们提供 -F [email protected]
来指示我们正在提交表单编码数据。然后,将 image
键设置为 dog.jpg
文件的内容。在 dog.jpg
之前提供 @
表示我们希望 cURL 加载图像的内容并将数据传递给请求。
最后,我们有我们的端点:https://127.0.0.1:5000/predict
请注意,输入图像如何被正确分类为 "beagle",置信度为 99.01%。其余的前 5 个预测及其相关概率也包含在我们 Keras API 的响应中。
以编程方式使用 Keras REST API
很有可能,您将向您的 Keras REST API 提交数据,然后以某种方式使用返回的预测,这需要我们以编程方式处理来自服务器的响应。
使用 requests Python 包,这是一个简单的过程
# import the necessary packages
import requests
# initialize the Keras REST API endpoint URL along with the input
# image path
KERAS_REST_API_URL = "https://127.0.0.1:5000/predict"
IMAGE_PATH = "dog.jpg"
# load the input image and construct the payload for the request
image = open(IMAGE_PATH, "rb").read()
payload = {"image": image}
# submit the request
r = requests.post(KERAS_REST_API_URL, files=payload).json()
# ensure the request was successful
if r["success"]:
# loop over the predictions and display them
for (i, result) in enumerate(r["predictions"]):
print("{}. {}: {:.4f}".format(i + 1, result["label"],
result["probability"]))
# otherwise, the request failed
else:
print("Request failed")
KERAS_REST_API_URL
指定我们的端点,而 IMAGE_PATH
是驻留在磁盘上的输入图像的路径。
使用 IMAGE_PATH
,我们加载图像,然后构造请求的 payload
。
给定 payload
,我们可以使用对 requests.post
的调用将数据 POST 到我们的端点。在调用末尾附加 .json()
会指示 requests
- 来自服务器的响应应为 JSON 格式
- 我们希望自动解析和反序列化 JSON 对象
获得请求的输出 r
后,我们可以检查分类是否成功,然后遍历 r["predictions"]
。
要运行 simple_request.py
,首先确保 run_keras_server.py
(即 Flask Web 服务器)当前正在运行。然后,在单独的 shell 中执行以下命令
$ python simple_request.py
1. beagle: 0.9901
2. Walker_hound: 0.0024
3. pot: 0.0014
4. Brittany_spaniel: 0.0013
5. bluetick: 0.0011
我们已成功调用 Keras REST API 并通过 Python 获得了模型的预测。
在这篇文章中,您学习了如何
本教程中涵盖的代码可以在 此处 找到,并可用作您自己的 Keras REST API 的模板 - 您可以根据需要随意修改它。
请记住,这篇文章中的代码仅供参考。它并非用于生产环境,并且无法在高负载和大量传入请求下进行扩展。
此方法最适合在以下情况下使用:
- 您需要为您的 Keras 深度学习模型快速建立一个 REST API。
- 您的端点不会受到大量访问。
如果您对利用消息队列和批处理的更高级的 Keras REST API 感兴趣,请参阅 这篇博文。
如果您对这篇文章有任何问题或意见,请联系 PyImageSearch 的 Adrian(今天这篇文章的作者)。有关未来主题的建议,请在 Twitter 上找到 Francois。