今天把自己写的一个机器学习算法库中的K-means算法整理了一下,因为这个算法较其他的相比相对独立,可以单独贴出来,不会引用太多的其他类(不过还是有点引用,不过引用些简单的功能,看类名就知道什么意思了)。
基本功能和规则为:
1.当然是进行k-means算法,对数据集(这里使用二维数组来表示数据集,行数为数据总数,列数为数据维度)进行N维聚类
2.可以指定收敛的阀值(convergenceDis默认为0.0001)
3.为避免局部最小,可以指定重复运行次数,通过设定replicates的数值来指定,默认为0,即只重复一次聚类过程
4.测试数据格式为每一行代表一个输入,用空格分隔输入的各个维度,为了计算结果不出太大意外,建议对原始数据进行归一化
首先上骨架代码:


import java.util.Random;
import org.tadoo.ml.exception.ClusterException;
import org.tadoo.ml.util.ArrayCompute;
import org.tadoo.ml.util.Utils;
/**
* 使用K-means方法进行聚类
*
* <p>time:2011-6-1</p>
* @author T. QIN
*/
public class KmeansCluster
{
private double[][] dataSet = null;
private int k = 0;
private double[][] centers = null;
private double totalSumOfdistances = 0;
private boolean convergence = false;
private int iter;
private double convergenceDis = 0.0001;
private int replicates = 0;
private KMCResult[] kmcresults = null;
public KmeansCluster(double[][] x, int k) throws ClusterException
{
if (x == null || x.length == 0)
{
throw new ClusterException("输入数据不可为空。");
}
this.dataSet = x;
this.k = k;
this.centers = new double[k][dataSet[0].length];
}
private void initKCenters()
{
Random r = new Random();
int rn = r.nextInt(dataSet.length);
for (int i = 0; i < this.k; i++)//初始化k个中心
{
for (int j = 0; j < dataSet[0].length; j++)
{
centers[i][j] = dataSet[rn][j];
}
rn = r.nextInt(dataSet.length);
}
}
public void train()
{
if (replicates > 1)
{
kmcresults = new KMCResult[replicates];
for (int i = 0; i < replicates; i++)
{
beginTrain();
kmcresults[i] = new KMCResult();
kmcresults[i].centers = this.centers;
kmcresults[i].sum = this.totalSumOfdistances;
kmcresults[i].iters = this.iter;
this.centers = new double[k][dataSet[0].length];
}
}
else
{
beginTrain();
}
}
private void beginTrain()
{
int rows = dataSet.length;
int cols = dataSet[0].length;
int[] c = new int[rows];//保存每个数据属于哪个中心
int vote = 0;//如果某一中心收敛,则投票数可加一
iter = 0;
initKCenters();
convergence = false;
while (!convergence)
{
double minDistance = Double.MAX_VALUE;
double currentDis = 0.0;
int count = 0;
int changedCenterNumber = 0;
double[] temp = new double[cols];
totalSumOfdistances = 0;
for (int i = 0; i < rows; i++)
{
for (int j = 0; j < this.k; j++)
{
currentDis = Utils.distance(dataSet[i], centers[j]);
if (currentDis < minDistance)
{
minDistance = currentDis;
c[i] = j;
}
}
totalSumOfdistances += minDistance;
minDistance = Double.MAX_VALUE;
}
for (int i = 0; i < this.k; i++)
{
for (int j = 0; j < c.length; j++)
{
if (c[j] == i)
{
temp = Utils.add(temp, dataSet[j]);
count++;
}
}
if (count != 0)
{
temp = ArrayCompute.devideC(temp, count);
if (isCenterConvergence(centers[i], temp))
{
vote++;
}
centers[i] = temp;
changedCenterNumber++;
}
count = 0;
temp = new double[cols];
}
iter++;
if (vote == changedCenterNumber)
{
convergence = true;
}
vote = 0;
changedCenterNumber = 0;
}
}
/**
* 判断某中心是否收敛
*
* @param center
* @param pCenter
* @return
* @see:
*/
private boolean isCenterConvergence(double[] center, double[] pCenter)
{
boolean result = true;
double[] distance = ArrayCompute.minus(center, pCenter);
for (int i = 0; i < distance.length; i++)
{
if (Math.abs(distance[i]) > convergenceDis)
{
result = false;
}
}
return result;
}
/**
* dataSet的 get() 方法
* @return double[][] dataSet.
*/
public double[][] getDataSet()
{
return dataSet;
}
/**
* dataSet的 set() 方法
* @param dataSet The dataSet to set.
*/
public void setDataSet(double[][] dataSet)
{
this.dataSet = dataSet;
}
/**
* k的 get() 方法
* @return int k.
*/
public int getK()
{
return k;
}
/**
* k的 set() 方法
* @param k The k to set.
*/
public void setK(int k)
{
this.k = k;
}
/**
* centers的 get() 方法
* @return double[][] centers.
*/
public double[][] getCenters()
{
return centers;
}
/**
* centers的 set() 方法
* @param centers The centers to set.
*/
public void setCenters(double[][] centers)
{
this.centers = centers;
}
/**
* totalSumOfdistances的 get() 方法
* @return double totalSumOfdistances.
*/
public double getTotalSumOfdistances()
{
return totalSumOfdistances;
}
/**
* totalSumOfdistances的 set() 方法
* @param totalSumOfdistances The totalSumOfdistances to set.
*/
public void setTotalSumOfdistances(double totalSumOfdistances)
{
this.totalSumOfdistances = totalSumOfdistances;
}
/**
* iter的 get() 方法
* @return int iter.
*/
public int getIter()
{
return iter;
}
/**
* convergenceDis的 get() 方法
* @return double convergenceDis.
*/
public double getConvergenceDis()
{
return convergenceDis;
}
/**
* convergenceDis的 set() 方法
* @param convergenceDis The convergenceDis to set.
*/
public void setConvergenceDis(double convergenceDis)
{
this.convergenceDis = convergenceDis;
}
/**
* replicates的 get() 方法
* @return int replicates.
*/
public int getReplicates()
{
return replicates;
}
/**
* replicates的 set() 方法
* @param replicates The replicates to set.
*/
public void setReplicates(int replicates)
{
this.replicates = replicates;
}
/**
* kmcresults的 get() 方法
* @return KMCResult[] kmcresults.
*/
public KMCResult[] getKmcresults()
{
return kmcresults;
}
/**
* kmcresults的 set() 方法
* @param kmcresults The kmcresults to set.
*/
public void setKmcresults(KMCResult[] kmcresults)
{
this.kmcresults = kmcresults;
}
/**
* 聚类运行的结果
*
* <p>time:2011-6-2</p>
* @author T. QIN
*/
public class KMCResult
{
public double[][] centers;
public double sum;
public int iters;
}
}


