java手撕KMeans算法实现手写数字聚类(失败案例)
2022/3/19 22:28:16
本文主要是介绍java手撕KMeans算法实现手写数字聚类(失败案例),对大家解决编程问题具有一定的参考价值,需要的程序猿们随着小编来一起学习吧!
最近几天刚刚接触机器学习,学完K-Means聚类算法。正好又赶上一个课程项目是识别“手写数字”,因为KMeans能够实现聚类,因此自然而然地想要通过KMeans来实现。
前排提示:这是kmeans聚类的一个失败案例,没有成功聚类,仅供参考。
一,什么是KMeans聚类算法??
非常传统的聚类算法,目的是将一堆数据进行分类。
它的思想很朴素:假设这里有一群点,要将这些点分成两类。要是分成的类很合理的话,那不同类之间的中心点相聚是不是应该足够大,中心点附近的同一类的点是不是应该足够多?
举个例子:
a表示的是一堆原始点,没有处理。要将a聚类成两类,先随便找到两个点,计算所有点到这两个点的距离(欧式距离,曼哈顿距离,闵式距离等等都可以),根据距离最近的原则分配成两类。这时候是不是就能够得到两类的中心点,然后再次重复操作,直到最后聚出来的类不会发生变化。
so easy 是不是
二,使用的手写数字测试集??
我们在这里使用的是mnist测试集。这家伙的知名程度在机器学习中相当于是hello world了。不知道的小伙伴可以去查查。
但是一定有人会问到,mnist测试集应该怎么通过java使用呢?
不用担心,我用Python通过TensorFlow将mnist测试集打包成了txt文件,用java的文件操作直接调用就可以了。
具体效果像这样:
这是28 * 28的二维int数组,每个值介于0到255之间,熟悉图像处理的小伙伴一定知道这是灰度值,0表示最黑,255表示最亮,因此这是黑纸白字的测试集,大家要是自己写测试数据的使用要记着对图片进行预处理,要不然可能会出错。
我将txt命名为:数字名-标号的形式,方便之后训练和测试。
三,java手撕KMeans算法
先摆上一个算法流程图
1.首先定义:
训练图片(50000 * 28 * 28 的三维数组)
聚类中心(10 * 28 * 28的三维数组)
每张图片到聚类中心的距离(50000 * 10 的二维数组)
旧的类和新的类(ArrayList[] 数组,因为不知道一个类中到底会有多少个图片)
static float[][][] num = new float[50000][28][28]; static float[][][] center = new float[10][28][28];// 聚类中心 static long[][] distance = new long[num.length][10]; static ArrayList<Integer>[] oldKinds = new ArrayList[10];// 旧的聚类 static ArrayList<Integer>[] newKinds = new ArrayList[10];
2.定义方法:
从Txt文件导入测试数据的方法
public static void getTXT(String path,int img,int x,int y) throws IOException { File file = new File(path); FileInputStream fis = new FileInputStream(file); InputStreamReader isr = new InputStreamReader(fis); BufferedReader br = new BufferedReader(isr); String line; while((line = br.readLine()) != null){ boolean isNum = false; for(int i = 0;i < line.length();i ++){ if(line.charAt(i) != ' ' && !isNum){ // 如果遇到数字 isNum = true; float tempNum = 0; // 取数字 while(i < line.length() && line.charAt(i) != ' '){ tempNum = tempNum * 10 + line.charAt(i) - '0'; i++; } isNum = false; if(y < 28){ } else{ y = 0; x ++; } num[img][x][y] = tempNum; y++; } } } br.close(); }
获得图片到聚类中心距离的方法
// 得到距离 public static long getDistance(float[][] n,float[][] k){ long ret = 0; for (int i = 0; i < 28;i ++){ for (int j = 0; j < 28; j ++){ ret += Math.pow((n[i][j] - k[i][j]),2); } } return ret; }
得到图片距离最近聚类中心索引的方法
// 获得数组元素最小值对应的下标 public static int getMinIndex(long dis[]){ int index = -1; long min = Integer.MAX_VALUE; for(int i = 0; i < 10;i ++){ if(dis[i] < min){ index = i; min = dis[i]; } } return index; }
比较旧的聚类和新的聚类是否相同的方法
public static boolean isSame(){ for(int i = 0; i < 10 ;i ++){ for(int j = 0; j < newKinds[i].size();j ++){ if(newKinds[i].size() != oldKinds[i].size()) return false; if (newKinds[i].get(j).intValue() != oldKinds[i].get(j).intValue() ) { return false; } } } return true; }
需要注意的是!!!
两个Integer的比较需要通过.intValue()的方法先转换成为int!!!再进行比较,否则会因为内存什么什么奇奇怪怪的原因导致出现130 != 130这种很天真的错误。
我在这里被坑了一次,希望看到这片文章的人能够避一下坑。
3.开始while(true)死循环,直到旧类和新类相等不发生改变
int kindTime = 0; while(true){ // 3.计算每个文件和当前类中心之间的距离 for (int i = 0; i < num.length; i++){ for (int j = 0; j < 10; j++){ distance[i][j] = getDistance(num[i],center[j]); } } // 更新旧类 for(int i = 0;i < 10;i ++){ oldKinds[i].clear(); for(int j = 0 ; j < newKinds[i].size();j ++){ oldKinds[i].add(newKinds[i].get(j)); } } // 更新新类 for (int i = 0; i < 10 ; i ++){ newKinds[i].clear(); } for (int i = 0; i < num.length; i ++){ // 获得距离最小值,将其放到对应的类中 newKinds[getMinIndex(distance[i])].add(i); } // 4.更新聚类中心 for(int i = 0; i < 10; i ++){ for(int x = 0; x < 28; x++){ for(int y = 0; y < 28;y ++){ center[i][x][y] = getAverage(newKinds[i],x,y); } } } // 5.重复步骤,直到类不再发生改变 if(isSame()){ break; } System.out.println("第"+kindTime+"次聚类"); kindTime++; }
4.保存类中心点
因为如果训练数据不变的话,聚类聚出的中心是不会变化的,所以为了避免之后聚类的重复操作,我们还是将得到的聚类中心点保存成为txt文件放到电脑上比较好。
// 保存聚类中心点 public static void saveKind(int index){ FileWriter out = null; String path = "D:\\java\\workSpace\\KMeans\\" + index + "kinds.txt"; File file = new File(path); try { out = new FileWriter(file); //二维数组按行存入到文件中 for (int i = 0; i < center[index].length; i++) { for (int j = 0; j < center[index][i].length; j++) { //将每个元素转换为字符串 String content = String.valueOf(center[index][i][j]) + " "; out.write(content + "\t"); } out.write("\r\n"); } out.close(); } catch (IOException e) { e.printStackTrace(); } }
到现在,所有kmeans要求的操作我们都已经实现了。我们看看效果怎么样吧
1.我从test测试集(刚刚是train训练集)中导入了8000张图片,0到9每个数字各800张。
导入的方式和上文中的相同,这里就不在赘述了。
然后通过刚刚聚出来的类中心对测试数据进行聚类。(因为kmeans是无监督聚类吗,所以我也不知道每个类中心代表的哪个数字)
这是最后聚出来的结果:
发现大问题!!!我将每个类聚到的数字分别列出来。比如第0类,聚到4个数字0,3个数字1……
最后得到的结果,很!不!理!想!
通过分析可以看到,数字1的聚类效果最好,800张图片中有787张被聚到第7类中了,但是第7类也混入了不少其他数字,还有129张2是什么鬼?!
其他的类就更不用说了,混杂了很多数字。
经过缜密思考之后,我认为是k的数值设置的问题,因为我们想要聚类出10个数字,所以很主观地将k设置成为了10,没有思考相同数字,因为书写原因而出现的数字内部聚类的问题。
就像数字0,分别被聚到了第1类和第4类中,这两类很少有其他数字。因此是将数字0进行了分类,把高的0矮的0胖的0瘦的0分开了!而不是将0之外的数字分开。
或许可以通过改变k的值进行改进呢!
这片文章才差不多就是这样了。最后贴上代码。
如果有朋友想要mnist手写数字数据集的txt文件,可以给我留言邮箱信息哦,我抽时间会发送的。
欢迎大佬们批评指正!
// 首先是kmeans聚类的代码 import java.io.*; import java.util.ArrayList; public class KMeans { // KMeans算法实现手写数字聚类 static float[][][] num = new float[50000][28][28]; static float[][][] center = new float[10][28][28];// 聚类中心 static long[][] distance = new long[num.length][10]; static ArrayList<Integer>[] oldKinds = new ArrayList[10];// 旧的聚类 static ArrayList<Integer>[] newKinds = new ArrayList[10]; public static void main(String[] args) throws IOException { // 1.读取文件 System.out.println("导入文件中……"); for (int i = 0;i < num.length;i ++){ getTXT("D:\\Python\\jupyter\\trains2\\" + Integer.toString(i/5000) + "-" + Integer.toString(i%5000 + 1) + ".txt",i,0,0); if(i % 1000 == 0) System.out.println("已导入文件:" + i); } System.out.println("导入文件成功!!!"); // 随机选择聚类中心 for(int i = 0; i < 10; i ++){ oldKinds[i] = new ArrayList<>(); } for(int i = 0 ; i < 10;i ++) { transTwoArray(num[i], center[i]); newKinds[i] = new ArrayList<>(); newKinds[i].add(i); } int kindTime = 0; while(true){ // 3.计算每个文件和当前类中心之间的距离 for (int i = 0; i < num.length; i++){ for (int j = 0; j < 10; j++){ distance[i][j] = getDistance(num[i],center[j]); } } // 更新旧类 for(int i = 0;i < 10;i ++){ oldKinds[i].clear(); for(int j = 0 ; j < newKinds[i].size();j ++){ oldKinds[i].add(newKinds[i].get(j)); } } // 更新新类 for (int i = 0; i < 10 ; i ++){ newKinds[i].clear(); } for (int i = 0; i < num.length; i ++){ // 获得距离最小值,将其放到对应的类中 newKinds[getMinIndex(distance[i])].add(i); } // 4.更新聚类中心 for(int i = 0; i < 10; i ++){ for(int x = 0; x < 28; x++){ for(int y = 0; y < 28;y ++){ center[i][x][y] = getAverage(newKinds[i],x,y); } } } // 5.重复步骤,直到类不再发生改变 if(isSame()){ break; } System.out.println("第"+kindTime+"次聚类"); kindTime++; } // 保存聚类中心 System.out.println("聚类成功!!!"); System.out.println("-------------------------"); System.out.println("保存类中心点中……"); for(int i = 0; i < 10;i ++){ saveKind(i); } System.out.println("保存类中心点成功!!!"); } // 读取文件 public static void getTXT(String path,int img,int x,int y) throws IOException { File file = new File(path); FileInputStream fis = new FileInputStream(file); InputStreamReader isr = new InputStreamReader(fis); BufferedReader br = new BufferedReader(isr); String line; while((line = br.readLine()) != null){ boolean isNum = false; for(int i = 0;i < line.length();i ++){ if(line.charAt(i) != ' ' && !isNum){ // 如果遇到数字 isNum = true; float tempNum = 0; // 取数字 while(i < line.length() && line.charAt(i) != ' '){ tempNum = tempNum * 10 + line.charAt(i) - '0'; i++; } isNum = false; if(y < 28){ } else{ y = 0; x ++; } num[img][x][y] = tempNum; y++; } } } br.close(); } // 转移两个数组 public static void transTwoArray(float[][] array1,float[][] array2){ for(int i = 0; i < 28;i ++){ for (int j = 0; j < 28;j ++){ array2[i][j] = array1[i][j]; } } } // 得到距离 public static long getDistance(float[][] n,float[][] k){ long ret = 0; for (int i = 0; i < 28;i ++){ for (int j = 0; j < 28; j ++){ ret += Math.pow((n[i][j] - k[i][j]),2); } } return ret; } // 获得数组元素最小值对应的下标 public static int getMinIndex(long dis[]){ int index = -1; long min = Integer.MAX_VALUE; for(int i = 0; i < 10;i ++){ if(dis[i] < min){ index = i; min = dis[i]; } } return index; } // 计算均值 public static float getAverage(ArrayList<Integer> arr,int x,int y){ float ret = 0; for(int i = 0; i < arr.size(); i ++){ ret += num[arr.get(i)][x][y];// 将同一类中所有相同位置元素相加 } return ret / arr.size(); } // 保存聚类中心点 public static void saveKind(int index){ FileWriter out = null; String path = "D:\\java\\workSpace\\KMeans\\" + index + "kinds.txt"; File file = new File(path); try { out = new FileWriter(file); //二维数组按行存入到文件中 for (int i = 0; i < center[index].length; i++) { for (int j = 0; j < center[index][i].length; j++) { //将每个元素转换为字符串 String content = String.valueOf(center[index][i][j]) + " "; out.write(content + "\t"); } out.write("\r\n"); } out.close(); } catch (IOException e) { e.printStackTrace(); } } // 是否相等 public static boolean isSame(){ for(int i = 0; i < 10 ;i ++){ for(int j = 0; j < newKinds[i].size();j ++){ if(newKinds[i].size() != oldKinds[i].size()) return false; if (newKinds[i].get(j).intValue() != oldKinds[i].get(j).intValue() ) { return false; } } } return true; } }
测试聚类中心的代码
import java.io.*; import java.util.ArrayList; public class myKMeansTest { static float[][][] kMeans = new float[10][28][28]; static float[][][] test = new float[8000][28][28];// 测试数据,每个数字有800张 static long[][] distance = new long[8000][10];// 每张图片聚类类中心的距离 static ArrayList<Integer>[] kinds = new ArrayList[10];// 每个类中包含的图片索引 public static void main(String[] args) throws IOException { System.out.println("-----获取文件中-----"); // 读取聚类中心文件 for(int i = 0; i < 10;i ++){ String img = "D:\\java\\workSpace\\KMeans\\" + i + "kinds.txt"; getKMeansTxt(img,i); } // 读取测试文件 for(int i = 0;i < 8000;i ++){ String img = "D:\\Python\\jupyter\\test\\" + i/800 + "-" + (i%800 + 1) + ".txt"; getTestTxt(img,i,0,0); if(i % 800 == 0) System.out.println("已导入数据:"+i); } System.out.println("获取文件成功!!"); // 进行测试 System.out.println("开始聚类……"); for(int i = 0; i < 10;i ++){ kinds[i] = new ArrayList<>(); } for(int i = 0; i < 8000;i ++){ for (int j = 0; j < 10;j ++){ distance[i][j] = GoodKMeans.getDistance(kMeans[j],test[i]);// 获得每张图片对应聚类中心的距离 } } for(int i= 0;i< 8000;i++){ kinds[GoodKMeans.getMinIndex(distance[i])].add(i);// 将图片归为最小距离的类中 } System.out.println("聚类成功!!"); int[][] ans = new int[10][10]; for(int i = 0; i < 10;i ++){ for(int j = 0; j < kinds[i].size();j ++){ if(kinds[i].get(j) < 800) ans[i][0]++; else if(kinds[i].get(j) >= 800 && kinds[i].get(j) < 1600) ans[i][1]++; else if(kinds[i].get(j) >= 1600 && kinds[i].get(j)< 2400) ans[i][2]++; else if(kinds[i].get(j) >= 2400 && kinds[i].get(j)< 3200) ans[i][3]++; else if(kinds[i].get(j) >= 3200 && kinds[i].get(j)< 4000) ans[i][4]++; else if(kinds[i].get(j) >= 4000 && kinds[i].get(j)< 4800) ans[i][5]++; else if(kinds[i].get(j) >= 4800 && kinds[i].get(j)< 5600) ans[i][6]++; else if(kinds[i].get(j) >= 5600 && kinds[i].get(j)< 6400) ans[i][7]++; else if(kinds[i].get(j) >= 6400 && kinds[i].get(j)< 7200) ans[i][8]++; else if(kinds[i].get(j) >= 7200 && kinds[i].get(j)< 8000) ans[i][9]++; } } for (int i = 0; i < 10;i ++){ System.out.print("第"+i+"类中:"); for (int j = 0; j < 10;j ++){ System.out.print(j+":"); System.out.printf("%3d",ans[i][j]); System.out.print("\t"); } System.out.println(); } } // 获得聚类中心文件 public static void getKMeansTxt(String img,int index) throws IOException { File file = new File(img); FileInputStream fis = new FileInputStream(file); InputStreamReader isr = new InputStreamReader(fis); BufferedReader br = new BufferedReader(isr); int x = 0; int y = 0; String line; while((line = br.readLine()) != null){ boolean isNum = false; for(int i = 0;i < line.length();i ++){ if(line.charAt(i)-'0' <10 && line.charAt(i)-'0' >=0 && !isNum){ // 如果遇到数字 isNum = true; // 取数字 int j = i + 1; while(j < line.length() && line.charAt(j) != ' '){ j++; } isNum = false; if(y < 28){ } else{ y = 0; x ++; } kMeans[index][x][y] = Float.valueOf(line.substring(i,j)).floatValue(); i = j; y++; } } } br.close(); } // 获得测试文件 public static void getTestTxt(String path,int img,int x,int y) throws IOException { File file = new File(path); FileInputStream fis = new FileInputStream(file); InputStreamReader isr = new InputStreamReader(fis); BufferedReader br = new BufferedReader(isr); String line; while((line = br.readLine()) != null){ boolean isNum = false; for(int i = 0;i < line.length();i ++){ if(line.charAt(i) != ' ' && !isNum){ // 如果遇到数字 isNum = true; float tempNum = 0; // 取数字 while(i < line.length() && line.charAt(i) != ' '){ tempNum = tempNum * 10 + line.charAt(i) - '0'; i++; } isNum = false; if(y < 28){ } else{ y = 0; x ++; } test[img][x][y] = tempNum; y++; } } } br.close(); } }
这篇关于java手撕KMeans算法实现手写数字聚类(失败案例)的文章就介绍到这儿,希望我们推荐的文章对大家有所帮助,也希望大家多多支持为之网!
- 2024-11-23Springboot应用的多环境打包入门
- 2024-11-23Springboot应用的生产发布入门教程
- 2024-11-23Python编程入门指南
- 2024-11-23Java创业入门:从零开始的编程之旅
- 2024-11-23Java创业入门:新手必读的Java编程与创业指南
- 2024-11-23Java对接阿里云智能语音服务入门详解
- 2024-11-23Java对接阿里云智能语音服务入门教程
- 2024-11-23JAVA对接阿里云智能语音服务入门教程
- 2024-11-23Java副业入门:初学者的简单教程
- 2024-11-23JAVA副业入门:初学者的实战指南