最近は専ら最尤推定と格闘しています。最尤推定において、必要なのが尤度関数の最大化。
なんとなく収束する手法を選択して目をを向けてきたものの、そろそろちゃんと理解する必要性を感じたため第一弾として、Nelder-Mead法を理解していこうと思います。
因みにNelder-Mead法は、ほかにも滑降シンプレックス法、アメーバ法、ポリトープ法等々、様々な呼び方があったりします。
これらはすべて同一の手法を指しているので気を付けましょう。
ざっくりどんな手法なのか
「解を含む空間を囲む多面体を縮小していく」動作を繰り返していき最適解を求めます。 操作としては下記のいずれかを繰り返すことになります。
- 反射後の点における関数値が最良点より良い場合、反射
- 膨張後の点における関数値が最良点より良い場合、膨張
- 縮小後の点における関数値が最良点より良い場合、縮小
- 反射・膨張・縮小のいずれも関数値が最良点より悪化する場合、多面体の体積を縮小
一般的に堅牢なアルゴリズムであり、Rのoptim()メソッドでは、デフォルトのアルゴリズムとして実装されています。ただし、導関数の数値計算が信頼できる場合、一次および二次導関数を使用する他のアルゴリズムの方が、一般的により良いパフォーマンスを示すことが多いです。
具体的にどんな手法なのか
個のパラメータ(変数)を含む関数( 次元関数)を最適化するためには、 次元の空間(超空間)で () 個の頂点を持つ多角形あるいは多面体(超多面体)を変形・移動する操作を行います。 以下は二次元平面上で考えた場合の一回の操作で移動先となり得る頂点を示しています。
各頂点 vertex の座標を とします。はじめに頂点の「初期配置」を決めます。初期配置はユーザーが定義することが多いですが、個の数値を指定することになり煩雑なため、個の数値のみ指定させ、 とすることも現実的な方法となります。
頂点を関数値に応じて序列化します。 例えば下記のように序列化した場合、その値によって、最良点()、第二最悪点()、最悪点()が定まります。 以降は以下の手順を繰り返します。
重心()を求めます。
最悪点()を重心に対して反射点()を求めます。
反射点が今までの最良点よりは悪いが、第二最悪点より良い場合() 今までの最悪点 を捨て、反射点に入れ替えます。このとき今までの第二最悪点が新しい最悪点になります。
反射点()が今までの最良点()より良い場合() 反射拡大点()を求め、関数値を計算します。
- 反射拡大点()が反射点() より良い場合() 最悪点()と反射点() を捨て,反射拡大点 を新しいグループの最良点( )として採用します。
- 反射拡大点()が反射点 ()より悪い場合() 最悪点()と反射拡大点()を捨て,反射点 ()を新しいグループの最良点( )として採用します。
反射点()が第二最良点()より悪い場合() 収縮点()を求め、関数値を計算します。
- 収縮点()が最悪点()より良い場合() 最悪点()と反射点()を捨て、を採用します。
- 収縮点()が最悪点 ()より悪い場合()
最良点 を中心に縮小を施します。
参考: Saša Singer and John Nelder (2009). Nelder-Mead algorithm
Nelder-Mead法の実装
Rosenbrock Banana functionに対して最適解を求める例を示します。
PythonでNelder-Mead法
scipy.optimizeのminimizeメソッドを使用します。
まずRosenbrock's banana functionを定義します。
scipy.minimize
の仕様から、目的関数の変数は1つのリストにまとめてあげる必要があります。
def objective_function(X, a, b): """Rosenbrock's banana function Args: X: X[0]:x X[1]:y a: param1 b: param2 """ return (a - X[0])**2 + b * (X[1] - X[0]**2)**2
a, bに適当な値を入れ、最適化を行います。
Nelder-Mead法で実行するためには、method="Nelder-Mead"
を指定してあげる必要があります。
a = math.sqrt(2) b = 100 arg = (a, b) res = minimize(objective_function, [-1, 1.2], args=arg, method="Nelder-Mead")
結果は下記の通り。最適化後の変数はres.x
にリスト形式で入ります。
final_simplex: (array([[1.41420186, 1.99996718], [1.41423414, 2.00005711], [1.414233 , 2.00005669]]), array([1.44630178e-10, 5.45553803e-10, 6.71605841e-10])) fun: 1.446301782885471e-10 message: 'Optimization terminated successfully.' nfev: 178 nit: 93 status: 0 success: True x: array([1.41420186, 1.99996718])
callback
に値をprint、plotする関数をかませることで、最適化のプロセスを可視化することもできます。
count = 0 plt.figure() plt.style.use("ggplot") plt.xlim(1, 35) print('count\tx\ty\tf') def cbf(X): global count count += 1 f = objective_function(X, a, b) print('%d\t%f\t%f\t%f' % (count, X[0], X[1], f)) plt.scatter(count, f, color='black') res = minimize(objective_function, [-1, 1.2], args=arg, callback=cbf)
結果はこんな感じ。が、が、目的関数の値が0に近づいていくことが分かります。
count x y f 1 -1.061164 1.167454 6.298758 2 -0.693743 0.426743 4.740897 3 -0.727047 0.500402 4.664489 4 -0.576543 0.304110 4.043153 5 -0.471721 0.181918 3.721611 6 -0.342229 0.079586 3.225976 7 -0.206316 0.037086 2.629120 8 -0.128193 -0.012000 2.459863 9 -0.010856 -0.045612 2.239946 10 0.070591 -0.024190 1.890429 11 0.196722 0.014245 1.542087 12 0.314848 0.112355 1.226096 13 0.423797 0.163800 1.005900 14 0.483925 0.204945 0.950926 15 0.529206 0.253353 0.854560 16 0.590590 0.334425 0.699010 17 0.674162 0.432513 0.595998 18 0.772228 0.579052 0.442017 19 0.854807 0.729481 0.313083 20 0.911311 0.817831 0.268928 21 1.005250 0.986956 0.222808 22 1.031241 1.054466 0.154756 23 1.133933 1.271208 0.099863 24 1.176896 1.381381 0.057691 25 1.234544 1.514569 0.041364 26 1.281671 1.636236 0.021720 27 1.313275 1.725268 0.010222 28 1.349919 1.819228 0.005067 29 1.380297 1.901651 0.002424 30 1.394576 1.944599 0.000392 31 1.410126 1.987898 0.000048 32 1.413111 1.996834 0.000001 33 1.414211 1.999982 0.000000 34 1.414201 1.999965 0.000000 35 1.414205 1.999977 0.000000
参考:
RでNelder-Mead法
statsのoptimメソッドを使用します。
Pythonの場合と同様に、Rosenbrock's banana functionを定義します。optim
の仕様から、目的関数の変数は1つのリストにまとめてあげる必要があります。
objectiveFunction <- function(x, a, b) { # Rosenbrock's banana function # # Args: # X: # x[1]:x # x[2]:y # a: param1 # b: param2 # return ((a - x[1])^2 + b * (x[2] - x[1] * x[1])^2) }
結果は下記の通り。最適化後の変数はres$per
にベクトル形式で入ります。
> res$per 1.41406310142774・1.99957098984883
Rのoptim
メソッドでは、callback関数を与えることはできませんが、以下のように関数のラッパーを作ってあげることで、同様に最適化のプロセスを可視化することができます。
# 結果を格納するDFを作成 dfNames <- c("count", "x", "y", "f") emptyTable <- data.frame( matrix(integer(), ncol = 4, nrow = 0), stringsAsFactors = FALSE ) %>% setNames(nm = c(dfNames)) count <- 0 df <- emptyTable cbf <- function(X, a, b){ f <- objectiveFunction(X, a, b) count <<- count + 1 res_vec <- c(count, X[1], X[2], f) df[nrow(df) + 1,] <<- res_vec return (f) } res <- optim(par=c(-1, 1.2), fn=cbf, a=a, b=b,method="Nelder-Mead") ggplot(data = df, aes(x = count, y = f)) + geom_point()
結果はこんな感じ。同じくが、が、目的関数の値が0に近づいていくことが分かります。
> df count x y f <dbl> <dbl> <dbl> <dbl> 1 1 -1.000000 1.200000 9.828427 2 2 -0.880000 1.200000 23.376952 3 3 -1.000000 1.320000 16.068427 4 4 -1.120000 1.320000 6.852574 5 5 -1.240000 1.380000 9.528626 6 6 -1.120000 1.200000 6.718174 7 7 -1.180000 1.140000 13.100520 8 8 -1.240000 1.320000 11.779826 9 9 -1.060000 1.230000 7.253829 10 10 -1.180000 1.290000 7.778520 11 11 -1.090000 1.245000 6.594847 12 12 -1.090000 1.125000 6.669247 13 13 -1.097500 1.173750 6.403300 14 14 -1.067500 1.218750 6.786067 15 15 -1.106875 1.204687 6.397850 16 16 -1.114375 1.133437 7.568689 17 17 -1.096094 1.217109 6.326254 18 18 -1.105469 1.248047 6.416325 19 19 -1.099492 1.192324 6.346136 20 20 -1.088711 1.204746 6.302479 21 21 -1.079629 1.204775 6.372733 22 22 -1.085313 1.229531 6.514176 23 23 -1.095947 1.201626 6.300935 24 24 -1.088564 1.189263 6.265738 25 25 -1.084800 1.175339 6.245278 26 26 -1.092036 1.172219 6.322593 27 27 -1.089542 1.196614 6.277841 28 28 -1.078395 1.170328 6.218561 29 29 -1.069619 1.154679 6.180647 30 30 -1.064876 1.133404 6.145917 ⋮ ⋮ ⋮ ⋮ ⋮ 164 164 1.414092 2.011432 1.386700e-02 165 165 1.395783 1.947526 3.864282e-04 166 166 1.366518 1.863815 3.540270e-03 167 167 1.404018 1.973967 8.325944e-04 168 168 1.412548 1.994520 6.209333e-05 169 169 1.425195 2.028295 9.529404e-04 170 170 1.404312 1.968080 1.708210e-03 171 171 1.404092 1.972495 2.067683e-04 172 172 1.420856 2.019489 8.719178e-05 173 173 1.414588 2.001498 1.942470e-05 174 174 1.423044 2.023524 3.120354e-04 175 175 1.408830 1.985252 4.931526e-05 176 176 1.410870 1.992230 2.918679e-04 177 177 1.412128 1.993948 6.853317e-06 178 178 1.417887 2.010194 1.781671e-05 179 179 1.415622 2.003959 2.062402e-06 180 180 1.413163 1.996408 3.957084e-05 181 181 1.414232 2.000226 3.051777e-06 182 182 1.417726 2.010237 2.076580e-05 183 183 1.413528 1.998020 6.324142e-07 184 184 1.414918 2.001753 6.297752e-06 185 185 1.414403 2.000608 5.385065e-07 186 186 1.412309 1.994669 3.914495e-06 187 187 1.414794 2.001636 3.398030e-07 188 188 1.415670 2.004224 3.189194e-06 189 189 1.414063 1.999571 2.383910e-08 190 190 1.414454 2.000600 6.910980e-07 191 191 1.414416 2.000606 1.516222e-07 192 192 1.413685 1.998540 4.000382e-07 193 193 1.414517 2.000862 9.418494e-08
scipy.optimizeの場合と異なり、収束までに193ループを必要としました。この辺の違いは収束の定義なんかが関係しているのでしょう。
$par 1.414063101427741.99957098984883 $value 2.38390982159864e-08 $counts function193gradient<NA> $convergence 0 $message NULL
参考: