trustformers_training/few_shot/
mod.rs1pub mod cross_task;
2pub mod in_context;
3pub mod meta_learning;
4pub mod prompt_tuning;
5pub mod task_adaptation;
6
7pub use cross_task::{CrossTaskGeneralizer, GeneralizationConfig, TaskEmbedding};
8pub use in_context::{ICLExample, InContextConfig, InContextLearner};
9pub use meta_learning::{
10 MAMLConfig, MAMLTrainer, MetaLearningAlgorithm, ReptileConfig, ReptileTrainer, TaskBatch,
11};
12pub use prompt_tuning::{PromptConfig, PromptTuner, SoftPrompt};
13pub use task_adaptation::{AdaptationConfig, TaskAdapter, TaskDescriptor};
14
15use anyhow::Result;
16use serde::{Deserialize, Serialize};
17
18#[derive(Debug, Clone, Serialize, Deserialize)]
20pub struct FewShotConfig {
21 pub k_shot: usize,
23 pub method: FewShotMethod,
25 pub in_context: Option<InContextConfig>,
27 pub prompt_tuning: Option<PromptConfig>,
29 pub meta_learning: Option<MetaLearningConfig>,
31 pub task_adaptation: Option<AdaptationConfig>,
33 pub enable_cross_task: bool,
35}
36
37impl Default for FewShotConfig {
38 fn default() -> Self {
39 Self {
40 k_shot: 5,
41 method: FewShotMethod::InContext,
42 in_context: Some(InContextConfig::default()),
43 prompt_tuning: None,
44 meta_learning: None,
45 task_adaptation: None,
46 enable_cross_task: false,
47 }
48 }
49}
50
51#[derive(Debug, Clone, Serialize, Deserialize)]
53pub enum FewShotMethod {
54 InContext,
56 PromptTuning,
58 MetaLearning,
60 TaskAdaptation,
62 Combined(Vec<FewShotMethod>),
64}
65
66#[derive(Debug, Clone, Serialize, Deserialize)]
68pub enum MetaLearningConfig {
69 MAML(MAMLConfig),
70 Reptile(ReptileConfig),
71}
72
73#[derive(Debug, Clone, Serialize, Deserialize)]
75pub struct FewShotExample {
76 pub input: Vec<f32>,
77 pub output: Vec<f32>,
78 pub task_id: Option<String>,
79 pub metadata: Option<serde_json::Value>,
80}
81
82#[derive(Debug, Clone)]
84pub struct SupportSet {
85 pub examples: Vec<FewShotExample>,
86 pub task_id: String,
87 pub k_shot: usize,
88 pub num_classes: Option<usize>,
89}
90
91impl SupportSet {
92 pub fn new(task_id: String, k_shot: usize) -> Self {
93 Self {
94 examples: Vec::new(),
95 task_id,
96 k_shot,
97 num_classes: None,
98 }
99 }
100
101 pub fn add_example(&mut self, example: FewShotExample) -> Result<()> {
102 if self.examples.len() >= self.k_shot * self.num_classes.unwrap_or(usize::MAX) {
103 return Err(anyhow::anyhow!("Support set is full"));
104 }
105 self.examples.push(example);
106 Ok(())
107 }
108
109 pub fn is_complete(&self) -> bool {
110 if let Some(num_classes) = self.num_classes {
111 self.examples.len() == self.k_shot * num_classes
112 } else {
113 false
114 }
115 }
116}
117
118#[derive(Debug, Clone)]
120pub struct QuerySet {
121 pub examples: Vec<FewShotExample>,
122 pub task_id: String,
123}
124
125pub struct FewShotLearningManager {
127 config: FewShotConfig,
128 support_sets: std::collections::HashMap<String, SupportSet>,
129 query_sets: std::collections::HashMap<String, QuerySet>,
130}
131
132impl FewShotLearningManager {
133 pub fn new(config: FewShotConfig) -> Self {
134 Self {
135 config,
136 support_sets: std::collections::HashMap::new(),
137 query_sets: std::collections::HashMap::new(),
138 }
139 }
140
141 pub fn create_support_set(&mut self, task_id: String, num_classes: usize) -> Result<()> {
142 let mut support_set = SupportSet::new(task_id.clone(), self.config.k_shot);
143 support_set.num_classes = Some(num_classes);
144 self.support_sets.insert(task_id, support_set);
145 Ok(())
146 }
147
148 pub fn add_support_example(&mut self, task_id: &str, example: FewShotExample) -> Result<()> {
149 let support_set = self
150 .support_sets
151 .get_mut(task_id)
152 .ok_or_else(|| anyhow::anyhow!("Support set not found for task: {}", task_id))?;
153 support_set.add_example(example)?;
154 Ok(())
155 }
156
157 pub fn create_query_set(&mut self, task_id: String) -> Result<()> {
158 let query_set = QuerySet {
159 examples: Vec::new(),
160 task_id: task_id.clone(),
161 };
162 self.query_sets.insert(task_id, query_set);
163 Ok(())
164 }
165
166 pub fn add_query_example(&mut self, task_id: &str, example: FewShotExample) -> Result<()> {
167 let query_set = self
168 .query_sets
169 .get_mut(task_id)
170 .ok_or_else(|| anyhow::anyhow!("Query set not found for task: {}", task_id))?;
171 query_set.examples.push(example);
172 Ok(())
173 }
174
175 pub fn get_support_set(&self, task_id: &str) -> Option<&SupportSet> {
176 self.support_sets.get(task_id)
177 }
178
179 pub fn get_query_set(&self, task_id: &str) -> Option<&QuerySet> {
180 self.query_sets.get(task_id)
181 }
182}
183
184#[cfg(test)]
185mod tests {
186 use super::*;
187
188 #[test]
189 fn test_support_set() {
190 let mut support_set = SupportSet::new("task1".to_string(), 5);
191 support_set.num_classes = Some(2);
192
193 for i in 0..10 {
194 let example = FewShotExample {
195 input: vec![i as f32],
196 output: vec![(i % 2) as f32],
197 task_id: Some("task1".to_string()),
198 metadata: None,
199 };
200 support_set.add_example(example).expect("add operation failed");
201 }
202
203 assert!(support_set.is_complete());
204 assert_eq!(support_set.examples.len(), 10);
205 }
206
207 #[test]
208 fn test_few_shot_manager() {
209 let config = FewShotConfig::default();
210 let mut manager = FewShotLearningManager::new(config);
211
212 manager
213 .create_support_set("task1".to_string(), 2)
214 .expect("operation failed in test");
215 manager.create_query_set("task1".to_string()).expect("operation failed in test");
216
217 let example = FewShotExample {
218 input: vec![1.0, 2.0],
219 output: vec![0.0],
220 task_id: Some("task1".to_string()),
221 metadata: None,
222 };
223
224 manager
225 .add_support_example("task1", example.clone())
226 .expect("add operation failed");
227 manager.add_query_example("task1", example).expect("add operation failed");
228
229 assert!(manager.get_support_set("task1").is_some());
230 assert!(manager.get_query_set("task1").is_some());
231 }
232}