opencv SVM 训练ocr模型

发布时间 2023-09-14 12:08:37作者: 哈库拉

实现0-6字符分类

数据准备: 训练数据:

train_data.txt 

查看代码
 D:/ocr/svm/train/imgs/0/0.png
0
D:/ocr/svm/train/imgs/0/0_1.jpg
0
D:/ocr/svm/train/imgs/1/1.png
1
D:/ocr/svm/train/imgs/1/1_1.jpg
1
D:/ocr/svm/train/imgs/1/1_2.jpg
1
D:/ocr/svm/train/imgs/1/1_3.jpg
1
D:/ocr/svm/train/imgs/2/2.png
2
D:/ocr/svm/train/imgs/2/2_1.jpg
2
D:/ocr/svm/train/imgs/2/2_2.jpg
2
D:/ocr/svm/train/imgs/3/3.png
3
D:/ocr/svm/train/imgs/3/3_1.jpg
3
D:/ocr/svm/train/imgs/3/3_2.jpg
3
D:/ocr/svm/train/imgs/4/4.png
4
D:/ocr/svm/train/imgs/4/4_1.jpg
4
D:/ocr/svm/train/imgs/4/4_2.jpg
4
D:/ocr/svm/train/imgs/5/5.png
5
D:/ocr/svm/train/imgs/5/5_1.jpg
5
D:/ocr/svm/train/imgs/5/5_2.jpg
5
D:/ocr/svm/train/imgs/6/6.png
6
D:/ocr/svm/train/imgs/6/6_1.jpg
6
D:/ocr/svm/train/imgs/6/6_2.jpg
6

数据处理及训练:

查看代码
 #include <iostream>
#include <opencv2/ml/ml.hpp>
#include <opencv2/objdetect/objdetect.hpp>

void OcrTrain()
{
    using namespace cv::ml;
    using namespace std;
    vector<string> imgpath;  // path of train image
    vector<int> imglabel;    // label of train image
    int nLine = 0;
    string buf;
    ifstream svm_data;
    svm_data.open("d:/ocr/svm/train/train_data.txt", ios::in);
    if (!svm_data.is_open())
    {
        cout << "read error" << endl;
        exit(EXIT_FAILURE);
    }
    unsigned long n;
    while (svm_data)
    {
        if (getline(svm_data, buf))
        {
            nLine++;
            if (nLine % 2 == 0)
            {
                imglabel.push_back(atoi(buf.c_str()));
            }
            else
            {
                imgpath.push_back(buf);
            }
        }
    }
    svm_data.close();    //close file
    /// <summary>
    /// 训练
    /// </summary>

    Mat data_mat, res_mat;
    int nImgNum = nLine / 2;
    data_mat = Mat::zeros(nImgNum, 324, CV_32FC1); // store hog feature 324=9*4*9 single channel float
    res_mat = Mat::zeros(nImgNum, 1, CV_32S);    // store label 注意这里的数据类型32F不行
    Mat src;
    for (string::size_type i = 0; i != imgpath.size(); i++)
    {
        src = imread(imgpath[i], 1);  // read train image
        if (src.empty())
        {
            cout << "can not read the image: " << imgpath[i] << endl;
            continue;
        }
        cout << "processing:" << endl;
        Mat trainImg;
        resize(src, trainImg, Size(28, 28)); // resize to 28*28
        HOGDescriptor* hog = new HOGDescriptor(Size(28, 28), Size(14, 14), Size(7, 7), Size(7, 7), 9); // hog descriptor
        vector<float> descriptors; // store result
        hog->compute(trainImg, descriptors, Size(1, 1), Size(0, 0)); // compute hog descriptor
        cout << "HOG dims:";
        n = 0;
        for (vector<float>::iterator iter = descriptors.begin(); iter != descriptors.end(); iter++)
        {
            data_mat.at<float>(i, n) = (*iter); // put hog descriptor into data_mat
            n++;
        }
        res_mat.at<int>(i, 0) = imglabel[i]; // put label into res_mat
        cout << "processing done:" << " " << endl;
    }
    Ptr<cv::ml::SVM> svm = SVM::create();//创建一个svm对象
    svm->setType(cv::ml::SVM::C_SVC);
    svm->setKernel(SVM::LINEAR);
    svm->setDegree(0);
    svm->setGamma(1);
    svm->setCoef0(0);
    svm->setC(1);
    svm->setNu(0);
    svm->setP(0);
    svm->setTermCriteria(TermCriteria(TermCriteria::MAX_ITER, 1000, TermCriteria::EPS));//设置SVM训练时迭代终止条件  10的12次方
    //训练
    cout << "开始进行训练..." << endl;
    Ptr<TrainData> tData = TrainData::create(data_mat, ROW_SAMPLE, res_mat);
    //svm->train(tData);  //这两行代码和下面一行代码等效
    svm->train(data_mat, cv::ml::SampleTypes::ROW_SAMPLE, res_mat);
    Mat resp;
    float err = svm->calcError(tData, false, resp);
    //CvSVM svm;
    //CvSVMParams param;
    //param = CvSVMParams(CvSVM::C_SVC, CvSVM::RBF, 10.0, 0.09, 1.0, 10.0, 0.5, 1.0, NULL, criteria); // svm parameter
    //svm->train(data_mat, res_mat, Mat(), Mat(), param); //train
    svm->save("D:/ocr/svm/HOG_SVM_OCR.xml"); // preserve result
}

加载模型预测:

查看代码
   Mat test;
    char result[512];
    vector<string> img_tst_path;
    ifstream img_tst("D:/ocr/svm/test_data.txt");
    string test_dir = "D:/ocr/svm/test/";
    while (img_tst)
    {
        if (getline(img_tst, buf))
        {
            buf = test_dir + buf;
            img_tst_path.push_back(buf);
        }
    }
    img_tst.close();
    // 预测阶段
    Ptr<cv::ml::SVM> svmLoad = StatModel::load<SVM>("D:/ocr/svm/HOG_SVM_OCR.xml");
    ofstream predict_txt("D:/ocr/svm/SVM_PREDICT1.txt");
    for (string::size_type j = 0; j != img_tst_path.size(); j++)
    {
        test = imread(img_tst_path[j], 1);
        if (test.empty())
        {
            cout << "can not load the image:" << endl;
            continue;
        }
        Mat trainTempImg;
        resize(test, trainTempImg, Size(28, 28));
        HOGDescriptor* hog = new HOGDescriptor(Size(28, 28), Size(14, 14), Size(7, 7), Size(7, 7), 9);
        vector<float> descriptors;
        hog->compute(trainTempImg, descriptors, Size(1, 1), Size(0, 0));
        cout << "HOG dims:" << endl;
        Mat SVMtrainMat(1, descriptors.size(), CV_32FC1);
        int n = 0;
        for (vector<float>::iterator iter = descriptors.begin(); iter != descriptors.end(); iter++)
        {
            SVMtrainMat.at<float>(0, n) = (*iter);
            n++;
        }
        int ret = svmLoad->predict(SVMtrainMat); // predict by svm
        sprintf_s(result, "%s %d\r\n", img_tst_path[j], ret);
        cout << img_tst_path[j]<<"   " <<ret << endl;
        predict_txt << result;//predict result
    }
    predict_txt.close();

结果输出: