hpc-lab-code/work/gemm_parallel.cpp
2026-01-21 18:02:30 +08:00

313 lines
10 KiB
C++
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#include <math.h>
#include <stdio.h>
#include <stdlib.h>
#include <sys/time.h>
#include <iostream>
#include <mpi.h>
#include <omp.h>
#include <vector>
using namespace std;
void randMat(int rows, int cols, float *&Mat) {
Mat = new float[rows * cols];
for (int i = 0; i < rows; i++)
for (int j = 0; j < cols; j++)
Mat[i * cols + j] = 1.0;
}
void openmp_sgemm(int m, int n, int k, float *leftMat, float *rightMat,
float *resultMat) {
// rightMat is transposed
// 使用OpenMP并行化外层循环
#pragma omp parallel for collapse(2)
for (int row = 0; row < m; row++) {
for (int col = 0; col < k; col++) {
resultMat[row * k + col] = 0.0;
for (int i = 0; i < n; i++) {
resultMat[row * k + col] +=
leftMat[row * n + i] * rightMat[col * n + i];
}
}
}
}
void mpi_sgemm(int m, int n, int k, float *&leftMat, float *&rightMat,
float *&resultMat, int rank, int worldsize) {
// 计算行列分块数(尽量接近平方数)
int rowBlock = (int)sqrt((double)worldsize);
while (rowBlock > 0 && worldsize % rowBlock != 0) {
rowBlock--;
}
int colBlock = worldsize / rowBlock;
int rowStride, colStride;
float *res = nullptr;
float *localLeftMat = leftMat;
float *localRightMat = rightMat;
if (rank == 0) {
// 矩阵转置
float *buf = new float[k * n];
#pragma omp parallel for collapse(2)
for (int r = 0; r < n; r++) {
for (int c = 0; c < k; c++) {
buf[c * n + r] = rightMat[r * k + c];
}
}
#pragma omp parallel for collapse(2)
for (int r = 0; r < k; r++) {
for (int c = 0; c < n; c++) {
rightMat[r * n + c] = buf[r * n + c];
}
}
delete[] buf;
// Master-Slave模式将子矩阵发送到各子进程
// 使用vector来动态分配足够的请求空间
std::vector<MPI_Request> sendRequests;
sendRequests.reserve(1000); // 预分配足够空间
for (int rowB = 0; rowB < rowBlock; rowB++) {
for (int colB = 0; colB < colBlock; colB++) {
// 计算分块大小(带状分块)
int rowStart = rowB * (m / rowBlock);
int rowEnd = (rowB == rowBlock - 1) ? m : (rowB + 1) * (m / rowBlock);
rowStride = rowEnd - rowStart;
int colStart = colB * (k / colBlock);
int colEnd = (colB == colBlock - 1) ? k : (colB + 1) * (k / colBlock);
colStride = colEnd - colStart;
int sendto = rowB * colBlock + colB;
if (sendto == 0) {
// Rank 0 保留自己的分块
res = new float[rowStride * colStride];
localLeftMat = leftMat + rowStart * n;
localRightMat = rightMat + colStart * n;
continue;
}
// 发送左矩阵分块
MPI_Request req;
MPI_Isend(&rowStride, 1, MPI_INT, sendto, 0, MPI_COMM_WORLD, &req);
sendRequests.push_back(req);
MPI_Isend(&colStride, 1, MPI_INT, sendto, 0, MPI_COMM_WORLD, &req);
sendRequests.push_back(req);
// 发送左矩阵数据
for (int r = 0; r < rowStride; r++) {
MPI_Isend(leftMat + (rowStart + r) * n, n, MPI_FLOAT, sendto,
1, MPI_COMM_WORLD, &req);
sendRequests.push_back(req);
}
// 发送右矩阵数据
for (int c = 0; c < colStride; c++) {
MPI_Isend(rightMat + (colStart + c) * n, n, MPI_FLOAT, sendto,
2, MPI_COMM_WORLD, &req);
sendRequests.push_back(req);
}
}
}
// 等待所有发送完成
for (size_t i = 0; i < sendRequests.size(); i++) {
MPI_Wait(&sendRequests[i], MPI_STATUS_IGNORE);
}
} else {
// 接收从主进程发送来的数据
if (rank < worldsize) {
// 计算当前rank的分块位置
int rowB = rank / colBlock;
int colB = rank % colBlock;
int rowStart = rowB * (m / rowBlock);
int rowEnd = (rowB == rowBlock - 1) ? m : (rowB + 1) * (m / rowBlock);
rowStride = rowEnd - rowStart;
int colStart = colB * (k / colBlock);
int colEnd = (colB == colBlock - 1) ? k : (colB + 1) * (k / colBlock);
colStride = colEnd - colStart;
// 接收分块大小
MPI_Recv(&rowStride, 1, MPI_INT, 0, 0, MPI_COMM_WORLD, MPI_STATUS_IGNORE);
MPI_Recv(&colStride, 1, MPI_INT, 0, 0, MPI_COMM_WORLD, MPI_STATUS_IGNORE);
// 分配内存并接收数据
localLeftMat = new float[rowStride * n];
localRightMat = new float[colStride * n];
for (int r = 0; r < rowStride; r++) {
MPI_Recv(localLeftMat + r * n, n, MPI_FLOAT, 0, 1, MPI_COMM_WORLD,
MPI_STATUS_IGNORE);
}
for (int c = 0; c < colStride; c++) {
MPI_Recv(localRightMat + c * n, n, MPI_FLOAT, 0, 2, MPI_COMM_WORLD,
MPI_STATUS_IGNORE);
}
res = new float[rowStride * colStride];
}
}
MPI_Barrier(MPI_COMM_WORLD);
// 本地子矩阵相乘
if (rank < worldsize) {
// 重新计算分块大小
int rowB = rank / colBlock;
int colB = rank % colBlock;
int rowStart = rowB * (m / rowBlock);
int rowEnd = (rowB == rowBlock - 1) ? m : (rowB + 1) * (m / rowBlock);
rowStride = rowEnd - rowStart;
int colStart = colB * (k / colBlock);
int colEnd = (colB == colBlock - 1) ? k : (colB + 1) * (k / colBlock);
colStride = colEnd - colStart;
// 调用OpenMP加速本地子矩阵相乘运算
openmp_sgemm(rowStride, n, colStride, localLeftMat, localRightMat, res);
}
MPI_Barrier(MPI_COMM_WORLD);
// 将计算结果传送回rank 0
if (rank == 0) {
// Rank 0 直接复制自己的结果
int rowB = 0;
int colB = 0;
int rowStart = rowB * (m / rowBlock);
int colStart = colB * (k / colBlock);
for (int r = 0; r < rowStride; r++) {
for (int c = 0; c < colStride; c++) {
resultMat[(rowStart + r) * k + (colStart + c)] = res[r * colStride + c];
}
}
delete[] res;
// 接收其他进程的结果
for (int rowB = 0; rowB < rowBlock; rowB++) {
for (int colB = 0; colB < colBlock; colB++) {
int recvfrom = rowB * colBlock + colB;
if (recvfrom == 0) continue;
// 接收分块大小
MPI_Recv(&rowStride, 1, MPI_INT, recvfrom, 3, MPI_COMM_WORLD, MPI_STATUS_IGNORE);
MPI_Recv(&colStride, 1, MPI_INT, recvfrom, 3, MPI_COMM_WORLD, MPI_STATUS_IGNORE);
// 接收结果数据
float *tmpRes = new float[rowStride * colStride];
MPI_Recv(tmpRes, rowStride * colStride, MPI_FLOAT, recvfrom, 4,
MPI_COMM_WORLD, MPI_STATUS_IGNORE);
// 组装到全局矩阵
int rowStart = rowB * (m / rowBlock);
int colStart = colB * (k / colBlock);
for (int r = 0; r < rowStride; r++) {
for (int c = 0; c < colStride; c++) {
resultMat[(rowStart + r) * k + (colStart + c)] = tmpRes[r * colStride + c];
}
}
delete[] tmpRes;
}
}
} else {
if (rank < worldsize) {
// 发送分块大小
MPI_Send(&rowStride, 1, MPI_INT, 0, 3, MPI_COMM_WORLD);
MPI_Send(&colStride, 1, MPI_INT, 0, 3, MPI_COMM_WORLD);
// 发送结果数据
MPI_Send(res, rowStride * colStride, MPI_FLOAT, 0, 4, MPI_COMM_WORLD);
delete[] res;
delete[] localLeftMat;
delete[] localRightMat;
}
}
MPI_Barrier(MPI_COMM_WORLD);
}
int main(int argc, char *argv[]) {
if (argc != 4) {
if (argc == 0) {
cout << "Usage: program M N K" << endl;
} else {
cout << "Usage: " << argv[0] << " M N K\n";
}
exit(-1);
}
int rank;
int worldSize;
MPI_Init(&argc, &argv);
MPI_Comm_size(MPI_COMM_WORLD, &worldSize);
MPI_Comm_rank(MPI_COMM_WORLD, &rank);
// 矩阵尺寸
int m = atoi(argv[1]);
int n = atoi(argv[2]);
int k = atoi(argv[3]);
float *leftMat, *rightMat, *resMat;
struct timeval start, stop;
// 矩阵初始化
if (rank == 0) {
randMat(m, n, leftMat);
randMat(n, k, rightMat);
randMat(m, k, resMat);
}
gettimeofday(&start, NULL);
// 使用MPI-OpenMP加速矩阵相乘
mpi_sgemm(m, n, k, leftMat, rightMat, resMat, rank, worldSize);
gettimeofday(&stop, NULL);
// 打印结果
if (rank == 0) {
double elapsed = (stop.tv_sec - start.tv_sec) * 1000.0 +
(stop.tv_usec - start.tv_usec) / 1000.0;
cout << "mpi matmul: " << elapsed << " ms" << endl;
// 验证结果
bool correct = true;
for (int i = 0; i < m; i++) {
for (int j = 0; j < k; j++){
if (int(resMat[i * k + j]) != n) {
cout << "Error at [" << i << "][" << j << "]: "
<< resMat[i * k + j] << " (expected " << n << ")\n";
correct = false;
goto end_check;
}
}
}
end_check:
if (correct) {
cout << "Result verification: PASSED" << endl;
} else {
cout << "Result verification: FAILED" << endl;
}
delete[] leftMat;
delete[] rightMat;
delete[] resMat;
}
MPI_Finalize();
return 0;
}