X Tutup
{ "cells": [ { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "## 一、RNN从零开始实现" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "%matplotlib inline\n", "import math\n", "import torch\n", "from torch import nn\n", "from torch.nn import functional as F\n", "from d2l import torch as d2l\n", "\n", "batch_size, num_steps = 32, 35\n", "train_iter, vocab = d2l.load_data_time_machine(batch_size, num_steps)" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "### 独热编码" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0],\n", " [0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0]])" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "F.one_hot(torch.tensor([0, 2]), len(vocab)) #将[0, 2]展开为长度为len(vocab)大小的独热向量" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "### 小批量数据形状是 (批量大小, 时间步数)" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "torch.Size([5, 2, 28])" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "X = torch.arange(10).reshape((2, 5)) # (batch_size, n_step)\n", "F.one_hot(X.T, 28).shape # (n_step, batch_size, n_features)" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "### 初始化循环神经网络模型的模型参数" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "def get_params(vocab_size, num_hiddens, device):\n", " num_inputs = num_outputs = vocab_size\n", " \n", " def normal(shape):\n", " return torch.randn(size=shape, device=device) * 0.01\n", " \n", " #隐藏层参数\n", " W_xh = normal((num_inputs, num_hiddens))\n", " W_hh = normal((num_hiddens, num_hiddens))\n", " b_h = torch.zeros(num_hiddens, device=device)\n", " #输出参数\n", " W_hq = normal((num_hiddens, num_outputs))\n", " b_q = torch.zeros(num_outputs, device=device)\n", " #附加梯度\n", " params = [W_xh, W_hh, b_h, W_hq, b_q]\n", " for param in params:\n", " param.requires_grad_(True)\n", " return params" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "### init_rnn_state函数:在初始化时返回隐藏状态" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "def init_rnn_state(batch_size, num_hiddens, device): #返回初始隐层状态\n", " return (torch.zeros((batch_size, num_hiddens), device=device), )" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### rnn函数:定义**一个时间步内**计算隐藏状态和输出\n", "更新隐藏状态: $$h_t = \\phi(W_{hh}h_{t-1}+W_{hx}x_{t-1}+b_{h})$$\n", "输出: $$o_{t}=\\phi(W_{ho}h_{t}+b_{o})$$" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "def rnn(inputs, state, params):\n", " W_xh, W_hh, b_h, W_hq, b_q = params\n", " H, = state\n", " outputs = [] #n_step个大小为(batch_size, n_outputs)的torch张量列表\n", " # inputs: (n_step, batch_size, n_features)\n", " for X in inputs: # 按时序遍历\n", " H = torch.tanh(torch.mm(X, W_xh) + torch.mm(H, W_hh) + b_h) # (batch_size, n_hiddens)\n", " Y = torch.mm(H, W_hq) + b_q # (batch_size, n_outputs)\n", " outputs.append(Y) \n", " return torch.cat(outputs, dim=0), (H,) #cat后维数(n_step * batch_size, n_outputs)" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "### 创建一个类来包装这些函数" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "class RNNModelScratch:\n", " \"\"\"从零开始实现的循环神经网络模型\"\"\"\n", " def __init__(self, vocab_size, num_hiddens, device, get_params,\n", " init_state, forward_fn):\n", " self.vocab_size, self.num_hiddens = vocab_size, num_hiddens\n", " self.params = get_params(vocab_size, num_hiddens, device) #获得模型初始参数\n", " self.init_state, self.forward_fn = init_state, forward_fn #隐层初始函数, 前馈函数 \n", " #注意前馈函数可以换成gru, lstm等\n", " \n", " def __call__(self, X, state):\n", " #输入X: (batch_size, n_step)\n", " #转置+onehot后 X:(n_step, batch_size, n_features)\n", " X = F.one_hot(X.T, self.vocab_size).type(torch.float32)\n", " return self.forward_fn(X, state, self.params)\n", " \n", " def begin_state(self, batch_size, device):\n", " return self.init_state(batch_size, self.num_hiddens, device)" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "### 检查输出是否具有正确的形状" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(torch.Size([10, 28]), 1, torch.Size([2, 512]))" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "num_hiddens = 512\n", "net = RNNModelScratch(len(vocab), num_hiddens, d2l.try_gpu(), \n", " get_params, init_rnn_state, rnn)\n", "# X: (2, 5) 对应(batch_size, n_step)\n", "state = net.begin_state(X.shape[0], d2l.try_gpu())\n", "Y, new_state = net(X.to(d2l.try_gpu()), state)\n", "#Y: (batch_size * n_step, n_outputs) \n", "#new_state中一个torch张量(最后一个时间步的隐层)\n", "#new_state[0].shape: (batch_size, n_hiddens)\n", "Y.shape, len(new_state), new_state[0].shape" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "### 首先定义预测函数来生成prefix之后的新字符" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "'time travellerrrrrrrrrrr'" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "def predict_ch8(prefix, num_preds, net, vocab, device):\n", " \"\"\"在`prefix`后面生成新字符\"\"\"\n", " #生成初始隐藏状态\n", " state = net.begin_state(batch_size=1, device=device) \n", " outputs = [vocab[prefix[0]]] #第一个word的整型下标\n", " #将最近预测的词做成tensor, batch_size=1, n_step=1\n", " get_input = lambda: torch.tensor([outputs[-1]], device=device).reshape((1, 1))\n", " for y in prefix[1:]: # 预热操作, 保存真值\n", " _, state = net(get_input(), state)\n", " outputs.append(vocab[y])\n", " for _ in range(num_preds): # 预测num_preds步\n", " y, state = net(get_input(), state)\n", " outputs.append(int(y.argmax(dim=1).reshape(1)))\n", " return ''.join([vocab.idx_to_token[i] for i in outputs])\n", "\n", "predict_ch8('time traveller', 10, net, vocab, d2l.try_gpu())" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "### 梯度裁剪\n", "$$\\mathbf{g}\\leftarrow min(1, \\frac{\\theta}{\\parallel \\mathbf{g} \\parallel}) \\mathbf{g}$$" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "def grad_clipping(net, theta):\n", " \"\"\"裁剪梯度\"\"\"\n", " if isinstance(net, nn.Module):#如果使用nn.Module来实现\n", " params = [p for p in net.parameters() if p.requires_grad]\n", " else:\n", " params = net.params\n", " norm = torch.sqrt(sum(torch.sum(\n", " (p.grad**2)) for p in params))\n", " if norm > theta:\n", " for param in params:\n", " param.grad[:] *= theta / norm" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "### 查看train_iter数据集" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "torch.Size([32, 35]) torch.Size([32, 35])\n", "tensor([[ 1, 3, 5, ..., 2, 1, 15],\n", " [ 4, 6, 11, ..., 5, 10, 8],\n", " [ 3, 1, 4, ..., 2, 8, 8],\n", " ...,\n", " [15, 7, 6, ..., 21, 14, 3],\n", " [10, 19, 8, ..., 14, 8, 3],\n", " [ 1, 13, 2, ..., 10, 1, 4]])\n", "tensor([[ 3, 5, 13, ..., 1, 15, 7],\n", " [ 6, 11, 20, ..., 10, 8, 1],\n", " [ 1, 4, 6, ..., 8, 8, 1],\n", " ...,\n", " [ 7, 6, 26, ..., 14, 3, 21],\n", " [19, 8, 3, ..., 8, 3, 1],\n", " [13, 2, 15, ..., 1, 4, 6]])\n", " time traveller for so it will be c\n", "time traveller for so it will be co\n", "andpassed in our glasses our chairs\n", "ndpassed in our glasses our chairs \n", "\n", "onvenient to speak of himwas expoun\n", "nvenient to speak of himwas expound\n", "8\n" ] } ], "source": [ "count = 0\n", "for X, Y in train_iter:\n", " if count == 0:#第0个batch\n", " print(X.shape, Y.shape)\n", " print(X) # (batch_size(=32), n_step(=35))\n", " print(Y) # (batch_size, n_step)\n", " print(''.join([vocab.idx_to_token[i] for i in X[0]])) #打印第0个样本对应句子\n", " print(''.join([vocab.idx_to_token[i] for i in Y[0]])) #打印第0个样本真值\n", " print(''.join([vocab.idx_to_token[i] for i in X[1]])) #打印第1个样本对应句子\n", " print(''.join([vocab.idx_to_token[i] for i in Y[1]])) #打印第1个样本真值\n", " print()\n", " if count == 1:#第1个batch, 内容和第0个batch上下承接(有时序关系)\n", " print(''.join([vocab.idx_to_token[i] for i in X[0]]))\n", " print(''.join([vocab.idx_to_token[i] for i in Y[0]]))\n", " count += 1\n", "print(count) #打印batch数量=8" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "### 定义一个函数在一个迭代周期内训练模型" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [], "source": [ "def train_epoch_ch8(net, train_iter, loss, updater, device,\n", " use_random_iter):\n", " \"\"\"训练模型一个迭代周期(定义见第8章)\"\"\"\n", " state, timer = None, d2l.Timer()\n", " metric = d2l.Accumulator(2)\n", " for X, Y in train_iter:\n", " if state is None or use_random_iter:#为第一个batch 或者 batch之间时序上不连续\n", " state = net.begin_state(batch_size=X.shape[0], device=device) #初始化state\n", " else:\n", " if isinstance(net, nn.Module) and not isinstance(state, tuple):\n", " # state对于nn.GRU是个张量\n", " state.detach_() # 对之前的部分取消梯度反向传播计算\n", " else:\n", " # state对于nn.LSTM或者对于我们从零开始实现的模型是个元组(张量构成)\n", " for s in state:\n", " s.detach_()\n", " y = Y.T.reshape(-1) #reshape真值, 将n_step放在第一维之后拉成一维向量\n", " X, y = X.to(device), y.to(device)\n", " y_hat,state = net(X, state)\n", " l = loss(y_hat, y.long()).mean()\n", " if isinstance(updater, torch.optim.Optimizer):#调用torch优化函数实现\n", " updater.zero_grad()\n", " l.backward()\n", " grad_clipping(net, 1)\n", " updater.step()\n", " else:\n", " l.backward()\n", " grad_clipping(net, 1)\n", " updater(batch_size=1)\n", " metric.add(l * y.numel(), y.numel())\n", " return math.exp(metric[0] / metric[1]), metric[1] / timer.stop()" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "### 循环神经网络模型的训练函数既支持从零开始实现,也可以使用高级API实现" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [], "source": [ "def train_ch8(net, train_iter, vocab, lr, num_epochs, device, use_random_iter=False):\n", " \"\"\"训练模型(定义见第8章)\"\"\"\n", " loss = nn.CrossEntropyLoss()\n", " animator = d2l.Animator(xlabel='epoch', ylabel='perplexity',\n", " legend=['train'], xlim=[10,num_epochs])\n", " #初始化优化器\n", " if isinstance(net, nn.Module):\n", " updater = torch.optim.SGD(net.parameters(), lr)\n", " else:\n", " updater = lambda batch_size: d2l.sgd(net.params, lr, batch_size)\n", " predict = lambda prefix: predict_ch8(prefix, 50, net, vocab, device)\n", " #训练和预测\n", " for epoch in range(num_epochs):\n", " ppl, speed = train_epoch_ch8(\n", " net, train_iter, loss, updater, device, use_random_iter)\n", " if (epoch + 1) % 10 == 0:\n", " print(predict('time traveller'))\n", " animator.add(epoch+1, [ppl])\n", " print(f'困惑度 {ppl:.1f}, {speed:.1f} 词元/秒 {str(device)}')\n", " print(predict('time traveller'))\n", " print(predict('traveller'))" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "### 训练循环神经网络模型(按序迭代batch)" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "困惑度 1.0, 46320.2 词元/秒 cuda:0\n", "time travelleryou can show black is white by argument said filby\n", "travelleryou can show black is white by argument said filby\n" ] }, { "data": { "image/svg+xml": [ "\r\n", "\r\n", "\r\n", "\r\n", " \r\n", " \r\n", " \r\n", " \r\n", " 2022-02-07T11:39:19.259185\r\n", " image/svg+xml\r\n", " \r\n", " \r\n", " Matplotlib v3.3.3, https://matplotlib.org/\r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", "\r\n" ], "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "num_epochs, lr = 500, 1\n", "train_ch8(net, train_iter, vocab, lr, num_epochs, d2l.try_gpu())" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "### 训练循环神经网络模型(随机迭代batch)" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "困惑度 1.3, 36524.6 词元/秒 cuda:0\n", "time traveller held in his hand was a glitteringmetallic framewo\n", "travellerit s against reason said filbycan a cube that does\n" ] }, { "data": { "image/svg+xml": [ "\r\n", "\r\n", "\r\n", "\r\n", " \r\n", " \r\n", " \r\n", " \r\n", " 2022-02-07T11:42:07.023084\r\n", " image/svg+xml\r\n", " \r\n", " \r\n", " Matplotlib v3.3.3, https://matplotlib.org/\r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", "\r\n" ], "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "net = RNNModelScratch(len(vocab), num_hiddens, d2l.try_gpu(), get_params,\n", " init_rnn_state, rnn)\n", "train_ch8(net, train_iter, vocab, lr, num_epochs, d2l.try_gpu(),\n", " use_random_iter=True)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**从零开始实现**上述循环神经网络模型, 虽然有指导意义,但是并不方便。 在下一节中,我们将学习如何改进循环神经网络模型。" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "## 二、循环神经网络的简洁实现" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [], "source": [ "import torch\n", "from torch import nn\n", "from torch.nn import functional as F\n", "from d2l import torch as d2l\n", "\n", "batch_size, num_steps = 32, 35\n", "train_iter, vocab = d2l.load_data_time_machine(batch_size, num_steps) #加载数据" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "### 定义模型\n", "nn.RNN(input_size, hidden_size, num_layers=1, nonlinearity=tanh, bias=True, batch_first=False, dropout=0, bidirectional=False)" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [], "source": [ "num_hiddens = 256\n", "rnn_layer = nn.RNN(len(vocab), num_hiddens)" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "### 使用张量来初始化隐藏状态" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "torch.Size([1, 32, 256])" ] }, "execution_count": 19, "metadata": {}, "output_type": "execute_result" } ], "source": [ "state = torch.zeros((1, batch_size, num_hiddens))\n", "state.shape #(D * num_layers(=1), batch_size, num_hiddens) " ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "### 通过一个隐藏状态和一个输入,我们就可以用更新后的隐藏状态计算输出" ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(torch.Size([35, 32, 256]), torch.Size([1, 32, 256]))" ] }, "execution_count": 20, "metadata": {}, "output_type": "execute_result" } ], "source": [ "X = torch.rand(size=(num_steps, batch_size, len(vocab)))# (n_step, batch_size, num_inputs)\n", "Y, state_new = rnn_layer(X, state)\n", "Y.shape, state_new.shape" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "### 定义RNNModel类:完整的循环神经网络模型" ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [], "source": [ "class RNNModel(nn.Module):\n", " def __init__(self, rnn_layer, vocab_size, **kwargs):\n", " super(RNNModel, self).__init__(**kwargs)\n", " self.rnn = rnn_layer\n", " self.vocab_size = vocab_size\n", " self.num_hiddens = self.rnn.hidden_size\n", " if not self.rnn.bidirectional: #如果是双向\n", " self.num_directions = 1\n", " self.linear = nn.Linear(self.num_hiddens, self.vocab_size) #线性层 /输出层\n", " else:\n", " self.num_directions = 2\n", " self.linear = nn.Linear(self.num_hiddens, self.vocab_size)\n", " \n", " def forward(self, inputs, state):\n", " X = F.one_hot(inputs.T.long(), self.vocab_size)\n", " X = X.to(torch.float32)\n", " Y, state = self.rnn(X, state)\n", " output = self.linear(Y.reshape((-1, Y.shape[-1])))\n", " return output, state\n", " \n", " def begin_state(self, device, batch_size=1):\n", " if not isinstance(self.rnn, nn.LSTM):\n", " # nn.GRU以张量作为隐状态\n", " return torch.zeros((self.num_directions * self.rnn.num_layers, \n", " batch_size, self.num_hiddens),\n", " device = device)\n", " else:\n", " # nn.LSTM以元组作为隐状态\n", " return (torch.zeros((\n", " self.num_directions * self.rnn.num_layers,\n", " batch_size, self.num_hiddens), device=device),\n", " torch.zeros((\n", " self.num_directions * self.rnn.num_layers,\n", " batch_size, self.num_hiddens), device=device))#(h_n, c_n)" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "### 训练与预测\n", "在训练模型之前,让我们基于一个具有随机权重的模型进行预测。" ] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "'time travelleridandand'" ] }, "execution_count": 22, "metadata": {}, "output_type": "execute_result" } ], "source": [ "device = d2l.try_gpu()\n", "net = RNNModel(rnn_layer, vocab_size=len(vocab))\n", "net = net.to(device)\n", "d2l.predict_ch8('time traveller', 10, net, vocab, device)" ] }, { "cell_type": "code", "execution_count": 23, "metadata": { "slideshow": { "slide_type": "slide" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "perplexity 1.3, 152273.6 tokens/sec on cuda:0\n", "time traveller coud and inn weridit so mimens of the pramithtred\n", "traveller his fictses tor hime hal very is f enghas ow llow\n" ] }, { "data": { "image/svg+xml": [ "\r\n", "\r\n", "\r\n", "\r\n", " \r\n", " \r\n", " \r\n", " \r\n", " 2022-02-07T14:08:12.978684\r\n", " image/svg+xml\r\n", " \r\n", " \r\n", " Matplotlib v3.3.3, https://matplotlib.org/\r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", "\r\n" ], "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "num_epochs, lr = 500, 1\n", "d2l.train_ch8(net, train_iter, vocab, lr, num_epochs, device)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "与上一节相比,由于深度学习框架的高级API对代码进行了更多的优化, 该模型在较短的时间内达到了较低的困惑度。" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "### 总结\n", "- **深度学习框架**的高级API提供了循环神经网络层的实现。\n", "\n", "- 高级API的循环神经网络层返回一个输出和一个更新后的隐状态,我们**还需要**计算整个模型的**输出层**。\n", "\n", "- 相比从零开始实现的循环神经网络,使用**高级API**实现可以**加速训练**。" ] } ], "metadata": { "celltoolbar": "幻灯片", "kernelspec": { "display_name": "deep2learn", "language": "python", "name": "deep2learn" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.7" } }, "nbformat": 4, "nbformat_minor": 4 }
X Tutup