1use crate::types::*;
6use rustkernel_core::{domain::Domain, kernel::KernelMetadata, traits::GpuKernel};
7use std::collections::{HashMap, HashSet};
8
9#[derive(Debug, Clone)]
18pub struct FeatureExtraction {
19 metadata: KernelMetadata,
20}
21
22impl Default for FeatureExtraction {
23 fn default() -> Self {
24 Self::new()
25 }
26}
27
28impl FeatureExtraction {
29 #[must_use]
31 pub fn new() -> Self {
32 Self {
33 metadata: KernelMetadata::batch("audit/feature-extraction", Domain::FinancialAudit)
34 .with_description("Audit feature vector extraction")
35 .with_throughput(50_000)
36 .with_latency_us(50.0),
37 }
38 }
39
40 pub fn extract(records: &[AuditRecord], config: &FeatureConfig) -> FeatureExtractionResult {
42 let mut entity_records: HashMap<String, Vec<&AuditRecord>> = HashMap::new();
44 for record in records {
45 entity_records
46 .entry(record.entity_id.clone())
47 .or_default()
48 .push(record);
49 }
50
51 let mut entity_features = Vec::new();
53 for (entity_id, records) in &entity_records {
54 let features = Self::extract_entity_features(entity_id, records, config);
55 entity_features.push(features);
56 }
57
58 let global_stats = Self::calculate_global_stats(&entity_features, config);
60
61 let anomaly_scores = if config.detect_anomalies {
63 Self::calculate_anomaly_scores(&entity_features, &global_stats)
64 } else {
65 HashMap::new()
66 };
67
68 FeatureExtractionResult {
69 entity_features,
70 global_stats,
71 anomaly_scores,
72 }
73 }
74
75 fn extract_entity_features(
77 entity_id: &str,
78 records: &[&AuditRecord],
79 config: &FeatureConfig,
80 ) -> EntityFeatureVector {
81 let mut features = Vec::new();
82 let mut feature_names = Vec::new();
83
84 if config.include_volume_features {
86 let (volume_features, volume_names) = Self::extract_volume_features(records);
87 features.extend(volume_features);
88 feature_names.extend(volume_names);
89 }
90
91 if config.include_temporal_features {
93 let (temporal_features, temporal_names) = Self::extract_temporal_features(records);
94 features.extend(temporal_features);
95 feature_names.extend(temporal_names);
96 }
97
98 if config.include_distribution_features {
100 let (dist_features, dist_names) = Self::extract_distribution_features(records);
101 features.extend(dist_features);
102 feature_names.extend(dist_names);
103 }
104
105 if config.include_network_features {
107 let (network_features, network_names) = Self::extract_network_features(records);
108 features.extend(network_features);
109 feature_names.extend(network_names);
110 }
111
112 EntityFeatureVector {
113 entity_id: entity_id.to_string(),
114 features,
115 feature_names,
116 metadata: HashMap::new(),
117 }
118 }
119
120 fn extract_volume_features(records: &[&AuditRecord]) -> (Vec<f64>, Vec<String>) {
122 let mut features = Vec::new();
123 let mut names = Vec::new();
124
125 features.push(records.len() as f64);
127 names.push("total_count".to_string());
128
129 let total_amount: f64 = records.iter().filter_map(|r| r.amount).sum();
131 features.push(total_amount);
132 names.push("total_amount".to_string());
133
134 let amounts: Vec<f64> = records.iter().filter_map(|r| r.amount).collect();
136 let avg_amount = if !amounts.is_empty() {
137 total_amount / amounts.len() as f64
138 } else {
139 0.0
140 };
141 features.push(avg_amount);
142 names.push("avg_amount".to_string());
143
144 let max_amount = amounts.iter().cloned().fold(0.0, f64::max);
146 features.push(max_amount);
147 names.push("max_amount".to_string());
148
149 let std_amount = Self::std_dev(&amounts);
151 features.push(std_amount);
152 names.push("std_amount".to_string());
153
154 let mut type_counts: HashMap<AuditRecordType, usize> = HashMap::new();
156 for record in records {
157 *type_counts.entry(record.record_type).or_insert(0) += 1;
158 }
159
160 let record_types = [
161 AuditRecordType::JournalEntry,
162 AuditRecordType::Invoice,
163 AuditRecordType::Payment,
164 AuditRecordType::Receipt,
165 AuditRecordType::Adjustment,
166 AuditRecordType::Transfer,
167 AuditRecordType::Expense,
168 AuditRecordType::Revenue,
169 ];
170
171 for rt in record_types {
172 features.push(*type_counts.get(&rt).unwrap_or(&0) as f64);
173 names.push(format!("count_{:?}", rt).to_lowercase());
174 }
175
176 (features, names)
177 }
178
179 fn extract_temporal_features(records: &[&AuditRecord]) -> (Vec<f64>, Vec<String>) {
181 let mut features = Vec::new();
182 let mut names = Vec::new();
183
184 if records.is_empty() {
185 return (
186 vec![0.0; 6],
187 vec![
188 "time_span_days".to_string(),
189 "avg_interval_hours".to_string(),
190 "activity_ratio".to_string(),
191 "weekend_ratio".to_string(),
192 "month_end_ratio".to_string(),
193 "off_hours_ratio".to_string(),
194 ],
195 );
196 }
197
198 let timestamps: Vec<u64> = records.iter().map(|r| r.timestamp).collect();
200 let min_ts = *timestamps.iter().min().unwrap_or(&0);
201 let max_ts = *timestamps.iter().max().unwrap_or(&0);
202 let time_span_days = (max_ts - min_ts) as f64 / 86400.0;
203 features.push(time_span_days);
204 names.push("time_span_days".to_string());
205
206 let mut sorted_ts = timestamps.clone();
208 sorted_ts.sort();
209 let avg_interval = if sorted_ts.len() > 1 {
210 let intervals: Vec<f64> = sorted_ts
211 .windows(2)
212 .map(|w| (w[1] - w[0]) as f64 / 3600.0)
213 .collect();
214 intervals.iter().sum::<f64>() / intervals.len() as f64
215 } else {
216 0.0
217 };
218 features.push(avg_interval);
219 names.push("avg_interval_hours".to_string());
220
221 let unique_days: HashSet<u64> = timestamps.iter().map(|t| t / 86400).collect();
223 let activity_ratio = if time_span_days > 0.0 {
224 unique_days.len() as f64 / time_span_days.max(1.0)
225 } else {
226 0.0
227 };
228 features.push(activity_ratio);
229 names.push("activity_ratio".to_string());
230
231 let weekend_count = timestamps
233 .iter()
234 .filter(|t| {
235 let day_of_week = (*t / 86400) % 7;
236 day_of_week == 5 || day_of_week == 6 })
238 .count();
239 features.push(weekend_count as f64 / records.len() as f64);
240 names.push("weekend_ratio".to_string());
241
242 let month_end_count = timestamps
244 .iter()
245 .filter(|t| {
246 let day_of_month = ((*t / 86400) % 30) as u32;
247 day_of_month >= 25
248 })
249 .count();
250 features.push(month_end_count as f64 / records.len() as f64);
251 names.push("month_end_ratio".to_string());
252
253 let off_hours_count = timestamps
255 .iter()
256 .filter(|t| {
257 let hour = ((*t / 3600) % 24) as u32;
258 !(9..17).contains(&hour)
259 })
260 .count();
261 features.push(off_hours_count as f64 / records.len() as f64);
262 names.push("off_hours_ratio".to_string());
263
264 (features, names)
265 }
266
267 fn extract_distribution_features(records: &[&AuditRecord]) -> (Vec<f64>, Vec<String>) {
269 let mut features = Vec::new();
270 let mut names = Vec::new();
271
272 let amounts: Vec<f64> = records.iter().filter_map(|r| r.amount).collect();
273
274 if amounts.is_empty() {
275 return (
276 vec![0.0; 4],
277 vec![
278 "amount_skewness".to_string(),
279 "amount_kurtosis".to_string(),
280 "round_number_ratio".to_string(),
281 "category_concentration".to_string(),
282 ],
283 );
284 }
285
286 let skewness = Self::skewness(&amounts);
288 features.push(skewness);
289 names.push("amount_skewness".to_string());
290
291 let kurtosis = Self::kurtosis(&amounts);
293 features.push(kurtosis);
294 names.push("amount_kurtosis".to_string());
295
296 let round_count = amounts
298 .iter()
299 .filter(|a| (**a % 100.0).abs() < 0.01 || (**a % 1000.0).abs() < 0.01)
300 .count();
301 features.push(round_count as f64 / amounts.len() as f64);
302 names.push("round_number_ratio".to_string());
303
304 let mut category_counts: HashMap<&str, usize> = HashMap::new();
306 for record in records {
307 *category_counts.entry(&record.category).or_insert(0) += 1;
308 }
309 let total = records.len() as f64;
310 let hhi: f64 = category_counts
311 .values()
312 .map(|c| (*c as f64 / total).powi(2))
313 .sum();
314 features.push(hhi);
315 names.push("category_concentration".to_string());
316
317 (features, names)
318 }
319
320 fn extract_network_features(records: &[&AuditRecord]) -> (Vec<f64>, Vec<String>) {
322 let mut features = Vec::new();
323 let mut names = Vec::new();
324
325 let unique_accounts: HashSet<&str> = records
327 .iter()
328 .filter_map(|r| r.account.as_deref())
329 .collect();
330 features.push(unique_accounts.len() as f64);
331 names.push("unique_accounts".to_string());
332
333 let unique_counterparties: HashSet<&str> = records
335 .iter()
336 .filter_map(|r| r.counter_party.as_deref())
337 .collect();
338 features.push(unique_counterparties.len() as f64);
339 names.push("unique_counterparties".to_string());
340
341 let mut cp_counts: HashMap<&str, usize> = HashMap::new();
343 for record in records {
344 if let Some(cp) = &record.counter_party {
345 *cp_counts.entry(cp.as_str()).or_insert(0) += 1;
346 }
347 }
348 let total_with_cp = cp_counts.values().sum::<usize>() as f64;
349 let cp_hhi: f64 = if total_with_cp > 0.0 {
350 cp_counts
351 .values()
352 .map(|c| (*c as f64 / total_with_cp).powi(2))
353 .sum()
354 } else {
355 0.0
356 };
357 features.push(cp_hhi);
358 names.push("counterparty_concentration".to_string());
359
360 let self_tx_count = records
362 .iter()
363 .filter(|r| r.account.as_ref() == r.counter_party.as_ref() && r.account.is_some())
364 .count();
365 features.push(self_tx_count as f64 / records.len().max(1) as f64);
366 names.push("self_transaction_ratio".to_string());
367
368 (features, names)
369 }
370
371 fn calculate_global_stats(
373 entity_features: &[EntityFeatureVector],
374 _config: &FeatureConfig,
375 ) -> FeatureStats {
376 if entity_features.is_empty() {
377 return FeatureStats {
378 entity_count: 0,
379 record_count: 0,
380 means: Vec::new(),
381 std_devs: Vec::new(),
382 feature_names: Vec::new(),
383 };
384 }
385
386 let feature_count = entity_features[0].features.len();
387 let entity_count = entity_features.len();
388
389 let mut means = vec![0.0; feature_count];
390 let mut std_devs = vec![0.0; feature_count];
391
392 for ef in entity_features {
394 for (i, f) in ef.features.iter().enumerate() {
395 means[i] += f;
396 }
397 }
398 for m in &mut means {
399 *m /= entity_count as f64;
400 }
401
402 for ef in entity_features {
404 for (i, f) in ef.features.iter().enumerate() {
405 std_devs[i] += (f - means[i]).powi(2);
406 }
407 }
408 for s in &mut std_devs {
409 *s = (*s / entity_count as f64).sqrt();
410 }
411
412 FeatureStats {
413 entity_count,
414 record_count: entity_features
415 .iter()
416 .map(|ef| ef.features.first().map(|f| *f as usize).unwrap_or(0))
417 .sum(),
418 means,
419 std_devs,
420 feature_names: entity_features[0].feature_names.clone(),
421 }
422 }
423
424 fn calculate_anomaly_scores(
426 entity_features: &[EntityFeatureVector],
427 stats: &FeatureStats,
428 ) -> HashMap<String, f64> {
429 let mut scores = HashMap::new();
430
431 for ef in entity_features {
432 let mut entity_score = 0.0;
433 let mut count = 0;
434
435 for (i, f) in ef.features.iter().enumerate() {
436 if i < stats.std_devs.len() && stats.std_devs[i] > 0.0 {
437 let z_score = (f - stats.means[i]).abs() / stats.std_devs[i];
438 entity_score += z_score;
439 count += 1;
440 }
441 }
442
443 if count > 0 {
444 scores.insert(ef.entity_id.clone(), entity_score / count as f64);
445 }
446 }
447
448 scores
449 }
450
451 fn std_dev(values: &[f64]) -> f64 {
453 if values.is_empty() {
454 return 0.0;
455 }
456 let mean = values.iter().sum::<f64>() / values.len() as f64;
457 let variance = values.iter().map(|v| (v - mean).powi(2)).sum::<f64>() / values.len() as f64;
458 variance.sqrt()
459 }
460
461 fn skewness(values: &[f64]) -> f64 {
463 if values.len() < 3 {
464 return 0.0;
465 }
466 let mean = values.iter().sum::<f64>() / values.len() as f64;
467 let std = Self::std_dev(values);
468 if std < f64::EPSILON {
469 return 0.0;
470 }
471 let n = values.len() as f64;
472 values
473 .iter()
474 .map(|v| ((v - mean) / std).powi(3))
475 .sum::<f64>()
476 / n
477 }
478
479 fn kurtosis(values: &[f64]) -> f64 {
481 if values.len() < 4 {
482 return 0.0;
483 }
484 let mean = values.iter().sum::<f64>() / values.len() as f64;
485 let std = Self::std_dev(values);
486 if std < f64::EPSILON {
487 return 0.0;
488 }
489 let n = values.len() as f64;
490 values
491 .iter()
492 .map(|v| ((v - mean) / std).powi(4))
493 .sum::<f64>()
494 / n
495 - 3.0 }
497
498 pub fn get_entity_features<'a>(
500 result: &'a FeatureExtractionResult,
501 entity_id: &str,
502 ) -> Option<&'a EntityFeatureVector> {
503 result
504 .entity_features
505 .iter()
506 .find(|ef| ef.entity_id == entity_id)
507 }
508
509 pub fn top_anomalies(result: &FeatureExtractionResult, limit: usize) -> Vec<(String, f64)> {
511 let mut anomalies: Vec<_> = result
512 .anomaly_scores
513 .iter()
514 .map(|(k, v)| (k.clone(), *v))
515 .collect();
516 anomalies.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
517 anomalies.truncate(limit);
518 anomalies
519 }
520}
521
522impl GpuKernel for FeatureExtraction {
523 fn metadata(&self) -> &KernelMetadata {
524 &self.metadata
525 }
526}
527
528#[derive(Debug, Clone)]
534pub struct FeatureConfig {
535 pub include_volume_features: bool,
537 pub include_temporal_features: bool,
539 pub include_distribution_features: bool,
541 pub include_network_features: bool,
543 pub detect_anomalies: bool,
545}
546
547impl Default for FeatureConfig {
548 fn default() -> Self {
549 Self {
550 include_volume_features: true,
551 include_temporal_features: true,
552 include_distribution_features: true,
553 include_network_features: true,
554 detect_anomalies: true,
555 }
556 }
557}
558
559#[cfg(test)]
564mod tests {
565 use super::*;
566
567 fn create_test_record(
568 id: &str,
569 entity_id: &str,
570 record_type: AuditRecordType,
571 amount: f64,
572 timestamp: u64,
573 ) -> AuditRecord {
574 AuditRecord {
575 id: id.to_string(),
576 record_type,
577 entity_id: entity_id.to_string(),
578 timestamp,
579 amount: Some(amount),
580 currency: Some("USD".to_string()),
581 account: Some(format!("ACC-{}", entity_id)),
582 counter_party: Some("CP001".to_string()),
583 category: "Operating".to_string(),
584 attributes: HashMap::new(),
585 }
586 }
587
588 fn create_test_records() -> Vec<AuditRecord> {
589 vec![
590 create_test_record("R001", "E001", AuditRecordType::Payment, 1000.0, 1000000),
591 create_test_record("R002", "E001", AuditRecordType::Invoice, 1500.0, 1000100),
592 create_test_record("R003", "E001", AuditRecordType::Payment, 500.0, 1000200),
593 create_test_record("R004", "E002", AuditRecordType::Revenue, 10000.0, 1000300),
594 create_test_record("R005", "E002", AuditRecordType::Expense, 3000.0, 1000400),
595 ]
596 }
597
598 #[test]
599 fn test_extract_features() {
600 let records = create_test_records();
601 let config = FeatureConfig::default();
602
603 let result = FeatureExtraction::extract(&records, &config);
604
605 assert_eq!(result.entity_features.len(), 2);
606 assert_eq!(result.global_stats.entity_count, 2);
607 }
608
609 #[test]
610 fn test_entity_features() {
611 let records = create_test_records();
612 let config = FeatureConfig::default();
613
614 let result = FeatureExtraction::extract(&records, &config);
615
616 let e001 = FeatureExtraction::get_entity_features(&result, "E001").unwrap();
617 assert_eq!(e001.entity_id, "E001");
618 assert!(!e001.features.is_empty());
619
620 assert_eq!(e001.features[0], 3.0); }
623
624 #[test]
625 fn test_volume_features() {
626 let records = create_test_records();
627 let config = FeatureConfig {
628 include_volume_features: true,
629 include_temporal_features: false,
630 include_distribution_features: false,
631 include_network_features: false,
632 detect_anomalies: false,
633 };
634
635 let result = FeatureExtraction::extract(&records, &config);
636
637 let e001 = FeatureExtraction::get_entity_features(&result, "E001").unwrap();
638 assert_eq!(e001.features[0], 3.0); assert_eq!(e001.features[1], 3000.0); assert_eq!(e001.features[2], 1000.0); }
643
644 #[test]
645 fn test_anomaly_detection() {
646 let mut records = create_test_records();
647 for i in 0..10 {
649 records.push(create_test_record(
650 &format!("R1{}", i),
651 "E003",
652 AuditRecordType::Payment,
653 100000.0, 1000000 + i * 100,
655 ));
656 }
657
658 let config = FeatureConfig::default();
659 let result = FeatureExtraction::extract(&records, &config);
660
661 assert!(result.anomaly_scores.contains_key("E003"));
663 let top = FeatureExtraction::top_anomalies(&result, 1);
664 assert_eq!(top[0].0, "E003");
665 }
666
667 #[test]
668 fn test_empty_records() {
669 let records: Vec<AuditRecord> = vec![];
670 let config = FeatureConfig::default();
671
672 let result = FeatureExtraction::extract(&records, &config);
673
674 assert!(result.entity_features.is_empty());
675 assert_eq!(result.global_stats.entity_count, 0);
676 }
677
678 #[test]
679 fn test_feature_names() {
680 let records = create_test_records();
681 let config = FeatureConfig::default();
682
683 let result = FeatureExtraction::extract(&records, &config);
684
685 let ef = &result.entity_features[0];
686 assert_eq!(ef.features.len(), ef.feature_names.len());
687 assert!(ef.feature_names.contains(&"total_count".to_string()));
688 assert!(ef.feature_names.contains(&"total_amount".to_string()));
689 }
690
691 #[test]
692 fn test_global_stats() {
693 let records = create_test_records();
694 let config = FeatureConfig::default();
695
696 let result = FeatureExtraction::extract(&records, &config);
697
698 assert_eq!(result.global_stats.entity_count, 2);
699 assert!(!result.global_stats.means.is_empty());
700 assert!(!result.global_stats.std_devs.is_empty());
701 }
702
703 #[test]
704 fn test_network_features() {
705 let mut records = create_test_records();
706 records.push(AuditRecord {
708 id: "R006".to_string(),
709 record_type: AuditRecordType::Payment,
710 entity_id: "E001".to_string(),
711 timestamp: 1000500,
712 amount: Some(500.0),
713 currency: Some("USD".to_string()),
714 account: Some("ACC-E001".to_string()),
715 counter_party: Some("CP002".to_string()),
716 category: "Operating".to_string(),
717 attributes: HashMap::new(),
718 });
719
720 let config = FeatureConfig {
721 include_volume_features: false,
722 include_temporal_features: false,
723 include_distribution_features: false,
724 include_network_features: true,
725 detect_anomalies: false,
726 };
727
728 let result = FeatureExtraction::extract(&records, &config);
729
730 let e001 = FeatureExtraction::get_entity_features(&result, "E001").unwrap();
731 assert!(e001.features[1] >= 2.0); }
734
735 #[test]
736 fn test_selective_features() {
737 let records = create_test_records();
738
739 let config_vol = FeatureConfig {
741 include_volume_features: true,
742 include_temporal_features: false,
743 include_distribution_features: false,
744 include_network_features: false,
745 detect_anomalies: false,
746 };
747 let result_vol = FeatureExtraction::extract(&records, &config_vol);
748
749 let config_all = FeatureConfig::default();
751 let result_all = FeatureExtraction::extract(&records, &config_all);
752
753 assert!(
755 result_all.entity_features[0].features.len()
756 > result_vol.entity_features[0].features.len()
757 );
758 }
759}