Skip to main content

trustformers_training/few_shot/
mod.rs

1pub mod cross_task;
2pub mod in_context;
3pub mod meta_learning;
4pub mod prompt_tuning;
5pub mod task_adaptation;
6
7pub use cross_task::{CrossTaskGeneralizer, GeneralizationConfig, TaskEmbedding};
8pub use in_context::{ICLExample, InContextConfig, InContextLearner};
9pub use meta_learning::{
10    MAMLConfig, MAMLTrainer, MetaLearningAlgorithm, ReptileConfig, ReptileTrainer, TaskBatch,
11};
12pub use prompt_tuning::{PromptConfig, PromptTuner, SoftPrompt};
13pub use task_adaptation::{AdaptationConfig, TaskAdapter, TaskDescriptor};
14
15use anyhow::Result;
16use serde::{Deserialize, Serialize};
17
18/// Configuration for few-shot and zero-shot learning
19#[derive(Debug, Clone, Serialize, Deserialize)]
20pub struct FewShotConfig {
21    /// Number of examples per class (K in K-shot)
22    pub k_shot: usize,
23    /// Method for few-shot learning
24    pub method: FewShotMethod,
25    /// In-context learning configuration
26    pub in_context: Option<InContextConfig>,
27    /// Prompt tuning configuration
28    pub prompt_tuning: Option<PromptConfig>,
29    /// Meta-learning configuration
30    pub meta_learning: Option<MetaLearningConfig>,
31    /// Task adaptation configuration
32    pub task_adaptation: Option<AdaptationConfig>,
33    /// Whether to use cross-task generalization
34    pub enable_cross_task: bool,
35}
36
37impl Default for FewShotConfig {
38    fn default() -> Self {
39        Self {
40            k_shot: 5,
41            method: FewShotMethod::InContext,
42            in_context: Some(InContextConfig::default()),
43            prompt_tuning: None,
44            meta_learning: None,
45            task_adaptation: None,
46            enable_cross_task: false,
47        }
48    }
49}
50
51/// Methods for few-shot learning
52#[derive(Debug, Clone, Serialize, Deserialize)]
53pub enum FewShotMethod {
54    /// In-context learning (like GPT-3)
55    InContext,
56    /// Prompt tuning with soft prompts
57    PromptTuning,
58    /// Meta-learning (MAML, Reptile)
59    MetaLearning,
60    /// Task-specific adaptation
61    TaskAdaptation,
62    /// Combined approach
63    Combined(Vec<FewShotMethod>),
64}
65
66/// Meta-learning configuration wrapper
67#[derive(Debug, Clone, Serialize, Deserialize)]
68pub enum MetaLearningConfig {
69    MAML(MAMLConfig),
70    Reptile(ReptileConfig),
71}
72
73/// Few-shot learning example
74#[derive(Debug, Clone, Serialize, Deserialize)]
75pub struct FewShotExample {
76    pub input: Vec<f32>,
77    pub output: Vec<f32>,
78    pub task_id: Option<String>,
79    pub metadata: Option<serde_json::Value>,
80}
81
82/// Support set for few-shot learning
83#[derive(Debug, Clone)]
84pub struct SupportSet {
85    pub examples: Vec<FewShotExample>,
86    pub task_id: String,
87    pub k_shot: usize,
88    pub num_classes: Option<usize>,
89}
90
91impl SupportSet {
92    pub fn new(task_id: String, k_shot: usize) -> Self {
93        Self {
94            examples: Vec::new(),
95            task_id,
96            k_shot,
97            num_classes: None,
98        }
99    }
100
101    pub fn add_example(&mut self, example: FewShotExample) -> Result<()> {
102        if self.examples.len() >= self.k_shot * self.num_classes.unwrap_or(usize::MAX) {
103            return Err(anyhow::anyhow!("Support set is full"));
104        }
105        self.examples.push(example);
106        Ok(())
107    }
108
109    pub fn is_complete(&self) -> bool {
110        if let Some(num_classes) = self.num_classes {
111            self.examples.len() == self.k_shot * num_classes
112        } else {
113            false
114        }
115    }
116}
117
118/// Query set for evaluation
119#[derive(Debug, Clone)]
120pub struct QuerySet {
121    pub examples: Vec<FewShotExample>,
122    pub task_id: String,
123}
124
125/// Few-shot learning manager
126pub struct FewShotLearningManager {
127    config: FewShotConfig,
128    support_sets: std::collections::HashMap<String, SupportSet>,
129    query_sets: std::collections::HashMap<String, QuerySet>,
130}
131
132impl FewShotLearningManager {
133    pub fn new(config: FewShotConfig) -> Self {
134        Self {
135            config,
136            support_sets: std::collections::HashMap::new(),
137            query_sets: std::collections::HashMap::new(),
138        }
139    }
140
141    pub fn create_support_set(&mut self, task_id: String, num_classes: usize) -> Result<()> {
142        let mut support_set = SupportSet::new(task_id.clone(), self.config.k_shot);
143        support_set.num_classes = Some(num_classes);
144        self.support_sets.insert(task_id, support_set);
145        Ok(())
146    }
147
148    pub fn add_support_example(&mut self, task_id: &str, example: FewShotExample) -> Result<()> {
149        let support_set = self
150            .support_sets
151            .get_mut(task_id)
152            .ok_or_else(|| anyhow::anyhow!("Support set not found for task: {}", task_id))?;
153        support_set.add_example(example)?;
154        Ok(())
155    }
156
157    pub fn create_query_set(&mut self, task_id: String) -> Result<()> {
158        let query_set = QuerySet {
159            examples: Vec::new(),
160            task_id: task_id.clone(),
161        };
162        self.query_sets.insert(task_id, query_set);
163        Ok(())
164    }
165
166    pub fn add_query_example(&mut self, task_id: &str, example: FewShotExample) -> Result<()> {
167        let query_set = self
168            .query_sets
169            .get_mut(task_id)
170            .ok_or_else(|| anyhow::anyhow!("Query set not found for task: {}", task_id))?;
171        query_set.examples.push(example);
172        Ok(())
173    }
174
175    pub fn get_support_set(&self, task_id: &str) -> Option<&SupportSet> {
176        self.support_sets.get(task_id)
177    }
178
179    pub fn get_query_set(&self, task_id: &str) -> Option<&QuerySet> {
180        self.query_sets.get(task_id)
181    }
182}
183
184#[cfg(test)]
185mod tests {
186    use super::*;
187
188    #[test]
189    fn test_support_set() {
190        let mut support_set = SupportSet::new("task1".to_string(), 5);
191        support_set.num_classes = Some(2);
192
193        for i in 0..10 {
194            let example = FewShotExample {
195                input: vec![i as f32],
196                output: vec![(i % 2) as f32],
197                task_id: Some("task1".to_string()),
198                metadata: None,
199            };
200            support_set.add_example(example).expect("add operation failed");
201        }
202
203        assert!(support_set.is_complete());
204        assert_eq!(support_set.examples.len(), 10);
205    }
206
207    #[test]
208    fn test_few_shot_manager() {
209        let config = FewShotConfig::default();
210        let mut manager = FewShotLearningManager::new(config);
211
212        manager
213            .create_support_set("task1".to_string(), 2)
214            .expect("operation failed in test");
215        manager.create_query_set("task1".to_string()).expect("operation failed in test");
216
217        let example = FewShotExample {
218            input: vec![1.0, 2.0],
219            output: vec![0.0],
220            task_id: Some("task1".to_string()),
221            metadata: None,
222        };
223
224        manager
225            .add_support_example("task1", example.clone())
226            .expect("add operation failed");
227        manager.add_query_example("task1", example).expect("add operation failed");
228
229        assert!(manager.get_support_set("task1").is_some());
230        assert!(manager.get_query_set("task1").is_some());
231    }
232}