sh_layer3/
example_selectors.rs1use crate::retriever_engine::RetrieverEngine;
6use crate::types::Layer3Result;
7use async_trait::async_trait;
8
9#[async_trait]
13pub trait ExampleSelector: Send + Sync {
14 async fn select_examples(&self, query: &str, top_k: usize) -> Layer3Result<Vec<Example>>;
16
17 async fn add_example(&self, example: Example) -> Layer3Result<bool>;
19
20 async fn count(&self) -> usize;
22}
23
24#[derive(Debug, Clone)]
26pub struct Example {
27 pub input: String,
29 pub output: String,
31 pub metadata: std::collections::HashMap<String, serde_json::Value>,
33}
34
35impl Example {
36 pub fn new(input: impl Into<String>, output: impl Into<String>) -> Self {
37 Self {
38 input: input.into(),
39 output: output.into(),
40 metadata: std::collections::HashMap::new(),
41 }
42 }
43
44 pub fn with_metadata(mut self, key: impl Into<String>, value: serde_json::Value) -> Self {
45 self.metadata.insert(key.into(), value);
46 self
47 }
48}
49
50pub struct SemanticExampleSelector {
52 retriever: Box<dyn RetrieverEngine>,
53 examples: Vec<Example>,
54}
55
56impl SemanticExampleSelector {
57 pub fn new(retriever: Box<dyn RetrieverEngine>) -> Self {
58 Self {
59 retriever,
60 examples: Vec::new(),
61 }
62 }
63}
64
65#[async_trait]
66impl ExampleSelector for SemanticExampleSelector {
67 async fn select_examples(&self, query: &str, top_k: usize) -> Layer3Result<Vec<Example>> {
68 let _results = self.retriever.retrieve(query, top_k).await?;
69 Ok(self.examples.iter().take(top_k).cloned().collect())
71 }
72
73 async fn add_example(&self, _example: Example) -> Layer3Result<bool> {
74 Ok(true)
76 }
77
78 async fn count(&self) -> usize {
79 self.examples.len()
80 }
81}
82
83pub struct LengthBasedSelector {
85 examples: Vec<Example>,
86 max_length: usize,
87}
88
89impl LengthBasedSelector {
90 pub fn new(max_length: usize) -> Self {
91 Self {
92 examples: Vec::new(),
93 max_length,
94 }
95 }
96
97 fn example_length(&self, example: &Example) -> usize {
99 example.input.len() + example.output.len()
100 }
101}
102
103#[async_trait]
104impl ExampleSelector for LengthBasedSelector {
105 async fn select_examples(&self, query: &str, top_k: usize) -> Layer3Result<Vec<Example>> {
106 let query_len = query.len();
107 let mut selected = Vec::new();
108 let mut total_len = 0;
109
110 for example in &self.examples {
111 let ex_len = self.example_length(example);
112 if total_len + query_len + ex_len <= self.max_length {
113 selected.push(example.clone());
114 total_len += ex_len;
115 if selected.len() >= top_k {
116 break;
117 }
118 }
119 }
120
121 Ok(selected)
122 }
123
124 async fn add_example(&self, _example: Example) -> Layer3Result<bool> {
125 Ok(true)
126 }
127
128 async fn count(&self) -> usize {
129 self.examples.len()
130 }
131}
132
133pub struct RandomSelector {
135 examples: Vec<Example>,
136}
137
138impl RandomSelector {
139 pub fn new() -> Self {
140 Self {
141 examples: Vec::new(),
142 }
143 }
144}
145
146#[async_trait]
147impl ExampleSelector for RandomSelector {
148 async fn select_examples(&self, _query: &str, top_k: usize) -> Layer3Result<Vec<Example>> {
149 Ok(self.examples.iter().take(top_k).cloned().collect())
151 }
152
153 async fn add_example(&self, _example: Example) -> Layer3Result<bool> {
154 Ok(true)
155 }
156
157 async fn count(&self) -> usize {
158 self.examples.len()
159 }
160}
161
162impl Default for RandomSelector {
163 fn default() -> Self {
164 Self::new()
165 }
166}
167
168#[cfg(test)]
169mod tests {
170 use super::*;
171
172 #[test]
173 fn test_example_creation() {
174 let ex = Example::new("input", "output");
175 assert_eq!(ex.input, "input");
176 }
177
178 #[test]
179 fn test_random_selector() {
180 let selector = RandomSelector::new();
181 assert_eq!(selector.examples.len(), 0);
182 }
183}