一、FM算法

什么是FM:FM(factor Machine,因子分解机)算法是一种基于矩阵分解的机器学习算法,是由Konstanz大学Steffen Rendle(现任职于Google)于2010年最早提出的,旨在解决稀疏数据下的特征组合问题

背景

  1. 什么是稀疏性
    假设一个广告分类的问题,根据用户和广告位相关的特征,预测用户是否点击了广告。源数据如下:
Clicked Country Day Ad_type
1 USA 26/11/15 Movie
0 China 1/7/14 Game
1 China 19/2/15 Game

“Clicked?“是label,Country、Day、Ad_type是特征。由于三种特征都是categorical类型的,需要经过独热编码(One-Hot Encoding)转换成数值型特征。

Clicked? Country=USA Country=China Day=26/11/15 Day=1/7/14 Day=19/2/15 Ad_type=Movie Ad_type=Game
1 1 0 1 0 0 1 0
0 0 1 0 1 0 0 1
1 0 1 0 0 1 0 1
  1. 什么是特征组合
    在推荐系统和CTR预估中,特征交叉是提升模型效果的关键。考虑一个广告点击预测场景:
1
2
用户特征: 年龄=25, 性别=男, 城市=北京
广告特征: 类别=游戏, 平台=iOS

单独看"性别=男"或"类别=游戏"可能意义不大,但 “男性用户 + 游戏广告” 这个组合特征可能有很强的预测能力。

方法1:手动特征交叉
多项式模型是包含特征组合的最直观的模型。在多项式模型中,特征 xix_ixjx_j 的组合采用 xixjx_ix_j 表示,即 xix_ixjx_j 都非零时,组合特征 xixjx_ix_j 才有意义。从对比的角度,本文只讨论二阶多项式模型。模型的表达式如下

y=w0+i=1nwixi+i=1nj=i+1nwijxixjy = w_0 + \sum_{i=1}^{n} w_i x_i + \sum_{i=1}^{n}\sum_{j=i+1}^{n} w_{ij} x_i x_j

问题:

  • 参数量爆炸:nn 个特征需要 O(n2)O(n^2) 个交叉参数
  • 数据稀疏性:在高维稀疏数据中,大部分特征组合 (xi,xj)(x_i, x_j) 在训练集中从未同时出现,导致 wijw_{ij} 无法学习。如有商品维度有100w,用户维度100w。那么用户和商品的很多交叉组合再训练集中从未出现过,导致训练困难。

FM核心思想

2.1 隐向量分解

FM的核心创新:不直接学习 wijw_{ij},而是将每个特征映射到一个k维隐向量,用隐向量的内积表示交叉权重

wijvi,vj=f=1kvifvjfw_{ij} \approx \langle \mathbf{v}_i, \mathbf{v}_j \rangle = \sum_{f=1}^{k} v_{if} \cdot v_{jf}

其中 viRk\mathbf{v}_i \in \mathbb{R}^k 是特征 ii 的隐向量(embedding)。

2.2 FM模型公式

y^(x)=w0+i=1nwixi+i=1nj=i+1nvi,vjxixj\hat{y}(\mathbf{x}) = w_0 + \sum_{i=1}^{n} w_i x_i + \sum_{i=1}^{n}\sum_{j=i+1}^{n} \langle \mathbf{v}_i, \mathbf{v}_j \rangle x_i x_j

  • w0w_0:全局偏置
  • wiw_i:一阶特征权重
  • vi,vj\langle \mathbf{v}_i, \mathbf{v}_j \rangle:二阶交叉权重

2.3 为什么能解决稀疏性问题?

关键洞察:即使 (xi,xj)(x_i, x_j) 从未同时出现,只要 xix_ixjx_j 分别与其他特征有交互,它们的隐向量就能被学习!

例子:

用户 商品 点击
A 1 1
A 2 1
B 2 1
B 3 1

虽然 (A,3)(A, 3) 从未出现,但:

  • 用户A的向量 vA\mathbf{v}_A 通过与商品1、2的交互学习
  • 商品3的向量 v3\mathbf{v}_3 通过与用户B的交互学习
  • 可以计算 vA,v3\langle \mathbf{v}_A, \mathbf{v}_3 \rangle 来预测!

这就是协同过滤的思想在FM中的体现。

三、计算复杂度优化(核心推导)

3.1 原始复杂度问题

