融合XLnet与DMGAN的文本生成图像方法

2024-03-05 08:15赵泽纬车进吕文涵
液晶与显示 2024年2期
关键词:特征向量编码器单词

赵泽纬, 车进, 吕文涵

(宁夏大学 物理与电子电气工程学院, 宁夏 银川 750021)

1 引言

文本生成图像是一种跨模态的研究任务,这中间主要应用自然语言处理(NLP)和计算机视觉(CV)两个研究领域。变分编码器(VAE)[1]、自回归模型[2]、生成对抗网络(GANs)[3]等深度学习网络的引入,为文本生成图像奠定了基础。尽管变分编码器是第一个通过输入信息的潜在表达生成样本的深度图像生成模型,但是由于注入噪声和VAE模型重建不完整,生成的图像模糊。自回归生成模型如pixelRNN[4]、pixelCNN[5]和pixelCNN++[6]比VAE更加有效。由于没有额外的噪声,pixelRNN需要较长的训练时间,并且由于像素值计算错误,pixelCNN遗漏了某些像素。为了避免pixelCNN中的盲点问题,piexlCNN++使用层之间的残留连接。尽管如此,由于图像的顺序(像素到像素)生成,自回归生成模型缺乏可伸缩性[7]。此后,学者们使用生成对抗网络GAN生成与真实图片相似的图片。GAN网络由生成器和鉴别器两部分组成。先用生成器生成新的图片,然后用鉴别器鉴别生成的图片是生成的图像还是真实图片[3]。在注释良好的图像数据集上训练的GAN可以生成接近非常真实的新图像。GAN要学习高度复杂的数据分布,但由于不收敛和模态坍塌等原因,训练难度较大。Zhang等[8]提出了具有两级GAN结构的StackGAN。Stack-GAN中的第一阶段是生成低分辨率的只能看出物体大体形状和颜色的图像;第二阶段对此图像进行细化,生成高分辨率图像。StackGAN++[9]使用多个生成器和判别器生成256×256的图片。这两种GAN网络都不以全局句子向量为条件,因此图像生成缺少细粒度的单词级信息,生成的图像不能令人满意。为解决这一问题,Xu等[10]提出了AttnGAN,利用深度注意多模态相似度模型(DAMSM)和注意力机制来描绘图像的局部区域。虽然这些方法取得了显著进展,但仍存在生成图像质量取决于初始图像,以及无法细化输入句子中每个单词描述图像内容的不同层次信息两个问题。为此,Zhu等[11]提出了DMGAN,增加一种内存机制处理不良的初始图片,引入内存写入门,动态选择与生成图像相关的单词。但是DMGAN的文本编码器还是使用RNN编码器,由于RNN的顺序性质,在从单词嵌入中提取文本语义时,会忽略一些单词,导致图像属性的损失,使重要信息被省略,最终生成的图像和文本存在语义不一致的问题。

在图像生成的同时,文本编码方法也在日新月异,从最初Uchida等[12]提出的word2vec,简单地生成词向量,到经典神经网络的RNN[13-14]的提出,此网络拥有优秀的并行能力和双向提取文本特征能力。Ashish等[15]提出Transformer模型,使用自注意力机制对文本信息进行编码。随后加强版的Transformer即BERT[16]出现,将文本进行双向编码,能够更好地挖掘文本信息。但是BERT在建模时,过度简化了一些高阶特征及长距离token语义依赖。针对这些问题,Yang等[17]提出XLnet(Generalized Autoregressive Pretraining for Language Understanding)模型,利用XLnet将自回归模型固定的向前或者向后替换为最大序列的对数似然概率期望,使上下文的token都能被每个位置的token所使用,并使用一种乘积方式来分解预测tokens的联合概率,继而消除BERT的token之间的独立假设,实现对文本信息的进一步挖掘。

针对DMGAN文本编码阶段的不足,引入XLnet编码器对文本进行编码,使DMGAN模型在初始阶段获取更多的文本信息,有利于生成更高质量的图片,并在图像生成的初始阶段和图像细化阶段均加入高效通道注意力[18](ECA)来进一步提高生成图像质量。

2 相关知识

2.1 DMGAN

DMGAN在AttnGAN的基础上进行改进,用一种动态记忆模块替换AttnGAN中注意力机制,生成更加生动形象的图像。DMGAN体系结构主要有深度注意多模态相似度网络(DAMSM)和图像生成网络。

