Skip to main content

xz_embed/
batch_manager.rs

1use std::sync::Arc;
2use std::time::Duration;
3
4use futures::future;
5use tokio::sync::Semaphore;
6use tracing::{debug, warn};
7
8use crate::config::RetryConfig;
9use crate::error::EmbedError;
10use crate::traits::EmbeddingModel;
11
12/// 并发批量管理器
13///
14/// 将大批次文本拆分为多个小批次,并发发送 Embedding API 请求。
15/// 内置重试、限流、进度回调。
16pub struct ConcurrentBatchManager {
17    embedder: Arc<dyn EmbeddingModel>,
18    /// 每批文本数(不超过 embedder.max_batch_size())
19    batch_size: usize,
20    /// 最大并发批次
21    max_concurrency: usize,
22    /// 重试策略
23    retry: RetryConfig,
24}
25
26impl ConcurrentBatchManager {
27    pub fn new(embedder: Box<dyn EmbeddingModel>, batch_size: usize, max_concurrency: usize) -> Self {
28        Self {
29            embedder: Arc::from(embedder),
30            batch_size,
31            max_concurrency,
32            retry: RetryConfig::default(),
33        }
34    }
35
36    pub fn with_retry(mut self, retry: RetryConfig) -> Self {
37        self.retry = retry;
38        self
39    }
40
41    /// 将文本列表拆分为多个子批次
42    fn chunk_texts(&self, texts: &[impl AsRef<str>]) -> Vec<Vec<String>> {
43        let max_batch = self
44            .batch_size
45            .min(self.embedder.max_batch_size());
46        texts
47            .chunks(max_batch)
48            .map(|chunk| chunk.iter().map(|t| t.as_ref().to_string()).collect())
49            .collect()
50    }
51
52    /// 嵌入全部文本,返回顺序与输入一致
53    pub async fn embed_all(
54        &self,
55        texts: &[impl AsRef<str>],
56    ) -> Result<Vec<Vec<f32>>, EmbedError> {
57        self.embed_all_with_progress(texts, |_, _| {}).await
58    }
59
60    /// 带进度回调的嵌入
61    pub async fn embed_all_with_progress(
62        &self,
63        texts: &[impl AsRef<str>],
64        on_batch_done: impl Fn(usize, usize),
65    ) -> Result<Vec<Vec<f32>>, EmbedError> {
66        let batches = self.chunk_texts(texts);
67        let total_batches = batches.len();
68
69        if total_batches == 0 {
70            return Ok(vec![]);
71        }
72
73        let semaphore = Arc::new(Semaphore::new(self.max_concurrency));
74        let mut handles = Vec::with_capacity(total_batches);
75
76        for (i, batch) in batches.into_iter().enumerate() {
77            let permit = semaphore.clone().acquire_owned().await.map_err(|e| {
78                EmbedError::Config(format!("获取信号量失败: {e}"))
79            })?;
80
81            let embedder = self.embedder.clone();
82            let retry = self.retry.clone();
83
84            handles.push(tokio::spawn(async move {
85                let _permit = permit;
86                let texts_refs: Vec<&str> = batch.iter().map(|s| s.as_str()).collect();
87                let result = retry_with_backoff(|| embedder.embed(&texts_refs), &retry).await;
88                (i, result)
89            }));
90        }
91
92        let mut ordered_results: Vec<Option<Vec<Vec<f32>>>> = vec![None; total_batches];
93        let mut errors = Vec::new();
94
95        for handle in handles {
96            match handle.await {
97                Ok((idx, Ok(vectors))) => {
98                    ordered_results[idx] = Some(vectors);
99                    on_batch_done(idx + 1, total_batches);
100                }
101                Ok((idx, Err(e))) => {
102                    warn!(target: "xz_embed", batch = idx, error = %e, "batch embedding failed");
103                    errors.push(e);
104                }
105                Err(e) => {
106                    errors.push(EmbedError::Config(format!("task join error: {e}")));
107                }
108            }
109        }
110
111        if !errors.is_empty() {
112            return Err(errors.remove(0));
113        }
114
115        let all_vectors: Vec<Vec<f32>> = ordered_results
116            .into_iter()
117            .filter_map(|r| r)
118            .flatten()
119            .collect();
120
121        debug!(
122            target: "xz_embed",
123            total_texts = texts.len(),
124            total_batches,
125            total_vectors = all_vectors.len(),
126            "embed_all completed"
127        );
128
129        Ok(all_vectors)
130    }
131}
132
133async fn retry_with_backoff<F, Fut, T>(
134    f: F,
135    config: &RetryConfig,
136) -> Result<T, EmbedError>
137where
138    F: Fn() -> Fut,
139    Fut: std::future::Future<Output = Result<T, EmbedError>>,
140{
141    let mut attempt = 0;
142    let mut backoff_ms = config.initial_backoff_ms;
143
144    loop {
145        match f().await {
146            Ok(result) => return Ok(result),
147            Err(e) if e.is_retryable() && attempt < config.max_retries => {
148                attempt += 1;
149                debug!(
150                    target: "xz_embed",
151                    attempt,
152                    backoff_ms,
153                    error = %e,
154                    "retrying embedding request"
155                );
156                tokio::time::sleep(Duration::from_millis(backoff_ms)).await;
157                backoff_ms =
158                    (backoff_ms as f64 * config.backoff_multiplier).min(config.max_backoff_ms as f64)
159                        as u64;
160            }
161            Err(e) => return Err(e),
162        }
163    }
164}