LSTM,一个强大算法模型 !!

文摘   2024-07-31 16:26   北京  

哈喽,我是cos大壮!~

最近,还是有很多同学不是特别的理解 LSTM。今天想和大家再聊聊~

今儿和大家聊的是案例是:使用LSTM进行文本生成。

其实,大家可以把 LSTM 想象成一个超级记忆大师,他有一个特别的记忆系统,能把重要的信息记得很久,同时也能忘记一些不太重要的信息。

老规矩如果大家伙觉得近期文章还不错!欢迎大家点个赞、转个发,文末赠送《机器学习学习小册》
文末可取本文PDF版本~

正经解释:LSTM,全称是长短期记忆网络(Long Short-Term Memory network),是一种特殊的循环神经网络(RNN)。它被设计用来处理和预测那些与时间序列相关的问题,能够记住长期的依赖信息,解决传统RNN在长序列上记忆力差的问题。

LSTM有三个门(Gate),每个门都起到不同的作用:

  1. 遗忘门(Forget Gate):决定哪些信息需要丢弃。
  2. 输入门(Input Gate):决定哪些新的信息需要记住。
  3. 输出门(Output Gate):决定当前要输出什么信息。

这三个门一起工作,帮助LSTM在处理长序列数据时,选择性地记住或遗忘信息,使得它能很好地处理时间序列数据。

OK,有了以上的一些解释,下面我给大家再举例一个非常通俗易懂的案例,一看就懂!~

假设我们有一个故事:“小猫去森林里找她的朋友,她遇到了很多动物。她特别记得遇到了一只会说话的狐狸。狐狸给了她一张藏宝图,然后小猫跟着藏宝图找到了一个神秘的宝箱。”

我们希望让 LSTM 理解这个故事,并在你给它一个开头时生成更多的故事情节。

首先,我们将故事文本转换为数字形式,方便 LSTM 处理。

  • 原始文本 "小猫去森林里找她的朋友,她遇到了很多动物。她特别记得遇到了一只会说话的狐狸。狐狸给了她一张藏宝图,然后小猫跟着藏宝图找到了一个神秘的宝箱。"
  • 转换为数字:使用字符到数字的映射将文本转化为数字序列。

故事输入:当你给 LSTM 一个开头,比如“小猫发现了一只”,LSTM 要决定如何继续这个故事。它会用以下步骤来生成文本:

  1. 遗忘门:决定在处理当前句子时,应该忘记多少之前的信息。例如,如果故事中小猫之前遇到了一只乌鸦,但这只乌鸦现在不再重要,LSTM 会“遗忘”这部分信息。

  2. 输入门:决定当前句子中哪些信息是新的和重要的。例如,当小猫发现了一只新动物时,LSTM 会把这部分信息“记住”,因为它可能对后续的故事情节有帮助。

  3. 细胞状态更新:更新记忆。结合遗忘和输入门的决定,LSTM 会更新内部的状态,以便能够记住重要的信息,同时忘记不重要的。

  4. 输出门:决定生成什么样的下一个字符或单词。基于更新后的记忆,LSTM 选择最适合的字符或单词来继续故事。例如,LSTM 可能会生成“发现了一只会唱歌的小鸟”。

生成故事

初始输入:“小猫发现了一只”

LSTM 会:

  • 检查之前的故事记忆,决定哪些部分需要保留(比如,小猫之前遇到的动物);
  • 处理新输入,决定新的信息(比如,小猫发现了一只新的动物);
  • 生成下一个字符或单词,可能是“狐狸”或“小鸟”;

LSTM 可能生成的故事继续部分是:

“小猫发现了一只会唱歌的小鸟。小鸟告诉小猫,她需要找到一棵古老的树来解开藏宝图的秘密。”

假设我们使用 Python 和 TensorFlow/Keras 来实现 LSTM 模型生成故事:

import numpy as np
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import LSTM, Dense, Embedding
import matplotlib.pyplot as plt

