hpc-lab-code/work/gemm_optimized.cpp
2026-01-21 18:30:58 +08:00

303 lines
10 KiB
C++

#include <math.h>
#include <stdio.h>
#include <stdlib.h>
#include <sys/time.h>
#include <iostream>
#include <mpi.h>
#include <omp.h>
#include <vector>
using namespace std;
void randMat(int rows, int cols, float *&Mat) {
Mat = new float[rows * cols];
for (int i = 0; i < rows; i++)
for (int j = 0; j < cols; j++)
Mat[i * cols + j] = 1.0;
}
// 优化版本:使用循环展开和更好的缓存局部性
void openmp_sgemm_optimized(int m, int n, int k, float *leftMat, float *rightMat,
float *resultMat) {
// 使用更大的分块以提高缓存利用率
const int BLOCK_SIZE = 64;
#pragma omp parallel for collapse(2)
for (int row = 0; row < m; row++) {
for (int col = 0; col < k; col++) {
resultMat[row * k + col] = 0.0;
}
}
// 分块计算以提高缓存命中率
#pragma omp parallel for collapse(2)
for (int row_block = 0; row_block < m; row_block += BLOCK_SIZE) {
for (int col_block = 0; col_block < k; col_block += BLOCK_SIZE) {
for (int i_block = 0; i_block < n; i_block += BLOCK_SIZE) {
int row_end = min(row_block + BLOCK_SIZE, m);
int col_end = min(col_block + BLOCK_SIZE, k);
int i_end = min(i_block + BLOCK_SIZE, n);
for (int row = row_block; row < row_end; row++) {
for (int col = col_block; col < col_end; col++) {
float sum = resultMat[row * k + col];
for (int i = i_block; i < i_end; i++) {
sum += leftMat[row * n + i] * rightMat[col * n + i];
}
resultMat[row * k + col] = sum;
}
}
}
}
}
}
void mpi_sgemm_optimized(int m, int n, int k, float *&leftMat, float *&rightMat,
float *&resultMat, int rank, int worldsize) {
// 计算行列分块数
int rowBlock = (int)sqrt((double)worldsize);
while (rowBlock > 0 && worldsize % rowBlock != 0) {
rowBlock--;
}
int colBlock = worldsize / rowBlock;
int rowStride, colStride;
float *res = nullptr;
float *localLeftMat = leftMat;
float *localRightMat = rightMat;
if (rank == 0) {
// 矩阵转置 - 使用OpenMP加速
float *buf = new float[k * n];
#pragma omp parallel for collapse(2)
for (int r = 0; r < n; r++) {
for (int c = 0; c < k; c++) {
buf[c * n + r] = rightMat[r * k + c];
}
}
#pragma omp parallel for collapse(2)
for (int r = 0; r < k; r++) {
for (int c = 0; c < n; c++) {
rightMat[r * n + c] = buf[r * n + c];
}
}
delete[] buf;
// 使用非阻塞通信重叠计算和通信
std::vector<MPI_Request> sendRequests;
sendRequests.reserve(1000);
for (int rowB = 0; rowB < rowBlock; rowB++) {
for (int colB = 0; colB < colBlock; colB++) {
int rowStart = rowB * (m / rowBlock);
int rowEnd = (rowB == rowBlock - 1) ? m : (rowB + 1) * (m / rowBlock);
rowStride = rowEnd - rowStart;
int colStart = colB * (k / colBlock);
int colEnd = (colB == colBlock - 1) ? k : (colB + 1) * (k / colBlock);
colStride = colEnd - colStart;
int sendto = rowB * colBlock + colB;
if (sendto == 0) {
res = new float[rowStride * colStride];
localLeftMat = leftMat + rowStart * n;
localRightMat = rightMat + colStart * n;
continue;
}
// 发送分块大小
MPI_Request req;
MPI_Isend(&rowStride, 1, MPI_INT, sendto, 0, MPI_COMM_WORLD, &req);
sendRequests.push_back(req);
MPI_Isend(&colStride, 1, MPI_INT, sendto, 0, MPI_COMM_WORLD, &req);
sendRequests.push_back(req);
// 发送矩阵数据
for (int r = 0; r < rowStride; r++) {
MPI_Isend(leftMat + (rowStart + r) * n, n, MPI_FLOAT, sendto,
1, MPI_COMM_WORLD, &req);
sendRequests.push_back(req);
}
for (int c = 0; c < colStride; c++) {
MPI_Isend(rightMat + (colStart + c) * n, n, MPI_FLOAT, sendto,
2, MPI_COMM_WORLD, &req);
sendRequests.push_back(req);
}
}
}
// 等待所有发送完成
for (size_t i = 0; i < sendRequests.size(); i++) {
MPI_Wait(&sendRequests[i], MPI_STATUS_IGNORE);
}
} else {
if (rank < worldsize) {
int rowB = rank / colBlock;
int colB = rank % colBlock;
int rowStart = rowB * (m / rowBlock);
int rowEnd = (rowB == rowBlock - 1) ? m : (rowB + 1) * (m / rowBlock);
rowStride = rowEnd - rowStart;
int colStart = colB * (k / colBlock);
int colEnd = (colB == colBlock - 1) ? k : (colB + 1) * (k / colBlock);
colStride = colEnd - colStart;
MPI_Recv(&rowStride, 1, MPI_INT, 0, 0, MPI_COMM_WORLD, MPI_STATUS_IGNORE);
MPI_Recv(&colStride, 1, MPI_INT, 0, 0, MPI_COMM_WORLD, MPI_STATUS_IGNORE);
localLeftMat = new float[rowStride * n];
localRightMat = new float[colStride * n];
for (int r = 0; r < rowStride; r++) {
MPI_Recv(localLeftMat + r * n, n, MPI_FLOAT, 0, 1, MPI_COMM_WORLD,
MPI_STATUS_IGNORE);
}
for (int c = 0; c < colStride; c++) {
MPI_Recv(localRightMat + c * n, n, MPI_FLOAT, 0, 2, MPI_COMM_WORLD,
MPI_STATUS_IGNORE);
}
res = new float[rowStride * colStride];
}
}
MPI_Barrier(MPI_COMM_WORLD);
// 本地计算 - 使用优化版本
if (rank < worldsize) {
int rowB = rank / colBlock;
int colB = rank % colBlock;
int rowStart = rowB * (m / rowBlock);
int rowEnd = (rowB == rowBlock - 1) ? m : (rowB + 1) * (m / rowBlock);
rowStride = rowEnd - rowStart;
int colStart = colB * (k / colBlock);
int colEnd = (colB == colBlock - 1) ? k : (colB + 1) * (k / colBlock);
colStride = colEnd - colStart;
openmp_sgemm_optimized(rowStride, n, colStride, localLeftMat, localRightMat, res);
}
MPI_Barrier(MPI_COMM_WORLD);
// 收集结果
if (rank == 0) {
int rowB = 0;
int colB = 0;
int rowStart = rowB * (m / rowBlock);
int colStart = colB * (k / colBlock);
for (int r = 0; r < rowStride; r++) {
for (int c = 0; c < colStride; c++) {
resultMat[(rowStart + r) * k + (colStart + c)] = res[r * colStride + c];
}
}
delete[] res;
for (int rowB = 0; rowB < rowBlock; rowB++) {
for (int colB = 0; colB < colBlock; colB++) {
int recvfrom = rowB * colBlock + colB;
if (recvfrom == 0) continue;
MPI_Recv(&rowStride, 1, MPI_INT, recvfrom, 3, MPI_COMM_WORLD, MPI_STATUS_IGNORE);
MPI_Recv(&colStride, 1, MPI_INT, recvfrom, 3, MPI_COMM_WORLD, MPI_STATUS_IGNORE);
float *tmpRes = new float[rowStride * colStride];
MPI_Recv(tmpRes, rowStride * colStride, MPI_FLOAT, recvfrom, 4,
MPI_COMM_WORLD, MPI_STATUS_IGNORE);
int rowStart = rowB * (m / rowBlock);
int colStart = colB * (k / colBlock);
for (int r = 0; r < rowStride; r++) {
for (int c = 0; c < colStride; c++) {
resultMat[(rowStart + r) * k + (colStart + c)] = tmpRes[r * colStride + c];
}
}
delete[] tmpRes;
}
}
} else {
if (rank < worldsize) {
MPI_Send(&rowStride, 1, MPI_INT, 0, 3, MPI_COMM_WORLD);
MPI_Send(&colStride, 1, MPI_INT, 0, 3, MPI_COMM_WORLD);
MPI_Send(res, rowStride * colStride, MPI_FLOAT, 0, 4, MPI_COMM_WORLD);
delete[] res;
delete[] localLeftMat;
delete[] localRightMat;
}
}
MPI_Barrier(MPI_COMM_WORLD);
}
int main(int argc, char *argv[]) {
if (argc != 4) {
cout << "Usage: " << argv[0] << " M N K\n";
exit(-1);
}
int rank;
int worldSize;
MPI_Init(&argc, &argv);
MPI_Comm_size(MPI_COMM_WORLD, &worldSize);
MPI_Comm_rank(MPI_COMM_WORLD, &rank);
int m = atoi(argv[1]);
int n = atoi(argv[2]);
int k = atoi(argv[3]);
float *leftMat, *rightMat, *resMat;
struct timeval start, stop;
if (rank == 0) {
randMat(m, n, leftMat);
randMat(n, k, rightMat);
randMat(m, k, resMat);
}
gettimeofday(&start, NULL);
mpi_sgemm_optimized(m, n, k, leftMat, rightMat, resMat, rank, worldSize);
gettimeofday(&stop, NULL);
if (rank == 0) {
double elapsed = (stop.tv_sec - start.tv_sec) * 1000.0 +
(stop.tv_usec - start.tv_usec) / 1000.0;
cout << "optimized mpi matmul: " << elapsed << " ms" << endl;
bool correct = true;
for (int i = 0; i < m; i++) {
for (int j = 0; j < k; j++){
if (int(resMat[i * k + j]) != n) {
cout << "Error at [" << i << "][" << j << "]: "
<< resMat[i * k + j] << " (expected " << n << ")\n";
correct = false;
goto end_check;
}
}
}
end_check:
if (correct) {
cout << "Result verification: PASSED" << endl;
} else {
cout << "Result verification: FAILED" << endl;
}
delete[] leftMat;
delete[] rightMat;
delete[] resMat;
}
MPI_Finalize();
return 0;
}