8. ニューラルネットワーク (2)#

import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation
from IPython.display import HTML

\(\def\bm{\boldsymbol}\)前章では、多層のニューラルネットワークを一般的に表現する方法を説明した。また、論理和(OR)や論理積(AND)などの論理結合子に対応する単層ニューラルネットワーク(線形分類器)のパラメータを学習で求め、その組み合わせで排他的論理和(XOR)の入出力を再現する多層ニューラルネットワークを構成する例を示した。排他的論理和の例では、入力と出力の組み合わせの数が少ないうえ、入力を出力に対応付ける背景知識(論理)が明確であった。ところが、機械学習が用いられるシナリオは、入力を出力に対応づけるメカニズムが不明な状況において、訓練データのみを用いて入力から出力を計算する写像を獲得することである。単層ニューラルネットワークの学習方法は回帰や線形分類などで説明した通りであるが、多層ニューラルネットワークの学習はどのようにすれば良いのだろうか?

本章では、ニューラルネットワークのモデルのパラメータを確率的勾配降下法で汎用的に学習する手法として、計算グラフ(computation graph)による自動微分(automatic differentiation)を紹介する。その後、PyTorchによる実装を紹介する。

まず、確率的勾配降下法の更新式を思い出してみたい。モデルに含まれる全てのパラメータを1次元に並べてベクトルとして一般的に表現したものを\(\bm{\theta}\)と書くことにする。例えば、式(7.24)の例では、行列\(\bm{W}^{(1)}, \bm{W}^{(2)}\)の全ての要素、ベクトル\(\bm{b}^{(1)}\)の全ての要素、スカラー\(b^{(2)}\)を並べたベクトルが\(\bm{\theta}\)である。ある訓練事例\((\bm{x}, \bm{y})\)に対する損失関数を\(\hat{l}_{\bm{x}, \bm{y}}(\bm{\theta}^{(t)})\)と定義する。\(t\)回目の反復時のパラメータの値を\(\bm{\theta}^{(t)}\)と書くことにすると、パラメータの更新式は、

(8.1)#\[\begin{align} \bm{\theta}^{(t+1)} = \bm{\theta}^{(t)} - \eta_t \nabla \hat{l}_{\bm{x}, \bm{y}}(\bm{\theta}^{(t)}) \end{align}\]

パラメータベクトル\(\bm{\theta}\)の要素数を\(K\)とすると、勾配\(\nabla \hat{l}_{\bm{x}, \bm{y}}(\bm{\theta})\)は、

(8.2)#\[\begin{align} \nabla \hat{l}_{\bm{x}, \bm{y}}(\bm{\theta}) = \begin{pmatrix} \frac{\partial \hat{l}_{\bm{x}, \bm{y}}(\bm{\theta})}{\partial \theta_1} & \frac{\partial \hat{l}_{\bm{x}, \bm{y}}(\bm{\theta})}{\partial \theta_2} & \dots & \frac{\partial \hat{l}_{\bm{x}, \bm{y}}(\bm{\theta})}{\partial \theta_K} \end{pmatrix} \end{align}\]

である。ゆえに、損失関数\(\hat{l}_{\bm{x}, \bm{y}}(\bm{\theta})\)に対して、全てのパラメータ\(\theta_k\) (\(k \in \{1, 2, \dots, K\}\)) に関する偏微分を求めることができれば、確率的勾配降下法を適用できる。これまでは、損失関数の定義から合成関数の微分や連鎖律を駆使して偏微分を求め、回帰では式(4.32)、線形二値分類では式(5.31)、線形多値分類では式(6.33)を得た。しかし、多層ニューラルネットワークになると損失値を計算するまでに関数の合成が何回も繰り返されるため、偏微分を手作業で求めるのは大変である。また、用いるニューラルネットワークの形状や活性化関数によって偏微分を求める式が変わってしまうため、全ての場合に対して手作業で偏微分を求めるのは現実的ではない。そこで、ニューラルネットワークの学習では損失関数を計算グラフとして表現し、自動微分で偏微分を求めるのが定石となっている。これは、ニューラルネットワークの学習方法として考案された誤差逆伝播法(backpropagation)と本質的に同じものである。

8.1. 計算グラフと自動微分#

計算グラフ(computation graph)は、計算式の中にある変数や定数を葉ノード、計算を行う順番に沿って演算子を内部ノードとして配置し、ボトムアップに計算を行うことで、全体の計算結果を根ノードで得られるようにした木構造である。このように説明すると難しそうであるが、例を見れば単純明快である。

