【论文阅读 | TPAMI 2025 | RWKVFusion:利用统一语言与掩码引导的高效图像融合网络】

16086 字
80 分钟
【论文阅读 | TPAMI 2025 | RWKVFusion:利用统一语言与掩码引导的高效图像融合网络】

[TOC]

01 论文信息#

  • 论文题目: An Efficient Image Fusion Network Exploiting Unifying Language and Mask Guidance
  • 论文作者: Zi-Han Cao, Yu-Jie Liang, Liang-Jian Deng, Gemine Vivone
  • 发表单位:
    • School of Mathematical Sciences, University of Electronic Science and Technology of China (UESTC), Chengdu, China
    • Multi-Hazard Early Warning Key Laboratory of Sichuan Province, UESTC, Chengdu, China
    • National Research Council, Institute of Methodologies for Environmental Analysis (CNR-IMAA), Tito, Italy
    • National Biodiversity Future Center (NBFC), Palermo, Italy
  • 发表会议\期刊: IEEE Transactions on Pattern Analysis and Machine Intelligence (TPAMI, Vol.47, No.11, 2025)
  • DOI:10.1109/TPAMI.2025.3591930
  • 论文状态:Received 2025-01-08 · Revised 2025-04-27 · Accepted 2025-07-12 · Published 2025-07-23
  • 代码链接: https://github.com/294coder/RWKVFusion

02 论文主要贡献#

2.1 关键判断#

  1. 方法层判断: RWKVFusion 的贡献在于把“全局语义文本 TT + 对象掩码 MM”直接写进融合定义 Eq.(5),并通过 MFM 在编码阶段持续注入。语义引导消融实验 Table VII 对此提供正面支持。
  2. 算子层判断: BRWKV 通过 WKV 的递推化表达 Eq.(12)-Eq.(14),把全局建模代价压到与 token 长度 LL 线性相关;结合 ERF 结果 Fig.10 与算子替换消融 Table VI,在“感受野-复杂度”上找到了相对稳定的折中。
  3. 实证层判断: 六类融合任务主实验 Table II-V、机制消融实验 Table VI-VIII、下游迁移任务 Fig.11-Fig.13 + Table IX-X 形成闭环证据。说明此方法跨任务能力稳健领先,运行机制可解释,可迁移。

03 论文创新点#

  1. 统一语义引导任务定义:将语言语义 TT 与对象掩码 MM 显式写入融合定义 Eq.(5),并通过 MFM 在编码阶段持续注入,实现“按语义意图融合”。
  2. 高效全局建模主干:提出 BRWKV + ESS 组合,以递推化 WKV 完成全局关系建模,将计算复杂度降为与 token 长度 LL 线性相关,同时保持跨区域依赖能力。
  3. 完整可迁移验证:在六类融合任务上给出主实验、机制消融与下游迁移(分割/检测)闭环验证,证明方法效益来自结构协同。

04 方法#

4.1 工作定位#

本论文提出的工作解决是更具体的结构性矛盾: 如何在不依赖复杂先验(如 GAN、扩散或下游任务头)的前提下,把语义信息显式注入融合过程,同时保证高分辨率场景下的计算可承受性。因此本文提出了 RWKVFusion: 以 RWKV 为高效主干,以语言与语义掩码作为统一引导信号,在六类融合任务上给出主结果、消融和下游任务证据链。

图1 RWKVFusion 与多种方法在多任务融合上的综合性能雷达图。图中更强调“跨任务的整体包络”,而不是单一任务的偶然峰值。
图1 RWKVFusion 与多种方法在多任务融合上的综合性能雷达图。图中更强调“跨任务的整体包络”,而不是单一任务的偶然峰值。

4.1.1 传统融合定义与瓶颈#

4.1.1.1 传统定义表达与语义缺口#

经典图像融合任务可写为:

F=Fθ(S1,S2,,Sn)(1)F = \mathcal{F}_{\theta}(S_1, S_2, \dots, S_n) \tag{1}

式(1)可以表达多模态输入到融合输出的映射关系,但没有给出“应优先保留哪些语义目标”的显式约束。结果是,网络通常只能依靠统计相关性去学习“哪里重要”,在低照、遮挡、烟雾与多目标等复杂场景下,容易出现目标弱化、边界漂移与结构失衡等问题。

图2 不同融合任务的输入样例及其语言/掩码引导信息示例。语言描述更偏全局语义,掩码更偏空间定位,两者互补。
图2 不同融合任务的输入样例及其语言/掩码引导信息示例。语言描述更偏全局语义,掩码更偏空间定位,两者互补。

4.1.2 既有融合框架的三类问题#

图3 已有融合框架与 RWKVFusion 的方法学对比。重点在于:自动语义生成 + 线性复杂度主干,试图同时规避“标注依赖、复杂先验、算子低效”三类代价。
图3 已有融合框架与 RWKVFusion 的方法学对比。重点在于:自动语义生成 + 线性复杂度主干,试图同时规避“标注依赖、复杂先验、算子低效”三类代价。

从文中的视角看,既有方法主要面临三类现实问题:

  1. 语义信息注入依赖额外任务头,例如分割/检测联训,会带来标注成本与训练开销;
  2. 复杂先验链路带来系统复杂化,例如 GAN 双网络、扩散推理过程与深先验迭代;
  3. 高分辨率条件下算子代价偏高,传统注意力在 token 维度上存在二次梯度爆炸。

4.2 从注意力到 RWKV 理论过渡与计算动机#

为了说明为何选择 RWKV 而不是直接继续改造 Transformer,文中先回到标准注意力形式:

Attn(Q,K,V)=softmax(QK)V(2)Attn(Q, K, V) = softmax(QK^\top)V \tag{2}

按 token 写作:

Attn(Q,K,V)t=i=1Teqtkivii=1Teqtki(3)Attn(Q, K, V)_t = \frac{\sum_{i=1}^{T}e^{q_t^\top k_i} \odot v_i}{\sum_{i=1}^{T}e^{q_t^\top k_i}} \tag{3}

进一步引入位置相关权重 WW 后,得到可递推改写的形式:

Attn(W,K,V)t=i=1Tewt,i+kivii=1Tewt,i+ki(4)Attn(W, K, V)_t = \frac{\sum_{i=1}^{T}e^{w_{t,i}+k_i} \odot v_i}{\sum_{i=1}^{T}e^{w_{t,i}+k_i}} \tag{4}

文中先把注意力重写成更接近递推计算的表达,再引入 RWKV 的衰减记忆机制。这样可以在保留全局依赖建模能力的同时,避免标准自注意力在高分辨率场景中的高代价路径。


4.3 方法总览#

RWKVFusion 的关键变化是把语义从“外部后验信息”提升为“前向过程中的条件变量”。文中将任务定义升级为:

F=Fθ(S1,,Sn,T,M)(5)F = \mathcal{F}_{\theta}(S_1, \dots, S_n, T, M) \tag{5}

其中 TT 是文本语义编码,MM 是语义掩码。式(5)是全文最重要的任务层变化:网络不仅学习“怎么融合”,还要在训练和推理阶段始终回答“按什么语义意图融合”。

图4 语义分支流程,包含 caption、检测框、掩码与文本编码,以及与融合主干的连接关系。
图4 语义分支流程,包含 caption、检测框、掩码与文本编码,以及与融合主干的连接关系。

图5 RWKVFusion 主干结构,包括多尺度编码解码、BRWKV、MFM 与 ESS。它把“多尺度表达、语义注入、线性主干算子”放进同一计算图
图5 RWKVFusion 主干结构,包括多尺度编码解码、BRWKV、MFM 与 ESS。它把“多尺度表达、语义注入、线性主干算子”放进同一计算图

4.3.1 语义分支#

语义分支由 Florence、DINO、SAM、T5 模型组成,其目标是构造可被主干使用的文本描述和语义掩膜。具体流程为:先由 Florence 生成 caption 候选,再经 DINO 做开放词汇检测,SAM 产出实例掩码,最后由 T5 编码文本并与 mask merging 后的掩码一起送入融合主干。

4.3.2 融合主干#

融合分支以多尺度 BRWKV 为主干,并通过 MFM 在编码阶段执行语义条件注入。编码器负责语义调制和跨尺度表征建立,解码器负责结构恢复和纹理重建。特别的,不在解码端重复注入语义条件,目的是降低重建阶段被高层语义扰动的风险。

该设计可概括为:语义分支负责“提供条件”,融合分支负责“条件化融合”,两者在训练与推理阶段均处于统一前向图中,而非后处理型松耦合架构。

4.3.3 代码对照 主干前向入口与条件注入时序#

