1use anyhow::{anyhow, Result};
2use lazy_static::lazy_static;
3use std::collections::HashSet;
4use std::sync::RwLock;
5
6use super::{ArgCount, FunctionCategory, FunctionSignature, SqlFunction};
7use crate::data::datatable::DataValue;
8
9include!(concat!(env!("OUT_DIR"), "/primes_data.rs"));
11
12lazy_static! {
14 static ref PRIME_SET: HashSet<u32> = PRIMES_100K.iter().copied().collect();
15 static ref EXTENDED_PRIME_CACHE: RwLock<Vec<u64>> = RwLock::new(Vec::new());
16}
17
18pub struct PrimeEngine;
20
21impl PrimeEngine {
22 pub fn nth_prime(n: usize) -> Result<u64> {
24 if n == 0 {
25 return Err(anyhow!("Prime index must be >= 1"));
26 }
27
28 if n <= 1_000 {
30 return Ok(PRIMES_1K[n - 1] as u64);
31 }
32 if n <= 10_000 {
33 return Ok(PRIMES_10K[n - 1] as u64);
34 }
35 if n <= 100_000 {
36 return Ok(PRIMES_100K[n - 1] as u64);
37 }
38
39 Self::generate_nth_prime(n)
41 }
42
43 pub fn is_prime(n: u64) -> bool {
45 if n < 2 {
46 return false;
47 }
48 if n == 2 {
49 return true;
50 }
51 if n % 2 == 0 {
52 return false;
53 }
54
55 if n <= 1_299_709 {
57 return PRIME_SET.contains(&(n as u32));
58 }
59
60 if n < 1_000_000_000_000 {
62 let sqrt_n = (n as f64).sqrt() as u64;
63
64 for &p in PRIMES_100K.iter() {
66 let p64 = p as u64;
67 if p64 > sqrt_n {
68 return true;
69 }
70 if n % p64 == 0 {
71 return false;
72 }
73 }
74
75 Self::is_prime_wheel(n, PRIMES_100K[PRIMES_100K.len() - 1] as u64)
78 } else {
79 Self::miller_rabin(n)
81 }
82 }
83
84 pub fn prime_count(n: u64) -> usize {
86 if n < 2 {
87 return 0;
88 }
89
90 if n <= 1_299_709 {
92 match PRIMES_100K.binary_search(&(n as u32)) {
93 Ok(idx) => idx + 1, Err(idx) => idx, }
96 } else {
97 Self::approximate_prime_count(n)
100 }
101 }
102
103 pub fn next_prime(n: u64) -> u64 {
105 if n <= 2 {
106 return 2;
107 }
108
109 if n <= 1_299_709 {
111 let target = n as u32;
112 match PRIMES_100K.binary_search(&target) {
113 Ok(_) => n, Err(idx) => {
115 if idx < PRIMES_100K.len() {
116 PRIMES_100K[idx] as u64
117 } else {
118 Self::find_next_prime_slow(n)
120 }
121 }
122 }
123 } else {
124 Self::find_next_prime_slow(n)
125 }
126 }
127
128 pub fn prev_prime(n: u64) -> Option<u64> {
130 if n < 2 {
131 return None;
132 }
133 if n == 2 {
134 return Some(2);
135 }
136
137 if n <= 1_299_709 {
139 let target = n as u32;
140 match PRIMES_100K.binary_search(&target) {
141 Ok(_) => Some(n), Err(idx) => {
143 if idx > 0 {
144 Some(PRIMES_100K[idx - 1] as u64)
145 } else {
146 None }
148 }
149 }
150 } else {
151 Self::find_prev_prime_slow(n)
152 }
153 }
154
155 pub fn factor(mut n: u64) -> Vec<(u64, u32)> {
157 if n <= 1 {
158 return vec![];
159 }
160
161 let mut factors = Vec::new();
162
163 for &p in PRIMES_10K.iter() {
165 let p64 = p as u64;
166 if p64 * p64 > n {
167 break;
168 }
169
170 let mut count = 0;
171 while n % p64 == 0 {
172 n /= p64;
173 count += 1;
174 }
175
176 if count > 0 {
177 factors.push((p64, count));
178 }
179 }
180
181 if n > 1 {
183 if Self::is_prime(n) {
185 factors.push((n, 1));
186 } else {
187 factors.push((n, 1));
190 }
191 }
192
193 factors
194 }
195
196 fn generate_nth_prime(n: usize) -> Result<u64> {
200 let cache = EXTENDED_PRIME_CACHE.read().unwrap();
202 let cache_start = 100_001;
203 let cache_idx = n - cache_start;
204
205 if cache_idx < cache.len() {
206 return Ok(cache[cache_idx]);
207 }
208 drop(cache);
209
210 let mut cache = EXTENDED_PRIME_CACHE.write().unwrap();
212
213 let mut candidate = PRIMES_100K[PRIMES_100K.len() - 1] as u64 + 2;
215 let mut count = 100_000 + cache.len();
216
217 while count < n {
218 if Self::is_prime(candidate) {
219 cache.push(candidate);
220 count += 1;
221 }
222 candidate += 2;
223 }
224
225 Ok(cache[cache_idx])
226 }
227
228 fn is_prime_wheel(n: u64, start: u64) -> bool {
230 const WHEEL: &[u64] = &[1, 7, 11, 13, 17, 19, 23, 29];
232
233 let sqrt_n = (n as f64).sqrt() as u64;
234 let mut base = ((start / 30) + 1) * 30;
235
236 while base <= sqrt_n {
237 for &offset in WHEEL {
238 let candidate = base + offset;
239 if candidate > sqrt_n {
240 return true;
241 }
242 if candidate > start && n % candidate == 0 {
243 return false;
244 }
245 }
246 base += 30;
247 }
248 true
249 }
250
251 fn miller_rabin(n: u64) -> bool {
253 const WITNESSES: &[u64] = &[2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37];
255
256 let mut d = n - 1;
258 let mut r = 0;
259 while d % 2 == 0 {
260 d /= 2;
261 r += 1;
262 }
263
264 'witness: for &a in WITNESSES {
265 if a >= n {
266 continue;
267 }
268
269 let mut x = Self::mod_pow(a, d, n);
270 if x == 1 || x == n - 1 {
271 continue;
272 }
273
274 for _ in 0..r - 1 {
275 x = Self::mod_mul(x, x, n);
276 if x == n - 1 {
277 continue 'witness;
278 }
279 }
280
281 return false;
282 }
283
284 true
285 }
286
287 fn mod_pow(mut base: u64, mut exp: u64, m: u64) -> u64 {
289 let mut result = 1;
290 base %= m;
291
292 while exp > 0 {
293 if exp % 2 == 1 {
294 result = Self::mod_mul(result, base, m);
295 }
296 base = Self::mod_mul(base, base, m);
297 exp /= 2;
298 }
299
300 result
301 }
302
303 fn mod_mul(a: u64, b: u64, m: u64) -> u64 {
305 ((a as u128 * b as u128) % m as u128) as u64
306 }
307
308 fn find_next_prime_slow(mut n: u64) -> u64 {
310 if n % 2 == 0 {
311 n += 1;
312 }
313
314 while !Self::is_prime(n) {
315 n += 2;
316 }
317
318 n
319 }
320
321 fn find_prev_prime_slow(mut n: u64) -> Option<u64> {
323 if n % 2 == 0 {
324 n -= 1;
325 }
326
327 while n > 2 {
328 if Self::is_prime(n) {
329 return Some(n);
330 }
331 n -= 2;
332 }
333
334 if n == 2 {
335 Some(2)
336 } else {
337 None
338 }
339 }
340
341 fn approximate_prime_count(n: u64) -> usize {
343 if n < 2 {
344 return 0;
345 }
346
347 let n_f = n as f64;
348 let ln_n = n_f.ln();
349
350 let approx = n_f / (ln_n - 1.0);
352 approx as usize
353 }
354}
355
356pub struct PrimeFunction;
360
361impl SqlFunction for PrimeFunction {
362 fn signature(&self) -> FunctionSignature {
363 FunctionSignature {
364 name: "PRIME",
365 category: FunctionCategory::Mathematical,
366 arg_count: ArgCount::Fixed(1),
367 description: "Returns the Nth prime number (1-indexed)",
368 returns: "INTEGER",
369 examples: vec![
370 "SELECT PRIME(1)", "SELECT PRIME(100)", "SELECT PRIME(10000)", ],
374 }
375 }
376
377 fn evaluate(&self, args: &[DataValue]) -> Result<DataValue> {
378 self.validate_args(args)?;
379
380 let n = match &args[0] {
381 DataValue::Integer(i) if *i > 0 => *i as usize,
382 DataValue::Integer(_) => return Err(anyhow!("PRIME index must be positive")),
383 DataValue::Float(f) if *f > 0.0 => *f as usize,
384 _ => return Err(anyhow!("PRIME requires a positive integer argument")),
385 };
386
387 let prime = PrimeEngine::nth_prime(n)?;
388 Ok(DataValue::Integer(prime as i64))
389 }
390}
391
392pub struct IsPrimeFunction;
394
395impl SqlFunction for IsPrimeFunction {
396 fn signature(&self) -> FunctionSignature {
397 FunctionSignature {
398 name: "IS_PRIME",
399 category: FunctionCategory::Mathematical,
400 arg_count: ArgCount::Fixed(1),
401 description: "Returns true if the number is prime, false otherwise",
402 returns: "BOOLEAN",
403 examples: vec![
404 "SELECT IS_PRIME(17)", "SELECT IS_PRIME(100)", "SELECT IS_PRIME(104729)", ],
408 }
409 }
410
411 fn evaluate(&self, args: &[DataValue]) -> Result<DataValue> {
412 self.validate_args(args)?;
413
414 let n = match &args[0] {
415 DataValue::Integer(i) if *i >= 0 => *i as u64,
416 DataValue::Integer(_) => return Ok(DataValue::Boolean(false)),
417 DataValue::Float(f) if *f >= 0.0 => *f as u64,
418 _ => return Err(anyhow!("IS_PRIME requires a non-negative integer argument")),
419 };
420
421 Ok(DataValue::Boolean(PrimeEngine::is_prime(n)))
422 }
423}
424
425pub struct PrimeCountFunction;
427
428impl SqlFunction for PrimeCountFunction {
429 fn signature(&self) -> FunctionSignature {
430 FunctionSignature {
431 name: "PRIME_COUNT",
432 category: FunctionCategory::Mathematical,
433 arg_count: ArgCount::Fixed(1),
434 description: "Returns the count of prime numbers up to n (π(n))",
435 returns: "INTEGER",
436 examples: vec![
437 "SELECT PRIME_COUNT(10)", "SELECT PRIME_COUNT(100)", "SELECT PRIME_COUNT(1000)", ],
441 }
442 }
443
444 fn evaluate(&self, args: &[DataValue]) -> Result<DataValue> {
445 self.validate_args(args)?;
446
447 let n = match &args[0] {
448 DataValue::Integer(i) if *i >= 0 => *i as u64,
449 DataValue::Integer(_) => return Ok(DataValue::Integer(0)),
450 DataValue::Float(f) if *f >= 0.0 => *f as u64,
451 _ => {
452 return Err(anyhow!(
453 "PRIME_COUNT requires a non-negative integer argument"
454 ))
455 }
456 };
457
458 Ok(DataValue::Integer(PrimeEngine::prime_count(n) as i64))
459 }
460}
461
462pub struct NextPrimeFunction;
464
465impl SqlFunction for NextPrimeFunction {
466 fn signature(&self) -> FunctionSignature {
467 FunctionSignature {
468 name: "NEXT_PRIME",
469 category: FunctionCategory::Mathematical,
470 arg_count: ArgCount::Fixed(1),
471 description: "Returns the smallest prime number >= n",
472 returns: "INTEGER",
473 examples: vec![
474 "SELECT NEXT_PRIME(100)", "SELECT NEXT_PRIME(97)", "SELECT NEXT_PRIME(1000)", ],
478 }
479 }
480
481 fn evaluate(&self, args: &[DataValue]) -> Result<DataValue> {
482 self.validate_args(args)?;
483
484 let n = match &args[0] {
485 DataValue::Integer(i) if *i >= 0 => *i as u64,
486 DataValue::Integer(_) => return Ok(DataValue::Integer(2)),
487 DataValue::Float(f) if *f >= 0.0 => *f as u64,
488 _ => {
489 return Err(anyhow!(
490 "NEXT_PRIME requires a non-negative integer argument"
491 ))
492 }
493 };
494
495 Ok(DataValue::Integer(PrimeEngine::next_prime(n) as i64))
496 }
497}
498
499pub struct PrevPrimeFunction;
501
502impl SqlFunction for PrevPrimeFunction {
503 fn signature(&self) -> FunctionSignature {
504 FunctionSignature {
505 name: "PREV_PRIME",
506 category: FunctionCategory::Mathematical,
507 arg_count: ArgCount::Fixed(1),
508 description: "Returns the largest prime number <= n",
509 returns: "INTEGER",
510 examples: vec![
511 "SELECT PREV_PRIME(100)", "SELECT PREV_PRIME(97)", "SELECT PREV_PRIME(1000)", ],
515 }
516 }
517
518 fn evaluate(&self, args: &[DataValue]) -> Result<DataValue> {
519 self.validate_args(args)?;
520
521 let n = match &args[0] {
522 DataValue::Integer(i) if *i >= 0 => *i as u64,
523 DataValue::Integer(_) => return Ok(DataValue::Null),
524 DataValue::Float(f) if *f >= 0.0 => *f as u64,
525 _ => {
526 return Err(anyhow!(
527 "PREV_PRIME requires a non-negative integer argument"
528 ))
529 }
530 };
531
532 match PrimeEngine::prev_prime(n) {
533 Some(p) => Ok(DataValue::Integer(p as i64)),
534 None => Ok(DataValue::Null),
535 }
536 }
537}
538
539#[cfg(test)]
540mod tests {
541 use super::*;
542
543 #[test]
544 fn test_nth_prime() {
545 assert_eq!(PrimeEngine::nth_prime(1).unwrap(), 2);
546 assert_eq!(PrimeEngine::nth_prime(10).unwrap(), 29);
547 assert_eq!(PrimeEngine::nth_prime(100).unwrap(), 541);
548 assert_eq!(PrimeEngine::nth_prime(1000).unwrap(), 7919);
549 assert_eq!(PrimeEngine::nth_prime(10000).unwrap(), 104729);
550 }
551
552 #[test]
553 fn test_is_prime() {
554 assert!(!PrimeEngine::is_prime(0));
555 assert!(!PrimeEngine::is_prime(1));
556 assert!(PrimeEngine::is_prime(2));
557 assert!(PrimeEngine::is_prime(17));
558 assert!(!PrimeEngine::is_prime(100));
559 assert!(PrimeEngine::is_prime(104729));
560 assert!(PrimeEngine::is_prime(1299709)); }
562
563 #[test]
564 fn test_prime_count() {
565 assert_eq!(PrimeEngine::prime_count(10), 4); assert_eq!(PrimeEngine::prime_count(100), 25);
567 assert_eq!(PrimeEngine::prime_count(1000), 168);
568 }
569
570 #[test]
571 fn test_next_prev_prime() {
572 assert_eq!(PrimeEngine::next_prime(100), 101);
573 assert_eq!(PrimeEngine::next_prime(97), 97);
574
575 assert_eq!(PrimeEngine::prev_prime(100), Some(97));
576 assert_eq!(PrimeEngine::prev_prime(97), Some(97));
577 assert_eq!(PrimeEngine::prev_prime(1), None);
578 }
579
580 #[test]
581 fn test_factorization() {
582 let factors = PrimeEngine::factor(60);
583 assert_eq!(factors, vec![(2, 2), (3, 1), (5, 1)]);
584
585 let factors = PrimeEngine::factor(97);
586 assert_eq!(factors, vec![(97, 1)]);
587
588 let factors = PrimeEngine::factor(1001);
589 assert_eq!(factors, vec![(7, 1), (11, 1), (13, 1)]);
590 }
591}