8.1.1. 前向き計算#

以下に\(y=ax+b\)の計算グラフを示した。ここで、\((a, b) = (-2, 4), x = 3\)であるとき、\(y\)を計算するには計算グラフの葉からボトムアップに演算子(内部ノード)を適用し、途中の計算結果を保持しながら根ノード\(y\)まで計算を進めればよい。以下の図では、この計算過程をアニメーションで見ることができる。このアニメーションでは葉が左側、根が右側に配置されているため、左から右に向かって計算が進む。この計算過程を前向き計算(forward computation)と呼ぶ。

Hide code cell source
def find_renderer(fig):
    if hasattr(fig.canvas, "get_renderer"):
        renderer = fig.canvas.get_renderer()
    else:
        import io
        fig.canvas.print_pdf(io.BytesIO())
        renderer = fig._cachedRenderer
    return renderer

def draw_variable(ax, x, y, s, requires_grad=False):
    bbox = bbox=dict(boxstyle='circle', fill=False, linewidth=0.5) if requires_grad else None
    return ax.text(x, y, s, bbox=bbox, ha='center', va='center', fontsize=config['variable-font-size'])
    
def draw_operator(ax, x, y, s):
    return ax.text(x, y, s, bbox=dict(boxstyle='square', fill=False, linewidth=0.5), ha='center', va='center', fontsize=config['operator-font-size'])

def draw_arrow(ax, x0, x1, y0, y1, style='arrow'):
    cstyle = 'arc3'
    if style == 'below':
        cstyle = 'angle,angleA=0,angleB=90,rad=0.0'
    elif style == 'above':
        cstyle = 'angle,angleA=0,angleB=-90,rad=0.0'        
    ax.annotate(
        "",
        xy=(x1, y1), xycoords='data',
        xytext=(x0, y0), textcoords='data',
        arrowprops=dict(arrowstyle="-|>,head_length=0.2,head_width=0.1", connectionstyle=cstyle, facecolor='black', linewidth=0.5),
        )

def write_message(msg):
    A = []
    p = msg.find(r' \\ ')
    if p != -1:
        msg1 = msg[:p]
        msg2 = msg[p+4:]
        A.append(ax.text(config['message-x'], config['message-y'], msg1, ha='center', va='center', fontsize=config['message-font-size']))
        A.append(ax.text(config['message-x'], config['message-y']-0.5, msg2, ha='center', va='center', fontsize=config['message-font-size']))
    else:
        A.append(ax.text(config['message-x'], config['message-y'], msg, ha='center', va='center', fontsize=config['message-font-size']))
    return A

def write_value(x, y, v, name, updated=False, backward=False):
    A = []
    c = 'tab:red' if updated else ('tab:blue' if backward else 'tab:gray')
    if name is not None:
        A.append(ax.text(x, y+0.2, name, ha='center', va='center', fontsize=config['value-font-size'], color=c))
    A.append(ax.text(x, y, v, ha='center', va='center', fontsize=config['value-font-size'], color=c))
    return A

def forward(ax, S, t):    
    artists = []
    backward = False
    for i in range(1, t+1):
        s = S[i-1]
        if s is None or len(s) < 4:
            continue
        if s[0] == 'Backward pass':
            backward = True
        name = s[4] if len(s) >= 5 else None
        artists += write_value(s[1], s[2], s[3], name, t==i, backward)
            
    s = S[t-1]
    if s is not None:
        artists += write_message(s[0])
    
    return artists

config = {
    'variable-font-size': 10,
    'operator-font-size': 8,
    'message-font-size': 10,
    'value-font-size': 10,
    'message-x': 3,
    'message-y': 1,
}

def draw_computation_graph(ax):
    draw_variable(ax, 1, 4, '$a$', requires_grad=True)
    draw_variable(ax, 0.5, 4, '$-2$')
    draw_variable(ax, 1, 3, '$x$', requires_grad=True)
    draw_variable(ax, 0.5, 3, '$3$')
    draw_variable(ax, 1, 2, '$b$', requires_grad=True)
    draw_variable(ax, 0.5, 2, '$4$')
    draw_variable(ax, 3.3, 4, r'$\alpha = ax$')
    draw_variable(ax, 5.5, 2.7, r'$y$')
    draw_operator(ax, 2.5, 3.5, r'$\times$')
    draw_operator(ax, 4, 2.7, '$+$')
    draw_arrow(ax, 1.15, 2.5, 3, 3.4, 'below')
    draw_arrow(ax, 1.15, 2.5, 4, 3.6, 'above')
    draw_arrow(ax, 1.15, 4, 2, 2.6, 'below')
    draw_arrow(ax, 2.6, 4, 3.5, 2.8, 'above')
    draw_arrow(ax, 4.1, 5.4, 2.7, 2.7)

