一码在手,天下我有 —— m2cgen 让机器学习模型跨语言无缝转换

文摘   2024-07-13 07:00   江苏  

一码在手,天下我有 —— m2cgen 让机器学习模型跨语言无缝转换

m2cgen(Model 2 Code Generator)是一个轻量级的库,它提供了一个简单的方法,可以将训练好的统计模型转换成原生代码,支持多种编程语言,包括 Python、C、Java、Go、JavaScript、Visual Basic、C#、PowerShell、R、PHP、Dart、Haskell、Ruby、F#、Rust 和 Elixir。

安装

支持的 Python 版本是>= 3.7

pip install m2cgen

开发

在提交 PR 之前,请确保以下命令可以成功运行:

make pre-pr

或者,你可以运行相同命令的 Docker 版本:

make docker-build docker-pre-pr

支持的语言

  • C
  • C#
  • Dart
  • F#
  • Go
  • Haskell
  • Java
  • JavaScript
  • PHP
  • PowerShell
  • Python
  • R
  • Ruby
  • Rust
  • Visual Basic (与 VBA 兼容)
  • Elixir

支持的模型

类型分类回归
线性模型LogisticRegression 等ARDRegression 等
SVMLinearSVC、NuSVC 等LinearSVR、NuSVR 等
树模型DecisionTreeClassifier 等DecisionTreeRegressor 等
随机森林RandomForestClassifier 等RandomForestRegressor 等
提升方法LGBMClassifier(gbdt/dart/goss)等LGBMRegressor(gbdt/dart/goss)等

CI 测试保证兼容性的包版本可以在这里[1]找到。其他版本也可能得到支持,但尚未经过测试。

分类输出

  • 线性/线性 SVM/核 SVM:二元分类输出为样本到超平面的有符号距离;多类分类输出为每个类别的有符号距离向量。
  • SVM:异常检测、二元和多类分类输出。
  • 树/随机森林/提升方法:二元和多类分类输出为类概率向量。

使用方法

以下是一个简单的示例,展示如何在 Java 代码中表示在 Python 环境中训练的线性模型:

from sklearn.datasets import load_diabetes
from sklearn import linear_model
import m2cgen as m2c

X, y = load_diabetes(return_X_y=True)

estimator = linear_model.LinearRegression()
estimator.fit(X, y)

code = m2c.export_to_java(estimator)

生成的 Java 代码示例:

public class Model {
    public static double score(double[] input) {
        return ((((((((((152.1334841628965) + ((input[0]) * (-10.012197817470472))) + ... ) + ((input[9]) * (67.62538639104406));
    }
}

更多不同模型/语言生成的代码示例可以在这里[2]找到。

命令行界面(CLI)

m2cgen 可以用作 CLI 工具,使用序列化的模型对象(pickle 协议)生成代码:

$ m2cgen <pickle_file> --language <language> [--indent <indent>] [--function_name <function_name>]
         [--class_name <class_name>] [--module_name <module_name>] [--package_name <package_name>]
         [--namespace <namespace>] [--recursion-limit <recursion_limit>]

注意,对于反序列化序列化的模型对象,它们的类必须在反序列化环境中可导入模块的顶层定义。

也支持管道:

$ cat <pickle_file> | m2cgen --language <language>

常见问题解答(FAQ)

  • Q: 生成失败,出现 RecursionError: maximum recursion depth exceeded 错误。
    A: 如果在生成代码时出现此错误,请尝试减少该模型中的训练估计器数量。或者,你可以使用 sys.setrecursionlimit(<new_depth>) 增加最大递归深度。

  • Q: 在从序列化的模型对象转换模型时生成失败,出现 ImportError: No module named <module_name_here> 错误。
    A: 此错误表明 pickle 协议无法反序列化模型对象。对于反序列化序列化的模型对象,需要它们的类在反序列化环境中可导入模块的顶层定义。因此,安装提供模型类定义的包应该可以解决问题。

  • Q: 由 m2cgen 生成的代码对于一些输入与原始 Python 模型的结果不同。
    A: 一些模型在它们的原生 Python 库中在预测阶段会强制输入数据为特定类型。目前,m2cgen 仅使用 float64double)数据类型。你可以尝试手动将输入数据转换为另一种类型并再次检查结果。此外,由于目标语言中浮点算术的特定实现,可能会有一些小的差异。

附录

https://github.com/BayesWitnesses/m2cgen

参考资料
[1]

这里: https://github.com/BayesWitnesses/m2cgen/blob/master/requirements-test.txt#L1

[2]

这里: https://github.com/BayesWitnesses/m2cgen/tree/master/generated_code_examples


编程悟道
自制软件研发、软件商店,全栈,ARTS 、架构,模型,原生系统,后端(Node、React)以及跨平台技术(Flutter、RN).vue.js react.js next.js express koa hapi uniapp Astro
 最新文章