98 lines
2.2 KiB
C++
98 lines
2.2 KiB
C++
#include <math.h>
|
|
#include <stdio.h>
|
|
#include <stdlib.h>
|
|
#include <sys/time.h>
|
|
#include <iostream>
|
|
|
|
using namespace std;
|
|
|
|
void randMat(int rows, int cols, float *&Mat) {
|
|
Mat = new float[rows * cols];
|
|
for (int i = 0; i < rows; i++)
|
|
for (int j = 0; j < cols; j++)
|
|
Mat[i * cols + j] = 1.0;
|
|
}
|
|
|
|
void serial_sgemm(int m, int n, int k, float *&leftMat, float *&rightMat,
|
|
float *&resultMat) {
|
|
// rightMat is transposed
|
|
float *buf = new float[k * n];
|
|
// transpose right Mat
|
|
for (int r = 0; r < n; r++) {
|
|
for (int c = 0; c < k; c++) {
|
|
buf[c * n + r] = rightMat[r * k + c];
|
|
}
|
|
}
|
|
for (int r = 0; r < k; r++) {
|
|
for (int c = 0; c < n; c++) {
|
|
rightMat[r * n + c] = buf[r * n + c];
|
|
}
|
|
}
|
|
|
|
for (int row = 0; row < m; row++) {
|
|
for (int col = 0; col < k; col++) {
|
|
resultMat[row * k + col] = 0.0;
|
|
for (int i = 0; i < n; i++) {
|
|
resultMat[row * k + col] +=
|
|
leftMat[row * n + i] * rightMat[col * n + i];
|
|
}
|
|
}
|
|
}
|
|
delete[] buf;
|
|
return;
|
|
}
|
|
|
|
int main(int argc, char *argv[]) {
|
|
if (argc != 5) {
|
|
cout << "Usage: " << argv[0] << " M N K use-blas\n";
|
|
exit(-1);
|
|
}
|
|
|
|
int m = atoi(argv[1]);
|
|
int n = atoi(argv[2]);
|
|
int k = atoi(argv[3]);
|
|
int blas = atoi(argv[4]);
|
|
|
|
float *leftMat, *rightMat, *resMat;
|
|
|
|
struct timeval start, stop;
|
|
randMat(m, n, leftMat);
|
|
randMat(n, k, rightMat);
|
|
randMat(m, k, resMat);
|
|
|
|
gettimeofday(&start, NULL);
|
|
|
|
serial_sgemm(m, n, k, leftMat, rightMat, resMat);
|
|
|
|
gettimeofday(&stop, NULL);
|
|
cout << "matmul: "
|
|
<< (stop.tv_sec - start.tv_sec) * 1000.0 +
|
|
(stop.tv_usec - start.tv_usec) / 1000.0
|
|
<< " ms" << endl;
|
|
|
|
// 验证结果
|
|
bool correct = true;
|
|
for (int i = 0; i < m; i++) {
|
|
for (int j = 0; j < k; j++){
|
|
if (int(resMat[i * k + j]) != n) {
|
|
cout << "Error at [" << i << "][" << j << "]: "
|
|
<< resMat[i * k + j] << " (expected " << n << ")\n";
|
|
correct = false;
|
|
goto end_check;
|
|
}
|
|
}
|
|
}
|
|
end_check:
|
|
if (correct) {
|
|
cout << "Result verification: PASSED" << endl;
|
|
} else {
|
|
cout << "Result verification: FAILED" << endl;
|
|
}
|
|
|
|
delete[] leftMat;
|
|
delete[] rightMat;
|
|
delete[] resMat;
|
|
|
|
return 0;
|
|
}
|