【论文阅读 | TPAMI 2025 | RWKVFusion:利用统一语言与掩码引导的高效图像融合网络】
[TOC]
题目: An Efficient Image Fusion Network Exploiting Unifying Language and Mask Guidance
期刊: IEEE Transactions on Pattern Analysis and Machine Intelligence (TPAMI)
年份: 2025
作者: Zi-Han Cao, Yu-Jie Liang, Liang-Jian Deng, Gemine Vivone
代码: https://github.com/294coder/RWKVFusion
0. 结论先行
0.1 三个关键判断
- 方法层判断: RWKVFusion 的贡献不在“替掉某一块”,而在把“全局语义文本 + 对象掩码 ”直接写进融合定义 Eq.(5),并通过 MFM 在编码阶段持续注入。语义引导消融 Table VII 对此给了正面支持。
- 算子层判断: BRWKV 通过 WKV 的递推化表达 Eq.(12)-Eq.(14),把全局建模代价压到与 token 长度 线性相关;结合 ERF 结果 Fig.10 与算子替换消融 Table VI,更像是在“感受野-复杂度”上找到了更稳的折中。
- 实证层判断: 六类融合任务主实验 Table II-V、机制消融 Table VI-VIII、下游迁移 Fig.11-Fig.13 + Table IX-X 形成闭环证据。更合适的表述是“跨任务稳健领先 + 机制可解释 + 迁移有效”,而不是“所有指标都第一”。
0.2 证据链导航
- 问题动机:为什么需要语义引导与高效主干,证据为 Fig.1-Fig.3 + Eq.(1)-Eq.(5)。
- 机制证据:网络如何实现,证据为 Fig.4-Fig.7 + Eq.(6)-Eq.(23)。
- 结果证据:是否有效、为何有效、能否迁移,证据为 Table II-X + Fig.8-Fig.13。
下文按“主实验 Table II-V → 消融 Table VI-VIII → 下游 Table IX-X”组织证据链。
1. 问题动机与工作定位
这项工作解决的不是“再做一个更高分的融合网络”,而是更具体的结构性矛盾: 如何在不依赖复杂先验、如 GAN、扩散或下游任务头的前提下,把语义信息显式注入融合过程,同时保证高分辨率场景下的计算可承受性。文中给出的答案是 RWKVFusion: 以 RWKV 为高效主干,以语言与语义掩码作为统一引导信号,在六类融合任务上给出主结果、消融和下游任务证据链。

图1:RWKVFusion 与多种方法在多任务融合上的综合性能雷达图。图中更强调“跨任务的整体包络”,而不是单一任务的偶然峰值。
1.1 传统融合定义与瓶颈
经典图像融合任务可写为:
式(1)可以表达多模态输入到融合输出的映射关系,但没有给出“应优先保留哪些语义目标”的显式约束。结果是,网络通常只能依靠统计相关性去学习“哪里重要”,在低照、遮挡、烟雾与多目标等复杂场景下,容易出现目标弱化、边界漂移与结构失衡等问题。

图2:不同融合任务的输入样例及其语言/掩码引导信息示例。语言描述更偏全局语义,掩码更偏空间定位,两者互补。
1.2 既有融合框架的三类代价

图3:已有融合框架与 RWKVFusion 的方法学对比。重点在于:自动语义生成 + 线性复杂度主干,试图同时规避“标注依赖、复杂先验、算子低效”三类代价。
从文中的比较视角看,既有方法主要面临三类现实问题:
- 语义信息注入依赖额外任务头,例如分割/检测联训,会带来标注成本与训练开销;
- 复杂先验链路带来系统复杂化,例如 GAN 双网络、扩散推理过程与深先验迭代;
- 高分辨率条件下算子代价偏高,传统注意力在 token 维度上存在二次项。
因此本文的目标并不是替换一个模块,而是把“语义可控 + 全局建模 + 计算可控”放在同一框架下联合成立。
2. 从注意力到 RWKV:理论过渡与计算动机
为了说明为何选择 RWKV 而不是直接继续改造 Transformer,文中先回到标准注意力形式:
按 token 写作:
进一步引入位置相关权重后,得到可递推改写的形式:
这三步的意义在于:文中并非经验性“换主干”,而是先把注意力重写成更接近递推计算的表达,再引入 RWKV 的衰减记忆机制。这样可以在保留全局依赖建模能力的同时,避免标准自注意力在高分辨率场景中的高代价路径。
3. 方法总览:统一语义引导的 RWKVFusion
RWKVFusion 的关键变化是把语义从“外部后验信息”提升为“前向过程中的条件变量”。文中将任务定义升级为:
其中 是文本语义编码, 是语义掩码。式(5)是全文最重要的任务层变化:网络不仅学习“怎么融合”,还要在训练和推理阶段始终回答“按什么语义意图融合”。

图4:语义分支流程,包含 caption、检测框、掩码与文本编码,以及与融合主干的连接关系。它不是后处理补丁,而是为主干提供可训练的条件输入。