def _foward_prior(self, inp, modal): # 定义函数 _foward_prior
# fusion prior
if self.fusion_prior == "max": # 按条件分支
prior = torch.max(inp, modal) # 调用 max 更新 prior
if self.feature_prior: # 按条件分支
prior = self.prior_convs(prior) # 调用 prior_convs 更新 prior
elif self.fusion_prior == "mean": # 分支判断 self.fusion_prior == "mean"
prior = (inp + modal) / 2.0 # 更新变量 prior
elif self.fusion_prior == "lr": # 分支判断 self.fusion_prior == "lr"
prior = inp # 更新变量 prior
if self.feature_prior: # 按条件分支 self.feature_prior
prior = self.prior_convs(prior) # 调用 prior_convs 更新 prior
elif self.fusion_prior.startswith("grad"): # 分支判断 self.fusion_prior.startswith("grad")
inp_grad = sobel_op(inp, normalized=True) # 调用 sobel_op 更新 inp_grad
modal_grad = sobel_op(modal, normalized=True) # 调用 sobel_op 更新 modal_grad
prior = inp_grad.maximum(modal_grad) # 调用 maximum 更新 prior
elif self.fusion_prior == "none": # 分支判断 self.fusion_prior == "none"
prior = 0.0 # 更新变量 prior
else: # 进入兜底分支
raise ValueError(f"Invalid fusion_prior: {self.fusion_prior}") # 执行 ValueError 调用
# if self.feature_prior:
# prior = self.prior_convs(prior)
return prior # 返回 prior
def _forward_recon(self, feature_out, prior, patch_embd_x=None): # 定义函数 _forward_recon
if self.reconstruction_head == "sr": # 按条件分支 self.reconstruction_head == "sr"
feature_out = self.conv_after_body(feature_out) # 调用 conv_after_body 更新 feature_out
if self.feature_prior: # 按条件分支 self.feature_prior
return self.fusion_head(feature_out + prior) # 返回 self.fusion_head(feature_out +...
else: # 进入兜底分支
return self.fusion_head(feature_out) + prior # 返回 self.fusion_head(feature_out) ...
elif self.reconstruction_head == "fusion": # 分支判断 self.reconstruction_head == "fusion"
assert ( # 执行断言校验
patch_embd_x is not None # 执行当前语句
), "patch_embd_x should be provided for fusion reconstruction" # 结束上一层表达式
feature_out = self.conv_after_body(feature_out) + patch_embd_x # 调用 conv_after_body 更新 feature_out
if self.feature_prior: # 按条件分支 self.feature_prior
return self.fusion_head(feature_out + prior) # 返回 self.fusion_head(feature_out +...
else: # 进入兜底分支
return self.fusion_head(feature_out) + prior # 返回 self.fusion_head(feature_out) ...
else: # 进入兜底分支
raise ValueError(f"Invalid reconstruction_head: {self.reconstruction_head}") # 执行 ValueError 调用
def _forward_implem(self, inp, modal, mask=None, llm_feature=None): # 定义函数 _forward_implem
# inp: vi/over/far/SPECT/LRMS
# modal: ir/under/near/MRI/PAN
# forward features
feature_out, patch_embd_x = self._forward_features( # 调用 _forward_features 更新 feature_out, patch_embd_x
inp, modal, mask, llm_feature # 执行当前语句
) # 结束上一层表达式
# forward prior
prior = self._foward_prior(inp, modal) # 调用 _foward_prior 更新 prior
# reconstruction
outp = self._forward_recon(feature_out, prior, patch_embd_x) # 调用 _forward_recon 更新 outp
return outp # 返回 outp
def _forward_features(self, inp, modal, mask=None, llm_feature=None): # 定义函数 _forward_features
bs, _, H, W = inp.shape # 更新变量 bs, _, H, W
cat_x_modal = torch.cat([inp, modal], dim=1) # 调用 cat 更新 cat_x_modal
patch_embd_x = self.patch_embd(cat_x_modal) # 调用 patch_embd 更新 patch_embd_x
modal = self.modal_in(cat_x_modal) # 调用 modal_in 更新 modal
if self.if_abs_pos: # 按条件分支 self.if_abs_pos
x = patch_embd_x + self.resize_pos_embd( # 调用 resize_pos_embd 更新 x
self.abs_pos, inp_size=(H, W) # 更新变量 self.abs_pos, inp_size
).expand(bs, patch_embd_x.size(1), -1, -1) # 结束上一层表达式
else: # 进入兜底分支
x = patch_embd_x # 更新变量 x
if llm_feature is not None and self.llm_channel is not None: # 按条件分支 llm_feature is not None and self.llm...
llm_feature = F.normalize(llm_feature, p=2, dim=1) # 调用 normalize 更新 llm_feature
llm_feature = self.llm_embd(llm_feature) # 调用 llm_embd 更新 llm_feature
llm_feature = (llm_feature + self.txt_sine_pe).float() # 调用 float 更新 llm_feature
llm_feature = self.embd_drop(llm_feature) # 调用 embd_drop 更新 llm_feature
if mask is not None and hasattr(self, "mask_embd"): # 按条件分支 mask is not None and hasattr(self, "...
mask = self.mask_embd(mask) # 调用 mask_embd 更新 mask
# init condition
u_cond = ConditionInput(modal, llm_feature, mask) # 调用 ConditionInput 更新 u_cond
## encoder
encs = [] # 更新变量 encs
for encoder, down in zip(self.encoders, self.downs): # 遍历 encoder, down in zip(self.encoders, self.downs)
u_cond.modalities = self.resize_conds( # 调用 resize_conds 更新 u_cond.modalities
H, W, u_cond.modalities, resample_type="bilinear" # 更新变量 H, W, u_cond.modalities, res...
) # 结束上一层表达式
u_cond.mask_input = self.resize_conds( # 调用 resize_conds 更新 u_cond.mask_input
H, W, u_cond.mask_input, resample_type="nearest" # 更新变量 H, W, u_cond.mask_input, res...
) # 结束上一层表达式
x, u_cond = encoder.enc_forward(x, u_cond, (H, W)) # 调用 enc_forward 更新 x, u_cond
encs.append(x) # 执行 append 调用
x = down(x) # 调用 down 更新 x
H = H // 2 # 更新变量 H
W = W // 2 # 更新变量 W
## middle layer
u_cond.modalities = self.resize_conds( # 调用 resize_conds 更新 u_cond.modalities
H, W, u_cond.modalities, resample_type="bilinear" # 更新变量 H, W, u_cond.modalities, res...
) # 结束上一层表达式
u_cond.mask_input = self.resize_conds( # 调用 resize_conds 更新 u_cond.mask_input
H, W, u_cond.mask_input, resample_type="nearest" # 更新变量 H, W, u_cond.mask_input, res...
) # 结束上一层表达式
x, _ = self.middle_blks.enc_forward(x, u_cond, (H, W)) # 调用 enc_forward 更新 x, _
## decoder
for decoder, up, enc_skip, skip_scale in zip( # 遍历 decoder, up, enc_skip, skip_scale in zip(
self.decoders, self.ups, encs[::-1], self.skip_scales # 执行当前语句
): # 结束上一层表达式
x = up(x) # 调用 up 更新 x
H = H * 2 # 更新变量 H
W = W * 2 # 更新变量 W
u_cond.modalities = self.resize_conds( # 调用 resize_conds 更新 u_cond.modalities
H, W, u_cond.modalities, resample_type="bilinear" # 更新变量 H, W, u_cond.modalities, res...
) # 结束上一层表达式
u_cond.mask_input = self.resize_conds( # 调用 resize_conds 更新 u_cond.mask_input
H, W, u_cond.mask_input, resample_type="nearest" # 更新变量 H, W, u_cond.mask_input, res...
) # 结束上一层表达式
x = torch.cat([x, enc_skip * skip_scale.view(1, -1, 1, 1)], dim=1) # 调用 cat 更新 x
x = decoder.dec_forward(x, u_cond, (H, W)) # 调用 dec_forward 更新 x
return x, patch_embd_x # 返回 x, patch_embd_x

这段实现表明,“语义引导并非后处理补丁”在工程上由 ConditionInput 的层级传递机制保障,并最终作用于融合特征形成过程。

4.4 融合主干 BRWKV 机制拆解 空间混合、通道混合与递推化#

4.4.1 空间混合 Spatial Mixing#

在 BRWKV 的空间分支中,输入序列先经线性投影:

Rs=XsWR,Ks=XsWK,Vs=XsWV(6)R_s = X_sW_R,\quad K_s = X_sW_K,\quad V_s = X_sW_V \tag{6}

随后通过 WKV 进行全局聚合:

At=i=1,itLeti1Lw+kivi+eu+ktvti=1,itLeti1Lw+ki+eu+kt(7)A_t = \frac{\sum_{i=1, i\neq t}^{L}e^{-\frac{|t-i|-1}{L}w+k_i}v_i + e^{u+k_t}v_t} {\sum_{i=1, i\neq t}^{L}e^{-\frac{|t-i|-1}{L}w+k_i}+e^{u+k_t}} \tag{7}

最后由门控得到空间输出:

Os=(σ(Rs)A)WOs(8)O_s = (\sigma(R_s)\odot A)W_{O_s} \tag{8}

式(7)中的两个参数含义非常关键:ww 控制通道级空间衰减,uu 控制当前位置 bonus。其机制本质为“位置衰减记忆 + 当前 token 增益”。与显式输出 attention map 的标准注意力不同,WKV 可视为先将历史 token 压缩到递推状态,再执行状态读取。

4.4.2 通道混合 Channel Mixing#

空间输出先做归一化:

Xc=RMSNorm(Os)(9)X_c = RMSNorm(O_s) \tag{9}

再做通道域投影与非线性:

Rc=XcWR,Kc=XcWK,Vc=ReLU2(Kc)WV(10)R_c = X_cW_R,\quad K_c = X_cW_K,\quad V_c = ReLU^2(K_c)W_V \tag{10}

门控后得到通道输出:

Oc=(σ(Rc)Vc)WOc(11)O_c = (\sigma(R_c)\odot V_c)W_{O_c} \tag{11}

这一段可以理解为“先在空间域做全局关系聚合,再在通道域做非线性重标定”。它对应 Transformer 块中 attention + FFN 的角色分工,但实现路径和复杂度结构不同。

4.4.3 递推化改写与复杂度来源#

文中将式(7)进一步改写为隐藏状态递推形式。定义:

at1=i=0t1eti1Lw+kivi,bt1=i=t+1L1eti1Lw+kivi,ct1=i=0t1eti1Lw+ki,dt1=i=t+1L1eti1Lw+ki(12)\begin{aligned} a_{t-1} &= \sum_{i=0}^{t-1}e^{-\frac{|t-i|-1}{L}w+k_i}v_i,\\ b_{t-1} &= \sum_{i=t+1}^{L-1}e^{-\frac{|t-i|-1}{L}w+k_i}v_i,\\ c_{t-1} &= \sum_{i=0}^{t-1}e^{-\frac{|t-i|-1}{L}w+k_i},\\ d_{t-1} &= \sum_{i=t+1}^{L-1}e^{-\frac{|t-i|-1}{L}w+k_i} \end{aligned} \tag{12}

对应地:

At=at1+bt1+ekt+uvtct1+dt1+ekt+u(13)A_t = \frac{a_{t-1}+b_{t-1}+e^{k_t+u}v_t}{c_{t-1}+d_{t-1}+e^{k_t+u}} \tag{13}

WKV 的 FLOPs 记为:

FLOPs(OWKV)=2×13×L×C(14)FLOPs(O_{WKV}) = 2\times 13\times L\times C \tag{14}

式(14)对应文中关于效率的论据:开销与 token 长度 LL 线性相关,而不是标准注意力常见的二次关系。对图像融合这类高分辨率、像素级任务,该差异是能否落地部署的关键。

4.4.4 BRWKV 的机制优势#

仅以“线性复杂度”概括 BRWKV 并不充分。更关键的是,RWKV 在保持全局交互能力时,将“显式两两相关矩阵”替换为“可递推的衰减记忆状态”,从而把全局建模状态压缩放在同一步中完成。这意味着:

  1. 该机制是对信息组织方式的重构;
  2. 对低层视觉任务,递推状态可减少高分辨率输入下的显存突增;
  3. 与仅依赖局部窗口的方案相比,它在跨区域结构一致性上更具理论优势。

因此,RWKVFusion 的效率收益与结构收益是耦合产生的:前者来自递推化,后者来自全局关系保留,而非二者择一。

4.4.5 代码对照 BRWKV 的空间混合与通道混合#

文中在式(6)-式(14)给出数学形式,仓库中对应实现主要落在 DoubleStreamRWKVBlock.BRWKV_img_forwardVRWKV_SpatialMix_wkv5.forward 两处。

