#include #include #include #include #include #include #include #include using namespace std; void randMat(int rows, int cols, float *&Mat) { Mat = new float[rows * cols]; for (int i = 0; i < rows; i++) for (int j = 0; j < cols; j++) Mat[i * cols + j] = 1.0; } // 优化版本:使用循环展开和更好的缓存局部性 void openmp_sgemm_optimized(int m, int n, int k, float *leftMat, float *rightMat, float *resultMat) { // 使用更大的分块以提高缓存利用率 const int BLOCK_SIZE = 64; #pragma omp parallel for collapse(2) for (int row = 0; row < m; row++) { for (int col = 0; col < k; col++) { resultMat[row * k + col] = 0.0; } } // 分块计算以提高缓存命中率 #pragma omp parallel for collapse(2) for (int row_block = 0; row_block < m; row_block += BLOCK_SIZE) { for (int col_block = 0; col_block < k; col_block += BLOCK_SIZE) { for (int i_block = 0; i_block < n; i_block += BLOCK_SIZE) { int row_end = min(row_block + BLOCK_SIZE, m); int col_end = min(col_block + BLOCK_SIZE, k); int i_end = min(i_block + BLOCK_SIZE, n); for (int row = row_block; row < row_end; row++) { for (int col = col_block; col < col_end; col++) { float sum = resultMat[row * k + col]; for (int i = i_block; i < i_end; i++) { sum += leftMat[row * n + i] * rightMat[col * n + i]; } resultMat[row * k + col] = sum; } } } } } } void mpi_sgemm_optimized(int m, int n, int k, float *&leftMat, float *&rightMat, float *&resultMat, int rank, int worldsize) { // 计算行列分块数 int rowBlock = (int)sqrt((double)worldsize); while (rowBlock > 0 && worldsize % rowBlock != 0) { rowBlock--; } int colBlock = worldsize / rowBlock; int rowStride, colStride; float *res = nullptr; float *localLeftMat = leftMat; float *localRightMat = rightMat; if (rank == 0) { // 矩阵转置 - 使用OpenMP加速 float *buf = new float[k * n]; #pragma omp parallel for collapse(2) for (int r = 0; r < n; r++) { for (int c = 0; c < k; c++) { buf[c * n + r] = rightMat[r * k + c]; } } #pragma omp parallel for collapse(2) for (int r = 0; r < k; r++) { for (int c = 0; c < n; c++) { rightMat[r * n + c] = buf[r * n + c]; } } delete[] buf; // 使用非阻塞通信重叠计算和通信 std::vector sendRequests; sendRequests.reserve(1000); for (int rowB = 0; rowB < rowBlock; rowB++) { for (int colB = 0; colB < colBlock; colB++) { int rowStart = rowB * (m / rowBlock); int rowEnd = (rowB == rowBlock - 1) ? m : (rowB + 1) * (m / rowBlock); rowStride = rowEnd - rowStart; int colStart = colB * (k / colBlock); int colEnd = (colB == colBlock - 1) ? k : (colB + 1) * (k / colBlock); colStride = colEnd - colStart; int sendto = rowB * colBlock + colB; if (sendto == 0) { res = new float[rowStride * colStride]; localLeftMat = leftMat + rowStart * n; localRightMat = rightMat + colStart * n; continue; } // 发送分块大小 MPI_Request req; MPI_Isend(&rowStride, 1, MPI_INT, sendto, 0, MPI_COMM_WORLD, &req); sendRequests.push_back(req); MPI_Isend(&colStride, 1, MPI_INT, sendto, 0, MPI_COMM_WORLD, &req); sendRequests.push_back(req); // 发送矩阵数据 for (int r = 0; r < rowStride; r++) { MPI_Isend(leftMat + (rowStart + r) * n, n, MPI_FLOAT, sendto, 1, MPI_COMM_WORLD, &req); sendRequests.push_back(req); } for (int c = 0; c < colStride; c++) { MPI_Isend(rightMat + (colStart + c) * n, n, MPI_FLOAT, sendto, 2, MPI_COMM_WORLD, &req); sendRequests.push_back(req); } } } // 等待所有发送完成 for (size_t i = 0; i < sendRequests.size(); i++) { MPI_Wait(&sendRequests[i], MPI_STATUS_IGNORE); } } else { if (rank < worldsize) { int rowB = rank / colBlock; int colB = rank % colBlock; int rowStart = rowB * (m / rowBlock); int rowEnd = (rowB == rowBlock - 1) ? m : (rowB + 1) * (m / rowBlock); rowStride = rowEnd - rowStart; int colStart = colB * (k / colBlock); int colEnd = (colB == colBlock - 1) ? k : (colB + 1) * (k / colBlock); colStride = colEnd - colStart; MPI_Recv(&rowStride, 1, MPI_INT, 0, 0, MPI_COMM_WORLD, MPI_STATUS_IGNORE); MPI_Recv(&colStride, 1, MPI_INT, 0, 0, MPI_COMM_WORLD, MPI_STATUS_IGNORE); localLeftMat = new float[rowStride * n]; localRightMat = new float[colStride * n]; for (int r = 0; r < rowStride; r++) { MPI_Recv(localLeftMat + r * n, n, MPI_FLOAT, 0, 1, MPI_COMM_WORLD, MPI_STATUS_IGNORE); } for (int c = 0; c < colStride; c++) { MPI_Recv(localRightMat + c * n, n, MPI_FLOAT, 0, 2, MPI_COMM_WORLD, MPI_STATUS_IGNORE); } res = new float[rowStride * colStride]; } } MPI_Barrier(MPI_COMM_WORLD); // 本地计算 - 使用优化版本 if (rank < worldsize) { int rowB = rank / colBlock; int colB = rank % colBlock; int rowStart = rowB * (m / rowBlock); int rowEnd = (rowB == rowBlock - 1) ? m : (rowB + 1) * (m / rowBlock); rowStride = rowEnd - rowStart; int colStart = colB * (k / colBlock); int colEnd = (colB == colBlock - 1) ? k : (colB + 1) * (k / colBlock); colStride = colEnd - colStart; openmp_sgemm_optimized(rowStride, n, colStride, localLeftMat, localRightMat, res); } MPI_Barrier(MPI_COMM_WORLD); // 收集结果 if (rank == 0) { int rowB = 0; int colB = 0; int rowStart = rowB * (m / rowBlock); int colStart = colB * (k / colBlock); for (int r = 0; r < rowStride; r++) { for (int c = 0; c < colStride; c++) { resultMat[(rowStart + r) * k + (colStart + c)] = res[r * colStride + c]; } } delete[] res; for (int rowB = 0; rowB < rowBlock; rowB++) { for (int colB = 0; colB < colBlock; colB++) { int recvfrom = rowB * colBlock + colB; if (recvfrom == 0) continue; MPI_Recv(&rowStride, 1, MPI_INT, recvfrom, 3, MPI_COMM_WORLD, MPI_STATUS_IGNORE); MPI_Recv(&colStride, 1, MPI_INT, recvfrom, 3, MPI_COMM_WORLD, MPI_STATUS_IGNORE); float *tmpRes = new float[rowStride * colStride]; MPI_Recv(tmpRes, rowStride * colStride, MPI_FLOAT, recvfrom, 4, MPI_COMM_WORLD, MPI_STATUS_IGNORE); int rowStart = rowB * (m / rowBlock); int colStart = colB * (k / colBlock); for (int r = 0; r < rowStride; r++) { for (int c = 0; c < colStride; c++) { resultMat[(rowStart + r) * k + (colStart + c)] = tmpRes[r * colStride + c]; } } delete[] tmpRes; } } } else { if (rank < worldsize) { MPI_Send(&rowStride, 1, MPI_INT, 0, 3, MPI_COMM_WORLD); MPI_Send(&colStride, 1, MPI_INT, 0, 3, MPI_COMM_WORLD); MPI_Send(res, rowStride * colStride, MPI_FLOAT, 0, 4, MPI_COMM_WORLD); delete[] res; delete[] localLeftMat; delete[] localRightMat; } } MPI_Barrier(MPI_COMM_WORLD); } int main(int argc, char *argv[]) { if (argc != 4) { cout << "Usage: " << argv[0] << " M N K\n"; exit(-1); } int rank; int worldSize; MPI_Init(&argc, &argv); MPI_Comm_size(MPI_COMM_WORLD, &worldSize); MPI_Comm_rank(MPI_COMM_WORLD, &rank); int m = atoi(argv[1]); int n = atoi(argv[2]); int k = atoi(argv[3]); float *leftMat, *rightMat, *resMat; struct timeval start, stop; if (rank == 0) { randMat(m, n, leftMat); randMat(n, k, rightMat); randMat(m, k, resMat); } gettimeofday(&start, NULL); mpi_sgemm_optimized(m, n, k, leftMat, rightMat, resMat, rank, worldSize); gettimeofday(&stop, NULL); if (rank == 0) { double elapsed = (stop.tv_sec - start.tv_sec) * 1000.0 + (stop.tv_usec - start.tv_usec) / 1000.0; cout << "optimized mpi matmul: " << elapsed << " ms" << endl; bool correct = true; for (int i = 0; i < m; i++) { for (int j = 0; j < k; j++){ if (int(resMat[i * k + j]) != n) { cout << "Error at [" << i << "][" << j << "]: " << resMat[i * k + j] << " (expected " << n << ")\n"; correct = false; goto end_check; } } } end_check: if (correct) { cout << "Result verification: PASSED" << endl; } else { cout << "Result verification: FAILED" << endl; } delete[] leftMat; delete[] rightMat; delete[] resMat; } MPI_Finalize(); return 0; }