使用benchmark比较循环嵌套与strassen求解矩阵乘法的性能

发布时间 2023-04-05 18:21:43作者: 残影0无痕
#include <benchmark/benchmark.h>

#include <iostream>
#include <random>
#include <vector>

using namespace std;

static const int n = 200;
static const int _lrange = 0;
static const int _rrange = 10;
static const int _iter = 1;

using Matrix = vector<vector<int>>;

auto matrix_mult(Matrix _Amatrix, Matrix _Bmatrix, int n) {
	Matrix _Rmatrix(n, vector<int>(n, 0));
	for (int i = 0; i < n; i++)
		for (int j = 0; j < n; j++)
			for (int k = 0; k < n; k++)
				_Rmatrix[i][j] += _Amatrix[i][k] * _Bmatrix[k][j];
	return _Rmatrix;
}

Matrix operator+(Matrix _Amatrix, Matrix _Bmatrix) {
	int n = _Amatrix.size();
	Matrix _Rmatrix(n, vector<int>(n, 0));
	for (int i = 0; i < n; i++)
		for (int j = 0; j < n; j++)
			_Rmatrix[i][j] = _Amatrix[i][j] + _Bmatrix[i][j];
	return _Rmatrix;
}
Matrix operator-(Matrix _Amatrix, Matrix _Bmatrix) {
	int n = _Amatrix.size();
	Matrix _Rmatrix(n, vector<int>(n, 0));
	for (int i = 0; i < n; i++)
		for (int j = 0; j < n; j++)
			_Rmatrix[i][j] = _Amatrix[i][j] - _Bmatrix[i][j];
	return _Rmatrix;
}

Matrix slice_matrix(Matrix _matrix, int row, int col, int n) {
	Matrix _Rmatrix(n, vector<int>(n, 0));
	for (int i = 0; i < n; i++)
		for (int j = 0; j < n; j++) _Rmatrix[i][j] = _matrix[row + i][col + j];
	return _Rmatrix;
}

Matrix merge_matrix(Matrix _M11, Matrix _M12, Matrix _M21, Matrix _M22, int n) {
	Matrix _Rmatrix(n, vector<int>(n, 0));
	for (int i = 0; i < n / 2; i++)
		for (int j = 0; j < n / 2; j++) _Rmatrix[i][j] = _M11[i][j];
	for (int i = 0; i < n / 2; i++)
		for (int j = 0; j < n / 2; j++) _Rmatrix[i][n / 2 + j] = _M12[i][j];
	for (int i = 0; i < n / 2; i++)
		for (int j = 0; j < n / 2; j++) _Rmatrix[n / 2 + i][j] = _M21[i][j];
	for (int i = 0; i < n / 2; i++)
		for (int j = 0; j < n / 2; j++) _Rmatrix[n / 2 + i][n / 2 + j] = _M22[i][j];
	return _Rmatrix;
}

Matrix matrix_mult_strassen(Matrix _Amatrix, Matrix _Bmatrix, int n) {
	Matrix _Rmatrix(n, vector<int>(n, 0));
	if (n == 1) {
		_Rmatrix[0][0] = _Amatrix[0][0] * _Bmatrix[0][0];
		return _Rmatrix;
	}

	auto _A11 = slice_matrix(_Amatrix, 0, 0, n / 2);
	auto _A12 = slice_matrix(_Amatrix, 0, n / 2, n / 2);
	auto _A21 = slice_matrix(_Amatrix, n / 2, 0, n / 2);
	auto _A22 = slice_matrix(_Amatrix, n / 2, n / 2, n / 2);

	auto _B11 = slice_matrix(_Bmatrix, 0, 0, n / 2);
	auto _B12 = slice_matrix(_Bmatrix, 0, n / 2, n / 2);
	auto _B21 = slice_matrix(_Bmatrix, n / 2, 0, n / 2);
	auto _B22 = slice_matrix(_Bmatrix, n / 2, n / 2, n / 2);

	auto _S1 = _B12 - _B22;
	auto _S2 = _A11 + _A12;
	auto _S3 = _A21 + _A22;
	auto _S4 = _B21 - _B11;
	auto _S5 = _A11 + _A22;
	auto _S6 = _B11 + _B22;
	auto _S7 = _A12 - _A22;
	auto _S8 = _B21 + _B22;
	auto _S9 = _A11 - _A21;
	auto _S10 = _B11 + _B12;

	auto _P1 = matrix_mult_strassen(_A11, _S1, n / 2);
	auto _P2 = matrix_mult_strassen(_S2, _B22, n / 2);
	auto _P3 = matrix_mult_strassen(_S3, _B11, n / 2);
	auto _P4 = matrix_mult_strassen(_A22, _S4, n / 2);
	auto _P5 = matrix_mult_strassen(_S5, _S6, n / 2);
	auto _P6 = matrix_mult_strassen(_S7, _S8, n / 2);
	auto _P7 = matrix_mult_strassen(_S9, _S10, n / 2);

	auto _R11 = _P5 + _P4 - _P2 + _P6;
	auto _R12 = _P1 + _P2;
	auto _R21 = _P3 + _P4;
	auto _R22 = _P5 + _P1 - _P3 - _P7;

	_Rmatrix = merge_matrix(_R11, _R12, _R21, _R22, n);
	return _Rmatrix;
}

static void BM_demo_1(benchmark::State& state) {
	for (auto _ : state) {
		state.PauseTiming();
		Matrix a_matrix, b_matrix;

		random_device rd;
		mt19937 gen(rd());
		uniform_int_distribution<int> dist(_lrange, _rrange);
		for (int i = 0; i < n; ++i) {
			vector<int> row;
			for (int j = 0; j < n; ++j) {
				row.push_back(dist(gen));
			}
			a_matrix.push_back(row);
		}
		for (int i = 0; i < n; ++i) {
			vector<int> row;
			for (int j = 0; j < n; ++j) {
				row.push_back(dist(gen));
			}
			b_matrix.push_back(row);
		}
		state.ResumeTiming();
		matrix_mult(a_matrix, b_matrix, n);
	}
}
BENCHMARK(BM_demo_1)->Iterations(_iter);

static void BM_demo_2(benchmark::State& state) {
	for (auto _ : state) {
		state.PauseTiming();
		Matrix a_matrix, b_matrix;

		random_device rd;
		mt19937 gen(rd());
		uniform_int_distribution<int> dist(_lrange, _rrange);
		for (int i = 0; i < n; ++i) {
			vector<int> row;
			for (int j = 0; j < n; ++j) {
				row.push_back(dist(gen));
			}
			a_matrix.push_back(row);
		}
		for (int i = 0; i < n; ++i) {
			vector<int> row;
			for (int j = 0; j < n; ++j) {
				row.push_back(dist(gen));
			}
			b_matrix.push_back(row);
		}
		state.ResumeTiming();
		matrix_mult_strassen(a_matrix, b_matrix, n);
	}
}
BENCHMARK(BM_demo_2)->Iterations(_iter);

BENCHMARK_MAIN();