图5:RWKVFusion 主干结构,包括多尺度编码解码、BRWKV、MFM 与 ESS。它把“多尺度表达、语义注入、线性主干算子”放进同一计算图,而非松散拼接。
3.1 语义分支:从开放词汇提示到可注入条件
语义分支由 Florence、DINO、SAM、T5 组成,其目标不是输出“可视化附属信息”,而是构造可被主干直接消费的条件变量。具体流程为:先由 Florence 生成 caption 候选,再经 DINO 做开放词汇检测,SAM 产出实例掩码,最后由 T5 编码文本并与 mask merging 后的掩码一起送入融合主干。
3.2 融合主干:多尺度 BRWKV 编码-解码架构
融合分支以多尺度 BRWKV 为主干,并通过 MFM 在编码阶段执行语义条件注入。编码器负责“语义调制 + 跨尺度表征建立”,解码器负责“结构恢复 + 纹理重建”。文中特意不在解码端重复注入语义条件,目的是降低重建阶段被高层语义扰动的风险。
该设计可概括为:语义分支负责“提供条件”,融合分支负责“条件化融合”,两者在训练与推理阶段均处于统一前向图中,而非后处理型松耦合架构。
3.3 代码对照:主干前向入口与条件注入时序
下面代码来自仓库 https://github.com/294coder/RWKVFusion,本地版本为 a00d633,对应方法总览在工程实现中的真实执行顺序。
def _forward_implem(self, inp, modal, mask=None, llm_feature=None): # inp: vi/over/far/SPECT/LRMS # modal: ir/under/near/MRI/PAN
# forward features feature_out, patch_embd_x = self._forward_features( inp, modal, mask, llm_feature )
# forward prior prior = self._foward_prior(inp, modal)
# reconstruction outp = self._forward_recon(feature_out, prior, patch_embd_x)
return outp def _forward_features(self, inp, modal, mask=None, llm_feature=None): bs, _, H, W = inp.shape
cat_x_modal = torch.cat([inp, modal], dim=1) patch_embd_x = self.patch_embd(cat_x_modal) modal = self.modal_in(cat_x_modal)
if self.if_abs_pos: x = patch_embd_x + self.resize_pos_embd( self.abs_pos, inp_size=(H, W) ).expand(bs, patch_embd_x.size(1), -1, -1) else: x = patch_embd_x
if llm_feature is not None and self.llm_channel is not None: llm_feature = F.normalize(llm_feature, p=2, dim=1) llm_feature = self.llm_embd(llm_feature) llm_feature = (llm_feature + self.txt_sine_pe).float() llm_feature = self.embd_drop(llm_feature)
if mask is not None and hasattr(self, "mask_embd"): mask = self.mask_embd(mask)
# init condition u_cond = ConditionInput(modal, llm_feature, mask)
## encoder encs = [] for encoder, down in zip(self.encoders, self.downs): u_cond.modalities = self.resize_conds( H, W, u_cond.modalities, resample_type="bilinear" ) u_cond.mask_input = self.resize_conds( H, W, u_cond.mask_input, resample_type="nearest" ) x, u_cond = encoder.enc_forward(x, u_cond, (H, W)) encs.append(x) x = down(x) H = H // 2 W = W // 2
## middle layer u_cond.modalities = self.resize_conds( H, W, u_cond.modalities, resample_type="bilinear" ) u_cond.mask_input = self.resize_conds( H, W, u_cond.mask_input, resample_type="nearest" ) x, _ = self.middle_blks.enc_forward(x, u_cond, (H, W))
## decoder for decoder, up, enc_skip, skip_scale in zip( self.decoders, self.ups, encs[::-1], self.skip_scales ): x = up(x) H = H * 2 W = W * 2 u_cond.modalities = self.resize_conds( H, W, u_cond.modalities, resample_type="bilinear" ) u_cond.mask_input = self.resize_conds( H, W, u_cond.mask_input, resample_type="nearest" ) x = torch.cat([x, enc_skip * skip_scale.view(1, -1, 1, 1)], dim=1) x = decoder.dec_forward(x, u_cond, (H, W))
return x, patch_embd_x这段实现表明,“语义引导并非后处理补丁”在工程上由 ConditionInput 的层级传递机制保障,并最终作用于融合特征形成过程。
4. 融合主干 BRWKV 机制拆解:空间混合、通道混合与递推化
4.1 空间混合 Spatial Mixing
在 BRWKV 的空间分支中,输入序列先经线性投影:
随后通过 WKV 进行全局聚合:
最后由门控得到空间输出:
式(7)中的两个参数含义非常关键: 控制通道级空间衰减, 控制当前位置 bonus。其机制本质为“位置衰减记忆 + 当前 token 增益”。与显式输出 attention map 的标准注意力不同,WKV 可视为先将历史 token 压缩到递推状态,再执行状态读取。
4.2 通道混合 Channel Mixing
空间输出先做归一化:
再做通道域投影与非线性:
门控后得到通道输出:
这一段可以理解为“先在空间域做全局关系聚合,再在通道域做非线性重标定”。它对应 Transformer 块中 attention + FFN 的角色分工,但实现路径和复杂度结构不同。
4.3 递推化改写与复杂度来源
文中将式(7)进一步改写为隐藏状态递推形式。定义:
对应地:
WKV 的 FLOPs 记为:
式(14)对应文中关于效率的核心论据:开销与 token 长度 线性相关,而不是标准注意力常见的二次关系。对图像融合这类高分辨率、像素级任务,该差异是能否落地部署的分水岭。
4.4 BRWKV 的机制优势不应被“线性复杂度”单一表述替代
仅以“线性复杂度”概括 BRWKV 并不充分。更关键的是,RWKV 在保持全局交互能力时,将“显式两两相关矩阵”替换为“可递推的衰减记忆状态”,从而把全局建模与状态压缩放在同一步中完成。这意味着:
- 该机制并非对注意力能力的简单削弱,而是对信息组织方式的重构;
- 对低层视觉任务,递推状态可减少高分辨率输入下的显存突增风险;
- 与仅依赖局部窗口的方案相比,它在跨区域结构一致性上更具理论优势。
因此,RWKVFusion 的效率收益与结构收益是耦合产生的:前者来自递推化,后者来自全局关系保留,而非二者择一。
4.5 代码对照:BRWKV 的空间混合与通道混合
文中在式(6)-式(14)中给出的是数学形式,仓库中对应实现落在 DoubleStreamRWKVBlock.BRWKV_img_forward 与 VRWKV_SpatialMix_wkv5.forward 两处。
def BRWKV_img_forward( self, x: torch.Tensor, MIMF_modals: torch.Tensor = None, MIFM_mm: torch.Tensor = None, llm_feat: torch.Tensor = None, patch_resolution=None, ): """ ##! 1. Image-only attention with clipping windows w_img = window_partition(img, window_size) w_img = scan(img) w_img = only_img_attention(img)
##! 2. Multi-modal attention with txt feature # solution 1 (downsample image) img = window_reverse(w_img, window_size, H, W) ds_img = downsample_scanned_img(img) ds_img, txt = multi_modal_attention(ds_img, txt) upsample_img = upsample_scanned_img(ds_img) img = (upsample_img + img) / 2
# solution 2 (repeat txt feature on bs-dim) rep_bs_txt = repeat(txt, 'bs d l -> (bs n_winds) d l') w_img, rep_bs_txt = multi_modal_attention(w_img, rep_bs_txt) img = window_reverse(w_img, window_size, H, W) txt = rearrange(rep_bs_txt, '(bs n_winds) d l -> bs n_winds d l', bs=B) txt = txt.mean(dim=1)
##! 3. Multi-modal FFN #! Flux.1 says we should have two FFNs, one for image and one for txt img = FFN(img, txt)
"""
B, C, H, W = x.shape # has_llm = (llm_feat is not None) and self.has_llm
# pre-fusion previous (non-downsampled) image and MIFM image # x = x + MIFM_img
# ver 3 # MIFM_mm = MIFM_mm * self.lerp_factor.view(1, -1, 1, 1) # x = self.mm_proj_in(torch.cat([x, MIFM_mm], dim=1)) x = self.mm_proj_in( torch.cat( [x, MIFM_mm * self.lerp_factor.view(1, -1, 1, 1), MIMF_modals], dim=1 ) )
# we perform scanning first and then window partition # different from v11 where we perform window partition first
# window partition # (bs, c, h * w) -> (bs x n_winds, c, window_size, window_size) h, w, x = self.partition_img_by_window(H, W, x)
# ESS x = ( self.scan(x).view(x.size(0) * self.K, -1, h * w).transpose(1, 2) ) # [b*k, l, c]
######################## Image-only Attention ##########################
# Image spatial mixing x = self.only_img_attn_ln(x) sc1, sh1, ga1 = self.modulator1(x) prenorm_x = (1 + sc1) * x + sh1 # modulated img x_attn = self.drop_path(self.att_img(prenorm_x, None, (h, w))) x = ga1 * x_attn + x
# merge ESS # (bs, c * K, H * W) -> (bs, c, H, W) x = rearrange(x, "(b k) (h w) d -> b k d h w", k=self.K, h=h, w=w) x = self.merge(x) / self.K
# window reverse # (bs x n_winds, c, wind x wind) -> (bs, c, h x w) x = self.reverse_img_by_window(H, W, x) x = x.view(B, C, H * W).transpose(1, 2)
########################################################################
######################## Image-modality FFN ############################
# Channel mixing x = self.fusion_img_llm_ln(x) sc2, sh2, ga2 = self.modulator2(x) prenorm_x = (1 + sc2) * x + sh2
# 2. pure image FFN x_ffn = self.drop_path(self.ffn_fusion(prenorm_x, None, (H, W))) x = ga2 * x_ffn + x
x = x.transpose(1, 2).view(B, C, H, W)
#########################################################################
# txt reduced channel if (not self.last_enc_block) and self.has_llm and llm_feat is not None: llm_feat = self.txt_chan_reduce(llm_feat)
return x, llm_feat def forward( self, x, txt=None, patch_resolution=None, mm_tokens: "tuple[torch.Tensor] | None" = None, ): def _inner_forward(x, txt, patch_resolution, mm_tokens): B, T, C = x.size() # sr, k, v, T = self.jit_func(x, txt, patch_resolution, mm_tokens) r, k, v, g, T = self.jit_func(x, txt, patch_resolution, mm_tokens) # x = RUN_CUDA_RWKV5(B, T, C, self.spatial_decay / T, self.spatial_first / T, k, v) x = RUN_CUDA_RWKV5_2( B, T, C, self.n_head, r, k, v, w=self.time_decay, u=self.time_faaaa ) # print(f'Spatial mix - x max: {x.abs().max()}, x norm: {x.norm()}') x = self.key_norm(x) # x = sr * x # x = self.output(x) x = self.jit_func_2(x, g)
return x
if self.with_cp and x.requires_grad: x = cp.checkpoint( _inner_forward, x, txt, patch_resolution, mm_tokens, use_reentrant=False ) else: x = _inner_forward(x, txt, patch_resolution, mm_tokens) return x这两段代码和文中公式是一一对应的: att_img 对应式(6)-式(14)的空间混合,ffn_fusion 对应式(9)-式(11)的通道混合。
5. ESS:把二维图像转成可递推序列
BRWKV 原生更接近序列建模,因此文中引入 ESS,即 Efficient Scanning Strategy,实现二维到一维的桥接。其扫描配置包括:
- 横/纵交替 + 翻转,2 scans;
- 横纵全量 + 翻转,4 scans;
- 在 4 scans 基础上加入对角扫描,8 scans。
ESS 的目标并非提升扫描方向数量本身,而是在空间覆盖率与参数/FLOPs 之间取得平衡。后续消融表 VI 显示,8 scans 在部分指标可略增,但代价上升明显;默认策略在综合性能与效率之间更稳。
5.1 代码对照:CrossScan / CrossMerge 的多方向扫描
仓库实现中,ESS 的核心不在高层 scan_mode 字符串,而在 CrossScan 与 CrossMerge 的真实张量重排逻辑:
class CrossScan(torch.autograd.Function): # ZSJ 这里是把图像按照特定方向展平的地方,改变扫描方向可以在这里修改 @staticmethod def forward(ctx, x: torch.Tensor): B, C, H, W = x.shape ctx.shape = (B, C, H, W) # xs = x.new_empty((B, 4, C, H * W)) xs = x.new_empty((B, 8, C, H * W)) # 添加横向和竖向的扫描 xs[:, 0] = x.flatten(2, 3) xs[:, 1] = x.transpose(dim0=2, dim1=3).flatten(2, 3) xs[:, 2:4] = torch.flip(xs[:, 0:2], dims=[-1])
# 提供斜向和反斜向的扫描 xs[:, 4] = diagonal_gather(x) xs[:, 5] = antidiagonal_gather(x) xs[:, 6:8] = torch.flip(xs[:, 4:6], dims=[-1])
return xs
@staticmethod def backward(ctx, ys: torch.Tensor): # out: (b, k, d, l) B, C, H, W = ctx.shape L = H * W # 把横向和竖向的反向部分再反向回来,并和原来的横向和竖向相加 # ys = ys[:, 0:2] + ys[:, 2:4].flip(dims=[-1]).view(B, 2, -1, L) y_rb = ys[:, 0:2] + ys[:, 2:4].flip(dims=[-1]).view(B, 2, -1, L) # 把竖向的部分转成横向,然后再相加,再转回最初是的矩阵形式 # y = ys[:, 0] + ys[:, 1].view(B, -1, W, H).transpose(dim0=2, dim1=3).contiguous().view(B, -1, L) y_rb = y_rb[:, 0] + y_rb[:, 1].view(B, -1, W, H).transpose( dim0=2, dim1=3 ).contiguous().view(B, -1, L) y_rb = y_rb.view(B, -1, H, W)
# 把斜向和反斜向的反向部分再反向回来,并和原来的斜向和反斜向相加 y_da = ys[:, 4:6] + ys[:, 6:8].flip(dims=[-1]).view(B, 2, -1, L) # 把斜向和反斜向的部分都转成原来的最初的矩阵形式,再相加 y_da = diagonal_scatter(y_da[:, 0], (B, C, H, W)) + antidiagonal_scatter( y_da[:, 1], (B, C, H, W) )
y_res = y_rb + y_da # return y.view(B, -1, H, W) return y_resclass CrossMerge(torch.autograd.Function): @staticmethod def forward(ctx, ys: torch.Tensor): B, K, D, H, W = ys.shape ctx.shape = (H, W) ys = ys.view(B, K, D, -1) # ys = ys[:, 0:2] + ys[:, 2:4].flip(dims=[-1]).view(B, 2, D, -1) # y = ys[:, 0] + ys[:, 1].view(B, -1, W, H).transpose(dim0=2, dim1=3).contiguous().view(B, D, -1)
y_rb = ys[:, 0:2] + ys[:, 2:4].flip(dims=[-1]).view(B, 2, D, -1) # 把竖向的部分转成横向,然后再相加,再转回最初是的矩阵形式 y_rb = y_rb[:, 0] + y_rb[:, 1].view(B, -1, W, H).transpose( dim0=2, dim1=3 ).contiguous().view(B, D, -1) y_rb = y_rb.view(B, -1, H, W)
# 把斜向和反斜向的反向部分再反向回来,并和原来的斜向和反斜向相加 y_da = ys[:, 4:6] + ys[:, 6:8].flip(dims=[-1]).view(B, 2, D, -1) # 把斜向和反斜向的部分都转成原来的最初的矩阵形式,再相加 y_da = diagonal_scatter(y_da[:, 0], (B, D, H, W)) + antidiagonal_scatter( y_da[:, 1], (B, D, H, W) )
y_res = y_rb + y_da return y_res.view(B, D, -1) # return y
@staticmethod def backward(ctx, x: torch.Tensor): # B, D, L = x.shape # out: (b, k, d, l) H, W = ctx.shape B, C, L = x.shape # xs = x.new_empty((B, 4, C, L)) xs = x.new_empty((B, 8, C, L))
# 横向和竖向扫描 xs[:, 0] = x xs[:, 1] = x.view(B, C, H, W).transpose(dim0=2, dim1=3).flatten(2, 3) xs[:, 2:4] = torch.flip(xs[:, 0:2], dims=[-1]) # xs = xs.view(B, 4, C, H, W)
# 提供斜向和反斜向的扫描 xs[:, 4] = diagonal_gather(x.view(B, C, H, W)) xs[:, 5] = antidiagonal_gather(x.view(B, C, H, W)) xs[:, 6:8] = torch.flip(xs[:, 4:6], dims=[-1])
# return xs return xs.view(B, 8, C, H, W)
############### K = 4 ################由此可见,“2/4/8 scans”不是抽象表述,而是可追踪的扫描方向组合与逆向重排过程。
6. MFM:语言、掩码与模态特征的三路径融合

图6:MFM 特征可视化。掩码路径强化对象区域响应,文本路径提供全局语义一致性,两者叠加后同时改善目标显著性与整体语义稳定性。
MFM,即 Multi-Modal Fusion Module,是语义分支与融合主干的接口层。其设计并非简单拼接层,而是由原始模态补偿、掩码引导与文本调制构成的三路径机制。
首先,原始模态与上一层特征通过门控形成主路径:
随后注入掩码:
合并图像特征:
并通过奇偶层交替拼接文本:
式(18)中的交替拼接并非形式性操作,其目的在于避免文本长期位于固定序列端点造成的信息偏置。文中后续的模块替换消融中,MLP/cross-attention 对照显示 MFM 默认设计优于替代结构,说明三路径耦合提供了可测收益。
6.1 代码对照:MFM 的门控、掩码注入与文本交互
下面给出对应式(15)-式(18)的代码切块。
def forward(self, img_feat: torch.Tensor, cond_input: ConditionInput): modals = cond_input.modalities mask = cond_input.mask_input llm_feat = cond_input.llm_feature
# modalities encoder modals = self.modal_convs(modals) modals = self.modal_norm(modals)
# feature encoder img_feat = self.feat_drop(self.feat_convs(img_feat))
# feature gate img_gate = self.feat_norm(img_feat + modals) feat_pool = self.adap_pool(img_gate) feat_mul = self.act_conv(feat_pool) # to gate # m1, m2 = feat_mul.chunk(2, dim=1)
# gate img_feat_gate = img_feat * feat_mul # + img_feat modals_gate = modals * feat_mul # + modals img_feat_gate = img_feat_gate + modals_gate # img_feat_gate = img_feat * m1 + modals * m2
# fusion mask in feat and modals # then needs to be normalized if mask is not None and self.with_mask: mask = self.mask_convs(mask) img_feat_mask = img_feat * mask # + img_feat modals_mask = modals * mask # + modals img_feat = img_feat_gate + img_feat_mask + modals_mask # img_feat = img_feat_gate + img_feat_mask * m1 + modals_mask * m2 else: img_feat = img_feat_gate
# fusion feat and modals img_feat = self.fusion_proj(img_feat) # img_feat_modals = self.modals_out_ln(img_feat) img_feat_modals = img_feat
################################# LLM text and image attention ########################### # fusion image and llm features if llm_feat is not None and self.with_llm_feat: llm_feat = self.llm_drop(self.llm_dense(llm_feat)) llm_feat = self.add_llm_pe(llm_feat)
# downsample image img_feat, full_size = self.downsample_img_by_llm(img_feat)
# ESS bs, C, H, W = img_feat.shape img_feat = self.scan(img_feat).view(bs * self.K, C, H * W).transpose(1, 2) img_feat = self.add_img_pe(img_feat, H, W)
# multi-modal RWKV attention img_feat = self.multi_modal_img_ln(img_feat) llm_feat = self.multi_modal_txt_ln(llm_feat) # llm_feat: [bs, llm_len, c], img_feat: [bs, img_len, c]
if self.add_mm_tokens: multi_modal_attn_out = self.multi_modal_attn( img_feat, llm_feat.repeat(self.K, 1, 1), (H, W), (self.img_token, self.txt_token), ) else: multi_modal_attn_out = self.multi_modal_attn( img_feat, llm_feat.repeat(self.K, 1, 1), (H, W) )
# extract img and llm feat img_feat_out, llm_feat_out = self.extract_img_and_llm_feat( H, W, multi_modal_attn_out, self.with_llm_feat )
# residual img_feat = img_feat + img_feat_out llm_feat = llm_feat + llm_feat_out.view(self.K, bs, -1, C).mean(dim=0)
# merge ESS img_feat = rearrange( img_feat, "(b k) (h w) d -> b k d h w", k=self.K, h=H, w=W ) img_feat = self.merge(img_feat) / self.K
# upsample image img_feat = self.upsample_img_by_llm(img_feat, full_size)
# mlps img_feat = self.img_mlp(img_feat) + img_feat if not self.last_enc_block: llm_feat = self.txt_mlp(llm_feat) + llm_feat
# img_feat_llm = self.llm_out_ln(img_feat) img_feat_llm = img_feat
return img_feat_modals, img_feat_llm, modals, llm_feat ###########################################################################################
return img_feat_modals, modals, llm_feat从代码实现可见,MFM 并非“先拼接再卷积”的浅层融合,而是门控、掩码与文本交互的串联机制;文本交互发生在下采样后的序列域,以控制计算开销。
7. 语义分支细化:掩码生成与 mask merging

图7:掩码生成与合并流程。跨模态掩码直接并用会引入重复与错配,合并机制用于提升语义引导的一致性与鲁棒性。
在语义链路中,文中采用 Florence 生成描述、DINO 开集检测、SAM 分割实例,再将结果送入 mask merging。这个设计的必要性来自一个常被忽视的问题:不同模态对同一对象的响应强度并不一致,导致“同 prompt 下的掩码质量差异”。
文中在主文给出流程,在补充材料给出算法细节。就主文信息而言,mask merging 的作用可归纳为:
- 抑制重复实例,即 duplicate objects;
- 缓解漏检导致的语义空洞;
- 减少错位掩码对主干更新的误导。
就结果解释而言,表 VII 中“caption + merged mask”优于“caption + unmerged mask”,可以直接视为 mask merging 的实证支持。
7.1 代码对照:掩码输入规范化与训练入口
主文强调 mask merging 的算法细节在补充材料中;仓库主干训练代码中可直接看到“掩码输入规范化 + one-hot 化 + 进入融合训练”的路径。
def check_multi_value_mask_to_one_hot(self, mask: torch.Tensor | None): if mask is not None and self.has_mask: if self.multi_value_mask_max_classes is not None and mask.ndim == 3: # * if mask is dtype float, must be careful about its values, must be in [0, multi_value_mask_max_classes] # * if float value in it, the code will raise CUDA error by `scatter_` without clear hint.
mask[mask >= self.multi_value_mask_max_classes] = ( 0.0 # larger than max classes are set to background ) new_mask = torch.zeros( mask.size(0), self.multi_value_mask_max_classes, *mask.shape[-2:], device=mask.device, dtype=torch.float32, ) new_mask.scatter_(1, mask.long()[:, None], 1.0) elif mask.ndim == 4: new_mask = mask else: raise ValueError( f"Invalid mask shape: {mask.shape} when `multi_value_mask_max_classes` is {self.multi_value_mask_max_classes}", "or model has `has_mask` as True. Input mask should be 3D or 4D tensor", "such as [Bs, 1, H, W] or [Bs, *, H, W] (to be one-hot encoded in this case).", ) else: new_mask = None
return new_mask def fusion_train_step( self, vi: "torch.Tensor", ir: "torch.Tensor", mask: "torch.Tensor | None" = None, gt: "torch.Tensor | None" = None, txt: "torch.Tensor | None" = None, fusion_criterion: "Callable | None" = None, to_rgb_fn: "Callable | None" = None, has_gt: bool = False, **_kwargs, ): vi, ir, mask, txt = self.check_multi_modal_inputs( vi, ir, mask, txt, vi.device, vi.dtype ) txt = self.drop_txt(txt)
mask = self.check_multi_value_mask_to_one_hot(mask)
fused_outp = self.only_fusion_step(vi, ir, mask, txt)
# mask for loss, should be detached if mask is not None: mask_for_loss = mask.clone().detach() else: mask_for_loss = None
fused_for_loss = to_rgb_fn(fused_outp) if to_rgb_fn is not None else fused_outp if ( has_gt or gt.size(1) == 3 ): # TODO: find more robust way to check if gt is available # if we have supervised GT, we use it to compute the supervised loss # sometimes, for MEF fusion task, we can access the GTs fusion_gt = gt # two different modalities for GT to compute the unsupervised loss boundary_gt = (vi, ir) else: fusion_gt = None boundary_gt = (vi, ir)
# compute supervised and unsupervised losses assert fusion_criterion is not None, "fusion_criterion should be provided" loss = list( fusion_criterion( fused_for_loss, boundary_gt=boundary_gt, fusion_gt=fusion_gt, mask=mask_for_loss, ) )
return fused_for_loss.clip(0, 1), loss
@torch.no_grad()这表明 mask 引导在工程实现中属于训练闭环的必要输入路径。
8. 损失函数设计:监督场景分流
文中按任务监督属性划分损失函数,而非以“统一损失”覆盖所有任务。
8.1 有监督:HMIF / Pansharpening
式(19)由像素一致性与结构一致性共同约束。其意义是避免单纯 L1 导致结构退化,也避免单纯结构项造成光谱/亮度偏移。
8.2 无监督:VIF / MIF / MEF / MFF
式(20)–式(23)对应典型的“强度-结构-边缘”三约束配比。对于无 GT 的融合任务,这是可解释性较强、工程上较稳定的选择:强度项控制内容保留,SSIM 项控制结构一致,梯度项控制细节锐度。
8.3 代码对照:损失实现与文中公式映射
仓库中损失实现位于 utils/loss_utils.py 的 DRMFFusionLoss.forward。实现是“文中损失族”的通用化版本,支持有监督/无监督、多任务权重与可选掩码约束。
def forward( self, img_fusion: Tensor, boundary_gt: "Tensor | tuple", # cat([vi, ir]) or tuple(vi, ir) fusion_gt: "Tensor" = None, # ground truth provided in dataset mask: "Tensor | None" = None, ) -> tuple[torch.Tensor, dict[str, torch.Tensor]]: if mask is not None and self.mask_loss: self.check_dtype_and_device(img_fusion, boundary_gt, mask) with torch.no_grad(): if self.reduce_label: mask2 = mask.detach().clone() mask2[mask2 > 1.0] = 1.0 else: mask2 = mask
## TODO: consider this case # mask.size(1) == 2, two boundaries all have masks if mask2.size(1) == 2: assert False, "mask.size(1) == 2, two boundaries all have masks" mask_B, mask_A = mask2.chunk(2, dim=1) else: mask_B, mask_A = mask2, mask2 elif not self.mask_loss: mask2 = None # cast to None
self.check_dtype_and_device(img_fusion, boundary_gt)
wd = self.weight_dict loss_intensity = 0 loss_color = 0 loss_grad = 0 loss_fusion = 0 loss = {}
# split boundary gt no_batch_ndim = img_fusion.ndim - 1 broadcast_fn = lambda x: x.reshape(-1, *[1] * no_batch_ndim) # noqa: py3.11 supported
# img_A: ir/under/MRI, img_B: vi/over/spect if isinstance(boundary_gt, (tuple, list)): img_B, img_A = boundary_gt else: img_B, img_A = self.split_boundary_gt_tensor(boundary_gt)
# if has gt (e.g., when we train model on multi-exposure image fusion task) if exist_gt := exists(fusion_gt): gt_loss = wd["fusion_gt"] * self.loss_func(img_fusion, fusion_gt).nanmean() loss["gt_loss"] = gt_loss loss_fusion += gt_loss
if self.grad_norm: img_fusion = grad_norm(img_fusion)
# ir, vi (detached in this function) (img_A, A_is_color), (img_B, B_is_color) = ( self.check_rgb(img_A), self.check_rgb(img_B), )
## vgg latent weights # the latent weights comes from U2Fusion paper if self.latent_weighted: vi_w, ir_w = self.dynamic_weight(img_A, img_B) vi_w = broadcast_fn(vi_w) ir_w = broadcast_fn(ir_w) else: vi_w, ir_w = 1.0, 1.0
# YCbCr decomposition # use for intensity, color, and gradient loss Y_fusion, Cb_fusion, Cr_fusion = kornia.color.rgb_to_ycbcr(img_fusion).chunk( 3, dim=1 ) Y_A, Cb_A, Cr_A = kornia.color.rgb_to_ycbcr(img_A).chunk(3, dim=1) # ir Y_B, Cb_B, Cr_B = kornia.color.rgb_to_ycbcr(img_B).chunk(3, dim=1) # vi
## intensity and color loss # * compute boundary intensity and color loss when GT is not provided # * or we omit computing the loss # * in this case, we can say intensity and color are supervised by the GT # * we do not need to use the unsupervised loss # * `still_boundary_loss_when_gt` is to keep the boundary loss when GT is provided (optional) if (not exist_gt) or self.still_boundary_loss_when_gt: if self.prior == "max": Y_joint = torch.max(Y_A, Y_B) elif self.prior == "mean": Y_joint = (Y_A + Y_B) / 2
if self.grad_norm: Y_fusion = grad_norm(Y_fusion) Cb_fusion = grad_norm(Cb_fusion) Cr_fusion = grad_norm(Cr_fusion)
if mask2 is not None: # * intensity loss # 1. || fusion - max(vi, ir) || max fusion # 2. || mask * fusion - mask * ir || ir loss (pedistrain) # 3. || (1 - mask) * fusion - (1 - mask) * vi || vi loss (background)
loss_intensity = ( ( wd["inten_f_joint"] * self.loss_func(Y_fusion, Y_joint) if self.use_prior else 0.0 ) + ( wd["inten_f_ir"] * ir_w * self.loss_func(mask2 * Y_fusion, mask2 * Y_A) if self.boundary else 0.0 ) + ( wd["inten_f_vi"] * vi_w * self.loss_func(Y_fusion * (1 - mask2), Y_B * (1 - mask2)) if self.boundary else 0.0 ) ) # * color loss # 1. || (1-mask) * fusion_Cb - (1-mask) * ir_Cb || ir_Cb loss # 2. || (1-mask) * fusion_Cr - (1-mask) * ir_Cr || ir_Cr loss if self.color_loss_bg_masked: bg_mask = 1.0 - mask2 else: bg_mask = torch.ones_like(mask2) if self.color_loss: loss_color = 0.0 if B_is_color: loss_color += wd["color_f_cb"] * vi_w * self.loss_func( Cb_fusion * bg_mask, Cb_B * bg_mask ) + wd["color_f_cr"] * ir_w * self.loss_func( Cr_fusion * bg_mask, Cr_B * bg_mask ) if A_is_color: loss_color += wd["color_f_cb"] * ir_w * self.loss_func( Cb_fusion * bg_mask, Cb_A * bg_mask ) + wd["color_f_cr"] * vi_w * self.loss_func( Cr_fusion * bg_mask, Cr_A * bg_mask ) else: loss_intensity = ( ( wd["inten_f_joint"] * self.loss_func(Y_fusion, Y_joint) if self.use_prior else 0.0 ) + ( wd["inten_f_ir"] * ir_w * self.loss_func(Y_fusion, Y_A) if self.boundary else 0.0 ) + ( wd["inten_f_vi"] * vi_w * self.loss_func(Y_fusion, Y_B) if self.boundary else 0.0 ) )
if self.color_loss: loss_color = 0.0 if B_is_color: loss_color += wd["color_f_cb"] * self.loss_func( Cb_fusion, Cb_B ) + wd["color_f_cr"] * self.loss_func(Cr_fusion, Cr_B) if A_is_color: loss_color += wd["color_f_cb"] * self.loss_func( Cb_fusion, Cb_A ) + wd["color_f_cr"] * self.loss_func(Cr_fusion, Cr_A)
loss_intensity = loss_intensity.nanmean() loss_fusion += loss_intensity loss["intensity_loss"] = loss_intensity if self.color_loss: loss_color = loss_color.nanmean() loss_fusion += loss_color loss["loss_color"] = loss_color
## lpips loss # * lpips loss is to enhance perceptual visuality if self.lpips: lpips_loss = 0.0 if exist_gt: lpips_loss += self.lpips_loss(img_fusion, fusion_gt) * wd["lpips_f_gt"]
# if self.still_boundary_loss_when_gt or (not exist_gt): # lpips_loss += self.lpips_loss(img_fusion, img_A) * wd['lpips_f_ir'] + \ # self.lpips_loss(img_fusion, img_B) * wd['lpips_f_vi'] # if self.use_prior: # if self.prior == 'max': # img_joint = torch.max(img_A, img_B) # elif self.prior == 'mean': # img_joint = (img_A + img_B) / 2 # lpips_loss += self.lpips_loss(img_fusion, img_joint) * wd['lpips_f_joint']
loss_fusion += lpips_loss loss["lpips_loss"] = lpips_loss
## grad loss # * gradient loss is to enhance the fused image # * keep it although the GT is provided if self.grad: if self.grad_only_on_Y: grad_A = self.grad_op(Y_A) grad_B = self.grad_op(Y_B) grad_fusion = self.grad_op(Y_fusion) else: grad_A = self.grad_op(img_A) grad_B = self.grad_op(img_B) grad_fusion = self.grad_op(img_fusion)
if self.grad_norm: grad_fusion = grad_norm(grad_fusion)
grad_joint = torch.max(grad_A, grad_B) loss_grad += wd["grad_f_joint"] * self.loss_func(grad_fusion, grad_joint) # if mask is not None: # mask_expand = mask # if grad_fusion.ndim > 4: # mask_expand = mask_expand.unsqueeze(1) # loss_grad += wd['grad_f_ir'] * self.loss_func(grad_fusion * mask_expand, grad_A * mask_expand)
# loss_grad += wd['grad_f_ir'] * self.loss_func(grad_fusion, grad_A) + \ # wd['grad_f_vi'] * self.loss_func(grad_fusion, grad_B)
loss_grad = loss_grad.nanmean()
loss_fusion += loss_grad loss.update({"loss_grad": loss_grad})
## ssim loss # * ssim loss is to enhance the fused image # * keep it although the GT is provided if self.ssim: w_A = grad_A.norm() w_B = grad_B.norm() Z = w_A + w_B w_A /= Z w_B /= Z
if self.grad_only_on_Y: ssim_A = self.ssim_func(Y_fusion, Y_A) ssim_B = self.ssim_func(Y_fusion, Y_B) else: ssim_A = self.ssim_func(img_fusion, img_A) ssim_B = self.ssim_func(img_fusion, img_B) loss_ssim = wd["ssim_f_joint"] * (w_A * ssim_A + w_B * ssim_B) loss_fusion += loss_ssim loss.update({"loss_ssim": loss_ssim})
## tv loss if self.tv: img_fusion_tv = img_fusion if self.grad_norm: img_fusion_tv = grad_norm(img_fusion_tv) tv_loss = wd["tv_f"] * self.tv_loss(img_fusion_tv).nanmean() loss_fusion += tv_loss loss.update({"tv_loss": tv_loss})
## correlation loss if self.correlation: img_fusion_tv = img_fusion if self.grad_norm: img_fusion_tv = grad_norm(img_fusion_tv) loss_corr = wd["crr_f"] * self.cc_loss(img_A, img_B, img_fusion_tv) loss_fusion += loss_corr loss.update({"loss_corr": loss_corr})
loss.update({"loss_fusion": loss_fusion})
return loss_fusion, loss
## EMMA stage two fusion training lossloss_cfg: drmffusion: latent_weighted: no grad_loss: yes color_loss: no # VIF: yes, others: no ssim_loss: yes prior: null boundary_loss: yes mask_loss: no # VIF: yes, others: no lpips_loss: no color_loss_bg_masked: yes tv_loss: no pseudo_l1_const: 0. # [0.002, 0.] correlation_loss: no reduce_label: no # VIF: no grad_only_on_Y: no grad_op: "sobel_add" still_boundary_loss_when_gt: yes grad_norm: no ssim_implm_by: torch ssim_window_size: 5 weight_dict: fusion_gt: 5.
inten_f_joint: 10. # prior used inten_f_ir: 10. inten_f_vi: 10.
color_f_cb: 2. color_f_cr: 2.
grad_f_joint: 20. # [5, 40] for VIF and MEF ssim_f_joint: 2.
lpips_f_gt: 1.
tv_f: 0.1 crr_f: 0.02
# unused lpips_f_joint: 0.2 lpips_f_ir: 0.2 lpips_f_vi: 0.2
filmfusion: prior: mean weight_grad: 1000 weight_ssim: 1. ssim_window_size: 5
l1ssim: weighted_r: [1.0, 0.1] implem_by: torch window_size: 5 grad_norm: no因此,式(19)-式(23)在仓库中并未采用逐式静态硬编码,而是通过统一损失框架按任务配置激活并加权;这也是其可覆盖 VIF/MEF/MFF/MIF/HMIF/Pansharpening 多任务训练流程的工程基础。
9. 实验设置与评测协议
9.1 任务覆盖与数据集
文中覆盖六类融合任务:
- VIF:MSRS、M3FD、TNO;
- MIF:Medical Harvard;
- MEF:SICE、MEFB;
- MFF:MFI-WHU、RealMFF;
- Pansharpening:WV3、GF2、QB;
- HMIF:Chikusei、Pavia。
这一覆盖范围保证结论不局限于单任务分布。尤其是同时包含跨传感器融合,如 VIF、Pansharpening、HMIF,以及同传感器参数差异融合,如 MEF、MFF,使方法泛化判断更可信。
9.2 基线与指标
文中基线覆盖 decomposition、task-designed、prior-based、architecture-designed、modality-guided 等多类方法。指标体系包含 MI、VIF、SF、Qcb、Qy、Qcv、Qabf、LPIPS,以及 SAM、ERGAS、Q2n、HQNR、PSNR、SSIM 等。
从评估逻辑看,这套指标组合同时约束了信息量、结构质量、感知一致性与遥感光谱质量,避免“单指标最优”误导。
9.3 语义输入配置与任务差异化处理
文中在语义输入上并非“一套配置跑全部任务”,而是根据任务属性做了条件化策略,这一点对复现实验非常关键:
- 在 VIF 场景中,M3FD 采用固定 prompt:People、Car、Bus、Lamp、Motorcycle、Truck,并执行 mask merging;
- MSRS 直接使用数据集给出的人工掩码;
- MEF/MFF/MIF 由于场景对象差异较大,采用 Florence 自动提取 prompt 以支持开集语义;
- Pansharpening 与 HMIF 因训练样本常用 64×64 小块,文中设置为仅保留语言引导,不使用 mask 引导。
这个配置策略背后有明确原则:当空间尺寸较小、对象边界不稳定或掩码噪声可能压过有效信息时,强行注入 mask 反而可能引入误导;而在目标显著、语义对象明确的场景,如 VIF,mask 的收益更容易释放。换言之,RWKVFusion 并非在所有任务中执行等强度语义注入,而是依据数据形态进行可解释的注入强度调度。
9.4 指标解释的组织顺序
文中指标较多,若不做层次化组织,容易出现“指标堆叠但结论模糊”的问题。本文采用如下解释顺序:
- 先看任务核心指标,如遥感任务的 SAM/ERGAS/Q2n/HQNR,VIF 的 MI/VIF/Q 系列;
- 再看感知相关指标,如 LPIPS 与结构保持,判断是否存在“分数上升但视觉退化”;
- 最后结合可视化图 Fig.8、Fig.9,验证指标变化是否对应可解释的视觉差异。
按这个顺序,RWKVFusion 的优势会更清晰:它在多数任务上不是单一指标尖峰,而是在“信息量、结构一致、下游可用性”三条轴上同时保持稳定。这也是图1 雷达图呈现外扩包络的根本原因。
10. 主实验结果:结论与证据链
10.1 VIF 与 MIF
说明:本文 Table II-Table X 均使用文中原表的独立裁剪小图,不使用整页截图与 Markdown 表格。

表 II:VIF 与 MIF 的定量结果。RWKVFusion 在 MSRS、M3FD、TNO 与 Medical Harvard 多数据集上保持较高占优比,优势并非来自单一数据集的偶然波动。
文中对 VIF 的报告显示:
- 在 MSRS 上,8 项指标中 7 项最优;
- 在 M3FD、TNO 上,多数指标第一或第二;
- 在 Medical Harvard 数据集 MIF 上,除 LPIPS 外其余指标领先。
文中节选值包括:M3FD 上 MI 2.57、VIF 0.78、Qabf 0.70、Qy 0.96;Medical Harvard 上 MI 2.02、VIF 0.57、SF 22.03、Qy 0.90。该证据链支持一个关键结论:语义引导并没有以牺牲纹理/结构指标为代价,而是在目标保持与视觉一致性之间取得了联合收益。
10.2 MEF 与 MFF

表 III:MEF 与 MFF 的定量结果。RWKVFusion 在多曝光与多焦点场景中保持较稳的综合领先,尤其在 SICE、MEFB 与 RealMFF 上优势更明显。

图8:M3FD、SICE、RealMFF 的可视化比较。RWKVFusion 在烟雾遮挡、曝光失衡、离焦区域等困难场景下,兼顾显著目标保留与背景纹理稳定。
结合 Table III 与图8 可得到更完整的证据链:
- 定量上,RWKVFusion 在 SICE、MEFB、MFI-WHU、RealMFF 的多数指标中处于领先;
- 定性上,方法在高亮区和暗区之间维持更一致的曝光平衡,并在焦内/焦外边界处保留更完整纹理。
文中节选值示例:SICE 上 MI 3.81、VIF 0.93、Qabf 0.78、Qy 0.96;RealMFF 上 MI 4.90、VIF 1.32、Qcb 0.74、LPIPS 0.200。
10.3 Pansharpening 与 HMIF

表 IV:WV3 上的 pansharpening 定量结果。RWKVFusion 在 RR 与 FR 协议下均表现突出,并兼顾光谱质量与结构保真。

表 V:Chikusei 与 Pavia 的 HMIF 定量结果。RWKVFusion 在主要质量指标上占优,同时保持较好的参数/FLOPs 效率。

图9:Pansharpening 与 HMIF 的误差图。RWKVFusion 在细节边缘与结构过渡区的残差更低,尤其在高频结构区域更明显。
文中给出的关键数值包括:
- WV3:SAM 2.78、ERGAS 2.03、Q2n 0.918、SCC 0.988、 0.016、 0.036、HQNR 0.949;
- Chikusei:PSNR 43.89、SAM 1.93、ERGAS 3.33、SSIM 0.963;
- Pavia:PSNR 36.06、SAM 3.95、ERGAS 3.07、SSIM 0.936。
文中同时指出,在 HMIF 对比中,RWKVFusion 相对 DHIF 仅使用约 8.41% 参数量和 0.67% FLOPs。该“性能 + 复杂度”联合结果,是其方法价值的重要支撑。
10.4 跨任务一致性与结果边界
把 VIF/MIF/MEF/MFF 与 Pansharpening/HMIF 放在一起看,可以观察到 RWKVFusion 的一个鲜明特征: 优势不依赖单一模态组合或单一退化类型。前四类任务主要考察自然图像层面的纹理与目标保留,后两类任务强调光谱一致性与空间细节重建;RWKVFusion 在这两类评价体系中都能给出正向结果,说明其主干设计并未绑定于某一任务先验。
但该一致性并不意味着“无条件全胜”。从文中呈现可见,个别指标上仍可能出现次优,这与任务差异、上游语义质量和评价指标偏好有关。这里应避免把文中结论简化成“全指标第一”,更准确的表述是:RWKVFusion 在跨任务综合表现和性能-复杂度平衡上具备显著优势。
11. 消融实验:机制有效性的因果验证

表 VI:算子替换、扫描策略、MFM 设计与 plain backbone 的主消融结果。RWKVFusion 的性能收益来自结构协同,而非单个组件的偶然增益。
11.1 BRWKV 与替代算子
文中将 BRWKV 替换为 flash attention、flatten attention、window attention、VMamba。结果显示 BRWKV 在大多数指标上更稳。该结论说明性能提升并非来自“简单扩大参数规模”,而是来自 WKV 机制与融合任务需求匹配。
11.2 ESS 扫描策略
消融显示:4 scans 并未稳定优于默认配置;8 scans 在部分指标略好但计算代价增大。结论是默认 ESS 在效率-性能之间达到更优折中。
11.3 MFM 结构替换
将 MFM 替换为简单 MLP 或 cross-attention 变体后,指标下降。该结果与图6 可视化一致,说明三路径语义注入 raw/modality + mask + text 具有不可替代的机制价值。
11.4 语义引导与 mask merging

表 VII:语义引导与 mask merging 消融结果。caption 与 merged mask 的组合优于单一路径与未合并掩码。
在 MSRS 上,文中报告:
- only caption:MI 3.20、VIF 0.76;
- only merged mask:MI 3.36、VIF 0.78;
- no guidance:MI 3.10、VIF 0.69;
- caption + unmerged mask:MI 3.38、VIF 0.84;
- caption + merged mask,默认设置:MI 3.42、VIF 0.87。
该组结果支持以下两条因果判断:
- caption 与 mask 是互补信号,不是冗余信号;
- mask merging 的收益是独立可验证的,不是与 caption 强绑定后的偶然增益。
11.5 Prompt 设定

表 VIII:auto-prompt 与 fixed-prompt 的消融结果。fixed-prompt 略优,但两种设定下方法整体都保持较强竞争力。
fixed-prompt 略优于 auto-prompt,但二者均优于多数对比方法。该现象说明:开放式语义引导是可行的,但上游提示词质量仍会影响最终融合上限。
11.6 Plain 与 Multi-scale
文中在 plain backbone 对照下显示多尺度结构更优,且参数量同量级。对于像素级融合任务,这说明“跨尺度上下文传递”仍是必要条件,而非可随意替代的结构装饰。
11.7 ERF 证据

图10:不同算子 CNN、Mamba、Attention、RWKV 的有效感受野比较。RWKVFusion 的 ERF 更广且响应更集中,支撑其“低代价全局建模”主张。
12. 下游任务验证:融合质量是否可迁移
下游任务部分的关键问题是:融合结果是否真正提升感知模型表现,而不仅是视觉观感更“清晰”。这里分别给出 Table IX 语义分割与 Table X 目标检测,并在对应小节逐项解读。
12.1 单目深度估计

图11:Depth Anything v2 的可视化结果对比。RWKVFusion 生成的融合图在轮廓连续性和远景层次上更有利于深度估计。
该部分文中给出的是可视化证据。统一量化深度指标,如 Abs Rel、RMSE,文中未给出/未报告。
12.2 语义分割

图12:SegFormer 分割结果对比。RWKVFusion 在小目标与边界区域的分割连续性更好,对 mIoU 提升有直接贡献。

表 IX:MSRS 数据集上的语义分割定量结果,含各类别 IoU、mIoU、mAcc。RWKVFusion 在总体分割指标上保持领先,即 mIoU 79.61、mAcc 88.72,并在关键类别上呈现更稳定表现。
文中表格结果显示:Proposed 的 mIoU 为 79.61、mAcc 为 88.72,高于 FILM 79.43/88.64 与 TextIF 79.27/88.30。这说明语义增强的融合结果并非仅提升视觉观感,而是能够实质性提升语义任务表现。
12.3 目标检测

图13:YOLOv5 检测可视化,Prediction 与 GT 对照。RWKVFusion 在多数检测场景下提供更高的目标可分辨性,特别是行人与车辆类别。

表 X:MSRS 数据集目标检测定量结果,People、Car、mAP@0.5、mAP0.5:0.9。RWKVFusion 在多数检测指标上占优,但在 mAP0.5:0.9 上并非绝对最优。
文中给出关键数值:Proposed 在 People 0.966、Car 0.847、mAP@0.5 0.907、mAP0.5:0.9 0.697。需要如实指出,SegMIF 在 mAP0.5:0.9 上为 0.703,高于 RWKVFusion。即 RWKVFusion 不是“所有检测指标绝对第一”,而是多数指标上的更优综合表现。
13. 局限性与进一步思考
基于主文与结果呈现方式,可以归纳出以下边界条件:
- 上游语义链路误差传导:caption、检测框、掩码质量会直接影响融合效果;
- 主干高效不等于全链路高效:主文重点报告主干复杂度优势,但包含 Florence/DINO/SAM 的统一端到端时延,文中未给出/未报告;
- 小尺寸遥感训练块中弱化掩码分支:在 pansharpening 与 HMIF 中,因训练块尺寸较小,文中给出的尺寸为 64×64,实践上省略 mask,仅保留语言引导;
- 深度任务量化不完整:深度估计仅给出可视化,统一数值指标文中未给出/未报告。
这些限制并不削弱文中主结论,但提示后续方向应聚焦于:更轻量的语义生成链路、端到端系统级时延评估、任务自适应的语义注入策略,以及更完整的跨任务量化协议。
13.1 面向复现与改进的具体启示
如果以“可复现可扩展”为目标,这项工作给出的直接启示至少有三点。第一,语义分支应被视为可替换部件,而不是固定实现。主文证明了“语言 + 掩码”这类条件本身有效,但并未限定必须使用 Florence/DINO/SAM 这一组合,因此后续工作可以在不改动主干的前提下,用更轻量的语义生成器替换上游链路。第二,mask 不是越多越好,关键是质量控制。文中通过 merged 与 unmerged 对照已经说明,不受控的掩码输入会稀释语义收益,甚至向主干注入噪声。第三,融合研究不应只停留在融合指标,必须绑定下游任务检验。本文在检测和分割上的结果提示我们:真正高价值的融合表示,应该在“视觉可读性”和“机器可判别性”两端同时成立。
从研究方法论看,RWKVFusion 也提供了一个可迁移范式:先在任务定义层明确缺口,再在结构层给出针对性设计,最后用主结果、消融、下游三层证据闭环验证。这个范式对于后续多模态低层视觉任务,不仅限于图像融合,同样具有参考意义。
14. 机制级复盘:从输入到输出的一次完整前向
为系统梳理本文的方法逻辑,可沿一次前向传播完整追踪信息流。这样做的价值在于,不少解读停留在模块名词罗列,如 RWKV、MFM、mask,但未回答“这些模块在计算图中的时序作用与相互约束关系”。RWKVFusion 的设计恰恰依赖严格时序:语义先生成、再注入、再跨尺度传播、最后重建。
14.1 输入阶段:多模态图像与语义条件的并行准备
在输入端,图像模态 进入融合分支;caption 与 mask 进入语义分支。两条分支虽为并行结构,但并不独立:语义分支输出是后续编码层的必需条件,因此在系统上属于前置条件准备,而非可选附加通道。这个设计与许多把语义当作后验打分器的融合方法有本质区别。
需要强调,文中没有把语义条件定义成单一向量,而是拆分为文本条件 与空间条件 。这种拆分背后的假设是:全局语义一致性和局部目标定位是两个不同层面的约束,不能由单一信号替代。后续表 VII 的结果也验证了这一点:只用 caption 或只用 merged mask 都弱于二者联合。
14.2 编码阶段:语义条件在多尺度主干中的层内注入
进入编码端后,RWKVFusion 并非“先完成融合再叠加语义”,而是在每个编码层通过 MFM 执行条件注入。其关键效果是把语义约束从输出端前移到特征形成端,减少后期补偿式修正造成的信息损失。
更具体地说,MFM 在每层做三件事:
- 保留并重校准原始模态信息,避免语义引导导致低层纹理被过度抑制;
- 利用掩码路径在空间上强化对象响应,把“哪里重要”显式落到特征图上;
- 通过文本交替拼接给出全局语义方向,把“该保留何种语义关系”加入序列建模。
这三步并不是并列拼贴,而是存在先后依赖:先有图像主路径,掩码对其进行空间约束,文本再做全局语义调制。若把顺序打乱,低层信息与高层语义的耦合稳定性会下降,这也能解释为何简单 MLP 或双 cross-attention 替代不能达到默认 MFM 的效果。
14.3 核心算子阶段:BRWKV 负责“全局关系建模 + 线性代价”
在编码层内部,BRWKV 分为空间混合和通道混合两部分。空间混合阶段通过 WKV 聚合全局信息,通道混合阶段进行非线性重标定。该结构与 Transformer block 的功能映射相似,但实现机制不同:RWKV 不显式构造完整注意力矩阵,而是依赖衰减记忆与递推读取。
从任务角度看,这一点对图像融合非常关键。融合任务不是高层语义分类,而是像素级重建;输入常是高分辨率图像,token 长度大,显存和 FLOPs 都更敏感。RWKV 的意义并非“理论上更优雅”,而是为像素级任务提供了更可执行的复杂度路径。
14.4 解码与重建阶段:语义不再重复注入的原因
主文中解码器不再继续注入 caption/mask,该设计具有明确针对性。解码器职责是将融合特征恢复到像素域并保留细节连续性;如果在该阶段继续强语义注入,容易把重建过程变成语义重写过程,导致纹理不自然或边缘伪影。
因此,语义主要在编码端“定向”,解码端主要“还原”。这是本文设计中一个容易忽略但非常重要的工程细节: 语义引导与图像重建分阶段治理,而不是在每个阶段都做同强度语义干预。
14.5 系统层视角:主干高效与全链路开销的关系
文中在主文中清晰给出了 BRWKV 主干的复杂度优势,并在遥感任务中给出参数/FLOPs 对比。但在系统层面仍应区分两件事:
- 主干前向开销,RWKVFusion 网络本体;
- 语义准备开销,Florence、DINO、SAM、T5。
前者文中报告充分,后者主文没有统一端到端时延统计。也就是说,“主干高效”这个结论是成立的,但“整链路总时延最优”在主文证据范围内尚不能直接推出。这一点在工程复现时尤其要注意。
15. 证据链交叉验证:主实验、消融、下游如何互相支撑
本工作把“结果—机制—迁移”三层证据连成闭环:主实验证明有效,消融解释来源,下游验证可迁移。需要注意的边界也很明确:部分下游指标并非绝对第一,深度估计只给出可视化证据。
16. 总结
RWKVFusion 的学术价值不在于“又一个融合网络”,而在于其把三个长期分离的问题放进了同一证据闭环:
- 在任务定义层面,用式(5)把语言与掩码变为融合前向条件;
- 在网络算子层面,用 BRWKV + ESS 构建线性主开销的全局建模路径;
- 在实证层面,用跨任务主实验、结构消融与下游验证证明收益来源。
进一步看,这篇工作真正值得借鉴的是方法论:先明确任务缺口,再给出结构设计,再用消融和下游任务建立因果与迁移证据。对后续研究而言,可延展方向包括:更轻量的语义生成链路、更完整的端到端时延评估、以及面向任务差异的自适应语义注入机制。
若将该工作置于近两年的融合研究脉络中考察,其关键贡献并不止于“语言引导”标签,而在于语义条件、主干效率与跨任务验证的同步落地。许多方法只在其中一维突出,而 RWKVFusion 的价值在于让三维同时成立并可复核。由此可见,评价重点不应停留在单一指标增量,而应落在研究设计是否形成可迁移、可解释、可工程化的完整闭环。该视角也有助于后续选题判断。
参考
- Cao Z.-H., Liang Y.-J., Deng L.-J., Vivone G., An Efficient Image Fusion Network Exploiting Unifying Language and Mask Guidance, IEEE TPAMI, 2025.
- 代码: https://github.com/294coder/RWKVFusion
文章分享
如果这篇文章对你有帮助,欢迎分享给更多人!