最適化手法について学ぶ - Nelder-Mead法

最近は専ら最尤推定と格闘しています。最尤推定において、必要なのが尤度関数の最大化。
なんとなく収束する手法を選択して目をを向けてきたものの、そろそろちゃんと理解する必要性を感じたため第一弾として、Nelder-Mead法を理解していこうと思います。

因みにNelder-Mead法は、ほかにも滑降シンプレックス法、アメーバ法、ポリトープ法等々、様々な呼び方があったりします。
これらはすべて同一の手法を指しているので気を付けましょう。

ざっくりどんな手法なのか

「解を含む空間を囲む多面体を縮小していく」動作を繰り返していき最適解を求めます。 操作としては下記のいずれかを繰り返すことになります。

  • 反射後の点における関数値が最良点より良い場合、反射
  • 膨張後の点における関数値が最良点より良い場合、膨張
  • 縮小後の点における関数値が最良点より良い場合、縮小
  • 反射・膨張・縮小のいずれも関数値が最良点より悪化する場合、多面体の体積を縮小

一般的に堅牢なアルゴリズムであり、Rのoptim()メソッドでは、デフォルトのアルゴリズムとして実装されています。ただし、導関数の数値計算が信頼できる場合、一次および二次導関数を使用する他のアルゴリズムの方が、一般的により良いパフォーマンスを示すことが多いです。

具体的にどんな手法なのか

m 個のパラメータ(変数)を含む関数(m 次元関数)を最適化するためには、m 次元の空間(超空間)で (m + 1) 個の頂点を持つ多角形あるいは多面体(超多面体)を変形・移動する操作を行います。 以下は二次元平面上で考えた場合の一回の操作で移動先となり得る頂点を示しています。

  1. 各頂点 vertex の座標を p_i (i=1, 2, ..., m+1) とします。はじめに頂点の「初期配置」を決めます。初期配置はユーザーが定義することが多いですが、m(m+1)個の数値を指定することになり煩雑なため、2m個の数値のみ指定させ、 \vec{p_1}= (\tilde{p_1}+\Delta\tilde{p_1}, ⋯, \tilde{p_m}),...,\vec{p_m}=(\tilde{p_1}, ⋯, \tilde{p_m}+\Delta\tilde{p_m}), \vec{p_{m+1}}=(\tilde{p_1}, ⋯, \tilde{p_m}) とすることも現実的な方法となります。

  2. 頂点を関数値に応じて序列化します。 例えば下記のように序列化した場合、その値によって、最良点(f (p_1))、第二最悪点(f (p_m))、最悪点(f(p _ {m+1}))が定まります。 f(\vec{p_1})≤f(\vec{p_2})≤⋯≤f(\vec{p_m})≤f (\vec{p_{m+1}}) 以降は以下の手順を繰り返します。

  3. 重心( \vec{p _ 0}= \frac{1}{m} \sum ^ m_{i=1} \vec{p_i})を求めます。

  4. 最悪点(\vec{p _ {m+1}})を重心に対して反射点(\vec{p _ {refl}})を求めます。
    \vec{p _ {refl}} = \vec{p _ 0} + (\vec{p _ 0} − \vec{p _ {m+1}}) = 2\vec{p _ 0} − \vec{p _ {m+1}}

  5. 反射点が今までの最良点よりは悪いが、第二最悪点より良い場合(f( \vec{p _ 1}) \lt f( \vec{p _ {refl}}) \lt f(\vec {p _ m})) 今までの最悪点 を捨て、反射点に入れ替えます。このとき今までの第二最悪点が新しい最悪点になります。

  6. 反射点(\vec{p _ {refl}})が今までの最良点(\vec{p _ 1})より良い場合(f(\vec{p _ {refl}}) \lt f(\vec{p _ 1})) 反射拡大点(\vec{p _ {expand}})を求め、関数値を計算します。
    \vec{p _ {expand}} = \vec{p _ 0} + 2(\vec{p _ 0} − \vec{p _ {m+1}})=3\vec{p _ 0}−2\vec{p _ {m+1}}

    1. 反射拡大点(\vec{p _ {expand}})が反射点(\vec{p _ {refl}}) より良い場合(f(\vec{p _ {expand}}) \lt f(\vec{p _ {refl}}) \lt f(\vec{p _ 1})) 最悪点(\vec{p _ {m+1}})と反射点(\vec{p_{refl}}) を捨て,反射拡大点 を新しいグループの最良点(\vec{p _ 1} )として採用します。
    2. 反射拡大点(\vec{p _ {expand}})が反射点 (\vec{p _ {refl}})より悪い場合(f(\vec{p _ {refl}}) \leq f(\vec{p _ {expand}})) 最悪点(\vec{p _ {m+1}})と反射拡大点(\vec{p _ {expand}})を捨て,反射点 (\vec{p_{refl}})を新しいグループの最良点(\vec{p_1} )として採用します。
  7. 反射点(\vec{p _ {refl}})が第二最良点(\vec{p _ m})より悪い場合(f(\vec{p _ m})\leq f(\vec{p _ {refl}})) 収縮点(\vec{p _ {contract}})を求め、関数値を計算します。
    \vec{p_{contract}} = \vec{p _ 0} − 0.5 ( \vec{p _ 0} −  \vec{p _ {m+1}}) = 0.5 \vec{p _ 0} + 0.5 \vec{p _ {m+1}}

    1. 収縮点(\vec{p _ {contract}})が最悪点(\vec{p _ {m+1}})より良い場合(f (p _ {contract}) \leq f (p _ {m+1})) 最悪点(\vec{p _ {m+1}})と反射点(\vec{p _ {refl}})を捨て、\vec{p _ {contract}}を採用します。
    2. 収縮点(\vec{p _ {contract}})が最悪点 (\vec{p_{m+1}})より悪い場合(f (p _ {m+1}) \lt f (p _ {contract})) 最良点 を中心に縮小を施します。
      \vec{p _ i} ← \vec{p _ i} − 0.5 (\vec{p_i} − \vec{p _ 1}) = 1.5 \vec{p _  i} − 0.5 \vec{p _ 1}