/**
* 聚类异常
*
* <p>time:2011-5-25</p>
* @author T. QIN
*/
public class ClusterException extends RuntimeException
{
public ClusterException()
{
super();
}
public ClusterException(String s)
{
super(s);
}
}


/**
* 简单数组计算
*
* <p>time:2011-5-27</p>
* @author T. QIN
*/
public class ArrayCompute
{
/**
* 数组相加
*
* @param x1
* @param x2
* @return
* @see:
*/
public static double[] add(final double[] x1, final double[] x2)
{
if (x1.length != x2.length)
{
System.err.print("向量长度不等不能相加!");
System.exit(0);
}
double[] result = new double[x1.length];
for (int i = 0; i < result.length; i++)
{
result[i] = x1[i] + x2[i];
}
return result;
}
/**
* 数组相减
*
* @param x1
* @param x2
* @return
* @see:
*/
public static double[] minus(final double[] x1, final double[] x2)
{
if (x1.length != x2.length)
{
System.err.print("向量长度不等不能相减!");
System.exit(0);
}
double[] result = new double[x1.length];
for (int i = 0; i < result.length; i++)
{
result[i] = x1[i] - x2[i];
}
return result;
}
/**
* 数组乘以一个常数
*
* @param x1
* @param c
* @return
* @see:
*/
public static double[] multiplyC(final double[] x1, final double c)
{
double[] ret = new double[x1.length];
for (int i = 0; i < x1.length; i++)
{
ret[i] = x1[i] * c;
}
return ret;
}
/**
* 数组除以一个常数
*
* @param x1
* @param c
* @return
* @see:
*/
public static double[] devideC(final double[] x1, final double c)
{
double[] ret = new double[x1.length];
for (int i = 0; i < x1.length; i++)
{
ret[i] = x1[i] / c;
}
return ret;
}
}