S = [
    [r'Forward pass'],
    [r'$a \cdot x = (-2) \times 3 = -6$', 3.3, 3.7, r'-6'],
    [r'$ax + b = -6 + 4 = -2$', 4.7, 2.9, '-2'],
    None,
    [r'Backward pass', 4.7, 2.4, '1'],
    [r'$y = \alpha + b$, $\frac{\partial y}{\partial \alpha} = 1$', 3.3, 3.2, r'$\frac{\partial y}{\partial \alpha} = 1$'],
    [r'$y = \alpha + b$, $\frac{\partial y}{\partial b} = 1$', 2.7, 1.7, r'$\frac{\partial y}{\partial b} = 1$'],
    [r'$\alpha = ax$, $\frac{\partial \alpha}{\partial a} = x$, $\frac{\partial y}{\partial a} = \frac{\partial y}{\partial \alpha}\cdot\frac{\partial \alpha}{\partial a}=1 \times 3 = 3$', 1.8, 3.7, r'$\frac{\partial y}{\partial a}=3$'],
    [r'$\alpha = ax$, $\frac{\partial \alpha}{\partial x} = a$, $\frac{\partial y}{\partial x} = \frac{\partial y}{\partial \alpha}\cdot\frac{\partial \alpha}{\partial x}=1 \times (-2) = -2$', 1.9, 2.7, r'$\frac{\partial y}{\partial x}=-2$'],
    None,
]

fig, ax = plt.subplots(dpi=200, figsize=(4, 3), frameon=False)
fig.subplots_adjust(left=0, bottom=0, right=1, top=1, wspace=None, hspace=None)
plt.box(False)
plt.axis('off')
ax.set_aspect('equal')
ax.set_xlim(0.2, 5.8)
ax.set_ylim(0.4, 4.5)
draw_computation_graph(ax)

#forward(ax, S, 10)
#plt.show()

A = []
for t in range(3+1):
    A.append(forward(ax, S, t))
    
ani = matplotlib.animation.ArtistAnimation(fig, A, interval=2000)
html = ani.to_jshtml()
plt.close(fig)
HTML(html)

8.1.2. 後ろ向き自動微分#

さて、本題はここからである。この計算グラフを使うと、根ノードを任意のノードで偏微分した値を一定の手続きに沿って求めることができる。以下のアニメーションでは、\(y = ax + b\)\(a, b, x\)に関して偏微分した結果を求めている。

Hide code cell source
fig, ax = plt.subplots(dpi=200, figsize=(4, 3), frameon=False)
fig.subplots_adjust(left=0, bottom=0, right=1, top=1, wspace=None, hspace=None)
plt.box(False)
plt.axis('off')
ax.set_aspect('equal')
ax.set_xlim(0.2, 5.8)
ax.set_ylim(0.4, 4.5)
draw_computation_graph(ax)

A = []
for t in range(4, len(S)+1):
    A.append(forward(ax, S, t))
    
ani = matplotlib.animation.ArtistAnimation(fig, A, interval=2000)
html = ani.to_jshtml()
plt.close(fig)
HTML(html)

後ろ向き自動微分では、根から葉に向かって偏微分の値を計算していく。このアニメーションでは、葉が左側、根が右側に配置されているので、右から左へと偏微分の計算が行われていく。そのため、この計算過程は後ろ向き自動微分(backward automatic differentiation)やリバースモード自動微分(reverse-mode automatic differentiation)と呼ばれている。

後ろ向き自動微分では、根に対応する部分に\(1\)を書くことからスタートする。続けて、計算グラフ上の根(右側)から葉(左側)に向かって偏微分の値を計算していく。計算グラフ中のノードを右から左に通過するたびに、ノードの右側を左側で偏微分した値と、ノードの右側で既に計算されている偏微分値の積を計算し、左側に書き込んでいく。この手続きを再帰的に繰り返していくと、全てのノードに関する偏微分の値を計算できる。