def BRWKV_img_forward( # 定义函数 BRWKV_img_forward
self, # 执行当前语句
x: torch.Tensor, # 执行当前语句
MIMF_modals: torch.Tensor = None, # 更新变量 MIMF_modals: torch.Tensor
MIFM_mm: torch.Tensor = None, # 更新变量 MIFM_mm: torch.Tensor
llm_feat: torch.Tensor = None, # 更新变量 llm_feat: torch.Tensor
patch_resolution=None, # 更新变量 patch_resolution
): # 结束上一层表达式
""" # 文档字符串边界
##! 1. Image-only attention with clipping windows
w_img = window_partition(img, window_size) # 调用 window_partition 更新 w_img
w_img = scan(img) # 调用 scan 更新 w_img
w_img = only_img_attention(img) # 调用 only_img_attention 更新 w_img
##! 2. Multi-modal attention with txt feature
# solution 1 (downsample image)
img = window_reverse(w_img, window_size, H, W) # 调用 window_reverse 更新 img
ds_img = downsample_scanned_img(img) # 调用 downsample_scanned_img 更新 ds_img
ds_img, txt = multi_modal_attention(ds_img, txt) # 调用 multi_modal_attention 更新 ds_img, txt
upsample_img = upsample_scanned_img(ds_img) # 调用 upsample_scanned_img 更新 upsample_img
img = (upsample_img + img) / 2 # 更新变量 img
# solution 2 (repeat txt feature on bs-dim)
rep_bs_txt = repeat(txt, 'bs d l -> (bs n_winds) d l') # 调用 repeat 更新 rep_bs_txt
w_img, rep_bs_txt = multi_modal_attention(w_img, rep_bs_txt) # 调用 multi_modal_attention 更新 w_img, rep_bs_txt
img = window_reverse(w_img, window_size, H, W) # 调用 window_reverse 更新 img
txt = rearrange(rep_bs_txt, '(bs n_winds) d l -> bs n_winds d l', bs=B) # 调用 rearrange 更新 txt
txt = txt.mean(dim=1) # 调用 mean 更新 txt
##! 3. Multi-modal FFN
#! Flux.1 says we should have two FFNs, one for image and one for txt
img = FFN(img, txt) # 调用 FFN 更新 img
""" # 文档字符串边界
B, C, H, W = x.shape # 更新变量 B, C, H, W
# has_llm = (llm_feat is not None) and self.has_llm
# pre-fusion previous (non-downsampled) image and MIFM image
# x = x + MIFM_img
# ver 3
# MIFM_mm = MIFM_mm * self.lerp_factor.view(1, -1, 1, 1)
# x = self.mm_proj_in(torch.cat([x, MIFM_mm], dim=1))
x = self.mm_proj_in( # 调用 mm_proj_in 更新 x
torch.cat( # 执行 cat 调用
[x, MIFM_mm * self.lerp_factor.view(1, -1, 1, 1), MIMF_modals], dim=1 # 更新变量 [x, MIFM_mm * self.lerp_fact...
) # 结束上一层表达式
) # 结束上一层表达式
# we perform scanning first and then window partition
# different from v11 where we perform window partition first
# window partition
# (bs, c, h * w) -> (bs x n_winds, c, window_size, window_size)
h, w, x = self.partition_img_by_window(H, W, x) # 调用 partition_img_by_window 更新 h, w, x
# ESS
x = ( # 更新变量 x
self.scan(x).view(x.size(0) * self.K, -1, h * w).transpose(1, 2) # 执行 scan 调用
) # [b*k, l, c]
######################## Image-only Attention ##########################
# Image spatial mixing
x = self.only_img_attn_ln(x) # 调用 only_img_attn_ln 更新 x
sc1, sh1, ga1 = self.modulator1(x) # 调用 modulator1 更新 sc1, sh1, ga1
prenorm_x = (1 + sc1) * x + sh1 # modulated img
x_attn = self.drop_path(self.att_img(prenorm_x, None, (h, w))) # 调用 drop_path 更新 x_attn
x = ga1 * x_attn + x # 更新变量 x
# merge ESS
# (bs, c * K, H * W) -> (bs, c, H, W)
x = rearrange(x, "(b k) (h w) d -> b k d h w", k=self.K, h=h, w=w) # 调用 rearrange 更新 x
x = self.merge(x) / self.K # 调用 merge 更新 x
# window reverse
# (bs x n_winds, c, wind x wind) -> (bs, c, h x w)
x = self.reverse_img_by_window(H, W, x) # 调用 reverse_img_by_window 更新 x
x = x.view(B, C, H * W).transpose(1, 2) # 调用 view 更新 x
########################################################################
######################## Image-modality FFN ############################
# Channel mixing
x = self.fusion_img_llm_ln(x) # 调用 fusion_img_llm_ln 更新 x
sc2, sh2, ga2 = self.modulator2(x) # 调用 modulator2 更新 sc2, sh2, ga2
prenorm_x = (1 + sc2) * x + sh2 # 更新变量 prenorm_x
# 2. pure image FFN
x_ffn = self.drop_path(self.ffn_fusion(prenorm_x, None, (H, W))) # 调用 drop_path 更新 x_ffn
x = ga2 * x_ffn + x # 更新变量 x
x = x.transpose(1, 2).view(B, C, H, W) # 调用 transpose 更新 x
#########################################################################
# txt reduced channel
if (not self.last_enc_block) and self.has_llm and llm_feat is not None: # 按条件分支 (not self.last_enc_block) and self.h...
llm_feat = self.txt_chan_reduce(llm_feat) # 调用 txt_chan_reduce 更新 llm_feat
return x, llm_feat # 返回 x, llm_feat
def forward( # 定义函数 forward
self, # 执行当前语句
x, # 执行当前语句
txt=None, # 更新变量 txt
patch_resolution=None, # 更新变量 patch_resolution
mm_tokens: "tuple[torch.Tensor] | None" = None, # 更新变量 mm_tokens: "tuple[torch.Tens...
): # 结束上一层表达式
def _inner_forward(x, txt, patch_resolution, mm_tokens): # 定义函数 _inner_forward
B, T, C = x.size() # 调用 size 更新 B, T, C
# sr, k, v, T = self.jit_func(x, txt, patch_resolution, mm_tokens)
r, k, v, g, T = self.jit_func(x, txt, patch_resolution, mm_tokens) # 调用 jit_func 更新 r, k, v, g, T
# x = RUN_CUDA_RWKV5(B, T, C, self.spatial_decay / T, self.spatial_first / T, k, v)
x = RUN_CUDA_RWKV5_2( # 调用 RUN_CUDA_RWKV5_2 更新 x
B, T, C, self.n_head, r, k, v, w=self.time_decay, u=self.time_faaaa # 更新变量 B, T, C, self.n_head, r, k, ...
) # 结束上一层表达式
# print(f'Spatial mix - x max: {x.abs().max()}, x norm: {x.norm()}')
x = self.key_norm(x) # 调用 key_norm 更新 x
# x = sr * x
# x = self.output(x)
x = self.jit_func_2(x, g) # 调用 jit_func_2 更新 x
return x # 返回 x
if self.with_cp and x.requires_grad: # 按条件分支 self.with_cp and x.requires_grad
x = cp.checkpoint( # 调用 checkpoint 更新 x
_inner_forward, x, txt, patch_resolution, mm_tokens, use_reentrant=False # 更新变量 _inner_forward, x, txt, patc...
) # 结束上一层表达式
else: # 进入兜底分支
x = _inner_forward(x, txt, patch_resolution, mm_tokens) # 调用 _inner_forward 更新 x
return x # 返回 x

att_img 对应式(6)-式(14)的空间混合,ffn_fusion 对应式(9)-式(11)的通道混合。


4.5 ESS 把二维图像转成可递推序列#

BRWKV 原生更接近序列建模,因此文中引入 ESS,即 Efficient Scanning Strategy,实现二维到一维的桥接。其扫描配置包括:

  1. 横/纵交替 + 翻转,2 scans;
  2. 横纵全量 + 翻转,4 scans;
  3. 在 4 scans 基础上加入对角扫描,8 scans。

ESS 的目标并非提升扫描方向数量本身,而是在空间覆盖率与 FLOPs 之间取得平衡。后续消融表 VI 显示,8 scans 在部分指标可略增,但代价上升明显;默认策略在综合性能与效率之间更平衡。

4.5.1 代码对照 CrossScan / CrossMerge 的多方向扫描#

代码实现中,ESS 的核心不在高层 scan_mode 字符串,而在 CrossScanCrossMerge 的真实张量重排逻辑。

class CrossScan(torch.autograd.Function): # 定义类 CrossScan
# ZSJ 这里是把图像按照特定方向展平的地方,改变扫描方向可以在这里修改
@staticmethod # 装饰器声明
def forward(ctx, x: torch.Tensor): # 定义函数 forward
B, C, H, W = x.shape # 更新变量 B, C, H, W
ctx.shape = (B, C, H, W) # 更新变量 ctx.shape
# xs = x.new_empty((B, 4, C, H * W))
xs = x.new_empty((B, 8, C, H * W)) # 调用 new_empty 更新 xs
# 添加横向和竖向的扫描
xs[:, 0] = x.flatten(2, 3) # 调用 flatten 更新 xs[:, 0]
xs[:, 1] = x.transpose(dim0=2, dim1=3).flatten(2, 3) # 调用 transpose 更新 xs[:, 1]
xs[:, 2:4] = torch.flip(xs[:, 0:2], dims=[-1]) # 调用 flip 更新 xs[:, 2:4]
# 提供斜向和反斜向的扫描
xs[:, 4] = diagonal_gather(x) # 调用 diagonal_gather 更新 xs[:, 4]
xs[:, 5] = antidiagonal_gather(x) # 调用 antidiagonal_gather 更新 xs[:, 5]
xs[:, 6:8] = torch.flip(xs[:, 4:6], dims=[-1]) # 调用 flip 更新 xs[:, 6:8]
return xs # 返回 xs
@staticmethod # 装饰器声明
def backward(ctx, ys: torch.Tensor): # 定义函数 backward
# out: (b, k, d, l)
B, C, H, W = ctx.shape # 更新变量 B, C, H, W
L = H * W # 更新变量 L
# 把横向和竖向的反向部分再反向回来,并和原来的横向和竖向相加
# ys = ys[:, 0:2] + ys[:, 2:4].flip(dims=[-1]).view(B, 2, -1, L)
y_rb = ys[:, 0:2] + ys[:, 2:4].flip(dims=[-1]).view(B, 2, -1, L) # 调用 flip 更新 y_rb
# 把竖向的部分转成横向,然后再相加,再转回最初是的矩阵形式
# y = ys[:, 0] + ys[:, 1].view(B, -1, W, H).transpose(dim0=2, dim1=3).contiguous().view(B, -1, L)
y_rb = y_rb[:, 0] + y_rb[:, 1].view(B, -1, W, H).transpose( # 调用 view 更新 y_rb
dim0=2, dim1=3 # 更新变量 dim0
).contiguous().view(B, -1, L) # 结束上一层表达式
y_rb = y_rb.view(B, -1, H, W) # 调用 view 更新 y_rb
# 把斜向和反斜向的反向部分再反向回来,并和原来的斜向和反斜向相加
y_da = ys[:, 4:6] + ys[:, 6:8].flip(dims=[-1]).view(B, 2, -1, L) # 调用 flip 更新 y_da
# 把斜向和反斜向的部分都转成原来的最初的矩阵形式,再相加
y_da = diagonal_scatter(y_da[:, 0], (B, C, H, W)) + antidiagonal_scatter( # 调用 diagonal_scatter 更新 y_da
y_da[:, 1], (B, C, H, W) # 执行当前语句
) # 结束上一层表达式
y_res = y_rb + y_da # 更新变量 y_res
# return y.view(B, -1, H, W)
return y_res # 返回 y_res
class CrossMerge(torch.autograd.Function): # 定义类 CrossMerge
@staticmethod # 装饰器声明
def forward(ctx, ys: torch.Tensor): # 定义函数 forward
B, K, D, H, W = ys.shape # 更新变量 B, K, D, H, W
ctx.shape = (H, W) # 更新变量 ctx.shape
ys = ys.view(B, K, D, -1) # 调用 view 更新 ys
# ys = ys[:, 0:2] + ys[:, 2:4].flip(dims=[-1]).view(B, 2, D, -1)
# y = ys[:, 0] + ys[:, 1].view(B, -1, W, H).transpose(dim0=2, dim1=3).contiguous().view(B, D, -1)
y_rb = ys[:, 0:2] + ys[:, 2:4].flip(dims=[-1]).view(B, 2, D, -1) # 调用 flip 更新 y_rb
# 把竖向的部分转成横向,然后再相加,再转回最初是的矩阵形式
y_rb = y_rb[:, 0] + y_rb[:, 1].view(B, -1, W, H).transpose( # 调用 view 更新 y_rb
dim0=2, dim1=3 # 更新变量 dim0
).contiguous().view(B, D, -1) # 结束上一层表达式
y_rb = y_rb.view(B, -1, H, W) # 调用 view 更新 y_rb
# 把斜向和反斜向的反向部分再反向回来,并和原来的斜向和反斜向相加
y_da = ys[:, 4:6] + ys[:, 6:8].flip(dims=[-1]).view(B, 2, D, -1) # 调用 flip 更新 y_da
# 把斜向和反斜向的部分都转成原来的最初的矩阵形式,再相加
y_da = diagonal_scatter(y_da[:, 0], (B, D, H, W)) + antidiagonal_scatter( # 调用 diagonal_scatter 更新 y_da
y_da[:, 1], (B, D, H, W) # 执行当前语句
) # 结束上一层表达式
y_res = y_rb + y_da # 更新变量 y_res
return y_res.view(B, D, -1) # 返回 y_res.view(B, D, -1)
# return y
@staticmethod # 装饰器声明
def backward(ctx, x: torch.Tensor): # 定义函数 backward
# B, D, L = x.shape
# out: (b, k, d, l)
H, W = ctx.shape # 更新变量 H, W
B, C, L = x.shape # 更新变量 B, C, L
# xs = x.new_empty((B, 4, C, L))
xs = x.new_empty((B, 8, C, L)) # 调用 new_empty 更新 xs
# 横向和竖向扫描
xs[:, 0] = x # 更新变量 xs[:, 0]
xs[:, 1] = x.view(B, C, H, W).transpose(dim0=2, dim1=3).flatten(2, 3) # 调用 view 更新 xs[:, 1]
xs[:, 2:4] = torch.flip(xs[:, 0:2], dims=[-1]) # 调用 flip 更新 xs[:, 2:4]
# xs = xs.view(B, 4, C, H, W)
# 提供斜向和反斜向的扫描
xs[:, 4] = diagonal_gather(x.view(B, C, H, W)) # 调用 diagonal_gather 更新 xs[:, 4]
xs[:, 5] = antidiagonal_gather(x.view(B, C, H, W)) # 调用 antidiagonal_gather 更新 xs[:, 5]
xs[:, 6:8] = torch.flip(xs[:, 4:6], dims=[-1]) # 调用 flip 更新 xs[:, 6:8]
# return xs
return xs.view(B, 8, C, H, W) # 返回 xs.view(B, 8, C, H, W)
############### K = 4 ################

