1use crate::{ChunkInput, Result, SqlRite, SqlRiteError};
2use serde::{Deserialize, Serialize};
3use serde_json::{Map, Value, json};
4use std::fs;
5use std::path::{Path, PathBuf};
6use std::process::Command;
7use std::thread;
8use std::time::{Duration, Instant};
9
10#[derive(Debug, Clone)]
11pub enum IngestionSource {
12 Direct { content: String },
13 File { path: PathBuf },
14 Url { url: String },
15}
16
17impl IngestionSource {
18 pub fn load_content(&self) -> Result<String> {
19 match self {
20 Self::Direct { content } => Ok(content.clone()),
21 Self::File { path } => Ok(fs::read_to_string(path)?),
22 Self::Url { url } => {
23 let output = Command::new("curl").arg("-fsSL").arg(url).output()?;
24 if !output.status.success() {
25 return Err(SqlRiteError::UnsupportedOperation(format!(
26 "url ingestion failed for `{url}`"
27 )));
28 }
29 String::from_utf8(output.stdout).map_err(|_| {
30 SqlRiteError::UnsupportedOperation("url content is not valid UTF-8".to_string())
31 })
32 }
33 }
34 }
35
36 fn source_label(&self) -> String {
37 match self {
38 Self::Direct { .. } => "direct".to_string(),
39 Self::File { path } => path.display().to_string(),
40 Self::Url { url } => url.clone(),
41 }
42 }
43}
44
45#[derive(Debug, Clone)]
46pub enum ChunkingStrategy {
47 Fixed {
48 max_chars: usize,
49 overlap_chars: usize,
50 },
51 HeadingAware {
52 max_chars: usize,
53 overlap_chars: usize,
54 },
55 Semantic {
56 max_chars: usize,
57 },
58}
59
60impl Default for ChunkingStrategy {
61 fn default() -> Self {
62 Self::HeadingAware {
63 max_chars: 1200,
64 overlap_chars: 120,
65 }
66 }
67}
68
69#[derive(Debug, Clone, Serialize, Deserialize)]
70pub struct IngestionCheckpoint {
71 pub job_id: String,
72 pub source_id: String,
73 pub next_chunk_index: usize,
74 pub updated_unix_ms: u64,
75}
76
77impl IngestionCheckpoint {
78 pub fn load(path: impl AsRef<Path>) -> Result<Option<Self>> {
79 let path = path.as_ref();
80 if !path.exists() {
81 return Ok(None);
82 }
83 let payload = fs::read_to_string(path)?;
84 let checkpoint = serde_json::from_str::<Self>(&payload)
85 .map_err(|e| SqlRiteError::InvalidIngestionCheckpoint(e.to_string()))?;
86 Ok(Some(checkpoint))
87 }
88
89 pub fn save(&self, path: impl AsRef<Path>) -> Result<()> {
90 let path = path.as_ref();
91 if let Some(parent) = path.parent()
92 && !parent.as_os_str().is_empty()
93 {
94 fs::create_dir_all(parent)?;
95 }
96
97 let payload = serde_json::to_string_pretty(self)?;
98 let temp = path.with_extension("tmp");
99 fs::write(&temp, payload)?;
100 fs::rename(temp, path)?;
101 Ok(())
102 }
103
104 pub fn clear(path: impl AsRef<Path>) -> Result<()> {
105 let path = path.as_ref();
106 if path.exists() {
107 fs::remove_file(path)?;
108 }
109 Ok(())
110 }
111}
112
113#[derive(Debug, Clone)]
114pub struct EmbeddingRetryPolicy {
115 pub max_retries: usize,
116 pub initial_backoff_ms: u64,
117 pub max_backoff_ms: u64,
118}
119
120impl Default for EmbeddingRetryPolicy {
121 fn default() -> Self {
122 Self {
123 max_retries: 3,
124 initial_backoff_ms: 50,
125 max_backoff_ms: 1_000,
126 }
127 }
128}
129
130pub trait EmbeddingProvider {
131 fn provider_name(&self) -> &str;
132 fn model_version(&self) -> &str;
133
134 fn embed_batch(&self, texts: &[String]) -> Result<Vec<std::result::Result<Vec<f32>, String>>>;
135}
136
137#[derive(Debug, Clone)]
138pub struct DeterministicEmbeddingProvider {
139 dimension: usize,
140 model_version: String,
141}
142
143impl DeterministicEmbeddingProvider {
144 pub fn new(dimension: usize, model_version: impl Into<String>) -> Result<Self> {
145 if dimension == 0 {
146 return Err(SqlRiteError::EmbeddingProvider(
147 "dimension must be greater than 0".to_string(),
148 ));
149 }
150 Ok(Self {
151 dimension,
152 model_version: model_version.into(),
153 })
154 }
155
156 fn embed_text(&self, text: &str) -> Vec<f32> {
157 let mut vector = vec![0.0f32; self.dimension];
158 for token in text
159 .split(|ch: char| !ch.is_ascii_alphanumeric())
160 .filter(|token| !token.is_empty())
161 {
162 let hash = fnv1a64(token.as_bytes());
163 let idx = (hash % self.dimension as u64) as usize;
164 vector[idx] += 1.0;
165 }
166
167 if vector.iter().all(|v| *v == 0.0) {
169 vector[0] = 1.0;
170 }
171
172 normalize(&mut vector);
173 vector
174 }
175}
176
177impl EmbeddingProvider for DeterministicEmbeddingProvider {
178 fn provider_name(&self) -> &str {
179 "deterministic_local"
180 }
181
182 fn model_version(&self) -> &str {
183 &self.model_version
184 }
185
186 fn embed_batch(&self, texts: &[String]) -> Result<Vec<std::result::Result<Vec<f32>, String>>> {
187 Ok(texts
188 .iter()
189 .map(|text| Ok(self.embed_text(text)))
190 .collect::<Vec<_>>())
191 }
192}
193
194#[derive(Debug, Clone)]
195pub struct OpenAiCompatibleEmbeddingProvider {
196 endpoint: String,
197 api_key: String,
198 model: String,
199 model_version: String,
200 timeout_secs: u64,
201}
202
203impl OpenAiCompatibleEmbeddingProvider {
204 pub fn new(
205 endpoint: impl Into<String>,
206 api_key: impl Into<String>,
207 model: impl Into<String>,
208 ) -> Result<Self> {
209 let endpoint = endpoint.into();
210 let api_key = api_key.into();
211 let model = model.into();
212 if endpoint.trim().is_empty() || api_key.trim().is_empty() || model.trim().is_empty() {
213 return Err(SqlRiteError::EmbeddingProvider(
214 "endpoint, api_key and model are required".to_string(),
215 ));
216 }
217 Ok(Self {
218 endpoint,
219 api_key,
220 model_version: model.clone(),
221 model,
222 timeout_secs: 30,
223 })
224 }
225
226 pub fn from_env(
227 endpoint: impl Into<String>,
228 model: impl Into<String>,
229 api_key_env: &str,
230 ) -> Result<Self> {
231 let api_key = std::env::var(api_key_env).map_err(|_| {
232 SqlRiteError::EmbeddingProvider(format!("missing required env var `{api_key_env}`"))
233 })?;
234 Self::new(endpoint, api_key, model)
235 }
236
237 pub fn with_timeout_secs(mut self, timeout_secs: u64) -> Self {
238 self.timeout_secs = timeout_secs.max(1);
239 self
240 }
241}
242
243impl EmbeddingProvider for OpenAiCompatibleEmbeddingProvider {
244 fn provider_name(&self) -> &str {
245 "openai_compatible_http"
246 }
247
248 fn model_version(&self) -> &str {
249 &self.model_version
250 }
251
252 fn embed_batch(&self, texts: &[String]) -> Result<Vec<std::result::Result<Vec<f32>, String>>> {
253 let payload = json!({
254 "model": self.model,
255 "input": texts,
256 });
257 let response = http_post_json(
258 &self.endpoint,
259 &payload,
260 &[(
261 "Authorization".to_string(),
262 format!("Bearer {}", self.api_key),
263 )],
264 self.timeout_secs,
265 )?;
266
267 parse_openai_embeddings_response(&response)
268 }
269}
270
271#[derive(Debug, Clone)]
272pub struct CustomHttpEmbeddingProvider {
273 endpoint: String,
274 model: Option<String>,
275 model_version: String,
276 input_field: String,
277 embeddings_field: String,
278 headers: Vec<(String, String)>,
279 timeout_secs: u64,
280}
281
282impl CustomHttpEmbeddingProvider {
283 pub fn new(endpoint: impl Into<String>, model_version: impl Into<String>) -> Result<Self> {
284 let endpoint = endpoint.into();
285 if endpoint.trim().is_empty() {
286 return Err(SqlRiteError::EmbeddingProvider(
287 "endpoint is required".to_string(),
288 ));
289 }
290 Ok(Self {
291 endpoint,
292 model: None,
293 model_version: model_version.into(),
294 input_field: "inputs".to_string(),
295 embeddings_field: "embeddings".to_string(),
296 headers: Vec::new(),
297 timeout_secs: 30,
298 })
299 }
300
301 pub fn with_model(mut self, model: impl Into<String>) -> Self {
302 self.model = Some(model.into());
303 self
304 }
305
306 pub fn with_header(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
307 self.headers.push((key.into(), value.into()));
308 self
309 }
310
311 pub fn with_fields(
312 mut self,
313 input_field: impl Into<String>,
314 embeddings_field: impl Into<String>,
315 ) -> Self {
316 self.input_field = input_field.into();
317 self.embeddings_field = embeddings_field.into();
318 self
319 }
320
321 pub fn with_timeout_secs(mut self, timeout_secs: u64) -> Self {
322 self.timeout_secs = timeout_secs.max(1);
323 self
324 }
325}
326
327impl EmbeddingProvider for CustomHttpEmbeddingProvider {
328 fn provider_name(&self) -> &str {
329 "custom_http"
330 }
331
332 fn model_version(&self) -> &str {
333 &self.model_version
334 }
335
336 fn embed_batch(&self, texts: &[String]) -> Result<Vec<std::result::Result<Vec<f32>, String>>> {
337 let mut payload = serde_json::Map::new();
338 payload.insert(
339 self.input_field.clone(),
340 Value::Array(texts.iter().cloned().map(Value::String).collect()),
341 );
342 if let Some(model) = &self.model {
343 payload.insert("model".to_string(), Value::String(model.clone()));
344 }
345
346 let response = http_post_json(
347 &self.endpoint,
348 &Value::Object(payload),
349 &self.headers,
350 self.timeout_secs,
351 )?;
352
353 if let Some(vectors) = response
354 .get(&self.embeddings_field)
355 .and_then(Value::as_array)
356 {
357 let mut out = Vec::with_capacity(vectors.len());
358 for vector in vectors {
359 out.push(parse_embedding_array(vector).map_err(|e| e.to_string()));
360 }
361 return Ok(out);
362 }
363
364 if let Some(results) = response.get("results").and_then(Value::as_array) {
365 let mut out = Vec::with_capacity(results.len());
366 for item in results {
367 if let Some(error) = item.get("error").and_then(Value::as_str) {
368 out.push(Err(error.to_string()));
369 continue;
370 }
371 let Some(embedding) = item.get("embedding") else {
372 out.push(Err("missing `embedding` field".to_string()));
373 continue;
374 };
375 out.push(parse_embedding_array(embedding).map_err(|e| e.to_string()));
376 }
377 return Ok(out);
378 }
379
380 if response.get("data").is_some() {
381 return parse_openai_embeddings_response(&response);
382 }
383
384 Err(SqlRiteError::EmbeddingProvider(
385 "unsupported custom embedding response schema".to_string(),
386 ))
387 }
388}
389
390#[derive(Debug, Clone)]
391pub struct IngestionRequest {
392 pub job_id: String,
393 pub doc_id: String,
394 pub source_id: String,
395 pub tenant_id: String,
396 pub source: IngestionSource,
397 pub metadata: Value,
398 pub chunking: ChunkingStrategy,
399 pub batch_size: usize,
400 pub batch_tuning: IngestionBatchTuning,
401 pub continue_on_partial_failure: bool,
402}
403
404#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
405pub struct IngestionBatchTuning {
406 pub adaptive: bool,
407 pub max_batch_size: usize,
408 pub target_batch_ms: u64,
409}
410
411impl Default for IngestionBatchTuning {
412 fn default() -> Self {
413 Self {
414 adaptive: true,
415 max_batch_size: 1024,
416 target_batch_ms: 80,
417 }
418 }
419}
420
421impl IngestionRequest {
422 pub fn from_direct(
423 job_id: impl Into<String>,
424 doc_id: impl Into<String>,
425 source_id: impl Into<String>,
426 tenant_id: impl Into<String>,
427 content: impl Into<String>,
428 ) -> Self {
429 Self {
430 job_id: job_id.into(),
431 doc_id: doc_id.into(),
432 source_id: source_id.into(),
433 tenant_id: tenant_id.into(),
434 source: IngestionSource::Direct {
435 content: content.into(),
436 },
437 metadata: json!({}),
438 chunking: ChunkingStrategy::default(),
439 batch_size: 64,
440 batch_tuning: IngestionBatchTuning::default(),
441 continue_on_partial_failure: false,
442 }
443 }
444
445 pub fn from_file(
446 job_id: impl Into<String>,
447 doc_id: impl Into<String>,
448 source_id: impl Into<String>,
449 tenant_id: impl Into<String>,
450 path: impl Into<PathBuf>,
451 ) -> Self {
452 Self {
453 job_id: job_id.into(),
454 doc_id: doc_id.into(),
455 source_id: source_id.into(),
456 tenant_id: tenant_id.into(),
457 source: IngestionSource::File { path: path.into() },
458 metadata: json!({}),
459 chunking: ChunkingStrategy::default(),
460 batch_size: 64,
461 batch_tuning: IngestionBatchTuning::default(),
462 continue_on_partial_failure: false,
463 }
464 }
465
466 pub fn from_url(
467 job_id: impl Into<String>,
468 doc_id: impl Into<String>,
469 source_id: impl Into<String>,
470 tenant_id: impl Into<String>,
471 url: impl Into<String>,
472 ) -> Self {
473 Self {
474 job_id: job_id.into(),
475 doc_id: doc_id.into(),
476 source_id: source_id.into(),
477 tenant_id: tenant_id.into(),
478 source: IngestionSource::Url { url: url.into() },
479 metadata: json!({}),
480 chunking: ChunkingStrategy::default(),
481 batch_size: 64,
482 batch_tuning: IngestionBatchTuning::default(),
483 continue_on_partial_failure: false,
484 }
485 }
486
487 pub fn with_metadata(mut self, metadata: Value) -> Self {
488 self.metadata = metadata;
489 self
490 }
491
492 pub fn with_chunking(mut self, chunking: ChunkingStrategy) -> Self {
493 self.chunking = chunking;
494 self
495 }
496
497 pub fn with_batch_size(mut self, batch_size: usize) -> Self {
498 self.batch_size = batch_size;
499 self
500 }
501
502 pub fn with_batch_tuning(mut self, batch_tuning: IngestionBatchTuning) -> Self {
503 self.batch_tuning = batch_tuning;
504 self
505 }
506
507 pub fn with_adaptive_batching(mut self, enabled: bool) -> Self {
508 self.batch_tuning.adaptive = enabled;
509 self
510 }
511
512 pub fn with_max_batch_size(mut self, max_batch_size: usize) -> Self {
513 self.batch_tuning.max_batch_size = max_batch_size.max(1);
514 self
515 }
516
517 pub fn with_target_batch_ms(mut self, target_batch_ms: u64) -> Self {
518 self.batch_tuning.target_batch_ms = target_batch_ms.max(1);
519 self
520 }
521
522 pub fn with_continue_on_partial_failure(mut self, enabled: bool) -> Self {
523 self.continue_on_partial_failure = enabled;
524 self
525 }
526}
527
528#[derive(Debug, Clone, Serialize, Deserialize)]
529pub struct IngestionReport {
530 pub total_chunks: usize,
531 pub processed_chunks: usize,
532 pub failed_chunks: usize,
533 pub resumed_from_chunk: usize,
534 pub duration_ms: f64,
535 pub throughput_chunks_per_minute: f64,
536 pub average_batch_size: f64,
537 pub peak_batch_size: usize,
538 pub batch_count: usize,
539 pub adaptive_batching: bool,
540 pub provider: String,
541 pub model_version: String,
542 pub source: String,
543}
544
545#[derive(Debug, Clone)]
546struct Segment {
547 content: String,
548 start: usize,
549 end: usize,
550 heading: Option<String>,
551}
552
553struct MetadataEnrichment<'a> {
554 tenant_id: &'a str,
555 content_hash: String,
556 source_start: usize,
557 source_end: usize,
558 heading: Option<&'a str>,
559 provider: &'a str,
560 model_version: &'a str,
561}
562
563pub struct IngestionWorker<'a, P: EmbeddingProvider> {
564 db: &'a SqlRite,
565 provider: P,
566 retry_policy: EmbeddingRetryPolicy,
567 checkpoint_path: Option<PathBuf>,
568}
569
570impl<'a, P: EmbeddingProvider> IngestionWorker<'a, P> {
571 pub fn new(db: &'a SqlRite, provider: P) -> Self {
572 Self {
573 db,
574 provider,
575 retry_policy: EmbeddingRetryPolicy::default(),
576 checkpoint_path: None,
577 }
578 }
579
580 pub fn with_retry_policy(mut self, retry_policy: EmbeddingRetryPolicy) -> Self {
581 self.retry_policy = retry_policy;
582 self
583 }
584
585 pub fn with_checkpoint_path(mut self, path: impl Into<PathBuf>) -> Self {
586 self.checkpoint_path = Some(path.into());
587 self
588 }
589
590 pub fn ingest(&self, request: IngestionRequest) -> Result<IngestionReport> {
591 if request.tenant_id.trim().is_empty() {
592 return Err(SqlRiteError::InvalidTenantId);
593 }
594 if request.batch_size == 0 {
595 return Err(SqlRiteError::InvalidBenchmarkConfig(
596 "ingestion batch_size must be >= 1".to_string(),
597 ));
598 }
599 if request.batch_tuning.max_batch_size == 0 {
600 return Err(SqlRiteError::InvalidBenchmarkConfig(
601 "ingestion max_batch_size must be >= 1".to_string(),
602 ));
603 }
604 if request.batch_tuning.target_batch_ms == 0 {
605 return Err(SqlRiteError::InvalidBenchmarkConfig(
606 "ingestion target_batch_ms must be >= 1".to_string(),
607 ));
608 }
609
610 let ingest_started = Instant::now();
611 let source_content = request.source.load_content()?;
612 let segments = chunk_content(&source_content, &request.chunking);
613
614 let resumed_from_chunk = self
615 .load_resume_checkpoint(&request.job_id, &request.source_id)?
616 .unwrap_or(0)
617 .min(segments.len());
618
619 let mut processed_chunks = 0usize;
620 let mut failed_chunks = 0usize;
621 let mut cursor = resumed_from_chunk;
622 let mut batch_count = 0usize;
623 let mut peak_batch_size = 0usize;
624 let mut total_batch_size = 0usize;
625 let mut next_batch_size = request
626 .batch_size
627 .max(1)
628 .min(request.batch_tuning.max_batch_size);
629
630 while cursor < segments.len() {
631 let batch_started = Instant::now();
632 let remaining = segments.len().saturating_sub(cursor);
633 let planned_batch = next_batch_size.min(remaining).max(1);
634 let end = (cursor + planned_batch).min(segments.len());
635 let batch = &segments[cursor..end];
636 batch_count += 1;
637 peak_batch_size = peak_batch_size.max(batch.len());
638 total_batch_size += batch.len();
639 let texts = batch
640 .iter()
641 .map(|segment| segment.content.clone())
642 .collect::<Vec<_>>();
643
644 let embedded = self.embed_with_retry(&texts)?;
645 let mut upserts = Vec::with_capacity(batch.len());
646 let mut batch_failed_chunks = 0usize;
647
648 for (idx, segment) in batch.iter().enumerate() {
649 let Some(embedding) = embedded[idx].clone() else {
650 batch_failed_chunks += 1;
651 failed_chunks += 1;
652 continue;
653 };
654
655 let chunk_id = chunk_id_for(
656 &request.tenant_id,
657 &request.doc_id,
658 segment.start,
659 segment.end,
660 &segment.content,
661 );
662
663 let metadata = merge_metadata(
664 &request.metadata,
665 &MetadataEnrichment {
666 tenant_id: &request.tenant_id,
667 content_hash: hex64(fnv1a64(segment.content.as_bytes())),
668 source_start: segment.start,
669 source_end: segment.end,
670 heading: segment.heading.as_deref(),
671 provider: self.provider.provider_name(),
672 model_version: self.provider.model_version(),
673 },
674 )?;
675
676 upserts.push(ChunkInput {
677 id: chunk_id,
678 doc_id: request.doc_id.clone(),
679 content: segment.content.clone(),
680 embedding,
681 metadata,
682 source: Some(request.source_id.clone()),
683 });
684 }
685
686 if !upserts.is_empty() {
687 self.db.ingest_chunks(&upserts)?;
688 processed_chunks += upserts.len();
689 }
690
691 cursor = end;
692 self.save_checkpoint(&request.job_id, &request.source_id, cursor)?;
693 let batch_duration_ms = batch_started.elapsed().as_secs_f64() * 1000.0;
694
695 if batch_failed_chunks > 0 && !request.continue_on_partial_failure {
696 return Err(SqlRiteError::EmbeddingBatchPartialFailure {
697 failed: failed_chunks,
698 });
699 }
700
701 if request.batch_tuning.adaptive {
702 let target_ms = request.batch_tuning.target_batch_ms as f64;
703 if batch_failed_chunks == 0 && batch_duration_ms <= target_ms * 0.60 {
704 let grown = next_batch_size
705 .saturating_add(next_batch_size / 2)
706 .saturating_add(1);
707 next_batch_size = grown.min(request.batch_tuning.max_batch_size).max(1);
708 } else if batch_duration_ms > target_ms || batch_failed_chunks > 0 {
709 next_batch_size = (next_batch_size / 2).max(1);
710 }
711 }
712 }
713
714 self.clear_checkpoint()?;
715 let duration_ms = ingest_started.elapsed().as_secs_f64() * 1000.0;
716 let throughput_chunks_per_minute = if duration_ms > 0.0 {
717 (processed_chunks as f64 / (duration_ms / 1000.0)) * 60.0
718 } else {
719 0.0
720 };
721 let average_batch_size = if batch_count > 0 {
722 total_batch_size as f64 / batch_count as f64
723 } else {
724 0.0
725 };
726
727 Ok(IngestionReport {
728 total_chunks: segments.len(),
729 processed_chunks,
730 failed_chunks,
731 resumed_from_chunk,
732 duration_ms,
733 throughput_chunks_per_minute,
734 average_batch_size,
735 peak_batch_size,
736 batch_count,
737 adaptive_batching: request.batch_tuning.adaptive,
738 provider: self.provider.provider_name().to_string(),
739 model_version: self.provider.model_version().to_string(),
740 source: request.source.source_label(),
741 })
742 }
743
744 fn embed_with_retry(&self, texts: &[String]) -> Result<Vec<Option<Vec<f32>>>> {
745 let mut pending: Vec<usize> = (0..texts.len()).collect();
746 let mut resolved = vec![None; texts.len()];
747 let mut attempt = 0usize;
748 let mut backoff_ms = self.retry_policy.initial_backoff_ms.max(1);
749
750 while !pending.is_empty() && attempt <= self.retry_policy.max_retries {
751 let current_texts = pending
752 .iter()
753 .map(|idx| texts[*idx].clone())
754 .collect::<Vec<_>>();
755 let responses = self.provider.embed_batch(¤t_texts)?;
756
757 if responses.len() != current_texts.len() {
758 return Err(SqlRiteError::EmbeddingProvider(
759 "provider returned mismatched batch length".to_string(),
760 ));
761 }
762
763 let mut next_pending = Vec::new();
764 for (slot, response) in responses.into_iter().enumerate() {
765 let original_idx = pending[slot];
766 match response {
767 Ok(embedding) => resolved[original_idx] = Some(embedding),
768 Err(_) => next_pending.push(original_idx),
769 }
770 }
771
772 pending = next_pending;
773 if !pending.is_empty() {
774 attempt += 1;
775 if attempt <= self.retry_policy.max_retries {
776 thread::sleep(Duration::from_millis(backoff_ms));
777 backoff_ms =
778 (backoff_ms.saturating_mul(2)).min(self.retry_policy.max_backoff_ms);
779 }
780 }
781 }
782
783 Ok(resolved)
784 }
785
786 fn load_resume_checkpoint(&self, job_id: &str, source_id: &str) -> Result<Option<usize>> {
787 let Some(path) = &self.checkpoint_path else {
788 return Ok(None);
789 };
790
791 let Some(checkpoint) = IngestionCheckpoint::load(path)? else {
792 return Ok(None);
793 };
794
795 if checkpoint.job_id == job_id && checkpoint.source_id == source_id {
796 Ok(Some(checkpoint.next_chunk_index))
797 } else {
798 Ok(None)
799 }
800 }
801
802 fn save_checkpoint(
803 &self,
804 job_id: &str,
805 source_id: &str,
806 next_chunk_index: usize,
807 ) -> Result<()> {
808 let Some(path) = &self.checkpoint_path else {
809 return Ok(());
810 };
811
812 let checkpoint = IngestionCheckpoint {
813 job_id: job_id.to_string(),
814 source_id: source_id.to_string(),
815 next_chunk_index,
816 updated_unix_ms: now_unix_ms(),
817 };
818 checkpoint.save(path)
819 }
820
821 fn clear_checkpoint(&self) -> Result<()> {
822 let Some(path) = &self.checkpoint_path else {
823 return Ok(());
824 };
825
826 IngestionCheckpoint::clear(path)
827 }
828}
829
830fn chunk_content(text: &str, strategy: &ChunkingStrategy) -> Vec<Segment> {
831 match strategy {
832 ChunkingStrategy::Fixed {
833 max_chars,
834 overlap_chars,
835 } => chunk_fixed(text, *max_chars, *overlap_chars),
836 ChunkingStrategy::HeadingAware {
837 max_chars,
838 overlap_chars,
839 } => chunk_heading_aware(text, *max_chars, *overlap_chars),
840 ChunkingStrategy::Semantic { max_chars } => chunk_semantic(text, *max_chars),
841 }
842}
843
844fn chunk_fixed(text: &str, max_chars: usize, overlap_chars: usize) -> Vec<Segment> {
845 if text.is_empty() {
846 return Vec::new();
847 }
848
849 let max_chars = max_chars.max(1);
850 let overlap_chars = overlap_chars.min(max_chars.saturating_sub(1));
851
852 let mut segments = Vec::new();
853 let mut start = 0usize;
854
855 while start < text.len() {
856 let mut end = (start + max_chars).min(text.len());
857 while end > start && !text.is_char_boundary(end) {
858 end -= 1;
859 }
860 if end <= start {
861 end = text.len();
862 }
863
864 segments.push(Segment {
865 content: text[start..end].to_string(),
866 start,
867 end,
868 heading: None,
869 });
870
871 if end == text.len() {
872 break;
873 }
874
875 let mut next_start = end.saturating_sub(overlap_chars);
876 while next_start > start && !text.is_char_boundary(next_start) {
877 next_start -= 1;
878 }
879 if next_start <= start {
880 next_start = end;
881 }
882 start = next_start;
883 }
884
885 segments
886}
887
888fn chunk_heading_aware(text: &str, max_chars: usize, overlap_chars: usize) -> Vec<Segment> {
889 if text.is_empty() {
890 return Vec::new();
891 }
892
893 let mut sections = Vec::new();
894 let mut offset = 0usize;
895 let mut section_start = 0usize;
896 let mut heading: Option<String> = None;
897
898 for line in text.split_inclusive('\n') {
899 let line_start = offset;
900 offset += line.len();
901 let trimmed = line.trim_start();
902 if trimmed.starts_with('#') {
903 if line_start > section_start {
904 sections.push((section_start, line_start, heading.clone()));
905 }
906 heading = Some(trimmed.trim().trim_start_matches('#').trim().to_string());
907 section_start = line_start;
908 }
909 }
910
911 if section_start < text.len() {
912 sections.push((section_start, text.len(), heading));
913 }
914
915 if sections.is_empty() {
916 return chunk_fixed(text, max_chars, overlap_chars);
917 }
918
919 let mut segments = Vec::new();
920 for (start, end, heading) in sections {
921 let section_text = &text[start..end];
922 for mut part in chunk_fixed(section_text, max_chars, overlap_chars) {
923 part.start += start;
924 part.end += start;
925 part.heading = heading.clone();
926 segments.push(part);
927 }
928 }
929
930 segments
931}
932
933fn chunk_semantic(text: &str, max_chars: usize) -> Vec<Segment> {
934 if text.is_empty() {
935 return Vec::new();
936 }
937
938 let max_chars = max_chars.max(1);
939 let mut sentence_bounds = Vec::new();
940 let mut sentence_start = 0usize;
941
942 for (idx, ch) in text.char_indices() {
943 if matches!(ch, '.' | '!' | '?') {
944 let end = idx + ch.len_utf8();
945 if end > sentence_start {
946 sentence_bounds.push((sentence_start, end));
947 sentence_start = end;
948 }
949 }
950 }
951 if sentence_start < text.len() {
952 sentence_bounds.push((sentence_start, text.len()));
953 }
954
955 if sentence_bounds.is_empty() {
956 return chunk_fixed(text, max_chars, 0);
957 }
958
959 let mut segments = Vec::new();
960 let mut current_start = sentence_bounds[0].0;
961 let mut current_end = sentence_bounds[0].0;
962
963 for (start, end) in sentence_bounds {
964 if end.saturating_sub(current_start) > max_chars && current_end > current_start {
965 segments.push(Segment {
966 content: text[current_start..current_end].trim().to_string(),
967 start: current_start,
968 end: current_end,
969 heading: None,
970 });
971 current_start = start;
972 }
973 current_end = end;
974 }
975
976 if current_end > current_start {
977 segments.push(Segment {
978 content: text[current_start..current_end].trim().to_string(),
979 start: current_start,
980 end: current_end,
981 heading: None,
982 });
983 }
984
985 if segments.is_empty() {
986 chunk_fixed(text, max_chars, 0)
987 } else {
988 segments
989 .into_iter()
990 .filter(|segment| !segment.content.is_empty())
991 .collect()
992 }
993}
994
995fn merge_metadata(base: &Value, enrichment: &MetadataEnrichment<'_>) -> Result<Value> {
996 let mut metadata_obj = match base {
997 Value::Object(map) => map.clone(),
998 _ => Map::new(),
999 };
1000
1001 metadata_obj.insert(
1002 "tenant".to_string(),
1003 Value::String(enrichment.tenant_id.to_string()),
1004 );
1005 metadata_obj.insert(
1006 "content_hash".to_string(),
1007 Value::String(enrichment.content_hash.clone()),
1008 );
1009 metadata_obj.insert(
1010 "source_start".to_string(),
1011 Value::Number(serde_json::Number::from(enrichment.source_start as u64)),
1012 );
1013 metadata_obj.insert(
1014 "source_end".to_string(),
1015 Value::Number(serde_json::Number::from(enrichment.source_end as u64)),
1016 );
1017 metadata_obj.insert(
1018 "embedding_provider".to_string(),
1019 Value::String(enrichment.provider.to_string()),
1020 );
1021 metadata_obj.insert(
1022 "embedding_model_version".to_string(),
1023 Value::String(enrichment.model_version.to_string()),
1024 );
1025
1026 if let Some(heading) = enrichment.heading
1027 && !heading.is_empty()
1028 {
1029 metadata_obj.insert("heading".to_string(), Value::String(heading.to_string()));
1030 }
1031
1032 Ok(Value::Object(metadata_obj))
1033}
1034
1035fn chunk_id_for(tenant_id: &str, doc_id: &str, start: usize, end: usize, content: &str) -> String {
1036 let mut seed = Vec::new();
1037 seed.extend_from_slice(tenant_id.as_bytes());
1038 seed.push(0);
1039 seed.extend_from_slice(doc_id.as_bytes());
1040 seed.push(0);
1041 seed.extend_from_slice(start.to_string().as_bytes());
1042 seed.push(0);
1043 seed.extend_from_slice(end.to_string().as_bytes());
1044 seed.push(0);
1045 seed.extend_from_slice(content.as_bytes());
1046
1047 let hash = hex64(fnv1a64(&seed));
1048 let mut tenant = tenant_id
1049 .chars()
1050 .filter(|ch| ch.is_ascii_alphanumeric() || *ch == '_')
1051 .collect::<String>();
1052 if tenant.is_empty() {
1053 tenant = "tenant".to_string();
1054 }
1055
1056 format!("{tenant}-{hash}")
1057}
1058
1059fn now_unix_ms() -> u64 {
1060 std::time::SystemTime::now()
1061 .duration_since(std::time::UNIX_EPOCH)
1062 .map(|d| d.as_millis() as u64)
1063 .unwrap_or(0)
1064}
1065
1066fn http_post_json(
1067 endpoint: &str,
1068 payload: &Value,
1069 headers: &[(String, String)],
1070 timeout_secs: u64,
1071) -> Result<Value> {
1072 let mut cmd = Command::new("curl");
1073 cmd.arg("-fsS")
1074 .arg("-X")
1075 .arg("POST")
1076 .arg("--max-time")
1077 .arg(timeout_secs.max(1).to_string())
1078 .arg("-H")
1079 .arg("Content-Type: application/json");
1080
1081 for (key, value) in headers {
1082 cmd.arg("-H").arg(format!("{key}: {value}"));
1083 }
1084
1085 let output = cmd
1086 .arg("-d")
1087 .arg(payload.to_string())
1088 .arg(endpoint)
1089 .output()?;
1090
1091 if !output.status.success() {
1092 let stderr = String::from_utf8_lossy(&output.stderr);
1093 return Err(SqlRiteError::EmbeddingProvider(format!(
1094 "http request failed for `{endpoint}`: {stderr}"
1095 )));
1096 }
1097
1098 serde_json::from_slice::<Value>(&output.stdout).map_err(|e| {
1099 SqlRiteError::EmbeddingProvider(format!("invalid json response from `{endpoint}`: {e}"))
1100 })
1101}
1102
1103fn parse_openai_embeddings_response(
1104 response: &Value,
1105) -> Result<Vec<std::result::Result<Vec<f32>, String>>> {
1106 let data = response
1107 .get("data")
1108 .and_then(Value::as_array)
1109 .ok_or_else(|| SqlRiteError::EmbeddingProvider("missing `data` array".to_string()))?;
1110
1111 let mut out = Vec::with_capacity(data.len());
1112 for row in data {
1113 let Some(embedding) = row.get("embedding") else {
1114 out.push(Err("missing `embedding` field".to_string()));
1115 continue;
1116 };
1117 out.push(parse_embedding_array(embedding).map_err(|e| e.to_string()));
1118 }
1119 Ok(out)
1120}
1121
1122fn parse_embedding_array(value: &Value) -> Result<Vec<f32>> {
1123 let array = value
1124 .as_array()
1125 .ok_or_else(|| SqlRiteError::EmbeddingProvider("embedding must be an array".to_string()))?;
1126 let mut embedding = Vec::with_capacity(array.len());
1127 for item in array {
1128 let Some(number) = item.as_f64() else {
1129 return Err(SqlRiteError::EmbeddingProvider(
1130 "embedding item must be numeric".to_string(),
1131 ));
1132 };
1133 embedding.push(number as f32);
1134 }
1135 if embedding.is_empty() {
1136 return Err(SqlRiteError::EmbeddingProvider(
1137 "embedding array cannot be empty".to_string(),
1138 ));
1139 }
1140 Ok(embedding)
1141}
1142
1143fn fnv1a64(bytes: &[u8]) -> u64 {
1144 let mut hash = 0xcbf29ce484222325u64;
1145 for byte in bytes {
1146 hash ^= *byte as u64;
1147 hash = hash.wrapping_mul(0x100000001b3);
1148 }
1149 hash
1150}
1151
1152fn hex64(value: u64) -> String {
1153 format!("{value:016x}")
1154}
1155
1156fn normalize(vector: &mut [f32]) {
1157 let norm = vector.iter().map(|v| v * v).sum::<f32>().sqrt();
1158 if norm > 0.0 {
1159 for value in vector {
1160 *value /= norm;
1161 }
1162 }
1163}
1164
1165#[cfg(test)]
1166mod tests {
1167 use super::*;
1168 use crate::RuntimeConfig;
1169 use serde_json::json;
1170 use tempfile::tempdir;
1171
1172 #[test]
1173 fn fixed_chunking_respects_overlap() {
1174 let text = "abcdefghijklmnopqrstuvwxyz";
1175 let chunks = chunk_fixed(text, 10, 3);
1176 assert!(chunks.len() >= 3);
1177 assert_eq!(chunks[0].content, "abcdefghij");
1178 assert_eq!(chunks[1].content, "hijklmnopq");
1179 }
1180
1181 #[test]
1182 fn checkpoint_round_trip() -> Result<()> {
1183 let dir = tempdir()?;
1184 let path = dir.path().join("checkpoint.json");
1185 let checkpoint = IngestionCheckpoint {
1186 job_id: "job-a".to_string(),
1187 source_id: "source-a".to_string(),
1188 next_chunk_index: 42,
1189 updated_unix_ms: 1,
1190 };
1191
1192 checkpoint.save(&path)?;
1193 let loaded = IngestionCheckpoint::load(&path)?.expect("checkpoint exists");
1194 assert_eq!(loaded.next_chunk_index, 42);
1195
1196 IngestionCheckpoint::clear(&path)?;
1197 assert!(IngestionCheckpoint::load(&path)?.is_none());
1198 Ok(())
1199 }
1200
1201 #[test]
1202 fn ingestion_request_convenience_builders_set_fields() {
1203 let req = IngestionRequest::from_file("job", "doc", "source", "acme", "README.md")
1204 .with_chunking(ChunkingStrategy::Fixed {
1205 max_chars: 200,
1206 overlap_chars: 20,
1207 })
1208 .with_batch_size(32)
1209 .with_max_batch_size(256)
1210 .with_target_batch_ms(120)
1211 .with_continue_on_partial_failure(true);
1212 assert_eq!(req.batch_size, 32);
1213 assert_eq!(req.batch_tuning.max_batch_size, 256);
1214 assert_eq!(req.batch_tuning.target_batch_ms, 120);
1215 assert!(req.continue_on_partial_failure);
1216 assert!(matches!(req.source, IngestionSource::File { .. }));
1217 }
1218
1219 #[test]
1220 fn parse_openai_response_extracts_embeddings() -> Result<()> {
1221 let response = json!({
1222 "data": [
1223 {"embedding": [0.1, 0.2, 0.3]},
1224 {"embedding": [0.4, 0.5, 0.6]}
1225 ]
1226 });
1227 let parsed = parse_openai_embeddings_response(&response)?;
1228 assert_eq!(parsed.len(), 2);
1229 assert!(parsed.iter().all(std::result::Result::is_ok));
1230 Ok(())
1231 }
1232
1233 #[test]
1234 fn custom_provider_parses_results_schema() -> Result<()> {
1235 let provider = CustomHttpEmbeddingProvider::new("http://localhost:1234", "v1")?;
1236 let response = json!({
1237 "results": [
1238 {"embedding": [0.1, 0.2]},
1239 {"error": "rate_limited"}
1240 ]
1241 });
1242
1243 let parsed = if let Some(results) = response.get("results").and_then(Value::as_array) {
1245 let mut out = Vec::new();
1246 for item in results {
1247 if let Some(error) = item.get("error").and_then(Value::as_str) {
1248 out.push(Err(error.to_string()));
1249 continue;
1250 }
1251 let embedding = item.get("embedding").expect("embedding exists");
1252 out.push(parse_embedding_array(embedding).map_err(|e| e.to_string()));
1253 }
1254 out
1255 } else {
1256 Vec::new()
1257 };
1258
1259 assert_eq!(provider.model_version(), "v1");
1260 assert_eq!(parsed.len(), 2);
1261 assert!(parsed[0].is_ok());
1262 assert!(parsed[1].is_err());
1263 Ok(())
1264 }
1265
1266 #[test]
1267 fn ingestion_is_idempotent_for_same_payload() -> Result<()> {
1268 let db = SqlRite::open_in_memory_with_config(RuntimeConfig::default())?;
1269 let provider = DeterministicEmbeddingProvider::new(64, "det-v1")?;
1270 let worker = IngestionWorker::new(&db, provider);
1271
1272 let request = IngestionRequest {
1273 job_id: "job-1".to_string(),
1274 doc_id: "doc-1".to_string(),
1275 source_id: "payload-1".to_string(),
1276 tenant_id: "acme".to_string(),
1277 source: IngestionSource::Direct {
1278 content: "# Intro\nRust agents need deterministic retrieval.\n\n# Details\nSQLite RAG memory is portable.".to_string(),
1279 },
1280 metadata: json!({"kind": "guide"}),
1281 chunking: ChunkingStrategy::HeadingAware {
1282 max_chars: 40,
1283 overlap_chars: 5,
1284 },
1285 batch_size: 4,
1286 batch_tuning: IngestionBatchTuning::default(),
1287 continue_on_partial_failure: false,
1288 };
1289
1290 let first = worker.ingest(request.clone())?;
1291 let count_after_first = db.chunk_count()?;
1292 let second = worker.ingest(request)?;
1293 let count_after_second = db.chunk_count()?;
1294
1295 assert!(first.total_chunks > 0);
1296 assert_eq!(count_after_first, count_after_second);
1297 assert_eq!(second.failed_chunks, 0);
1298 Ok(())
1299 }
1300
1301 #[test]
1302 fn ingestion_resumes_from_checkpoint() -> Result<()> {
1303 let db = SqlRite::open_in_memory_with_config(RuntimeConfig::default())?;
1304 let provider = DeterministicEmbeddingProvider::new(32, "det-v1")?;
1305 let dir = tempdir()?;
1306 let checkpoint_path = dir.path().join("ingest.checkpoint.json");
1307
1308 let checkpoint = IngestionCheckpoint {
1309 job_id: "job-resume".to_string(),
1310 source_id: "source-resume".to_string(),
1311 next_chunk_index: 1,
1312 updated_unix_ms: now_unix_ms(),
1313 };
1314 checkpoint.save(&checkpoint_path)?;
1315
1316 let worker = IngestionWorker::new(&db, provider).with_checkpoint_path(&checkpoint_path);
1317 let request = IngestionRequest {
1318 job_id: "job-resume".to_string(),
1319 doc_id: "doc-r".to_string(),
1320 source_id: "source-resume".to_string(),
1321 tenant_id: "acme".to_string(),
1322 source: IngestionSource::Direct {
1323 content: "one two three four five six seven eight nine ten".to_string(),
1324 },
1325 metadata: json!({}),
1326 chunking: ChunkingStrategy::Fixed {
1327 max_chars: 10,
1328 overlap_chars: 0,
1329 },
1330 batch_size: 2,
1331 batch_tuning: IngestionBatchTuning::default(),
1332 continue_on_partial_failure: false,
1333 };
1334
1335 let report = worker.ingest(request)?;
1336 assert!(report.resumed_from_chunk >= 1);
1337 assert!(IngestionCheckpoint::load(&checkpoint_path)?.is_none());
1338 Ok(())
1339 }
1340
1341 #[test]
1342 fn ingestion_reports_throughput_and_adaptive_batching_stats() -> Result<()> {
1343 let db = SqlRite::open_in_memory_with_config(RuntimeConfig::default())?;
1344 let provider = DeterministicEmbeddingProvider::new(64, "det-v1")?;
1345 let worker = IngestionWorker::new(&db, provider);
1346
1347 let request = IngestionRequest::from_direct(
1348 "job-batch-stats",
1349 "doc-batch-stats",
1350 "source-batch-stats",
1351 "acme",
1352 "# A\nRust SQLite agents.\n\n# B\nAdaptive batching for ingestion throughput.",
1353 )
1354 .with_chunking(ChunkingStrategy::HeadingAware {
1355 max_chars: 24,
1356 overlap_chars: 4,
1357 })
1358 .with_batch_size(2)
1359 .with_batch_tuning(IngestionBatchTuning {
1360 adaptive: true,
1361 max_batch_size: 8,
1362 target_batch_ms: 100,
1363 });
1364
1365 let report = worker.ingest(request)?;
1366 assert!(report.duration_ms >= 0.0);
1367 assert!(report.throughput_chunks_per_minute >= 0.0);
1368 assert!(report.batch_count > 0);
1369 assert!(report.average_batch_size >= 1.0);
1370 assert!(report.peak_batch_size >= 1);
1371 assert!(report.adaptive_batching);
1372 Ok(())
1373 }
1374}