I am going back to basics. It’s been a while since I’ve written any code, and it is an itch I need to scratch: Today, I am going to (re)implement the k-means algorithm in clojure. K-means is an unsupervised clustering algorithm that groups a cloud of vector data into k-clusters for a given number k. Idea is simple: start with k centeroids, group points according to which centroid they are close to, and then recalculate the centroids again. Repeat this until centroids sufficiently stabilize.
We first need a function that calculates the centroid of a given group of data points:
(defn centroid [xs]
(let [n (count xs)]
(->> xs
(reduce #(map + %1 %2))
(map #(/ % n)))))#'user/centroidNow, we are going to need a new function that takes two arguments: a
vector x and list of vectors xs. It will
return the closest vector to x within the list
xs.
(defn dist [xs ys]
(reduce + (map (fn [x y] (let [u (- x y)] (* u u))) xs ys)))
(defn closest [x xs]
(reduce (fn [u v] (if (< (dist x u) (dist x v)) u v)) xs))#'user/dist
#'user/closestThe k-means algorithm groups points with respect to a list of centroids, and then recalculate positions of the centroids:
(defn k-means [k data tolerance epochs]
(loop [centers (take k data)
n 0]
(let [new (->> data
(group-by (fn [x] (closest x centers)))
vals
(map centroid)
(into []))
err (reduce + (map dist centers new))]
(if (or (> tolerance err)
(> n epochs))
[(into [] new) err n]
(recur new (inc n))))))#'user/k-meansIn order to test the code we are going to need a list of data points:
(defn random-vector [xs sigma]
(map #(+ (- (rand (* 2 sigma)) sigma) %) xs))
(def data (concat (repeatedly 6500 (fn [] (random-vector (repeat 10 0.0) 0.2)))
(repeatedly 7750 (fn [] (random-vector (repeat 10 -1.0) 0.75)))
(repeatedly 8000 (fn [] (random-vector (repeat 10 2.0) 0.5)))))
(k-means 3 (into [] data) 1e-1 3000)#'user/random-vector
#'user/data
[[(4.919285050921601E-4 -4.266121558193015E-4 -0.0010858441460167305 0.0013142832807452201 1.9292508351109687E-4 -0.0013877959463485913 1.8175031110970964E-4 0.0014874848400147044 -0.002005580085078893 -0.0016259257736556466) (-1.002152038406233 -1.0026026033807856 -0.9970592926719917 -1.0019272862513706 -0.9999133579141265 -0.9988884814285772 -0.9964083627021618 -0.9972992209638767 -1.000477954379125 -1.0043421065924427) (2.0008453816213505 1.9993503658791387 2.002264972768322 1.99713372344871 1.9943595404870138 1.996898772420761 1.9977224459186604 2.002879114924502 2.0010176551762173 2.0002044154827097)] 3.3292170836126408E-6 2]Not bad.