输入预处理
在调用模型之前,Processor会先把文本和图像处理好,我们先来看看图像预处理过程:
图像预处理
假设每张图像(numpy格式)大小为1420 * 720 *3 (HWC)。每张图经过Image Processor的如下操作:
- smart_resize,图像的长宽都会变成
factor=grid_size * merge_size = 14 * 2的倍数, 即1428*728*3. - 把图像的通道维度换到第二维, 并且添加时间维度, 即
(1, 3, 1428, 728). - 在时间维度上复制一份, 变成
(2, 3, 1428, 728),记为(t, c, H, W) - 网格化, 时间维度一格是两张图, H维度一格是两个patch, 一个patch14个像素, W和H一样. 现在shape变为(1, 2, 3, grid_h // 2, 2, 14, grid_w // 2, 2, 14), 记为
(t, t_patch_size, c, h // merge_size, merge_size, patch_size, w // merge_size, merge_size, patch_size) - 调整view, 变成
(t, h // merge_size, w // merge_size, merge_size, merge_size, c, t_patch_size, patch_size, patch_size) - thw和后面的
merge_size * merge_size乘起来, 后三组也乘起来, 变成(t*h*w, 3*2*14*14=1176)
同时我们拿到每个图片的grid尺寸(t, h, w)。
Image Processor最后返回两个东西:
- pixel_values: (n, 1176), 是每个图片单独的patches在第一个维度上extend起来的东西
- image_grid_thw: (N, 3), N是图片张数, 这个东西包含每个图片的thw值.
文本预处理
假设我们的文本是a b c d <|vision_start|><|image_pad|><|vision_end|> e f g <|vision_start|><|image_pad|><|vision_end|>,在
处理的时候,<|image_pad|>会被扩展成t_grid * w_grid // MERGE_SIZE * h_grid // MERGE_SIZE个<|image_pad|>,最后tokenize成为input_ids。
最终我们拿到
input_ids:[B, seqlen]pixel_value:[n*t_grid*w_grid*h_gird, channel*temporal*w*h]grid_thw:[n, 3]其中n为图片个数,注意根据grid_thw我们可以知道pixel_value中每个patch对应哪一张图片的哪一个位置,这对我们后面计算二维ROPE很有用
prepare_inputs_for_genration
不知道这个函数有啥作用的话建议先去看GenerateMixin里面的generate方法。在generate方法里面会检查各种生成参数,设置KV Cache,设置Attention Mask,并根据你的
Generate Mode去使用对应的生成策略。我们使用的是Sample策略,也是最简单的一种。于是被分发到_sample方法中,在那里面我们会在调用模型forward方法之前先调用模型的prepare_inputs_for_generation方法来处理输入。
首先,获取缓存部分的input_ids,如果没有位置编码的话,要利用self.get_rope_index根据文本部分的input_ids和图片部分的image_grid_thw来计算位置编码的index。计算方案如下
Calculate the 3D rope index based on image and video's temporal, height and width in LLM.
Explanation:
Each embedding sequence contains vision embedding and text embedding or just contains text embedding.
For pure text embedding sequence, the rotary position embedding has no difference with mordern LLMs.
Examples:
input_ids: [T T T T T], here T is for text.
temporal position_ids: [0, 1, 2, 3, 4]
height position_ids: [0, 1, 2, 3, 4]
width position_ids: [0, 1, 2, 3, 4]
For vision and text embedding sequence, we calculate 3D rotary position embedding for vision part
and 1D rotary position embeddin for text part.
Examples:
Assume we have a video input with 3 temporal patches, 2 height patches and 2 width patches.
input_ids: [V V V V V V V V V V V V T T T T T], here V is for vision.
vision temporal position_ids: [0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2]
vision height position_ids: [0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1]
vision width position_ids: [0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1]
text temporal position_ids: [3, 4, 5, 6, 7]
text height position_ids: [3, 4, 5, 6, 7]
text width position_ids: [3, 4, 5, 6, 7]
Here we calculate the text start position_ids as the max vision position_ids plus 1.
最后vision部分的index会和text部分的index在序列维度上拼接起来,形成一个形状为[3, B, seqlen]的position_ids,以及一个形状为[B, 1]的mrope_position_delta。
后者是序列中最大index和实际seqlen的差别(有图片的话就会有差别),肯定小于等于0。
注意如果使用StaticCache,那么attention_mask要做适配
最后把计算出来的position_ids,rope_deltas,attention_mask更新到model_inputs中并返回。
Qwen2VLForConditionalGeneration的forward方法
终于来到这里了,首先input_ids会被embed,变成inputs_embeds,形状为[B, seqlen, config.hidden_size=1536],然后pixel_values会被模型中的视觉头self.vision进行编码,形成image_embeds。这个self.vision是一个ViT网络
图像patches的编码–ViT
这一部分在Qwen2VisionTransformerPretrainedModel里面
对于输入的[num_patches, 3*2*14*14=1176]的图片,我们首先进行一次Conv3d的投影,得到[num_patches, embed_dim]的hidden_states,这里embed_dim=1280
随后我们计算图像上的相向量,
-
第一步是获得每个patch的行号和列号
pos_ids,这可以根据grid_thw很方便地计算出来,这个pos_ids的形状为[num_patches, 2] -
第二步是计算rotary embedding中的
mθ,对于每一个位置,有一个长度为head_dim // 4的inv_freq向量:1 / base**(torch.arange(0, head_dim // 2, 2, dtype=torch.float) / (head_dim // 2))。这里因为num_head=16,所以head_dim = 1280 / 16 = 80,inv_freq长度为20。 对于一维序列中的位置m,该位置的相向量为m * inv_freq。所以,输入一个长度seqlen,我们可以先预备好一个形状为[seqlen, head_dim//4]的相谱:seq = torch.arange(seqlen) freqs = torch.outer(seq, self.inv_freq) -
第三步是根据每个patch的行号列号去索引相谱,得到该patch两个方向的相向量
rotary_pos_embrotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)索引出来本来形状是
[num_patches, 2, head_dim // 4],但是把第一个维度之后全部flatten了,所以最后的形状是[num_patches, head_dim // 2 = 40],用θ来表示20个θ的话,那么现在rotary_pos_emb每个位置上是[x θ,y θ]补充说明:为什么freqs的长度是
head_dim // 4? 参考苏神的文章,对于二维位置编码,旋转矩阵是4*4的对角分块矩阵,左上角对应xθ,右下角对应yθ。那么对于一个维度为head_dim的feature,需要的θ个数就是head_dim // 4
接下来计算cu_seqlens,即每个图片的patch范围,在我们的例子里cu_seqlens = [0, 5304, 10608]
接下来过16层AttentionBlock,在每个AttentionBlock里面进行一次attention,一次FFN。这一部分在Qwen2VLVisionBlock里面。
def forward(self, hidden_states, cu_seqlens, rotary_pos_emb) -> torch.Tensor:
hidden_states = hidden_states + self.attn(self.norm1(hidden_states), cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb)
hidden_states = hidden_states + self.mlp(self.norm2(hidden_states))
return hidden_states
每次Attention是由VisionSdpaAttention完成的,它的计算如下。可以看到,每一层都算了一次位置编码并且apply:
- 首先计算多头QKV,每个的形状都是
[seqlen, num_heads=16, head_dim=80] - 对QK增添位置编码,以Q为例。
- 把Q扩充一个维度,变成
[1, seqlen, num_heads, head_dim] - 利用
rotary_pos_emb计算cos和sin向量cos = freqs.cos() # [seqlen, head_dim // 2], [(cos(xθ), cos(yθ)), ...] sin = freqs.sin() cos = cos.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float() # [1, seqlen, 1, head_dim] sin = sin.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float() - 利用公式
x_rotary = x * cos + rotate(x) * sin进行编码,这里每个头都被广播了,位置编码是一样的output = (tensor * cos) + (rotate_half(tensor) * sin) output = output.to(orig_dtype) - 再把第一个维度squeeze掉,最后得到的q和k形状都是
[seqlen, num_heads=16, head_dim=80]
- 把Q扩充一个维度,变成
- 计算
attention_mask,用于控制每张图片只和自己做attention,这里用到了cu_seqlensattention_mask = torch.zeros([1, seq_length, seq_length], device=q.device, dtype=torch.bool) for i in range(1, len(cu_seqlens)): attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = True - 最后进行经典的sdpa attention,注意要把
num_heads维度交换到前面去,最后再换回来q = q.transpose(0, 1) k = k.transpose(0, 1) v = v.transpose(0, 1) attn_output = F.scaled_dot_product_attention(q, k, v, attention_mask, dropout_p=0.0) attn_output = attn_output.transpose(0, 1) attn_output = attn_output.reshape(seq_length, -1) attn_output = self.proj(attn_output) return attn_output
最后进行patch的融合,2*2的相邻patch融合成一个。操作如下,self.ln_q是layernorm,self.mlp是一个两层全连接,中间用GELU激活。维度变化是[config.embed_dim=1280 * 4 -> config.embed_dim=1280 * 4 -> config.hidden_size=1536]
x = self.mlp(self.ln_q(x).view(-1, self.hidden_size))
return x
ViT 最后输出的image_embeds的shape为[2652, 1536],其中2652 = 10608 // 4
图像embedding和文本embedding的融合
现在我们的inputs_embeds和image_embeds的hidden_size都统一为了config.hidden_size=1536,并且inputs_embeds中的image token恰好就有image_embeds第一维度那么多个。那么可以直接把image_embeds嵌入到对应位置,这在代码里是通过一个masked_scatter来实现的
image_mask = (
(input_ids == self.config.image_token_id)
.unsqueeze(-1)
.expand_as(inputs_embeds)
.to(inputs_embeds.device)
)
print("image_mask.shape: ", image_mask.shape)
image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)
把attention_mask传到相同device之后,就可以把图像和文本融合好的embedding送给传统的Transformer去做attention了。还记得之前算的position_ids,attention_mask吗,这里一并传入Transformer。
对全部embedding做计算–Qwen2VLModel
现在我们把融合之后的inputs_embeds送给Qwen2VLModel进行计算,在forward里面会先根据position_ids计算好rotary_emb,然后一层层送入decode layer进行一次SdpaAttention,具体的类是Qwen2VLSdpaAttention。
我们之前计算的
position_ids形状为[3, 1, seqlen],首先我们会在Qwen2VLRotaryEmbedding类里面计算positional_embedding,具体就是首先得到一个self.inv_freq,形状为[3, 1, dim=64],然后和position_ids做outter product,最后一维复制一遍,再取cos和sin,一并返回,维度均为[3, 1, seqlen, 128]
在SpdaAttention里面,首先计算QKV,这里KV的头数其实是
num_heads的一个因数,Q不变,是为GQA。然后通过apply_multimodal_rotary_pos_emb函数根据position_embeddings对QK进行位置编码。注意这里的细节:我们之前计算的position_ids是按照thw的顺序在第一个维度cat起来的。所以现在的position_embeddings也是这个顺序。
Qwen认为,64维的
inv_freq里面,一部分角度编码t,一部分角度编码h,一部分角度编码w。这个分割做在config.mrope_section里面。去看一下就知道,他设定的值为[16, 24, 24]。由于我们把虚部都统一放在维度的后一半儿,并且把cos,sin复制了一遍,所以也要先把mrope_section复制一遍,变成[16, 24, 24, 16, 24, 24]。接着分割cos和sin的最后一维,并按照[t, h, w, t, h, w]的顺序在第一维上选,重新在最后一维上拼接,并扩展head group的维度。最终形成[B, 1, seqlen, 128]的cos和sin。(不是,这也不是三位Rotary啊?正儿八经的三位Rotary应该是[tθ, hθ, wθ]的cos和sin来搞吧,它这是[tθ[:16], hθ[17:40], wθ[41: 64]]的cos和sin搞的)
cos = torch.cat([m[i % 3] for i, m in enumerate(cos.split(mrope_section, dim=-1))], dim=-1).unsqueeze(
unsqueeze_dim
)
sin = torch.cat([m[i % 3] for i, m in enumerate(sin.split(mrope_section, dim=-1))], dim=-1).unsqueeze(
unsqueeze_dim
)
Anyway,我们根据cos和sin就可以愉快地套公式编码了,编完之后返回QK。接着我们会把KV拿去更新KV Cache,KV Cache的update方法会返回拼上了缓存值的全量KV。随后把KV的Head Group展开(interleave_repeat)就可以拿去和Q做SpdaAttention啦~
每一层都会对QK进行位置编码,随后对KV Cache进行update,接着进行GQA,把头拼好,返回output和新的KV Cache。