Skip to main content

tensorlogic_quantrs_hooks/
models.rs

1//! Specialized model builders for common PGM types.
2//!
3//! This module provides convenient APIs for constructing and working with
4//! common probabilistic graphical models:
5//! - Bayesian Networks
6//! - Hidden Markov Models (HMMs)
7//! - Conditional Random Fields (CRFs)
8//! - Markov Random Fields (MRFs)
9
10use scirs2_core::ndarray::ArrayD;
11use std::collections::HashMap;
12
13use crate::error::{PgmError, Result};
14use crate::factor::Factor;
15use crate::graph::FactorGraph;
16
17/// Bayesian Network builder.
18///
19/// Provides a convenient API for constructing directed acyclic graphical models
20/// with conditional probability distributions.
21pub struct BayesianNetwork {
22    graph: FactorGraph,
23    structure: HashMap<String, Vec<String>>, // var -> parents
24}
25
26impl BayesianNetwork {
27    /// Create a new Bayesian Network.
28    pub fn new() -> Self {
29        Self {
30            graph: FactorGraph::new(),
31            structure: HashMap::new(),
32        }
33    }
34
35    /// Add a variable node to the network.
36    pub fn add_variable(&mut self, name: String, cardinality: usize) -> &mut Self {
37        self.graph
38            .add_variable_with_card(name.clone(), "Discrete".to_string(), cardinality);
39        self.structure.insert(name, Vec::new());
40        self
41    }
42
43    /// Add a conditional probability distribution P(child | parents).
44    ///
45    /// # Arguments
46    /// * `child` - The dependent variable
47    /// * `parents` - Parent variables that child depends on
48    /// * `cpd` - Conditional probability table (dimensions: [parent1_card, ..., child_card])
49    pub fn add_cpd(
50        &mut self,
51        child: String,
52        parents: Vec<String>,
53        cpd: ArrayD<f64>,
54    ) -> Result<&mut Self> {
55        // Verify child exists
56        if self.graph.get_variable(&child).is_none() {
57            return Err(PgmError::VariableNotFound(child));
58        }
59
60        // Verify parents exist
61        for parent in &parents {
62            if self.graph.get_variable(parent).is_none() {
63                return Err(PgmError::VariableNotFound(parent.clone()));
64            }
65        }
66
67        // Record structure
68        self.structure.insert(child.clone(), parents.clone());
69
70        // Create factor: variables = [parents..., child]
71        let mut factor_vars = parents.clone();
72        factor_vars.push(child.clone());
73
74        let factor = Factor::new(format!("P({}|{:?})", child, parents), factor_vars, cpd)?;
75
76        self.graph.add_factor(factor)?;
77        Ok(self)
78    }
79
80    /// Add a prior probability P(variable) for a root node.
81    pub fn add_prior(&mut self, variable: String, prior: ArrayD<f64>) -> Result<&mut Self> {
82        let factor = Factor::new(format!("P({})", variable), vec![variable.clone()], prior)?;
83        self.graph.add_factor(factor)?;
84        self.structure.insert(variable, Vec::new());
85        Ok(self)
86    }
87
88    /// Get the underlying factor graph.
89    pub fn graph(&self) -> &FactorGraph {
90        &self.graph
91    }
92
93    /// Check if the network is acyclic (DAG property).
94    pub fn is_acyclic(&self) -> bool {
95        // Simple cycle detection using DFS
96        let mut visited = HashMap::new();
97        let mut rec_stack = HashMap::new();
98
99        for node in self.structure.keys() {
100            if !visited.contains_key(node) && self.has_cycle(node, &mut visited, &mut rec_stack) {
101                return false;
102            }
103        }
104
105        true
106    }
107
108    fn has_cycle(
109        &self,
110        node: &str,
111        visited: &mut HashMap<String, bool>,
112        rec_stack: &mut HashMap<String, bool>,
113    ) -> bool {
114        visited.insert(node.to_string(), true);
115        rec_stack.insert(node.to_string(), true);
116
117        if let Some(parents) = self.structure.get(node) {
118            for parent in parents {
119                if !visited.contains_key(parent) {
120                    if self.has_cycle(parent, visited, rec_stack) {
121                        return true;
122                    }
123                } else if rec_stack.get(parent) == Some(&true) {
124                    return true;
125                }
126            }
127        }
128
129        rec_stack.insert(node.to_string(), false);
130        false
131    }
132
133    /// Get topological ordering of variables (ancestors before descendants).
134    pub fn topological_order(&self) -> Result<Vec<String>> {
135        if !self.is_acyclic() {
136            return Err(PgmError::InvalidGraph(
137                "Network contains cycles".to_string(),
138            ));
139        }
140
141        let mut in_degree: HashMap<String, usize> = HashMap::new();
142        let mut children: HashMap<String, Vec<String>> = HashMap::new();
143
144        // Build reverse graph (child -> parents becomes parent -> children)
145        for (child, parents) in &self.structure {
146            in_degree.insert(child.clone(), parents.len());
147            for parent in parents {
148                children
149                    .entry(parent.clone())
150                    .or_default()
151                    .push(child.clone());
152            }
153        }
154
155        // Kahn's algorithm for topological sort
156        let mut queue: Vec<String> = in_degree
157            .iter()
158            .filter(|(_, &deg)| deg == 0)
159            .map(|(v, _)| v.clone())
160            .collect();
161
162        let mut result = Vec::new();
163
164        while let Some(node) = queue.pop() {
165            result.push(node.clone());
166
167            if let Some(child_nodes) = children.get(&node) {
168                for child in child_nodes {
169                    if let Some(deg) = in_degree.get_mut(child) {
170                        *deg -= 1;
171                        if *deg == 0 {
172                            queue.push(child.clone());
173                        }
174                    }
175                }
176            }
177        }
178
179        if result.len() != self.structure.len() {
180            return Err(PgmError::InvalidGraph(
181                "Could not compute topological order".to_string(),
182            ));
183        }
184
185        Ok(result)
186    }
187}
188
189impl Default for BayesianNetwork {
190    fn default() -> Self {
191        Self::new()
192    }
193}
194
195/// Hidden Markov Model builder.
196///
197/// A temporal model with hidden states and observations.
198pub struct HiddenMarkovModel {
199    graph: FactorGraph,
200    #[allow(dead_code)]
201    num_states: usize,
202    #[allow(dead_code)]
203    num_observations: usize,
204    time_steps: usize,
205}
206
207impl HiddenMarkovModel {
208    /// Create a new HMM.
209    ///
210    /// # Arguments
211    /// * `num_states` - Number of hidden states
212    /// * `num_observations` - Number of observable symbols
213    /// * `time_steps` - Length of sequence
214    pub fn new(num_states: usize, num_observations: usize, time_steps: usize) -> Self {
215        let mut graph = FactorGraph::new();
216
217        // Add hidden state variables
218        for t in 0..time_steps {
219            graph.add_variable_with_card(
220                format!("state_{}", t),
221                "HiddenState".to_string(),
222                num_states,
223            );
224        }
225
226        // Add observation variables
227        for t in 0..time_steps {
228            graph.add_variable_with_card(
229                format!("obs_{}", t),
230                "Observation".to_string(),
231                num_observations,
232            );
233        }
234
235        Self {
236            graph,
237            num_states,
238            num_observations,
239            time_steps,
240        }
241    }
242
243    /// Set initial state distribution P(state_0).
244    pub fn set_initial_distribution(&mut self, initial: ArrayD<f64>) -> Result<&mut Self> {
245        let factor = Factor::new(
246            "P(state_0)".to_string(),
247            vec!["state_0".to_string()],
248            initial,
249        )?;
250        self.graph.add_factor(factor)?;
251        Ok(self)
252    }
253
254    /// Set transition matrix P(state_t | state_{t-1}).
255    ///
256    /// # Arguments
257    /// * `transition` - Transition probabilities [from_state, to_state]
258    pub fn set_transition_matrix(&mut self, transition: ArrayD<f64>) -> Result<&mut Self> {
259        // Add transition factors for all time steps
260        for t in 1..self.time_steps {
261            let factor = Factor::new(
262                format!("P(state_{}|state_{})", t, t - 1),
263                vec![format!("state_{}", t - 1), format!("state_{}", t)],
264                transition.clone(),
265            )?;
266            self.graph.add_factor(factor)?;
267        }
268        Ok(self)
269    }
270
271    /// Set emission matrix P(obs_t | state_t).
272    ///
273    /// # Arguments
274    /// * `emission` - Emission probabilities [state, observation]
275    pub fn set_emission_matrix(&mut self, emission: ArrayD<f64>) -> Result<&mut Self> {
276        // Add emission factors for all time steps
277        for t in 0..self.time_steps {
278            let factor = Factor::new(
279                format!("P(obs_{}|state_{})", t, t),
280                vec![format!("state_{}", t), format!("obs_{}", t)],
281                emission.clone(),
282            )?;
283            self.graph.add_factor(factor)?;
284        }
285        Ok(self)
286    }
287
288    /// Get the underlying factor graph.
289    pub fn graph(&self) -> &FactorGraph {
290        &self.graph
291    }
292
293    /// Perform filtering: compute P(state_t | obs_0:t).
294    ///
295    /// Uses variable elimination to compute the marginal distribution over
296    /// the hidden state at time t given observations from 0 to t.
297    pub fn filter(&self, observations: &[usize], t: usize) -> Result<ArrayD<f64>> {
298        if t >= self.time_steps {
299            return Err(PgmError::InvalidDistribution(format!(
300                "Time step {} exceeds sequence length {}",
301                t, self.time_steps
302            )));
303        }
304
305        if t >= observations.len() {
306            return Err(PgmError::InvalidDistribution(format!(
307                "Not enough observations: need {} but got {}",
308                t + 1,
309                observations.len()
310            )));
311        }
312
313        // Create a copy of the graph with evidence
314        let mut evidence_graph = self.graph.clone();
315
316        // Apply observations up to time t
317        for (time, &obs_value) in observations.iter().enumerate().take(t + 1) {
318            let obs_var = format!("obs_{}", time);
319
320            // Add evidence factor: indicator function for observed value
321            let mut evidence_values = vec![0.0; self.num_observations];
322            evidence_values[obs_value] = 1.0;
323            let evidence_factor = Factor::new(
324                format!("evidence_{}", time),
325                vec![obs_var.clone()],
326                ArrayD::from_shape_vec(vec![self.num_observations], evidence_values)?,
327            )?;
328            evidence_graph.add_factor(evidence_factor)?;
329        }
330
331        // Use variable elimination to compute marginal
332        use crate::variable_elimination::VariableElimination;
333        let ve = VariableElimination::new();
334        let state_var = format!("state_{}", t);
335        ve.marginalize(&evidence_graph, &state_var)
336    }
337
338    /// Perform smoothing: compute P(state_t | obs_0:T).
339    ///
340    /// Uses variable elimination with all observations to compute the
341    /// marginal distribution over the hidden state at time t.
342    pub fn smooth(&self, observations: &[usize], t: usize) -> Result<ArrayD<f64>> {
343        if t >= self.time_steps {
344            return Err(PgmError::InvalidDistribution(format!(
345                "Time step {} exceeds sequence length {}",
346                t, self.time_steps
347            )));
348        }
349
350        if observations.len() != self.time_steps {
351            return Err(PgmError::InvalidDistribution(format!(
352                "Expected {} observations but got {}",
353                self.time_steps,
354                observations.len()
355            )));
356        }
357
358        // Create a copy of the graph with all evidence
359        let mut evidence_graph = self.graph.clone();
360
361        // Apply all observations
362        for (time, &obs_value) in observations.iter().enumerate().take(self.time_steps) {
363            let obs_var = format!("obs_{}", time);
364
365            // Add evidence factor
366            let mut evidence_values = vec![0.0; self.num_observations];
367            evidence_values[obs_value] = 1.0;
368            let evidence_factor = Factor::new(
369                format!("evidence_{}", time),
370                vec![obs_var.clone()],
371                ArrayD::from_shape_vec(vec![self.num_observations], evidence_values)?,
372            )?;
373            evidence_graph.add_factor(evidence_factor)?;
374        }
375
376        // Use variable elimination to compute marginal
377        use crate::variable_elimination::VariableElimination;
378        let ve = VariableElimination::new();
379        let state_var = format!("state_{}", t);
380        ve.marginalize(&evidence_graph, &state_var)
381    }
382
383    /// Compute most likely state sequence (Viterbi algorithm).
384    ///
385    /// Finds the most probable sequence of hidden states given observations
386    /// using dynamic programming.
387    pub fn viterbi(&self, observations: &[usize]) -> Result<Vec<usize>> {
388        if observations.len() != self.time_steps {
389            return Err(PgmError::InvalidDistribution(format!(
390                "Observations length {} does not match time steps {}",
391                observations.len(),
392                self.time_steps
393            )));
394        }
395
396        // Create graph with evidence
397        let mut evidence_graph = self.graph.clone();
398
399        // Apply all observations
400        for (time, &obs_value) in observations.iter().enumerate().take(self.time_steps) {
401            let obs_var = format!("obs_{}", time);
402
403            let mut evidence_values = vec![0.0; self.num_observations];
404            evidence_values[obs_value] = 1.0;
405            let evidence_factor = Factor::new(
406                format!("evidence_{}", time),
407                vec![obs_var.clone()],
408                ArrayD::from_shape_vec(vec![self.num_observations], evidence_values)?,
409            )?;
410            evidence_graph.add_factor(evidence_factor)?;
411        }
412
413        // Use variable elimination with MAX to find MAP assignment
414        use crate::variable_elimination::VariableElimination;
415        let ve = VariableElimination::new();
416        let assignment = ve.map(&evidence_graph)?;
417
418        // Extract state sequence in temporal order
419        let mut sequence = Vec::new();
420        for t in 0..self.time_steps {
421            let state_var = format!("state_{}", t);
422            if let Some(&state) = assignment.get(&state_var) {
423                sequence.push(state);
424            } else {
425                return Err(PgmError::VariableNotFound(state_var));
426            }
427        }
428
429        Ok(sequence)
430    }
431}
432
433/// Markov Random Field builder (undirected graphical model).
434pub struct MarkovRandomField {
435    graph: FactorGraph,
436}
437
438impl MarkovRandomField {
439    /// Create a new MRF.
440    pub fn new() -> Self {
441        Self {
442            graph: FactorGraph::new(),
443        }
444    }
445
446    /// Add a variable node.
447    pub fn add_variable(&mut self, name: String, cardinality: usize) -> &mut Self {
448        self.graph
449            .add_variable_with_card(name, "Discrete".to_string(), cardinality);
450        self
451    }
452
453    /// Add a pairwise potential φ(x_i, x_j).
454    pub fn add_pairwise_potential(
455        &mut self,
456        var1: String,
457        var2: String,
458        potential: ArrayD<f64>,
459    ) -> Result<&mut Self> {
460        let factor = Factor::new(
461            format!("φ({},{})", var1, var2),
462            vec![var1.clone(), var2.clone()],
463            potential,
464        )?;
465        self.graph.add_factor(factor)?;
466        Ok(self)
467    }
468
469    /// Add a unary potential φ(x_i).
470    pub fn add_unary_potential(
471        &mut self,
472        var: String,
473        potential: ArrayD<f64>,
474    ) -> Result<&mut Self> {
475        let factor = Factor::new(format!("φ({})", var), vec![var.clone()], potential)?;
476        self.graph.add_factor(factor)?;
477        Ok(self)
478    }
479
480    /// Get the underlying factor graph.
481    pub fn graph(&self) -> &FactorGraph {
482        &self.graph
483    }
484}
485
486impl Default for MarkovRandomField {
487    fn default() -> Self {
488        Self::new()
489    }
490}
491
492/// Conditional Random Field builder (discriminative model for structured prediction).
493pub struct ConditionalRandomField {
494    graph: FactorGraph,
495    input_vars: Vec<String>,
496    output_vars: Vec<String>,
497}
498
499impl ConditionalRandomField {
500    /// Create a new CRF.
501    pub fn new() -> Self {
502        Self {
503            graph: FactorGraph::new(),
504            input_vars: Vec::new(),
505            output_vars: Vec::new(),
506        }
507    }
508
509    /// Add an input (observed) variable.
510    pub fn add_input_variable(&mut self, name: String, cardinality: usize) -> &mut Self {
511        self.graph
512            .add_variable_with_card(name.clone(), "Input".to_string(), cardinality);
513        self.input_vars.push(name);
514        self
515    }
516
517    /// Add an output (label) variable.
518    pub fn add_output_variable(&mut self, name: String, cardinality: usize) -> &mut Self {
519        self.graph
520            .add_variable_with_card(name.clone(), "Output".to_string(), cardinality);
521        self.output_vars.push(name);
522        self
523    }
524
525    /// Add a feature function (factor).
526    pub fn add_feature(
527        &mut self,
528        name: String,
529        variables: Vec<String>,
530        potential: ArrayD<f64>,
531    ) -> Result<&mut Self> {
532        let factor = Factor::new(format!("feature_{}", name), variables, potential)?;
533        self.graph.add_factor(factor)?;
534        Ok(self)
535    }
536
537    /// Get the underlying factor graph.
538    pub fn graph(&self) -> &FactorGraph {
539        &self.graph
540    }
541}
542
543impl Default for ConditionalRandomField {
544    fn default() -> Self {
545        Self::new()
546    }
547}
548
549#[cfg(test)]
550mod tests {
551    use super::*;
552    use scirs2_core::ndarray::Array;
553
554    #[test]
555    fn test_bayesian_network_creation() {
556        let mut bn = BayesianNetwork::new();
557        bn.add_variable("x".to_string(), 2);
558        bn.add_variable("y".to_string(), 2);
559
560        assert!(bn.graph().get_variable("x").is_some());
561        assert!(bn.graph().get_variable("y").is_some());
562    }
563
564    #[test]
565    fn test_bayesian_network_cpd() {
566        let mut bn = BayesianNetwork::new();
567        bn.add_variable("x".to_string(), 2);
568        bn.add_variable("y".to_string(), 2);
569
570        let prior = Array::from_shape_vec(vec![2], vec![0.6, 0.4])
571            .unwrap()
572            .into_dyn();
573        bn.add_prior("x".to_string(), prior).unwrap();
574
575        let cpd = Array::from_shape_vec(vec![2, 2], vec![0.9, 0.1, 0.2, 0.8])
576            .unwrap()
577            .into_dyn();
578        bn.add_cpd("y".to_string(), vec!["x".to_string()], cpd)
579            .unwrap();
580
581        assert_eq!(bn.graph().num_factors(), 2);
582    }
583
584    #[test]
585    fn test_bayesian_network_acyclic() {
586        let mut bn = BayesianNetwork::new();
587        bn.add_variable("x".to_string(), 2);
588        bn.add_variable("y".to_string(), 2);
589
590        let cpd = Array::from_shape_vec(vec![2, 2], vec![0.9, 0.1, 0.2, 0.8])
591            .unwrap()
592            .into_dyn();
593        bn.add_cpd("y".to_string(), vec!["x".to_string()], cpd)
594            .unwrap();
595
596        assert!(bn.is_acyclic());
597    }
598
599    #[test]
600    fn test_bayesian_network_topological_order() {
601        let mut bn = BayesianNetwork::new();
602        bn.add_variable("x".to_string(), 2);
603        bn.add_variable("y".to_string(), 2);
604        bn.add_variable("z".to_string(), 2);
605
606        let cpd_y = Array::from_shape_vec(vec![2, 2], vec![0.9, 0.1, 0.2, 0.8])
607            .unwrap()
608            .into_dyn();
609        bn.add_cpd("y".to_string(), vec!["x".to_string()], cpd_y)
610            .unwrap();
611
612        let cpd_z = Array::from_shape_vec(vec![2, 2], vec![0.8, 0.2, 0.3, 0.7])
613            .unwrap()
614            .into_dyn();
615        bn.add_cpd("z".to_string(), vec!["y".to_string()], cpd_z)
616            .unwrap();
617
618        let order = bn.topological_order().unwrap();
619        assert_eq!(order.len(), 3);
620    }
621
622    #[test]
623    fn test_hmm_creation() {
624        let hmm = HiddenMarkovModel::new(3, 2, 5);
625        assert_eq!(hmm.graph().num_variables(), 10); // 5 states + 5 observations
626    }
627
628    #[test]
629    fn test_hmm_parameters() {
630        let mut hmm = HiddenMarkovModel::new(2, 2, 3);
631
632        let initial = Array::from_shape_vec(vec![2], vec![0.6, 0.4])
633            .unwrap()
634            .into_dyn();
635        hmm.set_initial_distribution(initial).unwrap();
636
637        let transition = Array::from_shape_vec(vec![2, 2], vec![0.7, 0.3, 0.4, 0.6])
638            .unwrap()
639            .into_dyn();
640        hmm.set_transition_matrix(transition).unwrap();
641
642        let emission = Array::from_shape_vec(vec![2, 2], vec![0.9, 0.1, 0.2, 0.8])
643            .unwrap()
644            .into_dyn();
645        hmm.set_emission_matrix(emission).unwrap();
646
647        assert!(hmm.graph().num_factors() > 0);
648    }
649
650    #[test]
651    fn test_hmm_filtering() {
652        let mut hmm = HiddenMarkovModel::new(2, 2, 3);
653
654        let initial = Array::from_shape_vec(vec![2], vec![0.6, 0.4])
655            .unwrap()
656            .into_dyn();
657        hmm.set_initial_distribution(initial).unwrap();
658
659        let transition = Array::from_shape_vec(vec![2, 2], vec![0.7, 0.3, 0.4, 0.6])
660            .unwrap()
661            .into_dyn();
662        hmm.set_transition_matrix(transition).unwrap();
663
664        let emission = Array::from_shape_vec(vec![2, 2], vec![0.9, 0.1, 0.2, 0.8])
665            .unwrap()
666            .into_dyn();
667        hmm.set_emission_matrix(emission).unwrap();
668
669        // Filter with observations
670        let observations = vec![0, 1, 0];
671        let result = hmm.filter(&observations, 1);
672        assert!(result.is_ok());
673
674        let marginal = result.unwrap();
675        assert_eq!(marginal.len(), 2);
676        // Should be normalized
677        let sum: f64 = marginal.iter().sum();
678        assert!((sum - 1.0).abs() < 1e-6);
679    }
680
681    #[test]
682    fn test_hmm_smoothing() {
683        let mut hmm = HiddenMarkovModel::new(2, 2, 3);
684
685        let initial = Array::from_shape_vec(vec![2], vec![0.6, 0.4])
686            .unwrap()
687            .into_dyn();
688        hmm.set_initial_distribution(initial).unwrap();
689
690        let transition = Array::from_shape_vec(vec![2, 2], vec![0.7, 0.3, 0.4, 0.6])
691            .unwrap()
692            .into_dyn();
693        hmm.set_transition_matrix(transition).unwrap();
694
695        let emission = Array::from_shape_vec(vec![2, 2], vec![0.9, 0.1, 0.2, 0.8])
696            .unwrap()
697            .into_dyn();
698        hmm.set_emission_matrix(emission).unwrap();
699
700        // Smooth with all observations
701        let observations = vec![0, 1, 0];
702        let result = hmm.smooth(&observations, 1);
703        assert!(result.is_ok());
704
705        let marginal = result.unwrap();
706        assert_eq!(marginal.len(), 2);
707        let sum: f64 = marginal.iter().sum();
708        assert!((sum - 1.0).abs() < 1e-6);
709    }
710
711    #[test]
712    fn test_hmm_viterbi() {
713        let mut hmm = HiddenMarkovModel::new(2, 2, 3);
714
715        let initial = Array::from_shape_vec(vec![2], vec![0.6, 0.4])
716            .unwrap()
717            .into_dyn();
718        hmm.set_initial_distribution(initial).unwrap();
719
720        let transition = Array::from_shape_vec(vec![2, 2], vec![0.7, 0.3, 0.4, 0.6])
721            .unwrap()
722            .into_dyn();
723        hmm.set_transition_matrix(transition).unwrap();
724
725        let emission = Array::from_shape_vec(vec![2, 2], vec![0.9, 0.1, 0.2, 0.8])
726            .unwrap()
727            .into_dyn();
728        hmm.set_emission_matrix(emission).unwrap();
729
730        // Run Viterbi
731        let observations = vec![0, 1, 0];
732        let result = hmm.viterbi(&observations);
733        assert!(result.is_ok());
734
735        let sequence = result.unwrap();
736        assert_eq!(sequence.len(), 3);
737        // Each state should be valid (0 or 1)
738        for state in sequence {
739            assert!(state < 2);
740        }
741    }
742
743    #[test]
744    fn test_mrf_creation() {
745        let mut mrf = MarkovRandomField::new();
746        mrf.add_variable("x".to_string(), 2);
747        mrf.add_variable("y".to_string(), 2);
748
749        let potential = Array::from_shape_vec(vec![2, 2], vec![1.0, 0.5, 0.5, 1.0])
750            .unwrap()
751            .into_dyn();
752        mrf.add_pairwise_potential("x".to_string(), "y".to_string(), potential)
753            .unwrap();
754
755        assert_eq!(mrf.graph().num_factors(), 1);
756    }
757
758    #[test]
759    fn test_crf_creation() {
760        let mut crf = ConditionalRandomField::new();
761        crf.add_input_variable("x".to_string(), 3);
762        crf.add_output_variable("y".to_string(), 2);
763
764        let feature = Array::from_shape_vec(vec![3, 2], vec![1.0, 0.5, 0.8, 0.2, 0.6, 0.4])
765            .unwrap()
766            .into_dyn();
767        crf.add_feature(
768            "f1".to_string(),
769            vec!["x".to_string(), "y".to_string()],
770            feature,
771        )
772        .unwrap();
773
774        assert_eq!(crf.graph().num_factors(), 1);
775    }
776}