このアニメーションの各コマで求められた偏微分の計算過程を以下に示す。

(8.3)#\[\begin{align} \frac{\partial y}{\partial y} &= 1 \\ \frac{\partial y}{\partial \alpha} &= \frac{\partial }{\partial \alpha} (\alpha + b) = 1 \\ \frac{\partial y}{\partial b} &= \frac{\partial }{\partial b} (\alpha + b) = 1 \\ \frac{\partial y}{\partial a} &= \frac{\partial y}{\partial \alpha} \frac{\partial \alpha}{\partial a} = \frac{\partial y}{\partial \alpha} \frac{\partial }{\partial a} (ax) = 1 \times x = 3 \\ \frac{\partial y}{\partial x} &= \frac{\partial y}{\partial \alpha} \frac{\partial \alpha}{\partial x} = \frac{\partial y}{\partial \alpha} \frac{\partial }{\partial x} (a) = 1 \times a = -2 \\ \end{align}\]

ここで注目して欲しいのは、\(\frac{\partial y}{\partial a}\)\(\frac{\partial y}{\partial x}\)の計算方法(下から1行目と2行目)である。この計算方法では\(y = \alpha + b, \alpha = ax\)の合成関数であると考え、連鎖律を用いて求めている。このとき、\(\frac{\partial y}{\partial \alpha} = 1\)は2行目で求めてあるので、\(\alpha = ax\)の部分だけを考えればよい。\(\frac{\partial \alpha}{\partial a}\)\(\frac{\partial \alpha}{\partial x}\)を求めたのち、\(\frac{\partial y}{\partial \alpha}\)との積を、それぞれ、計算している。

8.1.3. 損失関数を自動微分する例#

後ろ向き自動微分の実行過程は明確ではあるが、この単純な例ではその威力が分かりにくいかもしれない。より実践的な例として、単層ニューラルネットワーク(線形二値分類モデル)対して、二値クロスエントロピーで定義された損失関数に基づき、パラメータの勾配を求める例を考えよう。特徴ベクトルの要素数を\(d\)、訓練事例を\((\bm{x}, y) \in \mathbb{R}^d \times \{0, 1\}\)、単層ニューラルネットワークのパラメータを\(\bm{w} \in \mathbb{R}^d\)、活性化関数にシグモイド関数\(\sigma\)を用いることにすると、損失関数は\(\log \hat{l}_{\bm{x}, y}(\bm{w})\)で表される。

(8.4)#\[\begin{align} \hat{l}_{\bm{x}, y}(\bm{w}) &= -y \log p - (1-y) \log (1-p) \\ p &= \sigma(\bm{x}^\top \bm{w}) = \frac{1}{1 + e^{-\bm{x}^\top \bm{w}}} \end{align}\]

ここでは、選ばれた訓練事例は正例(\(y = 1\))とする。

(8.5)#\[\begin{align} \hat{l}_{\bm{x}, 1}(\bm{w}) = -\log p = -\log \sigma(\bm{x}^\top \bm{w}) = -\log \frac{1}{1 + e^{-\bm{x}^\top \bm{w}}} \end{align}\]

さらに、特徴ベクトルの次元数\(d = 2\)の場合を考えると、損失関数は、

(8.6)#\[ \begin{align} \hat{l}_{x_1, x_2}(w_1, w_2) = -\log \frac{1}{1 + e^{-(w_1 x_1 + w_2 x_2)}} \end{align} \]

\((x_1, x_2) = (1, -1), (w_1, w_2) = (1, 0.5)\)であるとき、計算グラフを使って損失関数の値を求める過程と、\(w_1, w_2\)に関する偏微分を求める過程をアニメーションで示す。

Hide code cell source
config = {
    'variable-font-size': 6,
    'operator-font-size': 6,
    'message-font-size': 7,
    'value-font-size': 5,
    'message-x': 6.0,
    'message-y': 4.1,
}

fig, ax = plt.subplots(dpi=200, figsize=(4, 2))
fig.subplots_adjust(left=0, bottom=0, right=1, top=1, wspace=None, hspace=None)
plt.box(False)
plt.axis('off')
ax.set_aspect('equal')
ax.set_xlim(0.2, 10.3)
ax.set_ylim(0.4, 4.7)