DAMSM计算DMGAN模型生成的图像与输入文本在单词级别上的相似性。训练DAMSM使图像文本相似度最大化。图像生成网络包含初始图像生成和细化图像两部分。初始图像生成阶段是指文本通过文本编码器获得其语句特征,然后将语句特征和随机噪声融合,再通过一个全连接层和4个上采样层生成初始图像;图像细化阶段由记忆写入、键寻址、键值读取和键值响应4个部分组成。每次细化后图像像素值翻倍,本文最后实现像素值为256×256的图像。

2.2 XLnet

XLnet是一种双向捕获上下文的自回归模型。该模型训练句子中对所有可能的单词排列,而不是默认的从右到左或者从左向右排列。XLnet将BERT双向编码的优点和LSTMs等序列模型的递归函数结合,解决了BERT固定句子大小的限制。XLnet将文本视为块状,概率预测信息在这些块中传递,实现对每块信息的内容和位置的预测。

2.3 ECA通道注意力机制

ECA通道注意力机制是在SE通道注意力机制上的改进。SE通道注意力机制会对输入特征图进行压缩,不利于学习通道之间的依赖关系。为了避免降维,ECA通道注意力机制用一维卷积实现了局部通道交互,具体操作分以下3步:(1)对输入的特征图进行全局平均池化操作;(2)进行一维卷积操作,然后用sigmoid函数进行激活得到各个通道的权重;(3)将权重和原始输入特征图进行相乘操作,得到输出特征图。ECA通道注意力机制示意图如图1所示。

图1 ECA通道注意力机制Fig.1 ECA channel attention mechanism

3 网络模型

本文提出的XLnet与DMGAN融合模型的结构如图2所示,黄色区域为改进部分。首先应用AttnGAN[16]模型中的深度注意多模态相似度网络(DAMSM)计算细粒度图像到文本的匹配损失。在训练DAMSM网络时,将其原有的编码器替换为XLnet文本编码器,图像编码器保持不变,计算DAMSM损失并加入到DMGAN模型的生成器损失中,后续使用生成对抗网络生成像素为64×64的初始图像,最后利用动态内存将图像进行两次细化,分别生成像素为128×128和256×256的图像。

图2 XLnet-DMGAN融合网络结构Fig.2 Converged network architecture of XLnet-DMGAN

3.1 深度注意多模态相似度网络

文本字词特征用XLnet文本编码器提取,图像特征用inception-v3[19]图像编码器提取,将提取到的特征转换到公共空间进行训练,表达式如式(1)所示:

其中:f是局部特征提取矩阵,其维度为768×289,768是局部特征向量的维数,289是图像中子区域个数是全局特征提取矩阵,其维度为2 048;W是感知层,将图像特征和文本特征转换到共同语义空间;v和vˉ分别是图像局部和全局特征向量转化到公共区域的向量。

字词特征维度为768×T,768是单词特征向量维数,T是文本单词个数。首先计算句子中每个单词和图像中的子区域的相似矩阵,其表达式为:

式中:s∈RT×289和si,j是指句子的第i个单词和图像的第j个子区域点积相似度,v为图像局部特征转换到语义空间的向量,e是字词特征向量。

接着建立一个注意力模型计算图像相关区域和句子第i个单词的动态表示ci,其具体表达式如式(3)所示:

其中:vj和αj分别是第j个图像子区域特征和针对第j个图像子区域的注意力权重;ci为所有区域视觉向量的加权总和,也就是句子第i个单词相关的图像子区域的动态特征;γ1为参数。

然后通过ci和字词特征e的余弦相似确定第i个单词和图像之间的相关性,表达式如式(4)所示

一个图像(Q)和其对应的一个文本描述(D)之间的注意力驱动的图像-文本匹配得分定义为:

式中,γ2为参数,决定最相关的单词到区域上下文对的重要性放大多少。

因此,对于每一个batch的图像Qi和文本Di组成的其相匹配的后验概率为:

式中,γ3为实验确定的平滑因子。在所有句子中,只有Di匹配图像Qi,其余的M-1字词都视为不匹配的描述。

字词级别的文本匹配图像损失函数采用负对数后验概率,其表达式如式(7)所示:

