OpenCV3.2图像分割 实例4:GMM(高斯混合模型)样本数据训练与预言

发布时间 2023-08-18 09:00:17作者: 一杯清酒邀明月
 1 #include <opencv2/opencv.hpp>
 2 #include <iostream>
 3  
 4 using namespace cv;
 5 using namespace cv::ml;
 6 using namespace std;
 7  
 8 int main(int argc, char** argv) {
 9     Mat img = Mat::zeros(500, 500, CV_8UC3);
10     RNG rng(12345);
11  
12     Scalar colorTab[] = {
13         Scalar(0, 0, 255),
14         Scalar(0, 255, 0),
15         Scalar(255, 0, 0),
16         Scalar(0, 255, 255),
17         Scalar(255, 0, 255)
18     };
19  
20     int numCluster = rng.uniform(2, 5);
21     printf("number of clusters : %d\n", numCluster);
22  
23     int sampleCount = rng.uniform(5, 1000);
24     Mat points(sampleCount, 2, CV_32FC1);
25     Mat labels;
26  
27     // 生成随机数
28     for (int k = 0; k < numCluster; k++) {
29         Point center;
30         center.x = rng.uniform(0, img.cols);
31         center.y = rng.uniform(0, img.rows);
32         Mat pointChunk = points.rowRange(k*sampleCount / numCluster,
33             k == numCluster - 1 ? sampleCount : (k + 1)*sampleCount / numCluster);
34  
35         rng.fill(pointChunk, RNG::NORMAL, Scalar(center.x, center.y), Scalar(img.cols*0.05, img.rows*0.05));
36     }
37     randShuffle(points, 1, &rng);
38     Ptr<EM> em_model = EM::create();
39     em_model->setClustersNumber(numCluster);
40     em_model->setCovarianceMatrixType(EM::COV_MAT_SPHERICAL);//协方差矩阵
41     //训练次数设置为100
42     em_model->setTermCriteria(TermCriteria(TermCriteria::EPS + TermCriteria::COUNT, 100, 0.1));
43     em_model->trainEM(points, noArray(), labels, noArray());
44  
45     // classify every image pixels
46     Mat sample(1, 2, CV_32FC1);
47     for (int row = 0; row < img.rows; row++) {
48         for (int col = 0; col < img.cols; col++) {
49             sample.at<float>(0) = (float)col;
50             sample.at<float>(1) = (float)row;
51             int response = cvRound(em_model->predict2(sample, noArray())[1]);
52             Scalar c = colorTab[response];
53             circle(img, Point(col, row), 1, c*0.75, -1);
54         }
55     }
56  
57     // draw the clusters
58     for (int i = 0; i < sampleCount; i++) {
59         Point p(cvRound(points.at<float>(i, 0)), points.at<float>(i, 1));
60         circle(img, p, 1, colorTab[labels.at<int>(i)], -1);
61     }
62  
63     imshow("GMM-EM Demo", img);
64  
65     waitKey(0);
66     return 0;
67 }