# 准备数据
text = (
    "小猫去森林里找她的朋友,她遇到了很多动物。"
    "她特别记得遇到了一只会说话的狐狸。"
    "狐狸给了她一张藏宝图,然后小猫跟着藏宝图找到了一个神秘的宝箱。"
    "后来,小猫又遇到了一个会唱歌的小鸟。小鸟告诉她宝箱的秘密。"
    "小猫在森林里遇到了很多朋友,他们一起分享了宝箱里的宝藏。"
    "从此以后,小猫经常和她的朋友们一起去冒险,发现了更多的宝藏和秘密。"
    "他们的故事传遍了整个森林,成为了大家津津乐道的话题。"
)

vocab = sorted(set(text))
char2idx = {u:i for i, u in enumerate(vocab)}
idx2char = np.array(vocab)
text_as_int = np.array([char2idx[c] for c in text])

# 生成训练样本和标签
seq_length = 10
examples_per_epoch = len(text) // (seq_length + 1)
char_dataset = tf.data.Dataset.from_tensor_slices(text_as_int)
sequences = char_dataset.batch(seq_length + 1, drop_remainder=True)

def split_input_target(chunk):
    input_text = chunk[:-1]
    target_text = chunk[1:]
    return input_text, target_text

dataset = sequences.map(split_input_target)
BATCH_SIZE = 64
BUFFER_SIZE = 10000
dataset = dataset.shuffle(BUFFER_SIZE).batch(BATCH_SIZE, drop_remainder=True).repeat()

# 定义模型
vocab_size = len(vocab)
embedding_dim = 256
rnn_units = 512

model = Sequential([
    Embedding(vocab_size, embedding_dim),
    LSTM(rnn_units, return_sequences=True, stateful=False, recurrent_initializer='glorot_uniform'),
    Dense(vocab_size)
])

def loss(labels, logits):
    return tf.keras.losses.sparse_categorical_crossentropy(labels, logits, from_logits=True)

model.compile(optimizer='adam', loss=loss)

EPOCHS = 10
history = model.fit(dataset, epochs=EPOCHS, steps_per_epoch=examples_per_epoch)

# 文本生成函数
def generate_text(model, start_string, num_generate=100):
    input_eval = [char2idx[s] for s in start_string]
    input_eval = tf.expand_dims(input_eval, 0)
    
    text_generated = []
    temperature = 1.0
    
    # 重置 LSTM 层的状态
    lstm_layer = model.layers[1]
    lstm_layer.reset_states()
    
    for i in range(num_generate):
        predictions = model(input_eval)
        predictions = tf.squeeze(predictions, 0)
        predictions = predictions / temperature
        predicted_id = tf.random.categorical(predictions, num_samples=1)[-1,0].numpy()
        
        input_eval = tf.expand_dims([predicted_id], 0)
        text_generated.append(idx2char[predicted_id])
    
    return start_string + ''.join(text_generated)

# 生成文本
print(generate_text(model, start_string="小猫发现了一只"))

LSTM 就像是一个记忆和遗忘的智能助手,帮助你记住重要的信息(比如故事中的关键情节)并且自动生成接下来的内容。通过门控机制,LSTM 能够在处理长序列数据时保持长时间的记忆,同时抛弃无关的信息,从而生成连贯的文本。

最后

大家有问题可以直接在评论区留言即可~

喜欢本文的朋友可以收藏、点赞、转发起来!
需要本文PDF的同学,扫码备注「案例汇总」即可~ 
关注本号,带来更多算法干货实例,提升工作学习效率!
最后,给大家准备了《机器学习学习小册》PDF版本16大块的内容,124个问题总结

推荐阅读

原创、超强、精华合集
100个超强机器学习算法模型汇总
机器学习全路线
机器学习各个算法的优缺点
7大方面,30个最强数据集
6大部分,20 个机器学习算法全面汇总
铁汁,都到这了,别忘记点赞呀~

深夜努力写Python
Python、机器学习算法
 最新文章