hpc-lab-code/submit/gemm/matmul_youhua.cpp
2026-01-21 18:02:30 +08:00

277 lines
8.8 KiB
C++
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#include <math.h>
#include <stdio.h>
#include <stdlib.h>
#include <sys/time.h>
#include <iostream>
#include <mpi.h>
#include <omp.h>
#include <vector>
using namespace std;
void randMat(int rows, int cols, float *&Mat) {
Mat = new float[rows * cols];
for (int i = 0; i < rows; i++)
for (int j = 0; j < cols; j++)
Mat[i * cols + j] = 1.0;
}
// 改进的 OpenMP 子矩阵乘法:块化以提升缓存局部性
void omp_blocked_sgemm(int M, int N, int K, float *A_buf, float *B_buf,
float *C_buf) {
// 块大小,用于提高 L1/L2 缓存命中
const int TILE_SZ = 64;
#pragma omp parallel for collapse(2)
for (int rr = 0; rr < M; ++rr) {
for (int cc = 0; cc < K; ++cc) {
C_buf[rr * K + cc] = 0.0f;
}
}
// 三重循环按块执行,减少主存访问并重用缓存数据
#pragma omp parallel for collapse(2)
for (int rb = 0; rb < M; rb += TILE_SZ) {
for (int cb = 0; cb < K; cb += TILE_SZ) {
for (int ib = 0; ib < N; ib += TILE_SZ) {
int r_end = min(rb + TILE_SZ, M);
int c_end = min(cb + TILE_SZ, K);
int i_end = min(ib + TILE_SZ, N);
for (int r = rb; r < r_end; ++r) {
for (int c = cb; c < c_end; ++c) {
float acc = C_buf[r * K + c];
for (int t = ib; t < i_end; ++t) {
acc += A_buf[r * N + t] * B_buf[c * N + t];
}
C_buf[r * K + c] = acc;
}
}
}
}
}
}
void mpi_blocked_sgemm(int M, int N, int K, float *&A_buf, float *&B_buf,
float *&C_buf, int myRank, int worldN) {
// 选择接近平方的进程网格rows x cols
int rbCount = (int)sqrt((double)worldN);
while (rbCount > 0 && worldN % rbCount != 0) rbCount--;
int cbCount = worldN / rbCount;
int rLen, cLen;
float *localC = nullptr;
float *locA = A_buf;
float *locB = B_buf;
if (myRank == 0) {
// 将 B 矩阵按行与列交换以便后续按列访问更高效
float *tmp = new float[K * N];
#pragma omp parallel for collapse(2)
for (int r = 0; r < N; ++r)
for (int c = 0; c < K; ++c)
tmp[c * N + r] = B_buf[r * K + c];
#pragma omp parallel for collapse(2)
for (int r = 0; r < K; ++r)
for (int c = 0; c < N; ++c)
B_buf[r * N + c] = tmp[r * N + c];
delete[] tmp;
// 主进程将子块数据通过非阻塞发送分发给其他进程
std::vector<MPI_Request> outReqs;
outReqs.reserve(1000);
for (int rb = 0; rb < rbCount; ++rb) {
for (int cb = 0; cb < cbCount; ++cb) {
int rBeg = rb * (M / rbCount);
int rEnd = (rb == rbCount - 1) ? M : (rb + 1) * (M / rbCount);
rLen = rEnd - rBeg;
int cBeg = cb * (K / cbCount);
int cEnd = (cb == cbCount - 1) ? K : (cb + 1) * (K / cbCount);
cLen = cEnd - cBeg;
int dest = rb * cbCount + cb;
if (dest == 0) {
localC = new float[rLen * cLen];
locA = A_buf + rBeg * N;
locB = B_buf + cBeg * N;
continue;
}
MPI_Request rq;
MPI_Isend(&rLen, 1, MPI_INT, dest, 0, MPI_COMM_WORLD, &rq);
outReqs.push_back(rq);
MPI_Isend(&cLen, 1, MPI_INT, dest, 0, MPI_COMM_WORLD, &rq);
outReqs.push_back(rq);
for (int rr = 0; rr < rLen; ++rr) {
MPI_Isend(A_buf + (rBeg + rr) * N, N, MPI_FLOAT, dest, 1, MPI_COMM_WORLD, &rq);
outReqs.push_back(rq);
}
for (int cc = 0; cc < cLen; ++cc) {
MPI_Isend(B_buf + (cBeg + cc) * N, N, MPI_FLOAT, dest, 2, MPI_COMM_WORLD, &rq);
outReqs.push_back(rq);
}
}
}
for (size_t i = 0; i < outReqs.size(); ++i) MPI_Wait(&outReqs[i], MPI_STATUS_IGNORE);
} else {
if (myRank < worldN) {
int rb = myRank / cbCount;
int cb = myRank % cbCount;
int rBeg = rb * (M / rbCount);
int rEnd = (rb == rbCount - 1) ? M : (rb + 1) * (M / rbCount);
rLen = rEnd - rBeg;
int cBeg = cb * (K / cbCount);
int cEnd = (cb == cbCount - 1) ? K : (cb + 1) * (K / cbCount);
cLen = cEnd - cBeg;
MPI_Recv(&rLen, 1, MPI_INT, 0, 0, MPI_COMM_WORLD, MPI_STATUS_IGNORE);
MPI_Recv(&cLen, 1, MPI_INT, 0, 0, MPI_COMM_WORLD, MPI_STATUS_IGNORE);
locA = new float[rLen * N];
locB = new float[cLen * N];
for (int rr = 0; rr < rLen; ++rr)
MPI_Recv(locA + rr * N, N, MPI_FLOAT, 0, 1, MPI_COMM_WORLD, MPI_STATUS_IGNORE);
for (int cc = 0; cc < cLen; ++cc)
MPI_Recv(locB + cc * N, N, MPI_FLOAT, 0, 2, MPI_COMM_WORLD, MPI_STATUS_IGNORE);
localC = new float[rLen * cLen];
}
}
MPI_Barrier(MPI_COMM_WORLD);
// 调用本地优化的乘法实现
if (myRank < worldN) {
int rb = myRank / cbCount;
int cb = myRank % cbCount;
int rBeg = rb * (M / rbCount);
int rEnd = (rb == rbCount - 1) ? M : (rb + 1) * (M / rbCount);
rLen = rEnd - rBeg;
int cBeg = cb * (K / cbCount);
int cEnd = (cb == cbCount - 1) ? K : (cb + 1) * (K / cbCount);
cLen = cEnd - cBeg;
omp_blocked_sgemm(rLen, N, cLen, locA, locB, localC);
}
MPI_Barrier(MPI_COMM_WORLD);
// 汇总各子块到根进程
if (myRank == 0) {
int rb = 0, cb = 0;
int rBeg = rb * (M / rbCount);
int cBeg = cb * (K / cbCount);
for (int rr = 0; rr < rLen; ++rr)
for (int cc = 0; cc < cLen; ++cc)
C_buf[(rBeg + rr) * K + (cBeg + cc)] = localC[rr * cLen + cc];
delete[] localC;
for (int rb = 0; rb < rbCount; ++rb) {
for (int cb = 0; cb < cbCount; ++cb) {
int src = rb * cbCount + cb;
if (src == 0) continue;
MPI_Recv(&rLen, 1, MPI_INT, src, 3, MPI_COMM_WORLD, MPI_STATUS_IGNORE);
MPI_Recv(&cLen, 1, MPI_INT, src, 3, MPI_COMM_WORLD, MPI_STATUS_IGNORE);
float *tmp = new float[rLen * cLen];
MPI_Recv(tmp, rLen * cLen, MPI_FLOAT, src, 4, MPI_COMM_WORLD, MPI_STATUS_IGNORE);
int rStart = rb * (M / rbCount);
int cStart = cb * (K / cbCount);
for (int rr = 0; rr < rLen; ++rr)
for (int cc = 0; cc < cLen; ++cc)
C_buf[(rStart + rr) * K + (cStart + cc)] = tmp[rr * cLen + cc];
delete[] tmp;
}
}
} else {
if (myRank < worldN) {
MPI_Send(&rLen, 1, MPI_INT, 0, 3, MPI_COMM_WORLD);
MPI_Send(&cLen, 1, MPI_INT, 0, 3, MPI_COMM_WORLD);
MPI_Send(localC, rLen * cLen, MPI_FLOAT, 0, 4, MPI_COMM_WORLD);
delete[] localC;
delete[] locA;
delete[] locB;
}
}
MPI_Barrier(MPI_COMM_WORLD);
}
int main(int argc, char *argv[]) {
if (argc != 4) {
cout << "Usage: " << argv[0] << " M N K\n";
exit(-1);
}
int rank;
int worldSize;
MPI_Init(&argc, &argv);
MPI_Comm_size(MPI_COMM_WORLD, &worldSize);
MPI_Comm_rank(MPI_COMM_WORLD, &rank);
int m = atoi(argv[1]);
int n = atoi(argv[2]);
int k = atoi(argv[3]);
float *A_mat, *B_mat, *C_mat;
struct timeval start, stop;
if (rank == 0) {
randMat(m, n, A_mat);
randMat(n, k, B_mat);
randMat(m, k, C_mat);
}
gettimeofday(&start, NULL);
mpi_blocked_sgemm(m, n, k, A_mat, B_mat, C_mat, rank, worldSize);
gettimeofday(&stop, NULL);
if (rank == 0) {
double elapsed = (stop.tv_sec - start.tv_sec) * 1000.0 +
(stop.tv_usec - start.tv_usec) / 1000.0;
cout << "optimized mpi matmul: " << elapsed << " ms" << endl;
bool correct = true;
for (int i = 0; i < m; i++) {
for (int j = 0; j < k; j++){
if (int(C_mat[i * k + j]) != n) {
cout << "Error at [" << i << "][" << j << "]: "
<< C_mat[i * k + j] << " (expected " << n << ")\n";
correct = false;
goto end_check;
}
}
}
end_check:
if (correct) {
cout << "Result verification: PASSED" << endl;
} else {
cout << "Result verification: FAILED" << endl;
}
delete[] A_mat;
delete[] B_mat;
delete[] C_mat;
}
MPI_Finalize();
return 0;
}