kmeansとは
教師なし学習を用いたクラスタリング手法の1つです。事前情報としてクラスター数(K)を定義してあげることで、指定した数のクラスターにデータを分類することが可能です。
kmeansの具体的なアルゴリズムは下記の通りになります。
- 各データ
に対してランダムにクラスタを割り振る。
- 割り振ったデータをもとに各クラスタの中心
を計算する。計算は通常割り当てられたデータの各要素の算術平均が使用されるが、必須ではない。
- 各
と各
との距離を求め、
を最も近い中心のクラスタに割り当て直す。
- 上記の処理で全ての
のクラスタの割り当てが変化しなかった場合、あるいは変化量が事前に設定した一定の閾値を下回った場合に、収束したと判断して処理を終了する。そうでない場合は新しく割り振られたクラスタから
を再計算して上記の処理を繰り返す。
引用:k平均法 - Wikipedia
初期値依存しているという性質があるものの、非階層クラスター分析の手法としてシンプルかつ優秀な手法であり、 single cell RNA-seq解析では、セルタイプのクラスタリングを行う際、この手法を応用したものが頻繁に使用されています。
pythonではscikit-learnの sklearn.cluster.KMeans
クラスを用いることで簡単に実装が可能です。
Rではstatsパッケージ中のkmeans
関数を用いることで同様に実装することができます。
今回はこういったパッケージを使用せずにフルスクラッチで実装している例があったので、なぞりながら実装してみたいと思います。
プログラムとして実装する際には、以下の4つの要素が必要になります。
- 中心点の初期化
- クラスタの割り当て
- 中心点の更新
- 停止する際の閾値定義
該当する部分についてそれぞれ見ていきます。
中心点の初期化
sample.int
関数を使用して、データフレーム中からK個のランダムな行を取得しています。
current_stop_crit
はイテレーションの度に更新していく閾値になります。
この値が閾値より低くなったらイテレーションを停止させます。
# 中心点の初期化 centroids=data[sample.int(nrow(data),K),] ## 初期化をやめるまでの閾値を定義(とりあえず大きい値を入れておく) current_stop_crit=10e10 ## 割り当てられた中心点を格納するベクトル cluster=rep(0,nrow(data)) ## アルゴリズムが収束したか否か converged=F it=1
クラスタの割り当て
各イテレーションごとに、すべてのデータ点は最も近い中心点が属するクラスターとして割り当てられます。
今回は距離として、ユークリッド距離を使用します。
### 観測値の数だけイテレーション for (i in 1:nrow(data)) { ## 最小距離を定義(とりあえず大きい値を入れておく) min_dist=10e10 ## 中心の数だけイテレーション for (centroid in 1:nrow(centroids)) { ## L2ノルムを算出 distance_to_centroid=sum((centroids[centroid,]-data[i,])^2) ## 1. 中心点がこれまでの最小距離よりデータ点に近かった場合、 if (distance_to_centroid<=min_dist) { ## 2. そのデータ点を同じクラスタとして割り当てる cluster[i]=centroid min_dist=distance_to_centroid } } }
中心点の更新
各データ点について最も近い中心点が割り当てられたら、今度はそれらのデータ点を使用して新しい中心点の座標を算出、更新します。
## 各中心点についてイテレーション for (i in 1:nrow(centroids)) { ## 中心点の新しい座標は、クラスターごとの各点の平均値となる。 centroids[i,]=apply(data[cluster==i,],2,mean) }
停止する際の閾値定義
際限なくkmeansを実行するわけにはいかないため、中心点が1イテレーション前の中心点からほとんど動かなかった(中心点が収束した)場合、そこで停止するようにします。
この「ほとんど」の閾値をstop_crit
として定義しておき、イテレーションごとに値を比較します。
## current_stop_critがstop_critを上回っている限り中心点を更新し続ける while(current_stop_crit>=stop_crit & converged==F) { it=it+1 if (current_stop_crit<=stop_crit) { converged=T } old_centroids=centroids ## 現在の中心点と1イテレーション前の中心点の距離(current_stop_crit)を算出 current_stop_crit=mean((old_centroids-centroids)^2)
関数全体像
上述した流れを関数として定義すると、以下のようになります。
kmeans=function(data,K=4,stop_crit=10e-5) { #Initialisation of clusters centroids=data[sample.int(nrow(data),K),] current_stop_crit=1000 cluster=rep(0,nrow(data)) converged=F it=1 while(current_stop_crit>=stop_crit & converged==F) { it=it+1 if (current_stop_crit<=stop_crit) { converged=T } old_centroids=centroids ##Assigning each point to a centroid for (i in 1:nrow(data)) { min_dist=10e10 for (centroid in 1:nrow(centroids)) { distance_to_centroid=sum((centroids[centroid,]-data[i,])^2) if (distance_to_centroid<=min_dist) { cluster[i]=centroid min_dist=distance_to_centroid } } } ##Assigning each point to a centroid for (i in 1:nrow(centroids)) { centroids[i,]=apply(data[cluster==i,],2,mean) } current_stop_crit=mean((old_centroids-centroids)^2) } return(list(data=data.frame(data,cluster),centroids=centroids)) }
実行
mvrnorm
関数を使用してランダム値を生成し、データセットとして使用しています。
また、描画の際はわかりやすいように中心点のサイズ及びα値を調整しています。
# パッケージの読み込み require(MASS) require(ggplot2) set.seed(12345678) # データセットの作成 set1=mvrnorm(n = 300, c(-4,10), matrix(c(1.5,1,1,1.5),2)) set2=mvrnorm(n = 300, c(5,7), matrix(c(1,2,2,6),2)) set3=mvrnorm(n = 300, c(-1,1), matrix(c(4,0,0,4),2)) set4=mvrnorm(n = 300, c(10,-10), matrix(c(4,0,0,4),2)) set5=mvrnorm(n = 300, c(3,-3), matrix(c(4,0,0,4),2)) DF=data.frame(rbind(set1,set2,set3,set4,set5),cluster=as.factor(c(rep(1:5,each=300)))) # kmeansの実行 res=kmeans(DF[1:2],K=5) res$centroids$cluster=1:5 res$data$isCentroid=F res$centroids$isCentroid=T data_plot=rbind(res$centroids,res$data) ggplot(data_plot,aes(x=X1,y=X2,color=as.factor(cluster),size=isCentroid,alpha=isCentroid))+geom_point()
おまけ
書いた後に気づいたのですが1年前にもパッケージ使ってkmeans試してました。。