xz_embed/
batch_manager.rs1use std::sync::Arc;
2use std::time::Duration;
3
4use futures::future;
5use tokio::sync::Semaphore;
6use tracing::{debug, warn};
7
8use crate::config::RetryConfig;
9use crate::error::EmbedError;
10use crate::traits::EmbeddingModel;
11
12pub struct ConcurrentBatchManager {
17 embedder: Arc<dyn EmbeddingModel>,
18 batch_size: usize,
20 max_concurrency: usize,
22 retry: RetryConfig,
24}
25
26impl ConcurrentBatchManager {
27 pub fn new(embedder: Box<dyn EmbeddingModel>, batch_size: usize, max_concurrency: usize) -> Self {
28 Self {
29 embedder: Arc::from(embedder),
30 batch_size,
31 max_concurrency,
32 retry: RetryConfig::default(),
33 }
34 }
35
36 pub fn with_retry(mut self, retry: RetryConfig) -> Self {
37 self.retry = retry;
38 self
39 }
40
41 fn chunk_texts(&self, texts: &[impl AsRef<str>]) -> Vec<Vec<String>> {
43 let max_batch = self
44 .batch_size
45 .min(self.embedder.max_batch_size());
46 texts
47 .chunks(max_batch)
48 .map(|chunk| chunk.iter().map(|t| t.as_ref().to_string()).collect())
49 .collect()
50 }
51
52 pub async fn embed_all(
54 &self,
55 texts: &[impl AsRef<str>],
56 ) -> Result<Vec<Vec<f32>>, EmbedError> {
57 self.embed_all_with_progress(texts, |_, _| {}).await
58 }
59
60 pub async fn embed_all_with_progress(
62 &self,
63 texts: &[impl AsRef<str>],
64 on_batch_done: impl Fn(usize, usize),
65 ) -> Result<Vec<Vec<f32>>, EmbedError> {
66 let batches = self.chunk_texts(texts);
67 let total_batches = batches.len();
68
69 if total_batches == 0 {
70 return Ok(vec![]);
71 }
72
73 let semaphore = Arc::new(Semaphore::new(self.max_concurrency));
74 let mut handles = Vec::with_capacity(total_batches);
75
76 for (i, batch) in batches.into_iter().enumerate() {
77 let permit = semaphore.clone().acquire_owned().await.map_err(|e| {
78 EmbedError::Config(format!("获取信号量失败: {e}"))
79 })?;
80
81 let embedder = self.embedder.clone();
82 let retry = self.retry.clone();
83
84 handles.push(tokio::spawn(async move {
85 let _permit = permit;
86 let texts_refs: Vec<&str> = batch.iter().map(|s| s.as_str()).collect();
87 let result = retry_with_backoff(|| embedder.embed(&texts_refs), &retry).await;
88 (i, result)
89 }));
90 }
91
92 let mut ordered_results: Vec<Option<Vec<Vec<f32>>>> = vec![None; total_batches];
93 let mut errors = Vec::new();
94
95 for handle in handles {
96 match handle.await {
97 Ok((idx, Ok(vectors))) => {
98 ordered_results[idx] = Some(vectors);
99 on_batch_done(idx + 1, total_batches);
100 }
101 Ok((idx, Err(e))) => {
102 warn!(target: "xz_embed", batch = idx, error = %e, "batch embedding failed");
103 errors.push(e);
104 }
105 Err(e) => {
106 errors.push(EmbedError::Config(format!("task join error: {e}")));
107 }
108 }
109 }
110
111 if !errors.is_empty() {
112 return Err(errors.remove(0));
113 }
114
115 let all_vectors: Vec<Vec<f32>> = ordered_results
116 .into_iter()
117 .filter_map(|r| r)
118 .flatten()
119 .collect();
120
121 debug!(
122 target: "xz_embed",
123 total_texts = texts.len(),
124 total_batches,
125 total_vectors = all_vectors.len(),
126 "embed_all completed"
127 );
128
129 Ok(all_vectors)
130 }
131}
132
133async fn retry_with_backoff<F, Fut, T>(
134 f: F,
135 config: &RetryConfig,
136) -> Result<T, EmbedError>
137where
138 F: Fn() -> Fut,
139 Fut: std::future::Future<Output = Result<T, EmbedError>>,
140{
141 let mut attempt = 0;
142 let mut backoff_ms = config.initial_backoff_ms;
143
144 loop {
145 match f().await {
146 Ok(result) => return Ok(result),
147 Err(e) if e.is_retryable() && attempt < config.max_retries => {
148 attempt += 1;
149 debug!(
150 target: "xz_embed",
151 attempt,
152 backoff_ms,
153 error = %e,
154 "retrying embedding request"
155 );
156 tokio::time::sleep(Duration::from_millis(backoff_ms)).await;
157 backoff_ms =
158 (backoff_ms as f64 * config.backoff_multiplier).min(config.max_backoff_ms as f64)
159 as u64;
160 }
161 Err(e) => return Err(e),
162 }
163 }
164}