十分钟学会用 Keras 实现序列到序列学习

注意:这篇文章写于 2017 年。有关此处使用代码的最新版本,请参阅本教程

我经常看到这个问题——如何在 Keras 中实现 RNN 序列到序列学习?以下是一个简短的介绍。

请注意,本文假设您已经对循环网络和 Keras 有一定的了解。


什么是序列到序列学习?

序列到序列学习 (Seq2Seq) 是关于训练模型将序列从一个领域(例如英语句子)转换为另一个领域(例如翻译成法语的相同句子)。

"the cat sat on the mat" -> [Seq2Seq model] -> "le chat etait assis sur le tapis"

这可用于机器翻译或自由问答(针对自然语言问题生成自然语言答案)——总的来说,它适用于任何需要生成文本的情况。

有多种方法可以处理此任务,可以使用 RNN 或使用一维卷积网络。这里我们将重点关注 RNN。

简单情况:当输入序列和输出序列长度相同时

当输入序列和输出序列长度相同时,您可以简单地使用 Keras LSTM 或 GRU 层(或其堆栈)来实现此类模型。此示例脚本就是这样一种情况,它展示了如何训练 RNN 学习将编码为字符串的数字相加。

Seq2seq inference

这种方法的一个缺点是,它假设在给定 input[...t] 的情况下,可以生成 target[...t]。这在某些情况下有效(例如,添加数字字符串),但在大多数情况下无效。一般情况下,需要有关整个输入序列的信息才能开始生成目标序列。

一般情况:典型的序列到序列

一般情况下,输入序列和输出序列的长度不同(例如机器翻译),并且需要整个输入序列才能开始预测目标。这需要更高级的设置,这就是人们在没有进一步上下文的情况下提到“序列到序列模型”时通常所指的内容。其工作原理如下

  • 一个 RNN 层(或其堆栈)充当“编码器”:它处理输入序列并返回其自身的内部状态。请注意,我们丢弃了编码器 RNN 的输出,只保留状态。此状态将作为下一步中解码器的“上下文”或“条件”。
  • 另一个 RNN 层(或其堆栈)充当“解码器”:它经过训练,可以在给定目标序列先前字符的情况下,预测目标序列的下一个字符。具体来说,它被训练成将目标序列转换为相同的序列,但在时间上偏移了一个时间步,这种训练过程在这种情况下称为“教师强制”。重要的是,编码器使用编码器的状态向量作为初始状态,这就是解码器如何获得有关其应该生成什么的信息。实际上,解码器学习在以输入序列为条件的情况下,根据 targets[...t] 生成 targets[t+1...]

Seq2seq inference

在推理模式下,即当我们想要解码未知输入序列时,我们会经历一个略有不同的过程

  • 1) 将输入序列编码为状态向量。
  • 2) 从大小为 1 的目标序列开始(仅包含序列开始字符)。
  • 3) 将状态向量和 1 个字符的目标序列馈送到解码器,以生成下一个字符的预测。
  • 4) 使用这些预测对下一个字符进行采样(我们简单地使用 argmax)。
  • 5) 将采样字符追加到目标序列
  • 6) 重复,直到我们生成序列结束字符或达到字符限制。

Seq2seq inference

同样的过程也可以用来训练没有“教师强制”的 Seq2Seq 网络,即通过将解码器的预测重新注入解码器。

Keras 示例

让我们用实际代码来说明这些想法。

对于我们的示例实现,我们将使用一个由英语句子及其法语翻译组成的配对数据集,您可以从 manythings.org/anki/ 下载。要下载的文件名为 fra-eng.zip。我们将实现一个字符级序列到序列模型,逐个字符地处理输入并逐个字符地生成输出。另一种选择是词级模型,它在机器翻译中更为常见。在这篇文章的最后,您会发现一些关于使用 Embedding 层将我们的模型转换为词级模型的说明。

我们示例的完整脚本可以在 GitHub 上找到

以下是我们流程的总结

  • 1) 将句子转换为 3 个 Numpy 数组,encoder_input_datadecoder_input_datadecoder_target_data
    • encoder_input_data 是一个形状为 (num_pairs, max_english_sentence_length, num_english_characters) 的三维数组,包含英语句子的独热向量化表示。
    • decoder_input_data 是一个形状为 (num_pairs, max_french_sentence_length, num_french_characters) 的三维数组,包含法语句子的独热向量化表示。
    • decoder_target_datadecoder_input_data 相同,但偏移了一个时间步decoder_target_data[:, t, :] 将与 decoder_input_data[:, t + 1, :] 相同。
  • 2) 训练一个基本的基于 LSTM 的 Seq2Seq 模型,以根据 encoder_input_datadecoder_input_data 预测 decoder_target_data。我们的模型使用教师强制。
  • 3) 解码一些句子以检查模型是否正常工作(即将 encoder_input_data 中的样本转换为 decoder_target_data 中的对应样本)。

