温馨提示×

Linux平台上PyTorch的模型部署方法

小樊
45
2025-12-28 07:14:03
栏目: 智能运维

Linux平台上PyTorch模型部署方法

一 环境准备与模型保存加载

  • 建议使用 Ubuntu 20.04/22.04CentOS 7/8,搭配 Python 3.8+Anaconda/Minicondavenv 管理依赖。PyTorch 安装示例:
    • CPU 版(pip):pip3 install torch torchvision torchaudio
    • GPU 版(pip):pip3 install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu118
    • CPU 版(conda):conda install pytorch torchvision torchaudio cpuonly -c pytorch
    • GPU 版(conda):conda install pytorch torchvision torchaudio pytorch-cuda=11.8 -c pytorch -c nvidia
      验证命令:
      python3 - <<‘PY’ import torch print(“PyTorch version:”, torch.version) print(“CUDA available:”, torch.cuda.is_available()) PY
  • 模型保存与加载要点:
    • 仅保存权重:torch.save(model.state_dict(), ‘model.pth’);加载时先实例化模型结构,再 model.load_state_dict(torch.load(‘model.pth’))。
    • GPU 训练权重加载到 CPU:torch.load(‘model.pth’, map_location=torch.device(‘cpu’))。
    • 推理前务必 model.eval(),并使用 with torch.no_grad(): 关闭梯度计算。

二 部署路径选型

方式 适用场景 关键要点 典型命令或接口
Python 原生服务(Flask/FastAPI) 快速上线、内部服务 易开发、热更新;并发需自行优化(如多进程/异步/队列) Flask/FastAPI + gunicorn/uvicorn
TorchScript(LibTorch/C++) 无 Python 依赖、C++ 服务 通过 torch.jit.trace/script 导出;C++ 端用 LibTorch 加载 torch.jit.trace/script → model.pt;C++ 中 torch::jit::load
ONNX Runtime 跨框架、跨平台、CPU/GPU 导出 ONNX,ORT 高性能推理;注意 opset 与动态轴 torch.onnx.export → onnxruntime.InferenceSession
TorchServe 官方模型服务、生产级 模型管理、批处理、多模型、指标与健康检查 torchserve --start …;curl 调用 /predictions
TensorRT NVIDIA GPU 极致性能 ONNX→TensorRT 引擎;FP16/INT8、严格形状优化 trtexec --onnx=model.onnx --saveEngine=model.engine

三 关键步骤与示例

  • 原生 Python 服务(Flask 最小示例)
    • 安装:pip install flask
    • 代码要点: from flask import Flask, request, jsonify import torch, torch.nn as nn app = Flask(name) class Net(nn.Linear): pass model = Net(10, 5); model.load_state_dict(torch.load(‘model.pth’, map_location=‘cpu’)); model.eval() @app.route(‘/predict’, methods=[‘POST’]) def predict(): x = torch.tensor(request.json[‘x’], dtype=torch.float32).unsqueeze(0) with torch.no_grad(): y = model(x) return jsonify({‘y’: y.squeeze().tolist()}) if name == ‘main’: app.run(host=‘0.0.0.0’, port=5000)
    • 运行与测试:python app.py;curl -X POST -H “Content-Type: application/json” -d ‘{“x”: [1,2,3,4,5,6,7,8,9,10]}’ http://localhost:5000/predict
  • TorchScript 导出与 C++ 推理
    • 导出: import torch model = torchvision.models.resnet18(pretrained=True); model.eval() example = torch.rand(1, 3, 224, 224) traced = torch.jit.trace(model, example); traced.save(“resnet18_traced.pt”)
    • C++ 侧(LibTorch): #include <torch/torch.h> int main() { auto module = torch::jit::load(“resnet18_traced.pt”); auto input = torch::randn({1, 3, 224, 224}); auto output = module.forward({input}).toTensor(); }
  • ONNX 导出与 ONNX Runtime 推理
    • 导出: import torch, torchvision model = torchvision.models.resnet18(pretrained=True); model.eval() dummy = torch.randn(1, 3, 224, 224) torch.onnx.export(model, dummy, “resnet18.onnx”, input_names=[“input”], output_names=[“output”], opset_version=11)
    • 推理(Python): import onnxruntime as ort, numpy as np sess = ort.InferenceSession(“resnet18.onnx”) x = np.random.randn(1, 3, 224, 224).astype(np.float32) out = sess.run(None, {“input”: x})
  • TorchServe 部署
    • 安装:pip install torchserve torch-model-archiver
    • 打包:torch-model-archiver --model-name resnet18 --version 1.0 --serialized-file resnet18_traced.pt --handler image_classifier --export-path model_store
    • 启动:torchserve --start --model-store model_store --models resnet18=resnet18.mar
    • 调用:curl -X POST http://localhost:8080/predictions/resnet18 -T input.json
  • TensorRT 加速
    • ONNX→Engine:trtexec --onnx=resnet18.onnx --saveEngine=resnet18.engine --explicitBatch
    • 随后可用 C++/Python TensorRT API 加载 engine 推理(适合高吞吐 GPU 场景)。

四 性能与安全优化

  • 性能
    • 批处理与异步:在服务端启用批量(TorchServe 批处理)、异步队列与多工作进程(如 gunicorn 的 workers/threads)。
    • 设备与精度:GPU 推理使用 .to(‘cuda’);在 TensorRT 中启用 FP16/INT8 量化(需校准)。
    • 图优化:ONNX 导出选择合适 opset,并用 onnx-simplifier 简化图结构;TensorRT 使用严格形状与层融合。
    • 预热与复用:进程内长期持有模型与 session,避免频繁加载开销。
  • 安全
    • 仅暴露必要端口,启用 HTTPS 与鉴权;输入校验与大小限制;容器内以 非 root 运行;日志与异常隔离。

五 常见问题与排查

  • 版本与设备不匹配:加载 GPU 权重到 CPU 需 map_location;跨机器/容器部署时保持 PyTorch 版本一致
  • 动态控制流:有 if/for 等控制流请用 torch.jit.script;纯前向可用 trace,但需覆盖代表性输入。
  • ONNX 导出失败或不一致:检查自定义算子、动态轴、opset;用 onnx.checker 与 onnx-simplifier 验证。
  • TensorRT 转换/运行失败:确认 CUDA/cuDNN/TensorRT 版本匹配;必要时降低 opset 或固定输入形状;检查严格形状约束。
  • 服务无响应或崩溃:添加日志与异常捕获,使用进程管理工具(如 systemd/supervisord)自动拉起;逐步回退到最小可复现示例定位问题。

0