[半监督学习] FlexMatch: Boosting Semi-Supervised Learning with Curriculum Pseudo Labeling
2022/2/6 23:12:57
本文主要是介绍[半监督学习] FlexMatch: Boosting Semi-Supervised Learning with Curriculum Pseudo Labeling,对大家解决编程问题具有一定的参考价值,需要的程序猿们随着小编来一起学习吧!
在 FixMatch 中, 对所有类别使用预定义的常量阈值来选择有助于训练的未标记数据, 因此无法考虑不同类别的不同学习状态和学习难度, UDA 也是如此. 为解决这个问题, 提出课程伪标签(Curriculum Pseudo Labeling, CPL), 这是一种根据模型的学习状态利用未标记数据的课程学习方法. CPL 的核心是在不同时刻灵活地调整不同类别的阈值.
FlexMatch 使用了 CPL, CPL 是一种课程学习(Curriculum Learning)策略, 考虑到半监督学习中不同的学习状态, CPL 将预定义的阈值替换为灵活的阈值. FlexMatch 只需不到 FixMatch 训练时间的1/5就可以达到最终精度.
课程学习(Curriculum Learning)
根据样本的难易程度, 给不同难度的训练样本分配不同的权重. 初始阶段, 给简单样本的权重最高, 随着训练过程的持续, 较难样本的权重将会逐渐被调高. 将权重动态分配的过程称之为课程(Curriculum), 课程初始阶段简易样本居多, 课程末尾阶段样本难度增加, 即"先易后难".
针对不同的实际问题可以设置不同的样本难易程度评价标准. 例如对于一个原始样本, 对其进行强扰动后, 样本的就由简单变向复杂.
课程伪标签(Curriculum Pseudo Labeling, CPL)
根据学习状态动态确定阈值并非易事. 最理想的方法是计算每个类的评估准确度并使用它们来缩放阈值:
τ
t
(
c
)
=
a
t
⋅
τ
(1)
\tau_t(c)=a_t \cdot\tau \tag{1}
τt(c)=at⋅τ(1)
其中
τ
t
(
c
)
\tau_t(c)
τt(c) 是
t
t
t 时刻
c
c
c 类别的灵活阈值,
a
t
(
c
)
a_t(c)
at(c) 是相应的评估精度. 由于不能在模型学习过程中使用评估集, 因此必须从训练集中分离一个额外的验证集来进行准确性评估. 但是在 SSL 中, 标记数据原本就十分稀缺, 不能再剥离一部分出去. 其次, 为了在训练过程中动态调整阈值, 必须连续在每个时刻
t
t
t 进行准确度评估, 这将大大减慢训练速度.
为解决上述问题, CPL 使用另一种方法来估计学习状态, 它不引入额外的推理过程, 也不需要额外的验证集. 其关键假设是, 可以通过预测属于该类且高于阈值的样本数量来反映一个类的学习效果, 然后使用它们来调整阈值
τ
τ
\tau_τ
ττ. 如下图所示:
定义具有较少样本且其预测置信度达到阈值的类为具有较大的学习难度或较差的学习状态:
σ
t
(
c
)
=
∑
n
=
1
N
1
(
max
(
p
m
,
t
(
y
∣
u
n
)
)
>
τ
)
⋅
1
(
arg max
(
p
m
,
t
(
y
∣
u
n
)
)
=
c
)
(2)
\sigma_t(c)=\sum_{n=1}^N \mathbb{1}(\max(p_{m,t}(y\vert u_n))>\tau) \cdot \mathbb{1}(\argmax(p_{m,t}(y\vert u_n))=c) \tag{2}
σt(c)=n=1∑N1(max(pm,t(y∣un))>τ)⋅1(argmax(pm,t(y∣un))=c)(2)
其中
σ
t
(
c
)
\sigma_t(c)
σt(c) 反映了类
c
c
c 在
t
t
t 步的学习效果.
p
m
,
t
(
y
∣
u
n
)
p_{m,t}(y\vert u_n)
pm,t(y∣un) 是模型在
t
t
t 步对未标记数据
u
n
u_n
un 的预测,
N
N
N 是未标记数据的总数. 当未标记数据集是平衡的(即属于不同类别的未标记数据的数量相等或接近)时, 较大的
σ
t
(
c
)
\sigma_t(c)
σt(c) 表示更好的估计学习效果. 通过对
σ
t
(
c
)
\sigma_t(c)
σt(c) 应用以下归一化使其范围在 0 到 1 之间, 然后可以使用它来缩放固定阈值
τ
\tau
τ:
β
t
(
c
)
=
σ
t
(
c
)
max
c
σ
t
(3)
\beta_t(c)=\frac{\sigma_t(c)}{\underset{c}{\max}\sigma_t} \tag{3}
βt(c)=cmaxσtσt(c)(3)
τ
t
(
c
)
=
β
t
(
c
)
⋅
τ
(4)
\tau_t(c)=\beta_t(c) \cdot \tau \tag{4}
τt(c)=βt(c)⋅τ(4)
随着学习的进行, 学习良好的类的阈值会提高, 以选择性地提取更高质量的样本. 最终, 当所有类都达到可靠的准确度时, 阈值都将接近
τ
\tau
τ. 不过阈值并不总是增长态, 如果未标记的数据在后面的迭代中被分类到不同的类别, 阈值也可能会降低. 这个新阈值用于计算 FlexMatch 中的无监督损失, 可以表示为:
L
u
,
t
=
1
μ
B
∑
b
=
1
μ
B
1
(
max
(
q
b
)
≥
τ
t
)
H
(
q
^
b
,
p
m
(
y
∣
A
(
u
b
)
)
)
(5)
\mathcal{L}_{u,t}=\frac{1}{\mu B} \sum_{b=1}^{\mu B} \mathbb{1}(\max(q_b)\geq \tau_t) \mathrm{H}(\hat{q}_b,p_m(y\vert \mathcal{A}(u_b))) \tag{5}
Lu,t=μB1b=1∑μB1(max(qb)≥τt)H(q^b,pm(y∣A(ub)))(5)
其中
q
b
=
p
m
(
y
∣
α
(
u
b
)
)
q_b=p_m(y\vert \alpha(u_b))
qb=pm(y∣α(ub)), 这份损失的形式结构与 FixMatch 基本一致. 最后, FlexMatch 中的损失表示为有监督和无监督损失的加权组合:
L
t
=
L
s
+
λ
L
u
,
t
(6)
\mathcal{L}_t=\mathcal{L}_s+\lambda\mathcal{L}_{u,t} \tag{6}
Lt=Ls+λLu,t(6)
其中
L
s
\mathcal{L}_s
Ls 为有监督损失:
L
s
=
1
B
∑
b
=
1
B
H
(
y
b
,
p
m
(
y
∣
α
(
x
b
)
)
)
(7)
\mathcal{L}_{s}=\frac{1}{B} \sum_{b=1}^{B}\mathrm{H}(y_b,p_m(y\vert \alpha(x_b))) \tag{7}
Ls=B1b=1∑BH(yb,pm(y∣α(xb)))(7)
其他
为避免早阶段训练可能出现的盲目预测, 将式(3)改写为:
β
t
(
c
)
=
σ
t
(
c
)
max
{
max
c
σ
t
,
N
−
∑
c
σ
t
}
(8)
\beta_t(c)=\frac{\sigma_t(c)}{\max \{ \underset{c}{\max}\sigma_t,N-\underset{c}{\sum}\sigma_t \}\tag{8}}
βt(c)=max{cmaxσt,N−c∑σt}σt(c)(8)
这确保了在训练开始时, 所有估计的学习效果从 0 逐渐上升, 直到未使用的未标记数据的数量
N
−
∑
c
σ
t
N-\underset{c}{\sum}\sigma_t
N−c∑σt 不再占主导地位.
同时, 还提出一个非线性映射函数
M
\mathcal{M}
M, 当
β
t
(
c
)
\beta_t(c)
βt(c) 均匀地从 0 到 1 范围内变化时, 使阈值具有非线性的增加曲线:
τ
t
(
c
)
=
M
(
β
t
(
c
)
)
⋅
τ
(9)
\tau_t(c)=\mathcal{M}(\beta_t(c)) \cdot \tau \tag{9}
τt(c)=M(βt(c))⋅τ(9)
显然, 如果
M
\mathcal{M}
M 为恒等函数时, 式(9)与式(4)相同. 并且映射函数是单调递增的, 最大值不大于
1
/
τ
1/\tau
1/τ. 在文献中, 选择凸函数
M
(
x
)
=
x
2
−
x
\mathcal{M}(x) = \frac{x}{2−x}
M(x)=2−xx 作为映射函数.
FlexMatch 完整算法如下:
这篇关于[半监督学习] FlexMatch: Boosting Semi-Supervised Learning with Curriculum Pseudo Labeling的文章就介绍到这儿,希望我们推荐的文章对大家有所帮助,也希望大家多多支持为之网!
- 2024-11-23增量更新怎么做?-icode9专业技术文章分享
- 2024-11-23压缩包加密方案有哪些?-icode9专业技术文章分享
- 2024-11-23用shell怎么写一个开机时自动同步远程仓库的代码?-icode9专业技术文章分享
- 2024-11-23webman可以同步自己的仓库吗?-icode9专业技术文章分享
- 2024-11-23在 Webman 中怎么判断是否有某命令进程正在运行?-icode9专业技术文章分享
- 2024-11-23如何重置new Swiper?-icode9专业技术文章分享
- 2024-11-23oss直传有什么好处?-icode9专业技术文章分享
- 2024-11-23如何将oss直传封装成一个组件在其他页面调用时都可以使用?-icode9专业技术文章分享
- 2024-11-23怎么使用laravel 11在代码里获取路由列表?-icode9专业技术文章分享
- 2024-11-22怎么实现ansible playbook 备份代码中命名包含时间戳功能?-icode9专业技术文章分享