哈喽,我是cos大壮!~
最近,还是有很多同学不是特别的理解 LSTM。今天想和大家再聊聊~
今儿和大家聊的是案例是:使用LSTM进行文本生成。
其实,大家可以把 LSTM 想象成一个超级记忆大师,他有一个特别的记忆系统,能把重要的信息记得很久,同时也能忘记一些不太重要的信息。
正经解释:LSTM,全称是长短期记忆网络(Long Short-Term Memory network),是一种特殊的循环神经网络(RNN)。它被设计用来处理和预测那些与时间序列相关的问题,能够记住长期的依赖信息,解决传统RNN在长序列上记忆力差的问题。
LSTM有三个门(Gate),每个门都起到不同的作用:
遗忘门(Forget Gate):决定哪些信息需要丢弃。 输入门(Input Gate):决定哪些新的信息需要记住。 输出门(Output Gate):决定当前要输出什么信息。
这三个门一起工作,帮助LSTM在处理长序列数据时,选择性地记住或遗忘信息,使得它能很好地处理时间序列数据。
OK,有了以上的一些解释,下面我给大家再举例一个非常通俗易懂的案例,一看就懂!~
假设我们有一个故事:“小猫去森林里找她的朋友,她遇到了很多动物。她特别记得遇到了一只会说话的狐狸。狐狸给了她一张藏宝图,然后小猫跟着藏宝图找到了一个神秘的宝箱。”
我们希望让 LSTM 理解这个故事,并在你给它一个开头时生成更多的故事情节。
首先,我们将故事文本转换为数字形式,方便 LSTM 处理。
原始文本: "小猫去森林里找她的朋友,她遇到了很多动物。她特别记得遇到了一只会说话的狐狸。狐狸给了她一张藏宝图,然后小猫跟着藏宝图找到了一个神秘的宝箱。"
转换为数字:使用字符到数字的映射将文本转化为数字序列。
故事输入:当你给 LSTM 一个开头,比如“小猫发现了一只”,LSTM 要决定如何继续这个故事。它会用以下步骤来生成文本:
遗忘门:决定在处理当前句子时,应该忘记多少之前的信息。例如,如果故事中小猫之前遇到了一只乌鸦,但这只乌鸦现在不再重要,LSTM 会“遗忘”这部分信息。
输入门:决定当前句子中哪些信息是新的和重要的。例如,当小猫发现了一只新动物时,LSTM 会把这部分信息“记住”,因为它可能对后续的故事情节有帮助。
细胞状态更新:更新记忆。结合遗忘和输入门的决定,LSTM 会更新内部的状态,以便能够记住重要的信息,同时忘记不重要的。
输出门:决定生成什么样的下一个字符或单词。基于更新后的记忆,LSTM 选择最适合的字符或单词来继续故事。例如,LSTM 可能会生成“发现了一只会唱歌的小鸟”。
生成故事
初始输入:“小猫发现了一只”
LSTM 会:
检查之前的故事记忆,决定哪些部分需要保留(比如,小猫之前遇到的动物); 处理新输入,决定新的信息(比如,小猫发现了一只新的动物); 生成下一个字符或单词,可能是“狐狸”或“小鸟”;
LSTM 可能生成的故事继续部分是:
“小猫发现了一只会唱歌的小鸟。小鸟告诉小猫,她需要找到一棵古老的树来解开藏宝图的秘密。”
假设我们使用 Python 和 TensorFlow/Keras 来实现 LSTM 模型生成故事:
import numpy as np
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import LSTM, Dense, Embedding
import matplotlib.pyplot as plt
# 准备数据
text = (
"小猫去森林里找她的朋友,她遇到了很多动物。"
"她特别记得遇到了一只会说话的狐狸。"
"狐狸给了她一张藏宝图,然后小猫跟着藏宝图找到了一个神秘的宝箱。"
"后来,小猫又遇到了一个会唱歌的小鸟。小鸟告诉她宝箱的秘密。"
"小猫在森林里遇到了很多朋友,他们一起分享了宝箱里的宝藏。"
"从此以后,小猫经常和她的朋友们一起去冒险,发现了更多的宝藏和秘密。"
"他们的故事传遍了整个森林,成为了大家津津乐道的话题。"
)
vocab = sorted(set(text))
char2idx = {u:i for i, u in enumerate(vocab)}
idx2char = np.array(vocab)
text_as_int = np.array([char2idx[c] for c in text])
# 生成训练样本和标签
seq_length = 10
examples_per_epoch = len(text) // (seq_length + 1)
char_dataset = tf.data.Dataset.from_tensor_slices(text_as_int)
sequences = char_dataset.batch(seq_length + 1, drop_remainder=True)
def split_input_target(chunk):
input_text = chunk[:-1]
target_text = chunk[1:]
return input_text, target_text
dataset = sequences.map(split_input_target)
BATCH_SIZE = 64
BUFFER_SIZE = 10000
dataset = dataset.shuffle(BUFFER_SIZE).batch(BATCH_SIZE, drop_remainder=True).repeat()
# 定义模型
vocab_size = len(vocab)
embedding_dim = 256
rnn_units = 512
model = Sequential([
Embedding(vocab_size, embedding_dim),
LSTM(rnn_units, return_sequences=True, stateful=False, recurrent_initializer='glorot_uniform'),
Dense(vocab_size)
])
def loss(labels, logits):
return tf.keras.losses.sparse_categorical_crossentropy(labels, logits, from_logits=True)
model.compile(optimizer='adam', loss=loss)
EPOCHS = 10
history = model.fit(dataset, epochs=EPOCHS, steps_per_epoch=examples_per_epoch)
# 文本生成函数
def generate_text(model, start_string, num_generate=100):
input_eval = [char2idx[s] for s in start_string]
input_eval = tf.expand_dims(input_eval, 0)
text_generated = []
temperature = 1.0
# 重置 LSTM 层的状态
lstm_layer = model.layers[1]
lstm_layer.reset_states()
for i in range(num_generate):
predictions = model(input_eval)
predictions = tf.squeeze(predictions, 0)
predictions = predictions / temperature
predicted_id = tf.random.categorical(predictions, num_samples=1)[-1,0].numpy()
input_eval = tf.expand_dims([predicted_id], 0)
text_generated.append(idx2char[predicted_id])
return start_string + ''.join(text_generated)
# 生成文本
print(generate_text(model, start_string="小猫发现了一只"))
LSTM 就像是一个记忆和遗忘的智能助手,帮助你记住重要的信息(比如故事中的关键情节)并且自动生成接下来的内容。通过门控机制,LSTM 能够在处理长序列数据时保持长时间的记忆,同时抛弃无关的信息,从而生成连贯的文本。
最后
大家有问题可以直接在评论区留言即可~
喜欢本文的朋友可以收藏、点赞、转发起来! 需要本文PDF的同学,扫码备注「案例汇总」即可~ 关注本号,带来更多算法干货实例,提升工作学习效率! 最后,给大家准备了《机器学习学习小册》PDF版本,16大块的内容,124个问题总结! 推荐阅读
原创、超强、精华合集 100个超强机器学习算法模型汇总 机器学习全路线 机器学习各个算法的优缺点 7大方面,30个最强数据集 6大部分,20 个机器学习算法全面汇总 铁汁,都到这了,别忘记点赞呀~