论文链接:Speech-Transformer——A No-Recurrence Sequence-to-Sequence Model for Speech Recognition

Speech Transformer即应用于语音识别的Transformer
如下图所示其结构与普通Transformer结构基本完全一致
不了解Transformer的可以看这篇博客Transformer——Attention is all you need详解,此处不再赘述

speech_transformer

Speech Transformer的亮点在于encoder中的2D Multi-Head Attention
因为语音识别的输入一般是2D频谱图,若用普通的attention则只能提取时间方向的特征

如图所示,因为是self-attention层,所以其输入只有一个,是尺寸为(height, width, channels)的特征图
首先对输入应用三个不同的卷积得到Query、Key和Value
再对Q,K,V分别进行filter数为c的卷积,c即表示head数,记为Q’,K’,V’

对Q’,K’,V’的c个通道分别进行Scaled Dot-Product Attention再进行concat,即完成了对时间维度的attention
同理先将Q’,K’,V’转置再进行attention,输出concat后再转置,即完成了频率维度的attention

最后将时间和频率的attention结果concat后再应用一个卷积,将输出通道数变回和输入相同即可

2Dattention

下面贴一个2D Attention的keras简单实现,完整SpeechTransformer代码见SpeechTransformer.py

因为encoder中Multihead Attention输入尺寸都是2D的,只有第一张图中addtional module的位置才有可能会有3D输入,所以以下实现中的输入输出默认尺寸为(batch_size, h, w)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
class MultiHeadAttention2D(Layer):
def __init__(self, num_heads):
"""A simple implementation of 2D Multi-Head Attention proposed in paper
"Speech-Transformer——A No-Recurrence Sequence-to-Sequence Model for Speech Recognition"

Note that the MultiHeadAttention2D can only be used as self-attention,
since it contains a transpose operation in "frequency attention"

Argument:
:param num_heads: number of attention heads, i.e. number of filters of convolution
"""
super().__init__()

self.num_heads = num_heads
self.conv_Q = Conv2D(num_heads, 5, padding='same')
self.conv_V = Conv2D(num_heads, 5, padding='same')
self.conv_K = Conv2D(num_heads, 5, padding='same')
self.conv_out = Conv2D(1, 5, padding='same')

def call(self, query, value, key=None):
"""
:param query: Query Tensor of shape (batch_size, Tq, dim)
:param value: Value Tensor of shape (batch_size, Tv, dim)
:param key: Key Tensor of shape (batch_size, Tv, dim). If not
given, will use 'value' for both 'key' and 'value'
"""
if not key:
key = value

# expand (channel) dimension to apply convolution
# shape (batch_size, T, dim) -> (batch_size, T, dim, 1)
query = K.expand_dims(query, axis=-1)
value = K.expand_dims(value, axis=-1)
key = K.expand_dims(key, axis=-1)

# shape (batch_size, T, dim, num_heads)
feat_Q = self.conv_Q(query)
feat_V = self.conv_V(value)
feat_K = self.conv_K(key)

# Separate feature maps by channel
# Then generate a list of tuples of length num_heads
# like [(Q1, V1, K1), (Q2, V2, K2), ..., (Qn, Vn, Kn)]
combined = [
(feat_Q[:, :, :, i], feat_V[:, :, :, i], feat_K[:, :, :, i])
for i in range(self.num_heads)
]

# transpose feature maps to apply frequency attention
combined_transpose = [
(K.permute_dimensions(feat_query, (0, 2, 1)),
K.permute_dimensions(feat_value, (0, 2, 1)),
K.permute_dimensions(feat_key, (0, 2, 1)))
for feat_query, feat_value, feat_key in combined
]

out_temporal_atten = [
Attention()([feat_query, feat_value, feat_key])
for feat_query, feat_value, feat_key in combined
]

out_frequncy_atten = [
Attention()([feat_query_trans, feat_value_trans, feat_key_trans])
for feat_query_trans, feat_value_trans, feat_key_trans in combined_transpose
]

# concatenate feature maps by channel
feat_time = K.concatenate([K.expand_dims(feat, -1) for feat in out_temporal_atten], axis=-1)
feat_freq = K.concatenate([K.expand_dims(feat, -1) for feat in out_frequncy_atten], axis=-1)

feat = K.concatenate([feat_time, K.permute_dimensions(feat_freq, (0, 2, 1, 3))], axis=-1)
outputs = self.conv_out(feat)
outputs = K.squeeze(outputs, axis=-1)

return outputs