Skip to main content

rustkernel_audit/
feature_extraction.rs

1//! Feature extraction kernel for financial audit.
2//!
3//! Extracts feature vectors from audit records for analysis and anomaly detection.
4
5use crate::types::*;
6use rustkernel_core::{domain::Domain, kernel::KernelMetadata, traits::GpuKernel};
7use std::collections::{HashMap, HashSet};
8
9// ============================================================================
10// FeatureExtraction Kernel
11// ============================================================================
12
13/// Feature extraction kernel for audit records.
14///
15/// Extracts numerical feature vectors from audit records for
16/// machine learning analysis and anomaly detection.
17#[derive(Debug, Clone)]
18pub struct FeatureExtraction {
19    metadata: KernelMetadata,
20}
21
22impl Default for FeatureExtraction {
23    fn default() -> Self {
24        Self::new()
25    }
26}
27
28impl FeatureExtraction {
29    /// Create a new feature extraction kernel.
30    #[must_use]
31    pub fn new() -> Self {
32        Self {
33            metadata: KernelMetadata::batch("audit/feature-extraction", Domain::FinancialAudit)
34                .with_description("Audit feature vector extraction")
35                .with_throughput(50_000)
36                .with_latency_us(50.0),
37        }
38    }
39
40    /// Extract features from audit records.
41    pub fn extract(records: &[AuditRecord], config: &FeatureConfig) -> FeatureExtractionResult {
42        // Group records by entity
43        let mut entity_records: HashMap<String, Vec<&AuditRecord>> = HashMap::new();
44        for record in records {
45            entity_records
46                .entry(record.entity_id.clone())
47                .or_default()
48                .push(record);
49        }
50
51        // Extract features for each entity
52        let mut entity_features = Vec::new();
53        for (entity_id, records) in &entity_records {
54            let features = Self::extract_entity_features(entity_id, records, config);
55            entity_features.push(features);
56        }
57
58        // Calculate global statistics
59        let global_stats = Self::calculate_global_stats(&entity_features, config);
60
61        // Calculate anomaly scores
62        let anomaly_scores = if config.detect_anomalies {
63            Self::calculate_anomaly_scores(&entity_features, &global_stats)
64        } else {
65            HashMap::new()
66        };
67
68        FeatureExtractionResult {
69            entity_features,
70            global_stats,
71            anomaly_scores,
72        }
73    }
74
75    /// Extract features for a single entity.
76    fn extract_entity_features(
77        entity_id: &str,
78        records: &[&AuditRecord],
79        config: &FeatureConfig,
80    ) -> EntityFeatureVector {
81        let mut features = Vec::new();
82        let mut feature_names = Vec::new();
83
84        // Transaction volume features
85        if config.include_volume_features {
86            let (volume_features, volume_names) = Self::extract_volume_features(records);
87            features.extend(volume_features);
88            feature_names.extend(volume_names);
89        }
90
91        // Temporal features
92        if config.include_temporal_features {
93            let (temporal_features, temporal_names) = Self::extract_temporal_features(records);
94            features.extend(temporal_features);
95            feature_names.extend(temporal_names);
96        }
97
98        // Distribution features
99        if config.include_distribution_features {
100            let (dist_features, dist_names) = Self::extract_distribution_features(records);
101            features.extend(dist_features);
102            feature_names.extend(dist_names);
103        }
104
105        // Network features
106        if config.include_network_features {
107            let (network_features, network_names) = Self::extract_network_features(records);
108            features.extend(network_features);
109            feature_names.extend(network_names);
110        }
111
112        EntityFeatureVector {
113            entity_id: entity_id.to_string(),
114            features,
115            feature_names,
116            metadata: HashMap::new(),
117        }
118    }
119
120    /// Extract volume-based features.
121    fn extract_volume_features(records: &[&AuditRecord]) -> (Vec<f64>, Vec<String>) {
122        let mut features = Vec::new();
123        let mut names = Vec::new();
124
125        // Total transaction count
126        features.push(records.len() as f64);
127        names.push("total_count".to_string());
128
129        // Total amount
130        let total_amount: f64 = records.iter().filter_map(|r| r.amount).sum();
131        features.push(total_amount);
132        names.push("total_amount".to_string());
133
134        // Average amount
135        let amounts: Vec<f64> = records.iter().filter_map(|r| r.amount).collect();
136        let avg_amount = if !amounts.is_empty() {
137            total_amount / amounts.len() as f64
138        } else {
139            0.0
140        };
141        features.push(avg_amount);
142        names.push("avg_amount".to_string());
143
144        // Max amount
145        let max_amount = amounts.iter().cloned().fold(0.0, f64::max);
146        features.push(max_amount);
147        names.push("max_amount".to_string());
148
149        // Amount standard deviation
150        let std_amount = Self::std_dev(&amounts);
151        features.push(std_amount);
152        names.push("std_amount".to_string());
153
154        // Count by record type
155        let mut type_counts: HashMap<AuditRecordType, usize> = HashMap::new();
156        for record in records {
157            *type_counts.entry(record.record_type).or_insert(0) += 1;
158        }
159
160        let record_types = [
161            AuditRecordType::JournalEntry,
162            AuditRecordType::Invoice,
163            AuditRecordType::Payment,
164            AuditRecordType::Receipt,
165            AuditRecordType::Adjustment,
166            AuditRecordType::Transfer,
167            AuditRecordType::Expense,
168            AuditRecordType::Revenue,
169        ];
170
171        for rt in record_types {
172            features.push(*type_counts.get(&rt).unwrap_or(&0) as f64);
173            names.push(format!("count_{:?}", rt).to_lowercase());
174        }
175
176        (features, names)
177    }
178
179    /// Extract temporal features.
180    fn extract_temporal_features(records: &[&AuditRecord]) -> (Vec<f64>, Vec<String>) {
181        let mut features = Vec::new();
182        let mut names = Vec::new();
183
184        if records.is_empty() {
185            return (
186                vec![0.0; 6],
187                vec![
188                    "time_span_days".to_string(),
189                    "avg_interval_hours".to_string(),
190                    "activity_ratio".to_string(),
191                    "weekend_ratio".to_string(),
192                    "month_end_ratio".to_string(),
193                    "off_hours_ratio".to_string(),
194                ],
195            );
196        }
197
198        // Time span
199        let timestamps: Vec<u64> = records.iter().map(|r| r.timestamp).collect();
200        let min_ts = *timestamps.iter().min().unwrap_or(&0);
201        let max_ts = *timestamps.iter().max().unwrap_or(&0);
202        let time_span_days = (max_ts - min_ts) as f64 / 86400.0;
203        features.push(time_span_days);
204        names.push("time_span_days".to_string());
205
206        // Average interval between transactions
207        let mut sorted_ts = timestamps.clone();
208        sorted_ts.sort();
209        let avg_interval = if sorted_ts.len() > 1 {
210            let intervals: Vec<f64> = sorted_ts
211                .windows(2)
212                .map(|w| (w[1] - w[0]) as f64 / 3600.0)
213                .collect();
214            intervals.iter().sum::<f64>() / intervals.len() as f64
215        } else {
216            0.0
217        };
218        features.push(avg_interval);
219        names.push("avg_interval_hours".to_string());
220
221        // Activity concentration
222        let unique_days: HashSet<u64> = timestamps.iter().map(|t| t / 86400).collect();
223        let activity_ratio = if time_span_days > 0.0 {
224            unique_days.len() as f64 / time_span_days.max(1.0)
225        } else {
226            0.0
227        };
228        features.push(activity_ratio);
229        names.push("activity_ratio".to_string());
230
231        // Weekend activity ratio
232        let weekend_count = timestamps
233            .iter()
234            .filter(|t| {
235                let day_of_week = (*t / 86400) % 7;
236                day_of_week == 5 || day_of_week == 6 // Simplified weekend check
237            })
238            .count();
239        features.push(weekend_count as f64 / records.len() as f64);
240        names.push("weekend_ratio".to_string());
241
242        // Month-end activity ratio (last 5 days of month, simplified)
243        let month_end_count = timestamps
244            .iter()
245            .filter(|t| {
246                let day_of_month = ((*t / 86400) % 30) as u32;
247                day_of_month >= 25
248            })
249            .count();
250        features.push(month_end_count as f64 / records.len() as f64);
251        names.push("month_end_ratio".to_string());
252
253        // Off-hours activity (outside 9-17, simplified)
254        let off_hours_count = timestamps
255            .iter()
256            .filter(|t| {
257                let hour = ((*t / 3600) % 24) as u32;
258                !(9..17).contains(&hour)
259            })
260            .count();
261        features.push(off_hours_count as f64 / records.len() as f64);
262        names.push("off_hours_ratio".to_string());
263
264        (features, names)
265    }
266
267    /// Extract distribution features.
268    fn extract_distribution_features(records: &[&AuditRecord]) -> (Vec<f64>, Vec<String>) {
269        let mut features = Vec::new();
270        let mut names = Vec::new();
271
272        let amounts: Vec<f64> = records.iter().filter_map(|r| r.amount).collect();
273
274        if amounts.is_empty() {
275            return (
276                vec![0.0; 4],
277                vec![
278                    "amount_skewness".to_string(),
279                    "amount_kurtosis".to_string(),
280                    "round_number_ratio".to_string(),
281                    "category_concentration".to_string(),
282                ],
283            );
284        }
285
286        // Skewness
287        let skewness = Self::skewness(&amounts);
288        features.push(skewness);
289        names.push("amount_skewness".to_string());
290
291        // Kurtosis
292        let kurtosis = Self::kurtosis(&amounts);
293        features.push(kurtosis);
294        names.push("amount_kurtosis".to_string());
295
296        // Round number ratio
297        let round_count = amounts
298            .iter()
299            .filter(|a| (**a % 100.0).abs() < 0.01 || (**a % 1000.0).abs() < 0.01)
300            .count();
301        features.push(round_count as f64 / amounts.len() as f64);
302        names.push("round_number_ratio".to_string());
303
304        // Category concentration (HHI)
305        let mut category_counts: HashMap<&str, usize> = HashMap::new();
306        for record in records {
307            *category_counts.entry(&record.category).or_insert(0) += 1;
308        }
309        let total = records.len() as f64;
310        let hhi: f64 = category_counts
311            .values()
312            .map(|c| (*c as f64 / total).powi(2))
313            .sum();
314        features.push(hhi);
315        names.push("category_concentration".to_string());
316
317        (features, names)
318    }
319
320    /// Extract network features.
321    fn extract_network_features(records: &[&AuditRecord]) -> (Vec<f64>, Vec<String>) {
322        let mut features = Vec::new();
323        let mut names = Vec::new();
324
325        // Unique accounts
326        let unique_accounts: HashSet<&str> = records
327            .iter()
328            .filter_map(|r| r.account.as_deref())
329            .collect();
330        features.push(unique_accounts.len() as f64);
331        names.push("unique_accounts".to_string());
332
333        // Unique counterparties
334        let unique_counterparties: HashSet<&str> = records
335            .iter()
336            .filter_map(|r| r.counter_party.as_deref())
337            .collect();
338        features.push(unique_counterparties.len() as f64);
339        names.push("unique_counterparties".to_string());
340
341        // Counterparty concentration
342        let mut cp_counts: HashMap<&str, usize> = HashMap::new();
343        for record in records {
344            if let Some(cp) = &record.counter_party {
345                *cp_counts.entry(cp.as_str()).or_insert(0) += 1;
346            }
347        }
348        let total_with_cp = cp_counts.values().sum::<usize>() as f64;
349        let cp_hhi: f64 = if total_with_cp > 0.0 {
350            cp_counts
351                .values()
352                .map(|c| (*c as f64 / total_with_cp).powi(2))
353                .sum()
354        } else {
355            0.0
356        };
357        features.push(cp_hhi);
358        names.push("counterparty_concentration".to_string());
359
360        // Self-transactions ratio
361        let self_tx_count = records
362            .iter()
363            .filter(|r| r.account.as_ref() == r.counter_party.as_ref() && r.account.is_some())
364            .count();
365        features.push(self_tx_count as f64 / records.len().max(1) as f64);
366        names.push("self_transaction_ratio".to_string());
367
368        (features, names)
369    }
370
371    /// Calculate global statistics.
372    fn calculate_global_stats(
373        entity_features: &[EntityFeatureVector],
374        _config: &FeatureConfig,
375    ) -> FeatureStats {
376        if entity_features.is_empty() {
377            return FeatureStats {
378                entity_count: 0,
379                record_count: 0,
380                means: Vec::new(),
381                std_devs: Vec::new(),
382                feature_names: Vec::new(),
383            };
384        }
385
386        let feature_count = entity_features[0].features.len();
387        let entity_count = entity_features.len();
388
389        let mut means = vec![0.0; feature_count];
390        let mut std_devs = vec![0.0; feature_count];
391
392        // Calculate means
393        for ef in entity_features {
394            for (i, f) in ef.features.iter().enumerate() {
395                means[i] += f;
396            }
397        }
398        for m in &mut means {
399            *m /= entity_count as f64;
400        }
401
402        // Calculate standard deviations
403        for ef in entity_features {
404            for (i, f) in ef.features.iter().enumerate() {
405                std_devs[i] += (f - means[i]).powi(2);
406            }
407        }
408        for s in &mut std_devs {
409            *s = (*s / entity_count as f64).sqrt();
410        }
411
412        FeatureStats {
413            entity_count,
414            record_count: entity_features
415                .iter()
416                .map(|ef| ef.features.first().map(|f| *f as usize).unwrap_or(0))
417                .sum(),
418            means,
419            std_devs,
420            feature_names: entity_features[0].feature_names.clone(),
421        }
422    }
423
424    /// Calculate anomaly scores using z-score method.
425    fn calculate_anomaly_scores(
426        entity_features: &[EntityFeatureVector],
427        stats: &FeatureStats,
428    ) -> HashMap<String, f64> {
429        let mut scores = HashMap::new();
430
431        for ef in entity_features {
432            let mut entity_score = 0.0;
433            let mut count = 0;
434
435            for (i, f) in ef.features.iter().enumerate() {
436                if i < stats.std_devs.len() && stats.std_devs[i] > 0.0 {
437                    let z_score = (f - stats.means[i]).abs() / stats.std_devs[i];
438                    entity_score += z_score;
439                    count += 1;
440                }
441            }
442
443            if count > 0 {
444                scores.insert(ef.entity_id.clone(), entity_score / count as f64);
445            }
446        }
447
448        scores
449    }
450
451    /// Calculate standard deviation.
452    fn std_dev(values: &[f64]) -> f64 {
453        if values.is_empty() {
454            return 0.0;
455        }
456        let mean = values.iter().sum::<f64>() / values.len() as f64;
457        let variance = values.iter().map(|v| (v - mean).powi(2)).sum::<f64>() / values.len() as f64;
458        variance.sqrt()
459    }
460
461    /// Calculate skewness.
462    fn skewness(values: &[f64]) -> f64 {
463        if values.len() < 3 {
464            return 0.0;
465        }
466        let mean = values.iter().sum::<f64>() / values.len() as f64;
467        let std = Self::std_dev(values);
468        if std < f64::EPSILON {
469            return 0.0;
470        }
471        let n = values.len() as f64;
472        values
473            .iter()
474            .map(|v| ((v - mean) / std).powi(3))
475            .sum::<f64>()
476            / n
477    }
478
479    /// Calculate kurtosis.
480    fn kurtosis(values: &[f64]) -> f64 {
481        if values.len() < 4 {
482            return 0.0;
483        }
484        let mean = values.iter().sum::<f64>() / values.len() as f64;
485        let std = Self::std_dev(values);
486        if std < f64::EPSILON {
487            return 0.0;
488        }
489        let n = values.len() as f64;
490        values
491            .iter()
492            .map(|v| ((v - mean) / std).powi(4))
493            .sum::<f64>()
494            / n
495            - 3.0 // Excess kurtosis
496    }
497
498    /// Get feature vector for a specific entity.
499    pub fn get_entity_features<'a>(
500        result: &'a FeatureExtractionResult,
501        entity_id: &str,
502    ) -> Option<&'a EntityFeatureVector> {
503        result
504            .entity_features
505            .iter()
506            .find(|ef| ef.entity_id == entity_id)
507    }
508
509    /// Get top anomalous entities.
510    pub fn top_anomalies(result: &FeatureExtractionResult, limit: usize) -> Vec<(String, f64)> {
511        let mut anomalies: Vec<_> = result
512            .anomaly_scores
513            .iter()
514            .map(|(k, v)| (k.clone(), *v))
515            .collect();
516        anomalies.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
517        anomalies.truncate(limit);
518        anomalies
519    }
520}
521
522impl GpuKernel for FeatureExtraction {
523    fn metadata(&self) -> &KernelMetadata {
524        &self.metadata
525    }
526}
527
528// ============================================================================
529// Configuration Types
530// ============================================================================
531
532/// Feature extraction configuration.
533#[derive(Debug, Clone)]
534pub struct FeatureConfig {
535    /// Include volume features.
536    pub include_volume_features: bool,
537    /// Include temporal features.
538    pub include_temporal_features: bool,
539    /// Include distribution features.
540    pub include_distribution_features: bool,
541    /// Include network features.
542    pub include_network_features: bool,
543    /// Detect anomalies.
544    pub detect_anomalies: bool,
545}
546
547impl Default for FeatureConfig {
548    fn default() -> Self {
549        Self {
550            include_volume_features: true,
551            include_temporal_features: true,
552            include_distribution_features: true,
553            include_network_features: true,
554            detect_anomalies: true,
555        }
556    }
557}
558
559// ============================================================================
560// Tests
561// ============================================================================
562
563#[cfg(test)]
564mod tests {
565    use super::*;
566
567    fn create_test_record(
568        id: &str,
569        entity_id: &str,
570        record_type: AuditRecordType,
571        amount: f64,
572        timestamp: u64,
573    ) -> AuditRecord {
574        AuditRecord {
575            id: id.to_string(),
576            record_type,
577            entity_id: entity_id.to_string(),
578            timestamp,
579            amount: Some(amount),
580            currency: Some("USD".to_string()),
581            account: Some(format!("ACC-{}", entity_id)),
582            counter_party: Some("CP001".to_string()),
583            category: "Operating".to_string(),
584            attributes: HashMap::new(),
585        }
586    }
587
588    fn create_test_records() -> Vec<AuditRecord> {
589        vec![
590            create_test_record("R001", "E001", AuditRecordType::Payment, 1000.0, 1000000),
591            create_test_record("R002", "E001", AuditRecordType::Invoice, 1500.0, 1000100),
592            create_test_record("R003", "E001", AuditRecordType::Payment, 500.0, 1000200),
593            create_test_record("R004", "E002", AuditRecordType::Revenue, 10000.0, 1000300),
594            create_test_record("R005", "E002", AuditRecordType::Expense, 3000.0, 1000400),
595        ]
596    }
597
598    #[test]
599    fn test_extract_features() {
600        let records = create_test_records();
601        let config = FeatureConfig::default();
602
603        let result = FeatureExtraction::extract(&records, &config);
604
605        assert_eq!(result.entity_features.len(), 2);
606        assert_eq!(result.global_stats.entity_count, 2);
607    }
608
609    #[test]
610    fn test_entity_features() {
611        let records = create_test_records();
612        let config = FeatureConfig::default();
613
614        let result = FeatureExtraction::extract(&records, &config);
615
616        let e001 = FeatureExtraction::get_entity_features(&result, "E001").unwrap();
617        assert_eq!(e001.entity_id, "E001");
618        assert!(!e001.features.is_empty());
619
620        // E001 has 3 records
621        assert_eq!(e001.features[0], 3.0); // total_count
622    }
623
624    #[test]
625    fn test_volume_features() {
626        let records = create_test_records();
627        let config = FeatureConfig {
628            include_volume_features: true,
629            include_temporal_features: false,
630            include_distribution_features: false,
631            include_network_features: false,
632            detect_anomalies: false,
633        };
634
635        let result = FeatureExtraction::extract(&records, &config);
636
637        let e001 = FeatureExtraction::get_entity_features(&result, "E001").unwrap();
638        // E001: 3 transactions, total 3000, avg 1000
639        assert_eq!(e001.features[0], 3.0); // total_count
640        assert_eq!(e001.features[1], 3000.0); // total_amount
641        assert_eq!(e001.features[2], 1000.0); // avg_amount
642    }
643
644    #[test]
645    fn test_anomaly_detection() {
646        let mut records = create_test_records();
647        // Add an anomalous entity
648        for i in 0..10 {
649            records.push(create_test_record(
650                &format!("R1{}", i),
651                "E003",
652                AuditRecordType::Payment,
653                100000.0, // Very high amounts
654                1000000 + i * 100,
655            ));
656        }
657
658        let config = FeatureConfig::default();
659        let result = FeatureExtraction::extract(&records, &config);
660
661        // E003 should have high anomaly score
662        assert!(result.anomaly_scores.contains_key("E003"));
663        let top = FeatureExtraction::top_anomalies(&result, 1);
664        assert_eq!(top[0].0, "E003");
665    }
666
667    #[test]
668    fn test_empty_records() {
669        let records: Vec<AuditRecord> = vec![];
670        let config = FeatureConfig::default();
671
672        let result = FeatureExtraction::extract(&records, &config);
673
674        assert!(result.entity_features.is_empty());
675        assert_eq!(result.global_stats.entity_count, 0);
676    }
677
678    #[test]
679    fn test_feature_names() {
680        let records = create_test_records();
681        let config = FeatureConfig::default();
682
683        let result = FeatureExtraction::extract(&records, &config);
684
685        let ef = &result.entity_features[0];
686        assert_eq!(ef.features.len(), ef.feature_names.len());
687        assert!(ef.feature_names.contains(&"total_count".to_string()));
688        assert!(ef.feature_names.contains(&"total_amount".to_string()));
689    }
690
691    #[test]
692    fn test_global_stats() {
693        let records = create_test_records();
694        let config = FeatureConfig::default();
695
696        let result = FeatureExtraction::extract(&records, &config);
697
698        assert_eq!(result.global_stats.entity_count, 2);
699        assert!(!result.global_stats.means.is_empty());
700        assert!(!result.global_stats.std_devs.is_empty());
701    }
702
703    #[test]
704    fn test_network_features() {
705        let mut records = create_test_records();
706        // Add records with different counterparties
707        records.push(AuditRecord {
708            id: "R006".to_string(),
709            record_type: AuditRecordType::Payment,
710            entity_id: "E001".to_string(),
711            timestamp: 1000500,
712            amount: Some(500.0),
713            currency: Some("USD".to_string()),
714            account: Some("ACC-E001".to_string()),
715            counter_party: Some("CP002".to_string()),
716            category: "Operating".to_string(),
717            attributes: HashMap::new(),
718        });
719
720        let config = FeatureConfig {
721            include_volume_features: false,
722            include_temporal_features: false,
723            include_distribution_features: false,
724            include_network_features: true,
725            detect_anomalies: false,
726        };
727
728        let result = FeatureExtraction::extract(&records, &config);
729
730        let e001 = FeatureExtraction::get_entity_features(&result, "E001").unwrap();
731        // E001 now has 2 unique counterparties
732        assert!(e001.features[1] >= 2.0); // unique_counterparties
733    }
734
735    #[test]
736    fn test_selective_features() {
737        let records = create_test_records();
738
739        // Volume only
740        let config_vol = FeatureConfig {
741            include_volume_features: true,
742            include_temporal_features: false,
743            include_distribution_features: false,
744            include_network_features: false,
745            detect_anomalies: false,
746        };
747        let result_vol = FeatureExtraction::extract(&records, &config_vol);
748
749        // All features
750        let config_all = FeatureConfig::default();
751        let result_all = FeatureExtraction::extract(&records, &config_all);
752
753        // All features should have more features
754        assert!(
755            result_all.entity_features[0].features.len()
756                > result_vol.entity_features[0].features.len()
757        );
758    }
759}