Source code for torchcluster.zoo.k_means

import torch
import numpy as np
from .base import Cluster
from ..helper.distance import setwise_distance

[docs]class KMeans(Cluster): """K-Means algorithm """ def __init__(self, n_clusters, tol=1e-4): """Spectrum clustering factory's config. Args: n_clusters (int) - How many clusters in result. Kwargs: tol (float) - stop to update when shift is smaller than tol """ super(KMeans, self).__init__() self.n_clusters = n_clusters self.tol = tol
[docs] def __call__(self, x): """Clustering. Args: x (Tensor) - Data points of number n by feature dim m. """ idx = np.random.choice(len(x), self.n_clusters) state = x[idx] while True: pre_state = state.clone() dis = setwise_distance(x, state).squeeze() result = torch.argmin(dis, dim=1) for i in range(self.n_clusters): idx = torch.nonzero(result == i).squeeze() items = torch.index_select(x, 0, idx) if items.size(0): state[i] = items.mean(dim=0) else: state[i] = pre_state[i].clone() shift = torch.pairwise_distance(pre_state, state) total = torch.pow(torch.sum(shift), 2.0) if total < self.tol: return result, state