import java.io.BufferedReader;
import java.io.File;
import java.io.FileNotFoundException;
import java.io.FileReader;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.StringTokenizer;
import org.tadoo.ml.Matrix;
/**
*
*
* <p>time:2011-3-23</p>
* @author T. QIN
*/
/**
*
*
* <p>time:2011-3-28</p>
* @author T. QIN
*/
public class Utils
{
/**
* 计算两个点之间的欧几里德距离
*
* @param x1
* @param x2
* @return
* @see:
*/
public static double distance(double[] x1, double[] x2)
{
double r = 0.0;
for (int i = 0; i < x1.length; i++)
{
r += Math.pow(x1[i] - x2[i], 2);
}
return Math.sqrt(r);
}
/**
* 数组相加
*
* @param x1
* @param x2
* @return
* @see:
*/
public static double[] add(final double[] x1, final double[] x2)
{
if (x1.length != x2.length)
{
System.err.print("向量长度不等不能相加!");
System.exit(0);
}
double[] result = new double[x1.length];
for (int i = 0; i < result.length; i++)
{
result[i] = x1[i] + x2[i];
}
return result;
}
}


import java.io.PrintStream;
import org.tadoo.ml.exception.MatrixComputeException;
/**
* 矩阵结构
*
* <p>time:2011-3-23</p>
* @author T. QIN
*/
public class Matrix
{
private int rowNum;
private int colNum;
private double value[][];
/**
* 构造器方法
*
* @param rows 行数
* @param cols 列数
* @see:
* @author: T. QIN
*/
public Matrix(int rows, int cols)
{
this.rowNum = rows;
this.colNum = cols;
this.value = new double[rows][cols];
}
/**
* 构造器方法
*
* @param rows 行数
* @param cols 列数
* @param isInitialMemory 是否初始化权值矩阵
* @see:
* @author: T. QIN
*/
public Matrix(int rows, int cols, boolean isInitialMemory)
{
this.rowNum = rows;
this.colNum = cols;
if (isInitialMemory)
{
this.value = new double[rows][cols];
}
}
/**
* 替换矩阵值
*
* @param v
* @throws MatrixComputeException
* @see:
*/
public void changeWholeValue(double v[][]) throws MatrixComputeException
{
if (v.length != this.rowNum && v[0].length != this.colNum)
{
throw new MatrixComputeException("矩阵大小不拟合");
}
this.value = v;
}
public void print(PrintStream ps)
{
if (ps == null)
{
ps = System.out;
}
for (int i = 0; i < rowNum; i++)
{
for (int j = 0; j < colNum; j++)
{
ps.print(value[i][j] + "\t");
}
ps.println();
}
}
/**
* overwrite
*
* @return
* @see:
*/
public String toString()
{
StringBuffer sb = new StringBuffer();
for (int i = 0; i < rowNum; i++)
{
for (int j = 0; j < colNum; j++)
{
sb.append(value[i][j] + "\t");
}
sb.append("\n");
}
return sb.toString();
}
/**
* rowNum的 get() 方法
* @return int rowNum.
*/
public int getRowNum()
{
return rowNum;
}
/**
* rowNum的 set() 方法
* @param rowNum The rowNum to set.
*/
public void setRowNum(int rowNum)
{
this.rowNum = rowNum;
}
/**
* colNum的 get() 方法
* @return int colNum.
*/
public int getColNum()
{
return colNum;
}
/**
* colNum的 set() 方法
* @param colNum The colNum to set.
*/
public void setColNum(int colNum)
{
this.colNum = colNum;
}
/**
* value的 get() 方法
* @return double[][] value.
*/
public double[][] getValue()
{
return value;
}
/**
* value的 set() 方法
* @param value The value to set.
*/
public void setValue(double[][] value)
{
this.value = value;
}
}


