OpenCV3.3深度神经网络DNN模块 实例6:CNN模型预测性别与年龄

发布时间 2023-08-18 09:35:59作者: 一杯清酒邀明月
 1 #include <opencv2/opencv.hpp>
 2 #include <opencv2/dnn.hpp>
 3 #include <iostream>
 4  
 5 using namespace cv;
 6 using namespace cv::dnn;
 7 using namespace std;
 8 //人脸检测文件
 9 String haar_file = "D:/opencv3.3/opencv/build/etc/haarcascades/haarcascade_frontalface_alt_tree.xml";
10 //年龄预测模型
11 String age_model = "D:/opencv3.3/opencv/sources/samples/data/dnn/age_net.caffemodel";
12 //年龄描述文件
13 String age_text = "D:/opencv3.3/opencv/sources/samples/data/dnn/deploy_age.prototxt";
14  
15 //性别预测模型
16 String gender_model = "D:/opencv3.3/opencv/sources/samples/data/dnn/gender_net.caffemodel";
17 //年龄描述文件
18 String gender_text = "D:/opencv3.3/opencv/sources/samples/data/dnn/deploy_gender.prototxt";
19  
20 void predict_age(Net &net, Mat &image);//预测年龄 
21 void predict_gender(Net &net, Mat &image);//预测性别
22 int main(int argc, char** argv) {
23     Mat src = imread("star_lady.png");
24     if (src.empty()) {
25         printf("could not load image...\n");
26         return -1;
27     }
28     namedWindow("input", CV_WINDOW_AUTOSIZE);
29     imshow("input", src);
30     CascadeClassifier detector;
31     detector.load(haar_file);//人脸检测
32     vector<Rect> faces;
33     Mat gray;
34     cvtColor(src, gray, COLOR_BGR2GRAY);
35     detector.detectMultiScale(gray, faces, 1.02, 1, 0, Size(40, 40), Size(200, 200));
36     //加载网络
37     Net age_net = readNetFromCaffe(age_text, age_model);
38     Net gender_net = readNetFromCaffe(gender_text, gender_model);
39  
40     for (size_t t= 0; t < faces.size(); t++) {
41         rectangle(src, faces[t], Scalar(30, 255, 30), 2, 8, 0);
42         //年龄、性别预测
43         Mat face = src(faces[t]);//自己加的,不加会报错,提示类型错误
44         predict_age(age_net, face);
45         predict_gender(age_net, face);
46     }
47     imshow("age-gender-prediction-demo", src);
48  
49     waitKey(0);
50     return 0;
51 }
52  
53 vector<String> ageLabels() {
54     vector<String> ages;
55     ages.push_back("0-2");
56     ages.push_back("4 - 6");
57     ages.push_back("8 - 13");
58     ages.push_back("15 - 20");
59     ages.push_back("25 - 32");
60     ages.push_back("38 - 43");
61     ages.push_back("48 - 53");
62     ages.push_back("60-");
63     return ages;
64 }
65  
66 void predict_age(Net &net, Mat &image) {
67     // 输入
68     Mat blob = blobFromImage(image, 1.0, Size(227, 227));
69     net.setInput(blob, "data");
70     // 预测分类
71     Mat prob = net.forward("prob");
72     Mat probMat = prob.reshape(1, 1);//变为一行
73     Point classNum;
74     double classProb;
75  
76     vector<String> ages = ageLabels();
77     minMaxLoc(probMat, NULL, &classProb, NULL, &classNum);//提取最大概率的编号和概率值
78     int classidx = classNum.x;
79     putText(image, format("age:%s", ages.at(classidx).c_str()), Point(2, 10), FONT_HERSHEY_PLAIN, 0.8, Scalar(0, 0, 255), 1);
80 }
81  
82 void predict_gender(Net &net, Mat &image) {
83     // 输入
84     Mat blob = blobFromImage(image, 1.0, Size(227, 227));
85     net.setInput(blob, "data");
86     // 预测分类
87     Mat prob = net.forward("prob");
88     Mat probMat = prob.reshape(1, 1);
89     putText(image, format("gender:%s", (probMat.at<float>(0, 0) > probMat.at<float>(0, 1) ? "M" : "F")),
90         Point(2, 20), FONT_HERSHEY_PLAIN, 0.8, Scalar(0, 0, 255), 1);
91 }