一、FM算法
什么是FM:FM(factor Machine,因子分解机)算法是一种基于矩阵分解的机器学习算法,是由Konstanz大学Steffen Rendle(现任职于Google)于2010年最早提出的,旨在解决稀疏数据下的特征组合问题
背景
- 什么是稀疏性
假设一个广告分类的问题,根据用户和广告位相关的特征,预测用户是否点击了广告。源数据如下:
| Clicked |
Country |
Day |
Ad_type |
| 1 |
USA |
26/11/15 |
Movie |
| 0 |
China |
1/7/14 |
Game |
| 1 |
China |
19/2/15 |
Game |
“Clicked?“是label,Country、Day、Ad_type是特征。由于三种特征都是categorical类型的,需要经过独热编码(One-Hot Encoding)转换成数值型特征。
| Clicked? |
Country=USA |
Country=China |
Day=26/11/15 |
Day=1/7/14 |
Day=19/2/15 |
Ad_type=Movie |
Ad_type=Game |
| 1 |
1 |
0 |
1 |
0 |
0 |
1 |
0 |
| 0 |
0 |
1 |
0 |
1 |
0 |
0 |
1 |
| 1 |
0 |
1 |
0 |
0 |
1 |
0 |
1 |
- 什么是特征组合
在推荐系统和CTR预估中,特征交叉是提升模型效果的关键。考虑一个广告点击预测场景:
1 2
| 用户特征: 年龄=25, 性别=男, 城市=北京 广告特征: 类别=游戏, 平台=iOS
|
单独看"性别=男"或"类别=游戏"可能意义不大,但 “男性用户 + 游戏广告” 这个组合特征可能有很强的预测能力。
方法1:手动特征交叉
多项式模型是包含特征组合的最直观的模型。在多项式模型中,特征 xi 和 xj 的组合采用 xixj 表示,即 xi 和 xj 都非零时,组合特征 xixj 才有意义。从对比的角度,本文只讨论二阶多项式模型。模型的表达式如下
y=w0+i=1∑nwixi+i=1∑nj=i+1∑nwijxixj
问题:
- 参数量爆炸:n 个特征需要 O(n2) 个交叉参数
- 数据稀疏性:在高维稀疏数据中,大部分特征组合 (xi,xj) 在训练集中从未同时出现,导致 wij 无法学习。如有商品维度有100w,用户维度100w。那么用户和商品的很多交叉组合再训练集中从未出现过,导致训练困难。
FM核心思想
2.1 隐向量分解
FM的核心创新:不直接学习 wij,而是将每个特征映射到一个k维隐向量,用隐向量的内积表示交叉权重。
wij≈⟨vi,vj⟩=f=1∑kvif⋅vjf
其中 vi∈Rk 是特征 i 的隐向量(embedding)。
2.2 FM模型公式
y^(x)=w0+i=1∑nwixi+i=1∑nj=i+1∑n⟨vi,vj⟩xixj
- w0:全局偏置
- wi:一阶特征权重
- ⟨vi,vj⟩:二阶交叉权重
2.3 为什么能解决稀疏性问题?
关键洞察:即使 (xi,xj) 从未同时出现,只要 xi 和 xj 分别与其他特征有交互,它们的隐向量就能被学习!
例子:
| 用户 |
商品 |
点击 |
| A |
1 |
1 |
| A |
2 |
1 |
| B |
2 |
1 |
| B |
3 |
1 |
虽然 (A,3) 从未出现,但:
- 用户A的向量 vA 通过与商品1、2的交互学习
- 商品3的向量 v3 通过与用户B的交互学习
- 可以计算 ⟨vA,v3⟩ 来预测!
这就是协同过滤的思想在FM中的体现。
三、计算复杂度优化(核心推导)
3.1 原始复杂度问题
直接计算二阶交叉项:
i=1∑nj=i+1∑n⟨vi,vj⟩xixj
需要 O(kn2) 的时间复杂度,这在高维场景下不可接受。
3.2 数学推导:O(kn2)→O(kn)
核心技巧:利用平方和公式 (∑a)2=∑a2+2∑i<jaiaj
Step 1:扩展求和范围
i=1∑nj=i+1∑n⟨vi,vj⟩xixj=21(i=1∑nj=1∑n⟨vi,vj⟩xixj−i=1∑n⟨vi,vi⟩xi2)
解释:
- ∑i∑j 包含了 i<j、i>j、i=j 三部分
- i<j 和 i>j 对称,各贡献一半
- 减去 i=j 的对角线部分
Step 2:展开内积
⟨vi,vj⟩=f=1∑kvifvjf
代入:
=21i=1∑nj=1∑nf=1∑kvifvjfxixj−i=1∑nf=1∑kvif2xi2
Step 3:交换求和顺序
=21f=1∑k(i=1∑nj=1∑nvifvjfxixj−i=1∑nvif2xi2)
Step 4:分解乘积
注意到 ∑i∑jvifvjfxixj=(∑ivifxi)(∑jvjfxj)=(∑ivifxi)2
=21f=1∑k(i=1∑nvifxi)2−i=1∑nvif2xi2
最终公式:
FM二阶项=21f=1∑k(i=1∑nvifxi)2−i=1∑n(vifxi)2
即:“和的平方” 减去 “平方的和”
3.3 复杂度分析
- ∑i=1nvifxi:O(n)
- 平方:O(1)
- 对 k 个维度求和:O(k)
- 总复杂度:O(kn)
四、代码实现解析
代码如下:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
| class FM(nn.Module): def forward(self, inputs): fm_input = inputs square_of_sum = torch.pow(torch.sum(fm_input, dim=1, keepdim=True), 2) sum_of_square = torch.sum(fm_input * fm_input, dim=1, keepdim=True) cross_term = square_of_sum - sum_of_square cross_term = 0.5 * torch.sum(cross_term, dim=2, keepdim=False) return cross_term
|
维度解释:
假设输入 fm_input 的shape是 (batch_size=32, field_size=10, embedding_size=8):
| 操作 |
公式对应 |
Shape变化 |
torch.sum(fm_input, dim=1) |
∑i=1nvifxi |
(32,10,8) → (32,1,8) |
torch.pow(..., 2) |
(∑vifxi)2 |
(32,1,8) |
fm_input * fm_input |
(vifxi)2 |
(32,10,8) |
torch.sum(..., dim=1) |
∑(vifxi)2 |
(32,1,8) |
square_of_sum - sum_of_square |
差值 |
(32,1,8) |
torch.sum(..., dim=2) |
∑f=1k |
(32,1) |
* 0.5 |
系数 |
(32,1) |
五、完整数值例子
5.1 简单例子
假设有3个特征,embedding维度k=2:
v1=[1,2],v2=[3,4],v3=[5,6]
x1=1,x2=1,x3=0
方法1:直接计算(验证用)
只有 x1,x2 非零,所以只有一个交叉项:
⟨v1,v2⟩x1x2=(1×3+2×4)×1×1=11
方法2:优化公式
输入矩阵(只考虑非零特征):
V⋅X=[13](第1维),[24](第2维)
对于第1维(f=1):
- 和的平方:(1+3)2=16
- 平方的和:12+32=10
- 差值:16−10=6
对于第2维(f=2):
- 和的平方:(2+4)2=36
- 平方的和:22+42=20
- 差值:36−20=16
总交叉项:21(6+16)=11 ✓
5.2 推荐系统实例
场景:预测用户是否点击某电影
| 特征 |
One-Hot编码 |
Embedding (k=4) |
| 用户ID=A |
[1,0,0] |
vA=[0.1,0.2,0.3,0.1] |
| 电影ID=M |
[0,1,0] |
vM=[0.2,0.1,0.4,0.2] |
| 类型=动作 |
[0,0,1] |
v动作=[0.3,0.3,0.2,0.1] |
FM计算:
-
一阶项:wA⋅1+wM⋅1+w动作⋅1
-
二阶项(3个交叉):
- ⟨vA,vM⟩=0.1×0.2+0.2×0.1+0.3×0.4+0.1×0.2=0.18
- ⟨vA,v动作⟩=0.1×0.3+0.2×0.3+0.3×0.2+0.1×0.1=0.16
- ⟨vM,v动作⟩=0.2×0.3+0.1×0.3+0.4×0.2+0.2×0.1=0.19
-
总二阶项:0.18+0.16+0.19=0.53