温馨提示×

Rust在Debian上的机器学习实践

小樊
35
2026-01-01 16:16:27
栏目: 编程语言

在 Debian 上用 Rust 做机器学习,可以兼顾性能、内存安全与工程化部署。下面给出从环境准备到训练、推理与部署的完整实践路线,并配套可直接运行的示例与要点。

环境准备与工具链

  • 安装 Rust 工具链(rustup):
    • 执行:curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh
    • 生效:source $HOME/.cargo/env
    • 验证:rustc --versioncargo --version
  • 可选:配置国内 crates 镜像(加速依赖下载),在 ~/.cargo/config.toml 添加:
    [source.crates-io]
    replace-with = 'tuna'
    [source.tuna]
    registry = "https://mirrors.tuna.tsinghua.edu.cn/git/crates.io-index.git"
    
  • Debian 编译依赖(构建部分 crate 时需要):
    • 执行:sudo apt-get update && sudo apt-get install -y build-essential cmake ccache pkg-config libssl-dev libclang-dev clang llvm-dev git-lfs
  • 建议:使用较新的 Debian 12/BookwormDebian 11/Bullseye,以获得较新的编译器与库版本支持。

常用库与生态选择

场景 推荐库 关键点
传统机器学习 linfasmartcore 类似 scikit-learn 的 API,涵盖回归、分类、聚类等
深度学习 tch-rscandle tch-rsPyTorch 绑定;candle 轻量、纯 Rust、易部署
数据处理 ndarraypolars ndarray 多维数组;polars 高性能数据框
推理部署 tractwonnx ONNX 模型推理与优化,适合跨语言与服务化部署
以上库在 Debian 上均可直接用 Cargo 构建,生态成熟,示例丰富。

快速上手示例

  • 示例一 传统机器学习:用 linfa 做线性回归
    1. 新建项目:cargo new linfa-lr && cd linfa-lr
    2. 添加依赖(Cargo.toml):
      [dependencies]
      linfa = "0.6"
      ndarray = "0.15"
      
    3. 示例代码(src/main.rs):
      use linfa::prelude::*;
      use ndarray::array;
      
      fn main() {
          // 训练数据:y = 2x + 1
          let x = array![[1.0], [2.0], [3.0], [4.0], [5.0]];
          let y = array![3.0, 5.0, 7.0, 9.0, 11.0];
      
          // 训练
          let model = linfa::linear_regression::LinearRegression::default().fit(&x, &y).unwrap();
      
          // 预测
          let preds = model.predict(&array![[6.0], [7.0]]);
          println!("Predictions: {:?}", preds); // 期望接近 [13.0, 15.0]
      }
      
    4. 运行:cargo run
  • 示例二 深度学习:用 tch-rs 训练 MNIST 手写数字识别
    1. 新建项目:cargo new tch-mnist && cd tch-mnist
    2. 添加依赖(Cargo.toml):
      [dependencies]
      tch = { version = "0.22", features = ["vision"] }
      anyhow = "1.0"
      
    3. 示例代码(src/main.rs,CPU 版):
      use anyhow::Result;
      use tch::{nn, nn::ModuleT, Device, Tensor, Kind};
      use tch::vision::mnist;
      
      fn lenet(vs: &nn::Path) -> impl ModuleT {
          nn::seq()
              .add(nn::conv2d(vs / "conv1", 1, 6, 5, Default::default()))
              .add_fn(|xs| xs.relu().max_pool2d_default(2))
              .add(nn::conv2d(vs / "conv2", 6, 16, 5, Default::default()))
              .add_fn(|xs| xs.relu().max_pool2d_default(2))
              .add_fn(|xs| xs.flatten(1, -1))
              .add(nn::linear(vs / "fc1", 16 * 5 * 5, 120, Default::default()))
              .add_fn(|xs| xs.relu())
              .add(nn::linear(vs / "fc2", 120, 84, Default::default()))
              .add_fn(|xs| xs.relu())
              .add(nn::linear(vs / "fc3", 84, 10, Default::default()))
      }
      
      fn main() -> Result<()> {
          let device = Device::Cpu; // 有 NVIDIA GPU 可改为 Device::CudaIfAvailable
          let vs = nn::VarStore::new(device);
          let net = lenet(&vs.root());
          let mut opt = nn::Adam::default().build(&vs, 1e-3)?;
      
          // 加载 MNIST 数据集
          let m = mnist::load_dir("data")?;
      
          for epoch in 1..=5 {
              let loss = net
                  .forward_t(&m.train_images, true)
                  .cross_entropy_for_logits(&m.train_labels);
              opt.backward_step(&loss);
      
              let acc = net
                  .forward_t(&m.test_images, false)
                  .accuracy_for_logits(&m.test_labels);
              println!("Epoch: {:2}, Loss: {:.4}, Test Acc: {:.2}%",
                       epoch, f64::from(loss), 100.0 * acc);
          }
          Ok(())
      }
      
    4. 运行前准备数据:mkdir -p data && wget -O data/train-images-idx3-ubyte.gz http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz && wget -O data/train-labels-idx1-ubyte.gz http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz && wget -O data/t10k-images-idx3-ubyte.gz http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz && wget -O data/t10k-labels-idx1-ubyte.gz http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
    5. 运行:cargo run --release
  • 说明
    • 若需 GPU:安装 CUDAcuDNN,并将设备改为 Device::CudaIfAvailable;tch-rs 将自动使用 GPU。
    • 若使用 candle 做 MNIST,可参考其官方示例,支持 cargo build --examples --features cuda 启用 CUDA。

GPU加速与依赖要点

  • tch-rs(PyTorch 绑定)
    • 需要 libtorch 运行时;可通过环境变量指定下载地址,例如:
      export LIBTORCH=https://download.pytorch.org/libtorch/cpu/libtorch-cxx11-abi-shared-with-deps-2.0.1%2Bcpu.zip
      export LIBTORCH_USE_PYTORCH=1
      
    • CPU 版直接可用;GPU 版需匹配 CUDA 版本的 libtorch,并将设备设为 tch::Device::CudaIfAvailable
  • candle(纯 Rust)
    • 启用 CUDA:在构建或运行示例时添加 --features cuda;确保系统已安装 CUDA Toolkit 与相应驱动。
  • 通用建议
    • 使用 --release 编译以获得最佳性能。
    • 在容器/CI 中固定 CUDAcuDNNlibtorch 版本,避免环境漂移。

模型部署与服务化

  • ONNX 推理
    • 训练或导出得到 ONNX 模型后,可用 tractwonnx 在 Rust 中加载与推理,适合跨语言、低开销的服务端部署。
  • 原生二进制与服务
    • tch-rscandle 均可导出/保存模型参数,构建静态二进制,结合 actix-webwarp 暴露 HTTP 推理接口,实现低延迟在线服务。
  • 工程化建议
    • 分离模型训练与推理代码;推理侧仅保留前向计算图与必要的预处理/后处理。
    • 使用 serde 做输入输出序列化(如 JSON),并加入批处理与并发控制以充分利用多核与 GPU。

0