1use crate::{Factor, FactorGraph, PgmError, Result};
31use scirs2_core::ndarray::{Array1, Array2};
32
33pub trait FeatureFunction: Send + Sync {
39 fn compute(
47 &self,
48 prev_label: Option<usize>,
49 curr_label: usize,
50 input_sequence: &[usize],
51 position: usize,
52 ) -> f64;
53
54 fn name(&self) -> &str;
56}
57
58pub struct LinearChainCRF {
63 num_states: usize,
65 features: Vec<(Box<dyn FeatureFunction>, f64)>,
67 transition_weights: Option<Array2<f64>>,
69 emission_weights: Option<Array2<f64>>,
71}
72
73impl LinearChainCRF {
74 pub fn new(num_states: usize) -> Self {
76 Self {
77 num_states,
78 features: Vec::new(),
79 transition_weights: None,
80 emission_weights: None,
81 }
82 }
83
84 pub fn add_feature(&mut self, feature: Box<dyn FeatureFunction>, weight: f64) {
86 self.features.push((feature, weight));
87 }
88
89 pub fn set_transition_weights(&mut self, weights: Array2<f64>) -> Result<()> {
93 if weights.shape() != [self.num_states, self.num_states] {
94 return Err(PgmError::DimensionMismatch {
95 expected: vec![self.num_states, self.num_states],
96 got: weights.shape().to_vec(),
97 });
98 }
99 self.transition_weights = Some(weights);
100 Ok(())
101 }
102
103 pub fn set_emission_weights(&mut self, weights: Array2<f64>) -> Result<()> {
105 if weights.shape()[0] != self.num_states {
106 return Err(PgmError::DimensionMismatch {
107 expected: vec![self.num_states, weights.shape()[1]],
108 got: weights.shape().to_vec(),
109 });
110 }
111 self.emission_weights = Some(weights);
112 Ok(())
113 }
114
115 fn compute_feature_scores(&self, input_sequence: &[usize], position: usize) -> Array2<f64> {
117 let mut scores = Array2::zeros((self.num_states, self.num_states));
118
119 for prev_state in 0..self.num_states {
121 for curr_state in 0..self.num_states {
122 let mut score = 0.0;
123
124 for (feature, weight) in &self.features {
126 let feat_val =
127 feature.compute(Some(prev_state), curr_state, input_sequence, position);
128 score += weight * feat_val;
129 }
130
131 scores[[prev_state, curr_state]] = score;
132 }
133 }
134
135 scores
136 }
137
138 fn compute_emission_scores(&self, input_sequence: &[usize], position: usize) -> Array1<f64> {
140 let mut scores = Array1::zeros(self.num_states);
141
142 for state in 0..self.num_states {
143 let mut score = 0.0;
144
145 for (feature, weight) in &self.features {
147 let feat_val = feature.compute(None, state, input_sequence, position);
148 score += weight * feat_val;
149 }
150
151 if let Some(ref emission_weights) = self.emission_weights {
153 if position < input_sequence.len() {
154 let obs = input_sequence[position];
155 if obs < emission_weights.shape()[1] {
156 score += emission_weights[[state, obs]];
157 }
158 }
159 }
160
161 scores[state] = score;
162 }
163
164 scores
165 }
166
167 pub fn viterbi(&self, input_sequence: &[usize]) -> Result<(Vec<usize>, f64)> {
171 if input_sequence.is_empty() {
172 return Err(PgmError::InvalidGraph("Empty input sequence".to_string()));
173 }
174
175 let seq_len = input_sequence.len();
176
177 let mut viterbi_table = Array2::zeros((seq_len, self.num_states));
179
180 let mut backpointers = Array2::zeros((seq_len, self.num_states));
182
183 let emission_scores = self.compute_emission_scores(input_sequence, 0);
185 for state in 0..self.num_states {
186 viterbi_table[[0, state]] = emission_scores[state];
187 }
188
189 for t in 1..seq_len {
191 let emission_scores = self.compute_emission_scores(input_sequence, t);
192 let transition_scores = if let Some(ref weights) = self.transition_weights {
193 weights.clone()
194 } else {
195 self.compute_feature_scores(input_sequence, t)
196 };
197
198 for curr_state in 0..self.num_states {
199 let mut max_score = f64::NEG_INFINITY;
200 let mut best_prev_state = 0;
201
202 for prev_state in 0..self.num_states {
203 let score = viterbi_table[[t - 1, prev_state]]
204 + transition_scores[[prev_state, curr_state]]
205 + emission_scores[curr_state];
206
207 if score > max_score {
208 max_score = score;
209 best_prev_state = prev_state;
210 }
211 }
212
213 viterbi_table[[t, curr_state]] = max_score;
214 backpointers[[t, curr_state]] = best_prev_state as f64;
215 }
216 }
217
218 let mut best_final_state = 0;
220 let mut best_final_score = f64::NEG_INFINITY;
221 for state in 0..self.num_states {
222 let score = viterbi_table[[seq_len - 1, state]];
223 if score > best_final_score {
224 best_final_score = score;
225 best_final_state = state;
226 }
227 }
228
229 let mut path = vec![0; seq_len];
231 path[seq_len - 1] = best_final_state;
232
233 for t in (1..seq_len).rev() {
234 path[t - 1] = backpointers[[t, path[t]]] as usize;
235 }
236
237 Ok((path, best_final_score))
238 }
239
240 pub fn forward(&self, input_sequence: &[usize]) -> Result<Array2<f64>> {
244 if input_sequence.is_empty() {
245 return Err(PgmError::InvalidGraph("Empty input sequence".to_string()));
246 }
247
248 let seq_len = input_sequence.len();
249 let mut alpha = Array2::zeros((seq_len, self.num_states));
250
251 let emission_scores = self.compute_emission_scores(input_sequence, 0);
253 for state in 0..self.num_states {
254 alpha[[0, state]] = emission_scores[state].exp();
255 }
256
257 let init_sum: f64 = alpha.row(0).sum();
259 if init_sum > 0.0 {
260 for state in 0..self.num_states {
261 alpha[[0, state]] /= init_sum;
262 }
263 }
264
265 for t in 1..seq_len {
267 let emission_scores = self.compute_emission_scores(input_sequence, t);
268 let transition_scores = if let Some(ref weights) = self.transition_weights {
269 weights.clone()
270 } else {
271 self.compute_feature_scores(input_sequence, t)
272 };
273
274 for curr_state in 0..self.num_states {
275 let mut sum = 0.0;
276
277 for prev_state in 0..self.num_states {
278 sum += alpha[[t - 1, prev_state]]
279 * (transition_scores[[prev_state, curr_state]]
280 + emission_scores[curr_state])
281 .exp();
282 }
283
284 alpha[[t, curr_state]] = sum;
285 }
286
287 let row_sum: f64 = alpha.row(t).sum();
289 if row_sum > 0.0 {
290 for state in 0..self.num_states {
291 alpha[[t, state]] /= row_sum;
292 }
293 }
294 }
295
296 Ok(alpha)
297 }
298
299 pub fn backward(&self, input_sequence: &[usize]) -> Result<Array2<f64>> {
303 if input_sequence.is_empty() {
304 return Err(PgmError::InvalidGraph("Empty input sequence".to_string()));
305 }
306
307 let seq_len = input_sequence.len();
308 let mut beta = Array2::zeros((seq_len, self.num_states));
309
310 for state in 0..self.num_states {
312 beta[[seq_len - 1, state]] = 1.0;
313 }
314
315 for t in (0..seq_len - 1).rev() {
317 let emission_scores = self.compute_emission_scores(input_sequence, t + 1);
318 let transition_scores = if let Some(ref weights) = self.transition_weights {
319 weights.clone()
320 } else {
321 self.compute_feature_scores(input_sequence, t + 1)
322 };
323
324 for curr_state in 0..self.num_states {
325 let mut sum = 0.0;
326
327 for next_state in 0..self.num_states {
328 sum += beta[[t + 1, next_state]]
329 * (transition_scores[[curr_state, next_state]]
330 + emission_scores[next_state])
331 .exp();
332 }
333
334 beta[[t, curr_state]] = sum;
335 }
336
337 let row_sum: f64 = beta.row(t).sum();
339 if row_sum > 0.0 {
340 for state in 0..self.num_states {
341 beta[[t, state]] /= row_sum;
342 }
343 }
344 }
345
346 Ok(beta)
347 }
348
349 pub fn marginals(&self, input_sequence: &[usize]) -> Result<Array2<f64>> {
353 let alpha = self.forward(input_sequence)?;
354 let beta = self.backward(input_sequence)?;
355
356 let seq_len = input_sequence.len();
357 let mut marginals = Array2::zeros((seq_len, self.num_states));
358
359 for t in 0..seq_len {
360 for state in 0..self.num_states {
361 marginals[[t, state]] = alpha[[t, state]] * beta[[t, state]];
362 }
363
364 let row_sum: f64 = marginals.row(t).sum();
366 if row_sum > 0.0 {
367 for state in 0..self.num_states {
368 marginals[[t, state]] /= row_sum;
369 }
370 }
371 }
372
373 Ok(marginals)
374 }
375
376 pub fn to_factor_graph(&self, input_sequence: &[usize]) -> Result<FactorGraph> {
378 let mut graph = FactorGraph::new();
379 let seq_len = input_sequence.len();
380
381 for t in 0..seq_len {
383 graph.add_variable_with_card(format!("y_{}", t), "Label".to_string(), self.num_states);
384 }
385
386 for t in 0..seq_len {
388 let emission_scores = self.compute_emission_scores(input_sequence, t);
389 let emission_potentials = emission_scores.mapv(|x| x.exp());
390
391 let factor = Factor::new(
392 format!("emission_{}", t),
393 vec![format!("y_{}", t)],
394 emission_potentials.into_dyn(),
395 )?;
396
397 graph.add_factor(factor)?;
398 }
399
400 for t in 1..seq_len {
402 let transition_scores = if let Some(ref weights) = self.transition_weights {
403 weights.clone()
404 } else {
405 self.compute_feature_scores(input_sequence, t)
406 };
407
408 let transition_potentials = transition_scores.mapv(|x| x.exp());
409
410 let factor = Factor::new(
411 format!("transition_{}", t),
412 vec![format!("y_{}", t - 1), format!("y_{}", t)],
413 transition_potentials.into_dyn(),
414 )?;
415
416 graph.add_factor(factor)?;
417 }
418
419 Ok(graph)
420 }
421}
422
423pub struct IdentityFeature {
425 name: String,
426}
427
428impl IdentityFeature {
429 pub fn new(name: String) -> Self {
430 Self { name }
431 }
432}
433
434impl FeatureFunction for IdentityFeature {
435 fn compute(
436 &self,
437 _prev_label: Option<usize>,
438 _curr_label: usize,
439 _input_sequence: &[usize],
440 _position: usize,
441 ) -> f64 {
442 1.0
443 }
444
445 fn name(&self) -> &str {
446 &self.name
447 }
448}
449
450pub struct TransitionFeature {
452 from_state: usize,
453 to_state: usize,
454 name: String,
455}
456
457impl TransitionFeature {
458 pub fn new(from_state: usize, to_state: usize) -> Self {
459 Self {
460 from_state,
461 to_state,
462 name: format!("transition_{}_{}", from_state, to_state),
463 }
464 }
465}
466
467impl FeatureFunction for TransitionFeature {
468 fn compute(
469 &self,
470 prev_label: Option<usize>,
471 curr_label: usize,
472 _input_sequence: &[usize],
473 _position: usize,
474 ) -> f64 {
475 if let Some(prev) = prev_label {
476 if prev == self.from_state && curr_label == self.to_state {
477 return 1.0;
478 }
479 }
480 0.0
481 }
482
483 fn name(&self) -> &str {
484 &self.name
485 }
486}
487
488pub struct EmissionFeature {
490 state: usize,
491 observation: usize,
492 name: String,
493}
494
495impl EmissionFeature {
496 pub fn new(state: usize, observation: usize) -> Self {
497 Self {
498 state,
499 observation,
500 name: format!("emission_{}_{}", state, observation),
501 }
502 }
503}
504
505impl FeatureFunction for EmissionFeature {
506 fn compute(
507 &self,
508 _prev_label: Option<usize>,
509 curr_label: usize,
510 input_sequence: &[usize],
511 position: usize,
512 ) -> f64 {
513 if curr_label == self.state
514 && position < input_sequence.len()
515 && input_sequence[position] == self.observation
516 {
517 return 1.0;
518 }
519 0.0
520 }
521
522 fn name(&self) -> &str {
523 &self.name
524 }
525}
526
527#[cfg(test)]
528mod tests {
529 use super::*;
530 use approx::assert_abs_diff_eq;
531 use scirs2_core::ndarray::Array;
532
533 #[test]
534 fn test_linear_chain_crf_creation() {
535 let crf = LinearChainCRF::new(3);
536 assert_eq!(crf.num_states, 3);
537 assert_eq!(crf.features.len(), 0);
538 }
539
540 #[test]
541 fn test_add_feature() {
542 let mut crf = LinearChainCRF::new(2);
543 let feature = Box::new(IdentityFeature::new("test".to_string()));
544 crf.add_feature(feature, 1.0);
545 assert_eq!(crf.features.len(), 1);
546 }
547
548 #[test]
549 fn test_viterbi_simple() {
550 let mut crf = LinearChainCRF::new(2);
551
552 let transition_weights = Array::from_shape_vec(
554 vec![2, 2],
555 vec![1.0, -1.0, -1.0, 1.0], )
557 .unwrap()
558 .into_dimensionality::<scirs2_core::ndarray::Ix2>()
559 .unwrap();
560 crf.set_transition_weights(transition_weights).unwrap();
561
562 let input_sequence = vec![0, 0, 0];
564
565 let (path, _score) = crf.viterbi(&input_sequence).unwrap();
567
568 assert_eq!(path.len(), 3);
569 }
571
572 #[test]
573 fn test_forward_backward() {
574 let mut crf = LinearChainCRF::new(2);
575
576 let transition_weights = Array::from_shape_vec(vec![2, 2], vec![0.0, 0.0, 0.0, 0.0])
578 .unwrap()
579 .into_dimensionality::<scirs2_core::ndarray::Ix2>()
580 .unwrap();
581 crf.set_transition_weights(transition_weights).unwrap();
582
583 let input_sequence = vec![0, 1];
584
585 let alpha = crf.forward(&input_sequence).unwrap();
587 assert_eq!(alpha.shape(), &[2, 2]);
588
589 let beta = crf.backward(&input_sequence).unwrap();
591 assert_eq!(beta.shape(), &[2, 2]);
592
593 for t in 0..2 {
595 let sum: f64 = alpha.row(t).sum();
596 assert_abs_diff_eq!(sum, 1.0, epsilon = 1e-6);
597 }
598 }
599
600 #[test]
601 fn test_marginals() {
602 let mut crf = LinearChainCRF::new(2);
603
604 let transition_weights = Array::from_shape_vec(vec![2, 2], vec![0.0, 0.0, 0.0, 0.0])
605 .unwrap()
606 .into_dimensionality::<scirs2_core::ndarray::Ix2>()
607 .unwrap();
608 crf.set_transition_weights(transition_weights).unwrap();
609
610 let input_sequence = vec![0, 1];
611
612 let marginals = crf.marginals(&input_sequence).unwrap();
613
614 assert_eq!(marginals.shape(), &[2, 2]);
615
616 for t in 0..2 {
618 let sum: f64 = marginals.row(t).sum();
619 assert_abs_diff_eq!(sum, 1.0, epsilon = 1e-6);
620 }
621 }
622
623 #[test]
624 fn test_transition_feature() {
625 let feature = TransitionFeature::new(0, 1);
626
627 let val = feature.compute(Some(0), 1, &[0, 1], 1);
629 assert_abs_diff_eq!(val, 1.0, epsilon = 1e-10);
630
631 let val = feature.compute(Some(0), 0, &[0, 1], 1);
633 assert_abs_diff_eq!(val, 0.0, epsilon = 1e-10);
634 }
635
636 #[test]
637 fn test_emission_feature() {
638 let feature = EmissionFeature::new(0, 5);
639
640 let val = feature.compute(None, 0, &[5, 3], 0);
642 assert_abs_diff_eq!(val, 1.0, epsilon = 1e-10);
643
644 let val = feature.compute(None, 0, &[3, 5], 0);
646 assert_abs_diff_eq!(val, 0.0, epsilon = 1e-10);
647
648 let val = feature.compute(None, 1, &[5, 3], 0);
650 assert_abs_diff_eq!(val, 0.0, epsilon = 1e-10);
651 }
652
653 #[test]
654 fn test_to_factor_graph() {
655 let mut crf = LinearChainCRF::new(2);
656
657 let transition_weights = Array::from_shape_vec(vec![2, 2], vec![0.5, 0.5, 0.5, 0.5])
658 .unwrap()
659 .into_dimensionality::<scirs2_core::ndarray::Ix2>()
660 .unwrap();
661 crf.set_transition_weights(transition_weights).unwrap();
662
663 let input_sequence = vec![0, 1, 0];
664
665 let graph = crf.to_factor_graph(&input_sequence).unwrap();
666
667 assert_eq!(graph.num_variables(), 3);
669
670 assert_eq!(graph.num_factors(), 5);
672 }
673}