参考: 第3部 最適化とモンテカルロ法 - 先進セラミックス研究センター

Nelder-Mead法の実装

Rosenbrock Banana functionに対して最適解を求める例を示します。

PythonでNelder-Mead法

scipy.optimizeのminimizeメソッドを使用します。

まずRosenbrock's banana functionを定義します。 scipy.minimizeの仕様から、目的関数の変数は1つのリストにまとめてあげる必要があります。

def objective_function(X, a, b):
    """Rosenbrock's banana function 
    Args:
        X: 
            X[0]:x
            X[1]:y
        a: param1
        b: param2
    """
    return (a - X[0])**2 + b * (X[1] - X[0]**2)**2

a, bに適当な値を入れ、最適化を行います。 Nelder-Mead法で実行するためには、method="Nelder-Mead"を指定してあげる必要があります。

a = math.sqrt(2)
b = 100
arg = (a, b)
res = minimize(objective_function, [-1, 1.2], args=arg, method="Nelder-Mead")

結果は下記の通り。最適化後の変数はres.xにリスト形式で入ります。

 final_simplex: (array([[1.41420186, 1.99996718],
       [1.41423414, 2.00005711],
       [1.414233  , 2.00005669]]), array([1.44630178e-10, 5.45553803e-10, 6.71605841e-10]))
           fun: 1.446301782885471e-10
       message: 'Optimization terminated successfully.'
          nfev: 178
           nit: 93
        status: 0
       success: True
             x: array([1.41420186, 1.99996718])

callbackに値をprint、plotする関数をかませることで、最適化のプロセスを可視化することもできます。

count = 0
plt.figure()
plt.style.use("ggplot")
plt.xlim(1, 35)
print('count\tx\ty\tf')

def cbf(X):
    global count
    count += 1
    f = objective_function(X, a, b)
    print('%d\t%f\t%f\t%f' % (count, X[0], X[1], f))
    plt.scatter(count, f, color='black')
res = minimize(objective_function, [-1, 1.2], args=arg, callback=cbf)

結果はこんな感じ。x\sqrt{2} (a)y2 (a ^ 2)、目的関数の値が0に近づいていくことが分かります。

count  x   y   f
1  -1.061164  1.167454   6.298758
2  -0.693743  0.426743   4.740897
3  -0.727047  0.500402   4.664489
4  -0.576543  0.304110   4.043153
5  -0.471721  0.181918   3.721611
6  -0.342229  0.079586   3.225976
7  -0.206316  0.037086   2.629120
8  -0.128193  -0.012000  2.459863
9  -0.010856  -0.045612  2.239946
10 0.070591   -0.024190  1.890429
11 0.196722   0.014245   1.542087
12 0.314848   0.112355   1.226096
13 0.423797   0.163800   1.005900
14 0.483925   0.204945   0.950926
15 0.529206   0.253353   0.854560
16 0.590590   0.334425   0.699010
17 0.674162   0.432513   0.595998
18 0.772228   0.579052   0.442017
19 0.854807   0.729481   0.313083
20 0.911311   0.817831   0.268928
21 1.005250   0.986956   0.222808
22 1.031241   1.054466   0.154756
23 1.133933   1.271208   0.099863
24 1.176896   1.381381   0.057691
25 1.234544   1.514569   0.041364
26 1.281671   1.636236   0.021720
27 1.313275   1.725268   0.010222
28 1.349919   1.819228   0.005067
29 1.380297   1.901651   0.002424
30 1.394576   1.944599   0.000392
31 1.410126   1.987898   0.000048
32 1.413111   1.996834   0.000001
33 1.414211   1.999982   0.000000
34 1.414201   1.999965   0.000000
35 1.414205   1.999977   0.000000

参考:

RでNelder-Mead法

statsのoptimメソッドを使用します。

Pythonの場合と同様に、Rosenbrock's banana functionを定義します。optimの仕様から、目的関数の変数は1つのリストにまとめてあげる必要があります。

