Skip to main content

sqlrite/
ingest.rs

1use crate::{ChunkInput, Result, SqlRite, SqlRiteError};
2use serde::{Deserialize, Serialize};
3use serde_json::{Map, Value, json};
4use std::fs;
5use std::path::{Path, PathBuf};
6use std::process::Command;
7use std::thread;
8use std::time::{Duration, Instant};
9
10#[derive(Debug, Clone)]
11pub enum IngestionSource {
12    Direct { content: String },
13    File { path: PathBuf },
14    Url { url: String },
15}
16
17impl IngestionSource {
18    pub fn load_content(&self) -> Result<String> {
19        match self {
20            Self::Direct { content } => Ok(content.clone()),
21            Self::File { path } => Ok(fs::read_to_string(path)?),
22            Self::Url { url } => {
23                let output = Command::new("curl").arg("-fsSL").arg(url).output()?;
24                if !output.status.success() {
25                    return Err(SqlRiteError::UnsupportedOperation(format!(
26                        "url ingestion failed for `{url}`"
27                    )));
28                }
29                String::from_utf8(output.stdout).map_err(|_| {
30                    SqlRiteError::UnsupportedOperation("url content is not valid UTF-8".to_string())
31                })
32            }
33        }
34    }
35
36    fn source_label(&self) -> String {
37        match self {
38            Self::Direct { .. } => "direct".to_string(),
39            Self::File { path } => path.display().to_string(),
40            Self::Url { url } => url.clone(),
41        }
42    }
43}
44
45#[derive(Debug, Clone)]
46pub enum ChunkingStrategy {
47    Fixed {
48        max_chars: usize,
49        overlap_chars: usize,
50    },
51    HeadingAware {
52        max_chars: usize,
53        overlap_chars: usize,
54    },
55    Semantic {
56        max_chars: usize,
57    },
58}
59
60impl Default for ChunkingStrategy {
61    fn default() -> Self {
62        Self::HeadingAware {
63            max_chars: 1200,
64            overlap_chars: 120,
65        }
66    }
67}
68
69#[derive(Debug, Clone, Serialize, Deserialize)]
70pub struct IngestionCheckpoint {
71    pub job_id: String,
72    pub source_id: String,
73    pub next_chunk_index: usize,
74    pub updated_unix_ms: u64,
75}
76
77impl IngestionCheckpoint {
78    pub fn load(path: impl AsRef<Path>) -> Result<Option<Self>> {
79        let path = path.as_ref();
80        if !path.exists() {
81            return Ok(None);
82        }
83        let payload = fs::read_to_string(path)?;
84        let checkpoint = serde_json::from_str::<Self>(&payload)
85            .map_err(|e| SqlRiteError::InvalidIngestionCheckpoint(e.to_string()))?;
86        Ok(Some(checkpoint))
87    }
88
89    pub fn save(&self, path: impl AsRef<Path>) -> Result<()> {
90        let path = path.as_ref();
91        if let Some(parent) = path.parent()
92            && !parent.as_os_str().is_empty()
93        {
94            fs::create_dir_all(parent)?;
95        }
96
97        let payload = serde_json::to_string_pretty(self)?;
98        let temp = path.with_extension("tmp");
99        fs::write(&temp, payload)?;
100        fs::rename(temp, path)?;
101        Ok(())
102    }
103
104    pub fn clear(path: impl AsRef<Path>) -> Result<()> {
105        let path = path.as_ref();
106        if path.exists() {
107            fs::remove_file(path)?;
108        }
109        Ok(())
110    }
111}
112
113#[derive(Debug, Clone)]
114pub struct EmbeddingRetryPolicy {
115    pub max_retries: usize,
116    pub initial_backoff_ms: u64,
117    pub max_backoff_ms: u64,
118}
119
120impl Default for EmbeddingRetryPolicy {
121    fn default() -> Self {
122        Self {
123            max_retries: 3,
124            initial_backoff_ms: 50,
125            max_backoff_ms: 1_000,
126        }
127    }
128}
129
130pub trait EmbeddingProvider {
131    fn provider_name(&self) -> &str;
132    fn model_version(&self) -> &str;
133
134    fn embed_batch(&self, texts: &[String]) -> Result<Vec<std::result::Result<Vec<f32>, String>>>;
135}
136
137#[derive(Debug, Clone)]
138pub struct DeterministicEmbeddingProvider {
139    dimension: usize,
140    model_version: String,
141}
142
143impl DeterministicEmbeddingProvider {
144    pub fn new(dimension: usize, model_version: impl Into<String>) -> Result<Self> {
145        if dimension == 0 {
146            return Err(SqlRiteError::EmbeddingProvider(
147                "dimension must be greater than 0".to_string(),
148            ));
149        }
150        Ok(Self {
151            dimension,
152            model_version: model_version.into(),
153        })
154    }
155
156    fn embed_text(&self, text: &str) -> Vec<f32> {
157        let mut vector = vec![0.0f32; self.dimension];
158        for token in text
159            .split(|ch: char| !ch.is_ascii_alphanumeric())
160            .filter(|token| !token.is_empty())
161        {
162            let hash = fnv1a64(token.as_bytes());
163            let idx = (hash % self.dimension as u64) as usize;
164            vector[idx] += 1.0;
165        }
166
167        // Keep a deterministic fallback signal for empty-token content.
168        if vector.iter().all(|v| *v == 0.0) {
169            vector[0] = 1.0;
170        }
171
172        normalize(&mut vector);
173        vector
174    }
175}
176
177impl EmbeddingProvider for DeterministicEmbeddingProvider {
178    fn provider_name(&self) -> &str {
179        "deterministic_local"
180    }
181
182    fn model_version(&self) -> &str {
183        &self.model_version
184    }
185
186    fn embed_batch(&self, texts: &[String]) -> Result<Vec<std::result::Result<Vec<f32>, String>>> {
187        Ok(texts
188            .iter()
189            .map(|text| Ok(self.embed_text(text)))
190            .collect::<Vec<_>>())
191    }
192}
193
194#[derive(Debug, Clone)]
195pub struct OpenAiCompatibleEmbeddingProvider {
196    endpoint: String,
197    api_key: String,
198    model: String,
199    model_version: String,
200    timeout_secs: u64,
201}
202
203impl OpenAiCompatibleEmbeddingProvider {
204    pub fn new(
205        endpoint: impl Into<String>,
206        api_key: impl Into<String>,
207        model: impl Into<String>,
208    ) -> Result<Self> {
209        let endpoint = endpoint.into();
210        let api_key = api_key.into();
211        let model = model.into();
212        if endpoint.trim().is_empty() || api_key.trim().is_empty() || model.trim().is_empty() {
213            return Err(SqlRiteError::EmbeddingProvider(
214                "endpoint, api_key and model are required".to_string(),
215            ));
216        }
217        Ok(Self {
218            endpoint,
219            api_key,
220            model_version: model.clone(),
221            model,
222            timeout_secs: 30,
223        })
224    }
225
226    pub fn from_env(
227        endpoint: impl Into<String>,
228        model: impl Into<String>,
229        api_key_env: &str,
230    ) -> Result<Self> {
231        let api_key = std::env::var(api_key_env).map_err(|_| {
232            SqlRiteError::EmbeddingProvider(format!("missing required env var `{api_key_env}`"))
233        })?;
234        Self::new(endpoint, api_key, model)
235    }
236
237    pub fn with_timeout_secs(mut self, timeout_secs: u64) -> Self {
238        self.timeout_secs = timeout_secs.max(1);
239        self
240    }
241}
242
243impl EmbeddingProvider for OpenAiCompatibleEmbeddingProvider {
244    fn provider_name(&self) -> &str {
245        "openai_compatible_http"
246    }
247
248    fn model_version(&self) -> &str {
249        &self.model_version
250    }
251
252    fn embed_batch(&self, texts: &[String]) -> Result<Vec<std::result::Result<Vec<f32>, String>>> {
253        let payload = json!({
254            "model": self.model,
255            "input": texts,
256        });
257        let response = http_post_json(
258            &self.endpoint,
259            &payload,
260            &[(
261                "Authorization".to_string(),
262                format!("Bearer {}", self.api_key),
263            )],
264            self.timeout_secs,
265        )?;
266
267        parse_openai_embeddings_response(&response)
268    }
269}
270
271#[derive(Debug, Clone)]
272pub struct CustomHttpEmbeddingProvider {
273    endpoint: String,
274    model: Option<String>,
275    model_version: String,
276    input_field: String,
277    embeddings_field: String,
278    headers: Vec<(String, String)>,
279    timeout_secs: u64,
280}
281
282impl CustomHttpEmbeddingProvider {
283    pub fn new(endpoint: impl Into<String>, model_version: impl Into<String>) -> Result<Self> {
284        let endpoint = endpoint.into();
285        if endpoint.trim().is_empty() {
286            return Err(SqlRiteError::EmbeddingProvider(
287                "endpoint is required".to_string(),
288            ));
289        }
290        Ok(Self {
291            endpoint,
292            model: None,
293            model_version: model_version.into(),
294            input_field: "inputs".to_string(),
295            embeddings_field: "embeddings".to_string(),
296            headers: Vec::new(),
297            timeout_secs: 30,
298        })
299    }
300
301    pub fn with_model(mut self, model: impl Into<String>) -> Self {
302        self.model = Some(model.into());
303        self
304    }
305
306    pub fn with_header(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
307        self.headers.push((key.into(), value.into()));
308        self
309    }
310
311    pub fn with_fields(
312        mut self,
313        input_field: impl Into<String>,
314        embeddings_field: impl Into<String>,
315    ) -> Self {
316        self.input_field = input_field.into();
317        self.embeddings_field = embeddings_field.into();
318        self
319    }
320
321    pub fn with_timeout_secs(mut self, timeout_secs: u64) -> Self {
322        self.timeout_secs = timeout_secs.max(1);
323        self
324    }
325}
326
327impl EmbeddingProvider for CustomHttpEmbeddingProvider {
328    fn provider_name(&self) -> &str {
329        "custom_http"
330    }
331
332    fn model_version(&self) -> &str {
333        &self.model_version
334    }
335
336    fn embed_batch(&self, texts: &[String]) -> Result<Vec<std::result::Result<Vec<f32>, String>>> {
337        let mut payload = serde_json::Map::new();
338        payload.insert(
339            self.input_field.clone(),
340            Value::Array(texts.iter().cloned().map(Value::String).collect()),
341        );
342        if let Some(model) = &self.model {
343            payload.insert("model".to_string(), Value::String(model.clone()));
344        }
345
346        let response = http_post_json(
347            &self.endpoint,
348            &Value::Object(payload),
349            &self.headers,
350            self.timeout_secs,
351        )?;
352
353        if let Some(vectors) = response
354            .get(&self.embeddings_field)
355            .and_then(Value::as_array)
356        {
357            let mut out = Vec::with_capacity(vectors.len());
358            for vector in vectors {
359                out.push(parse_embedding_array(vector).map_err(|e| e.to_string()));
360            }
361            return Ok(out);
362        }
363
364        if let Some(results) = response.get("results").and_then(Value::as_array) {
365            let mut out = Vec::with_capacity(results.len());
366            for item in results {
367                if let Some(error) = item.get("error").and_then(Value::as_str) {
368                    out.push(Err(error.to_string()));
369                    continue;
370                }
371                let Some(embedding) = item.get("embedding") else {
372                    out.push(Err("missing `embedding` field".to_string()));
373                    continue;
374                };
375                out.push(parse_embedding_array(embedding).map_err(|e| e.to_string()));
376            }
377            return Ok(out);
378        }
379
380        if response.get("data").is_some() {
381            return parse_openai_embeddings_response(&response);
382        }
383
384        Err(SqlRiteError::EmbeddingProvider(
385            "unsupported custom embedding response schema".to_string(),
386        ))
387    }
388}
389
390#[derive(Debug, Clone)]
391pub struct IngestionRequest {
392    pub job_id: String,
393    pub doc_id: String,
394    pub source_id: String,
395    pub tenant_id: String,
396    pub source: IngestionSource,
397    pub metadata: Value,
398    pub chunking: ChunkingStrategy,
399    pub batch_size: usize,
400    pub batch_tuning: IngestionBatchTuning,
401    pub continue_on_partial_failure: bool,
402}
403
404#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
405pub struct IngestionBatchTuning {
406    pub adaptive: bool,
407    pub max_batch_size: usize,
408    pub target_batch_ms: u64,
409}
410
411impl Default for IngestionBatchTuning {
412    fn default() -> Self {
413        Self {
414            adaptive: true,
415            max_batch_size: 1024,
416            target_batch_ms: 80,
417        }
418    }
419}
420
421impl IngestionRequest {
422    pub fn from_direct(
423        job_id: impl Into<String>,
424        doc_id: impl Into<String>,
425        source_id: impl Into<String>,
426        tenant_id: impl Into<String>,
427        content: impl Into<String>,
428    ) -> Self {
429        Self {
430            job_id: job_id.into(),
431            doc_id: doc_id.into(),
432            source_id: source_id.into(),
433            tenant_id: tenant_id.into(),
434            source: IngestionSource::Direct {
435                content: content.into(),
436            },
437            metadata: json!({}),
438            chunking: ChunkingStrategy::default(),
439            batch_size: 64,
440            batch_tuning: IngestionBatchTuning::default(),
441            continue_on_partial_failure: false,
442        }
443    }
444
445    pub fn from_file(
446        job_id: impl Into<String>,
447        doc_id: impl Into<String>,
448        source_id: impl Into<String>,
449        tenant_id: impl Into<String>,
450        path: impl Into<PathBuf>,
451    ) -> Self {
452        Self {
453            job_id: job_id.into(),
454            doc_id: doc_id.into(),
455            source_id: source_id.into(),
456            tenant_id: tenant_id.into(),
457            source: IngestionSource::File { path: path.into() },
458            metadata: json!({}),
459            chunking: ChunkingStrategy::default(),
460            batch_size: 64,
461            batch_tuning: IngestionBatchTuning::default(),
462            continue_on_partial_failure: false,
463        }
464    }
465
466    pub fn from_url(
467        job_id: impl Into<String>,
468        doc_id: impl Into<String>,
469        source_id: impl Into<String>,
470        tenant_id: impl Into<String>,
471        url: impl Into<String>,
472    ) -> Self {
473        Self {
474            job_id: job_id.into(),
475            doc_id: doc_id.into(),
476            source_id: source_id.into(),
477            tenant_id: tenant_id.into(),
478            source: IngestionSource::Url { url: url.into() },
479            metadata: json!({}),
480            chunking: ChunkingStrategy::default(),
481            batch_size: 64,
482            batch_tuning: IngestionBatchTuning::default(),
483            continue_on_partial_failure: false,
484        }
485    }
486
487    pub fn with_metadata(mut self, metadata: Value) -> Self {
488        self.metadata = metadata;
489        self
490    }
491
492    pub fn with_chunking(mut self, chunking: ChunkingStrategy) -> Self {
493        self.chunking = chunking;
494        self
495    }
496
497    pub fn with_batch_size(mut self, batch_size: usize) -> Self {
498        self.batch_size = batch_size;
499        self
500    }
501
502    pub fn with_batch_tuning(mut self, batch_tuning: IngestionBatchTuning) -> Self {
503        self.batch_tuning = batch_tuning;
504        self
505    }
506
507    pub fn with_adaptive_batching(mut self, enabled: bool) -> Self {
508        self.batch_tuning.adaptive = enabled;
509        self
510    }
511
512    pub fn with_max_batch_size(mut self, max_batch_size: usize) -> Self {
513        self.batch_tuning.max_batch_size = max_batch_size.max(1);
514        self
515    }
516
517    pub fn with_target_batch_ms(mut self, target_batch_ms: u64) -> Self {
518        self.batch_tuning.target_batch_ms = target_batch_ms.max(1);
519        self
520    }
521
522    pub fn with_continue_on_partial_failure(mut self, enabled: bool) -> Self {
523        self.continue_on_partial_failure = enabled;
524        self
525    }
526}
527
528#[derive(Debug, Clone, Serialize, Deserialize)]
529pub struct IngestionReport {
530    pub total_chunks: usize,
531    pub processed_chunks: usize,
532    pub failed_chunks: usize,
533    pub resumed_from_chunk: usize,
534    pub duration_ms: f64,
535    pub throughput_chunks_per_minute: f64,
536    pub average_batch_size: f64,
537    pub peak_batch_size: usize,
538    pub batch_count: usize,
539    pub adaptive_batching: bool,
540    pub provider: String,
541    pub model_version: String,
542    pub source: String,
543}
544
545#[derive(Debug, Clone)]
546struct Segment {
547    content: String,
548    start: usize,
549    end: usize,
550    heading: Option<String>,
551}
552
553struct MetadataEnrichment<'a> {
554    tenant_id: &'a str,
555    content_hash: String,
556    source_start: usize,
557    source_end: usize,
558    heading: Option<&'a str>,
559    provider: &'a str,
560    model_version: &'a str,
561}
562
563pub struct IngestionWorker<'a, P: EmbeddingProvider> {
564    db: &'a SqlRite,
565    provider: P,
566    retry_policy: EmbeddingRetryPolicy,
567    checkpoint_path: Option<PathBuf>,
568}
569
570impl<'a, P: EmbeddingProvider> IngestionWorker<'a, P> {
571    pub fn new(db: &'a SqlRite, provider: P) -> Self {
572        Self {
573            db,
574            provider,
575            retry_policy: EmbeddingRetryPolicy::default(),
576            checkpoint_path: None,
577        }
578    }
579
580    pub fn with_retry_policy(mut self, retry_policy: EmbeddingRetryPolicy) -> Self {
581        self.retry_policy = retry_policy;
582        self
583    }
584
585    pub fn with_checkpoint_path(mut self, path: impl Into<PathBuf>) -> Self {
586        self.checkpoint_path = Some(path.into());
587        self
588    }
589
590    pub fn ingest(&self, request: IngestionRequest) -> Result<IngestionReport> {
591        if request.tenant_id.trim().is_empty() {
592            return Err(SqlRiteError::InvalidTenantId);
593        }
594        if request.batch_size == 0 {
595            return Err(SqlRiteError::InvalidBenchmarkConfig(
596                "ingestion batch_size must be >= 1".to_string(),
597            ));
598        }
599        if request.batch_tuning.max_batch_size == 0 {
600            return Err(SqlRiteError::InvalidBenchmarkConfig(
601                "ingestion max_batch_size must be >= 1".to_string(),
602            ));
603        }
604        if request.batch_tuning.target_batch_ms == 0 {
605            return Err(SqlRiteError::InvalidBenchmarkConfig(
606                "ingestion target_batch_ms must be >= 1".to_string(),
607            ));
608        }
609
610        let ingest_started = Instant::now();
611        let source_content = request.source.load_content()?;
612        let segments = chunk_content(&source_content, &request.chunking);
613
614        let resumed_from_chunk = self
615            .load_resume_checkpoint(&request.job_id, &request.source_id)?
616            .unwrap_or(0)
617            .min(segments.len());
618
619        let mut processed_chunks = 0usize;
620        let mut failed_chunks = 0usize;
621        let mut cursor = resumed_from_chunk;
622        let mut batch_count = 0usize;
623        let mut peak_batch_size = 0usize;
624        let mut total_batch_size = 0usize;
625        let mut next_batch_size = request
626            .batch_size
627            .max(1)
628            .min(request.batch_tuning.max_batch_size);
629
630        while cursor < segments.len() {
631            let batch_started = Instant::now();
632            let remaining = segments.len().saturating_sub(cursor);
633            let planned_batch = next_batch_size.min(remaining).max(1);
634            let end = (cursor + planned_batch).min(segments.len());
635            let batch = &segments[cursor..end];
636            batch_count += 1;
637            peak_batch_size = peak_batch_size.max(batch.len());
638            total_batch_size += batch.len();
639            let texts = batch
640                .iter()
641                .map(|segment| segment.content.clone())
642                .collect::<Vec<_>>();
643
644            let embedded = self.embed_with_retry(&texts)?;
645            let mut upserts = Vec::with_capacity(batch.len());
646            let mut batch_failed_chunks = 0usize;
647
648            for (idx, segment) in batch.iter().enumerate() {
649                let Some(embedding) = embedded[idx].clone() else {
650                    batch_failed_chunks += 1;
651                    failed_chunks += 1;
652                    continue;
653                };
654
655                let chunk_id = chunk_id_for(
656                    &request.tenant_id,
657                    &request.doc_id,
658                    segment.start,
659                    segment.end,
660                    &segment.content,
661                );
662
663                let metadata = merge_metadata(
664                    &request.metadata,
665                    &MetadataEnrichment {
666                        tenant_id: &request.tenant_id,
667                        content_hash: hex64(fnv1a64(segment.content.as_bytes())),
668                        source_start: segment.start,
669                        source_end: segment.end,
670                        heading: segment.heading.as_deref(),
671                        provider: self.provider.provider_name(),
672                        model_version: self.provider.model_version(),
673                    },
674                )?;
675
676                upserts.push(ChunkInput {
677                    id: chunk_id,
678                    doc_id: request.doc_id.clone(),
679                    content: segment.content.clone(),
680                    embedding,
681                    metadata,
682                    source: Some(request.source_id.clone()),
683                });
684            }
685
686            if !upserts.is_empty() {
687                self.db.ingest_chunks(&upserts)?;
688                processed_chunks += upserts.len();
689            }
690
691            cursor = end;
692            self.save_checkpoint(&request.job_id, &request.source_id, cursor)?;
693            let batch_duration_ms = batch_started.elapsed().as_secs_f64() * 1000.0;
694
695            if batch_failed_chunks > 0 && !request.continue_on_partial_failure {
696                return Err(SqlRiteError::EmbeddingBatchPartialFailure {
697                    failed: failed_chunks,
698                });
699            }
700
701            if request.batch_tuning.adaptive {
702                let target_ms = request.batch_tuning.target_batch_ms as f64;
703                if batch_failed_chunks == 0 && batch_duration_ms <= target_ms * 0.60 {
704                    let grown = next_batch_size
705                        .saturating_add(next_batch_size / 2)
706                        .saturating_add(1);
707                    next_batch_size = grown.min(request.batch_tuning.max_batch_size).max(1);
708                } else if batch_duration_ms > target_ms || batch_failed_chunks > 0 {
709                    next_batch_size = (next_batch_size / 2).max(1);
710                }
711            }
712        }
713
714        self.clear_checkpoint()?;
715        let duration_ms = ingest_started.elapsed().as_secs_f64() * 1000.0;
716        let throughput_chunks_per_minute = if duration_ms > 0.0 {
717            (processed_chunks as f64 / (duration_ms / 1000.0)) * 60.0
718        } else {
719            0.0
720        };
721        let average_batch_size = if batch_count > 0 {
722            total_batch_size as f64 / batch_count as f64
723        } else {
724            0.0
725        };
726
727        Ok(IngestionReport {
728            total_chunks: segments.len(),
729            processed_chunks,
730            failed_chunks,
731            resumed_from_chunk,
732            duration_ms,
733            throughput_chunks_per_minute,
734            average_batch_size,
735            peak_batch_size,
736            batch_count,
737            adaptive_batching: request.batch_tuning.adaptive,
738            provider: self.provider.provider_name().to_string(),
739            model_version: self.provider.model_version().to_string(),
740            source: request.source.source_label(),
741        })
742    }
743
744    fn embed_with_retry(&self, texts: &[String]) -> Result<Vec<Option<Vec<f32>>>> {
745        let mut pending: Vec<usize> = (0..texts.len()).collect();
746        let mut resolved = vec![None; texts.len()];
747        let mut attempt = 0usize;
748        let mut backoff_ms = self.retry_policy.initial_backoff_ms.max(1);
749
750        while !pending.is_empty() && attempt <= self.retry_policy.max_retries {
751            let current_texts = pending
752                .iter()
753                .map(|idx| texts[*idx].clone())
754                .collect::<Vec<_>>();
755            let responses = self.provider.embed_batch(&current_texts)?;
756
757            if responses.len() != current_texts.len() {
758                return Err(SqlRiteError::EmbeddingProvider(
759                    "provider returned mismatched batch length".to_string(),
760                ));
761            }
762
763            let mut next_pending = Vec::new();
764            for (slot, response) in responses.into_iter().enumerate() {
765                let original_idx = pending[slot];
766                match response {
767                    Ok(embedding) => resolved[original_idx] = Some(embedding),
768                    Err(_) => next_pending.push(original_idx),
769                }
770            }
771
772            pending = next_pending;
773            if !pending.is_empty() {
774                attempt += 1;
775                if attempt <= self.retry_policy.max_retries {
776                    thread::sleep(Duration::from_millis(backoff_ms));
777                    backoff_ms =
778                        (backoff_ms.saturating_mul(2)).min(self.retry_policy.max_backoff_ms);
779                }
780            }
781        }
782
783        Ok(resolved)
784    }
785
786    fn load_resume_checkpoint(&self, job_id: &str, source_id: &str) -> Result<Option<usize>> {
787        let Some(path) = &self.checkpoint_path else {
788            return Ok(None);
789        };
790
791        let Some(checkpoint) = IngestionCheckpoint::load(path)? else {
792            return Ok(None);
793        };
794
795        if checkpoint.job_id == job_id && checkpoint.source_id == source_id {
796            Ok(Some(checkpoint.next_chunk_index))
797        } else {
798            Ok(None)
799        }
800    }
801
802    fn save_checkpoint(
803        &self,
804        job_id: &str,
805        source_id: &str,
806        next_chunk_index: usize,
807    ) -> Result<()> {
808        let Some(path) = &self.checkpoint_path else {
809            return Ok(());
810        };
811
812        let checkpoint = IngestionCheckpoint {
813            job_id: job_id.to_string(),
814            source_id: source_id.to_string(),
815            next_chunk_index,
816            updated_unix_ms: now_unix_ms(),
817        };
818        checkpoint.save(path)
819    }
820
821    fn clear_checkpoint(&self) -> Result<()> {
822        let Some(path) = &self.checkpoint_path else {
823            return Ok(());
824        };
825
826        IngestionCheckpoint::clear(path)
827    }
828}
829
830fn chunk_content(text: &str, strategy: &ChunkingStrategy) -> Vec<Segment> {
831    match strategy {
832        ChunkingStrategy::Fixed {
833            max_chars,
834            overlap_chars,
835        } => chunk_fixed(text, *max_chars, *overlap_chars),
836        ChunkingStrategy::HeadingAware {
837            max_chars,
838            overlap_chars,
839        } => chunk_heading_aware(text, *max_chars, *overlap_chars),
840        ChunkingStrategy::Semantic { max_chars } => chunk_semantic(text, *max_chars),
841    }
842}
843
844fn chunk_fixed(text: &str, max_chars: usize, overlap_chars: usize) -> Vec<Segment> {
845    if text.is_empty() {
846        return Vec::new();
847    }
848
849    let max_chars = max_chars.max(1);
850    let overlap_chars = overlap_chars.min(max_chars.saturating_sub(1));
851
852    let mut segments = Vec::new();
853    let mut start = 0usize;
854
855    while start < text.len() {
856        let mut end = (start + max_chars).min(text.len());
857        while end > start && !text.is_char_boundary(end) {
858            end -= 1;
859        }
860        if end <= start {
861            end = text.len();
862        }
863
864        segments.push(Segment {
865            content: text[start..end].to_string(),
866            start,
867            end,
868            heading: None,
869        });
870
871        if end == text.len() {
872            break;
873        }
874
875        let mut next_start = end.saturating_sub(overlap_chars);
876        while next_start > start && !text.is_char_boundary(next_start) {
877            next_start -= 1;
878        }
879        if next_start <= start {
880            next_start = end;
881        }
882        start = next_start;
883    }
884
885    segments
886}
887
888fn chunk_heading_aware(text: &str, max_chars: usize, overlap_chars: usize) -> Vec<Segment> {
889    if text.is_empty() {
890        return Vec::new();
891    }
892
893    let mut sections = Vec::new();
894    let mut offset = 0usize;
895    let mut section_start = 0usize;
896    let mut heading: Option<String> = None;
897
898    for line in text.split_inclusive('\n') {
899        let line_start = offset;
900        offset += line.len();
901        let trimmed = line.trim_start();
902        if trimmed.starts_with('#') {
903            if line_start > section_start {
904                sections.push((section_start, line_start, heading.clone()));
905            }
906            heading = Some(trimmed.trim().trim_start_matches('#').trim().to_string());
907            section_start = line_start;
908        }
909    }
910
911    if section_start < text.len() {
912        sections.push((section_start, text.len(), heading));
913    }
914
915    if sections.is_empty() {
916        return chunk_fixed(text, max_chars, overlap_chars);
917    }
918
919    let mut segments = Vec::new();
920    for (start, end, heading) in sections {
921        let section_text = &text[start..end];
922        for mut part in chunk_fixed(section_text, max_chars, overlap_chars) {
923            part.start += start;
924            part.end += start;
925            part.heading = heading.clone();
926            segments.push(part);
927        }
928    }
929
930    segments
931}
932
933fn chunk_semantic(text: &str, max_chars: usize) -> Vec<Segment> {
934    if text.is_empty() {
935        return Vec::new();
936    }
937
938    let max_chars = max_chars.max(1);
939    let mut sentence_bounds = Vec::new();
940    let mut sentence_start = 0usize;
941
942    for (idx, ch) in text.char_indices() {
943        if matches!(ch, '.' | '!' | '?') {
944            let end = idx + ch.len_utf8();
945            if end > sentence_start {
946                sentence_bounds.push((sentence_start, end));
947                sentence_start = end;
948            }
949        }
950    }
951    if sentence_start < text.len() {
952        sentence_bounds.push((sentence_start, text.len()));
953    }
954
955    if sentence_bounds.is_empty() {
956        return chunk_fixed(text, max_chars, 0);
957    }
958
959    let mut segments = Vec::new();
960    let mut current_start = sentence_bounds[0].0;
961    let mut current_end = sentence_bounds[0].0;
962
963    for (start, end) in sentence_bounds {
964        if end.saturating_sub(current_start) > max_chars && current_end > current_start {
965            segments.push(Segment {
966                content: text[current_start..current_end].trim().to_string(),
967                start: current_start,
968                end: current_end,
969                heading: None,
970            });
971            current_start = start;
972        }
973        current_end = end;
974    }
975
976    if current_end > current_start {
977        segments.push(Segment {
978            content: text[current_start..current_end].trim().to_string(),
979            start: current_start,
980            end: current_end,
981            heading: None,
982        });
983    }
984
985    if segments.is_empty() {
986        chunk_fixed(text, max_chars, 0)
987    } else {
988        segments
989            .into_iter()
990            .filter(|segment| !segment.content.is_empty())
991            .collect()
992    }
993}
994
995fn merge_metadata(base: &Value, enrichment: &MetadataEnrichment<'_>) -> Result<Value> {
996    let mut metadata_obj = match base {
997        Value::Object(map) => map.clone(),
998        _ => Map::new(),
999    };
1000
1001    metadata_obj.insert(
1002        "tenant".to_string(),
1003        Value::String(enrichment.tenant_id.to_string()),
1004    );
1005    metadata_obj.insert(
1006        "content_hash".to_string(),
1007        Value::String(enrichment.content_hash.clone()),
1008    );
1009    metadata_obj.insert(
1010        "source_start".to_string(),
1011        Value::Number(serde_json::Number::from(enrichment.source_start as u64)),
1012    );
1013    metadata_obj.insert(
1014        "source_end".to_string(),
1015        Value::Number(serde_json::Number::from(enrichment.source_end as u64)),
1016    );
1017    metadata_obj.insert(
1018        "embedding_provider".to_string(),
1019        Value::String(enrichment.provider.to_string()),
1020    );
1021    metadata_obj.insert(
1022        "embedding_model_version".to_string(),
1023        Value::String(enrichment.model_version.to_string()),
1024    );
1025
1026    if let Some(heading) = enrichment.heading
1027        && !heading.is_empty()
1028    {
1029        metadata_obj.insert("heading".to_string(), Value::String(heading.to_string()));
1030    }
1031
1032    Ok(Value::Object(metadata_obj))
1033}
1034
1035fn chunk_id_for(tenant_id: &str, doc_id: &str, start: usize, end: usize, content: &str) -> String {
1036    let mut seed = Vec::new();
1037    seed.extend_from_slice(tenant_id.as_bytes());
1038    seed.push(0);
1039    seed.extend_from_slice(doc_id.as_bytes());
1040    seed.push(0);
1041    seed.extend_from_slice(start.to_string().as_bytes());
1042    seed.push(0);
1043    seed.extend_from_slice(end.to_string().as_bytes());
1044    seed.push(0);
1045    seed.extend_from_slice(content.as_bytes());
1046
1047    let hash = hex64(fnv1a64(&seed));
1048    let mut tenant = tenant_id
1049        .chars()
1050        .filter(|ch| ch.is_ascii_alphanumeric() || *ch == '_')
1051        .collect::<String>();
1052    if tenant.is_empty() {
1053        tenant = "tenant".to_string();
1054    }
1055
1056    format!("{tenant}-{hash}")
1057}
1058
1059fn now_unix_ms() -> u64 {
1060    std::time::SystemTime::now()
1061        .duration_since(std::time::UNIX_EPOCH)
1062        .map(|d| d.as_millis() as u64)
1063        .unwrap_or(0)
1064}
1065
1066fn http_post_json(
1067    endpoint: &str,
1068    payload: &Value,
1069    headers: &[(String, String)],
1070    timeout_secs: u64,
1071) -> Result<Value> {
1072    let mut cmd = Command::new("curl");
1073    cmd.arg("-fsS")
1074        .arg("-X")
1075        .arg("POST")
1076        .arg("--max-time")
1077        .arg(timeout_secs.max(1).to_string())
1078        .arg("-H")
1079        .arg("Content-Type: application/json");
1080
1081    for (key, value) in headers {
1082        cmd.arg("-H").arg(format!("{key}: {value}"));
1083    }
1084
1085    let output = cmd
1086        .arg("-d")
1087        .arg(payload.to_string())
1088        .arg(endpoint)
1089        .output()?;
1090
1091    if !output.status.success() {
1092        let stderr = String::from_utf8_lossy(&output.stderr);
1093        return Err(SqlRiteError::EmbeddingProvider(format!(
1094            "http request failed for `{endpoint}`: {stderr}"
1095        )));
1096    }
1097
1098    serde_json::from_slice::<Value>(&output.stdout).map_err(|e| {
1099        SqlRiteError::EmbeddingProvider(format!("invalid json response from `{endpoint}`: {e}"))
1100    })
1101}
1102
1103fn parse_openai_embeddings_response(
1104    response: &Value,
1105) -> Result<Vec<std::result::Result<Vec<f32>, String>>> {
1106    let data = response
1107        .get("data")
1108        .and_then(Value::as_array)
1109        .ok_or_else(|| SqlRiteError::EmbeddingProvider("missing `data` array".to_string()))?;
1110
1111    let mut out = Vec::with_capacity(data.len());
1112    for row in data {
1113        let Some(embedding) = row.get("embedding") else {
1114            out.push(Err("missing `embedding` field".to_string()));
1115            continue;
1116        };
1117        out.push(parse_embedding_array(embedding).map_err(|e| e.to_string()));
1118    }
1119    Ok(out)
1120}
1121
1122fn parse_embedding_array(value: &Value) -> Result<Vec<f32>> {
1123    let array = value
1124        .as_array()
1125        .ok_or_else(|| SqlRiteError::EmbeddingProvider("embedding must be an array".to_string()))?;
1126    let mut embedding = Vec::with_capacity(array.len());
1127    for item in array {
1128        let Some(number) = item.as_f64() else {
1129            return Err(SqlRiteError::EmbeddingProvider(
1130                "embedding item must be numeric".to_string(),
1131            ));
1132        };
1133        embedding.push(number as f32);
1134    }
1135    if embedding.is_empty() {
1136        return Err(SqlRiteError::EmbeddingProvider(
1137            "embedding array cannot be empty".to_string(),
1138        ));
1139    }
1140    Ok(embedding)
1141}
1142
1143fn fnv1a64(bytes: &[u8]) -> u64 {
1144    let mut hash = 0xcbf29ce484222325u64;
1145    for byte in bytes {
1146        hash ^= *byte as u64;
1147        hash = hash.wrapping_mul(0x100000001b3);
1148    }
1149    hash
1150}
1151
1152fn hex64(value: u64) -> String {
1153    format!("{value:016x}")
1154}
1155
1156fn normalize(vector: &mut [f32]) {
1157    let norm = vector.iter().map(|v| v * v).sum::<f32>().sqrt();
1158    if norm > 0.0 {
1159        for value in vector {
1160            *value /= norm;
1161        }
1162    }
1163}
1164
1165#[cfg(test)]
1166mod tests {
1167    use super::*;
1168    use crate::RuntimeConfig;
1169    use serde_json::json;
1170    use tempfile::tempdir;
1171
1172    #[test]
1173    fn fixed_chunking_respects_overlap() {
1174        let text = "abcdefghijklmnopqrstuvwxyz";
1175        let chunks = chunk_fixed(text, 10, 3);
1176        assert!(chunks.len() >= 3);
1177        assert_eq!(chunks[0].content, "abcdefghij");
1178        assert_eq!(chunks[1].content, "hijklmnopq");
1179    }
1180
1181    #[test]
1182    fn checkpoint_round_trip() -> Result<()> {
1183        let dir = tempdir()?;
1184        let path = dir.path().join("checkpoint.json");
1185        let checkpoint = IngestionCheckpoint {
1186            job_id: "job-a".to_string(),
1187            source_id: "source-a".to_string(),
1188            next_chunk_index: 42,
1189            updated_unix_ms: 1,
1190        };
1191
1192        checkpoint.save(&path)?;
1193        let loaded = IngestionCheckpoint::load(&path)?.expect("checkpoint exists");
1194        assert_eq!(loaded.next_chunk_index, 42);
1195
1196        IngestionCheckpoint::clear(&path)?;
1197        assert!(IngestionCheckpoint::load(&path)?.is_none());
1198        Ok(())
1199    }
1200
1201    #[test]
1202    fn ingestion_request_convenience_builders_set_fields() {
1203        let req = IngestionRequest::from_file("job", "doc", "source", "acme", "README.md")
1204            .with_chunking(ChunkingStrategy::Fixed {
1205                max_chars: 200,
1206                overlap_chars: 20,
1207            })
1208            .with_batch_size(32)
1209            .with_max_batch_size(256)
1210            .with_target_batch_ms(120)
1211            .with_continue_on_partial_failure(true);
1212        assert_eq!(req.batch_size, 32);
1213        assert_eq!(req.batch_tuning.max_batch_size, 256);
1214        assert_eq!(req.batch_tuning.target_batch_ms, 120);
1215        assert!(req.continue_on_partial_failure);
1216        assert!(matches!(req.source, IngestionSource::File { .. }));
1217    }
1218
1219    #[test]
1220    fn parse_openai_response_extracts_embeddings() -> Result<()> {
1221        let response = json!({
1222            "data": [
1223                {"embedding": [0.1, 0.2, 0.3]},
1224                {"embedding": [0.4, 0.5, 0.6]}
1225            ]
1226        });
1227        let parsed = parse_openai_embeddings_response(&response)?;
1228        assert_eq!(parsed.len(), 2);
1229        assert!(parsed.iter().all(std::result::Result::is_ok));
1230        Ok(())
1231    }
1232
1233    #[test]
1234    fn custom_provider_parses_results_schema() -> Result<()> {
1235        let provider = CustomHttpEmbeddingProvider::new("http://localhost:1234", "v1")?;
1236        let response = json!({
1237            "results": [
1238                {"embedding": [0.1, 0.2]},
1239                {"error": "rate_limited"}
1240            ]
1241        });
1242
1243        // Use internal parser contract by matching the same branch behavior.
1244        let parsed = if let Some(results) = response.get("results").and_then(Value::as_array) {
1245            let mut out = Vec::new();
1246            for item in results {
1247                if let Some(error) = item.get("error").and_then(Value::as_str) {
1248                    out.push(Err(error.to_string()));
1249                    continue;
1250                }
1251                let embedding = item.get("embedding").expect("embedding exists");
1252                out.push(parse_embedding_array(embedding).map_err(|e| e.to_string()));
1253            }
1254            out
1255        } else {
1256            Vec::new()
1257        };
1258
1259        assert_eq!(provider.model_version(), "v1");
1260        assert_eq!(parsed.len(), 2);
1261        assert!(parsed[0].is_ok());
1262        assert!(parsed[1].is_err());
1263        Ok(())
1264    }
1265
1266    #[test]
1267    fn ingestion_is_idempotent_for_same_payload() -> Result<()> {
1268        let db = SqlRite::open_in_memory_with_config(RuntimeConfig::default())?;
1269        let provider = DeterministicEmbeddingProvider::new(64, "det-v1")?;
1270        let worker = IngestionWorker::new(&db, provider);
1271
1272        let request = IngestionRequest {
1273            job_id: "job-1".to_string(),
1274            doc_id: "doc-1".to_string(),
1275            source_id: "payload-1".to_string(),
1276            tenant_id: "acme".to_string(),
1277            source: IngestionSource::Direct {
1278                content: "# Intro\nRust agents need deterministic retrieval.\n\n# Details\nSQLite RAG memory is portable.".to_string(),
1279            },
1280            metadata: json!({"kind": "guide"}),
1281            chunking: ChunkingStrategy::HeadingAware {
1282                max_chars: 40,
1283                overlap_chars: 5,
1284            },
1285            batch_size: 4,
1286            batch_tuning: IngestionBatchTuning::default(),
1287            continue_on_partial_failure: false,
1288        };
1289
1290        let first = worker.ingest(request.clone())?;
1291        let count_after_first = db.chunk_count()?;
1292        let second = worker.ingest(request)?;
1293        let count_after_second = db.chunk_count()?;
1294
1295        assert!(first.total_chunks > 0);
1296        assert_eq!(count_after_first, count_after_second);
1297        assert_eq!(second.failed_chunks, 0);
1298        Ok(())
1299    }
1300
1301    #[test]
1302    fn ingestion_resumes_from_checkpoint() -> Result<()> {
1303        let db = SqlRite::open_in_memory_with_config(RuntimeConfig::default())?;
1304        let provider = DeterministicEmbeddingProvider::new(32, "det-v1")?;
1305        let dir = tempdir()?;
1306        let checkpoint_path = dir.path().join("ingest.checkpoint.json");
1307
1308        let checkpoint = IngestionCheckpoint {
1309            job_id: "job-resume".to_string(),
1310            source_id: "source-resume".to_string(),
1311            next_chunk_index: 1,
1312            updated_unix_ms: now_unix_ms(),
1313        };
1314        checkpoint.save(&checkpoint_path)?;
1315
1316        let worker = IngestionWorker::new(&db, provider).with_checkpoint_path(&checkpoint_path);
1317        let request = IngestionRequest {
1318            job_id: "job-resume".to_string(),
1319            doc_id: "doc-r".to_string(),
1320            source_id: "source-resume".to_string(),
1321            tenant_id: "acme".to_string(),
1322            source: IngestionSource::Direct {
1323                content: "one two three four five six seven eight nine ten".to_string(),
1324            },
1325            metadata: json!({}),
1326            chunking: ChunkingStrategy::Fixed {
1327                max_chars: 10,
1328                overlap_chars: 0,
1329            },
1330            batch_size: 2,
1331            batch_tuning: IngestionBatchTuning::default(),
1332            continue_on_partial_failure: false,
1333        };
1334
1335        let report = worker.ingest(request)?;
1336        assert!(report.resumed_from_chunk >= 1);
1337        assert!(IngestionCheckpoint::load(&checkpoint_path)?.is_none());
1338        Ok(())
1339    }
1340
1341    #[test]
1342    fn ingestion_reports_throughput_and_adaptive_batching_stats() -> Result<()> {
1343        let db = SqlRite::open_in_memory_with_config(RuntimeConfig::default())?;
1344        let provider = DeterministicEmbeddingProvider::new(64, "det-v1")?;
1345        let worker = IngestionWorker::new(&db, provider);
1346
1347        let request = IngestionRequest::from_direct(
1348            "job-batch-stats",
1349            "doc-batch-stats",
1350            "source-batch-stats",
1351            "acme",
1352            "# A\nRust SQLite agents.\n\n# B\nAdaptive batching for ingestion throughput.",
1353        )
1354        .with_chunking(ChunkingStrategy::HeadingAware {
1355            max_chars: 24,
1356            overlap_chars: 4,
1357        })
1358        .with_batch_size(2)
1359        .with_batch_tuning(IngestionBatchTuning {
1360            adaptive: true,
1361            max_batch_size: 8,
1362            target_batch_ms: 100,
1363        });
1364
1365        let report = worker.ingest(request)?;
1366        assert!(report.duration_ms >= 0.0);
1367        assert!(report.throughput_chunks_per_minute >= 0.0);
1368        assert!(report.batch_count > 0);
1369        assert!(report.average_batch_size >= 1.0);
1370        assert!(report.peak_batch_size >= 1);
1371        assert!(report.adaptive_batching);
1372        Ok(())
1373    }
1374}