Skip to main content

sh_layer3/
example_selectors.rs

1//! # Example Selectors
2//!
3//! 示例选择器:为提示词选择最相关的示例。
4
5use crate::retriever_engine::RetrieverEngine;
6use crate::types::Layer3Result;
7use async_trait::async_trait;
8
9/// 示例选择器 trait
10///
11/// 定义示例选择接口。
12#[async_trait]
13pub trait ExampleSelector: Send + Sync {
14    /// 选择示例
15    async fn select_examples(&self, query: &str, top_k: usize) -> Layer3Result<Vec<Example>>;
16
17    /// 添加示例
18    async fn add_example(&self, example: Example) -> Layer3Result<bool>;
19
20    /// 获取所有示例数量
21    async fn count(&self) -> usize;
22}
23
24/// 示例
25#[derive(Debug, Clone)]
26pub struct Example {
27    /// 输入
28    pub input: String,
29    /// 输出
30    pub output: String,
31    /// 元数据
32    pub metadata: std::collections::HashMap<String, serde_json::Value>,
33}
34
35impl Example {
36    pub fn new(input: impl Into<String>, output: impl Into<String>) -> Self {
37        Self {
38            input: input.into(),
39            output: output.into(),
40            metadata: std::collections::HashMap::new(),
41        }
42    }
43
44    pub fn with_metadata(mut self, key: impl Into<String>, value: serde_json::Value) -> Self {
45        self.metadata.insert(key.into(), value);
46        self
47    }
48}
49
50/// 语义相似度示例选择器
51pub struct SemanticExampleSelector {
52    retriever: Box<dyn RetrieverEngine>,
53    examples: Vec<Example>,
54}
55
56impl SemanticExampleSelector {
57    pub fn new(retriever: Box<dyn RetrieverEngine>) -> Self {
58        Self {
59            retriever,
60            examples: Vec::new(),
61        }
62    }
63}
64
65#[async_trait]
66impl ExampleSelector for SemanticExampleSelector {
67    async fn select_examples(&self, query: &str, top_k: usize) -> Layer3Result<Vec<Example>> {
68        let _results = self.retriever.retrieve(query, top_k).await?;
69        // 根据检索结果匹配示例
70        Ok(self.examples.iter().take(top_k).cloned().collect())
71    }
72
73    async fn add_example(&self, _example: Example) -> Layer3Result<bool> {
74        // 实际实现需要索引到 retriever
75        Ok(true)
76    }
77
78    async fn count(&self) -> usize {
79        self.examples.len()
80    }
81}
82
83/// 固定长度示例选择器
84pub struct LengthBasedSelector {
85    examples: Vec<Example>,
86    max_length: usize,
87}
88
89impl LengthBasedSelector {
90    pub fn new(max_length: usize) -> Self {
91        Self {
92            examples: Vec::new(),
93            max_length,
94        }
95    }
96
97    /// 计算示例长度
98    fn example_length(&self, example: &Example) -> usize {
99        example.input.len() + example.output.len()
100    }
101}
102
103#[async_trait]
104impl ExampleSelector for LengthBasedSelector {
105    async fn select_examples(&self, query: &str, top_k: usize) -> Layer3Result<Vec<Example>> {
106        let query_len = query.len();
107        let mut selected = Vec::new();
108        let mut total_len = 0;
109
110        for example in &self.examples {
111            let ex_len = self.example_length(example);
112            if total_len + query_len + ex_len <= self.max_length {
113                selected.push(example.clone());
114                total_len += ex_len;
115                if selected.len() >= top_k {
116                    break;
117                }
118            }
119        }
120
121        Ok(selected)
122    }
123
124    async fn add_example(&self, _example: Example) -> Layer3Result<bool> {
125        Ok(true)
126    }
127
128    async fn count(&self) -> usize {
129        self.examples.len()
130    }
131}
132
133/// 随机示例选择器
134pub struct RandomSelector {
135    examples: Vec<Example>,
136}
137
138impl RandomSelector {
139    pub fn new() -> Self {
140        Self {
141            examples: Vec::new(),
142        }
143    }
144}
145
146#[async_trait]
147impl ExampleSelector for RandomSelector {
148    async fn select_examples(&self, _query: &str, top_k: usize) -> Layer3Result<Vec<Example>> {
149        // 随机选择(简化实现)
150        Ok(self.examples.iter().take(top_k).cloned().collect())
151    }
152
153    async fn add_example(&self, _example: Example) -> Layer3Result<bool> {
154        Ok(true)
155    }
156
157    async fn count(&self) -> usize {
158        self.examples.len()
159    }
160}
161
162impl Default for RandomSelector {
163    fn default() -> Self {
164        Self::new()
165    }
166}
167
168#[cfg(test)]
169mod tests {
170    use super::*;
171
172    #[test]
173    fn test_example_creation() {
174        let ex = Example::new("input", "output");
175        assert_eq!(ex.input, "input");
176    }
177
178    #[test]
179    fn test_random_selector() {
180        let selector = RandomSelector::new();
181        assert_eq!(selector.examples.len(), 0);
182    }
183}