大家好,今儿再来和大家聊聊LSTM~
这个话题很重要,很多同学私信聊到了。
今天我们用非常简单的大白话来解释「LSTM」(长短期记忆网络,Long Short-Term Memory)是什么,它是怎么工作的,为什么有用,最后通过一个简单的例子让你更明白。
什么是LSTM?
LSTM是一种特殊的神经网络,它特别擅长处理序列数据,比如时间序列、文本、音频等。它解决了传统神经网络在处理这种数据时“短期记忆不好”的问题。
普通神经网络的问题:当数据有时间顺序时,普通的神经网络容易忘记之前的信息。比如说,你在读一段文章,读到第10句话时,普通神经网络可能已经不记得前面几句话了,所以做决策会不准。 LSTM的优点:LSTM可以“记住”长期的信息(比如很久之前的内容),同时也能忘记不重要的信息(比如无关紧要的内容),所以它在处理这种有顺序的数据时表现特别好。
LSTM的三个「门」是什么?
LSTM有三个「门」来控制信息的流动,这些门就像“水龙头”,可以让信息通过、堵住,或者选择性地让部分信息通过。
1. 遗忘门(Forget Gate):
这个门决定哪些过去的信息要“忘记”。 比如,当处理一段新闻文章时,开头的天气描述可能不重要,LSTM会选择忘掉。
2. 输入门(Input Gate):
这个门决定哪些新的信息要“记住”。 比如,在读一段故事时,突然引入了一个重要的新角色,这个门会决定要把这个新角色的信息记下来。
3. 输出门(Output Gate):
这个门决定哪些信息应该输出,也就是当前应该关注的重点是什么。 比如,你现在读到文章的高潮部分,LSTM会选择把最重要的信息输出,帮助它更好地预测下一步内容。
为什么LSTM好用?
LSTM可以像一个有记忆力的小助手,帮你在处理复杂信息时记住有用的东西,忘记没用的东西,还能灵活应对变化的内容。这就像是你在看一部电视剧,LSTM帮你记住主要剧情,而不会因为剧中的无关细节而让你忘记主线。
例子:预测下一句话
假设我们有一段文本:“今天我去公园,天气很好,我带了一本书准备____。”
普通神经网络:可能只记住了最后几个词“带了一本书准备”,它不知道你前面提到的“公园”和“天气很好”,所以它很可能会预测错误,比如猜“带了一本书准备学习”。
LSTM:它不仅记住了“带了一本书准备”,还记得你一开始提到的“去公园”和“天气很好”,所以它更有可能做出合理的预测,比如“带了一本书准备阅读”或者“准备在草地上看书”。
LSTM的核心在于它能够在处理一段时间序列(比如文本或时间序列数据)时,不仅能记住长期的有用信息,还能选择性地忘记不重要的信息。这就像你的大脑,能够记住重要的事情,并根据当下的情况做出合理的判断。
接下来的内容,咱们分为2部分,一部分分享原理性的内容,另外一个部分分享一个完整的案例~
核心原理
1. LSTM 公式推导
LSTM是通过一系列门控机制来处理输入数据。假设我们有输入数据 和上一个时刻的隐藏状态 ,记忆单元的状态 ,那么LSTM通过以下公式进行计算:
遗忘门(Forget Gate)
遗忘门控制要“忘记”多少前一时刻的记忆 。计算公式如下:
其中:
是遗忘门的输出,值在 到 之间。 是权重矩阵, 是偏置向量。 是sigmoid激活函数,输出值 越接近 ,表示“保留”的信息越多。
输入门(Input Gate)
输入门控制要“记住”多少当前的输入信息 。这个过程分为两步:
生成一个候选状态 :
然后通过输入门控制要保留多少候选信息:
最后结合遗忘门的输出,计算新的记忆单元状态 :
输出门(Output Gate)
输出门决定了要输出的隐藏状态 ,这是模型的最终输出:
每一步都围绕着这些「门」的控制,遗忘旧信息,选择性记忆新信息,并输出重要的隐藏状态。
完整案例
接下来我们用Python来实现LSTM网络。为了展示效果,我们使用虚拟时间序列数据,模拟一些简单的预测任务。并且我们会用图形来展示各个门的输出、记忆状态的变化以及预测结果等。
以下所有内容,手动实现,更好的理解原理~
import numpy as np
import matplotlib.pyplot as plt
# 设置随机种子
np.random.seed(0)
# 激活函数
def sigmoid(x):
return 1 / (1 + np.exp(-x))
def tanh(x):
return np.tanh(x)
# LSTM的类
class LSTM:
def __init__(self, input_size, hidden_size):
# 权重初始化(随机小值)
self.W_f = np.random.randn(hidden_size, input_size + hidden_size) * 0.1
self.b_f = np.zeros((hidden_size, 1))
self.W_i = np.random.randn(hidden_size, input_size + hidden_size) * 0.1
self.b_i = np.zeros((hidden_size, 1))
self.W_C = np.random.randn(hidden_size, input_size + hidden_size) * 0.1
self.b_C = np.zeros((hidden_size, 1))
self.W_o = np.random.randn(hidden_size, input_size + hidden_size) * 0.1
self.b_o = np.zeros((hidden_size, 1))
self.hidden_size = hidden_size
def step(self, x_t, h_prev, C_prev):
# 拼接输入和上一个隐藏状态
combined = np.vstack((h_prev, x_t))
# 计算遗忘门
f_t = sigmoid(np.dot(self.W_f, combined) + self.b_f)
# 计算输入门
i_t = sigmoid(np.dot(self.W_i, combined) + self.b_i)
# 生成候选记忆单元
C_tilde = tanh(np.dot(self.W_C, combined) + self.b_C)
# 更新记忆单元
C_t = f_t * C_prev + i_t * C_tilde
# 计算输出门
o_t = sigmoid(np.dot(self.W_o, combined) + self.b_o)
# 更新隐藏状态
h_t = o_t * tanh(C_t)
return h_t, C_t, f_t, i_t, o_t
# 生成虚拟时间序列数据
time_steps = 100
x_data = np.sin(np.linspace(0, 3 * np.pi, time_steps)) # 正弦波作为输入
# 输入维度和隐藏层大小
input_size = 1
hidden_size = 10
# 初始化LSTM
lstm = LSTM(input_size, hidden_size)
# 初始化隐藏状态和记忆单元
h_t = np.zeros((hidden_size, 1))
C_t = np.zeros((hidden_size, 1))
# 保存门和状态的值以便后续绘图
h_states = []
C_states = []
f_gates = []
i_gates = []
o_gates = []
# 开始逐步处理时间序列
for t in range(time_steps):
x_t = np.array([[x_data[t]]]) # 当前输入(维度为1)
h_t, C_t, f_t, i_t, o_t = lstm.step(x_t, h_t, C_t)
# 保存每个时刻的状态
h_states.append(h_t)
C_states.append(C_t)
f_gates.append(f_t)
i_gates.append(i_t)
o_gates.append(o_t)
# 转换为NumPy数组以便绘图
h_states = np.squeeze(np.array(h_states))
C_states = np.squeeze(np.array(C_states))
f_gates = np.squeeze(np.array(f_gates))
i_gates = np.squeeze(np.array(i_gates))
o_gates = np.squeeze(np.array(o_gates))
# 开始绘图
plt.figure(figsize=(14, 10))
# 图1:输入数据的时间序列
plt.subplot(4, 1, 1)
plt.plot(x_data, label='Input (sin wave)', color='red')
plt.title('Input Time Series')
plt.legend()
# 图2:隐藏状态的变化
plt.subplot(4, 1, 2)
for i in range(hidden_size):
plt.plot(h_states[:, i], label=f'hidden state {i+1}')
plt.title('Hidden States Over Time')
plt.legend()
# 图3:遗忘门的输出
plt.subplot(4, 1, 3)
for i in range(hidden_size):
plt.plot(f_gates[:, i], label=f'forget gate {i+1}')
plt.title('Forget Gates Over Time')
plt.legend()
# 图4:记忆单元状态变化
plt.subplot(4, 1, 4)
for i in range(hidden_size):
plt.plot(C_states[:, i], label=f'cell state {i+1}')
plt.title('Cell States Over Time')
plt.legend()
plt.tight_layout()
plt.show()
1. LSTM实现:我们手动实现了一个LSTM网络,包括遗忘门、输入门、候选记忆状态、输出门和最终隐藏状态的更新。
2. 虚拟数据:使用正弦波作为输入数据。
3. 状态保存:每个时间步,我们记录了隐藏状态、记忆单元状态和各个门的输出,以便在图中展示。
下面,咱们来聊聊数据分析展示的图形要说明的问题:
1. 图1:输入数据的时间序列:显示我们使用的正弦波输入。
2. 图2:隐藏状态的变化:展示了LSTM隐藏状态随时间的变化。
3. 图3:遗忘门的输出:展示了每个时刻遗忘门的输出值,反映LSTM是否“遗忘”之前的信息。
4. 图4:记忆单元状态变化:显示了每个时间步LSTM记忆单元状态的变化,反映LSTM如何保留或更新长期信息。
通过手动实现LSTM,我们展示了每个门的工作原理以及它们对隐藏状态和记忆状态的影响。分析这些状态图,有助于大家理解LSTM如何处理序列数据,并且能够灵活地记住重要信息,忽略不相关的信息。