直接计算二阶交叉项:

i=1nj=i+1nvi,vjxixj\sum_{i=1}^{n}\sum_{j=i+1}^{n} \langle \mathbf{v}_i, \mathbf{v}_j \rangle x_i x_j

需要 O(kn2)O(kn^2) 的时间复杂度,这在高维场景下不可接受。

3.2 数学推导:O(kn2)O(kn)O(kn^2) \rightarrow O(kn)

核心技巧:利用平方和公式 (a)2=a2+2i<jaiaj(\sum a)^2 = \sum a^2 + 2\sum_{i<j} a_i a_j

Step 1:扩展求和范围

i=1nj=i+1nvi,vjxixj=12(i=1nj=1nvi,vjxixji=1nvi,vixi2)\sum_{i=1}^{n}\sum_{j=i+1}^{n} \langle \mathbf{v}_i, \mathbf{v}_j \rangle x_i x_j = \frac{1}{2} \left( \sum_{i=1}^{n}\sum_{j=1}^{n} \langle \mathbf{v}_i, \mathbf{v}_j \rangle x_i x_j - \sum_{i=1}^{n} \langle \mathbf{v}_i, \mathbf{v}_i \rangle x_i^2 \right)

解释:

  • ij\sum_i \sum_j 包含了 i<ji<ji>ji>ji=ji=j 三部分
  • i<ji<ji>ji>j 对称,各贡献一半
  • 减去 i=ji=j 的对角线部分

Step 2:展开内积

vi,vj=f=1kvifvjf\langle \mathbf{v}_i, \mathbf{v}_j \rangle = \sum_{f=1}^{k} v_{if} v_{jf}

代入:

=12(i=1nj=1nf=1kvifvjfxixji=1nf=1kvif2xi2)= \frac{1}{2} \left( \sum_{i=1}^{n}\sum_{j=1}^{n} \sum_{f=1}^{k} v_{if} v_{jf} x_i x_j - \sum_{i=1}^{n} \sum_{f=1}^{k} v_{if}^2 x_i^2 \right)

Step 3:交换求和顺序

=12f=1k(i=1nj=1nvifvjfxixji=1nvif2xi2)= \frac{1}{2} \sum_{f=1}^{k} \left( \sum_{i=1}^{n}\sum_{j=1}^{n} v_{if} v_{jf} x_i x_j - \sum_{i=1}^{n} v_{if}^2 x_i^2 \right)

Step 4:分解乘积

注意到 ijvifvjfxixj=(ivifxi)(jvjfxj)=(ivifxi)2\sum_i \sum_j v_{if} v_{jf} x_i x_j = (\sum_i v_{if} x_i)(\sum_j v_{jf} x_j) = (\sum_i v_{if} x_i)^2

=12f=1k((i=1nvifxi)2i=1nvif2xi2)= \frac{1}{2} \sum_{f=1}^{k} \left( \left(\sum_{i=1}^{n} v_{if} x_i\right)^2 - \sum_{i=1}^{n} v_{if}^2 x_i^2 \right)

最终公式:

FM二阶项=12f=1k((i=1nvifxi)2i=1n(vifxi)2)\boxed{\text{FM二阶项} = \frac{1}{2} \sum_{f=1}^{k} \left( \left(\sum_{i=1}^{n} v_{if} x_i\right)^2 - \sum_{i=1}^{n} (v_{if} x_i)^2 \right)}

即:“和的平方” 减去 “平方的和”

3.3 复杂度分析

  • i=1nvifxi\sum_{i=1}^{n} v_{if} x_iO(n)O(n)
  • 平方:O(1)O(1)
  • kk 个维度求和:O(k)O(k)
  • 总复杂度:O(kn)O(kn)

四、代码实现解析

代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
class FM(nn.Module):
def forward(self, inputs):
fm_input = inputs # shape: (batch_size, field_size, embedding_size)

# 和的平方: (Σv_i)^2
square_of_sum = torch.pow(torch.sum(fm_input, dim=1, keepdim=True), 2)

# 平方的和: Σ(v_i^2)
sum_of_square = torch.sum(fm_input * fm_input, dim=1, keepdim=True)

# 交叉项 = 0.5 * (和的平方 - 平方的和)
cross_term = square_of_sum - sum_of_square
cross_term = 0.5 * torch.sum(cross_term, dim=2, keepdim=False)

