import torch
from .base import Dataset
[docs]class SimpleDataset(Dataset):
"""We use this as a simple dataset to test clustering algorithm.
"""
def __init__(self, n_clusters, device='cpu', feature=10, sigma=10):
"""Simple dataset factory's config.
Args:
n_clusters (int) - How many clusters in result.
Kwargs:
device (string) - Device of tensors.
feature (int) - The dim of each data point.
sigma (float) - Factor of clustering difficulty, the bigger the easier.
"""
super(SimpleDataset, self).__init__()
self.n_clusters = n_clusters
self.device = device
self.feature = feature
self.sigma = sigma
[docs] def __call__(self, n):
"""Generate dataset.
Args:
n (int) - the number of data point.
"""
idx_n = n // self.n_clusters
X = torch.cat([torch.randn(idx_n, self.feature, device=self.device) + idx * self.sigma for idx in range(self.n_clusters)])
y = torch.cat([torch.ones(idx_n, dtype=torch.long, device=self.device) * idx for idx in range(self.n_clusters)])
return X, y