式中,w为word,即单词。对应可得P(Di|Qi)的损失函数如式(8)所示:

将式(5)重新定义为:

3.2 初始图像生成阶段

在初始图像生成阶段,由给定的文本通过文本编码器生成语句特征向量和字词特征向量。本文的文本编码器使用的是新提出的XLnet文本编码器。语句特征向量s是包含整个文本语句特征的向量,该向量用于初始图像的生成。字词特征是包含单词个数的字词特征向量,该向量用于提升初始生成图像的分辨率。文本编码器编码得到的语句特征向量s需要先进行条件增强,首先从语句特征向量s的高斯分布中的到它的平均协方差矩阵μ(s)和对角协方差矩阵σ(s),然后计算特征向量c0.(c0=μ(s)⊕σ(s)⊗ε,ε~N(0,1)),再将c0和一个正态分布中随机取样的噪声Z拼接得到进行一次全连接操作和4次上采样操作得到初始特征图像R0,最后通过ECA通道注意力卷积模块和一次3×3卷积块生成初始图像。

原始的RNN编码器只能从左向右或者从右向左编码,这使得从embedding层中提取文本语义的过程会忽略一些单词和曲解语句信息,导致图像属性的损失。针对这些问题,提出一种基于XLnet的编码器的文本编码器,实现对文本信息的深度挖掘。融合XLnet编码后,整体文本编码结构如图3所示。整个图像编码器由5部分组成,分别是输入、文本预处理、XLnet预训练编码器和输出。具体实现细节如下:

图3 文本编码器流程框图Fig.3 Flow diagram of text encoder

(1)导入pytorch_transformers库中的XLNet-Model类和XLNetTokenizer类。XLNetModel是PyTorch提供的XLNet模型网络结构,XLNet-Tokenizer是XLNet模型的分词工具,存储模型的词汇表并提供用于编码/解码需要的token embedding。

(2)数据预处理阶段:对读取的文本进行处理,需要用到XLNetTokenizer类中基于SentencePiece构造的tokenizer方法,数据预处理方法见算法1。例如文本text=[‘这个鸟有白色翅膀和白色腹部’],用tokenizer对句子分词后得到tokens=[‘这个’,‘鸟’,‘有’,‘白色’,‘翅膀’,‘和’,‘白色’,‘腹部’]。tokenizer将文本划分成8个词组成的序列,接着对tokens计数,返回一个字典类型的数据,键是元素,值是元素出现的次数,即{‘这个’:1,‘鸟’:1,‘有’:1,‘白色’:2,‘翅膀’:1,‘和’:1,‘腹部’:1}。接着对照加载的token embedding词表找到词组索引。token embedding是包含实例化标记程序所需的词汇表,比如“这个”索引值为3 683,依次类推。

(3)构建XLNetModel阶段:对于阶段(2)中获取到的词组索引token_index,构建一个XLNet模型计算词组的字词向量表达。XLNetModel是PyTorch提供的XLNet模型网络结构,构建XLNet模型训练字词向量方法如算法2所示。

对于构造的XLNet模型,初始化其Embedding矩阵shape=(32 000,768),由32 000个维度为768的特征向量组成。对于由阶段(2)得到的词组索引,对照初始化的词表,根据索引值查找到其对应的特征向量,最终对其加权平均得到XLNet模型训练的字词向量。

对于文本text,经过XLNet模型对每个词组的上下文内容学习,得到一个Word Embedding矩阵:shape=(8,768),简单理解为将每个词语映射到一个768维的矩阵中。例如“翅膀”经过XLNet模型对其上下文的学习得到的字词特征向量为[-9.515 4e-02,-7.279 3e-02,-2.319 0e-01,…,-3.872 8e-05,-9.983 7e-02,-1.942 1e-04],通过学习字词特征向量及字词的位置,生成相应的语句特征Sentence Embedding,矩阵shape=(1,768)。例如“这个鸟有白色翅膀和白色腹部”经过XLNet模型对其上下文的学习得到语句特征向量为[-8.946 8e-01,-3.181 3e-01,-7.396 6e-01,…,-2.104 8e-02,-6.997 3e-01,-7.210 0e-01],[-9.401 6e-04,3.475 4e-02,1.271 4e-01,…,-4.303 3e-03,6.347 8e-01,-7.069 6e-01],[-2.299 2e-02,2.675 0e-04,-6.227 4e-03,…,-8.204 7e-04,2.669 5e-01,2.899 6e-20],…,[5.412 2e-01,-1.777 2e-11,9.400 5e-01,…,9.603 8e-01,3.736 8e-01,6.806 8e-06],[3.825 3e-04,5.585 0e-05,4.050 8e-01,…,7.618 3e-01,5.420 4e-01,8.125 2e-03],[-7.615 9e-01,-8.309 0e-09,5.942 4e-07,…,5.821 6e-01,4.719 8e-01,-5.940 6e-01]]。