return cross_term

维度解释:

假设输入 fm_input 的shape是 (batch_size=32, field_size=10, embedding_size=8)

操作 公式对应 Shape变化
torch.sum(fm_input, dim=1) i=1nvifxi\sum_{i=1}^{n} v_{if} x_i (32,10,8) → (32,1,8)
torch.pow(..., 2) (vifxi)2(\sum v_{if} x_i)^2 (32,1,8)
fm_input * fm_input (vifxi)2(v_{if} x_i)^2 (32,10,8)
torch.sum(..., dim=1) (vifxi)2\sum (v_{if} x_i)^2 (32,1,8)
square_of_sum - sum_of_square 差值 (32,1,8)
torch.sum(..., dim=2) f=1k\sum_{f=1}^{k} (32,1)
* 0.5 系数 (32,1)

五、完整数值例子

5.1 简单例子

假设有3个特征,embedding维度k=2:

v1=[1,2],v2=[3,4],v3=[5,6]\mathbf{v}_1 = [1, 2], \quad \mathbf{v}_2 = [3, 4], \quad \mathbf{v}_3 = [5, 6]

x1=1,x2=1,x3=0x_1 = 1, \quad x_2 = 1, \quad x_3 = 0

方法1:直接计算(验证用)

只有 x1,x2x_1, x_2 非零,所以只有一个交叉项:

v1,v2x1x2=(1×3+2×4)×1×1=11\langle \mathbf{v}_1, \mathbf{v}_2 \rangle x_1 x_2 = (1 \times 3 + 2 \times 4) \times 1 \times 1 = 11

方法2:优化公式

输入矩阵(只考虑非零特征):

VX=[13](第1维),[24](第2维)V \cdot X = \begin{bmatrix} 1 \\ 3 \end{bmatrix} \text{(第1维)}, \begin{bmatrix} 2 \\ 4 \end{bmatrix} \text{(第2维)}

对于第1维(f=1):

  • 和的平方:(1+3)2=16(1 + 3)^2 = 16
  • 平方的和:12+32=101^2 + 3^2 = 10
  • 差值:1610=616 - 10 = 6

对于第2维(f=2):

  • 和的平方:(2+4)2=36(2 + 4)^2 = 36
  • 平方的和:22+42=202^2 + 4^2 = 20
  • 差值:3620=1636 - 20 = 16

总交叉项:12(6+16)=11\frac{1}{2}(6 + 16) = 11

5.2 推荐系统实例

场景:预测用户是否点击某电影

特征 One-Hot编码 Embedding (k=4)
用户ID=A [1,0,0] vA=[0.1,0.2,0.3,0.1]\mathbf{v}_A = [0.1, 0.2, 0.3, 0.1]
电影ID=M [0,1,0] vM=[0.2,0.1,0.4,0.2]\mathbf{v}_M = [0.2, 0.1, 0.4, 0.2]
类型=动作 [0,0,1] v动作=[0.3,0.3,0.2,0.1]\mathbf{v}_{动作} = [0.3, 0.3, 0.2, 0.1]

FM计算

  1. 一阶项:wA1+wM1+w动作1w_A \cdot 1 + w_M \cdot 1 + w_{动作} \cdot 1

  2. 二阶项(3个交叉):

    • vA,vM=0.1×0.2+0.2×0.1+0.3×0.4+0.1×0.2=0.18\langle \mathbf{v}_A, \mathbf{v}_M \rangle = 0.1 \times 0.2 + 0.2 \times 0.1 + 0.3 \times 0.4 + 0.1 \times 0.2 = 0.18
    • vA,v动作=0.1×0.3+0.2×0.3+0.3×0.2+0.1×0.1=0.16\langle \mathbf{v}_A, \mathbf{v}_{动作} \rangle = 0.1 \times 0.3 + 0.2 \times 0.3 + 0.3 \times 0.2 + 0.1 \times 0.1 = 0.16
    • vM,v动作=0.2×0.3+0.1×0.3+0.4×0.2+0.2×0.1=0.19\langle \mathbf{v}_M, \mathbf{v}_{动作} \rangle = 0.2 \times 0.3 + 0.1 \times 0.3 + 0.4 \times 0.2 + 0.2 \times 0.1 = 0.19
  3. 总二阶项:0.18+0.16+0.19=0.530.18 + 0.16 + 0.19 = 0.53