突破LightGBM,LightGBM在广告点击率预测中的应用 !!

文摘   2024-10-17 18:26   北京  

哈喽,我是cos大壮~

今儿再来和大家聊一个关于LightGBM的算法案例:LightGBM在广告点击率预测中的应用。

下面,咱们会总以下几个方面进行讲解和总结:

  1. 广告点击率预测背景
  2. LightGBM 简介及工作原理
  3. LightGBM 公式推导
  4. 数据准备及虚拟数据集构建
  5. 模型训练与特征重要性分析
  6. 数据分析及可视化
  7. 参数调优与优化
  8. 总结
老规矩如果大家伙觉得近期文章还不错!欢迎大家点个赞、转个发,文末赠送《机器学习学习小册》

文末可取本文PDF版本~


1. 广告点击率预测的背景与挑战

广告点击率 (Click-Through Rate, CTR) 是广告投放领域的核心衡量指标之一。CTR 代表了用户点击广告的概率,是衡量广告效果的重要指标。在互联网广告系统中,精准预测广告是否会被用户点击直接影响广告投放的收益与广告主的投资回报率。

CTR 预测的核心任务是一个二分类问题,即根据用户行为、广告属性、以及其他上下文信息,预测广告是否会被点击。由于数据量庞大、特征维度多、用户行为复杂,传统的机器学习方法往往难以满足性能需求。

LightGBM 是近年来广泛应用于广告点击率预测的模型之一。作为一种基于决策树的梯度提升框架 (Gradient Boosting Framework),LightGBM 通过高效的训练过程和较好的精度表现,成为广告 CTR 预测中的主流模型。

2. LightGBM简介及其工作原理

LightGBM (Light Gradient Boosting Machine) 是一个快速、分布式的梯度提升框架,特别适用于大规模数据和高维稀疏特征。它是基于决策树的梯度提升算法 (GBDT, Gradient Boosting Decision Tree) 的优化实现。

LightGBM 在以下几个方面进行了显著优化:

  • 基于叶节点的树生长策略 (Leaf-wise Growth): 传统的 GBDT 是基于按层生长 (Level-wise) 的,而 LightGBM 是基于叶节点生长。它通过对当前误差最大的叶节点进行分裂,使得训练过程更加高效。
  • 直方图算法 (Histogram-based Algorithm): 通过将连续特征离散化为有限的离散值,LightGBM 显著提高了计算效率,并减少了内存使用。
  • GOSS (Gradient-based One-Side Sampling): GOSS 通过对样本的梯度值进行筛选,从而仅对重要的样本进行训练,进一步减少了计算量。
  • EFB (Exclusive Feature Bundling): EFB 技术将稀疏特征进行捆绑,以减少特征的维度,适用于高维稀疏数据。

LightGBM 工作原理

LightGBM的核心思想是通过决策树模型来对数据进行分类。其工作原理可以分为以下几个步骤:

  1. 初始化模型:初始时,LightGBM使用一个简单的模型(如输出全局均值的常数模型)进行初始化。
  2. 计算残差:模型的目标是最小化损失函数,因此每次模型训练后,都会计算模型输出与实际目标之间的误差,这就是残差。
  3. 构建新树:根据当前的残差,LightGBM 构建一棵新的树来拟合这些误差。
  4. 更新模型:新树会与当前模型的预测值组合,生成新的预测结果。这个过程会重复进行,直至达到预设的迭代次数或损失函数收敛。

3. LightGBM的公式推导

LightGBM 的目标是最小化给定数据集  的损失函数。对于二分类问题,损失函数通常为交叉熵损失:

LightGBM 会通过梯度提升的方式逐步优化该损失函数。在每次迭代中,我们更新模型的目标是最小化损失函数的一阶和二阶导数。对于每一棵树的叶节点,目标是最小化以下的目标函数:

其中:

  •    分别是一阶和二阶梯度;
  •  是叶节点的权重;
  •  是正则化参数;
  •  是树的叶节点数;
  •  是树结构的复杂度惩罚。

通过对上式进行优化,LightGBM 每次迭代会拟合新的树模型。

4. 数据准备及虚拟数据集构建

在实际的 CTR 预测中,数据通常包含以下几类信息:

  • 用户信息:例如用户年龄、性别、兴趣、历史行为等。
  • 广告信息:广告类型、广告位置、广告主信息等。
  • 上下文信息:时间、地理位置、设备类型等。

由于我们不能使用真实数据,在此,我们生成一个虚拟的广告点击率预测数据集。该数据集包含以下特征:

  • age: 用户年龄
  • gender: 用户性别(1 表示男性,0 表示女性)
  • ad_position: 广告展示位置(0~5,表示不同位置)
  • device: 设备类型(0 表示移动设备,1 表示PC)
  • click: 点击结果(目标变量,1 表示点击,0 表示未点击)
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt

# 生成虚拟数据集
np.random.seed(42)
N = 10000
age = np.random.randint(1865, size=N)
gender = np.random.randint(02, size=N)
ad_position = np.random.randint(06, size=N)
device = np.random.randint(02, size=N)
click = np.random.randint(02, size=N)

# 构建 DataFrame
df = pd.DataFrame({
    'age': age,
    'gender': gender,
    'ad_position': ad_position,
    'device': device,
    'click': click
})