因为训练过程和推理过程(解码句子)截然不同,所以我们对两者使用不同的模型,尽管它们都利用了相同的内部层。

这是我们的训练模型。它利用了 Keras RNN 的三个关键特性

  • return_state 构造函数参数,配置 RNN 层以返回一个列表,其中第一个条目是输出,后面的条目是内部 RNN 状态。这用于恢复编码器的状态。
  • inital_state 调用参数,指定 RNN 的初始状态。这用于将编码器状态作为初始状态传递给解码器。
  • return_sequences 构造函数参数,配置 RNN 以返回其完整的输出序列(而不仅仅是最后一个输出,这是默认行为)。这在解码器中使用。
from keras.models import Model
from keras.layers import Input, LSTM, Dense

# Define an input sequence and process it.
encoder_inputs = Input(shape=(None, num_encoder_tokens))
encoder = LSTM(latent_dim, return_state=True)
encoder_outputs, state_h, state_c = encoder(encoder_inputs)
# We discard `encoder_outputs` and only keep the states.
encoder_states = [state_h, state_c]

# Set up the decoder, using `encoder_states` as initial state.
decoder_inputs = Input(shape=(None, num_decoder_tokens))
# We set up our decoder to return full output sequences,
# and to return internal states as well. We don't use the 
# return states in the training model, but we will use them in inference.
decoder_lstm = LSTM(latent_dim, return_sequences=True, return_state=True)
decoder_outputs, _, _ = decoder_lstm(decoder_inputs,
                                     initial_state=encoder_states)
decoder_dense = Dense(num_decoder_tokens, activation='softmax')
decoder_outputs = decoder_dense(decoder_outputs)

# Define the model that will turn
# `encoder_input_data` & `decoder_input_data` into `decoder_target_data`
model = Model([encoder_inputs, decoder_inputs], decoder_outputs)

我们用两行代码训练模型,同时监控 20% 的样本的保留集上的损失。

# Run training
model.compile(optimizer='rmsprop', loss='categorical_crossentropy')
model.fit([encoder_input_data, decoder_input_data], decoder_target_data,
          batch_size=batch_size,
          epochs=epochs,
          validation_split=0.2)

在 MacBook CPU 上运行一个小时左右后,我们就可以进行推理了。为了解码测试句子,我们将重复执行以下操作

  • 1) 编码输入句子并检索初始解码器状态
  • 2) 使用此初始状态和“序列开始”标记作为目标,运行解码器的一个步骤。输出将是下一个目标字符。
  • 3) 追加预测的目标字符并重复。

以下是我们的推理设置

encoder_model = Model(encoder_inputs, encoder_states)

decoder_state_input_h = Input(shape=(latent_dim,))
decoder_state_input_c = Input(shape=(latent_dim,))
decoder_states_inputs = [decoder_state_input_h, decoder_state_input_c]
decoder_outputs, state_h, state_c = decoder_lstm(
    decoder_inputs, initial_state=decoder_states_inputs)
decoder_states = [state_h, state_c]
decoder_outputs = decoder_dense(decoder_outputs)
decoder_model = Model(
    [decoder_inputs] + decoder_states_inputs,
    [decoder_outputs] + decoder_states)

我们使用它来实现上面描述的推理循环

def decode_sequence(input_seq):
    # Encode the input as state vectors.
    states_value = encoder_model.predict(input_seq)

    # Generate empty target sequence of length 1.
    target_seq = np.zeros((1, 1, num_decoder_tokens))
    # Populate the first character of target sequence with the start character.
    target_seq[0, 0, target_token_index['\t']] = 1.

    # Sampling loop for a batch of sequences
    # (to simplify, here we assume a batch of size 1).
    stop_condition = False
    decoded_sentence = ''
    while not stop_condition:
        output_tokens, h, c = decoder_model.predict(
            [target_seq] + states_value)

        # Sample a token
        sampled_token_index = np.argmax(output_tokens[0, -1, :])
        sampled_char = reverse_target_char_index[sampled_token_index]
        decoded_sentence += sampled_char

        # Exit condition: either hit max length
        # or find stop character.
        if (sampled_char == '\n' or
           len(decoded_sentence) > max_decoder_seq_length):
            stop_condition = True

        # Update the target sequence (of length 1).
        target_seq = np.zeros((1, 1, num_decoder_tokens))
        target_seq[0, 0, sampled_token_index] = 1.

        # Update states
        states_value = [h, c]

    return decoded_sentence

我们得到了一些不错的结果——这并不奇怪,因为我们解码的是从训练测试中获取的样本。

Input sentence: Be nice.
Decoded sentence: Soyez gentil !
-
Input sentence: Drop it!
Decoded sentence: Laissez tomber !
-
Input sentence: Get out!
Decoded sentence: Sortez !