3.3 图像细化阶段

在图像细化阶段,将更多细粒度的信息添加到模糊初始图像中,生成较上一阶段逼真的图像xi:xi=Gi(Ri-1,W),其中Ri-1为上一阶段的图像特征。细化阶段主要由内存写入、键寻址、V值读取、响应和ECA通道注意力5个部分组成。首先内存写入功能将文本内容存储到键值结构化存储器内,以便检索。通过键寻址和V读取操作从内存模块中读取特征,细化初始生成图像质量。再采用V响应操作控制图像特征的融合;最后将融合后的特征图像通过ECA通道注意力加权融合。细化阶段可以重复多次(本文重复两次)以检索更相关的信息,并生成具有更细粒度细节的高分辨率图像。

3.3.1 动态内存

从给定的输入词W、图像x和图像特征Ri进行计算:

其中:T为单词数,Nw为单词特征的维数,N为图像像素数,图像像素特征为Nr维向量。

细化阶段包含内存写入、键寻址、V值读取和响应。内存写入主要是指对先验知识进行编码。内存写入将文本特征经过一次1×1卷积运算嵌入到n维的记忆特征空间中,具体表达公式如式(13)所示:

式中,M(· )表示1×1卷积。键寻址主要是使用键存储器检索相关的存储器,计算每个内存插槽的权重作为内存插槽mi和图像特征ri:

式中:αi,j为第i个记忆体与第j个图像特征之间的相似概率,φK(· )为将记忆内存特征映射到维数Nr的一个内存访问进程,φK(· )表示1×1卷积。V值读取是指输出存储器表示被定义为根据相似概率对值存储器进行加权求和,具体表达式如式(15)所示:

式中,φV(· )为将内存特征映射到维度Nr的值内存访问过程。φV(· )实现为1×1卷积。在接收到输出存储器后,将当前图像和输出图像相结合以提供一个新的图像特征。一种简单的方法就是单纯地将图像特征和输出表示连接起来,得到全新的图像特征,其表达式如式(16)所示:

式中,[· ,· ]表示拼接操作。然后,利用一个上采样块和几个残留块,得到一个较高分辨率的图像特征。上采样块由一个上采样层和一个3×3卷积组成。最后,利用3×3卷积从新的图像特征中得到细化的图像x。

3.3.2 内存写入门

内存写入门允许DMGAN模型选择相关的单词来细化初始图像,它将最后阶段的图像特征与单词特征相结合,计算出单词的重要性,其公式如式(17)所示:

式中:σ是sigmoid函数,A是一个1×Nw矩阵,B是一个1×Nr矩阵。结合图像和文本特征编写内存插槽mi∈RNm。Mw(· )和Mr(· )表示1×1卷积运算。Mw(· )和Mr(· )将图像特征和字词特征拼接起来进行输入。

3.3.3 响应门

利用自适应门控机制动态控制信息流,更新图像特征,其表达式如式(18)所示:

3.3.4 生成器

生成器的目标函数可以表示为:

式中:λ1和λ2分别为条件增强损失和DAMSM损失的权重,G0表示初始生成过程的生成器,Gi表示图像细化阶段第i次迭代的生成器。

对抗损失Gi的定义如式(20)所示:

式中第一项是无条件损失,使生成的伪图像尽可能真实;第二项是条件损失,使生成的伪图像匹配输入的句子。

每个鉴别器Di的对抗损失定义为:

式中上半部分为无条件损失,用于将生成的伪图像与真实图像区分开来;下半部分为条件损失,决定生成的伪图像与输入句子是否相符。

条件增强(CA)损失描述了训练数据的标准高斯分布和高斯分布之间的KL散度:

