Skip to main content

xz_embed/embedder/
mock.rs

1use async_trait::async_trait;
2use std::fmt::Debug;
3use std::sync::Mutex;
4
5use crate::error::EmbedError;
6use crate::traits::{EmbedModelInfo, EmbedPricing, EmbeddingModel};
7
8/// 测试用 Mock Embedder
9#[derive(Debug)]
10pub struct MockEmbedder {
11    info: EmbedModelInfo,
12    expected_input: Mutex<Option<Vec<String>>>,
13    mock_output: Mutex<Vec<Vec<f32>>>,
14    should_error: Mutex<Option<EmbedError>>,
15}
16
17impl MockEmbedder {
18    /// 创建新的 MockEmbedder
19    pub fn new(dimensions: usize, max_batch_size: usize) -> Self {
20        Self {
21            info: EmbedModelInfo {
22                name: "mock-embedder".into(),
23                display_name: "Mock Embedder".into(),
24                supported_dimensions: None,
25                current_dimension: dimensions,
26                max_input_tokens: 1024,
27                max_batch_size,
28                pricing: EmbedPricing {
29                    input_per_million: 0.0,
30                },
31            },
32            expected_input: Mutex::new(None),
33            mock_output: Mutex::new(vec![]),
34            should_error: Mutex::new(None),
35        }
36    }
37
38    /// 设置期望输入和返回输出
39    pub fn expect_embed(&mut self, inputs: Vec<&str>, outputs: Vec<Vec<f32>>) -> &mut Self {
40        *self.expected_input.get_mut().unwrap() = Some(inputs.iter().map(|s| s.to_string()).collect());
41        *self.mock_output.get_mut().unwrap() = outputs;
42        self
43    }
44
45    /// 设置应返回的错误
46    pub fn set_error(&mut self, error: EmbedError) {
47        *self.should_error.get_mut().unwrap() = Some(error);
48    }
49
50    /// 直接设置返回向量
51    pub fn set_output(&mut self, vectors: Vec<Vec<f32>>) {
52        *self.mock_output.get_mut().unwrap() = vectors;
53    }
54}
55
56#[async_trait]
57impl EmbeddingModel for MockEmbedder {
58    async fn embed(&self, input: &[&str]) -> Result<Vec<Vec<f32>>, EmbedError> {
59        if input.is_empty() {
60            return Err(EmbedError::EmptyBatch);
61        }
62
63        // 检查是否设定了错误
64        if let Some(ref err) = *self.should_error.lock().unwrap() {
65            return Err(EmbedError::Model(format!("Mock error: {err}")));
66        }
67
68        // 检查期望输入
69        if let Some(ref expected) = *self.expected_input.lock().unwrap() {
70            let actual: Vec<String> = input.iter().map(|s| s.to_string()).collect();
71            if &actual != expected {
72                return Err(EmbedError::Model(format!(
73                    "输入不匹配: expected {expected:?}, got {actual:?}"
74                )));
75            }
76        }
77
78        let output = self.mock_output.lock().unwrap();
79        if !output.is_empty() {
80            if output.len() != input.len() {
81                return Err(EmbedError::Model(format!(
82                    "输出数量不匹配: expected {}, got {}",
83                    input.len(),
84                    output.len()
85                )));
86            }
87            return Ok(output.clone());
88        }
89
90        // 默认行为:每个输入生成零向量
91        Ok(vec![vec![0.0; self.info.current_dimension]; input.len()])
92    }
93
94    fn model_info(&self) -> &EmbedModelInfo {
95        &self.info
96    }
97
98    fn max_batch_size(&self) -> usize {
99        self.info.max_batch_size
100    }
101
102    fn dimensions(&self) -> usize {
103        self.info.current_dimension
104    }
105}