objectiveFunction <- function(x, a, b) {
  # Rosenbrock's banana function 
  #
  #   Args:
  #     X: 
  #       x[1]:x
  #       x[2]:y
  #     a: param1
  #     b: param2
  #
  return ((a - x[1])^2 + b * (x[2] - x[1] * x[1])^2)
}

結果は下記の通り。最適化後の変数はres$perにベクトル形式で入ります。

> res$per
1.414063101427741.99957098984883

Rのoptimメソッドでは、callback関数を与えることはできませんが、以下のように関数のラッパーを作ってあげることで、同様に最適化のプロセスを可視化することができます。

# 結果を格納するDFを作成
dfNames <- c("count", "x", "y", "f")
emptyTable <- 
    data.frame(
        matrix(integer(), ncol = 4, nrow = 0), stringsAsFactors = FALSE
    ) %>% 
    setNames(nm = c(dfNames))

count <- 0
df <- emptyTable
cbf <- function(X, a, b){
    f <- objectiveFunction(X, a, b)
    count <<- count + 1
    res_vec <- c(count, X[1], X[2], f)
    df[nrow(df) + 1,] <<- res_vec
    return (f)
}
res <- optim(par=c(-1, 1.2), fn=cbf, a=a, b=b,method="Nelder-Mead")
ggplot(data = df, aes(x = count, y = f)) +
  geom_point()

結果はこんな感じ。同じくx\sqrt{2} (a)y2 (a ^ 2)、目的関数の値が0に近づいていくことが分かります。

> df
    count   x   y   f
<dbl> <dbl> <dbl> <dbl>
1  1  -1.000000    1.200000   9.828427
2  2  -0.880000    1.200000   23.376952
3  3  -1.000000    1.320000   16.068427
4  4  -1.120000    1.320000   6.852574
5  5  -1.240000    1.380000   9.528626
6  6  -1.120000    1.200000   6.718174
7  7  -1.180000    1.140000   13.100520
8  8  -1.240000    1.320000   11.779826
9  9  -1.060000    1.230000   7.253829
10 10 -1.180000    1.290000   7.778520
11 11 -1.090000    1.245000   6.594847
12 12 -1.090000    1.125000   6.669247
13 13 -1.097500    1.173750   6.403300
14 14 -1.067500    1.218750   6.786067
15 15 -1.106875    1.204687   6.397850
16 16 -1.114375    1.133437   7.568689
17 17 -1.096094    1.217109   6.326254
18 18 -1.105469    1.248047   6.416325
19 19 -1.099492    1.192324   6.346136
20 20 -1.088711    1.204746   6.302479
21 21 -1.079629    1.204775   6.372733
22 22 -1.085313    1.229531   6.514176
23 23 -1.095947    1.201626   6.300935
24 24 -1.088564    1.189263   6.265738
25 25 -1.084800    1.175339   6.245278
26 26 -1.092036    1.172219   6.322593
27 27 -1.089542    1.196614   6.277841
28 28 -1.078395    1.170328   6.218561
29 29 -1.069619    1.154679   6.180647
30 30 -1.064876    1.133404   6.145917
⋮ ⋮ ⋮ ⋮ ⋮
164    164    1.414092   2.011432   1.386700e-02
165    165    1.395783   1.947526   3.864282e-04
166    166    1.366518   1.863815   3.540270e-03
167    167    1.404018   1.973967   8.325944e-04
168    168    1.412548   1.994520   6.209333e-05
169    169    1.425195   2.028295   9.529404e-04
170    170    1.404312   1.968080   1.708210e-03
171    171    1.404092   1.972495   2.067683e-04
172    172    1.420856   2.019489   8.719178e-05
173    173    1.414588   2.001498   1.942470e-05
174    174    1.423044   2.023524   3.120354e-04
175    175    1.408830   1.985252   4.931526e-05
176    176    1.410870   1.992230   2.918679e-04
177    177    1.412128   1.993948   6.853317e-06
178    178    1.417887   2.010194   1.781671e-05
179    179    1.415622   2.003959   2.062402e-06
180    180    1.413163   1.996408   3.957084e-05
181    181    1.414232   2.000226   3.051777e-06
182    182    1.417726   2.010237   2.076580e-05
183    183    1.413528   1.998020   6.324142e-07
184    184    1.414918   2.001753   6.297752e-06
185    185    1.414403   2.000608   5.385065e-07
186    186    1.412309   1.994669   3.914495e-06
187    187    1.414794   2.001636   3.398030e-07
188    188    1.415670   2.004224   3.189194e-06
189    189    1.414063   1.999571   2.383910e-08
190    190    1.414454   2.000600   6.910980e-07
191    191    1.414416   2.000606   1.516222e-07
192    192    1.413685   1.998540   4.000382e-07
193    193    1.414517   2.000862   9.418494e-08

scipy.optimizeの場合と異なり、収束までに193ループを必要としました。この辺の違いは収束の定義なんかが関係しているのでしょう。

$par
1.414063101427741.99957098984883
$value
2.38390982159864e-08
$counts
function193gradient<NA>
$convergence
0
$message
NULL

参考: