303 lines
10 KiB
C++
303 lines
10 KiB
C++
#include <math.h>
|
|
#include <stdio.h>
|
|
#include <stdlib.h>
|
|
#include <sys/time.h>
|
|
#include <iostream>
|
|
#include <mpi.h>
|
|
#include <omp.h>
|
|
#include <vector>
|
|
|
|
using namespace std;
|
|
|
|
void randMat(int rows, int cols, float *&Mat) {
|
|
Mat = new float[rows * cols];
|
|
for (int i = 0; i < rows; i++)
|
|
for (int j = 0; j < cols; j++)
|
|
Mat[i * cols + j] = 1.0;
|
|
}
|
|
|
|
// 优化版本:使用循环展开和更好的缓存局部性
|
|
void openmp_sgemm_optimized(int m, int n, int k, float *leftMat, float *rightMat,
|
|
float *resultMat) {
|
|
// 使用更大的分块以提高缓存利用率
|
|
const int BLOCK_SIZE = 64;
|
|
|
|
#pragma omp parallel for collapse(2)
|
|
for (int row = 0; row < m; row++) {
|
|
for (int col = 0; col < k; col++) {
|
|
resultMat[row * k + col] = 0.0;
|
|
}
|
|
}
|
|
|
|
// 分块计算以提高缓存命中率
|
|
#pragma omp parallel for collapse(2)
|
|
for (int row_block = 0; row_block < m; row_block += BLOCK_SIZE) {
|
|
for (int col_block = 0; col_block < k; col_block += BLOCK_SIZE) {
|
|
for (int i_block = 0; i_block < n; i_block += BLOCK_SIZE) {
|
|
|
|
int row_end = min(row_block + BLOCK_SIZE, m);
|
|
int col_end = min(col_block + BLOCK_SIZE, k);
|
|
int i_end = min(i_block + BLOCK_SIZE, n);
|
|
|
|
for (int row = row_block; row < row_end; row++) {
|
|
for (int col = col_block; col < col_end; col++) {
|
|
float sum = resultMat[row * k + col];
|
|
for (int i = i_block; i < i_end; i++) {
|
|
sum += leftMat[row * n + i] * rightMat[col * n + i];
|
|
}
|
|
resultMat[row * k + col] = sum;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
void mpi_sgemm_optimized(int m, int n, int k, float *&leftMat, float *&rightMat,
|
|
float *&resultMat, int rank, int worldsize) {
|
|
|
|
// 计算行列分块数
|
|
int rowBlock = (int)sqrt((double)worldsize);
|
|
while (rowBlock > 0 && worldsize % rowBlock != 0) {
|
|
rowBlock--;
|
|
}
|
|
int colBlock = worldsize / rowBlock;
|
|
|
|
int rowStride, colStride;
|
|
float *res = nullptr;
|
|
float *localLeftMat = leftMat;
|
|
float *localRightMat = rightMat;
|
|
|
|
if (rank == 0) {
|
|
// 矩阵转置 - 使用OpenMP加速
|
|
float *buf = new float[k * n];
|
|
#pragma omp parallel for collapse(2)
|
|
for (int r = 0; r < n; r++) {
|
|
for (int c = 0; c < k; c++) {
|
|
buf[c * n + r] = rightMat[r * k + c];
|
|
}
|
|
}
|
|
|
|
#pragma omp parallel for collapse(2)
|
|
for (int r = 0; r < k; r++) {
|
|
for (int c = 0; c < n; c++) {
|
|
rightMat[r * n + c] = buf[r * n + c];
|
|
}
|
|
}
|
|
delete[] buf;
|
|
|
|
// 使用非阻塞通信重叠计算和通信
|
|
std::vector<MPI_Request> sendRequests;
|
|
sendRequests.reserve(1000);
|
|
|
|
for (int rowB = 0; rowB < rowBlock; rowB++) {
|
|
for (int colB = 0; colB < colBlock; colB++) {
|
|
int rowStart = rowB * (m / rowBlock);
|
|
int rowEnd = (rowB == rowBlock - 1) ? m : (rowB + 1) * (m / rowBlock);
|
|
rowStride = rowEnd - rowStart;
|
|
|
|
int colStart = colB * (k / colBlock);
|
|
int colEnd = (colB == colBlock - 1) ? k : (colB + 1) * (k / colBlock);
|
|
colStride = colEnd - colStart;
|
|
|
|
int sendto = rowB * colBlock + colB;
|
|
if (sendto == 0) {
|
|
res = new float[rowStride * colStride];
|
|
localLeftMat = leftMat + rowStart * n;
|
|
localRightMat = rightMat + colStart * n;
|
|
continue;
|
|
}
|
|
|
|
// 发送分块大小
|
|
MPI_Request req;
|
|
MPI_Isend(&rowStride, 1, MPI_INT, sendto, 0, MPI_COMM_WORLD, &req);
|
|
sendRequests.push_back(req);
|
|
MPI_Isend(&colStride, 1, MPI_INT, sendto, 0, MPI_COMM_WORLD, &req);
|
|
sendRequests.push_back(req);
|
|
|
|
// 发送矩阵数据
|
|
for (int r = 0; r < rowStride; r++) {
|
|
MPI_Isend(leftMat + (rowStart + r) * n, n, MPI_FLOAT, sendto,
|
|
1, MPI_COMM_WORLD, &req);
|
|
sendRequests.push_back(req);
|
|
}
|
|
|
|
for (int c = 0; c < colStride; c++) {
|
|
MPI_Isend(rightMat + (colStart + c) * n, n, MPI_FLOAT, sendto,
|
|
2, MPI_COMM_WORLD, &req);
|
|
sendRequests.push_back(req);
|
|
}
|
|
}
|
|
}
|
|
|
|
// 等待所有发送完成
|
|
for (size_t i = 0; i < sendRequests.size(); i++) {
|
|
MPI_Wait(&sendRequests[i], MPI_STATUS_IGNORE);
|
|
}
|
|
} else {
|
|
if (rank < worldsize) {
|
|
int rowB = rank / colBlock;
|
|
int colB = rank % colBlock;
|
|
|
|
int rowStart = rowB * (m / rowBlock);
|
|
int rowEnd = (rowB == rowBlock - 1) ? m : (rowB + 1) * (m / rowBlock);
|
|
rowStride = rowEnd - rowStart;
|
|
|
|
int colStart = colB * (k / colBlock);
|
|
int colEnd = (colB == colBlock - 1) ? k : (colB + 1) * (k / colBlock);
|
|
colStride = colEnd - colStart;
|
|
|
|
MPI_Recv(&rowStride, 1, MPI_INT, 0, 0, MPI_COMM_WORLD, MPI_STATUS_IGNORE);
|
|
MPI_Recv(&colStride, 1, MPI_INT, 0, 0, MPI_COMM_WORLD, MPI_STATUS_IGNORE);
|
|
|
|
localLeftMat = new float[rowStride * n];
|
|
localRightMat = new float[colStride * n];
|
|
|
|
for (int r = 0; r < rowStride; r++) {
|
|
MPI_Recv(localLeftMat + r * n, n, MPI_FLOAT, 0, 1, MPI_COMM_WORLD,
|
|
MPI_STATUS_IGNORE);
|
|
}
|
|
|
|
for (int c = 0; c < colStride; c++) {
|
|
MPI_Recv(localRightMat + c * n, n, MPI_FLOAT, 0, 2, MPI_COMM_WORLD,
|
|
MPI_STATUS_IGNORE);
|
|
}
|
|
|
|
res = new float[rowStride * colStride];
|
|
}
|
|
}
|
|
|
|
MPI_Barrier(MPI_COMM_WORLD);
|
|
|
|
// 本地计算 - 使用优化版本
|
|
if (rank < worldsize) {
|
|
int rowB = rank / colBlock;
|
|
int colB = rank % colBlock;
|
|
|
|
int rowStart = rowB * (m / rowBlock);
|
|
int rowEnd = (rowB == rowBlock - 1) ? m : (rowB + 1) * (m / rowBlock);
|
|
rowStride = rowEnd - rowStart;
|
|
|
|
int colStart = colB * (k / colBlock);
|
|
int colEnd = (colB == colBlock - 1) ? k : (colB + 1) * (k / colBlock);
|
|
colStride = colEnd - colStart;
|
|
|
|
openmp_sgemm_optimized(rowStride, n, colStride, localLeftMat, localRightMat, res);
|
|
}
|
|
|
|
MPI_Barrier(MPI_COMM_WORLD);
|
|
|
|
// 收集结果
|
|
if (rank == 0) {
|
|
int rowB = 0;
|
|
int colB = 0;
|
|
int rowStart = rowB * (m / rowBlock);
|
|
int colStart = colB * (k / colBlock);
|
|
|
|
for (int r = 0; r < rowStride; r++) {
|
|
for (int c = 0; c < colStride; c++) {
|
|
resultMat[(rowStart + r) * k + (colStart + c)] = res[r * colStride + c];
|
|
}
|
|
}
|
|
delete[] res;
|
|
|
|
for (int rowB = 0; rowB < rowBlock; rowB++) {
|
|
for (int colB = 0; colB < colBlock; colB++) {
|
|
int recvfrom = rowB * colBlock + colB;
|
|
if (recvfrom == 0) continue;
|
|
|
|
MPI_Recv(&rowStride, 1, MPI_INT, recvfrom, 3, MPI_COMM_WORLD, MPI_STATUS_IGNORE);
|
|
MPI_Recv(&colStride, 1, MPI_INT, recvfrom, 3, MPI_COMM_WORLD, MPI_STATUS_IGNORE);
|
|
|
|
float *tmpRes = new float[rowStride * colStride];
|
|
MPI_Recv(tmpRes, rowStride * colStride, MPI_FLOAT, recvfrom, 4,
|
|
MPI_COMM_WORLD, MPI_STATUS_IGNORE);
|
|
|
|
int rowStart = rowB * (m / rowBlock);
|
|
int colStart = colB * (k / colBlock);
|
|
|
|
for (int r = 0; r < rowStride; r++) {
|
|
for (int c = 0; c < colStride; c++) {
|
|
resultMat[(rowStart + r) * k + (colStart + c)] = tmpRes[r * colStride + c];
|
|
}
|
|
}
|
|
delete[] tmpRes;
|
|
}
|
|
}
|
|
} else {
|
|
if (rank < worldsize) {
|
|
MPI_Send(&rowStride, 1, MPI_INT, 0, 3, MPI_COMM_WORLD);
|
|
MPI_Send(&colStride, 1, MPI_INT, 0, 3, MPI_COMM_WORLD);
|
|
MPI_Send(res, rowStride * colStride, MPI_FLOAT, 0, 4, MPI_COMM_WORLD);
|
|
|
|
delete[] res;
|
|
delete[] localLeftMat;
|
|
delete[] localRightMat;
|
|
}
|
|
}
|
|
|
|
MPI_Barrier(MPI_COMM_WORLD);
|
|
}
|
|
|
|
int main(int argc, char *argv[]) {
|
|
if (argc != 4) {
|
|
cout << "Usage: " << argv[0] << " M N K\n";
|
|
exit(-1);
|
|
}
|
|
|
|
int rank;
|
|
int worldSize;
|
|
MPI_Init(&argc, &argv);
|
|
|
|
MPI_Comm_size(MPI_COMM_WORLD, &worldSize);
|
|
MPI_Comm_rank(MPI_COMM_WORLD, &rank);
|
|
|
|
int m = atoi(argv[1]);
|
|
int n = atoi(argv[2]);
|
|
int k = atoi(argv[3]);
|
|
|
|
float *leftMat, *rightMat, *resMat;
|
|
struct timeval start, stop;
|
|
|
|
if (rank == 0) {
|
|
randMat(m, n, leftMat);
|
|
randMat(n, k, rightMat);
|
|
randMat(m, k, resMat);
|
|
}
|
|
|
|
gettimeofday(&start, NULL);
|
|
mpi_sgemm_optimized(m, n, k, leftMat, rightMat, resMat, rank, worldSize);
|
|
gettimeofday(&stop, NULL);
|
|
|
|
if (rank == 0) {
|
|
double elapsed = (stop.tv_sec - start.tv_sec) * 1000.0 +
|
|
(stop.tv_usec - start.tv_usec) / 1000.0;
|
|
cout << "optimized mpi matmul: " << elapsed << " ms" << endl;
|
|
|
|
bool correct = true;
|
|
for (int i = 0; i < m; i++) {
|
|
for (int j = 0; j < k; j++){
|
|
if (int(resMat[i * k + j]) != n) {
|
|
cout << "Error at [" << i << "][" << j << "]: "
|
|
<< resMat[i * k + j] << " (expected " << n << ")\n";
|
|
correct = false;
|
|
goto end_check;
|
|
}
|
|
}
|
|
}
|
|
end_check:
|
|
if (correct) {
|
|
cout << "Result verification: PASSED" << endl;
|
|
} else {
|
|
cout << "Result verification: FAILED" << endl;
|
|
}
|
|
|
|
delete[] leftMat;
|
|
delete[] rightMat;
|
|
delete[] resMat;
|
|
}
|
|
|
|
MPI_Finalize();
|
|
return 0;
|
|
}
|