rustkernel_procint/
prediction.rs

1//! Next activity prediction kernels.
2//!
3//! This module provides process activity prediction:
4//! - Markov chain-based prediction
5//! - N-gram model prediction
6//! - Batch inference for multiple traces
7
8use crate::types::EventLog;
9use rustkernel_core::traits::GpuKernel;
10use rustkernel_core::{domain::Domain, kernel::KernelMetadata};
11use serde::{Deserialize, Serialize};
12use std::collections::HashMap;
13use std::time::Instant;
14
15// ============================================================================
16// Next Activity Prediction Kernel
17// ============================================================================
18
19/// Prediction model type.
20#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
21pub enum PredictionModelType {
22    /// First-order Markov chain (single previous activity).
23    Markov1,
24    /// Second-order Markov chain (two previous activities).
25    Markov2,
26    /// N-gram model with configurable n.
27    NGram,
28}
29
30impl Default for PredictionModelType {
31    fn default() -> Self {
32        Self::Markov1
33    }
34}
35
36/// Configuration for prediction.
37#[derive(Debug, Clone, Serialize, Deserialize)]
38pub struct PredictionConfig {
39    /// Type of prediction model.
40    pub model_type: PredictionModelType,
41    /// N for N-gram model (ignored for Markov).
42    pub n_gram_size: usize,
43    /// Number of top predictions to return.
44    pub top_k: usize,
45    /// Minimum probability threshold.
46    pub min_probability: f64,
47    /// Use Laplace smoothing for unseen transitions.
48    pub laplace_smoothing: bool,
49}
50
51impl Default for PredictionConfig {
52    fn default() -> Self {
53        Self {
54            model_type: PredictionModelType::Markov1,
55            n_gram_size: 3,
56            top_k: 5,
57            min_probability: 0.01,
58            laplace_smoothing: true,
59        }
60    }
61}
62
63/// Transition matrix for first-order Markov model.
64/// Key: current activity, Value: map of next activity -> count
65pub type TransitionMatrix = HashMap<String, HashMap<String, u64>>;
66
67/// Higher-order transition matrix.
68/// Key: sequence of activities (as tuple), Value: map of next activity -> count
69pub type HigherOrderTransitions = HashMap<Vec<String>, HashMap<String, u64>>;
70
71/// A trained prediction model.
72#[derive(Debug, Clone, Serialize, Deserialize)]
73pub struct PredictionModel {
74    /// Model type.
75    pub model_type: PredictionModelType,
76    /// First-order transitions (activity -> next -> count).
77    pub transitions: TransitionMatrix,
78    /// Higher-order transitions (for Markov2 and N-gram).
79    pub higher_order: HigherOrderTransitions,
80    /// Start activity frequencies.
81    pub start_activities: HashMap<String, u64>,
82    /// End activity frequencies.
83    pub end_activities: HashMap<String, u64>,
84    /// Activity vocabulary.
85    pub vocabulary: Vec<String>,
86    /// Total traces trained on.
87    pub trace_count: u64,
88    /// Total events trained on.
89    pub event_count: u64,
90}
91
92impl Default for PredictionModel {
93    fn default() -> Self {
94        Self {
95            model_type: PredictionModelType::Markov1,
96            transitions: HashMap::new(),
97            higher_order: HashMap::new(),
98            start_activities: HashMap::new(),
99            end_activities: HashMap::new(),
100            vocabulary: Vec::new(),
101            trace_count: 0,
102            event_count: 0,
103        }
104    }
105}
106
107impl PredictionModel {
108    /// Create a new model from an event log.
109    pub fn train(log: &EventLog, config: &PredictionConfig) -> Self {
110        let mut model = Self {
111            model_type: config.model_type,
112            ..Default::default()
113        };
114
115        let mut vocab_set = std::collections::HashSet::new();
116
117        for trace in log.traces.values() {
118            if trace.events.is_empty() {
119                continue;
120            }
121
122            model.trace_count += 1;
123            model.event_count += trace.events.len() as u64;
124
125            let activities: Vec<&str> = trace.events.iter().map(|e| e.activity.as_str()).collect();
126
127            // Record start/end activities
128            if let Some(first) = activities.first() {
129                *model.start_activities.entry(first.to_string()).or_default() += 1;
130            }
131            if let Some(last) = activities.last() {
132                *model.end_activities.entry(last.to_string()).or_default() += 1;
133            }
134
135            // Build vocabulary
136            for act in &activities {
137                vocab_set.insert(act.to_string());
138            }
139
140            // Build transition matrix
141            for window in activities.windows(2) {
142                let from = window[0].to_string();
143                let to = window[1].to_string();
144                *model
145                    .transitions
146                    .entry(from)
147                    .or_default()
148                    .entry(to)
149                    .or_default() += 1;
150            }
151
152            // Build higher-order transitions if needed
153            match config.model_type {
154                PredictionModelType::Markov2 => {
155                    for window in activities.windows(3) {
156                        let key = vec![window[0].to_string(), window[1].to_string()];
157                        let next = window[2].to_string();
158                        *model
159                            .higher_order
160                            .entry(key)
161                            .or_default()
162                            .entry(next)
163                            .or_default() += 1;
164                    }
165                }
166                PredictionModelType::NGram => {
167                    let n = config.n_gram_size;
168                    if activities.len() >= n {
169                        for window in activities.windows(n) {
170                            let key: Vec<String> =
171                                window[..n - 1].iter().map(|s| s.to_string()).collect();
172                            let next = window[n - 1].to_string();
173                            *model
174                                .higher_order
175                                .entry(key)
176                                .or_default()
177                                .entry(next)
178                                .or_default() += 1;
179                        }
180                    }
181                }
182                PredictionModelType::Markov1 => {}
183            }
184        }
185
186        model.vocabulary = vocab_set.into_iter().collect();
187        model.vocabulary.sort();
188
189        model
190    }
191
192    /// Predict next activities for a given sequence.
193    pub fn predict(
194        &self,
195        history: &[String],
196        config: &PredictionConfig,
197    ) -> Vec<ActivityPrediction> {
198        let vocab_size = self.vocabulary.len();
199        let smoothing = if config.laplace_smoothing { 1.0 } else { 0.0 };
200
201        // Get transition counts based on model type
202        let counts: Option<&HashMap<String, u64>> = match self.model_type {
203            PredictionModelType::Markov1 => {
204                history.last().and_then(|last| self.transitions.get(last))
205            }
206            PredictionModelType::Markov2 => {
207                if history.len() >= 2 {
208                    let key = vec![
209                        history[history.len() - 2].clone(),
210                        history[history.len() - 1].clone(),
211                    ];
212                    self.higher_order.get(&key)
213                } else if history.len() == 1 {
214                    // Fall back to first-order
215                    self.transitions.get(&history[0])
216                } else {
217                    None
218                }
219            }
220            PredictionModelType::NGram => {
221                let n = config.n_gram_size;
222                if history.len() >= n - 1 {
223                    let key: Vec<String> = history[history.len() - (n - 1)..].to_vec();
224                    self.higher_order.get(&key)
225                } else if history.len() >= 1 {
226                    // Fall back to first-order
227                    self.transitions.get(&history[history.len() - 1])
228                } else {
229                    None
230                }
231            }
232        };
233
234        // Calculate probabilities
235        let mut predictions: Vec<ActivityPrediction> = if let Some(counts) = counts {
236            let total: u64 = counts.values().sum();
237            let total_with_smoothing = total as f64 + smoothing * vocab_size as f64;
238
239            self.vocabulary
240                .iter()
241                .map(|activity| {
242                    let count = counts.get(activity).copied().unwrap_or(0);
243                    let prob = (count as f64 + smoothing) / total_with_smoothing;
244                    ActivityPrediction {
245                        activity: activity.clone(),
246                        probability: prob,
247                        confidence: if total > 10 { prob } else { prob * 0.5 },
248                        is_end: self.end_activities.contains_key(activity),
249                    }
250                })
251                .filter(|p| p.probability >= config.min_probability)
252                .collect()
253        } else if config.laplace_smoothing && !self.vocabulary.is_empty() {
254            // Uniform distribution with smoothing for unseen context
255            let prob = 1.0 / vocab_size as f64;
256            self.vocabulary
257                .iter()
258                .map(|activity| ActivityPrediction {
259                    activity: activity.clone(),
260                    probability: prob,
261                    confidence: 0.1, // Low confidence for uniform
262                    is_end: self.end_activities.contains_key(activity),
263                })
264                .collect()
265        } else {
266            Vec::new()
267        };
268
269        // Sort by probability descending and take top_k
270        predictions.sort_by(|a, b| {
271            b.probability
272                .partial_cmp(&a.probability)
273                .unwrap_or(std::cmp::Ordering::Equal)
274        });
275        predictions.truncate(config.top_k);
276
277        predictions
278    }
279
280    /// Predict from activity names (convenience method).
281    pub fn predict_from_names(
282        &self,
283        history: &[&str],
284        config: &PredictionConfig,
285    ) -> Vec<ActivityPrediction> {
286        let history: Vec<String> = history.iter().map(|s| s.to_string()).collect();
287        self.predict(&history, config)
288    }
289}
290
291/// A predicted next activity.
292#[derive(Debug, Clone, Serialize, Deserialize)]
293pub struct ActivityPrediction {
294    /// Predicted activity name.
295    pub activity: String,
296    /// Probability of this activity.
297    pub probability: f64,
298    /// Confidence in the prediction (adjusted for data sparsity).
299    pub confidence: f64,
300    /// Whether this is commonly an end activity.
301    pub is_end: bool,
302}
303
304/// Input for batch prediction.
305#[derive(Debug, Clone, Serialize, Deserialize)]
306pub struct PredictionInput {
307    /// Traces to predict next activities for.
308    pub traces: Vec<TraceHistory>,
309    /// Trained model.
310    pub model: PredictionModel,
311    /// Configuration.
312    pub config: PredictionConfig,
313}
314
315/// A trace history for prediction.
316#[derive(Debug, Clone, Serialize, Deserialize)]
317pub struct TraceHistory {
318    /// Case/trace ID.
319    pub case_id: String,
320    /// Activity history (most recent last).
321    pub activities: Vec<String>,
322}
323
324/// Output from batch prediction.
325#[derive(Debug, Clone, Serialize, Deserialize)]
326pub struct PredictionOutput {
327    /// Predictions per trace.
328    pub predictions: Vec<TracePrediction>,
329    /// Compute time in microseconds.
330    pub compute_time_us: u64,
331}
332
333/// Predictions for a single trace.
334#[derive(Debug, Clone, Serialize, Deserialize)]
335pub struct TracePrediction {
336    /// Case/trace ID.
337    pub case_id: String,
338    /// Top-k predictions.
339    pub predictions: Vec<ActivityPrediction>,
340    /// Expected remaining activities (if model supports).
341    pub expected_remaining: Option<f64>,
342}
343
344/// Next activity prediction kernel.
345///
346/// Predicts the next activity in a business process using
347/// Markov chains or N-gram models trained on historical data.
348#[derive(Debug, Clone)]
349pub struct NextActivityPrediction {
350    metadata: KernelMetadata,
351}
352
353impl Default for NextActivityPrediction {
354    fn default() -> Self {
355        Self::new()
356    }
357}
358
359impl NextActivityPrediction {
360    /// Create a new next activity prediction kernel.
361    #[must_use]
362    pub fn new() -> Self {
363        Self {
364            metadata: KernelMetadata::batch("procint/next-activity", Domain::ProcessIntelligence)
365                .with_description("Markov/N-gram next activity prediction")
366                .with_throughput(100_000)
367                .with_latency_us(50.0),
368        }
369    }
370
371    /// Train a model from an event log.
372    pub fn train(log: &EventLog, config: &PredictionConfig) -> PredictionModel {
373        PredictionModel::train(log, config)
374    }
375
376    /// Batch predict for multiple traces.
377    pub fn predict_batch(
378        traces: &[TraceHistory],
379        model: &PredictionModel,
380        config: &PredictionConfig,
381    ) -> Vec<TracePrediction> {
382        traces
383            .iter()
384            .map(|trace| {
385                let predictions = model.predict(&trace.activities, config);
386                TracePrediction {
387                    case_id: trace.case_id.clone(),
388                    predictions,
389                    expected_remaining: None,
390                }
391            })
392            .collect()
393    }
394
395    /// Compute batch predictions.
396    pub fn compute(input: &PredictionInput) -> PredictionOutput {
397        let start = Instant::now();
398        let predictions = Self::predict_batch(&input.traces, &input.model, &input.config);
399        PredictionOutput {
400            predictions,
401            compute_time_us: start.elapsed().as_micros() as u64,
402        }
403    }
404}
405
406impl GpuKernel for NextActivityPrediction {
407    fn metadata(&self) -> &KernelMetadata {
408        &self.metadata
409    }
410}
411
412#[cfg(test)]
413mod tests {
414    use super::*;
415    use crate::types::ProcessEvent;
416
417    fn create_test_log() -> EventLog {
418        let mut log = EventLog::new("test".to_string());
419
420        // Trace 1: A -> B -> C -> D
421        for (i, activity) in ["A", "B", "C", "D"].iter().enumerate() {
422            log.add_event(ProcessEvent {
423                id: i as u64,
424                case_id: "trace1".to_string(),
425                activity: activity.to_string(),
426                timestamp: i as u64 * 100,
427                resource: None,
428                attributes: HashMap::new(),
429            });
430        }
431
432        // Trace 2: A -> B -> C -> D (same pattern)
433        for (i, activity) in ["A", "B", "C", "D"].iter().enumerate() {
434            log.add_event(ProcessEvent {
435                id: (10 + i) as u64,
436                case_id: "trace2".to_string(),
437                activity: activity.to_string(),
438                timestamp: i as u64 * 100,
439                resource: None,
440                attributes: HashMap::new(),
441            });
442        }
443
444        // Trace 3: A -> B -> E -> D (different path)
445        for (i, activity) in ["A", "B", "E", "D"].iter().enumerate() {
446            log.add_event(ProcessEvent {
447                id: (20 + i) as u64,
448                case_id: "trace3".to_string(),
449                activity: activity.to_string(),
450                timestamp: i as u64 * 100,
451                resource: None,
452                attributes: HashMap::new(),
453            });
454        }
455
456        // Trace 4: A -> B -> C -> D
457        for (i, activity) in ["A", "B", "C", "D"].iter().enumerate() {
458            log.add_event(ProcessEvent {
459                id: (30 + i) as u64,
460                case_id: "trace4".to_string(),
461                activity: activity.to_string(),
462                timestamp: i as u64 * 100,
463                resource: None,
464                attributes: HashMap::new(),
465            });
466        }
467
468        log
469    }
470
471    #[test]
472    fn test_next_activity_prediction_metadata() {
473        let kernel = NextActivityPrediction::new();
474        assert_eq!(kernel.metadata().id, "procint/next-activity");
475        assert_eq!(kernel.metadata().domain, Domain::ProcessIntelligence);
476    }
477
478    #[test]
479    fn test_model_training() {
480        let log = create_test_log();
481        let config = PredictionConfig::default();
482        let model = PredictionModel::train(&log, &config);
483
484        assert_eq!(model.trace_count, 4);
485        assert!(model.vocabulary.contains(&"A".to_string()));
486        assert!(model.vocabulary.contains(&"B".to_string()));
487        assert!(model.vocabulary.contains(&"C".to_string()));
488        assert!(model.vocabulary.contains(&"D".to_string()));
489        assert!(model.vocabulary.contains(&"E".to_string()));
490
491        // Check transitions
492        assert!(model.transitions.contains_key("A"));
493        assert!(model.transitions.contains_key("B"));
494    }
495
496    #[test]
497    fn test_first_order_prediction() {
498        let log = create_test_log();
499        let config = PredictionConfig {
500            model_type: PredictionModelType::Markov1,
501            top_k: 3,
502            min_probability: 0.0,
503            laplace_smoothing: false,
504            ..Default::default()
505        };
506        let model = PredictionModel::train(&log, &config);
507
508        // After A, B should be predicted with high probability
509        let predictions = model.predict_from_names(&["A"], &config);
510        assert!(!predictions.is_empty());
511        assert_eq!(predictions[0].activity, "B");
512        assert!(predictions[0].probability > 0.9);
513
514        // After B, C should be most likely (3 traces), E second (1 trace)
515        let predictions = model.predict_from_names(&["B"], &config);
516        assert!(!predictions.is_empty());
517        assert_eq!(predictions[0].activity, "C");
518    }
519
520    #[test]
521    fn test_second_order_prediction() {
522        let log = create_test_log();
523        let config = PredictionConfig {
524            model_type: PredictionModelType::Markov2,
525            top_k: 3,
526            min_probability: 0.0,
527            laplace_smoothing: false,
528            ..Default::default()
529        };
530        let model = PredictionModel::train(&log, &config);
531
532        // After A, B -> C should be predicted (using 2nd order)
533        let predictions = model.predict_from_names(&["A", "B"], &config);
534        assert!(!predictions.is_empty());
535        // C appears after A,B in 3 traces, E in 1 trace
536        assert_eq!(predictions[0].activity, "C");
537    }
538
539    #[test]
540    fn test_batch_prediction() {
541        let log = create_test_log();
542        let config = PredictionConfig::default();
543        let model = PredictionModel::train(&log, &config);
544
545        let traces = vec![
546            TraceHistory {
547                case_id: "test1".to_string(),
548                activities: vec!["A".to_string()],
549            },
550            TraceHistory {
551                case_id: "test2".to_string(),
552                activities: vec!["A".to_string(), "B".to_string()],
553            },
554        ];
555
556        let results = NextActivityPrediction::predict_batch(&traces, &model, &config);
557        assert_eq!(results.len(), 2);
558        assert_eq!(results[0].case_id, "test1");
559        assert_eq!(results[1].case_id, "test2");
560    }
561
562    #[test]
563    fn test_laplace_smoothing() {
564        let log = create_test_log();
565        let config_no_smooth = PredictionConfig {
566            laplace_smoothing: false,
567            top_k: 10,
568            min_probability: 0.0,
569            ..Default::default()
570        };
571        let config_smooth = PredictionConfig {
572            laplace_smoothing: true,
573            top_k: 10,
574            min_probability: 0.0,
575            ..Default::default()
576        };
577        let model = PredictionModel::train(&log, &config_no_smooth);
578
579        // Without smoothing, unseen transition should have 0 probability
580        let pred_no_smooth = model.predict_from_names(&["D"], &config_no_smooth);
581        // D is end activity, so no transitions from it without smoothing
582        let _max_prob = pred_no_smooth.iter().map(|p| p.probability).sum::<f64>();
583
584        // With smoothing, should have non-zero probabilities
585        let pred_smooth = model.predict_from_names(&["D"], &config_smooth);
586        assert!(!pred_smooth.is_empty());
587        assert!(pred_smooth.iter().all(|p| p.probability > 0.0));
588    }
589
590    #[test]
591    fn test_start_end_activities() {
592        let log = create_test_log();
593        let config = PredictionConfig::default();
594        let model = PredictionModel::train(&log, &config);
595
596        // A should be start activity
597        assert!(model.start_activities.contains_key("A"));
598        assert_eq!(model.start_activities.get("A"), Some(&4));
599
600        // D should be end activity
601        assert!(model.end_activities.contains_key("D"));
602        assert_eq!(model.end_activities.get("D"), Some(&4));
603    }
604
605    #[test]
606    fn test_ngram_prediction() {
607        let log = create_test_log();
608        let config = PredictionConfig {
609            model_type: PredictionModelType::NGram,
610            n_gram_size: 3,
611            top_k: 3,
612            min_probability: 0.0,
613            laplace_smoothing: false,
614            ..Default::default()
615        };
616        let model = PredictionModel::train(&log, &config);
617
618        // With 3-gram: A, B -> C or E
619        let predictions = model.predict_from_names(&["A", "B"], &config);
620        assert!(!predictions.is_empty());
621    }
622
623    #[test]
624    fn test_empty_history() {
625        let log = create_test_log();
626        let config = PredictionConfig {
627            laplace_smoothing: true,
628            ..Default::default()
629        };
630        let model = PredictionModel::train(&log, &config);
631
632        // Empty history should return uniform or start distribution
633        let predictions = model.predict(&[], &config);
634        // With smoothing, should return something
635        assert!(!predictions.is_empty() || config.laplace_smoothing);
636    }
637
638    #[test]
639    fn test_compute_output() {
640        let log = create_test_log();
641        let config = PredictionConfig::default();
642        let model = PredictionModel::train(&log, &config);
643
644        let input = PredictionInput {
645            traces: vec![TraceHistory {
646                case_id: "test".to_string(),
647                activities: vec!["A".to_string(), "B".to_string()],
648            }],
649            model,
650            config,
651        };
652
653        let output = NextActivityPrediction::compute(&input);
654        assert_eq!(output.predictions.len(), 1);
655        assert!(output.compute_time_us < 1_000_000); // Should be fast
656    }
657}