#include #include #include #include #include #include #include #include using namespace std; void randMat(int rows, int cols, float *&Mat) { Mat = new float[rows * cols]; for (int i = 0; i < rows; i++) for (int j = 0; j < cols; j++) Mat[i * cols + j] = 1.0; } // 改进的 OpenMP 子矩阵乘法:块化以提升缓存局部性 void omp_blocked_sgemm(int M, int N, int K, float *A_buf, float *B_buf, float *C_buf) { // 块大小,用于提高 L1/L2 缓存命中 const int TILE_SZ = 64; #pragma omp parallel for collapse(2) for (int rr = 0; rr < M; ++rr) { for (int cc = 0; cc < K; ++cc) { C_buf[rr * K + cc] = 0.0f; } } // 三重循环按块执行,减少主存访问并重用缓存数据 #pragma omp parallel for collapse(2) for (int rb = 0; rb < M; rb += TILE_SZ) { for (int cb = 0; cb < K; cb += TILE_SZ) { for (int ib = 0; ib < N; ib += TILE_SZ) { int r_end = min(rb + TILE_SZ, M); int c_end = min(cb + TILE_SZ, K); int i_end = min(ib + TILE_SZ, N); for (int r = rb; r < r_end; ++r) { for (int c = cb; c < c_end; ++c) { float acc = C_buf[r * K + c]; for (int t = ib; t < i_end; ++t) { acc += A_buf[r * N + t] * B_buf[c * N + t]; } C_buf[r * K + c] = acc; } } } } } } void mpi_blocked_sgemm(int M, int N, int K, float *&A_buf, float *&B_buf, float *&C_buf, int myRank, int worldN) { // 选择接近平方的进程网格(rows x cols) int rbCount = (int)sqrt((double)worldN); while (rbCount > 0 && worldN % rbCount != 0) rbCount--; int cbCount = worldN / rbCount; int rLen, cLen; float *localC = nullptr; float *locA = A_buf; float *locB = B_buf; if (myRank == 0) { // 将 B 矩阵按行与列交换以便后续按列访问更高效 float *tmp = new float[K * N]; #pragma omp parallel for collapse(2) for (int r = 0; r < N; ++r) for (int c = 0; c < K; ++c) tmp[c * N + r] = B_buf[r * K + c]; #pragma omp parallel for collapse(2) for (int r = 0; r < K; ++r) for (int c = 0; c < N; ++c) B_buf[r * N + c] = tmp[r * N + c]; delete[] tmp; // 主进程将子块数据通过非阻塞发送分发给其他进程 std::vector outReqs; outReqs.reserve(1000); for (int rb = 0; rb < rbCount; ++rb) { for (int cb = 0; cb < cbCount; ++cb) { int rBeg = rb * (M / rbCount); int rEnd = (rb == rbCount - 1) ? M : (rb + 1) * (M / rbCount); rLen = rEnd - rBeg; int cBeg = cb * (K / cbCount); int cEnd = (cb == cbCount - 1) ? K : (cb + 1) * (K / cbCount); cLen = cEnd - cBeg; int dest = rb * cbCount + cb; if (dest == 0) { localC = new float[rLen * cLen]; locA = A_buf + rBeg * N; locB = B_buf + cBeg * N; continue; } MPI_Request rq; MPI_Isend(&rLen, 1, MPI_INT, dest, 0, MPI_COMM_WORLD, &rq); outReqs.push_back(rq); MPI_Isend(&cLen, 1, MPI_INT, dest, 0, MPI_COMM_WORLD, &rq); outReqs.push_back(rq); for (int rr = 0; rr < rLen; ++rr) { MPI_Isend(A_buf + (rBeg + rr) * N, N, MPI_FLOAT, dest, 1, MPI_COMM_WORLD, &rq); outReqs.push_back(rq); } for (int cc = 0; cc < cLen; ++cc) { MPI_Isend(B_buf + (cBeg + cc) * N, N, MPI_FLOAT, dest, 2, MPI_COMM_WORLD, &rq); outReqs.push_back(rq); } } } for (size_t i = 0; i < outReqs.size(); ++i) MPI_Wait(&outReqs[i], MPI_STATUS_IGNORE); } else { if (myRank < worldN) { int rb = myRank / cbCount; int cb = myRank % cbCount; int rBeg = rb * (M / rbCount); int rEnd = (rb == rbCount - 1) ? M : (rb + 1) * (M / rbCount); rLen = rEnd - rBeg; int cBeg = cb * (K / cbCount); int cEnd = (cb == cbCount - 1) ? K : (cb + 1) * (K / cbCount); cLen = cEnd - cBeg; MPI_Recv(&rLen, 1, MPI_INT, 0, 0, MPI_COMM_WORLD, MPI_STATUS_IGNORE); MPI_Recv(&cLen, 1, MPI_INT, 0, 0, MPI_COMM_WORLD, MPI_STATUS_IGNORE); locA = new float[rLen * N]; locB = new float[cLen * N]; for (int rr = 0; rr < rLen; ++rr) MPI_Recv(locA + rr * N, N, MPI_FLOAT, 0, 1, MPI_COMM_WORLD, MPI_STATUS_IGNORE); for (int cc = 0; cc < cLen; ++cc) MPI_Recv(locB + cc * N, N, MPI_FLOAT, 0, 2, MPI_COMM_WORLD, MPI_STATUS_IGNORE); localC = new float[rLen * cLen]; } } MPI_Barrier(MPI_COMM_WORLD); // 调用本地优化的乘法实现 if (myRank < worldN) { int rb = myRank / cbCount; int cb = myRank % cbCount; int rBeg = rb * (M / rbCount); int rEnd = (rb == rbCount - 1) ? M : (rb + 1) * (M / rbCount); rLen = rEnd - rBeg; int cBeg = cb * (K / cbCount); int cEnd = (cb == cbCount - 1) ? K : (cb + 1) * (K / cbCount); cLen = cEnd - cBeg; omp_blocked_sgemm(rLen, N, cLen, locA, locB, localC); } MPI_Barrier(MPI_COMM_WORLD); // 汇总各子块到根进程 if (myRank == 0) { int rb = 0, cb = 0; int rBeg = rb * (M / rbCount); int cBeg = cb * (K / cbCount); for (int rr = 0; rr < rLen; ++rr) for (int cc = 0; cc < cLen; ++cc) C_buf[(rBeg + rr) * K + (cBeg + cc)] = localC[rr * cLen + cc]; delete[] localC; for (int rb = 0; rb < rbCount; ++rb) { for (int cb = 0; cb < cbCount; ++cb) { int src = rb * cbCount + cb; if (src == 0) continue; MPI_Recv(&rLen, 1, MPI_INT, src, 3, MPI_COMM_WORLD, MPI_STATUS_IGNORE); MPI_Recv(&cLen, 1, MPI_INT, src, 3, MPI_COMM_WORLD, MPI_STATUS_IGNORE); float *tmp = new float[rLen * cLen]; MPI_Recv(tmp, rLen * cLen, MPI_FLOAT, src, 4, MPI_COMM_WORLD, MPI_STATUS_IGNORE); int rStart = rb * (M / rbCount); int cStart = cb * (K / cbCount); for (int rr = 0; rr < rLen; ++rr) for (int cc = 0; cc < cLen; ++cc) C_buf[(rStart + rr) * K + (cStart + cc)] = tmp[rr * cLen + cc]; delete[] tmp; } } } else { if (myRank < worldN) { MPI_Send(&rLen, 1, MPI_INT, 0, 3, MPI_COMM_WORLD); MPI_Send(&cLen, 1, MPI_INT, 0, 3, MPI_COMM_WORLD); MPI_Send(localC, rLen * cLen, MPI_FLOAT, 0, 4, MPI_COMM_WORLD); delete[] localC; delete[] locA; delete[] locB; } } MPI_Barrier(MPI_COMM_WORLD); } int main(int argc, char *argv[]) { if (argc != 4) { cout << "Usage: " << argv[0] << " M N K\n"; exit(-1); } int rank; int worldSize; MPI_Init(&argc, &argv); MPI_Comm_size(MPI_COMM_WORLD, &worldSize); MPI_Comm_rank(MPI_COMM_WORLD, &rank); int m = atoi(argv[1]); int n = atoi(argv[2]); int k = atoi(argv[3]); float *A_mat, *B_mat, *C_mat; struct timeval start, stop; if (rank == 0) { randMat(m, n, A_mat); randMat(n, k, B_mat); randMat(m, k, C_mat); } gettimeofday(&start, NULL); mpi_blocked_sgemm(m, n, k, A_mat, B_mat, C_mat, rank, worldSize); gettimeofday(&stop, NULL); if (rank == 0) { double elapsed = (stop.tv_sec - start.tv_sec) * 1000.0 + (stop.tv_usec - start.tv_usec) / 1000.0; cout << "optimized mpi matmul: " << elapsed << " ms" << endl; bool correct = true; for (int i = 0; i < m; i++) { for (int j = 0; j < k; j++){ if (int(C_mat[i * k + j]) != n) { cout << "Error at [" << i << "][" << j << "]: " << C_mat[i * k + j] << " (expected " << n << ")\n"; correct = false; goto end_check; } } } end_check: if (correct) { cout << "Result verification: PASSED" << endl; } else { cout << "Result verification: FAILED" << endl; } delete[] A_mat; delete[] B_mat; delete[] C_mat; } MPI_Finalize(); return 0; }