矩阵乘法的指令集加速例子

发布时间 2023-04-27 09:51:38作者: 兜尼完

这里就不介绍基本概念了,直接给代码和对比结果。分别是普通C++代码,SSE加速代码和OpenCV代码。代码基于VS2017、OpenCV430和Qt5.9。CPU型号是Intel Core i5-7400。

Mat mul1(const Mat& a, const Mat& b)
{
    ASSERT(a.cols == b.rows);
#define COUNT a.cols
    Mat c = Mat::zeros(a.rows, b.cols, CV_32FC1);
    for (int i = 0; i < c.rows; i++)
    {
        const float *aptr = a.ptr<float>(i);
        float *optr = c.ptr<float>(i);
        for (int k = 0; k < COUNT; k++)
        {
            const float* bptr = b.ptr<float>(k);
            for (int j = 0; j < c.cols; j++)
            {
                optr[j] += aptr[k] * bptr[j];
            }
        }
    }
#undef COUNT
    return c;
}

Mat mul2(const Mat& a, const Mat& b)
{
    ASSERT(a.cols == b.rows);
#define COUNT a.cols
    Mat c = Mat::zeros(a.rows, b.cols, CV_32FC1);
    int max4 = c.cols / 4 * 4;
    for (int i = 0; i < c.rows; i++)
    {
        const float *aptr = a.ptr<float>(i);
        float *optr = c.ptr<float>(i);
        for (int k = 0; k < COUNT; k++)
        {
            const float* bptr = b.ptr<float>(k);
            for (int j = 0; j < max4; j += 4)
            {
                __m128 mma = _mm_set_ps1(aptr[k]);
                __m128 mmb = _mm_loadu_ps(&bptr[j]);
                __m128 mmo = _mm_loadu_ps(&optr[j]);
                __m128 sum = _mm_fmadd_ps(mma, mmb, mmo); /* FMA */
                _mm_storeu_ps(&optr[j], sum);
            }
            for (int j = max4; j < c.cols; j++)
            {
                optr[j] += aptr[k] * bptr[j];
            }
        }
    }
#undef COUNT
    return c;
}

void main()
{
    Mat image = Mat::zeros(1600, 1600, CV_8UC1);
    circle(image, Point2i(800, 800), 200, 255, -1);
    Mat a, b;
    image.convertTo(a, CV_32FC1, 0.0039);
    image.convertTo(b, CV_32FC1, 0.0039);
    int64 t1, t2;

    Mat result1;
    t1 = getTickCount();
    result1 = mul1(a, b);
    t2 = getTickCount();
    qDebug() << u8"diy(ms):" << (t2 - t1) / getTickFrequency() * 1000;

    Mat result2;
    t1 = getTickCount();
    result2 = mul2(a, b);
    t2 = getTickCount();
    qDebug() << u8"mul2(ms)" << (t2 - t1) / getTickFrequency() * 1000;

    Mat result3;
    t1 = getTickCount();
    result3 = a * b;
    t2 = getTickCount();
    qDebug() << u8"cv(ms)" << (t2 - t1) / getTickFrequency() * 1000;

    qDebug() << result1.at<float>(800, 800) << result2.at<float>(800, 800) << result3.at<float>(800, 800);
}

下面是Release版程序的文本输出。可见计算结果跟OpenCV尾数稍有区别,可能是浮点数精度问题。SSE的效率要超过OpenCV的速度:

diy(ms): 3028.05
mul2(ms) 1165.66
cv(ms) 2454.72
396.603 396.603 396.601