由此可见,“2/4/8 scans”是可追踪的扫描方向组合与逆向重排过程。


4.6 MFM 语言、掩码与模态特征的三路径融合#

图6 MFM 特征可视化。掩码路径强化对象区域响应,文本路径提供全局语义一致性,两者叠加后同时改善目标显著性与整体语义稳定性。
图6 MFM 特征可视化。掩码路径强化对象区域响应,文本路径提供全局语义一致性,两者叠加后同时改善目标显著性与整体语义稳定性。

MFM,即 Multi-Modal Fusion Module,是语义分支与融合主干的接口层。其设计并非简单拼接层,而是由原始模态补偿、掩码引导与文本调制构成的三路径机制。

首先,原始模态与上一层特征通过门控形成主路径:

Xfeat=RMSNorm(Conv(Ocl1)),Smod=RMSNorm(Conv(S)),Xact=κ(AdapPool(Xfeat+Smod)),Xfeat=(Xfeat+Smod)Xact(15)\begin{aligned} X_{feat} &= RMSNorm(Conv(O_c^{l-1})),\\ S_{mod} &= RMSNorm(Conv(S)),\\ X_{act} &= \kappa(AdapPool(X_{feat}+S_{mod})),\\ X_{feat} &= (X_{feat}+S_{mod})\odot X_{act} \end{aligned} \tag{15}

随后注入掩码:

Mfeat=Conv(M),Xmask=(Xfeat+Smod)Mfeat(16)M_{feat}=Conv(M),\quad X_{mask}=(X_{feat}+S_{mod})\odot M_{feat} \tag{16}

合并图像特征:

Ximg=Xfeat+Xmask(17)X_{img}=X_{feat}+X_{mask} \tag{17}

并通过奇偶层交替拼接文本:

Xtxt={Concat(T,Ximg),j is evenConcat(Ximg,T),j is odd(18)X_{txt}= \begin{cases} Concat(T,X_{img}), & j\text{ is even}\\ Concat(X_{img},T), & j\text{ is odd} \end{cases} \tag{18}

式(18)中的交替拼接并非形式性操作,其目的在于避免文本长期位于固定序列端点造成的信息偏置。文中后续的模块替换消融中,MLP/cross-attention 对照显示 MFM 默认设计优于替代结构,说明三路径耦合提供了可测收益。

4.6.1 代码对照 MFM 的门控、掩码注入与文本交互#

下面给出对应式(15)-式(18)的逐行注释代码切块。

def forward(self, img_feat: torch.Tensor, cond_input: ConditionInput): # 定义函数 forward
modals = cond_input.modalities # 更新变量 modals
mask = cond_input.mask_input # 更新变量 mask
llm_feat = cond_input.llm_feature # 更新变量 llm_feat
# modalities encoder
modals = self.modal_convs(modals) # 调用 modal_convs 更新 modals
modals = self.modal_norm(modals) # 调用 modal_norm 更新 modals
# feature encoder
img_feat = self.feat_drop(self.feat_convs(img_feat)) # 调用 feat_drop 更新 img_feat
# feature gate
img_gate = self.feat_norm(img_feat + modals) # 调用 feat_norm 更新 img_gate
feat_pool = self.adap_pool(img_gate) # 调用 adap_pool 更新 feat_pool
feat_mul = self.act_conv(feat_pool) # to gate
# m1, m2 = feat_mul.chunk(2, dim=1)
# gate
img_feat_gate = img_feat * feat_mul # + img_feat
modals_gate = modals * feat_mul # + modals
img_feat_gate = img_feat_gate + modals_gate # 更新变量 img_feat_gate
# img_feat_gate = img_feat * m1 + modals * m2
# fusion mask in feat and modals
# then needs to be normalized
if mask is not None and self.with_mask: # 按条件分支 mask is not None and self.with_mask
mask = self.mask_convs(mask) # 调用 mask_convs 更新 mask
img_feat_mask = img_feat * mask # + img_feat
modals_mask = modals * mask # + modals
img_feat = img_feat_gate + img_feat_mask + modals_mask # 更新变量 img_feat
# img_feat = img_feat_gate + img_feat_mask * m1 + modals_mask * m2
else: # 进入兜底分支
img_feat = img_feat_gate # 更新变量 img_feat
# fusion feat and modals
img_feat = self.fusion_proj(img_feat) # 调用 fusion_proj 更新 img_feat
# img_feat_modals = self.modals_out_ln(img_feat)
img_feat_modals = img_feat # 更新变量 img_feat_modals
################################# LLM text and image attention ###########################
# fusion image and llm features
if llm_feat is not None and self.with_llm_feat: # 按条件分支 llm_feat is not None and self.with_l...
llm_feat = self.llm_drop(self.llm_dense(llm_feat)) # 调用 llm_drop 更新 llm_feat
llm_feat = self.add_llm_pe(llm_feat) # 调用 add_llm_pe 更新 llm_feat
# downsample image
img_feat, full_size = self.downsample_img_by_llm(img_feat) # 调用 downsample_img_by_llm 更新 img_feat, full_size
# ESS
bs, C, H, W = img_feat.shape # 更新变量 bs, C, H, W
img_feat = self.scan(img_feat).view(bs * self.K, C, H * W).transpose(1, 2) # 调用 scan 更新 img_feat
img_feat = self.add_img_pe(img_feat, H, W) # 调用 add_img_pe 更新 img_feat
# multi-modal RWKV attention
img_feat = self.multi_modal_img_ln(img_feat) # 调用 multi_modal_img_ln 更新 img_feat
llm_feat = self.multi_modal_txt_ln(llm_feat) # 调用 multi_modal_txt_ln 更新 llm_feat
# llm_feat: [bs, llm_len, c], img_feat: [bs, img_len, c]
if self.add_mm_tokens: # 按条件分支 self.add_mm_tokens
multi_modal_attn_out = self.multi_modal_attn( # 调用 multi_modal_attn 更新 multi_modal_attn_out
img_feat, # 执行当前语句
llm_feat.repeat(self.K, 1, 1), # 执行 repeat 调用
(H, W), # 执行当前语句
(self.img_token, self.txt_token), # 执行当前语句
) # 结束上一层表达式
else: # 进入兜底分支
multi_modal_attn_out = self.multi_modal_attn( # 调用 multi_modal_attn 更新 multi_modal_attn_out
img_feat, llm_feat.repeat(self.K, 1, 1), (H, W) # 执行 repeat 调用
) # 结束上一层表达式
# extract img and llm feat
img_feat_out, llm_feat_out = self.extract_img_and_llm_feat( # 调用 extract_img_and_llm_feat 更新 img_feat_out, llm_feat_out
H, W, multi_modal_attn_out, self.with_llm_feat # 执行当前语句
) # 结束上一层表达式
# residual
img_feat = img_feat + img_feat_out # 更新变量 img_feat
llm_feat = llm_feat + llm_feat_out.view(self.K, bs, -1, C).mean(dim=0) # 调用 view 更新 llm_feat
# merge ESS
img_feat = rearrange( # 调用 rearrange 更新 img_feat
img_feat, "(b k) (h w) d -> b k d h w", k=self.K, h=H, w=W # 更新变量 img_feat, "(b k) (h w) d -> ...
) # 结束上一层表达式
img_feat = self.merge(img_feat) / self.K # 调用 merge 更新 img_feat
# upsample image
img_feat = self.upsample_img_by_llm(img_feat, full_size) # 调用 upsample_img_by_llm 更新 img_feat
# mlps
img_feat = self.img_mlp(img_feat) + img_feat # 调用 img_mlp 更新 img_feat
if not self.last_enc_block: # 按条件分支 not self.last_enc_block
llm_feat = self.txt_mlp(llm_feat) + llm_feat # 调用 txt_mlp 更新 llm_feat
# img_feat_llm = self.llm_out_ln(img_feat)
img_feat_llm = img_feat # 更新变量 img_feat_llm
return img_feat_modals, img_feat_llm, modals, llm_feat # 返回 img_feat_modals, img_feat_llm,...
###########################################################################################
return img_feat_modals, modals, llm_feat # 返回 img_feat_modals, modals, llm_f...

从代码实现可见,MFM 并非“先拼接再卷积”的浅层融合,而是门控、掩码与文本交互的串联机制;文本交互发生在下采样后的序列域,以控制计算开销。


4.7 语义分支细化 掩码生成与 mask merging#

图7 掩码生成与合并流程。跨模态掩码直接并用会引入重复与错配,合并机制用于提升语义引导的一致性与鲁棒性。
图7 掩码生成与合并流程。跨模态掩码直接并用会引入重复与错配,合并机制用于提升语义引导的一致性与鲁棒性。

在语义链路中,文中采用 Florence 生成描述、DINO 开集检测、SAM 分割实例,再将结果送入 mask merging。这个设计的必要性来自一个常被忽视的问题:不同模态对同一对象的响应强度并不一致,导致“同 prompt 下的掩码质量差异”。

文中在主文给出流程,在补充材料中给出算法细节。就主文信息而言,mask merging 的作用可归纳为:

  • 抑制重复实例,即 duplicate objects;
  • 缓解漏检导致的语义空洞;
  • 减少错位掩码对主干更新的误导。

就结果解释而言,表 VII 中“caption + merged mask”优于“caption + unmerged mask”,可以直接视为 mask merging 的实证支持。

4.7.1 代码对照 掩码输入规范化与训练入口#

主文强调 mask merging 的算法细节在补充材料中;仓库主干训练代码中可直接看到“掩码输入规范化 + one-hot 化 + 进入融合训练”的路径。

def check_multi_value_mask_to_one_hot(self, mask: torch.Tensor | None): # 定义函数 check_multi_value_mask_to_one_hot
if mask is not None and self.has_mask: # 按条件分支 mask is not None and self.has_mask
if self.multi_value_mask_max_classes is not None and mask.ndim == 3: # 按条件分支 self.multi_value_mask_max_classes is...
# * if mask is dtype float, must be careful about its values, must be in [0, multi_value_mask_max_classes]
# * if float value in it, the code will raise CUDA error by `scatter_` without clear hint.
mask[mask >= self.multi_value_mask_max_classes] = ( # 执行当前语句
0.0 # larger than max classes are set to background
) # 结束上一层表达式
new_mask = torch.zeros( # 调用 zeros 更新 new_mask
mask.size(0), # 执行 size 调用
self.multi_value_mask_max_classes, # 执行当前语句
*mask.shape[-2:], # 执行当前语句
device=mask.device, # 更新变量 device
dtype=torch.float32, # 更新变量 dtype
) # 结束上一层表达式
new_mask.scatter_(1, mask.long()[:, None], 1.0) # 执行 scatter_ 调用
elif mask.ndim == 4: # 分支判断 mask.ndim == 4
new_mask = mask # 更新变量 new_mask
else: # 进入兜底分支
raise ValueError( # 执行 ValueError 调用
f"Invalid mask shape: {mask.shape} when `multi_value_mask_max_classes` is {self.multi_value_mask_max_classes}", # 执行当前语句
"or model has `has_mask` as True. Input mask should be 3D or 4D tensor", # 执行当前语句
"such as [Bs, 1, H, W] or [Bs, *, H, W] (to be one-hot encoded in this case).", # 执行当前语句
) # 结束上一层表达式
else: # 进入兜底分支
new_mask = None # 更新变量 new_mask
return new_mask # 返回 new_mask
def fusion_train_step( # 定义函数 fusion_train_step
self, # 执行当前语句
vi: "torch.Tensor", # 执行当前语句
ir: "torch.Tensor", # 执行当前语句
mask: "torch.Tensor | None" = None, # 更新变量 mask: "torch.Tensor | None"
gt: "torch.Tensor | None" = None, # 更新变量 gt: "torch.Tensor | None"
txt: "torch.Tensor | None" = None, # 更新变量 txt: "torch.Tensor | None"
fusion_criterion: "Callable | None" = None, # 更新变量 fusion_criterion: "Callable ...
to_rgb_fn: "Callable | None" = None, # 更新变量 to_rgb_fn: "Callable | None"
has_gt: bool = False, # 更新变量 has_gt: bool
**_kwargs, # 执行当前语句
): # 结束上一层表达式
vi, ir, mask, txt = self.check_multi_modal_inputs( # 调用 check_multi_modal_inputs 更新 vi, ir, mask, txt
vi, ir, mask, txt, vi.device, vi.dtype # 执行当前语句
) # 结束上一层表达式
txt = self.drop_txt(txt) # 调用 drop_txt 更新 txt
mask = self.check_multi_value_mask_to_one_hot(mask) # 调用 check_multi_value_mask_to_one_hot 更新 mask
fused_outp = self.only_fusion_step(vi, ir, mask, txt) # 调用 only_fusion_step 更新 fused_outp
# mask for loss, should be detached
if mask is not None: # 按条件分支 mask is not None
mask_for_loss = mask.clone().detach() # 调用 clone 更新 mask_for_loss
else: # 进入兜底分支
mask_for_loss = None # 更新变量 mask_for_loss
fused_for_loss = to_rgb_fn(fused_outp) if to_rgb_fn is not None else fused_outp # 调用 to_rgb_fn 更新 fused_for_loss
if ( # 按条件分支 (
has_gt or gt.size(1) == 3 # 执行 size 调用
): # TODO: find more robust way to check if gt is available
# if we have supervised GT, we use it to compute the supervised loss
# sometimes, for MEF fusion task, we can access the GTs
fusion_gt = gt # 更新变量 fusion_gt
# two different modalities for GT to compute the unsupervised loss
boundary_gt = (vi, ir) # 更新变量 boundary_gt
else: # 进入兜底分支
fusion_gt = None # 更新变量 fusion_gt
boundary_gt = (vi, ir) # 更新变量 boundary_gt
# compute supervised and unsupervised losses
assert fusion_criterion is not None, "fusion_criterion should be provided" # 执行断言校验
loss = list( # 调用 list 更新 loss
fusion_criterion( # 执行 fusion_criterion 调用
fused_for_loss, # 执行当前语句
boundary_gt=boundary_gt, # 更新变量 boundary_gt
fusion_gt=fusion_gt, # 更新变量 fusion_gt
mask=mask_for_loss, # 更新变量 mask
) # 结束上一层表达式
) # 结束上一层表达式
return fused_for_loss.clip(0, 1), loss # 返回 fused_for_loss.clip(0, 1), los...
@torch.no_grad() # 装饰器声明

这表明 mask 引导在工程实现中属于训练闭环的必要输入路径。


4.8 损失函数设计 监督场景分流#

文中按任务监督属性划分损失函数。

4.8.1 有监督 HMIF / Pansharpening#

Lsharpening=FGT1+λ(1SSIM(F,GT))(19)L_{sharpening}=\lVert F-GT\rVert_1 + \lambda(1-SSIM(F,GT)) \tag{19}

式(19)由像素一致性与结构一致性共同约束。其意义是避免单纯 L1 导致结构退化,也避免单纯结构项造成光谱/亮度偏移。

4.8.2 无监督 VIF / MIF / MEF / MFF#

Lfusion=η1Linten+η2Lssim+η3Lgrad(20)L_{fusion}=\eta_1L_{inten}+\eta_2L_{ssim}+\eta_3L_{grad} \tag{20}Linten=FS11+FS21(21)L_{inten}=\lVert F-S_1\rVert_1+\lVert F-S_2\rVert_1 \tag{21}Lssim=2SSIM(F,S1)SSIM(F,S2)(22)L_{ssim}=2-SSIM(F,S_1)-SSIM(F,S_2) \tag{22}Lgrad=Fmax(S1,S2)1(23)L_{grad}=\lVert\nabla F-\max(\nabla S_1,\nabla S_2)\rVert_1 \tag{23}

式(20)–式(23)对应典型的“强度-结构-边缘”三约束配比。对于无 GT 的融合任务,这是可解释性较强、工程上较稳定的选择:强度项控制内容保留,SSIM 项控制结构一致,梯度项控制细节锐度。

4.8.3 代码对照 损失实现与文中公式映射#

损失实现位于 utils/loss_utils.pyDRMFFusionLoss.forward。该实现是“文中损失族”的通用化版本,支持有监督/无监督、多任务权重与可选掩码约束。

def forward( # 定义函数 forward
self, # 执行当前语句
img_fusion: Tensor, # 执行当前语句
boundary_gt: "Tensor | tuple", # cat([vi, ir]) or tuple(vi, ir)
fusion_gt: "Tensor" = None, # ground truth provided in dataset
mask: "Tensor | None" = None, # 更新变量 mask: "Tensor | None"
) -> tuple[torch.Tensor, dict[str, torch.Tensor]]: # 结束上一层表达式
if mask is not None and self.mask_loss: # 按条件分支 mask is not None and self.mask_loss
self.check_dtype_and_device(img_fusion, boundary_gt, mask) # 执行 check_dtype_and_device 调用
with torch.no_grad(): # 进入上下文 torch.no_grad()
if self.reduce_label: # 按条件分支 self.reduce_label
mask2 = mask.detach().clone() # 调用 detach 更新 mask2
mask2[mask2 > 1.0] = 1.0 # 更新变量 mask2[mask2 > 1.0]
else: # 进入兜底分支
mask2 = mask # 更新变量 mask2
## TODO: consider this case
# mask.size(1) == 2, two boundaries all have masks
if mask2.size(1) == 2: # 按条件分支 mask2.size(1) == 2
assert False, "mask.size(1) == 2, two boundaries all have masks" # 执行断言校验
mask_B, mask_A = mask2.chunk(2, dim=1) # 调用 chunk 更新 mask_B, mask_A
else: # 进入兜底分支
mask_B, mask_A = mask2, mask2 # 更新变量 mask_B, mask_A
elif not self.mask_loss: # 分支判断 not self.mask_loss
mask2 = None # cast to None
self.check_dtype_and_device(img_fusion, boundary_gt) # 执行 check_dtype_and_device 调用
wd = self.weight_dict # 更新变量 wd
loss_intensity = 0 # 更新变量 loss_intensity
loss_color = 0 # 更新变量 loss_color
loss_grad = 0 # 更新变量 loss_grad
loss_fusion = 0 # 更新变量 loss_fusion
loss = {} # 更新变量 loss
# split boundary gt
no_batch_ndim = img_fusion.ndim - 1 # 更新变量 no_batch_ndim
broadcast_fn = lambda x: x.reshape(-1, *[1] * no_batch_ndim) # noqa: py3.11 supported
# img_A: ir/under/MRI, img_B: vi/over/spect
if isinstance(boundary_gt, (tuple, list)): # 按条件分支 isinstance(boundary_gt, (tuple, list...
img_B, img_A = boundary_gt # 更新变量 img_B, img_A
else: # 进入兜底分支
img_B, img_A = self.split_boundary_gt_tensor(boundary_gt) # 调用 split_boundary_gt_tensor 更新 img_B, img_A
# if has gt (e.g., when we train model on multi-exposure image fusion task)
if exist_gt := exists(fusion_gt): # 按条件分支 exist_gt := exists(fusion_gt)
gt_loss = wd["fusion_gt"] * self.loss_func(img_fusion, fusion_gt).nanmean() # 调用 loss_func 更新 gt_loss
loss["gt_loss"] = gt_loss # 更新变量 loss["gt_loss"]
loss_fusion += gt_loss # 更新变量 loss_fusion +
if self.grad_norm: # 按条件分支 self.grad_norm
img_fusion = grad_norm(img_fusion) # 调用 grad_norm 更新 img_fusion
# ir, vi (detached in this function)
(img_A, A_is_color), (img_B, B_is_color) = ( # 更新变量 (img_A, A_is_color), (img_B,...
self.check_rgb(img_A), # 执行 check_rgb 调用
self.check_rgb(img_B), # 执行 check_rgb 调用
) # 结束上一层表达式
## vgg latent weights
# the latent weights comes from U2Fusion paper
if self.latent_weighted: # 按条件分支 self.latent_weighted
vi_w, ir_w = self.dynamic_weight(img_A, img_B) # 调用 dynamic_weight 更新 vi_w, ir_w
vi_w = broadcast_fn(vi_w) # 调用 broadcast_fn 更新 vi_w
ir_w = broadcast_fn(ir_w) # 调用 broadcast_fn 更新 ir_w
else: # 进入兜底分支
vi_w, ir_w = 1.0, 1.0 # 更新变量 vi_w, ir_w
# YCbCr decomposition
# use for intensity, color, and gradient loss
Y_fusion, Cb_fusion, Cr_fusion = kornia.color.rgb_to_ycbcr(img_fusion).chunk( # 调用 rgb_to_ycbcr 更新 Y_fusion, Cb_fusion, Cr_fusi...
3, dim=1 # 更新变量 3, dim
) # 结束上一层表达式
Y_A, Cb_A, Cr_A = kornia.color.rgb_to_ycbcr(img_A).chunk(3, dim=1) # ir
Y_B, Cb_B, Cr_B = kornia.color.rgb_to_ycbcr(img_B).chunk(3, dim=1) # vi
## intensity and color loss
# * compute boundary intensity and color loss when GT is not provided
# * or we omit computing the loss
# * in this case, we can say intensity and color are supervised by the GT
# * we do not need to use the unsupervised loss
# * `still_boundary_loss_when_gt` is to keep the boundary loss when GT is provided (optional)
if (not exist_gt) or self.still_boundary_loss_when_gt: # 按条件分支 (not exist_gt) or self.still_boundar...
if self.prior == "max": # 按条件分支 self.prior == "max"
Y_joint = torch.max(Y_A, Y_B) # 调用 max 更新 Y_joint
elif self.prior == "mean": # 分支判断 self.prior == "mean"
Y_joint = (Y_A + Y_B) / 2 # 更新变量 Y_joint
if self.grad_norm: # 按条件分支 self.grad_norm
Y_fusion = grad_norm(Y_fusion) # 调用 grad_norm 更新 Y_fusion
Cb_fusion = grad_norm(Cb_fusion) # 调用 grad_norm 更新 Cb_fusion
Cr_fusion = grad_norm(Cr_fusion) # 调用 grad_norm 更新 Cr_fusion
if mask2 is not None: # 按条件分支 mask2 is not None
# * intensity loss
# 1. || fusion - max(vi, ir) || max fusion
# 2. || mask * fusion - mask * ir || ir loss (pedistrain)
# 3. || (1 - mask) * fusion - (1 - mask) * vi || vi loss (background)
loss_intensity = ( # 更新变量 loss_intensity
( # 执行当前语句
wd["inten_f_joint"] * self.loss_func(Y_fusion, Y_joint) # 执行 loss_func 调用
if self.use_prior # 按条件分支 self.use_prior
else 0.0 # 执行当前语句
) # 结束上一层表达式
+ ( # 执行当前语句
wd["inten_f_ir"] # 执行当前语句
* ir_w # 执行当前语句
* self.loss_func(mask2 * Y_fusion, mask2 * Y_A) # 执行 loss_func 调用
if self.boundary # 按条件分支 self.boundary
else 0.0 # 执行当前语句
) # 结束上一层表达式
+ ( # 执行当前语句
wd["inten_f_vi"] # 执行当前语句
* vi_w # 执行当前语句
* self.loss_func(Y_fusion * (1 - mask2), Y_B * (1 - mask2)) # 执行 loss_func 调用
if self.boundary # 按条件分支 self.boundary
else 0.0 # 执行当前语句
) # 结束上一层表达式
) # 结束上一层表达式
# * color loss
# 1. || (1-mask) * fusion_Cb - (1-mask) * ir_Cb || ir_Cb loss
# 2. || (1-mask) * fusion_Cr - (1-mask) * ir_Cr || ir_Cr loss
if self.color_loss_bg_masked: # 按条件分支 self.color_loss_bg_masked
bg_mask = 1.0 - mask2 # 更新变量 bg_mask
else: # 进入兜底分支
bg_mask = torch.ones_like(mask2) # 调用 ones_like 更新 bg_mask
if self.color_loss: # 按条件分支 self.color_loss
loss_color = 0.0 # 更新变量 loss_color
if B_is_color: # 按条件分支 B_is_color
loss_color += wd["color_f_cb"] * vi_w * self.loss_func( # 调用 loss_func 更新 loss_color +
Cb_fusion * bg_mask, Cb_B * bg_mask # 执行当前语句
) + wd["color_f_cr"] * ir_w * self.loss_func( # 结束上一层表达式
Cr_fusion * bg_mask, Cr_B * bg_mask # 执行当前语句
) # 结束上一层表达式
if A_is_color: # 按条件分支 A_is_color
loss_color += wd["color_f_cb"] * ir_w * self.loss_func( # 调用 loss_func 更新 loss_color +
Cb_fusion * bg_mask, Cb_A * bg_mask # 执行当前语句
) + wd["color_f_cr"] * vi_w * self.loss_func( # 结束上一层表达式
Cr_fusion * bg_mask, Cr_A * bg_mask # 执行当前语句
) # 结束上一层表达式
else: # 进入兜底分支
loss_intensity = ( # 更新变量 loss_intensity
( # 执行当前语句
wd["inten_f_joint"] * self.loss_func(Y_fusion, Y_joint) # 执行 loss_func 调用
if self.use_prior # 按条件分支 self.use_prior
else 0.0 # 执行当前语句
) # 结束上一层表达式
+ ( # 执行当前语句
wd["inten_f_ir"] * ir_w * self.loss_func(Y_fusion, Y_A) # 执行 loss_func 调用
if self.boundary # 按条件分支 self.boundary
else 0.0 # 执行当前语句
) # 结束上一层表达式
+ ( # 执行当前语句
wd["inten_f_vi"] * vi_w * self.loss_func(Y_fusion, Y_B) # 执行 loss_func 调用
if self.boundary # 按条件分支 self.boundary
else 0.0 # 执行当前语句
) # 结束上一层表达式
) # 结束上一层表达式
if self.color_loss: # 按条件分支 self.color_loss
loss_color = 0.0 # 更新变量 loss_color
if B_is_color: # 按条件分支 B_is_color
loss_color += wd["color_f_cb"] * self.loss_func( # 调用 loss_func 更新 loss_color +
Cb_fusion, Cb_B # 执行当前语句
) + wd["color_f_cr"] * self.loss_func(Cr_fusion, Cr_B) # 结束上一层表达式
if A_is_color: # 按条件分支 A_is_color
loss_color += wd["color_f_cb"] * self.loss_func( # 调用 loss_func 更新 loss_color +
Cb_fusion, Cb_A # 执行当前语句
) + wd["color_f_cr"] * self.loss_func(Cr_fusion, Cr_A) # 结束上一层表达式
loss_intensity = loss_intensity.nanmean() # 调用 nanmean 更新 loss_intensity
loss_fusion += loss_intensity # 更新变量 loss_fusion +
loss["intensity_loss"] = loss_intensity # 更新变量 loss["intensity_loss"]
if self.color_loss: # 按条件分支 self.color_loss
loss_color = loss_color.nanmean() # 调用 nanmean 更新 loss_color
loss_fusion += loss_color # 更新变量 loss_fusion +
loss["loss_color"] = loss_color # 更新变量 loss["loss_color"]
## lpips loss
# * lpips loss is to enhance perceptual visuality
if self.lpips: # 按条件分支 self.lpips
lpips_loss = 0.0 # 更新变量 lpips_loss
if exist_gt: # 按条件分支 exist_gt
lpips_loss += self.lpips_loss(img_fusion, fusion_gt) * wd["lpips_f_gt"] # 调用 lpips_loss 更新 lpips_loss +
# if self.still_boundary_loss_when_gt or (not exist_gt):
# lpips_loss += self.lpips_loss(img_fusion, img_A) * wd['lpips_f_ir'] + \
# self.lpips_loss(img_fusion, img_B) * wd['lpips_f_vi']
# if self.use_prior:
# if self.prior == 'max':
# img_joint = torch.max(img_A, img_B)
# elif self.prior == 'mean':
# img_joint = (img_A + img_B) / 2
# lpips_loss += self.lpips_loss(img_fusion, img_joint) * wd['lpips_f_joint']
loss_fusion += lpips_loss # 更新变量 loss_fusion +
loss["lpips_loss"] = lpips_loss # 更新变量 loss["lpips_loss"]
## grad loss
# * gradient loss is to enhance the fused image
# * keep it although the GT is provided
if self.grad: # 按条件分支 self.grad
if self.grad_only_on_Y: # 按条件分支 self.grad_only_on_Y
grad_A = self.grad_op(Y_A) # 调用 grad_op 更新 grad_A
grad_B = self.grad_op(Y_B) # 调用 grad_op 更新 grad_B
grad_fusion = self.grad_op(Y_fusion) # 调用 grad_op 更新 grad_fusion
else: # 进入兜底分支
grad_A = self.grad_op(img_A) # 调用 grad_op 更新 grad_A
grad_B = self.grad_op(img_B) # 调用 grad_op 更新 grad_B
grad_fusion = self.grad_op(img_fusion) # 调用 grad_op 更新 grad_fusion
if self.grad_norm: # 按条件分支 self.grad_norm
grad_fusion = grad_norm(grad_fusion) # 调用 grad_norm 更新 grad_fusion
grad_joint = torch.max(grad_A, grad_B) # 调用 max 更新 grad_joint
loss_grad += wd["grad_f_joint"] * self.loss_func(grad_fusion, grad_joint) # 调用 loss_func 更新 loss_grad +
# if mask is not None:
# mask_expand = mask
# if grad_fusion.ndim > 4:
# mask_expand = mask_expand.unsqueeze(1)
# loss_grad += wd['grad_f_ir'] * self.loss_func(grad_fusion * mask_expand, grad_A * mask_expand)
# loss_grad += wd['grad_f_ir'] * self.loss_func(grad_fusion, grad_A) + \
# wd['grad_f_vi'] * self.loss_func(grad_fusion, grad_B)
loss_grad = loss_grad.nanmean() # 调用 nanmean 更新 loss_grad
loss_fusion += loss_grad # 更新变量 loss_fusion +
loss.update({"loss_grad": loss_grad}) # 执行 update 调用
## ssim loss
# * ssim loss is to enhance the fused image
# * keep it although the GT is provided
if self.ssim: # 按条件分支 self.ssim
w_A = grad_A.norm() # 调用 norm 更新 w_A
w_B = grad_B.norm() # 调用 norm 更新 w_B
Z = w_A + w_B # 更新变量 Z
w_A /= Z # 更新变量 w_A /
w_B /= Z # 更新变量 w_B /
if self.grad_only_on_Y: # 按条件分支 self.grad_only_on_Y
ssim_A = self.ssim_func(Y_fusion, Y_A) # 调用 ssim_func 更新 ssim_A
ssim_B = self.ssim_func(Y_fusion, Y_B) # 调用 ssim_func 更新 ssim_B
else: # 进入兜底分支
ssim_A = self.ssim_func(img_fusion, img_A) # 调用 ssim_func 更新 ssim_A
ssim_B = self.ssim_func(img_fusion, img_B) # 调用 ssim_func 更新 ssim_B
loss_ssim = wd["ssim_f_joint"] * (w_A * ssim_A + w_B * ssim_B) # 更新变量 loss_ssim
loss_fusion += loss_ssim # 更新变量 loss_fusion +
loss.update({"loss_ssim": loss_ssim}) # 执行 update 调用
## tv loss
if self.tv: # 按条件分支 self.tv
img_fusion_tv = img_fusion # 更新变量 img_fusion_tv
if self.grad_norm: # 按条件分支 self.grad_norm
img_fusion_tv = grad_norm(img_fusion_tv) # 调用 grad_norm 更新 img_fusion_tv
tv_loss = wd["tv_f"] * self.tv_loss(img_fusion_tv).nanmean() # 调用 tv_loss 更新 tv_loss
loss_fusion += tv_loss # 更新变量 loss_fusion +
loss.update({"tv_loss": tv_loss}) # 执行 update 调用
## correlation loss
if self.correlation: # 按条件分支 self.correlation
img_fusion_tv = img_fusion # 更新变量 img_fusion_tv
if self.grad_norm: # 按条件分支 self.grad_norm
img_fusion_tv = grad_norm(img_fusion_tv) # 调用 grad_norm 更新 img_fusion_tv
loss_corr = wd["crr_f"] * self.cc_loss(img_A, img_B, img_fusion_tv) # 调用 cc_loss 更新 loss_corr
loss_fusion += loss_corr # 更新变量 loss_fusion +
loss.update({"loss_corr": loss_corr}) # 执行 update 调用
loss.update({"loss_fusion": loss_fusion}) # 执行 update 调用
return loss_fusion, loss # 返回 loss_fusion, loss
##### 4.8.3.1 EMMA stage two fusion training loss
loss_cfg: # 配置项 loss_cfg
drmffusion: # 配置项 drmffusion
latent_weighted: no # 配置项 latent_weighted
grad_loss: yes # 配置项 grad_loss
color_loss: no # VIF: yes, others: no
ssim_loss: yes # 配置项 ssim_loss
prior: null # 配置项 prior
boundary_loss: yes # 配置项 boundary_loss
mask_loss: no # VIF: yes, others: no
lpips_loss: no # 配置项 lpips_loss
color_loss_bg_masked: yes # 配置项 color_loss_bg_masked
tv_loss: no # 配置项 tv_loss
pseudo_l1_const: 0. # [0.002, 0.]
correlation_loss: no # 配置项 correlation_loss
reduce_label: no # VIF: no
grad_only_on_Y: no # 配置项 grad_only_on_Y
grad_op: "sobel_add" # 配置项 grad_op
still_boundary_loss_when_gt: yes # 配置项 still_boundary_loss_when_gt
grad_norm: no # 配置项 grad_norm
ssim_implm_by: torch # 配置项 ssim_implm_by
ssim_window_size: 5 # 配置项 ssim_window_size
weight_dict: # 配置项 weight_dict
fusion_gt: 5. # 配置项 fusion_gt
inten_f_joint: 10. # prior used
inten_f_ir: 10. # 配置项 inten_f_ir
inten_f_vi: 10. # 配置项 inten_f_vi
color_f_cb: 2. # 配置项 color_f_cb
color_f_cr: 2. # 配置项 color_f_cr
grad_f_joint: 20. # [5, 40] for VIF and MEF
ssim_f_joint: 2. # 配置项 ssim_f_joint
lpips_f_gt: 1. # 配置项 lpips_f_gt
tv_f: 0.1 # 配置项 tv_f
crr_f: 0.02 # 配置项 crr_f
# unused
lpips_f_joint: 0.2 # 配置项 lpips_f_joint
lpips_f_ir: 0.2 # 配置项 lpips_f_ir
lpips_f_vi: 0.2 # 配置项 lpips_f_vi
filmfusion: # 配置项 filmfusion
prior: mean # 配置项 prior
weight_grad: 1000 # 配置项 weight_grad
weight_ssim: 1. # 配置项 weight_ssim
ssim_window_size: 5 # 配置项 ssim_window_size
l1ssim: # 配置项 l1ssim
weighted_r: [1.0, 0.1] # 配置项 weighted_r
implem_by: torch # 配置项 implem_by
window_size: 5 # 配置项 window_size
grad_norm: no # 配置项 grad_norm

因此,式(19)-式(23)在仓库中并未采用逐式静态硬编码,而是通过统一损失框架按任务配置激活并加权;这也是其可迁移覆盖 VIF/MEF/MFF/MIF/HMIF/Pansharpening 多任务训练流程的工程基础。


05 实验分析#

主要分析 SOTA 对比表以及消融实验。

5.1 实验设置与评测协议#

5.1.1 任务覆盖与数据集#

文中覆盖六类融合任务:

  • VIF:MSRS、M3FD、TNO;
  • MIF:Medical Harvard;
  • MEF:SICE、MEFB;
  • MFF:MFI-WHU、RealMFF;
  • Pansharpening:WV3、GF2、QB;
  • HMIF:Chikusei、Pavia。

这一覆盖范围保证结论不局限于单任务分布。尤其是同时包含跨传感器融合,如 VIF、Pansharpening、HMIF,以及同传感器参数差异融合,如 MEF、MFF,使方法泛化判断更可信。

5.1.2 基线与指标#

文中基线覆盖 decomposition、task-designed、prior-based、architecture-designed、modality-guided 等多类方法。指标体系包含 MI、VIF、SF、Qcb、Qy、Qcv、Qabf、LPIPS,以及 SAM、ERGAS、Q2n、HQNR、PSNR、SSIM 等。

从评估逻辑看,这套指标组合同时约束了信息量、结构质量、感知一致性与遥感光谱质量,避免“单指标最优”误导。

5.1.3 语义输入配置与任务差异化处理#

文中在语义输入上并非“一套配置跑全部任务”,而是根据任务属性做了条件化策略,这一点对复现实验非常关键:

  • 在 VIF 场景中,M3FD 采用固定 prompt:People、Car、Bus、Lamp、Motorcycle、Truck,并执行 mask merging;
  • MSRS 直接使用数据集给出的人工掩码;
  • MEF/MFF/MIF 由于场景对象差异较大,采用 Florence 自动提取 prompt 以支持开集语义;
  • Pansharpening 与 HMIF 因训练样本常用 64×64 小块,文中设置为仅保留语言引导,不使用 mask 引导。

这个配置策略背后有明确原则:当空间尺寸较小、对象边界不稳定或掩码噪声可能压过有效信息时,强行注入 mask 反而可能引入误导;而在目标显著、语义对象明确的场景,如 VIF,mask 的收益更容易释放。换言之,RWKVFusion 并非在所有任务中执行等强度语义注入,而是依据数据形态进行可解释的注入强度调度。

5.1.4 指标解释的组织顺序#

文中指标较多,若不做层次化组织,容易出现“指标堆叠但结论模糊”的问题。本文采用如下解释顺序:

  1. 先看任务核心指标,如遥感任务的 SAM/ERGAS/Q2n/HQNR,VIF 的 MI/VIF/Q 系列;
  2. 再看感知相关指标,如 LPIPS 与结构保持,判断是否存在“分数上升但视觉退化”;
  3. 最后结合可视化图 Fig.8、Fig.9,验证指标变化是否对应可解释的视觉差异。

按这个顺序,RWKVFusion 的优势会更清晰:它在多数任务上不是单一指标尖峰,而是在“信息量、结构一致、下游可用性”三条轴上同时保持稳定。这也是图1 雷达图呈现外扩包络的根本原因。


5.2 主实验结果 结论与证据链#

5.2.1 VIF 与 MIF#

表 II VIF 与 MIF 的定量结果。RWKVFusion 在 MSRS、M3FD、TNO 与 Medical Harvard 多数据集上保持较高占优比,优势并非来自单一数据集的偶然波动。
表 II VIF 与 MIF 的定量结果。RWKVFusion 在 MSRS、M3FD、TNO 与 Medical Harvard 多数据集上保持较高占优比,优势并非来自单一数据集的偶然波动。

文中对 VIF 的报告显示:

  • 在 MSRS 上,8 项指标中 7 项最优;
  • 在 M3FD、TNO 上,多数指标第一或第二;
  • 在 Medical Harvard 数据集 MIF 上,除 LPIPS 外其余指标领先。

文中节选值包括:M3FD 上 MI 2.57、VIF 0.78、Qabf 0.70、Qy 0.96;Medical Harvard 上 MI 2.02、VIF 0.57、SF 22.03、Qy 0.90。该证据链支持一个关键结论:语义引导并没有以牺牲纹理/结构指标为代价,而是在目标保持与视觉一致性之间取得了联合收益。

5.2.2 MEF 与 MFF#

表 III MEF 与 MFF 的定量结果。RWKVFusion 在多曝光与多焦点场景中保持较稳的综合领先,尤其在 SICE、MEFB 与 RealMFF 上优势更明显。
表 III MEF 与 MFF 的定量结果。RWKVFusion 在多曝光与多焦点场景中保持较稳的综合领先,尤其在 SICE、MEFB 与 RealMFF 上优势更明显。

图8 M3FD、SICE、RealMFF 的可视化比较。RWKVFusion 在烟雾遮挡、曝光失衡、离焦区域等困难场景下,兼顾显著目标保留与背景纹理稳定。
图8 M3FD、SICE、RealMFF 的可视化比较。RWKVFusion 在烟雾遮挡、曝光失衡、离焦区域等困难场景下,兼顾显著目标保留与背景纹理稳定。

结合 Table III 与图8 :

  • 定量上,RWKVFusion 在 SICE、MEFB、MFI-WHU、RealMFF 的多数指标中处于领先;
  • 定性上,方法在高亮区和暗区之间维持更一致的曝光平衡,并在焦内/焦外边界处保留更完整纹理。

5.2.3 Pansharpening 与 HMIF#

表 IV WV3 上的 pansharpening 定量结果。RWKVFusion 在 RR 与 FR 协议下均表现突出,并兼顾光谱质量与结构保真。
表 IV WV3 上的 pansharpening 定量结果。RWKVFusion 在 RR 与 FR 协议下均表现突出,并兼顾光谱质量与结构保真。

表 V Chikusei 与 Pavia 的 HMIF 定量结果。RWKVFusion 在主要质量指标上占优,同时保持较好的参数/FLOPs 效率。
表 V Chikusei 与 Pavia 的 HMIF 定量结果。RWKVFusion 在主要质量指标上占优,同时保持较好的参数/FLOPs 效率。

图9 Pansharpening 与 HMIF 的误差图。RWKVFusion 在细节边缘与结构过渡区的残差更低,尤其在高频结构区域更明显。
图9 Pansharpening 与 HMIF 的误差图。RWKVFusion 在细节边缘与结构过渡区的残差更低,尤其在高频结构区域更明显。

文中给出的关键数值包括:

  • WV3:SAM 2.78、ERGAS 2.03、Q2n 0.918、SCC 0.988、DλD_\lambda 0.016、DsD_s 0.036、HQNR 0.949;
  • Chikusei:PSNR 43.89、SAM 1.93、ERGAS 3.33、SSIM 0.963;
  • Pavia:PSNR 36.06、SAM 3.95、ERGAS 3.07、SSIM 0.936。

文中同时指出,在 HMIF 对比中,RWKVFusion 相对 DHIF 仅使用约 8.41% 参数量和 0.67% FLOPs。该“性能 + 复杂度”联合结果,是其方法价值的重要支撑。

5.2.4 跨任务一致性与结果边界#

把 VIF/MIF/MEF/MFF 与 Pansharpening/HMIF 放在一起看,可以观察到 RWKVFusion 的一个鲜明特征: 优势不依赖单一模态组合或单一退化类型。前四类任务主要考察自然图像层面的纹理与目标保留,后两类任务强调光谱一致性与空间细节重建;RWKVFusion 在这两类评价体系中都能给出正向结果,说明其主干设计并未绑定于某一任务先验。

但该一致性并不意味着“无条件全胜”。从文中呈现可见,个别指标上仍可能出现次优,这与任务差异、上游语义质量和评价指标偏好有关。这里应避免把文中结论简化成“全指标第一”,更准确的表述是:RWKVFusion 在跨任务综合表现和性能-复杂度平衡上具备显著优势。


5.3 消融实验#

表 VI 算子替换、扫描策略、MFM 设计与 plain backbone 的主消融结果。RWKVFusion 的性能收益来自结构协同,而非单个组件的偶然增益。
表 VI 算子替换、扫描策略、MFM 设计与 plain backbone 的主消融结果。RWKVFusion 的性能收益来自结构协同,而非单个组件的偶然增益。

5.3.1 BRWKV 与替代算子#

文中将 BRWKV 替换为 flash attention、flatten attention、window attention、VMamba。结果显示 BRWKV 在大多数指标上更稳。该结论说明性能提升并非来自“简单扩大参数规模”,而是来自 WKV 机制与融合任务需求匹配。

5.3.2 ESS 扫描策略#

消融显示:4 scans 并未稳定优于默认配置;8 scans 在部分指标略好但计算代价增大。结论是默认 ESS 在效率-性能之间达到更优折中。

5.3.3 MFM 结构替换#

将 MFM 替换为简单 MLP 或 cross-attention 变体后,指标下降。该结果与图6 可视化一致,说明三路径语义注入 raw/modality + mask + text 具有不可替代的机制价值。

5.3.4 语义引导与 mask merging#

表 VII 语义引导与 mask merging 消融结果。caption 与 merged mask 的组合优于单一路径与未合并掩码。
表 VII 语义引导与 mask merging 消融结果。caption 与 merged mask 的组合优于单一路径与未合并掩码。

在 MSRS 上,文中报告:

  • only caption:MI 3.20、VIF 0.76;
  • only merged mask:MI 3.36、VIF 0.78;
  • no guidance:MI 3.10、VIF 0.69;
  • caption + unmerged mask:MI 3.38、VIF 0.84;
  • caption + merged mask,默认设置:MI 3.42、VIF 0.87。

该组结果支持以下两条因果判断:

  1. caption 与 mask 是互补信号,不是冗余信号;
  2. mask merging 的收益是独立可验证的,不是与 caption 强绑定后的偶然增益。

5.3.5 Prompt 设定#

表 VIII auto-prompt 与 fixed-prompt 的消融结果。fixed-prompt 略优,但两种设定下方法整体都保持较强竞争力。
表 VIII auto-prompt 与 fixed-prompt 的消融结果。fixed-prompt 略优,但两种设定下方法整体都保持较强竞争力。

fixed-prompt 略优于 auto-prompt,但二者均优于多数对比方法。该现象说明:开放式语义引导是可行的,但上游提示词质量仍会影响最终融合上限。

5.3.6 Plain 与 Multi-scale#

文中在 plain backbone 对照下显示多尺度结构更优,且参数量同量级。对于像素级融合任务,这说明“跨尺度上下文传递”仍是必要条件,而非可随意替代的结构装饰。

5.3.7 ERF 证据#

图10 不同算子 CNN、Mamba、Attention、RWKV 的有效感受野比较。RWKVFusion 的 ERF 更广且响应更集中,支撑其“低代价全局建模”主张。
图10 不同算子 CNN、Mamba、Attention、RWKV 的有效感受野比较。RWKVFusion 的 ERF 更广且响应更集中,支撑其“低代价全局建模”主张。


5.4 下游任务验证 融合质量是否可迁移#

下游任务部分的关键问题是:融合结果是否真正提升感知模型表现,而不仅是视觉观感更“清晰”。这里分别给出 Table IX 语义分割与 Table X 目标检测,并在对应小节逐项解读。

5.4.1 单目深度估计#

图11 Depth Anything v2 的可视化结果对比。RWKVFusion 生成的融合图在轮廓连续性和远景层次上更有利于深度估计。
图11 Depth Anything v2 的可视化结果对比。RWKVFusion 生成的融合图在轮廓连续性和远景层次上更有利于深度估计。

该部分文中给出的是可视化证据。统一量化深度指标,如 Abs Rel、RMSE,文中未给出/未报告

5.4.2 语义分割#

图12 SegFormer 分割结果对比。RWKVFusion 在小目标与边界区域的分割连续性更好,对 mIoU 提升有直接贡献。
图12 SegFormer 分割结果对比。RWKVFusion 在小目标与边界区域的分割连续性更好,对 mIoU 提升有直接贡献。

表 IX MSRS 数据集上的语义分割定量结果,含各类别 IoU、mIoU、mAcc。RWKVFusion 在总体分割指标上保持领先,即 mIoU 79.61、mAcc 88.72,并在关键类别上呈现更稳定表现。
表 IX MSRS 数据集上的语义分割定量结果,含各类别 IoU、mIoU、mAcc。RWKVFusion 在总体分割指标上保持领先,即 mIoU 79.61、mAcc 88.72,并在关键类别上呈现更稳定表现。

文中表格结果显示:Proposed 的 mIoU 为 79.61、mAcc 为 88.72,高于 FILM 79.43/88.64 与 TextIF 79.27/88.30。这说明语义增强的融合结果并非仅提升视觉观感,而是能够实质性提升语义任务表现。

5.4.3 目标检测#

图13 YOLOv5 检测可视化,Prediction 与 GT 对照。RWKVFusion 在多数检测场景下提供更高的目标可分辨性,特别是行人与车辆类别。
图13 YOLOv5 检测可视化,Prediction 与 GT 对照。RWKVFusion 在多数检测场景下提供更高的目标可分辨性,特别是行人与车辆类别。

表 X MSRS 数据集目标检测定量结果,People、Car、mAP@0.5、mAP0.5:0.9。RWKVFusion 在多数检测指标上占优,但在 mAP0.5:0.9 上并非绝对最优。
表 X MSRS 数据集目标检测定量结果,People、Car、mAP@0.5、mAP0.5:0.9。RWKVFusion 在多数检测指标上占优,但在 mAP0.5:0.9 上并非绝对最优。

文中给出关键数值:Proposed 在 People 0.966、Car 0.847、mAP@0.5 0.907、mAP0.5:0.9 0.697。需要如实指出,SegMIF 在 mAP0.5:0.9 上为 0.703,高于 RWKVFusion。即 RWKVFusion 不是“所有检测指标绝对第一”,而是多数指标上的更优综合表现。


5.5 局限性与进一步思考#

基于主文与结果呈现方式,可以归纳出以下边界条件:

  1. 上游语义链路误差传导:caption、检测框、掩码质量会直接影响融合效果;
  2. 主干高效不等于全链路高效:主文重点报告主干复杂度优势,但包含 Florence/DINO/SAM 的统一端到端时延,文中未给出/未报告;
  3. 小尺寸遥感训练块中弱化掩码分支:在 pansharpening 与 HMIF 中,因训练块尺寸较小,文中给出的尺寸为 64×64,实践上省略 mask,仅保留语言引导;
  4. 深度任务量化不完整:深度估计仅给出可视化,统一数值指标文中未给出/未报告。

这些限制并不削弱文中主结论,但提示后续方向应聚焦于:更轻量的语义生成链路、端到端系统级时延评估、任务自适应的语义注入策略,以及更完整的跨任务量化协议。

5.5.1 面向复现与改进的具体启示#

如果以“可复现可扩展”为目标,这项工作给出的直接启示至少有三点。第一,语义分支应被视为可替换部件,而不是固定实现。主文证明了“语言 + 掩码”这类条件本身有效,但并未限定必须使用 Florence/DINO/SAM 这一组合,因此后续工作可以在不改动主干的前提下,用更轻量的语义生成器替换上游链路。第二,mask 不是越多越好,关键是质量控制。文中通过 merged 与 unmerged 对照已经说明,不受控的掩码输入会稀释语义收益,甚至向主干注入噪声。第三,融合研究不应只停留在融合指标,必须绑定下游任务检验。本文在检测和分割上的结果提示我们:真正高价值的融合表示,应该在“视觉可读性”和“机器可判别性”两端同时成立。

从研究方法论看,RWKVFusion 也提供了一个可迁移范式:先在任务定义层明确缺口,再在结构层给出针对性设计,最后用主结果、消融、下游三层证据闭环验证。这个范式对于后续多模态低层视觉任务,不仅限于图像融合,同样具有参考意义。


5.6 从输入到输出的一次完整前向#

为系统梳理本文的方法逻辑,可沿一次前向传播完整追踪信息流。这样做的价值在于,不少解读停留在模块名词罗列,如 RWKV、MFM、mask,但未回答“这些模块在计算图中的时序作用与相互约束关系”。RWKVFusion 的设计恰恰依赖严格时序:语义先生成、再注入、再跨尺度传播、最后重建。

5.6.1 输入阶段 多模态图像与语义条件的并行准备#

在输入端,图像模态 S1,,SnS_1,\dots,S_n 进入融合分支;caption 与 mask 进入语义分支。两条分支虽为并行结构,但并不独立:语义分支输出是后续编码层的必需条件,因此在系统上属于前置条件准备,而非可选附加通道。这个设计与许多把语义当作后验打分器的融合方法有本质区别。

需要强调,文中没有把语义条件定义成单一向量,而是拆分为文本条件 TT 与空间条件 MM。这种拆分背后的假设是:全局语义一致性和局部目标定位是两个不同层面的约束,不能由单一信号替代。后续表 VII 的结果也验证了这一点:只用 caption 或只用 merged mask 都弱于二者联合。

5.6.2 编码阶段 语义条件在多尺度主干中的层内注入#

进入编码端后,RWKVFusion 并非“先完成融合再叠加语义”,而是在每个编码层通过 MFM 执行条件注入。其关键效果是把语义约束从输出端前移到特征形成端,减少后期补偿式修正造成的信息损失。

更具体地说,MFM 在每层做三件事:

  1. 保留并重校准原始模态信息,避免语义引导导致低层纹理被过度抑制;
  2. 利用掩码路径在空间上强化对象响应,把“哪里重要”显式落到特征图上;
  3. 通过文本交替拼接给出全局语义方向,把“该保留何种语义关系”加入序列建模。

这三步存在先后依赖:先有图像主路径,掩码对其进行空间约束,文本再做全局语义调制。若把顺序打乱,低层信息与高层语义的耦合稳定性会下降,这也能解释为何简单 MLP 或双 cross-attention 替代不能达到默认 MFM 的效果。

5.6.3 核心算子阶段 BRWKV 负责“全局关系建模 + 线性代价”#

在编码层内部,BRWKV 分为空间混合和通道混合两部分。空间混合阶段通过 WKV 聚合全局信息,通道混合阶段进行非线性重标定。该结构与 Transformer block 的功能映射相似,但实现机制不同:RWKV 不显式构造完整注意力矩阵,而是依赖衰减记忆与递推读取。

从任务角度看,这一点对图像融合非常关键。融合任务不是高层语义分类,而是像素级重建;输入常是高分辨率图像,token 长度大,显存和 FLOPs 都更敏感。

5.6.4 解码与重建阶段 语义不再重复注入的原因#

主文中解码器不再继续注入 caption/mask,该设计具有明确针对性。解码器职责是将融合特征恢复到像素域并保留细节连续性;如果在该阶段继续强语义注入,容易把重建过程变成语义重写过程,导致纹理不自然或边缘伪影。

因此,语义主要在编码端“定向”,解码端主要“还原”。这是本文设计中一个容易忽略但非常重要的工程细节: 语义引导与图像重建分阶段治理,而不是在每个阶段都做同强度语义干预。

5.6.5 系统层视角 主干高效与全链路开销的关系#

文中在主文中清晰给出了 BRWKV 主干的复杂度优势,并在遥感任务中给出参数/FLOPs 对比。但在系统层面仍应区分两件事:

  • 主干前向开销,RWKVFusion 网络本体;
  • 语义准备开销,Florence、DINO、SAM、T5。

前者文中报告充分,后者主文没有统一端到端时延统计。也就是说,“主干高效”这个结论是成立的,但“整链路总时延最优”在主文证据范围内尚不能直接推出。这一点在工程复现时尤其要注意。

5.7 证据链交叉验证#

从研究设计角度,本文结论建立在三层证据的相互支撑上。

  1. 主实验层 Table II 至 Table V 证明方法在六类融合任务中的综合性能优势。
  2. 机制层 Table VI 至 Table VIII 说明收益来源于 BRWKV、ESS、MFM 与语义注入的结构协同,而非单一技巧。
  3. 迁移层 Fig.11 至 Fig.13 与 Table IX 至 Table X 验证融合结果可向分割与检测任务迁移。

三层证据共同指向同一判断 RWKVFusion 的改进具有可解释性、可复核性与跨任务可迁移性。即便在个别指标出现次优,整体证据链仍支持其方法有效性。


5.8 总结#

RWKVFusion 的学术价值不在于“又一个融合网络”,而在于其把三个长期分离的问题放进了同一证据闭环:

  • 在任务定义层面,用式(5)把语言与掩码变为融合前向条件;
  • 在网络算子层面,用 BRWKV + ESS 构建线性主开销的全局建模路径;
  • 在实证层面,用跨任务主实验、结构消融与下游验证证明收益来源。

进一步看,这篇工作真正值得借鉴的是方法论:先明确任务缺口,再给出结构设计,再用消融和下游任务建立因果与迁移证据。对后续研究而言,可延展方向包括:更轻量的语义生成链路、更完整的端到端时延评估、以及面向任务差异的自适应语义注入机制。

若将该工作置于近两年的融合研究脉络中考察,其关键贡献并不止于“语言引导”标签,而在于语义条件、主干效率与跨任务验证的同步落地。许多方法只在其中一维突出,而 RWKVFusion 的价值在于让三维同时成立并可复核。由此可见,评价重点不应停留在单一指标增量,而应落在研究设计是否形成可迁移、可解释、可工程化的完整闭环。该视角也有助于后续选题判断。


5.9 参考#

  • Cao Z.-H., Liang Y.-J., Deng L.-J., Vivone G., An Efficient Image Fusion Network Exploiting Unifying Language and Mask Guidance, IEEE TPAMI, 2025.
  • 代码: https://github.com/294coder/RWKVFusion

06 个人声明#

本文为作者对原论文的学习笔记与心得分享,受个人学识与理解所限,文中对论文内容的解读或有不够周全之处,一切以原论文正式表述为准。本文仅用于学术交流与传播,内容均由作者独立整理完成,不代表本公众号立场。如文中所涉文字、图片等内容存在版权争议,请及时与作者联系,作者将在第一时间核实并妥善处理。

文章分享

如果这篇文章对你有帮助,欢迎分享给更多人!

【论文阅读 | TPAMI 2025 | RWKVFusion:利用统一语言与掩码引导的高效图像融合网络】
https://mjy.js.org/posts/paper-tpami-2025-rwkvfusion/
作者
MaJianyu
发布于
2026-02-11
许可协议
CC BY-NC-SA 4.0
Profile Image of the Author
MaJianyu
永远相信,美好的事情即将发生。
音乐
封面

音乐

暂未播放

0:00 0:00
暂无歌词
分类
标签
站点统计
文章
37
分类
7
标签
107
总字数
195,390
运行时长
0
最后活动
0 天前

目录