hpc-lab-code/work/gemm_serial.cpp
2026-01-21 18:02:30 +08:00

98 lines
2.2 KiB
C++

#include <math.h>
#include <stdio.h>
#include <stdlib.h>
#include <sys/time.h>
#include <iostream>
using namespace std;
void randMat(int rows, int cols, float *&Mat) {
Mat = new float[rows * cols];
for (int i = 0; i < rows; i++)
for (int j = 0; j < cols; j++)
Mat[i * cols + j] = 1.0;
}
void serial_sgemm(int m, int n, int k, float *&leftMat, float *&rightMat,
float *&resultMat) {
// rightMat is transposed
float *buf = new float[k * n];
// transpose right Mat
for (int r = 0; r < n; r++) {
for (int c = 0; c < k; c++) {
buf[c * n + r] = rightMat[r * k + c];
}
}
for (int r = 0; r < k; r++) {
for (int c = 0; c < n; c++) {
rightMat[r * n + c] = buf[r * n + c];
}
}
for (int row = 0; row < m; row++) {
for (int col = 0; col < k; col++) {
resultMat[row * k + col] = 0.0;
for (int i = 0; i < n; i++) {
resultMat[row * k + col] +=
leftMat[row * n + i] * rightMat[col * n + i];
}
}
}
delete[] buf;
return;
}
int main(int argc, char *argv[]) {
if (argc != 5) {
cout << "Usage: " << argv[0] << " M N K use-blas\n";
exit(-1);
}
int m = atoi(argv[1]);
int n = atoi(argv[2]);
int k = atoi(argv[3]);
int blas = atoi(argv[4]);
float *leftMat, *rightMat, *resMat;
struct timeval start, stop;
randMat(m, n, leftMat);
randMat(n, k, rightMat);
randMat(m, k, resMat);
gettimeofday(&start, NULL);
serial_sgemm(m, n, k, leftMat, rightMat, resMat);
gettimeofday(&stop, NULL);
cout << "matmul: "
<< (stop.tv_sec - start.tv_sec) * 1000.0 +
(stop.tv_usec - start.tv_usec) / 1000.0
<< " ms" << endl;
// 验证结果
bool correct = true;
for (int i = 0; i < m; i++) {
for (int j = 0; j < k; j++){
if (int(resMat[i * k + j]) != n) {
cout << "Error at [" << i << "][" << j << "]: "
<< resMat[i * k + j] << " (expected " << n << ")\n";
correct = false;
goto end_check;
}
}
}
end_check:
if (correct) {
cout << "Result verification: PASSED" << endl;
} else {
cout << "Result verification: FAILED" << endl;
}
delete[] leftMat;
delete[] rightMat;
delete[] resMat;
return 0;
}