式中,μ(s)和∑(s)为句子特征的均值和对角协方差矩阵。μ(s)和∑(s)通过全连接层计算。

DAMSM损失用来衡量图像和文本描述之间的匹配程度,其相关理论及公式在2.1节已详细介绍,不再赘述。

4 实验与结果分析

4.1 实验环境及数据集

本文所做实验的软硬件环境如下:系统为Ubuntu 20.04,CPU为Intel(R) Xeon(R) Platinum 8350C,GPU为 GeForce RTX 3090(24G),Cuda版本为11.3,Python版本为3.9,所用的深度学习框架为Pytorch。

为了验证DMGAN和XLnet融合的网络图像生成能力,选用CUB数据集进行实验。CUB[20]数据集中鸟类包含了200个类别,每个类别平均约有60张鸟的图片,共11 788张。使用其中的8 855张图片进行训练,余下的2 933张图片用于测试。每张图片均有10个描述语句。

4.2 实验设置

本文的文本编码器选用的是xlnet-base-cased模型,设置其维度为768×300,与DMGAN模型的维度一致。

模型的训练分为两步:先进行DAMSM语义一致性网络的训练,训练生成xlnet-rnn-encoder和cnn-encoder两个编码器权重;然后将训练好的两个权重导入到DMGAN模型中,进行DMGAN生成对抗网络的训练。

具体参数选择配置如下:优化器方面选择的是Adam优化器,学习率α=0.000 2,轮数epochs=800,batchsize=20,参数λ=5。

4.3 评价指标

本文采用IS[21](inception score)和FID[22](fréchet inception distance)两种评价指标对实验结果进行量化评价。

IS通过预训练的inception_v3网络表示条件类分布和边缘类分布之间的KL散度,较高的IS值表示生成的图像多样性强,图像品质好,而且能明显鉴别出图像所属的类别,其具体的计算公式如式(23)所示:

