1use scirs2_core::ndarray::{s, Array1, Array2, Array3, ArrayView1};
9use scirs2_core::random::thread_rng;
10use scirs2_core::random::{RandNormal, Rng};
11use sklears_core::{
12 error::{Result as SklResult, SklearsError},
13 traits::{Estimator, Fit, Predict, Untrained},
14 types::Float,
15};
16use std::collections::{HashMap, HashSet};
17
18#[derive(Debug, Clone)]
44pub struct StructuredPerceptron<State = Untrained> {
45 max_iterations: usize,
46 learning_rate: Float,
47 random_state: Option<u64>,
48 state: State,
49}
50
51#[derive(Debug, Clone)]
53pub struct StructuredPerceptronTrained {
54 weights: Array1<Float>,
55 n_features: usize,
56 n_classes: usize,
57}
58
59impl Default for StructuredPerceptron<Untrained> {
60 fn default() -> Self {
61 Self::new()
62 }
63}
64
65impl StructuredPerceptron<Untrained> {
66 pub fn new() -> Self {
68 Self {
69 max_iterations: 100,
70 learning_rate: 1.0,
71 random_state: None,
72 state: Untrained,
73 }
74 }
75
76 pub fn max_iterations(mut self, max_iterations: usize) -> Self {
78 self.max_iterations = max_iterations;
79 self
80 }
81
82 pub fn learning_rate(mut self, learning_rate: Float) -> Self {
84 self.learning_rate = learning_rate;
85 self
86 }
87
88 pub fn random_state(mut self, random_state: u64) -> Self {
90 self.random_state = Some(random_state);
91 self
92 }
93}
94
95impl Estimator for StructuredPerceptron<Untrained> {
96 type Config = ();
97 type Error = SklearsError;
98 type Float = Float;
99
100 fn config(&self) -> &Self::Config {
101 &()
102 }
103}
104
105impl Fit<Array3<Float>, Array2<i32>> for StructuredPerceptron<Untrained> {
106 type Fitted = StructuredPerceptron<StructuredPerceptronTrained>;
107
108 fn fit(self, X: &Array3<Float>, y: &Array2<i32>) -> SklResult<Self::Fitted> {
109 let (n_sequences, max_seq_len, n_features) = X.dim();
110
111 if n_sequences != y.nrows() {
112 return Err(SklearsError::InvalidInput(
113 "X and y must have the same number of sequences".to_string(),
114 ));
115 }
116
117 if y.ncols() != max_seq_len {
118 return Err(SklearsError::InvalidInput(
119 "y sequence length must match X sequence length".to_string(),
120 ));
121 }
122
123 let n_classes = y.iter().max().unwrap_or(&0) + 1;
124 let feature_dim = n_features * n_classes as usize + n_classes as usize * n_classes as usize;
125 let mut weights = Array1::<Float>::zeros(feature_dim);
126
127 let rng = thread_rng();
128
129 for _iteration in 0..self.max_iterations {
130 let mut updated = false;
131
132 for seq_idx in 0..n_sequences {
133 let sequence = X.slice(s![seq_idx, .., ..]);
134 let true_labels = y.row(seq_idx);
135
136 let mut predicted_labels = Array1::<Float>::zeros(max_seq_len);
138
139 for pos in 0..max_seq_len {
140 let features = sequence.slice(s![pos, ..]);
141 let mut best_score = Float::NEG_INFINITY;
142 let mut best_label = 0;
143
144 for label in 0..n_classes {
145 let feature_offset = label as usize * n_features;
146 let score = features
147 .iter()
148 .enumerate()
149 .map(|(feat_idx, &feat_val)| {
150 weights[feature_offset + feat_idx] * feat_val
151 })
152 .sum::<Float>();
153
154 if score > best_score {
155 best_score = score;
156 best_label = label;
157 }
158 }
159 predicted_labels[pos] = best_label as Float;
160 }
161
162 let correct = true_labels
164 .iter()
165 .zip(predicted_labels.iter())
166 .all(|(&true_label, &pred_label)| true_label == pred_label as i32);
167
168 if !correct {
169 for pos in 0..max_seq_len {
171 let features = sequence.slice(s![pos, ..]);
172 let true_label = true_labels[pos] as usize;
173 let pred_label = predicted_labels[pos] as usize;
174
175 let true_offset = true_label * n_features;
177 for (feat_idx, &feat_val) in features.iter().enumerate() {
178 weights[true_offset + feat_idx] += self.learning_rate * feat_val;
179 }
180
181 let pred_offset = pred_label * n_features;
183 for (feat_idx, &feat_val) in features.iter().enumerate() {
184 weights[pred_offset + feat_idx] -= self.learning_rate * feat_val;
185 }
186 }
187 updated = true;
188 }
189 }
190
191 if !updated {
192 break;
193 }
194 }
195
196 Ok(StructuredPerceptron {
197 max_iterations: self.max_iterations,
198 learning_rate: self.learning_rate,
199 random_state: self.random_state,
200 state: StructuredPerceptronTrained {
201 weights,
202 n_features,
203 n_classes: n_classes as usize,
204 },
205 })
206 }
207}
208
209impl StructuredPerceptron<Untrained> {
210 pub fn weights(&self) -> Option<&Array1<Float>> {
212 None
213 }
214}
215
216impl Predict<Array3<Float>, Array2<i32>> for StructuredPerceptron<StructuredPerceptronTrained> {
217 fn predict(&self, X: &Array3<Float>) -> SklResult<Array2<i32>> {
218 let (n_sequences, max_seq_len, n_features) = X.dim();
219
220 if n_features != self.state.n_features {
221 return Err(SklearsError::InvalidInput(
222 "X has different number of features than training data".to_string(),
223 ));
224 }
225
226 let mut predictions = Array2::<i32>::zeros((n_sequences, max_seq_len));
227
228 for seq_idx in 0..n_sequences {
229 let sequence = X.slice(s![seq_idx, .., ..]);
230
231 for pos in 0..max_seq_len {
232 let features = sequence.slice(s![pos, ..]);
233 let mut best_score = Float::NEG_INFINITY;
234 let mut best_label = 0;
235
236 for label in 0..self.state.n_classes {
237 let feature_offset = label * n_features;
238 let score = features
239 .iter()
240 .enumerate()
241 .map(|(feat_idx, &feat_val)| {
242 self.state.weights[feature_offset + feat_idx] * feat_val
243 })
244 .sum::<Float>();
245
246 if score > best_score {
247 best_score = score;
248 best_label = label;
249 }
250 }
251 predictions[[seq_idx, pos]] = best_label as i32;
252 }
253 }
254
255 Ok(predictions)
256 }
257}
258
259impl StructuredPerceptron<StructuredPerceptronTrained> {
260 pub fn weights(&self) -> &Array1<Float> {
262 &self.state.weights
263 }
264}
265
266#[derive(Debug, Clone)]
289pub struct HiddenMarkovModel<State = Untrained> {
290 n_states: usize,
291 max_iterations: usize,
292 tolerance: Float,
293 random_state: Option<u64>,
294 state: State,
295}
296
297#[derive(Debug, Clone)]
299pub struct HiddenMarkovModelTrained {
300 transition_matrix: Array2<Float>,
301 emission_means: Array2<Float>,
302 emission_covariances: Array3<Float>,
303 initial_probs: Array1<Float>,
304 n_features: usize,
305 n_states: usize,
306}
307
308impl Default for HiddenMarkovModel<Untrained> {
309 fn default() -> Self {
310 Self::new()
311 }
312}
313
314impl HiddenMarkovModel<Untrained> {
315 pub fn new() -> Self {
317 Self {
318 n_states: 2,
319 max_iterations: 100,
320 tolerance: 1e-6,
321 random_state: None,
322 state: Untrained,
323 }
324 }
325
326 pub fn n_states(mut self, n_states: usize) -> Self {
328 self.n_states = n_states;
329 self
330 }
331
332 pub fn max_iterations(mut self, max_iterations: usize) -> Self {
334 self.max_iterations = max_iterations;
335 self
336 }
337
338 pub fn tolerance(mut self, tolerance: Float) -> Self {
340 self.tolerance = tolerance;
341 self
342 }
343
344 pub fn random_state(mut self, random_state: u64) -> Self {
346 self.random_state = Some(random_state);
347 self
348 }
349}
350
351impl Estimator for HiddenMarkovModel<Untrained> {
352 type Config = ();
353 type Error = SklearsError;
354 type Float = Float;
355
356 fn config(&self) -> &Self::Config {
357 &()
358 }
359}
360
361impl Fit<Array3<Float>, Array2<i32>> for HiddenMarkovModel<Untrained> {
362 type Fitted = HiddenMarkovModel<HiddenMarkovModelTrained>;
363
364 fn fit(self, X: &Array3<Float>, y: &Array2<i32>) -> SklResult<Self::Fitted> {
365 let (n_sequences, max_seq_len, n_features) = X.dim();
366
367 if n_sequences != y.nrows() {
368 return Err(SklearsError::InvalidInput(
369 "X and y must have the same number of sequences".to_string(),
370 ));
371 }
372
373 if y.ncols() != max_seq_len {
374 return Err(SklearsError::InvalidInput(
375 "y sequence length must match X sequence length".to_string(),
376 ));
377 }
378
379 let mut rng = thread_rng();
380
381 let normal_dist = RandNormal::new(0.0, 1.0).unwrap();
383 let mut transition_matrix = Array2::<Float>::zeros((self.n_states, self.n_states));
384 for i in 0..self.n_states {
385 for j in 0..self.n_states {
386 transition_matrix[[i, j]] = rng.sample(normal_dist);
387 }
388 }
389 let mut emission_means = Array2::<Float>::zeros((self.n_states, n_features));
390 for i in 0..self.n_states {
391 for j in 0..n_features {
392 emission_means[[i, j]] = rng.sample(normal_dist);
393 }
394 }
395 let mut emission_covariances =
396 Array3::from_elem((self.n_states, n_features, n_features), 1.0);
397 let mut initial_probs = Array1::from_elem(self.n_states, 1.0 / self.n_states as Float);
398
399 for i in 0..self.n_states {
401 let row_sum = transition_matrix.row(i).sum();
402 if row_sum > 0.0 {
403 for j in 0..self.n_states {
404 transition_matrix[[i, j]] /= row_sum;
405 }
406 }
407 }
408
409 for state in 0..self.n_states {
411 for i in 0..n_features {
412 emission_covariances[[state, i, i]] = 1.0;
413 }
414 }
415
416 let mut prev_likelihood = Float::NEG_INFINITY;
417
418 for _iteration in 0..self.max_iterations {
420 let mut total_likelihood = 0.0;
421
422 let mut state_counts = Array1::<Float>::zeros(self.n_states);
427 let mut transition_counts = Array2::<Float>::zeros((self.n_states, self.n_states));
428 let mut emission_sums = Array2::<Float>::zeros((self.n_states, n_features));
429
430 for seq_idx in 0..n_sequences {
431 let sequence_data = X.slice(s![seq_idx, .., ..]);
432 let sequence_states = y.row(seq_idx);
433
434 for pos in 0..max_seq_len {
435 let state = sequence_states[pos] as usize;
436 if state < self.n_states {
437 state_counts[state] += 1.0;
438
439 for feat in 0..n_features {
441 emission_sums[[state, feat]] += sequence_data[[pos, feat]];
442 }
443
444 if pos < max_seq_len - 1 {
446 let next_state = sequence_states[pos + 1] as usize;
447 if next_state < self.n_states {
448 transition_counts[[state, next_state]] += 1.0;
449 }
450 }
451 }
452 }
453
454 let first_state = sequence_states[0] as usize;
456 if first_state < self.n_states {
457 initial_probs[first_state] += 1.0;
458 }
459 }
460
461 for state in 0..self.n_states {
463 if state_counts[state] > 0.0 {
464 for feat in 0..n_features {
465 emission_means[[state, feat]] =
466 emission_sums[[state, feat]] / state_counts[state];
467 }
468 }
469
470 let row_sum = transition_counts.row(state).sum();
471 if row_sum > 0.0 {
472 for next_state in 0..self.n_states {
473 transition_matrix[[state, next_state]] =
474 transition_counts[[state, next_state]] / row_sum;
475 }
476 }
477 }
478
479 let init_sum = initial_probs.sum();
481 if init_sum > 0.0 {
482 initial_probs /= init_sum;
483 }
484
485 total_likelihood = state_counts.sum();
487
488 if (total_likelihood - prev_likelihood).abs() < self.tolerance {
489 break;
490 }
491 prev_likelihood = total_likelihood;
492 }
493
494 Ok(HiddenMarkovModel {
495 n_states: self.n_states,
496 max_iterations: self.max_iterations,
497 tolerance: self.tolerance,
498 random_state: self.random_state,
499 state: HiddenMarkovModelTrained {
500 transition_matrix,
501 emission_means,
502 emission_covariances,
503 initial_probs,
504 n_features,
505 n_states: self.n_states,
506 },
507 })
508 }
509}
510
511impl HiddenMarkovModel<Untrained> {
512 pub fn transition_matrix(&self) -> Option<&Array2<Float>> {
514 None
515 }
516
517 pub fn emission_means(&self) -> Option<&Array2<Float>> {
519 None
520 }
521
522 pub fn initial_probabilities(&self) -> Option<&Array1<Float>> {
524 None
525 }
526}
527
528impl Predict<Array3<Float>, Array2<i32>> for HiddenMarkovModel<HiddenMarkovModelTrained> {
529 fn predict(&self, X: &Array3<Float>) -> SklResult<Array2<i32>> {
530 let (n_sequences, max_seq_len, n_features) = X.dim();
531
532 if n_features != self.state.n_features {
533 return Err(SklearsError::InvalidInput(
534 "X has different number of features than training data".to_string(),
535 ));
536 }
537
538 let mut predictions = Array2::<i32>::zeros((n_sequences, max_seq_len));
539
540 for seq_idx in 0..n_sequences {
541 let sequence = X.slice(s![seq_idx, .., ..]);
542
543 let mut viterbi = Array2::<Float>::zeros((max_seq_len, self.state.n_states));
545 let mut path = Array2::<Float>::zeros((max_seq_len, self.state.n_states));
546
547 for state in 0..self.state.n_states {
549 let emission_prob = self.gaussian_probability(&sequence.slice(s![0, ..]), state);
550 viterbi[[0, state]] = self.state.initial_probs[state].ln() + emission_prob.ln();
551 }
552
553 for t in 1..max_seq_len {
555 for state in 0..self.state.n_states {
556 let emission_prob =
557 self.gaussian_probability(&sequence.slice(s![t, ..]), state);
558 let mut best_prob = Float::NEG_INFINITY;
559 let mut best_prev_state = 0;
560
561 for prev_state in 0..self.state.n_states {
562 let prob = viterbi[[t - 1, prev_state]]
563 + self.state.transition_matrix[[prev_state, state]].ln()
564 + emission_prob.ln();
565 if prob > best_prob {
566 best_prob = prob;
567 best_prev_state = prev_state;
568 }
569 }
570 viterbi[[t, state]] = best_prob;
571 path[[t, state]] = best_prev_state as Float;
572 }
573 }
574
575 let mut states = Array1::<Float>::zeros(max_seq_len);
577
578 let mut best_final_prob = Float::NEG_INFINITY;
580 let mut best_final_state = 0;
581 for state in 0..self.state.n_states {
582 if viterbi[[max_seq_len - 1, state]] > best_final_prob {
583 best_final_prob = viterbi[[max_seq_len - 1, state]];
584 best_final_state = state;
585 }
586 }
587
588 states[max_seq_len - 1] = best_final_state as Float;
589
590 for t in (0..max_seq_len - 1).rev() {
592 states[t] = path[[t + 1, states[t + 1] as usize]];
593 }
594
595 for t in 0..max_seq_len {
597 predictions[[seq_idx, t]] = states[t] as i32;
598 }
599 }
600
601 Ok(predictions)
602 }
603}
604
605impl HiddenMarkovModel<HiddenMarkovModelTrained> {
606 pub fn transition_matrix(&self) -> &Array2<Float> {
608 &self.state.transition_matrix
609 }
610
611 pub fn emission_means(&self) -> &Array2<Float> {
613 &self.state.emission_means
614 }
615
616 pub fn initial_probabilities(&self) -> &Array1<Float> {
618 &self.state.initial_probs
619 }
620
621 fn gaussian_probability(&self, observation: &ArrayView1<Float>, state: usize) -> Float {
623 let mean = self.state.emission_means.row(state);
624 let diff = observation.to_owned() - &mean.to_owned();
625
626 let exponent = -0.5 * diff.mapv(|x| x * x).sum();
628 let normalization =
629 (2.0 * std::f64::consts::PI).powf(self.state.n_features as f64 / 2.0) as Float;
630
631 (exponent.exp() / normalization).max(1e-10)
632 }
633}
634
635#[derive(Debug, Clone)]
671pub struct MaximumEntropyMarkovModel<S = Untrained> {
672 state: S,
673 max_iter: usize,
674 learning_rate: Float,
675 l2_reg: Float,
676 tolerance: Float,
677 feature_functions: Vec<FeatureFunction>,
678 random_state: Option<u64>,
679}
680
681#[derive(Debug, Clone)]
683pub struct MaximumEntropyMarkovModelTrained {
684 weights: Array1<Float>,
685 feature_functions: Vec<FeatureFunction>,
686 n_labels: usize,
687 n_features: usize,
688 label_to_idx: HashMap<i32, usize>,
689 idx_to_label: HashMap<usize, i32>,
690}
691
692#[derive(Debug, Clone)]
698pub struct FeatureFunction {
699 pub feature_type: FeatureType,
701 pub weight_index: usize,
703}
704
705#[derive(Debug, Clone)]
707pub enum FeatureType {
708 Observation(usize),
710 PreviousLabel(i32),
712 LabelObservationInteraction(i32, usize),
714}
715
716impl MaximumEntropyMarkovModel<Untrained> {
717 pub fn new() -> Self {
719 Self {
720 state: Untrained,
721 max_iter: 100,
722 learning_rate: 0.01,
723 l2_reg: 0.01,
724 tolerance: 1e-6,
725 feature_functions: Vec::new(),
726 random_state: None,
727 }
728 }
729
730 pub fn max_iter(mut self, max_iter: usize) -> Self {
732 self.max_iter = max_iter;
733 self
734 }
735
736 pub fn learning_rate(mut self, learning_rate: Float) -> Self {
738 self.learning_rate = learning_rate;
739 self
740 }
741
742 pub fn l2_regularization(mut self, l2_reg: Float) -> Self {
744 self.l2_reg = l2_reg;
745 self
746 }
747
748 pub fn tolerance(mut self, tolerance: Float) -> Self {
750 self.tolerance = tolerance;
751 self
752 }
753
754 pub fn random_state(mut self, random_state: u64) -> Self {
756 self.random_state = Some(random_state);
757 self
758 }
759}
760
761impl Default for MaximumEntropyMarkovModel<Untrained> {
762 fn default() -> Self {
763 Self::new()
764 }
765}
766
767impl Estimator for MaximumEntropyMarkovModel<Untrained> {
768 type Config = ();
769 type Error = SklearsError;
770 type Float = Float;
771
772 fn config(&self) -> &Self::Config {
773 &()
774 }
775}
776
777impl Fit<Vec<Array2<Float>>, Vec<Vec<i32>>> for MaximumEntropyMarkovModel<Untrained> {
778 type Fitted = MaximumEntropyMarkovModel<MaximumEntropyMarkovModelTrained>;
779
780 fn fit(self, X: &Vec<Array2<Float>>, y: &Vec<Vec<i32>>) -> SklResult<Self::Fitted> {
781 if X.len() != y.len() {
782 return Err(SklearsError::InvalidInput(
783 "X and y must have the same number of sequences".to_string(),
784 ));
785 }
786
787 if X.is_empty() {
788 return Err(SklearsError::InvalidInput(
789 "Cannot fit with 0 sequences".to_string(),
790 ));
791 }
792
793 let n_features = X[0].ncols();
795 let mut unique_labels = HashSet::new();
796
797 for sequence_labels in y {
798 for &label in sequence_labels {
799 unique_labels.insert(label);
800 }
801 }
802
803 let mut label_to_idx = HashMap::new();
804 let mut idx_to_label = HashMap::new();
805 let mut sorted_labels: Vec<_> = unique_labels.iter().cloned().collect();
807 sorted_labels.sort();
808 for (idx, label) in sorted_labels.iter().enumerate() {
809 label_to_idx.insert(*label, idx);
810 idx_to_label.insert(idx, *label);
811 }
812 let n_labels = unique_labels.len();
813
814 let mut feature_functions = Vec::new();
816 let mut weight_idx = 0;
817
818 for feat_idx in 0..n_features {
820 for &label in &sorted_labels {
821 feature_functions.push(FeatureFunction {
822 feature_type: FeatureType::LabelObservationInteraction(label, feat_idx),
823 weight_index: weight_idx,
824 });
825 weight_idx += 1;
826 }
827 }
828
829 for &prev_label in &sorted_labels {
831 for &curr_label in &sorted_labels {
832 feature_functions.push(FeatureFunction {
833 feature_type: FeatureType::PreviousLabel(prev_label),
834 weight_index: weight_idx,
835 });
836 weight_idx += 1;
837 }
838 }
839
840 let n_weights = weight_idx;
841 let mut weights = Array1::<Float>::zeros(n_weights);
842
843 let mut rng = if let Some(seed) = self.random_state {
845 scirs2_core::random::seeded_rng(seed)
846 } else {
847 use std::time::{SystemTime, UNIX_EPOCH};
849 let time_seed = SystemTime::now()
850 .duration_since(UNIX_EPOCH)
851 .unwrap()
852 .as_secs();
853 scirs2_core::random::seeded_rng(time_seed)
854 };
855
856 for w in weights.iter_mut() {
857 *w = rng.gen_range(-0.1..0.1);
858 }
859
860 for _iter in 0..self.max_iter {
862 let mut gradient = Array1::<Float>::zeros(n_weights);
863 let mut total_loss = 0.0;
864
865 for (seq_idx, (sequence_x, sequence_y)) in X.iter().zip(y.iter()).enumerate() {
867 let seq_len = sequence_x.nrows();
868
869 if sequence_y.len() != seq_len {
870 return Err(SklearsError::InvalidInput(format!(
871 "Sequence {} length mismatch between X and y",
872 seq_idx
873 )));
874 }
875
876 for pos in 0..seq_len {
878 let current_obs = sequence_x.row(pos);
879 let true_label = sequence_y[pos];
880 let prev_label = if pos == 0 { -1 } else { sequence_y[pos - 1] };
881
882 let features =
884 self.extract_features(¤t_obs, prev_label, &feature_functions);
885
886 let mut scores = Array1::<Float>::zeros(n_labels);
888 let mut max_score = Float::NEG_INFINITY;
889
890 for (label_idx, &label) in unique_labels.iter().enumerate() {
891 let label_features =
892 self.extract_features(¤t_obs, prev_label, &feature_functions);
893 scores[label_idx] = label_features.dot(&weights);
894 max_score = max_score.max(scores[label_idx]);
895 }
896
897 for score in scores.iter_mut() {
899 *score -= max_score;
900 }
901
902 let exp_scores: Array1<Float> = scores.mapv(|x| x.exp());
904 let sum_exp_scores = exp_scores.sum();
905 let probabilities = exp_scores / sum_exp_scores;
906
907 let true_label_idx = label_to_idx[&true_label];
909 total_loss -= probabilities[true_label_idx].ln();
910
911 for (label_idx, &label) in sorted_labels.iter().enumerate() {
913 let label_features =
914 self.extract_features(¤t_obs, prev_label, &feature_functions);
915 let prob = probabilities[label_idx];
916 let indicator = if label_idx == true_label_idx {
917 1.0
918 } else {
919 0.0
920 };
921
922 for (feat_idx, &feat_val) in label_features.iter().enumerate() {
923 gradient[feat_idx] += feat_val * (prob - indicator);
924 }
925 }
926 }
927 }
928
929 for (i, w) in weights.iter().enumerate() {
931 gradient[i] += self.l2_reg * w;
932 }
933
934 let gradient_norm = gradient.mapv(|x| x.abs()).sum();
936 weights = &weights - self.learning_rate * &gradient;
937
938 if gradient_norm < self.tolerance {
940 break;
941 }
942 }
943
944 let trained_state = MaximumEntropyMarkovModelTrained {
945 weights,
946 feature_functions,
947 n_labels,
948 n_features,
949 label_to_idx,
950 idx_to_label,
951 };
952
953 Ok(MaximumEntropyMarkovModel {
954 state: trained_state,
955 max_iter: self.max_iter,
956 learning_rate: self.learning_rate,
957 l2_reg: self.l2_reg,
958 tolerance: self.tolerance,
959 feature_functions: Vec::new(),
960 random_state: self.random_state,
961 })
962 }
963}
964
965impl MaximumEntropyMarkovModel<Untrained> {
966 fn extract_features(
968 &self,
969 observation: &ArrayView1<Float>,
970 prev_label: i32,
971 feature_functions: &[FeatureFunction],
972 ) -> Array1<Float> {
973 let mut features = Array1::<Float>::zeros(feature_functions.len());
974
975 for (i, func) in feature_functions.iter().enumerate() {
976 features[i] = match &func.feature_type {
977 FeatureType::Observation(feat_idx) => observation[*feat_idx],
978 FeatureType::PreviousLabel(label) => {
979 if prev_label == *label {
980 1.0
981 } else {
982 0.0
983 }
984 }
985 FeatureType::LabelObservationInteraction(label, feat_idx) => {
986 if prev_label == *label {
987 observation[*feat_idx]
988 } else {
989 0.0
990 }
991 }
992 };
993 }
994
995 features
996 }
997}
998
999impl Predict<Vec<Array2<Float>>, Vec<Vec<i32>>>
1000 for MaximumEntropyMarkovModel<MaximumEntropyMarkovModelTrained>
1001{
1002 fn predict(&self, X: &Vec<Array2<Float>>) -> SklResult<Vec<Vec<i32>>> {
1003 if X.is_empty() {
1004 return Ok(Vec::new());
1005 }
1006
1007 let mut predictions = Vec::with_capacity(X.len());
1008
1009 for sequence in X {
1010 let seq_len = sequence.nrows();
1011
1012 if sequence.ncols() != self.state.n_features {
1013 return Err(SklearsError::InvalidInput(
1014 "X has different number of features than training data".to_string(),
1015 ));
1016 }
1017
1018 let mut sequence_predictions = Vec::with_capacity(seq_len);
1019
1020 for pos in 0..seq_len {
1021 let current_obs = sequence.row(pos);
1022 let prev_label = if pos == 0 {
1023 -1
1024 } else {
1025 sequence_predictions[pos - 1]
1026 };
1027
1028 let mut best_score = Float::NEG_INFINITY;
1030 let mut best_label = 0;
1031
1032 let mut labels_sorted: Vec<_> = self.state.label_to_idx.iter().collect();
1034 labels_sorted.sort_by_key(|(&label, _)| label);
1035
1036 for (&label, &label_idx) in labels_sorted {
1037 let features = self.extract_features(
1038 ¤t_obs,
1039 prev_label,
1040 &self.state.feature_functions,
1041 );
1042 let score = features.dot(&self.state.weights);
1043
1044 if score > best_score {
1045 best_score = score;
1046 best_label = label;
1047 }
1048 }
1049
1050 sequence_predictions.push(best_label);
1051 }
1052
1053 predictions.push(sequence_predictions);
1054 }
1055
1056 Ok(predictions)
1057 }
1058}
1059
1060impl MaximumEntropyMarkovModel<MaximumEntropyMarkovModelTrained> {
1061 pub fn weights(&self) -> &Array1<Float> {
1063 &self.state.weights
1064 }
1065
1066 pub fn n_labels(&self) -> usize {
1068 self.state.n_labels
1069 }
1070
1071 fn extract_features(
1073 &self,
1074 observation: &ArrayView1<Float>,
1075 prev_label: i32,
1076 feature_functions: &[FeatureFunction],
1077 ) -> Array1<Float> {
1078 let mut features = Array1::<Float>::zeros(feature_functions.len());
1079
1080 for (i, func) in feature_functions.iter().enumerate() {
1081 features[i] = match &func.feature_type {
1082 FeatureType::Observation(feat_idx) => observation[*feat_idx],
1083 FeatureType::PreviousLabel(label) => {
1084 if prev_label == *label {
1085 1.0
1086 } else {
1087 0.0
1088 }
1089 }
1090 FeatureType::LabelObservationInteraction(label, feat_idx) => {
1091 if prev_label == *label {
1092 observation[*feat_idx]
1093 } else {
1094 0.0
1095 }
1096 }
1097 };
1098 }
1099
1100 features
1101 }
1102}
1103
1104#[allow(non_snake_case)]
1105#[cfg(test)]
1106mod tests {
1107 use super::*;
1108 use scirs2_core::ndarray::array;
1110
1111 #[test]
1112 #[allow(non_snake_case)]
1113 fn test_structured_perceptron_basic() {
1114 let X = Array3::from_shape_vec((1, 2, 2), vec![1.0, 2.0, 2.0, 3.0]).unwrap();
1115 let y = Array2::from_shape_vec((1, 2), vec![0, 1]).unwrap();
1116
1117 let perceptron = StructuredPerceptron::new().max_iterations(10);
1118 let trained = perceptron.fit(&X, &y).unwrap();
1119 let predictions = trained.predict(&X).unwrap();
1120
1121 assert_eq!(predictions.dim(), (1, 2));
1122 }
1123
1124 #[test]
1125 #[allow(non_snake_case)]
1126 fn test_hidden_markov_model_basic() {
1127 let X = Array3::from_shape_vec((1, 3, 2), vec![1.0, 2.0, 2.0, 3.0, 1.5, 2.5]).unwrap();
1128 let y = Array2::from_shape_vec((1, 3), vec![0, 1, 0]).unwrap();
1129
1130 let hmm = HiddenMarkovModel::new().n_states(2).max_iterations(5);
1131 let trained = hmm.fit(&X, &y).unwrap();
1132 let predictions = trained.predict(&X).unwrap();
1133
1134 assert_eq!(predictions.dim(), (1, 3));
1135 }
1136
1137 #[test]
1138 #[allow(non_snake_case)]
1139 fn test_memm_basic() {
1140 let X = vec![array![[1.0, 2.0], [2.0, 3.0]], array![[3.0, 1.0]]];
1141 let y = vec![vec![0, 1], vec![0]];
1142
1143 let memm = MaximumEntropyMarkovModel::new()
1144 .max_iter(5)
1145 .learning_rate(0.1);
1146 let trained = memm.fit(&X, &y).unwrap();
1147 let predictions = trained.predict(&X).unwrap();
1148
1149 assert_eq!(predictions.len(), 2);
1150 assert_eq!(predictions[0].len(), 2);
1151 assert_eq!(predictions[1].len(), 1);
1152 }
1153
1154 #[test]
1155 #[allow(non_snake_case)]
1156 fn test_memm_reproducibility() {
1157 let X = vec![
1158 array![[1.0, 2.0], [2.0, 3.0]],
1159 array![[3.0, 1.0], [4.0, 2.0]],
1160 ];
1161 let y = vec![vec![0, 1], vec![1, 0]];
1162
1163 let memm1 = MaximumEntropyMarkovModel::new()
1164 .max_iter(10)
1165 .random_state(42);
1166 let trained_memm1 = memm1.fit(&X, &y).unwrap();
1167 let pred1 = trained_memm1.predict(&X).unwrap();
1168
1169 let memm2 = MaximumEntropyMarkovModel::new()
1170 .max_iter(10)
1171 .random_state(42);
1172 let trained_memm2 = memm2.fit(&X, &y).unwrap();
1173 let pred2 = trained_memm2.predict(&X).unwrap();
1174
1175 assert_eq!(pred1, pred2);
1176 }
1177}