1use super::client::AnthropicClient;
4use super::domain::{classify_domain, extract_course_dir};
5use super::types::GroundTruthEntry;
6use rand::seq::SliceRandom;
7use rand::SeedableRng;
8use std::collections::HashMap;
9
10const SYSTEM_PROMPT: &str = "You generate evaluation questions from video transcript chunks.
11Given a transcript chunk, generate ONE specific question this text answers.
12Rules:
13(1) The question must be answerable only from the provided text.
14(2) Write a student-style query, 8-20 words long.
15(3) Do NOT reference \"the video\", \"the instructor\", \"the speaker\", or \"this lecture\".
16(4) Do NOT ask yes/no questions.
17(5) If the text is too vague or navigational to generate a good question, respond with exactly: SKIP";
18
19#[derive(Debug, Clone)]
21pub struct IndexChunk {
22 pub content: String,
24 pub source: String,
26 pub title: Option<String>,
28 pub start_secs: Option<f64>,
30 pub end_secs: Option<f64>,
32}
33
34pub struct GroundTruthGenerator {
36 client: AnthropicClient,
37 model: String,
38 sample_size: usize,
39 seed: u64,
40}
41
42impl GroundTruthGenerator {
43 pub fn new(client: AnthropicClient, model: &str, sample_size: usize, seed: u64) -> Self {
45 Self { client, model: model.to_string(), sample_size, seed }
46 }
47
48 pub fn sample_chunks(&self, chunks: &[IndexChunk]) -> Vec<SampledChunk> {
50 let mut rng = rand::rngs::StdRng::seed_from_u64(self.seed);
51
52 let mut by_course: HashMap<String, Vec<&IndexChunk>> = HashMap::new();
54 for chunk in chunks {
55 let course = extract_course_dir(&chunk.source).to_string();
56 by_course.entry(course).or_default().push(chunk);
57 }
58
59 let mut courses: Vec<(String, Vec<&IndexChunk>)> = by_course.into_iter().collect();
61 courses.sort_by(|a, b| b.1.len().cmp(&a.1.len()).then_with(|| a.0.cmp(&b.0)));
62
63 let mut sampled = Vec::new();
64
65 for (course, course_chunks) in &courses {
66 let eligible: Vec<&&IndexChunk> =
68 course_chunks.iter().filter(|c| is_eligible(c)).collect();
69
70 if eligible.len() < 2 {
71 continue;
72 }
73
74 let n = eligible.len().min(3);
76 let mut indices: Vec<usize> = (0..eligible.len()).collect();
77 indices.shuffle(&mut rng);
78
79 for &idx in indices.iter().take(n) {
80 let chunk = eligible[idx];
81 sampled.push(SampledChunk {
82 content: chunk.content.clone(),
83 source: chunk.source.clone(),
84 start_secs: chunk.start_secs,
85 end_secs: chunk.end_secs,
86 course: course.clone(),
87 domain: classify_domain(course).to_string(),
88 });
89 }
90
91 if sampled.len() >= self.sample_size {
92 break;
93 }
94 }
95
96 sampled.truncate(self.sample_size);
98
99 let mut domain_counts: HashMap<&str, usize> = HashMap::new();
101 for s in &sampled {
102 *domain_counts.entry(&s.domain).or_default() += 1;
103 }
104 eprintln!(
105 "Sampled {} chunks from {} courses",
106 sampled.len(),
107 courses.len().min(sampled.len())
108 );
109 let mut sorted_domains: Vec<_> = domain_counts.into_iter().collect();
110 sorted_domains.sort_by_key(|(_, c)| std::cmp::Reverse(*c));
111 for (domain, count) in &sorted_domains {
112 eprintln!(" {domain}: {count}");
113 }
114
115 sampled
116 }
117
118 pub async fn generate_question(&self, content: &str) -> Result<Option<String>, String> {
120 let user_msg = format!("Transcript chunk:\n---\n{content}\n---");
121
122 let result = self.client.complete(&self.model, Some(SYSTEM_PROMPT), &user_msg, 150).await?;
123
124 let text = result.text.trim().to_string();
125 if text == "SKIP" || text.starts_with("SKIP") {
126 return Ok(None);
127 }
128
129 let mut question = text.trim_matches('"').trim_matches('\'').trim().to_string();
131 if !question.ends_with('?') {
132 question.push('?');
133 }
134
135 Ok(Some(question))
136 }
137
138 pub async fn generate(&self, chunks: &[IndexChunk]) -> Result<Vec<GroundTruthEntry>, String> {
140 let sampled = self.sample_chunks(chunks);
141 let total = sampled.len();
142 let mut results = Vec::new();
143 let mut skipped = 0usize;
144 let mut errors = 0usize;
145
146 for (i, sample) in sampled.iter().enumerate() {
147 eprint!("[{}/{}] {} ({})...", i + 1, total, sample.course, sample.domain);
148
149 match self.generate_question(&sample.content).await {
150 Ok(Some(question)) => {
151 eprintln!(" {}", &question[..question.len().min(60)]);
152 results.push(GroundTruthEntry {
153 query: question,
154 chunk_content: sample.content.clone(),
155 chunk_source: sample.source.clone(),
156 chunk_start_secs: sample.start_secs,
157 chunk_end_secs: sample.end_secs,
158 domain: sample.domain.clone(),
159 course: sample.course.clone(),
160 });
161 }
162 Ok(None) => {
163 eprintln!(" SKIP");
164 skipped += 1;
165 }
166 Err(e) => {
167 eprintln!(" ERROR: {e}");
168 errors += 1;
169 }
170 }
171 }
172
173 eprintln!("\nGenerated {} queries, {} skipped, {} errors", results.len(), skipped, errors);
174
175 Ok(results)
176 }
177}
178
179#[derive(Debug, Clone)]
181pub struct SampledChunk {
182 pub content: String,
184 pub source: String,
186 pub start_secs: Option<f64>,
188 pub end_secs: Option<f64>,
190 pub course: String,
192 pub domain: String,
194}
195
196fn is_eligible(chunk: &IndexChunk) -> bool {
198 let words: Vec<&str> = chunk.content.split_whitespace().collect();
199 if words.len() < 50 {
200 return false;
201 }
202 let lowered: Vec<String> = words.iter().map(|w| w.to_lowercase()).collect();
203 let unique: std::collections::HashSet<&str> = lowered.iter().map(|w| w.as_str()).collect();
204 if unique.len() < 15 {
205 return false;
206 }
207
208 let lower = chunk.content.to_lowercase();
210 let nav_phrases = [
211 "welcome back",
212 "in this video",
213 "let's go ahead",
214 "see you in the next",
215 "don't forget to subscribe",
216 "click the link",
217 "table of contents",
218 ];
219 let nav_count = nav_phrases.iter().filter(|p| lower.contains(*p)).count();
220 nav_count < 3
221}
222
223#[cfg(test)]
224mod tests {
225 use super::*;
226
227 fn make_chunk(content: &str, source: &str) -> IndexChunk {
228 IndexChunk {
229 content: content.to_string(),
230 source: source.to_string(),
231 title: None,
232 start_secs: Some(0.0),
233 end_secs: Some(30.0),
234 }
235 }
236
237 #[test]
238 fn test_is_eligible_short() {
239 let chunk = make_chunk("too short", "/data/courses/test/build/a.srt");
240 assert!(!is_eligible(&chunk));
241 }
242
243 #[test]
244 fn test_is_eligible_valid() {
245 let words: Vec<String> = (0..60).map(|i| format!("word{i}")).collect();
246 let content = words.join(" ");
247 let chunk = make_chunk(&content, "/data/courses/test/build/a.srt");
248 assert!(is_eligible(&chunk));
249 }
250
251 #[test]
252 fn test_sampling_deterministic() {
253 let chunks: Vec<IndexChunk> = (0..100)
254 .map(|i| {
255 let words: Vec<String> = (0..60).map(|j| format!("w{j}c{i}")).collect();
256 make_chunk(
257 &words.join(" "),
258 &format!("/data/courses/course-{}/build/vid.srt", i / 5),
259 )
260 })
261 .collect();
262
263 let gen1 = GroundTruthGenerator::new(AnthropicClient::new("fake"), "model", 20, 42);
264 let gen2 = GroundTruthGenerator::new(AnthropicClient::new("fake"), "model", 20, 42);
265
266 let s1 = gen1.sample_chunks(&chunks);
267 let s2 = gen2.sample_chunks(&chunks);
268
269 assert_eq!(s1.len(), s2.len());
270 for (a, b) in s1.iter().zip(s2.iter()) {
271 assert_eq!(a.source, b.source);
272 assert_eq!(a.course, b.course);
273 }
274 }
275}