Skip to main content

tensorlogic_quantrs_hooks/
linear_chain_crf.rs

1//! Linear-chain Conditional Random Fields (CRFs).
2//!
3//! Linear-chain CRFs are a special case of CRFs where the output variables form a chain.
4//! This structure enables efficient inference and learning using dynamic programming.
5//!
6//! # Applications
7//!
8//! - Sequence labeling (POS tagging, NER, etc.)
9//! - Speech recognition
10//! - Bioinformatics (protein sequence analysis)
11//!
12//! # Algorithm
13//!
14//! Given input sequence x = (x₁, ..., xₙ) and output sequence y = (y₁, ..., yₙ):
15//!
16//! ```text
17//! P(y|x) = (1/Z(x)) × exp(Σᵢ Σₖ λₖ fₖ(yᵢ₋₁, yᵢ, x, i))
18//! ```
19//!
20//! Where:
21//! - fₖ are feature functions
22//! - λₖ are learned weights
23//! - Z(x) is the partition function
24//!
25//! # References
26//!
27//! - Lafferty et al., "Conditional Random Fields: Probabilistic Models for Segmenting
28//!   and Labeling Sequence Data" (2001)
29
30use crate::{Factor, FactorGraph, PgmError, Result};
31use scirs2_core::ndarray::{Array1, Array2};
32
33/// Feature function for linear-chain CRF.
34///
35/// Features can be:
36/// - Transition features: depend on (yᵢ₋₁, yᵢ, x, i)
37/// - Emission features: depend on (yᵢ, x, i)
38pub trait FeatureFunction: Send + Sync {
39    /// Compute feature value for a transition.
40    ///
41    /// # Arguments
42    /// * `prev_label` - Previous output label (None for first position)
43    /// * `curr_label` - Current output label
44    /// * `input_sequence` - Input sequence
45    /// * `position` - Current position in sequence
46    fn compute(
47        &self,
48        prev_label: Option<usize>,
49        curr_label: usize,
50        input_sequence: &[usize],
51        position: usize,
52    ) -> f64;
53
54    /// Get feature name/description.
55    fn name(&self) -> &str;
56}
57
58/// Linear-chain CRF for sequence labeling.
59///
60/// This specialized CRF structure enables efficient inference using
61/// the forward-backward algorithm and Viterbi decoding.
62pub struct LinearChainCRF {
63    /// Number of states (labels)
64    num_states: usize,
65    /// Feature functions with their weights
66    features: Vec<(Box<dyn FeatureFunction>, f64)>,
67    /// Transition weights matrix: [from_state, to_state]
68    transition_weights: Option<Array2<f64>>,
69    /// Emission weights matrix: [state, observation]
70    emission_weights: Option<Array2<f64>>,
71}
72
73impl LinearChainCRF {
74    /// Create a new linear-chain CRF.
75    pub fn new(num_states: usize) -> Self {
76        Self {
77            num_states,
78            features: Vec::new(),
79            transition_weights: None,
80            emission_weights: None,
81        }
82    }
83
84    /// Add a feature function with its weight.
85    pub fn add_feature(&mut self, feature: Box<dyn FeatureFunction>, weight: f64) {
86        self.features.push((feature, weight));
87    }
88
89    /// Set transition weights directly.
90    ///
91    /// This is useful when you have pre-trained weights.
92    pub fn set_transition_weights(&mut self, weights: Array2<f64>) -> Result<()> {
93        if weights.shape() != [self.num_states, self.num_states] {
94            return Err(PgmError::DimensionMismatch {
95                expected: vec![self.num_states, self.num_states],
96                got: weights.shape().to_vec(),
97            });
98        }
99        self.transition_weights = Some(weights);
100        Ok(())
101    }
102
103    /// Set emission weights directly.
104    pub fn set_emission_weights(&mut self, weights: Array2<f64>) -> Result<()> {
105        if weights.shape()[0] != self.num_states {
106            return Err(PgmError::DimensionMismatch {
107                expected: vec![self.num_states, weights.shape()[1]],
108                got: weights.shape().to_vec(),
109            });
110        }
111        self.emission_weights = Some(weights);
112        Ok(())
113    }
114
115    /// Compute feature scores for a sequence.
116    fn compute_feature_scores(&self, input_sequence: &[usize], position: usize) -> Array2<f64> {
117        let mut scores = Array2::zeros((self.num_states, self.num_states));
118
119        // Transition features (from prev_state to curr_state)
120        for prev_state in 0..self.num_states {
121            for curr_state in 0..self.num_states {
122                let mut score = 0.0;
123
124                // Compute weighted feature sum
125                for (feature, weight) in &self.features {
126                    let feat_val =
127                        feature.compute(Some(prev_state), curr_state, input_sequence, position);
128                    score += weight * feat_val;
129                }
130
131                scores[[prev_state, curr_state]] = score;
132            }
133        }
134
135        scores
136    }
137
138    /// Compute emission scores for a position.
139    fn compute_emission_scores(&self, input_sequence: &[usize], position: usize) -> Array1<f64> {
140        let mut scores = Array1::zeros(self.num_states);
141
142        for state in 0..self.num_states {
143            let mut score = 0.0;
144
145            // Emission features
146            for (feature, weight) in &self.features {
147                let feat_val = feature.compute(None, state, input_sequence, position);
148                score += weight * feat_val;
149            }
150
151            // Add pre-trained emission weights if available
152            if let Some(ref emission_weights) = self.emission_weights {
153                if position < input_sequence.len() {
154                    let obs = input_sequence[position];
155                    if obs < emission_weights.shape()[1] {
156                        score += emission_weights[[state, obs]];
157                    }
158                }
159            }
160
161            scores[state] = score;
162        }
163
164        scores
165    }
166
167    /// Viterbi algorithm for finding the most likely label sequence.
168    ///
169    /// Returns the optimal label sequence and its score.
170    pub fn viterbi(&self, input_sequence: &[usize]) -> Result<(Vec<usize>, f64)> {
171        if input_sequence.is_empty() {
172            return Err(PgmError::InvalidGraph("Empty input sequence".to_string()));
173        }
174
175        let seq_len = input_sequence.len();
176
177        // Viterbi table: [position, state] -> max score
178        let mut viterbi_table = Array2::zeros((seq_len, self.num_states));
179
180        // Backpointer table: [position, state] -> previous state
181        let mut backpointers = Array2::zeros((seq_len, self.num_states));
182
183        // Initialize first position
184        let emission_scores = self.compute_emission_scores(input_sequence, 0);
185        for state in 0..self.num_states {
186            viterbi_table[[0, state]] = emission_scores[state];
187        }
188
189        // Forward pass
190        for t in 1..seq_len {
191            let emission_scores = self.compute_emission_scores(input_sequence, t);
192            let transition_scores = if let Some(ref weights) = self.transition_weights {
193                weights.clone()
194            } else {
195                self.compute_feature_scores(input_sequence, t)
196            };
197
198            for curr_state in 0..self.num_states {
199                let mut max_score = f64::NEG_INFINITY;
200                let mut best_prev_state = 0;
201
202                for prev_state in 0..self.num_states {
203                    let score = viterbi_table[[t - 1, prev_state]]
204                        + transition_scores[[prev_state, curr_state]]
205                        + emission_scores[curr_state];
206
207                    if score > max_score {
208                        max_score = score;
209                        best_prev_state = prev_state;
210                    }
211                }
212
213                viterbi_table[[t, curr_state]] = max_score;
214                backpointers[[t, curr_state]] = best_prev_state as f64;
215            }
216        }
217
218        // Find best final state
219        let mut best_final_state = 0;
220        let mut best_final_score = f64::NEG_INFINITY;
221        for state in 0..self.num_states {
222            let score = viterbi_table[[seq_len - 1, state]];
223            if score > best_final_score {
224                best_final_score = score;
225                best_final_state = state;
226            }
227        }
228
229        // Backward pass to reconstruct path
230        let mut path = vec![0; seq_len];
231        path[seq_len - 1] = best_final_state;
232
233        for t in (1..seq_len).rev() {
234            path[t - 1] = backpointers[[t, path[t]]] as usize;
235        }
236
237        Ok((path, best_final_score))
238    }
239
240    /// Forward algorithm for computing marginal probabilities.
241    ///
242    /// Returns forward probabilities: α[t, s] = P(y₁...yₜ = s, x₁...xₜ)
243    pub fn forward(&self, input_sequence: &[usize]) -> Result<Array2<f64>> {
244        if input_sequence.is_empty() {
245            return Err(PgmError::InvalidGraph("Empty input sequence".to_string()));
246        }
247
248        let seq_len = input_sequence.len();
249        let mut alpha = Array2::zeros((seq_len, self.num_states));
250
251        // Initialize first position
252        let emission_scores = self.compute_emission_scores(input_sequence, 0);
253        for state in 0..self.num_states {
254            alpha[[0, state]] = emission_scores[state].exp();
255        }
256
257        // Normalize initial position
258        let init_sum: f64 = alpha.row(0).sum();
259        if init_sum > 0.0 {
260            for state in 0..self.num_states {
261                alpha[[0, state]] /= init_sum;
262            }
263        }
264
265        // Forward pass
266        for t in 1..seq_len {
267            let emission_scores = self.compute_emission_scores(input_sequence, t);
268            let transition_scores = if let Some(ref weights) = self.transition_weights {
269                weights.clone()
270            } else {
271                self.compute_feature_scores(input_sequence, t)
272            };
273
274            for curr_state in 0..self.num_states {
275                let mut sum = 0.0;
276
277                for prev_state in 0..self.num_states {
278                    sum += alpha[[t - 1, prev_state]]
279                        * (transition_scores[[prev_state, curr_state]]
280                            + emission_scores[curr_state])
281                            .exp();
282                }
283
284                alpha[[t, curr_state]] = sum;
285            }
286
287            // Normalize to prevent underflow
288            let row_sum: f64 = alpha.row(t).sum();
289            if row_sum > 0.0 {
290                for state in 0..self.num_states {
291                    alpha[[t, state]] /= row_sum;
292                }
293            }
294        }
295
296        Ok(alpha)
297    }
298
299    /// Backward algorithm for computing marginal probabilities.
300    ///
301    /// Returns backward probabilities: β[t, s] = P(yₜ₊₁...yₙ | yₜ = s, xₜ₊₁...xₙ)
302    pub fn backward(&self, input_sequence: &[usize]) -> Result<Array2<f64>> {
303        if input_sequence.is_empty() {
304            return Err(PgmError::InvalidGraph("Empty input sequence".to_string()));
305        }
306
307        let seq_len = input_sequence.len();
308        let mut beta = Array2::zeros((seq_len, self.num_states));
309
310        // Initialize last position
311        for state in 0..self.num_states {
312            beta[[seq_len - 1, state]] = 1.0;
313        }
314
315        // Backward pass
316        for t in (0..seq_len - 1).rev() {
317            let emission_scores = self.compute_emission_scores(input_sequence, t + 1);
318            let transition_scores = if let Some(ref weights) = self.transition_weights {
319                weights.clone()
320            } else {
321                self.compute_feature_scores(input_sequence, t + 1)
322            };
323
324            for curr_state in 0..self.num_states {
325                let mut sum = 0.0;
326
327                for next_state in 0..self.num_states {
328                    sum += beta[[t + 1, next_state]]
329                        * (transition_scores[[curr_state, next_state]]
330                            + emission_scores[next_state])
331                            .exp();
332                }
333
334                beta[[t, curr_state]] = sum;
335            }
336
337            // Normalize to prevent overflow
338            let row_sum: f64 = beta.row(t).sum();
339            if row_sum > 0.0 {
340                for state in 0..self.num_states {
341                    beta[[t, state]] /= row_sum;
342                }
343            }
344        }
345
346        Ok(beta)
347    }
348
349    /// Compute marginal probabilities for each position.
350    ///
351    /// Returns: P(yₜ = s | x) for all t and s
352    pub fn marginals(&self, input_sequence: &[usize]) -> Result<Array2<f64>> {
353        let alpha = self.forward(input_sequence)?;
354        let beta = self.backward(input_sequence)?;
355
356        let seq_len = input_sequence.len();
357        let mut marginals = Array2::zeros((seq_len, self.num_states));
358
359        for t in 0..seq_len {
360            for state in 0..self.num_states {
361                marginals[[t, state]] = alpha[[t, state]] * beta[[t, state]];
362            }
363
364            // Normalize
365            let row_sum: f64 = marginals.row(t).sum();
366            if row_sum > 0.0 {
367                for state in 0..self.num_states {
368                    marginals[[t, state]] /= row_sum;
369                }
370            }
371        }
372
373        Ok(marginals)
374    }
375
376    /// Convert to factor graph representation.
377    pub fn to_factor_graph(&self, input_sequence: &[usize]) -> Result<FactorGraph> {
378        let mut graph = FactorGraph::new();
379        let seq_len = input_sequence.len();
380
381        // Add variables for each position
382        for t in 0..seq_len {
383            graph.add_variable_with_card(format!("y_{}", t), "Label".to_string(), self.num_states);
384        }
385
386        // Add emission factors
387        for t in 0..seq_len {
388            let emission_scores = self.compute_emission_scores(input_sequence, t);
389            let emission_potentials = emission_scores.mapv(|x| x.exp());
390
391            let factor = Factor::new(
392                format!("emission_{}", t),
393                vec![format!("y_{}", t)],
394                emission_potentials.into_dyn(),
395            )?;
396
397            graph.add_factor(factor)?;
398        }
399
400        // Add transition factors
401        for t in 1..seq_len {
402            let transition_scores = if let Some(ref weights) = self.transition_weights {
403                weights.clone()
404            } else {
405                self.compute_feature_scores(input_sequence, t)
406            };
407
408            let transition_potentials = transition_scores.mapv(|x| x.exp());
409
410            let factor = Factor::new(
411                format!("transition_{}", t),
412                vec![format!("y_{}", t - 1), format!("y_{}", t)],
413                transition_potentials.into_dyn(),
414            )?;
415
416            graph.add_factor(factor)?;
417        }
418
419        Ok(graph)
420    }
421}
422
423/// Simple identity feature: always returns 1.0
424pub struct IdentityFeature {
425    name: String,
426}
427
428impl IdentityFeature {
429    pub fn new(name: String) -> Self {
430        Self { name }
431    }
432}
433
434impl FeatureFunction for IdentityFeature {
435    fn compute(
436        &self,
437        _prev_label: Option<usize>,
438        _curr_label: usize,
439        _input_sequence: &[usize],
440        _position: usize,
441    ) -> f64 {
442        1.0
443    }
444
445    fn name(&self) -> &str {
446        &self.name
447    }
448}
449
450/// Transition feature: fires when transitioning from one state to another.
451pub struct TransitionFeature {
452    from_state: usize,
453    to_state: usize,
454    name: String,
455}
456
457impl TransitionFeature {
458    pub fn new(from_state: usize, to_state: usize) -> Self {
459        Self {
460            from_state,
461            to_state,
462            name: format!("transition_{}_{}", from_state, to_state),
463        }
464    }
465}
466
467impl FeatureFunction for TransitionFeature {
468    fn compute(
469        &self,
470        prev_label: Option<usize>,
471        curr_label: usize,
472        _input_sequence: &[usize],
473        _position: usize,
474    ) -> f64 {
475        if let Some(prev) = prev_label {
476            if prev == self.from_state && curr_label == self.to_state {
477                return 1.0;
478            }
479        }
480        0.0
481    }
482
483    fn name(&self) -> &str {
484        &self.name
485    }
486}
487
488/// Emission feature: fires when a specific label is paired with a specific observation.
489pub struct EmissionFeature {
490    state: usize,
491    observation: usize,
492    name: String,
493}
494
495impl EmissionFeature {
496    pub fn new(state: usize, observation: usize) -> Self {
497        Self {
498            state,
499            observation,
500            name: format!("emission_{}_{}", state, observation),
501        }
502    }
503}
504
505impl FeatureFunction for EmissionFeature {
506    fn compute(
507        &self,
508        _prev_label: Option<usize>,
509        curr_label: usize,
510        input_sequence: &[usize],
511        position: usize,
512    ) -> f64 {
513        if curr_label == self.state
514            && position < input_sequence.len()
515            && input_sequence[position] == self.observation
516        {
517            return 1.0;
518        }
519        0.0
520    }
521
522    fn name(&self) -> &str {
523        &self.name
524    }
525}
526
527#[cfg(test)]
528mod tests {
529    use super::*;
530    use approx::assert_abs_diff_eq;
531    use scirs2_core::ndarray::Array;
532
533    #[test]
534    fn test_linear_chain_crf_creation() {
535        let crf = LinearChainCRF::new(3);
536        assert_eq!(crf.num_states, 3);
537        assert_eq!(crf.features.len(), 0);
538    }
539
540    #[test]
541    fn test_add_feature() {
542        let mut crf = LinearChainCRF::new(2);
543        let feature = Box::new(IdentityFeature::new("test".to_string()));
544        crf.add_feature(feature, 1.0);
545        assert_eq!(crf.features.len(), 1);
546    }
547
548    #[test]
549    fn test_viterbi_simple() {
550        let mut crf = LinearChainCRF::new(2);
551
552        // Set simple transition weights favoring 0->0 and 1->1
553        let transition_weights = Array::from_shape_vec(
554            vec![2, 2],
555            vec![1.0, -1.0, -1.0, 1.0], // Prefer staying in same state
556        )
557        .unwrap()
558        .into_dimensionality::<scirs2_core::ndarray::Ix2>()
559        .unwrap();
560        crf.set_transition_weights(transition_weights).unwrap();
561
562        // Simple input sequence
563        let input_sequence = vec![0, 0, 0];
564
565        // Run Viterbi
566        let (path, _score) = crf.viterbi(&input_sequence).unwrap();
567
568        assert_eq!(path.len(), 3);
569        // With positive weights on diagonal, should prefer staying in state 0
570    }
571
572    #[test]
573    fn test_forward_backward() {
574        let mut crf = LinearChainCRF::new(2);
575
576        // Set uniform transition weights
577        let transition_weights = Array::from_shape_vec(vec![2, 2], vec![0.0, 0.0, 0.0, 0.0])
578            .unwrap()
579            .into_dimensionality::<scirs2_core::ndarray::Ix2>()
580            .unwrap();
581        crf.set_transition_weights(transition_weights).unwrap();
582
583        let input_sequence = vec![0, 1];
584
585        // Run forward
586        let alpha = crf.forward(&input_sequence).unwrap();
587        assert_eq!(alpha.shape(), &[2, 2]);
588
589        // Run backward
590        let beta = crf.backward(&input_sequence).unwrap();
591        assert_eq!(beta.shape(), &[2, 2]);
592
593        // Check normalization
594        for t in 0..2 {
595            let sum: f64 = alpha.row(t).sum();
596            assert_abs_diff_eq!(sum, 1.0, epsilon = 1e-6);
597        }
598    }
599
600    #[test]
601    fn test_marginals() {
602        let mut crf = LinearChainCRF::new(2);
603
604        let transition_weights = Array::from_shape_vec(vec![2, 2], vec![0.0, 0.0, 0.0, 0.0])
605            .unwrap()
606            .into_dimensionality::<scirs2_core::ndarray::Ix2>()
607            .unwrap();
608        crf.set_transition_weights(transition_weights).unwrap();
609
610        let input_sequence = vec![0, 1];
611
612        let marginals = crf.marginals(&input_sequence).unwrap();
613
614        assert_eq!(marginals.shape(), &[2, 2]);
615
616        // Each row should sum to 1
617        for t in 0..2 {
618            let sum: f64 = marginals.row(t).sum();
619            assert_abs_diff_eq!(sum, 1.0, epsilon = 1e-6);
620        }
621    }
622
623    #[test]
624    fn test_transition_feature() {
625        let feature = TransitionFeature::new(0, 1);
626
627        // Should fire when transitioning from 0 to 1
628        let val = feature.compute(Some(0), 1, &[0, 1], 1);
629        assert_abs_diff_eq!(val, 1.0, epsilon = 1e-10);
630
631        // Should not fire for other transitions
632        let val = feature.compute(Some(0), 0, &[0, 1], 1);
633        assert_abs_diff_eq!(val, 0.0, epsilon = 1e-10);
634    }
635
636    #[test]
637    fn test_emission_feature() {
638        let feature = EmissionFeature::new(0, 5);
639
640        // Should fire when state=0 and observation=5
641        let val = feature.compute(None, 0, &[5, 3], 0);
642        assert_abs_diff_eq!(val, 1.0, epsilon = 1e-10);
643
644        // Should not fire for different observation
645        let val = feature.compute(None, 0, &[3, 5], 0);
646        assert_abs_diff_eq!(val, 0.0, epsilon = 1e-10);
647
648        // Should not fire for different state
649        let val = feature.compute(None, 1, &[5, 3], 0);
650        assert_abs_diff_eq!(val, 0.0, epsilon = 1e-10);
651    }
652
653    #[test]
654    fn test_to_factor_graph() {
655        let mut crf = LinearChainCRF::new(2);
656
657        let transition_weights = Array::from_shape_vec(vec![2, 2], vec![0.5, 0.5, 0.5, 0.5])
658            .unwrap()
659            .into_dimensionality::<scirs2_core::ndarray::Ix2>()
660            .unwrap();
661        crf.set_transition_weights(transition_weights).unwrap();
662
663        let input_sequence = vec![0, 1, 0];
664
665        let graph = crf.to_factor_graph(&input_sequence).unwrap();
666
667        // Should have 3 variables (one per position)
668        assert_eq!(graph.num_variables(), 3);
669
670        // Should have 3 emission factors + 2 transition factors = 5 total
671        assert_eq!(graph.num_factors(), 5);
672    }
673}