OpenCV3.3深度神经网络DNN模块 实例3:SSD模型实现对象检测

发布时间 2023-08-18 09:30:53作者: 一杯清酒邀明月
  1 #include <opencv2/opencv.hpp>
  2 #include <opencv2/dnn.hpp>
  3 #include <iostream>
  4  
  5 using namespace cv;
  6 using namespace cv::dnn;
  7 using namespace std;
  8  
  9 const size_t width = 300;//模型尺寸为300*300
 10 const size_t height = 300;
 11 //label文件
 12 String labelFile = "D:/opencv3.3/opencv/sources/samples/data/dnn/labelmap_det.txt";
 13 //模型文件
 14 String modelFile = "D:/opencv3.3/opencv/sources/samples/data/dnn/VGG_ILSVRC2016_SSD_300x300_iter_440000.caffemodel";
 15 //模型描述文件
 16 String model_text_file = "D:/opencv3.3/opencv/sources/samples/data/dnn/deploy.prototxt";
 17  
 18 vector<String> readLabels();
 19 const int meanValues[3] = { 104, 117, 123 };
 20 static Mat getMean(const size_t &w, const size_t &h) {
 21     Mat mean;
 22     vector<Mat> channels;
 23     for (int i = 0; i < 3; i++) {
 24         Mat channel(h, w, CV_32F, Scalar(meanValues[i]));
 25         channels.push_back(channel);
 26     }
 27     merge(channels, mean);
 28     return mean;
 29 }
 30  
 31 static Mat preprocess(const Mat &frame) {
 32     Mat preprocessed;
 33     frame.convertTo(preprocessed, CV_32F);
 34     resize(preprocessed, preprocessed, Size(width, height)); // 300x300 image
 35     Mat mean = getMean(width, height);
 36     subtract(preprocessed, mean, preprocessed);
 37     return preprocessed;
 38 }
 39  
 40 int main(int argc, char** argv) {
 41     Mat frame = imread("persons.png");
 42     if (frame.empty()) {
 43         printf("could not load image...\n");
 44         return -1;
 45     }
 46     namedWindow("input image", CV_WINDOW_AUTOSIZE);
 47     imshow("input image", frame);
 48  
 49     vector<String> objNames = readLabels();
 50     // import Caffe SSD model
 51     Ptr<dnn::Importer> importer;
 52     try {
 53         importer = createCaffeImporter(model_text_file, modelFile);
 54     }
 55     catch (const cv::Exception &err) {
 56         cerr << err.msg << endl;
 57     }
 58     //初始化网络
 59     Net net;
 60     importer->populateNet(net);
 61     importer.release();
 62  
 63     Mat input_image = preprocess(frame);//获取输入图像
 64     Mat blobImage = blobFromImage(input_image);//将图像转换为blob
 65  
 66     net.setInput(blobImage, "data");//将图像转换的blob数据输入到网络的第一层“data”层,见deploy.protxt文件
 67     Mat detection = net.forward("detection_out");//结果输出(最后一层“detection_out”层)输出给detection
 68     Mat detectionMat(detection.size[2], detection.size[3], CV_32F, detection.ptr<float>());
 69     float confidence_threshold = 0.2;//自信区间,可以修改,越低检测到的物体越多
 70     for (int i = 0; i < detectionMat.rows; i++) {
 71         float confidence = detectionMat.at<float>(i, 2);
 72         if (confidence > confidence_threshold) {
 73             size_t objIndex = (size_t)(detectionMat.at<float>(i, 1));
 74             float tl_x = detectionMat.at<float>(i, 3) * frame.cols;
 75             float tl_y = detectionMat.at<float>(i, 4) * frame.rows;
 76             float br_x = detectionMat.at<float>(i, 5) * frame.cols;
 77             float br_y = detectionMat.at<float>(i, 6) * frame.rows;
 78  
 79             Rect object_box((int)tl_x, (int)tl_y, (int)(br_x - tl_x), (int)(br_y - tl_y));
 80             //标记框
 81             rectangle(frame, object_box, Scalar(0, 0, 255), 2, 8, 0);
 82             //设置颜色
 83             putText(frame, format("%s", objNames[objIndex].c_str()), Point(tl_x, tl_y), FONT_HERSHEY_SIMPLEX, 1.0, Scalar(255, 0, 0), 2);
 84         }
 85     }
 86     imshow("ssd-demo", frame);
 87  
 88     waitKey(0);
 89     return 0;
 90 }
 91  
 92 vector<String> readLabels() {
 93     vector<String> objNames;
 94     ifstream fp(labelFile);
 95     if (!fp.is_open()) {
 96         printf("could not open the file...\n");
 97         exit(-1);
 98     }
 99     string name;
100     while (!fp.eof()) {
101         getline(fp, name);
102         if (name.length() && (name.find("display_name:") == 0)) {
103             string temp = name.substr(15);
104             temp.replace(temp.end() - 1, temp.end(), "");
105             objNames.push_back(temp);
106         }
107     }
108     return objNames;
109 }