draw_variable(ax, 1, 4, '$w_1$', requires_grad=True)
draw_variable(ax, 0.5, 4, '$1$')
draw_variable(ax, 1, 3, '$x_1$')
draw_variable(ax, 0.5, 3, '$1$')
draw_variable(ax, 1, 2, '$w_2$', requires_grad=True)
draw_variable(ax, 0.5, 2, '$0.5$')
draw_variable(ax, 1, 1, '$x_2$')
draw_variable(ax, 0.5, 1, '$-1$')
draw_variable(ax, 4, 1.5, '$-1$')
draw_variable(ax, 6, 1.5, '$1$')
draw_variable(ax, 9, 1.5, '$-1$')
draw_operator(ax, 2, 1.5, r'$\times$')
draw_operator(ax, 2, 3.5, r'$\times$')
draw_operator(ax, 3, 2.5, '$+$')
draw_operator(ax, 4, 2.5, r'$\times$')
draw_operator(ax, 5, 2.5, r'$\exp$')
draw_operator(ax, 6, 2.5, '$+$')
draw_operator(ax, 7, 2.5, '$1/a$')
draw_operator(ax, 8, 2.5, r'$\log$')
draw_operator(ax, 9, 2.5, r'$\times$')
draw_variable(ax, 10, 2.5, '$\hat{l}$')
draw_arrow(ax, 1.15, 2, 1, 1.4, 'below')
draw_arrow(ax, 1.15, 2, 2, 1.6, 'above')
draw_arrow(ax, 1.15, 2, 3, 3.4, 'below')
draw_arrow(ax, 1.15, 2, 4, 3.6, 'above')
draw_arrow(ax, 2.1, 3, 1.5, 2.4, 'below')
draw_arrow(ax, 2.1, 3, 3.5, 2.6, 'above')
draw_arrow(ax, 3.1, 3.9, 2.5, 2.5)
draw_arrow(ax, 4, 4, 1.7, 2.4)
draw_arrow(ax, 4.1, 4.8, 2.5, 2.5)
draw_arrow(ax, 5.2, 5.9, 2.5, 2.5)
draw_arrow(ax, 6, 6, 1.7, 2.4)
draw_arrow(ax, 6.1, 6.8, 2.5, 2.5)
draw_arrow(ax, 7.2, 7.8, 2.5, 2.5)
draw_arrow(ax, 8.2, 8.9, 2.5, 2.5)
draw_arrow(ax, 9, 9, 1.7, 2.4)
draw_arrow(ax, 9.1, 9.9, 2.5, 2.5)

S = [
    [r'Forward pass'],
    [r'$w_1 x_1 = 1 \times 1 = 1$', 2.55, 3.7, '1', r'$\alpha$'],
    [r'$w_2 x_2 = 0.5 \times -1 = -0.5$', 2.55, 1.7, '-0.5', r'$\beta$'],
    [r'$1 + (-0.5) = 0.5$', 3.5, 2.8, '0.5', r'$\mu$'],
    [r'$0.5 \times (-1) = -0.5$', 4.5, 2.8, '-0.5', r'$\nu$'],
    [r'$\exp (-0.5) = 0.6065$', 5.5, 2.8, '0.6065', r'$\xi$'],
    [r'$0.6065 + 1 = 1.6065$', 6.5, 2.8, '1.6065', r'$\pi$'],
    [r'$\frac{1}{1.6065} = 0.6225$', 7.5, 2.8, '0.6225', r'$\rho$'],
    [r'$\log (0.6225) = -0.4740$', 8.5, 2.8, '-0.4740', r'$\phi$'],
    [r'$(-0.4740) \times (-1) = 0.4740$', 9.5, 2.8, '0.4740'],
    None,
    [r'Backward pass', 9.5, 2.1, '1'],
    [r'$\frac{\partial \hat{l}}{\partial \phi} = \frac{\partial }{\partial \phi} (-\phi) = -1$', 8.5, 2.1, '-1'],
    [r'$\frac{\partial \phi}{\partial \rho} = \frac{\partial }{\partial \rho} (\log \rho) = \frac{1}{\rho} = \frac{1}{0.6225} = 1.6065$ \\ $\frac{\partial \hat{l}}{\partial \rho} = \frac{\partial \hat{l}}{\partial \phi} \frac{\partial \phi}{\partial \rho} = -1 \times 1.6065 = -1.6065$', 7.5, 2.1, '-1.6065'],
    [r'$\frac{\partial \rho}{\partial \pi} = \frac{\partial }{\partial \pi} (\frac{1}{\pi}) = -\frac{1}{\pi^2} = -\frac{1}{1.6065^2} = -0.3875$ \\ $\frac{\partial \hat{l}}{\partial \pi} = \frac{\partial \hat{l}}{\partial \rho} \frac{\partial \rho}{\partial \pi} = -1.6065 \times (-0.3875) = 0.6224$', 6.5, 2.1, '0.6224'],
    [r'$\frac{\partial \pi}{\partial \xi} = \frac{\partial }{\partial \xi} (\xi + 1) = 1$ \\ $\frac{\partial \hat{l}}{\partial \xi} = \frac{\partial \hat{l}}{\partial \pi} \frac{\partial \pi}{\partial \xi} = 0.6224 \times 1 = 0.6224$', 5.5, 2.1, '0.6224'],
    [r'$\frac{\partial \xi}{\partial \nu} = \frac{\partial }{\partial \nu} (e^{\nu}) = e^{\nu} = e^{-0.5} = 0.6065$ \\ $\frac{\partial \hat{l}}{\partial \nu} = \frac{\partial \hat{l}}{\partial \xi} \frac{\partial \xi}{\partial \nu} = 0.6224 \times 0.6065 = 0.3775$', 4.5, 2.1, '0.3775'],
    [r'$\frac{\partial \nu}{\partial \mu} = \frac{\partial }{\partial \mu} (-\mu) = -1$ \\ $\frac{\partial \hat{l}}{\partial \mu} = \frac{\partial \hat{l}}{\partial \nu} \frac{\partial \nu}{\partial \mu} = 0.3775 \times (-1) = -0.3775$', 3.5, 2.1, '-0.3775'],
    [r'$\frac{\partial \mu}{\partial \alpha} = \frac{\partial }{\partial \alpha} (\alpha + \beta) = 1$ \\ $\frac{\partial \hat{l}}{\partial \alpha} = \frac{\partial \hat{l}}{\partial \mu} \frac{\partial \mu}{\partial \alpha} = -0.3775 \times 1 = -0.3775$', 2.55, 3.3, '-0.3775'],
    [r'$\frac{\partial \mu}{\partial \beta} = \frac{\partial }{\partial \beta} (\alpha + \beta) = 1$ \\ $\frac{\partial \hat{l}}{\partial \beta} = \frac{\partial \hat{l}}{\partial \mu} \frac{\partial \mu}{\partial \beta} = -0.3775 \times 1 = -0.3775$', 2.55, 1.3, '-0.3775'],
    [r'$\frac{\partial \alpha}{\partial w_1} = \frac{\partial }{\partial w_1} (w_1x_1) = x_1 = 1$ \\ $\frac{\partial \hat{l}}{\partial w_1} = \frac{\partial \hat{l}}{\partial \alpha} \frac{\partial \alpha}{\partial w_1} = -0.3775 \times 1 = -0.3775$', 1.55, 3.8, '-0.3775'],
    [r'$\frac{\partial \alpha}{\partial x_1} = \frac{\partial }{\partial w_1} (w_1x_1) = w_1 = 1$ \\ $\frac{\partial \hat{l}}{\partial x_1} = \frac{\partial \hat{l}}{\partial \alpha} \frac{\partial \alpha}{\partial x_1} = -0.3775 \times 1 = -0.3775$', 1.55, 2.8, '-0.3775'],
    [r'$\frac{\partial \beta}{\partial w_2} = \frac{\partial }{\partial w_2} (w_2x_2) = x_2 = -1$ \\ $\frac{\partial \hat{l}}{\partial w_2} = \frac{\partial \hat{l}}{\partial \beta} \frac{\partial \beta}{\partial w_2} = -0.3775 \times (-1) = 0.3775$', 1.55, 1.8, '0.3775'],
    [r'$\frac{\partial \beta}{\partial x_2} = \frac{\partial }{\partial w_2} (w_2x_2) = w_2 = 0.5$ \\ $\frac{\partial \hat{l}}{\partial x_2} = \frac{\partial \hat{l}}{\partial \beta} \frac{\partial \beta}{\partial x_2} = -0.3775 \times 0.5 = -0.1887$', 1.55, 0.8, '-0.1887'],
    None,
]

A = []
for t in range(len(S)+1):
    A.append(forward(ax, S, t))
    
ani = matplotlib.animation.ArtistAnimation(fig, A, interval=2000)
html = ani.to_jshtml()
plt.close(fig)
HTML(html)