/**
*
*
* <p>time:2011-3-23</p>
* @author T. QIN
*/
public class MatrixComputeException extends Exception
{
public MatrixComputeException(String s)
{
super(s);
}
}


import java.io.BufferedReader;
import java.io.BufferedWriter;
import java.io.File;
import java.io.FileNotFoundException;
import java.io.FileReader;
import java.io.FileWriter;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
/**
* 加载文件中的数据
*
* <p>time:2011-5-31</p>
* @author T. QIN
*/
public class DataUtil
{
/**
* 加载数据
*
* @param filePath
* @return
* @see:
*/
public static double[][] load(String filePath)
{
BufferedReader reader = null;
List<String[]> container = new ArrayList<String[]>();
String line = null;
double[][] result = null;
int xs, ys = 0;
try
{
reader = new BufferedReader(new FileReader(new File(filePath)));
while ((line = reader.readLine()) != null)
{
String temp[] = line.trim().split("[\\s]+");
container.add(temp);
}
xs = (((String[]) container.get(0)).length);
ys = container.size(); //数据条目
result = new double[ys][xs];
String[] strings = null;
for (int i = 0, n = container.size(); i < n; i++)
{
strings = (String[]) container.get(i);
for (int j = 0; j < strings.length; j++)
{
result[i][j] = Double.parseDouble(strings[j]);
}
}
}
catch (FileNotFoundException e)
{
e.printStackTrace();
}
catch (IOException e)
{
e.printStackTrace();
}
return result;
}
//TODO:
/**
* 输出数据到文件,可选择某几列属性
*
* @param data
* @param saveFilename
* @param columns
* @see:
*/
public static void save(double[][] data, String saveFilename, int[] columns)
{
BufferedWriter fp_saver = null;
Arrays.sort(columns);
try
{
fp_saver = new BufferedWriter(new FileWriter(saveFilename));
for (int i = 0; i < data.length; i++)
{
for (int j = 0; j < columns.length; j++)
{
fp_saver.write(String.valueOf(data[i][columns[j]]) + " ");
}
fp_saver.write("\n");
}
fp_saver.flush();
}
catch (IOException e)
{
e.printStackTrace();
}
finally
{
try
{
fp_saver.close();
}
catch (IOException e)
{
e.printStackTrace();
}
}
}
}
然后是测试:


import junit.framework.TestCase;
import org.tadoo.ml.Matrix;
import org.tadoo.ml.cluster.kmeans.KmeansCluster;
import org.tadoo.ml.util.DataUtil;
import org.tadoo.ml.util.Utils;
/**
* 测试K-means聚类器
*
* <p>time:2011-6-2</p>
* @author T. QIN
*/
public class TestKmeansCluster extends TestCase
{
Matrix dataSet = null;
double[][] ds = null;
protected void setUp()
{
dataSet = Utils.uniformFileInputIntoFeatures("D:\\test.s.txt");
ds = DataUtil.load("D:\\data1.txt");
}
/**
* 测试用K-means选取中心节点
*
* @see:
*/
public void testKmeansCenters()
{
KmeansCluster kmc = new KmeansCluster(dataSet.getValue(), 2);
kmc.train();
System.out.println(kmc.getTotalSumOfdistances());
System.out.println(kmc.getIter());
double[][] centers = kmc.getCenters();
for (int i = 0; i < centers.length; i++)
{
for (int j = 0; j < centers[i].length; j++)
{
System.out.print(centers[i][j] + "\t");
}
System.out.println();
}
}
public void testKmeansReplicate(){
KmeansCluster kmc = new KmeansCluster(dataSet.getValue(), 11);
kmc.setReplicates(12);
kmc.train();
KmeansCluster.KMCResult[] kmcr = kmc.getKmcresults();
for (int i = 0; i < kmcr.length; i++)
{
System.out.println("iters:"+kmcr[i].iters+"\tSum:"+kmcr[i].sum);
}
}
}