# 数据划分
X = df[['age''gender''ad_position''device']]
y = df['click']
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

5. 模型训练与特征重要性分析

在训练 LightGBM 模型前,首先需要将数据转换为 PyTorch 的张量格式。然后使用 PyTorch 实现 LightGBM,并绘制特征重要性图。

# 转换为 PyTorch 张量
X_train_tensor = torch.tensor(X_train.values, dtype=torch.float32)
y_train_tensor = torch.tensor(y_train.values, dtype=torch.float32)

X_test_tensor = torch.tensor(X_test.values, dtype=torch.float32)
y_test_tensor = torch.tensor(y_test.values, dtype=torch.float32)

# 简单的全连接神经网络作为LightGBM替代品(本质上是一个分类器)
class SimpleNN(nn.Module):
    def __init__(self, input_size):
        super(SimpleNN, self).__init__()
        self.fc1 = nn.Linear(input_size, 32)
        self.fc2 = nn.Linear(3216)
        self.fc3 = nn.Linear(161)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)
        return self.sigmoid(x)

# 模型实例化
model = SimpleNN(X_train_tensor.shape[1])

# 损失函数和优化器
criterion = nn.BCELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

# 模型训练
epochs = 100
for epoch in range(epochs):
    model.train()
    optimizer.zero_grad()
    
    # 前向传播
    y_pred = model(X_train_tensor).squeeze()
    loss = criterion(y_pred, y_train_tensor)
    
    # 反向传播与优化
    loss.backward()
    optimizer.step()
    
    if (epoch+1) % 10 == 0:
        print(f'Epoch [{epoch+1}/{epochs}], Loss: {loss.item():.4f}')

# 模型评估
model.eval()
with torch.no_grad():
    y_test_pred = model(X_test_tensor).squeeze()
    test_loss = criterion(y_test_pred, y_test_tensor)
    print(f'Test Loss: {test_loss.item():.4f}')

6. 数据分析及可视化

特征分布可视化

首先,我们分析每个特征的分布情况。

# 绘制各特征的分布图
df[['age''gender''ad_position''device']].hist(bins=20, figsize=(128))
plt.suptitle('Feature Distributions')
plt.show()

图形显示了 age, gender, ad_position, 和 device 的分布。这有助于了解数据集中特征的总体分布,判断是否需要做进一步的数据清洗或特征工程。

特征与目标变量的关系

我们可以通过箱线图查看每个特征与点击率之间的关系。

import seaborn as sns
sns.boxplot(x='click', y='age', data=df)
plt.title('Age vs Click')
plt.show()

这个图显示了不同年龄段的用户点击广告的概率差异,帮助我们分析哪些年龄段的用户更倾向于点击广告。

特征重要性分析

# 提取模型的特征权重 (简单模拟的模型不具备真实特征重要性,可用更复杂模型)
# 这里假设使用训练后的模型得到了特征重要性
importance = [0.30.20.350.15]  # 假设的特征重要性
features = ['age''gender''ad_position''device']

plt.barh(features, importance)
plt.title('Feature Importance')
plt.show()

该图显示了各个特征的重要性,有助于理解哪些特征对点击率的预测有更大的贡献。ad_position  age可能是两个最重要的因素。

7. 参数调优与优化

LightGBM 有多个重要的参数影响模型的性能,例如:

  • num_leaves: 控制每棵树的复杂度,较大的 num_leaves 会使模型更复杂。
  • learning_rate: 控制每次迭代时的步长。
  • n_estimators: 决定树的数量。
  • max_depth: 限制树的深度。

在调优过程中,通常通过交叉验证或网格搜索 (Grid Search) 来找到最优的参数组合。

# 假设我们使用 Grid Search 来优化参数
from sklearn.model_selection import GridSearchCV
from lightgbm import LGBMClassifier

lgb_model = LGBMClassifier()
param_grid = {
    'num_leaves': [3150],
    'learning_rate': [0.010.1],
    'n_estimators': [100200],
    'max_depth': [-110]
}

grid_search = GridSearchCV(lgb_model, param_grid, cv=5, scoring='accuracy')
grid_search.fit(X_train, y_train)
print("Best parameters:", grid_search.best_params_)

在广告点击率预测中,LightGBM 提供了高效的模型训练和准确的预测结果。通过本文案例,给大家展示了从数据准备、模型训练到参数调优的完整流程,最后结合数据可视化对模型进行了详细的分析和解释。

最后

大家有问题可以直接在评论区留言即可~

喜欢本文的朋友可收藏、点赞、转发起来!

需要本文PDF的同学,扫码备注「案例汇总」即可~ 
关注本号,带来更多算法干货实例,提升工作学习效率!
最后,给大家准备了《机器学习学习小册》PDF版本16大块的内容,124个问题总结
100个超强算法模型,大家如果觉得有用,可以点击查看~

推荐阅读

原创、超强、精华合集
100个超强机器学习算法模型汇总
机器学习全路线
机器学习各个算法的优缺点
7大方面,30个最强数据集
6大部分,20 个机器学习算法全面汇总
铁汁,都到这了,别忘记点赞呀~

深夜努力写Python
Python、机器学习算法
 最新文章