#include #include #include #include #include using namespace std; void randMat(int rows, int cols, float *&Mat) { Mat = new float[rows * cols]; for (int i = 0; i < rows; i++) for (int j = 0; j < cols; j++) Mat[i * cols + j] = 1.0; } void serial_sgemm(int m, int n, int k, float *&leftMat, float *&rightMat, float *&resultMat) { // rightMat is transposed float *buf = new float[k * n]; // transpose right Mat for (int r = 0; r < n; r++) { for (int c = 0; c < k; c++) { buf[c * n + r] = rightMat[r * k + c]; } } for (int r = 0; r < k; r++) { for (int c = 0; c < n; c++) { rightMat[r * n + c] = buf[r * n + c]; } } for (int row = 0; row < m; row++) { for (int col = 0; col < k; col++) { resultMat[row * k + col] = 0.0; for (int i = 0; i < n; i++) { resultMat[row * k + col] += leftMat[row * n + i] * rightMat[col * n + i]; } } } delete[] buf; return; } int main(int argc, char *argv[]) { if (argc != 5) { cout << "Usage: " << argv[0] << " M N K use-blas\n"; exit(-1); } int m = atoi(argv[1]); int n = atoi(argv[2]); int k = atoi(argv[3]); int blas = atoi(argv[4]); float *leftMat, *rightMat, *resMat; struct timeval start, stop; randMat(m, n, leftMat); randMat(n, k, rightMat); randMat(m, k, resMat); gettimeofday(&start, NULL); serial_sgemm(m, n, k, leftMat, rightMat, resMat); gettimeofday(&stop, NULL); cout << "matmul: " << (stop.tv_sec - start.tv_sec) * 1000.0 + (stop.tv_usec - start.tv_usec) / 1000.0 << " ms" << endl; // 验证结果 bool correct = true; for (int i = 0; i < m; i++) { for (int j = 0; j < k; j++){ if (int(resMat[i * k + j]) != n) { cout << "Error at [" << i << "][" << j << "]: " << resMat[i * k + j] << " (expected " << n << ")\n"; correct = false; goto end_check; } } } end_check: if (correct) { cout << "Result verification: PASSED" << endl; } else { cout << "Result verification: FAILED" << endl; } delete[] leftMat; delete[] rightMat; delete[] resMat; return 0; }