陳鍾誠

Version 1.0

反傳遞演算法 – 手算案例

為了講清楚反傳遞算法,我們必須先發明一點數學符號!

讓我們先回頭看看梯度中的基本元素,也就是偏微分,其定義是:

\frac{\partial }{\partial x_i} f(x) = \lim_{h \to 0} \frac{f(x_1, ..., x_i+h, ...., x_n)-f(x_1, ..., x_i, ...., x_n)}{h}

舉例而言,假如對 $f(x,y) = x^2+y^2$ 這個函數而言,其對 x 的偏微分就是:

\frac{\partial }{\partial x} f(x,y) = \lim_{h \to 0} \frac{f(x+h,y)-f(x,y)}{h}

而對 y 的偏微分就是:

\frac{\partial }{\partial y} f(x,y) = \lim_{h \to 0} \frac{f(x,y+h)-f(x,y)}{h}

以上的數學符號源自《萊布尼茲》

簡易案例

讓我們考慮一個兩層式網路如下圖,該網路是計算 f = (x+y) * z 這個算式。

其中的 q = x+y, 而 f = q*z。

反傳遞的原理主要來自偏微分的鏈鎖規則,我們可以用以下數學式描述 f, q, x 之間的梯度關係。

\frac{\partial{f(q,z)}}{\partial{x}} = \frac{\partial{q(x,y)}}{\partial{x}} \frac{\partial{f(q,z)}}{\partial{q}}

但是其中的 ${\partial{x}}$ 並非偏微分,而是 $\frac{\partial{f(q,z)}}{\partial{x}}$ 才是 f 函數對 的偏微分,這樣寫起來不僅冗長,而且會引導我們一直去把 ${\partial{x}}$ 想成偏微分 (梯度向量的其中一個軸),因而會造成很多誤解!

為了避免誤解,我們採用 $g^x_f=\frac{\partial{f}}{\partial{x}}$ 這樣的表達形式,於是可以有下列偏微分式:

g^x_f=\frac{\partial{f}}{\partial{x}}
g^y_f=\frac{\partial{f}}{\partial{y}}
g^q_f=\frac{\partial{f}}{\partial{q}}
g^z_f=\frac{\partial{f}}{\partial{z}}

然後我們可以改寫鏈鎖規則成為以 g 為主的形式:

萊布尼茲形式 :

\frac{\partial{f(q,z)}}{\partial{x}} = \frac{\partial{q(x,y)}}{\partial{x}} \frac{\partial{f(q,z)}}{\partial{q}}

以 g 為主的形式:

g^x_f = g^q_f * g^x_q

這樣我們就可以寫出下列兩組關係式:

g^x_f = g^q_f * g^x_q
g^y_f = g^q_f * g^y_q

由於 f=q*z, q=x+y ,因此我們可以計算出下列算式:

g^q_f = z
g^x_q = 1
g^y_q = 1

所以我們得到

g^x_f = g^q_f * g^x_q = z * 1
g^y_f = g^q_f * g^y_q = z * 1

如此只要把 z 值帶入就能計算出梯度 $g^x_f$ 與 $g^y_f$ 了。

透過這種方式,我們可以一層一層的算回去,得到 f 對任意變數的梯度。

更複雜的案例

f(x,y) = ((2*x)+(y+1))^2

在 x=3, y=2 時,正向傳遞後再反向傳遞的結果為:

運算式正向傳遞閘的梯度反向傳遞
x = 3x=3$g^x_f = ??$36
y = 2y=2$g^y_f = ??$18
p = 2xp=6$g^x_p = 2$$g^x_f = g^p_f*g^x_p=18*2=36$
q = y+1q=3$g^y_q = 1$$g^y_f = g^q_f*g^y_q=18*1=18$
r = p+q = 2x+y+1r=9$g^q_r = 1$ ; $g^p_r = 1$$g^q_f = g^r_f*g^q_r=18*1$ ; $g^p_f=g^r_f*g^p_r=18*1$
$f = r*r = (2x+y+1)^2$f=9*9$g^r_f = 2r=18$$g^r_f = g^r_f*g^f_f=18$
f = ff=81$g^f_f = 1$
2x  => p
       + => r*r => f
y+1 => q
g^x_f = g^r_f * g^p_r * g^x_p = 1*18*2=36
g^y_f = g^r_f * g^q_r * g^y_q = 1*18*1=18

檢驗:

正向: $f(x,y) = ((2*x)+(y+1))^2 = (2*3+2+1)^2 = 9^2 = 81$

反向:

$g^x_f = 8x + 4y + 4 = 8*3 + 4*2 + 4 = 36$

$g^y_f = 4x + 2y + 2 = 4*3 + 2*2 + 2 = 18$