至此,我们对 Keras 中的序列到序列模型的十分钟介绍就结束了。提醒:此脚本的完整代码可以在 GitHub 上找到

参考文献


额外常见问题解答

如果我想使用 GRU 层而不是 LSTM,该怎么办?

实际上要简单一些,因为 GRU 只有一个状态,而 LSTM 有两个状态。以下是调整训练模型以使用 GRU 层的方法

encoder_inputs = Input(shape=(None, num_encoder_tokens))
encoder = GRU(latent_dim, return_state=True)
encoder_outputs, state_h = encoder(encoder_inputs)

decoder_inputs = Input(shape=(None, num_decoder_tokens))
decoder_gru = GRU(latent_dim, return_sequences=True)
decoder_outputs = decoder_gru(decoder_inputs, initial_state=state_h)
decoder_dense = Dense(num_decoder_tokens, activation='softmax')
decoder_outputs = decoder_dense(decoder_outputs)
model = Model([encoder_inputs, decoder_inputs], decoder_outputs)

如果我想使用带有整数序列的词级模型,该怎么办?

如果您的输入是整数序列(例如,表示单词序列,由它们在字典中的索引编码),该怎么办?您可以通过 Embedding 层嵌入这些整数标记。方法如下

# Define an input sequence and process it.
encoder_inputs = Input(shape=(None,))
x = Embedding(num_encoder_tokens, latent_dim)(encoder_inputs)
x, state_h, state_c = LSTM(latent_dim,
                           return_state=True)(x)
encoder_states = [state_h, state_c]

# Set up the decoder, using `encoder_states` as initial state.
decoder_inputs = Input(shape=(None,))
x = Embedding(num_decoder_tokens, latent_dim)(decoder_inputs)
x = LSTM(latent_dim, return_sequences=True)(x, initial_state=encoder_states)
decoder_outputs = Dense(num_decoder_tokens, activation='softmax')(x)

# Define the model that will turn
# `encoder_input_data` & `decoder_input_data` into `decoder_target_data`
model = Model([encoder_inputs, decoder_inputs], decoder_outputs)

# Compile & run training
model.compile(optimizer='rmsprop', loss='categorical_crossentropy')
# Note that `decoder_target_data` needs to be one-hot encoded,
# rather than sequences of integers like `decoder_input_data`!
model.fit([encoder_input_data, decoder_input_data], decoder_target_data,
          batch_size=batch_size,
          epochs=epochs,
          validation_split=0.2)

如果我不想在训练中使用教师强制,该怎么办?

在某些特殊情况下,您可能无法使用教师强制,因为您无法访问完整的目标序列,例如,如果您正在对非常长的序列进行在线训练,在这种情况下,缓冲完整的输入-目标对是不可能的。在这种情况下,您可能希望通过将解码器的预测重新注入解码器的输入来进行训练,就像我们对推理所做的那样。

您可以通过构建一个对输出重新注入循环进行硬编码的模型来实现这一点

from keras.layers import Lambda
from keras import backend as K

# The first part is unchanged
encoder_inputs = Input(shape=(None, num_encoder_tokens))
encoder = LSTM(latent_dim, return_state=True)
encoder_outputs, state_h, state_c = encoder(encoder_inputs)
states = [state_h, state_c]

# Set up the decoder, which will only process one timestep at a time.
decoder_inputs = Input(shape=(1, num_decoder_tokens))
decoder_lstm = LSTM(latent_dim, return_sequences=True, return_state=True)
decoder_dense = Dense(num_decoder_tokens, activation='softmax')

all_outputs = []
inputs = decoder_inputs
for _ in range(max_decoder_seq_length):
    # Run the decoder on one timestep
    outputs, state_h, state_c = decoder_lstm(inputs,
                                             initial_state=states)
    outputs = decoder_dense(outputs)
    # Store the current prediction (we will concatenate all predictions later)
    all_outputs.append(outputs)
    # Reinject the outputs as inputs for the next loop iteration
    # as well as update the states
    inputs = outputs
    states = [state_h, state_c]

# Concatenate all predictions
decoder_outputs = Lambda(lambda x: K.concatenate(x, axis=1))(all_outputs)

# Define and compile model as previously
model = Model([encoder_inputs, decoder_inputs], decoder_outputs)
model.compile(optimizer='rmsprop', loss='categorical_crossentropy')

# Prepare decoder input data that just contains the start character
# Note that we could have made it a constant hard-coded in the model
decoder_input_data = np.zeros((num_samples, 1, num_decoder_tokens))
decoder_input_data[:, 0, target_token_index['\t']] = 1.

# Train model as previously
model.fit([encoder_input_data, decoder_input_data], decoder_target_data,
          batch_size=batch_size,
          epochs=epochs,
          validation_split=0.2)

如果您还有其他问题,请在 Twitter 上联系