kmeansをフルスクラッチ実装

kmeansとは

教師なし学習を用いたクラスタリング手法の1つです。事前情報としてクラスター数(K)を定義してあげることで、指定した数のクラスターにデータを分類することが可能です。
kmeansの具体的なアルゴリズムは下記の通りになります。

  1. 各データ  {\displaystyle x_{i}(i=1,\dotsc ,n)} に対してランダムにクラスタを割り振る。
  2. 割り振ったデータをもとに各クラスタの中心  {\displaystyle V_{j}(j=1,\dotsc ,k)} を計算する。計算は通常割り当てられたデータの各要素の算術平均が使用されるが、必須ではない。
  3. {\displaystyle x_{i}}  と各 {\displaystyle V_{j}} との距離を求め、 {\displaystyle x_{i}} を最も近い中心のクラスタに割り当て直す。
  4. 上記の処理で全ての {\displaystyle x_{i}} のクラスタの割り当てが変化しなかった場合、あるいは変化量が事前に設定した一定の閾値を下回った場合に、収束したと判断して処理を終了する。そうでない場合は新しく割り振られたクラスタから {\displaystyle V_{j}} を再計算して上記の処理を繰り返す。
    引用:k平均法 - Wikipedia

初期値依存しているという性質があるものの、非階層クラスター分析の手法としてシンプルかつ優秀な手法であり、 single cell RNA-seq解析では、セルタイプのクラスタリングを行う際、この手法を応用したものが頻繁に使用されています。

pythonではscikit-learnの sklearn.cluster.KMeansクラスを用いることで簡単に実装が可能です。
Rではstatsパッケージ中のkmeans関数を用いることで同様に実装することができます。

今回はこういったパッケージを使用せずにフルスクラッチで実装している例があったので、なぞりながら実装してみたいと思います。

enhancedatascience.com

プログラムとして実装する際には、以下の4つの要素が必要になります。

  1. 中心点の初期化
  2. クラスタの割り当て
  3. 中心点の更新
  4. 停止する際の閾値定義

該当する部分についてそれぞれ見ていきます。

中心点の初期化

sample.int関数を使用して、データフレーム中からK個のランダムな行を取得しています。 current_stop_critはイテレーションの度に更新していく閾値になります。
この値が閾値より低くなったらイテレーションを停止させます。

# 中心点の初期化
  centroids=data[sample.int(nrow(data),K),]

## 初期化をやめるまでの閾値を定義(とりあえず大きい値を入れておく)
  current_stop_crit=10e10

## 割り当てられた中心点を格納するベクトル
  cluster=rep(0,nrow(data))

## アルゴリズムが収束したか否か
  converged=F
  it=1

クラスタの割り当て

各イテレーションごとに、すべてのデータ点は最も近い中心点が属するクラスターとして割り当てられます。
今回は距離として、ユークリッド距離を使用します。

### 観測値の数だけイテレーション
    for (i in 1:nrow(data))
    {

## 最小距離を定義(とりあえず大きい値を入れておく)
   
   min_dist=10e10

## 中心の数だけイテレーション
      for (centroid in 1:nrow(centroids))
      {

## L2ノルムを算出
        distance_to_centroid=sum((centroids[centroid,]-data[i,])^2)

## 1. 中心点がこれまでの最小距離よりデータ点に近かった場合、
        if (distance_to_centroid<=min_dist)
        {

## 2. そのデータ点を同じクラスタとして割り当てる
          cluster[i]=centroid
          min_dist=distance_to_centroid
        }
      }
    }

中心点の更新

各データ点について最も近い中心点が割り当てられたら、今度はそれらのデータ点を使用して新しい中心点の座標を算出、更新します。

## 各中心点についてイテレーション
    for (i in 1:nrow(centroids))
    {

## 中心点の新しい座標は、クラスターごとの各点の平均値となる。
      centroids[i,]=apply(data[cluster==i,],2,mean)
    }

停止する際の閾値定義

際限なくkmeansを実行するわけにはいかないため、中心点が1イテレーション前の中心点からほとんど動かなかった(中心点が収束した)場合、そこで停止するようにします。
この「ほとんど」の閾値をstop_crit として定義しておき、イテレーションごとに値を比較します。

## current_stop_critがstop_critを上回っている限り中心点を更新し続ける
  while(current_stop_crit>=stop_crit & converged==F)
  {
    it=it+1
    if (current_stop_crit<=stop_crit)
    {
      converged=T
    }
    old_centroids=centroids
 
## 現在の中心点と1イテレーション前の中心点の距離(current_stop_crit)を算出
current_stop_crit=mean((old_centroids-centroids)^2)

関数全体像

上述した流れを関数として定義すると、以下のようになります。

kmeans=function(data,K=4,stop_crit=10e-5)
{
  #Initialisation of clusters
  centroids=data[sample.int(nrow(data),K),]
  current_stop_crit=1000
  cluster=rep(0,nrow(data))
  converged=F
  it=1
  while(current_stop_crit>=stop_crit & converged==F)
  {
    it=it+1
    if (current_stop_crit<=stop_crit)
    {
      converged=T
    }
    old_centroids=centroids
    ##Assigning each point to a centroid
    for (i in 1:nrow(data))
    {
      min_dist=10e10
      for (centroid in 1:nrow(centroids))
      {
        distance_to_centroid=sum((centroids[centroid,]-data[i,])^2)
        if (distance_to_centroid<=min_dist)
        {
          cluster[i]=centroid
          min_dist=distance_to_centroid
        }
      }
    }
    ##Assigning each point to a centroid
    for (i in 1:nrow(centroids))
    {
      centroids[i,]=apply(data[cluster==i,],2,mean)
    }
    current_stop_crit=mean((old_centroids-centroids)^2)
  }
  return(list(data=data.frame(data,cluster),centroids=centroids))
}

実行

mvrnorm関数を使用してランダム値を生成し、データセットとして使用しています。
また、描画の際はわかりやすいように中心点のサイズ及びα値を調整しています。

# パッケージの読み込み
require(MASS)
require(ggplot2)
set.seed(12345678)

# データセットの作成
set1=mvrnorm(n = 300, c(-4,10), matrix(c(1.5,1,1,1.5),2))
set2=mvrnorm(n = 300, c(5,7), matrix(c(1,2,2,6),2))
set3=mvrnorm(n = 300, c(-1,1), matrix(c(4,0,0,4),2))
set4=mvrnorm(n = 300, c(10,-10), matrix(c(4,0,0,4),2))
set5=mvrnorm(n = 300, c(3,-3), matrix(c(4,0,0,4),2))
DF=data.frame(rbind(set1,set2,set3,set4,set5),cluster=as.factor(c(rep(1:5,each=300))))

# kmeansの実行
res=kmeans(DF[1:2],K=5)
res$centroids$cluster=1:5
res$data$isCentroid=F
res$centroids$isCentroid=T
data_plot=rbind(res$centroids,res$data)
ggplot(data_plot,aes(x=X1,y=X2,color=as.factor(cluster),size=isCentroid,alpha=isCentroid))+geom_point()

f:id:kimoppy126:20190618185228p:plain

おまけ

書いた後に気づいたのですが1年前にもパッケージ使ってkmeans試してました。。

www.kimoton.com