式中:x为生成样本的图像,y为算法预测出的标签;DKL为计算(P(y|x)和P(y)的KL散度。

FID计算出生成图像和真实图像之间的Fréchet距离。FID越小,代表生成图片越接近真实图片。其具体的计算公式如式(24)所示:

式中:μr、μg分别为真实样本特征均值和生成样本特征均值,Tr建立了对真实数据和生成样本数据之间的协方差矩阵的求迹。

4.4 结果分析

4.4.1 指标对比

实验中,前期DAMSM语义一致性预训练损失变化曲线如图4所示。本文可视化了预训练过程中式(11)的DAMSM loss。得益于XLnet文本编码器,预训练的DAMSM语义一致性训练在进行到50个epoch时已经收敛,而原始的网络需要到150个epoch才开始收敛。收敛后,融合XLnet文本编码器的网络损失波动幅度小于原始网络的损失,说明该网络收敛速度快,并且稳定性也优于原始网络。

图4 DAMSM语义一致性训练损失变化图Fig.4 Loss change graph of DAMSM semantic consistency training

生成网络在CUB数据集上训练了800个epoch,生成约3 000张测试集图片,量化评价指标对比见表1。如表1所示,融合后网络模型得到的IS值达到5.22±0.18,较最初DMGAN模型的IS值4.75±0.07提升了0.47,与其他具有代表性的模型相比,效果也最好。本文所提模型的FID仅为13.31,较最初DMGAN模型的FID值16.09下降了2.78,说明融合后模型生成的图像在视觉上更加贴近真实图片,细节处处理得更好。

表1 各项评价指标对比Tab.1 Comparison of evaluation indicators

IS值的变化曲线如图5所示。可以看出在200轮之后,本文所提方法明显优于原始DMGAN模型。FID值的变化曲线如图6所示,可以看出在520轮之后,本文所提方法明显优于原始DMGAN模型。

图5 IS值的变化曲线Fig.5 Variation curves of IS value

图6 FID值的变化曲线Fig.6 Variation curves of FID value

由表1、图5、图6可以看出,本文利用XLnet编码器进行文本编码,深度理解文本信息,使得初始阶段生成的图像特征融合了更多的语义信息,后续模型生成了高保真、高质量且语义一致性的图像。

4.4.2 生成图像对比

将本文与几种具有代表性的文本生成图像的实验效果进行对比,参与对比的网络有Stack-GAN、AttnGAN、DMGAN,均使用公开模型且在同一环境下训练所得结果。训练测试结果如图7所示。在生成的图像中,从第一、二列中的测试结果可以看出,原始DMGAN模型生成的鸟存在整体、尾部、脚爪的失真;从第三、四列中测试结果可以看出,原始DMGAN模型生成的鸟存在空间结构不合理的问题,第三列中的翅膀和第四列中的脖颈均不符合正常鸟的结构。上述问题已用白色椭圆框标出。融合XLnet的编码器后,图像的语义一致性和整体形态都有明显提升。ECA通道注意力专注于生成图像的空间结构和具体细节之处,使得生成的图像更加符合真实图像,图像的美感也得到了提升,定性地表明了本文所融合模型优于原始模型。

图7 生成图像对比Fig.7 Comparison of generated images

为了详细分析文本生成图像的过程,例如文本text=[‘这种鸟是白色和黄色的,有一个非常短的喙。’],从开始生成的噪声图像到第一阶段结束时的64×64图像进行逐步展示,如图8所示。图像细化工作如图9所示。为了验证生成图像的多样性,图10使用相同文本描述和多个噪声向量生成多个现状和背景不同的图像。

图8 初始图像生成细节。(a)初始噪声图像;(b)经过全连接和4个上采样之后的图像;(c)经过ECA通道注意力和3×3卷积的初始图像。Fig.8 Generated details of initial image. (a) Initial noise image; (b) Image after full connection and four upsampling; (c) Initial image after ECA channel attention and 3×3 convolution.

图9 细化图像细节。(a)64×64图像;(b)128×128图像;(c)256×256图像;(d)注意力权重。Fig.9 Refined details of the image.(a) 64×64 image;(b) 128×128 image; (c) 256×256 image; (d)Attention weights.

图10 相同文本生成图像展示Fig.10 Generated images from the same text

4.5 消融实验

为了验证本文所提出的融合模型在文本生成图像的出色表现,本文进行了XLnet编码器模块和高效通道注意力的消融实验,结果如表2所示,其中,ECA为图像生成过程中添加的通道注意力模块。从表2可以看出,两个模块的结合得到了最优的实验结果,验证了两个模块的有效性。

表2 消融实验结果Tab.2 Results of ablation experiments

5 结论

针对文本生成图像任务中语义不匹配、图像细节损失、图像空间结构不合理等问题,本文提出了一种改进的DMGAN模型。首先引入NLP领域中的XLnet模型对文本进行编码,捕获上下文内容,实现了对文本信息的深度挖掘;其次在DMGAN模型的初始图像生成和图像细化两个阶段均加入高效通道注意力机制,提高了模型的泛化能力,模型的收敛速度和稳定性也得到了大幅提升。最后在公开数据集CUB上进行了实验验证,对比原始DMGAN模型,本文所提模型的IS指标提升了0.47,FID指标降低了2.78。结果表明,改进后的DMGAN模型提高了生成图像的质量、增强了生成图像的多样性,在文本生成图像领域具有一定的实际应用价值,如在行人重识别方面的应用[25]。

在未来的实验中,可以选择Transformer中的其他预训练模型,如预训练模型超大且拥有1 750亿预训练参数的GPT-3、轻量化的预训练模型ALBERT等;选择更好的生成对抗网络,如DFGAN,ManiGAN等做相应的融合实验。上述只是针对文本编码器的融合,也可以做图像编码器的融合,如将VIT(Vision Transformer)的预训练模型融合到生成对抗网络中的图像编码器中,验证VIT在文本生成图像这个下游任务中是否具有良好的表现,也可以同时将自然语言处理处理中的Transformer预训练模型和计算机视觉中的Transformer预训练模型一起融合到生成对抗网络中。

猜你喜欢
特征向量编码器单词
二年制职教本科线性代数课程的几何化教学设计——以特征值和特征向量为例
克罗内克积的特征向量
单词连一连
基于FPGA的同步机轴角编码器
看图填单词
一类特殊矩阵特征向量的求法
EXCEL表格计算判断矩阵近似特征向量在AHP法检验上的应用
看完这些单词的翻译,整个人都不好了
基于PRBS检测的8B/IOB编码器设计
JESD204B接口协议中的8B10B编码器设计