155 lines
4.7 KiB
C++
155 lines
4.7 KiB
C++
#include <iostream>
|
||
#include <vector>
|
||
#include <cmath>
|
||
#include <mpi.h>
|
||
|
||
// 对局部区间执行埃拉托斯特尼筛法
|
||
void local_sieve(int low, int high, std::vector<bool>& is_prime, const std::vector<int>& base_primes) {
|
||
// 初始化局部区间内的所有数为可能的素数
|
||
is_prime.assign(high - low + 1, true);
|
||
|
||
// 如果区间从0或1开始,标记它们为非素数
|
||
if (low == 0) {
|
||
is_prime[0] = false;
|
||
if (high >= 1) {
|
||
is_prime[1] = false;
|
||
}
|
||
} else if (low == 1) {
|
||
is_prime[0] = false;
|
||
}
|
||
|
||
// 使用基础素数标记局部区间中的非素数
|
||
for (int p : base_primes) {
|
||
// 找到p在[low, high]范围内的第一个倍数
|
||
int start_multiple = (low / p) * p;
|
||
if (start_multiple < low) {
|
||
start_multiple += p;
|
||
}
|
||
// 确保不将素数本身标记为非素数
|
||
if (start_multiple == p) {
|
||
start_multiple += p;
|
||
}
|
||
|
||
// 标记局部区间中p的所有倍数为非素数
|
||
for (int multiple = start_multiple; multiple <= high; multiple += p) {
|
||
is_prime[multiple - low] = false;
|
||
}
|
||
}
|
||
}
|
||
|
||
int main(int argc, char* argv[]) {
|
||
MPI_Init(&argc, &argv);
|
||
|
||
int rank, size;
|
||
double wtime;
|
||
MPI_Comm_rank(MPI_COMM_WORLD, &rank);
|
||
MPI_Comm_size(MPI_COMM_WORLD, &size);
|
||
|
||
// 检查参数数量
|
||
if (argc != 3) {
|
||
if (rank == 0) {
|
||
std::cerr << "用法: " << argv[0] << " <N> <B>" << std::endl;
|
||
std::cerr << " N: 区间[2, N]的上界" << std::endl;
|
||
std::cerr << " B: 分配区间的块大小" << std::endl;
|
||
}
|
||
MPI_Finalize();
|
||
return 1;
|
||
}
|
||
|
||
int N = std::atoi(argv[1]);
|
||
int B = std::atoi(argv[2]);
|
||
|
||
if (N < 2) {
|
||
if (rank == 0) {
|
||
std::cout << "区间[2, " << N << "]包含0个素数" << std::endl;
|
||
}
|
||
MPI_Finalize();
|
||
return 0;
|
||
}
|
||
|
||
// 步骤1: 进程0找出sqrt(N)内的基础素数
|
||
std::vector<int> base_primes;
|
||
int limit = static_cast<int>(std::sqrt(N));
|
||
if (rank == 0) {
|
||
wtime = MPI_Wtime();
|
||
|
||
std::vector<bool> is_prime_small(limit + 1, true);
|
||
is_prime_small[0] = is_prime_small[1] = false;
|
||
for (int p = 2; p * p <= limit; ++p) {
|
||
if (is_prime_small[p]) {
|
||
for (int i = p * p; i <= limit; i += p) {
|
||
is_prime_small[i] = false;
|
||
}
|
||
}
|
||
}
|
||
for (int i = 2; i <= limit; ++i) {
|
||
if (is_prime_small[i]) {
|
||
base_primes.push_back(i);
|
||
}
|
||
}
|
||
}
|
||
|
||
// 步骤2: 广播基础素数到所有进程
|
||
int num_base_primes = base_primes.size();
|
||
MPI_Bcast(&num_base_primes, 1, MPI_INT, 0, MPI_COMM_WORLD);
|
||
if (rank != 0) {
|
||
base_primes.resize(num_base_primes);
|
||
}
|
||
MPI_Bcast(base_primes.data(), num_base_primes, MPI_INT, 0, MPI_COMM_WORLD);
|
||
|
||
// 步骤3: 在进程间分配区间[sqrt(N)+1, N]
|
||
int start_range = limit + 1;
|
||
if (start_range > N) {
|
||
// 无需分配,所有素数都是基础素数
|
||
int total_count = base_primes.size();
|
||
if (rank == 0) {
|
||
std::cout << "区间[2, " << N << "]内的素数总数为 " << total_count << std::endl;
|
||
}
|
||
MPI_Finalize();
|
||
return 0;
|
||
}
|
||
|
||
int total_elements = N - start_range + 1;
|
||
int local_low, local_high;
|
||
std::vector<bool> is_prime_local;
|
||
|
||
// 计算每个进程分配的区间
|
||
int elements_per_proc = total_elements / size;
|
||
int remainder = total_elements % size;
|
||
|
||
if (rank < remainder) {
|
||
local_low = start_range + rank * (elements_per_proc + 1);
|
||
local_high = local_low + elements_per_proc;
|
||
} else {
|
||
local_low = start_range + rank * elements_per_proc + remainder;
|
||
local_high = local_low + elements_per_proc - 1;
|
||
}
|
||
local_high = std::min(local_high, N);
|
||
|
||
// 对分配的局部区间执行筛法
|
||
local_sieve(local_low, local_high, is_prime_local, base_primes);
|
||
|
||
// 统计局部区间内的素数数量
|
||
int local_prime_count = 0;
|
||
for (bool prime : is_prime_local) {
|
||
if (prime) {
|
||
local_prime_count++;
|
||
}
|
||
}
|
||
|
||
// 步骤4: 汇总局部素数计数
|
||
int global_prime_count = 0;
|
||
MPI_Reduce(&local_prime_count, &global_prime_count, 1, MPI_INT, MPI_SUM, 0, MPI_COMM_WORLD);
|
||
|
||
// 步骤5: 进程0输出最终结果
|
||
if (rank == 0) {
|
||
double end_wtime = MPI_Wtime() - wtime;
|
||
int total_count = base_primes.size() + global_prime_count;
|
||
std::cout << "区间[2, " << N << "]内的素数总数为 " << total_count << std::endl;
|
||
std::cout << "计算时间: " << end_wtime << " 秒" << std::endl;
|
||
}
|
||
|
||
MPI_Finalize();
|
||
return 0;
|
||
}
|