277 lines
8.8 KiB
C++
277 lines
8.8 KiB
C++
#include <math.h>
|
||
#include <stdio.h>
|
||
#include <stdlib.h>
|
||
#include <sys/time.h>
|
||
#include <iostream>
|
||
#include <mpi.h>
|
||
#include <omp.h>
|
||
#include <vector>
|
||
|
||
using namespace std;
|
||
|
||
void randMat(int rows, int cols, float *&Mat) {
|
||
Mat = new float[rows * cols];
|
||
for (int i = 0; i < rows; i++)
|
||
for (int j = 0; j < cols; j++)
|
||
Mat[i * cols + j] = 1.0;
|
||
}
|
||
|
||
// 改进的 OpenMP 子矩阵乘法:块化以提升缓存局部性
|
||
void omp_blocked_sgemm(int M, int N, int K, float *A_buf, float *B_buf,
|
||
float *C_buf) {
|
||
// 块大小,用于提高 L1/L2 缓存命中
|
||
const int TILE_SZ = 64;
|
||
|
||
#pragma omp parallel for collapse(2)
|
||
for (int rr = 0; rr < M; ++rr) {
|
||
for (int cc = 0; cc < K; ++cc) {
|
||
C_buf[rr * K + cc] = 0.0f;
|
||
}
|
||
}
|
||
|
||
// 三重循环按块执行,减少主存访问并重用缓存数据
|
||
#pragma omp parallel for collapse(2)
|
||
for (int rb = 0; rb < M; rb += TILE_SZ) {
|
||
for (int cb = 0; cb < K; cb += TILE_SZ) {
|
||
for (int ib = 0; ib < N; ib += TILE_SZ) {
|
||
int r_end = min(rb + TILE_SZ, M);
|
||
int c_end = min(cb + TILE_SZ, K);
|
||
int i_end = min(ib + TILE_SZ, N);
|
||
|
||
for (int r = rb; r < r_end; ++r) {
|
||
for (int c = cb; c < c_end; ++c) {
|
||
float acc = C_buf[r * K + c];
|
||
for (int t = ib; t < i_end; ++t) {
|
||
acc += A_buf[r * N + t] * B_buf[c * N + t];
|
||
}
|
||
C_buf[r * K + c] = acc;
|
||
}
|
||
}
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
void mpi_blocked_sgemm(int M, int N, int K, float *&A_buf, float *&B_buf,
|
||
float *&C_buf, int myRank, int worldN) {
|
||
|
||
// 选择接近平方的进程网格(rows x cols)
|
||
int rbCount = (int)sqrt((double)worldN);
|
||
while (rbCount > 0 && worldN % rbCount != 0) rbCount--;
|
||
int cbCount = worldN / rbCount;
|
||
|
||
int rLen, cLen;
|
||
float *localC = nullptr;
|
||
float *locA = A_buf;
|
||
float *locB = B_buf;
|
||
|
||
if (myRank == 0) {
|
||
// 将 B 矩阵按行与列交换以便后续按列访问更高效
|
||
float *tmp = new float[K * N];
|
||
#pragma omp parallel for collapse(2)
|
||
for (int r = 0; r < N; ++r)
|
||
for (int c = 0; c < K; ++c)
|
||
tmp[c * N + r] = B_buf[r * K + c];
|
||
|
||
#pragma omp parallel for collapse(2)
|
||
for (int r = 0; r < K; ++r)
|
||
for (int c = 0; c < N; ++c)
|
||
B_buf[r * N + c] = tmp[r * N + c];
|
||
delete[] tmp;
|
||
|
||
// 主进程将子块数据通过非阻塞发送分发给其他进程
|
||
std::vector<MPI_Request> outReqs;
|
||
outReqs.reserve(1000);
|
||
|
||
for (int rb = 0; rb < rbCount; ++rb) {
|
||
for (int cb = 0; cb < cbCount; ++cb) {
|
||
int rBeg = rb * (M / rbCount);
|
||
int rEnd = (rb == rbCount - 1) ? M : (rb + 1) * (M / rbCount);
|
||
rLen = rEnd - rBeg;
|
||
|
||
int cBeg = cb * (K / cbCount);
|
||
int cEnd = (cb == cbCount - 1) ? K : (cb + 1) * (K / cbCount);
|
||
cLen = cEnd - cBeg;
|
||
|
||
int dest = rb * cbCount + cb;
|
||
if (dest == 0) {
|
||
localC = new float[rLen * cLen];
|
||
locA = A_buf + rBeg * N;
|
||
locB = B_buf + cBeg * N;
|
||
continue;
|
||
}
|
||
|
||
MPI_Request rq;
|
||
MPI_Isend(&rLen, 1, MPI_INT, dest, 0, MPI_COMM_WORLD, &rq);
|
||
outReqs.push_back(rq);
|
||
MPI_Isend(&cLen, 1, MPI_INT, dest, 0, MPI_COMM_WORLD, &rq);
|
||
outReqs.push_back(rq);
|
||
|
||
for (int rr = 0; rr < rLen; ++rr) {
|
||
MPI_Isend(A_buf + (rBeg + rr) * N, N, MPI_FLOAT, dest, 1, MPI_COMM_WORLD, &rq);
|
||
outReqs.push_back(rq);
|
||
}
|
||
for (int cc = 0; cc < cLen; ++cc) {
|
||
MPI_Isend(B_buf + (cBeg + cc) * N, N, MPI_FLOAT, dest, 2, MPI_COMM_WORLD, &rq);
|
||
outReqs.push_back(rq);
|
||
}
|
||
}
|
||
}
|
||
|
||
for (size_t i = 0; i < outReqs.size(); ++i) MPI_Wait(&outReqs[i], MPI_STATUS_IGNORE);
|
||
} else {
|
||
if (myRank < worldN) {
|
||
int rb = myRank / cbCount;
|
||
int cb = myRank % cbCount;
|
||
|
||
int rBeg = rb * (M / rbCount);
|
||
int rEnd = (rb == rbCount - 1) ? M : (rb + 1) * (M / rbCount);
|
||
rLen = rEnd - rBeg;
|
||
|
||
int cBeg = cb * (K / cbCount);
|
||
int cEnd = (cb == cbCount - 1) ? K : (cb + 1) * (K / cbCount);
|
||
cLen = cEnd - cBeg;
|
||
|
||
MPI_Recv(&rLen, 1, MPI_INT, 0, 0, MPI_COMM_WORLD, MPI_STATUS_IGNORE);
|
||
MPI_Recv(&cLen, 1, MPI_INT, 0, 0, MPI_COMM_WORLD, MPI_STATUS_IGNORE);
|
||
|
||
locA = new float[rLen * N];
|
||
locB = new float[cLen * N];
|
||
|
||
for (int rr = 0; rr < rLen; ++rr)
|
||
MPI_Recv(locA + rr * N, N, MPI_FLOAT, 0, 1, MPI_COMM_WORLD, MPI_STATUS_IGNORE);
|
||
for (int cc = 0; cc < cLen; ++cc)
|
||
MPI_Recv(locB + cc * N, N, MPI_FLOAT, 0, 2, MPI_COMM_WORLD, MPI_STATUS_IGNORE);
|
||
|
||
localC = new float[rLen * cLen];
|
||
}
|
||
}
|
||
|
||
MPI_Barrier(MPI_COMM_WORLD);
|
||
|
||
// 调用本地优化的乘法实现
|
||
if (myRank < worldN) {
|
||
int rb = myRank / cbCount;
|
||
int cb = myRank % cbCount;
|
||
|
||
int rBeg = rb * (M / rbCount);
|
||
int rEnd = (rb == rbCount - 1) ? M : (rb + 1) * (M / rbCount);
|
||
rLen = rEnd - rBeg;
|
||
|
||
int cBeg = cb * (K / cbCount);
|
||
int cEnd = (cb == cbCount - 1) ? K : (cb + 1) * (K / cbCount);
|
||
cLen = cEnd - cBeg;
|
||
|
||
omp_blocked_sgemm(rLen, N, cLen, locA, locB, localC);
|
||
}
|
||
|
||
MPI_Barrier(MPI_COMM_WORLD);
|
||
|
||
// 汇总各子块到根进程
|
||
if (myRank == 0) {
|
||
int rb = 0, cb = 0;
|
||
int rBeg = rb * (M / rbCount);
|
||
int cBeg = cb * (K / cbCount);
|
||
|
||
for (int rr = 0; rr < rLen; ++rr)
|
||
for (int cc = 0; cc < cLen; ++cc)
|
||
C_buf[(rBeg + rr) * K + (cBeg + cc)] = localC[rr * cLen + cc];
|
||
delete[] localC;
|
||
|
||
for (int rb = 0; rb < rbCount; ++rb) {
|
||
for (int cb = 0; cb < cbCount; ++cb) {
|
||
int src = rb * cbCount + cb;
|
||
if (src == 0) continue;
|
||
|
||
MPI_Recv(&rLen, 1, MPI_INT, src, 3, MPI_COMM_WORLD, MPI_STATUS_IGNORE);
|
||
MPI_Recv(&cLen, 1, MPI_INT, src, 3, MPI_COMM_WORLD, MPI_STATUS_IGNORE);
|
||
|
||
float *tmp = new float[rLen * cLen];
|
||
MPI_Recv(tmp, rLen * cLen, MPI_FLOAT, src, 4, MPI_COMM_WORLD, MPI_STATUS_IGNORE);
|
||
|
||
int rStart = rb * (M / rbCount);
|
||
int cStart = cb * (K / cbCount);
|
||
for (int rr = 0; rr < rLen; ++rr)
|
||
for (int cc = 0; cc < cLen; ++cc)
|
||
C_buf[(rStart + rr) * K + (cStart + cc)] = tmp[rr * cLen + cc];
|
||
|
||
delete[] tmp;
|
||
}
|
||
}
|
||
} else {
|
||
if (myRank < worldN) {
|
||
MPI_Send(&rLen, 1, MPI_INT, 0, 3, MPI_COMM_WORLD);
|
||
MPI_Send(&cLen, 1, MPI_INT, 0, 3, MPI_COMM_WORLD);
|
||
MPI_Send(localC, rLen * cLen, MPI_FLOAT, 0, 4, MPI_COMM_WORLD);
|
||
|
||
delete[] localC;
|
||
delete[] locA;
|
||
delete[] locB;
|
||
}
|
||
}
|
||
|
||
MPI_Barrier(MPI_COMM_WORLD);
|
||
}
|
||
|
||
int main(int argc, char *argv[]) {
|
||
if (argc != 4) {
|
||
cout << "Usage: " << argv[0] << " M N K\n";
|
||
exit(-1);
|
||
}
|
||
|
||
int rank;
|
||
int worldSize;
|
||
MPI_Init(&argc, &argv);
|
||
|
||
MPI_Comm_size(MPI_COMM_WORLD, &worldSize);
|
||
MPI_Comm_rank(MPI_COMM_WORLD, &rank);
|
||
|
||
int m = atoi(argv[1]);
|
||
int n = atoi(argv[2]);
|
||
int k = atoi(argv[3]);
|
||
|
||
float *A_mat, *B_mat, *C_mat;
|
||
struct timeval start, stop;
|
||
|
||
if (rank == 0) {
|
||
randMat(m, n, A_mat);
|
||
randMat(n, k, B_mat);
|
||
randMat(m, k, C_mat);
|
||
}
|
||
|
||
gettimeofday(&start, NULL);
|
||
mpi_blocked_sgemm(m, n, k, A_mat, B_mat, C_mat, rank, worldSize);
|
||
gettimeofday(&stop, NULL);
|
||
|
||
if (rank == 0) {
|
||
double elapsed = (stop.tv_sec - start.tv_sec) * 1000.0 +
|
||
(stop.tv_usec - start.tv_usec) / 1000.0;
|
||
cout << "optimized mpi matmul: " << elapsed << " ms" << endl;
|
||
|
||
bool correct = true;
|
||
for (int i = 0; i < m; i++) {
|
||
for (int j = 0; j < k; j++){
|
||
if (int(C_mat[i * k + j]) != n) {
|
||
cout << "Error at [" << i << "][" << j << "]: "
|
||
<< C_mat[i * k + j] << " (expected " << n << ")\n";
|
||
correct = false;
|
||
goto end_check;
|
||
}
|
||
}
|
||
}
|
||
end_check:
|
||
if (correct) {
|
||
cout << "Result verification: PASSED" << endl;
|
||
} else {
|
||
cout << "Result verification: FAILED" << endl;
|
||
}
|
||
|
||
delete[] A_mat;
|
||
delete[] B_mat;
|
||
delete[] C_mat;
|
||
}
|
||
|
||
MPI_Finalize();
|
||
return 0;
|
||
}
|