Skip to main content

trustformers_tokenizers/
async_tokenizer.rs

1use async_trait::async_trait;
2use futures::future::{BoxFuture, FutureExt};
3use futures::stream::Stream;
4use std::sync::Arc;
5use tokio::sync::{mpsc, Semaphore};
6use tokio::task;
7use trustformers_core::errors::Result;
8use trustformers_core::traits::{TokenizedInput, Tokenizer};
9
10/// Async version of the Tokenizer trait
11#[async_trait]
12pub trait AsyncTokenizer: Send + Sync {
13    /// Asynchronously encode a single text
14    async fn encode_async(&self, text: &str) -> Result<TokenizedInput>;
15
16    /// Asynchronously encode text pairs
17    async fn encode_pair_async(&self, text: &str, text2: &str) -> Result<TokenizedInput>;
18
19    /// Asynchronously decode token IDs to text
20    async fn decode_async(&self, ids: &[u32]) -> Result<String>;
21
22    /// Asynchronously encode multiple texts in parallel
23    async fn encode_batch_async(&self, texts: &[&str]) -> Result<Vec<TokenizedInput>>;
24
25    /// Asynchronously encode text pairs in parallel
26    async fn encode_pair_batch_async(
27        &self,
28        text_pairs: &[(&str, &str)],
29    ) -> Result<Vec<TokenizedInput>>;
30
31    /// Stream-based encoding for large datasets
32    fn encode_stream<'a>(
33        &'a self,
34        texts: Vec<String>,
35    ) -> BoxFuture<'a, Result<Box<dyn Stream<Item = Result<TokenizedInput>> + Send + Unpin>>>;
36}
37
38/// Wrapper that adds async capabilities to any synchronous tokenizer
39pub struct AsyncTokenizerWrapper<T> {
40    tokenizer: Arc<T>,
41    max_concurrent_tasks: usize,
42    task_semaphore: Arc<Semaphore>,
43}
44
45impl<T> AsyncTokenizerWrapper<T>
46where
47    T: Tokenizer + Send + Sync + 'static,
48{
49    /// Create a new async wrapper around a synchronous tokenizer
50    pub fn new(tokenizer: T, max_concurrent_tasks: Option<usize>) -> Self {
51        let max_tasks = max_concurrent_tasks.unwrap_or(num_cpus::get() * 2);
52        Self {
53            tokenizer: Arc::new(tokenizer),
54            max_concurrent_tasks: max_tasks,
55            task_semaphore: Arc::new(Semaphore::new(max_tasks)),
56        }
57    }
58
59    /// Set the maximum number of concurrent tasks
60    pub fn with_max_concurrent_tasks(mut self, max_tasks: usize) -> Self {
61        self.max_concurrent_tasks = max_tasks;
62        self.task_semaphore = Arc::new(Semaphore::new(max_tasks));
63        self
64    }
65
66    /// Get the underlying synchronous tokenizer
67    pub fn inner(&self) -> &Arc<T> {
68        &self.tokenizer
69    }
70}
71
72#[async_trait]
73impl<T> AsyncTokenizer for AsyncTokenizerWrapper<T>
74where
75    T: Tokenizer + Send + Sync + 'static,
76{
77    async fn encode_async(&self, text: &str) -> Result<TokenizedInput> {
78        let tokenizer = Arc::clone(&self.tokenizer);
79        let text = text.to_string();
80        let _permit = self.task_semaphore.acquire().await.map_err(|_| {
81            trustformers_core::errors::TrustformersError::other(
82                anyhow::anyhow!("Failed to acquire semaphore permit").to_string(),
83            )
84        })?;
85
86        task::spawn_blocking(move || tokenizer.encode(&text)).await.map_err(|e| {
87            trustformers_core::errors::TrustformersError::other(
88                anyhow::anyhow!(format!("Task join error: {}", e)).to_string(),
89            )
90        })?
91    }
92
93    async fn encode_pair_async(&self, text: &str, text2: &str) -> Result<TokenizedInput> {
94        let tokenizer = Arc::clone(&self.tokenizer);
95        let text = text.to_string();
96        let text2 = text2.to_string();
97        let _permit = self.task_semaphore.acquire().await.map_err(|_| {
98            trustformers_core::errors::TrustformersError::other(
99                anyhow::anyhow!("Failed to acquire semaphore permit").to_string(),
100            )
101        })?;
102
103        task::spawn_blocking(move || tokenizer.encode_pair(&text, &text2))
104            .await
105            .map_err(|e| {
106                trustformers_core::errors::TrustformersError::other(
107                    anyhow::anyhow!(format!("Task join error: {}", e)).to_string(),
108                )
109            })?
110    }
111
112    async fn decode_async(&self, ids: &[u32]) -> Result<String> {
113        let tokenizer = Arc::clone(&self.tokenizer);
114        let ids = ids.to_vec();
115        let _permit = self.task_semaphore.acquire().await.map_err(|_| {
116            trustformers_core::errors::TrustformersError::other(
117                anyhow::anyhow!("Failed to acquire semaphore permit").to_string(),
118            )
119        })?;
120
121        task::spawn_blocking(move || tokenizer.decode(&ids)).await.map_err(|e| {
122            trustformers_core::errors::TrustformersError::other(
123                anyhow::anyhow!(format!("Task join error: {}", e)).to_string(),
124            )
125        })?
126    }
127
128    async fn encode_batch_async(&self, texts: &[&str]) -> Result<Vec<TokenizedInput>> {
129        let mut tasks = Vec::new();
130
131        for text in texts {
132            let tokenizer = Arc::clone(&self.tokenizer);
133            let text = text.to_string();
134            let semaphore = Arc::clone(&self.task_semaphore);
135
136            let task = task::spawn(async move {
137                let _permit = semaphore.acquire().await.map_err(|_| {
138                    trustformers_core::errors::TrustformersError::other(
139                        anyhow::anyhow!("Failed to acquire semaphore permit").to_string(),
140                    )
141                })?;
142
143                task::spawn_blocking(move || tokenizer.encode(&text)).await.map_err(|e| {
144                    trustformers_core::errors::TrustformersError::other(
145                        anyhow::anyhow!(format!("Task join error: {}", e)).to_string(),
146                    )
147                })?
148            });
149
150            tasks.push(task);
151        }
152
153        let mut results = Vec::with_capacity(texts.len());
154        for task in tasks {
155            let result = task.await.map_err(|e| {
156                trustformers_core::errors::TrustformersError::other(
157                    anyhow::anyhow!(format!("Task join error: {}", e)).to_string(),
158                )
159            })??;
160            results.push(result);
161        }
162
163        Ok(results)
164    }
165
166    async fn encode_pair_batch_async(
167        &self,
168        text_pairs: &[(&str, &str)],
169    ) -> Result<Vec<TokenizedInput>> {
170        let mut tasks = Vec::new();
171
172        for (text1, text2) in text_pairs {
173            let tokenizer = Arc::clone(&self.tokenizer);
174            let text1 = text1.to_string();
175            let text2 = text2.to_string();
176            let semaphore = Arc::clone(&self.task_semaphore);
177
178            let task = task::spawn(async move {
179                let _permit = semaphore.acquire().await.map_err(|_| {
180                    trustformers_core::errors::TrustformersError::other(
181                        anyhow::anyhow!("Failed to acquire semaphore permit").to_string(),
182                    )
183                })?;
184
185                task::spawn_blocking(move || tokenizer.encode_pair(&text1, &text2))
186                    .await
187                    .map_err(|e| {
188                        trustformers_core::errors::TrustformersError::other(
189                            anyhow::anyhow!(format!("Task join error: {}", e)).to_string(),
190                        )
191                    })?
192            });
193
194            tasks.push(task);
195        }
196
197        let mut results = Vec::with_capacity(text_pairs.len());
198        for task in tasks {
199            let result = task.await.map_err(|e| {
200                trustformers_core::errors::TrustformersError::other(
201                    anyhow::anyhow!(format!("Task join error: {}", e)).to_string(),
202                )
203            })??;
204            results.push(result);
205        }
206
207        Ok(results)
208    }
209
210    fn encode_stream<'a>(
211        &'a self,
212        texts: Vec<String>,
213    ) -> BoxFuture<'a, Result<Box<dyn Stream<Item = Result<TokenizedInput>> + Send + Unpin>>> {
214        async move {
215            let (tx, rx) = mpsc::unbounded_channel();
216            let tokenizer = Arc::clone(&self.tokenizer);
217            let semaphore = Arc::clone(&self.task_semaphore);
218
219            // Spawn a task to process all texts
220            task::spawn(async move {
221                for text in texts {
222                    let tokenizer = Arc::clone(&tokenizer);
223                    let semaphore = Arc::clone(&semaphore);
224                    let tx = tx.clone();
225
226                    task::spawn(async move {
227                        let result = async {
228                            let _permit = semaphore.acquire().await.map_err(|_| {
229                                trustformers_core::errors::TrustformersError::other(
230                                    anyhow::anyhow!("Failed to acquire semaphore permit")
231                                        .to_string(),
232                                )
233                            })?;
234
235                            task::spawn_blocking(move || tokenizer.encode(&text)).await.map_err(
236                                |e| {
237                                    trustformers_core::errors::TrustformersError::other(
238                                        anyhow::anyhow!(format!("Task join error: {}", e))
239                                            .to_string(),
240                                    )
241                                },
242                            )?
243                        }
244                        .await;
245
246                        let _ = tx.send(result);
247                    });
248                }
249            });
250
251            let stream = tokio_stream::wrappers::UnboundedReceiverStream::new(rx);
252            Ok(Box::new(stream)
253                as Box<
254                    dyn Stream<Item = Result<TokenizedInput>> + Send + Unpin,
255                >)
256        }
257        .boxed()
258    }
259}
260
261/// Configuration for async tokenization operations
262#[derive(Debug, Clone)]
263pub struct AsyncTokenizerConfig {
264    /// Maximum number of concurrent tokenization tasks
265    pub max_concurrent_tasks: usize,
266
267    /// Buffer size for streaming operations
268    pub stream_buffer_size: usize,
269
270    /// Timeout for individual tokenization operations (in milliseconds)
271    pub task_timeout_ms: Option<u64>,
272
273    /// Enable task cancellation on timeout
274    pub enable_cancellation: bool,
275}
276
277impl Default for AsyncTokenizerConfig {
278    fn default() -> Self {
279        Self {
280            max_concurrent_tasks: num_cpus::get() * 2,
281            stream_buffer_size: 1000,
282            task_timeout_ms: None,
283            enable_cancellation: false,
284        }
285    }
286}
287
288/// Advanced async tokenizer with configurable behavior
289pub struct ConfigurableAsyncTokenizer<T> {
290    tokenizer: Arc<T>,
291    config: AsyncTokenizerConfig,
292    task_semaphore: Arc<Semaphore>,
293}
294
295impl<T> ConfigurableAsyncTokenizer<T>
296where
297    T: Tokenizer + Send + Sync + 'static,
298{
299    /// Create a new configurable async tokenizer
300    pub fn new(tokenizer: T, config: AsyncTokenizerConfig) -> Self {
301        let semaphore = Arc::new(Semaphore::new(config.max_concurrent_tasks));
302        Self {
303            tokenizer: Arc::new(tokenizer),
304            config,
305            task_semaphore: semaphore,
306        }
307    }
308
309    /// Update configuration
310    pub fn update_config(&mut self, config: AsyncTokenizerConfig) {
311        self.task_semaphore = Arc::new(Semaphore::new(config.max_concurrent_tasks));
312        self.config = config;
313    }
314
315    /// Get current configuration
316    pub fn config(&self) -> &AsyncTokenizerConfig {
317        &self.config
318    }
319
320    /// Process a large batch with progress reporting
321    pub async fn encode_large_batch_with_progress<F>(
322        &self,
323        texts: &[&str],
324        mut progress_callback: F,
325    ) -> Result<Vec<TokenizedInput>>
326    where
327        F: FnMut(usize, usize) + Send + 'static,
328    {
329        let total = texts.len();
330        let mut completed = 0;
331        let mut results = Vec::with_capacity(total);
332
333        // Process in chunks to avoid overwhelming the system
334        let chunk_size = (self.config.max_concurrent_tasks).max(1);
335
336        for chunk in texts.chunks(chunk_size) {
337            let chunk_results = self.encode_batch_async(chunk).await?;
338            results.extend(chunk_results);
339
340            completed += chunk.len();
341            progress_callback(completed, total);
342        }
343
344        Ok(results)
345    }
346}
347
348#[async_trait]
349impl<T> AsyncTokenizer for ConfigurableAsyncTokenizer<T>
350where
351    T: Tokenizer + Send + Sync + 'static,
352{
353    async fn encode_async(&self, text: &str) -> Result<TokenizedInput> {
354        let tokenizer = Arc::clone(&self.tokenizer);
355        let text = text.to_string();
356        let _permit = self.task_semaphore.acquire().await.map_err(|_| {
357            trustformers_core::errors::TrustformersError::other(
358                anyhow::anyhow!("Failed to acquire semaphore permit").to_string(),
359            )
360        })?;
361
362        let encoding_task = task::spawn_blocking(move || tokenizer.encode(&text));
363
364        if let Some(timeout_ms) = self.config.task_timeout_ms {
365            match tokio::time::timeout(std::time::Duration::from_millis(timeout_ms), encoding_task)
366                .await
367            {
368                Ok(result) => result.map_err(|e| {
369                    trustformers_core::errors::TrustformersError::other(
370                        anyhow::anyhow!(format!("Task join error: {}", e)).to_string(),
371                    )
372                })?,
373                Err(_) => Err(trustformers_core::errors::TrustformersError::other(
374                    anyhow::anyhow!("Tokenization timeout".to_string()).to_string(),
375                )),
376            }
377        } else {
378            encoding_task.await.map_err(|e| {
379                trustformers_core::errors::TrustformersError::other(
380                    anyhow::anyhow!(format!("Task join error: {}", e)).to_string(),
381                )
382            })?
383        }
384    }
385
386    async fn encode_pair_async(&self, text: &str, text2: &str) -> Result<TokenizedInput> {
387        let tokenizer = Arc::clone(&self.tokenizer);
388        let text = text.to_string();
389        let text2 = text2.to_string();
390        let _permit = self.task_semaphore.acquire().await.map_err(|_| {
391            trustformers_core::errors::TrustformersError::other(
392                anyhow::anyhow!("Failed to acquire semaphore permit").to_string(),
393            )
394        })?;
395
396        let encoding_task = task::spawn_blocking(move || tokenizer.encode_pair(&text, &text2));
397
398        if let Some(timeout_ms) = self.config.task_timeout_ms {
399            match tokio::time::timeout(std::time::Duration::from_millis(timeout_ms), encoding_task)
400                .await
401            {
402                Ok(result) => result.map_err(|e| {
403                    trustformers_core::errors::TrustformersError::other(
404                        anyhow::anyhow!(format!("Task join error: {}", e)).to_string(),
405                    )
406                })?,
407                Err(_) => Err(trustformers_core::errors::TrustformersError::other(
408                    anyhow::anyhow!("Tokenization timeout".to_string()).to_string(),
409                )),
410            }
411        } else {
412            encoding_task.await.map_err(|e| {
413                trustformers_core::errors::TrustformersError::other(
414                    anyhow::anyhow!(format!("Task join error: {}", e)).to_string(),
415                )
416            })?
417        }
418    }
419
420    async fn decode_async(&self, ids: &[u32]) -> Result<String> {
421        let tokenizer = Arc::clone(&self.tokenizer);
422        let ids = ids.to_vec();
423        let _permit = self.task_semaphore.acquire().await.map_err(|_| {
424            trustformers_core::errors::TrustformersError::other(
425                anyhow::anyhow!("Failed to acquire semaphore permit").to_string(),
426            )
427        })?;
428
429        let decoding_task = task::spawn_blocking(move || tokenizer.decode(&ids));
430
431        if let Some(timeout_ms) = self.config.task_timeout_ms {
432            match tokio::time::timeout(std::time::Duration::from_millis(timeout_ms), decoding_task)
433                .await
434            {
435                Ok(result) => result.map_err(|e| {
436                    trustformers_core::errors::TrustformersError::other(
437                        anyhow::anyhow!(format!("Task join error: {}", e)).to_string(),
438                    )
439                })?,
440                Err(_) => Err(trustformers_core::errors::TrustformersError::other(
441                    anyhow::anyhow!("Decoding timeout".to_string()).to_string(),
442                )),
443            }
444        } else {
445            decoding_task.await.map_err(|e| {
446                trustformers_core::errors::TrustformersError::other(
447                    anyhow::anyhow!(format!("Task join error: {}", e)).to_string(),
448                )
449            })?
450        }
451    }
452
453    async fn encode_batch_async(&self, texts: &[&str]) -> Result<Vec<TokenizedInput>> {
454        let mut tasks = Vec::new();
455
456        for text in texts {
457            let tokenizer = Arc::clone(&self.tokenizer);
458            let text = text.to_string();
459            let semaphore = Arc::clone(&self.task_semaphore);
460            let timeout_ms = self.config.task_timeout_ms;
461
462            let task = task::spawn(async move {
463                let _permit = semaphore.acquire().await.map_err(|_| {
464                    trustformers_core::errors::TrustformersError::other(
465                        anyhow::anyhow!("Failed to acquire semaphore permit").to_string(),
466                    )
467                })?;
468
469                let encoding_task = task::spawn_blocking(move || tokenizer.encode(&text));
470
471                if let Some(timeout_ms) = timeout_ms {
472                    match tokio::time::timeout(
473                        std::time::Duration::from_millis(timeout_ms),
474                        encoding_task,
475                    )
476                    .await
477                    {
478                        Ok(result) => result.map_err(|e| {
479                            trustformers_core::errors::TrustformersError::other(
480                                anyhow::anyhow!(format!("Task join error: {}", e)).to_string(),
481                            )
482                        })?,
483                        Err(_) => Err(trustformers_core::errors::TrustformersError::other(
484                            anyhow::anyhow!("Tokenization timeout").to_string(),
485                        )),
486                    }
487                } else {
488                    encoding_task.await.map_err(|e| {
489                        trustformers_core::errors::TrustformersError::other(
490                            anyhow::anyhow!(format!("Task join error: {}", e)).to_string(),
491                        )
492                    })?
493                }
494            });
495
496            tasks.push(task);
497        }
498
499        let mut results = Vec::with_capacity(texts.len());
500        for task in tasks {
501            let result = task.await.map_err(|e| {
502                trustformers_core::errors::TrustformersError::other(
503                    anyhow::anyhow!(format!("Task join error: {}", e)).to_string(),
504                )
505            })??;
506            results.push(result);
507        }
508
509        Ok(results)
510    }
511
512    async fn encode_pair_batch_async(
513        &self,
514        text_pairs: &[(&str, &str)],
515    ) -> Result<Vec<TokenizedInput>> {
516        let mut tasks = Vec::new();
517
518        for (text1, text2) in text_pairs {
519            let tokenizer = Arc::clone(&self.tokenizer);
520            let text1 = text1.to_string();
521            let text2 = text2.to_string();
522            let semaphore = Arc::clone(&self.task_semaphore);
523            let timeout_ms = self.config.task_timeout_ms;
524
525            let task = task::spawn(async move {
526                let _permit = semaphore.acquire().await.map_err(|_| {
527                    trustformers_core::errors::TrustformersError::other(
528                        anyhow::anyhow!("Failed to acquire semaphore permit").to_string(),
529                    )
530                })?;
531
532                let encoding_task =
533                    task::spawn_blocking(move || tokenizer.encode_pair(&text1, &text2));
534
535                if let Some(timeout_ms) = timeout_ms {
536                    match tokio::time::timeout(
537                        std::time::Duration::from_millis(timeout_ms),
538                        encoding_task,
539                    )
540                    .await
541                    {
542                        Ok(result) => result.map_err(|e| {
543                            trustformers_core::errors::TrustformersError::other(
544                                anyhow::anyhow!(format!("Task join error: {}", e)).to_string(),
545                            )
546                        })?,
547                        Err(_) => Err(trustformers_core::errors::TrustformersError::other(
548                            anyhow::anyhow!("Tokenization timeout").to_string(),
549                        )),
550                    }
551                } else {
552                    encoding_task.await.map_err(|e| {
553                        trustformers_core::errors::TrustformersError::other(
554                            anyhow::anyhow!(format!("Task join error: {}", e)).to_string(),
555                        )
556                    })?
557                }
558            });
559
560            tasks.push(task);
561        }
562
563        let mut results = Vec::with_capacity(text_pairs.len());
564        for task in tasks {
565            let result = task.await.map_err(|e| {
566                trustformers_core::errors::TrustformersError::other(
567                    anyhow::anyhow!(format!("Task join error: {}", e)).to_string(),
568                )
569            })??;
570            results.push(result);
571        }
572
573        Ok(results)
574    }
575
576    fn encode_stream<'a>(
577        &'a self,
578        texts: Vec<String>,
579    ) -> BoxFuture<'a, Result<Box<dyn Stream<Item = Result<TokenizedInput>> + Send + Unpin>>> {
580        async move {
581            let (tx, rx) = mpsc::channel(self.config.stream_buffer_size);
582            let tokenizer = Arc::clone(&self.tokenizer);
583            let semaphore = Arc::clone(&self.task_semaphore);
584            let timeout_ms = self.config.task_timeout_ms;
585
586            // Spawn a task to process all texts
587            task::spawn(async move {
588                for text in texts {
589                    let tokenizer = Arc::clone(&tokenizer);
590                    let semaphore = Arc::clone(&semaphore);
591                    let tx = tx.clone();
592
593                    task::spawn(async move {
594                        let result = async {
595                            let _permit = semaphore.acquire().await.map_err(|_| {
596                                trustformers_core::errors::TrustformersError::other(
597                                    anyhow::anyhow!("Failed to acquire semaphore permit")
598                                        .to_string(),
599                                )
600                            })?;
601
602                            let encoding_task =
603                                task::spawn_blocking(move || tokenizer.encode(&text));
604
605                            if let Some(timeout_ms) = timeout_ms {
606                                match tokio::time::timeout(
607                                    std::time::Duration::from_millis(timeout_ms),
608                                    encoding_task,
609                                )
610                                .await
611                                {
612                                    Ok(result) => result.map_err(|e| {
613                                        trustformers_core::errors::TrustformersError::other(
614                                            anyhow::anyhow!(format!("Task join error: {}", e))
615                                                .to_string(),
616                                        )
617                                    })?,
618                                    Err(_) => {
619                                        Err(trustformers_core::errors::TrustformersError::other(
620                                            anyhow::anyhow!("Tokenization timeout").to_string(),
621                                        ))
622                                    },
623                                }
624                            } else {
625                                encoding_task.await.map_err(|e| {
626                                    trustformers_core::errors::TrustformersError::other(
627                                        anyhow::anyhow!(format!("Task join error: {}", e))
628                                            .to_string(),
629                                    )
630                                })?
631                            }
632                        }
633                        .await;
634
635                        let _ = tx.send(result).await;
636                    });
637                }
638            });
639
640            let stream = tokio_stream::wrappers::ReceiverStream::new(rx);
641            Ok(Box::new(stream)
642                as Box<
643                    dyn Stream<Item = Result<TokenizedInput>> + Send + Unpin,
644                >)
645        }
646        .boxed()
647    }
648}
649
650#[cfg(test)]
651mod tests {
652    use super::*;
653    use crate::wordpiece::WordPieceTokenizer;
654    use futures::StreamExt;
655    use std::time::Instant;
656
657    #[tokio::test]
658    async fn test_async_tokenizer_wrapper() {
659        let mut vocab = std::collections::HashMap::new();
660        vocab.insert("[UNK]".to_string(), 0);
661        vocab.insert("[CLS]".to_string(), 1);
662        vocab.insert("[SEP]".to_string(), 2);
663        vocab.insert("[PAD]".to_string(), 3);
664        vocab.insert("[MASK]".to_string(), 4);
665        vocab.insert("hello".to_string(), 5);
666        vocab.insert("world".to_string(), 6);
667
668        let tokenizer = WordPieceTokenizer::new(vocab, true);
669        let async_tokenizer = AsyncTokenizerWrapper::new(tokenizer, Some(4));
670
671        let result = async_tokenizer
672            .encode_async("Hello world")
673            .await
674            .expect("Operation failed in test");
675        assert!(!result.input_ids.is_empty());
676    }
677
678    #[tokio::test]
679    async fn test_batch_async_encoding() {
680        let tokenizer = WordPieceTokenizer::from_pretrained("bert-base-uncased")
681            .expect("Operation failed in test");
682        let async_tokenizer = AsyncTokenizerWrapper::new(tokenizer, Some(4));
683
684        let texts = vec!["Hello world", "This is a test", "Async tokenization"];
685        let results = async_tokenizer
686            .encode_batch_async(&texts)
687            .await
688            .expect("Operation failed in test");
689
690        assert_eq!(results.len(), texts.len());
691        for result in &results {
692            assert!(!result.input_ids.is_empty());
693        }
694    }
695
696    #[tokio::test]
697    async fn test_configurable_async_tokenizer() {
698        let tokenizer = WordPieceTokenizer::from_pretrained("bert-base-uncased")
699            .expect("Operation failed in test");
700        let config = AsyncTokenizerConfig {
701            max_concurrent_tasks: 2,
702            stream_buffer_size: 100,
703            task_timeout_ms: Some(5000),
704            enable_cancellation: true,
705        };
706        let async_tokenizer = ConfigurableAsyncTokenizer::new(tokenizer, config);
707
708        let result = async_tokenizer
709            .encode_async("Hello world")
710            .await
711            .expect("Operation failed in test");
712        assert!(!result.input_ids.is_empty());
713    }
714
715    #[tokio::test]
716    async fn test_async_decode() {
717        let mut vocab = std::collections::HashMap::new();
718        vocab.insert("[UNK]".to_string(), 0);
719        vocab.insert("[CLS]".to_string(), 1);
720        vocab.insert("[SEP]".to_string(), 2);
721        vocab.insert("[PAD]".to_string(), 3);
722        vocab.insert("[MASK]".to_string(), 4);
723        vocab.insert("hello".to_string(), 5);
724        vocab.insert("world".to_string(), 6);
725
726        let tokenizer = WordPieceTokenizer::new(vocab, true);
727        let async_tokenizer = AsyncTokenizerWrapper::new(tokenizer, Some(4));
728
729        let encoded = async_tokenizer
730            .encode_async("Hello world")
731            .await
732            .expect("Operation failed in test");
733        let decoded = async_tokenizer
734            .decode_async(&encoded.input_ids)
735            .await
736            .expect("Operation failed in test");
737
738        assert!(!decoded.is_empty());
739        assert!(
740            decoded.to_lowercase().contains("hello") || decoded.to_lowercase().contains("world")
741        );
742    }
743
744    #[tokio::test]
745    async fn test_stream_encoding() {
746        let mut vocab = std::collections::HashMap::new();
747        vocab.insert("[UNK]".to_string(), 0);
748        vocab.insert("[CLS]".to_string(), 1);
749        vocab.insert("[SEP]".to_string(), 2);
750        vocab.insert("[PAD]".to_string(), 3);
751        vocab.insert("[MASK]".to_string(), 4);
752        vocab.insert("hello".to_string(), 5);
753        vocab.insert("world".to_string(), 6);
754        vocab.insert("this".to_string(), 7);
755        vocab.insert("is".to_string(), 8);
756        vocab.insert("a".to_string(), 9);
757        vocab.insert("test".to_string(), 10);
758        vocab.insert("async".to_string(), 11);
759        vocab.insert("tokenization".to_string(), 12);
760
761        let tokenizer = WordPieceTokenizer::new(vocab, true);
762        let async_tokenizer = AsyncTokenizerWrapper::new(tokenizer, Some(4));
763
764        let texts = vec![
765            "Hello world".to_string(),
766            "This is a test".to_string(),
767            "Async tokenization".to_string(),
768        ];
769
770        let mut stream = async_tokenizer
771            .encode_stream(texts.clone())
772            .await
773            .expect("Operation failed in test");
774        let mut results = Vec::new();
775
776        while let Some(result) = stream.next().await {
777            results.push(result.expect("Operation failed in test"));
778        }
779
780        assert_eq!(results.len(), texts.len());
781    }
782
783    #[tokio::test]
784    async fn test_large_batch_with_progress() {
785        let tokenizer = WordPieceTokenizer::from_pretrained("bert-base-uncased")
786            .expect("Operation failed in test");
787        let config = AsyncTokenizerConfig::default();
788        let async_tokenizer = ConfigurableAsyncTokenizer::new(tokenizer, config);
789
790        let texts: Vec<&str> = (0..100)
791            .map(
792                |i| {
793                    if i % 2 == 0 {
794                        "Hello world"
795                    } else {
796                        "This is a test"
797                    }
798                },
799            )
800            .collect();
801
802        let progress_updates = Arc::new(std::sync::Mutex::new(Vec::new()));
803        let progress_updates_clone = Arc::clone(&progress_updates);
804
805        let results = async_tokenizer
806            .encode_large_batch_with_progress(&texts, move |completed, total| {
807                progress_updates_clone
808                    .lock()
809                    .expect("lock should not be poisoned")
810                    .push((completed, total));
811            })
812            .await
813            .expect("Operation failed in test");
814
815        assert_eq!(results.len(), texts.len());
816
817        let updates = progress_updates.lock().expect("lock should not be poisoned");
818        assert!(!updates.is_empty());
819        assert_eq!(
820            updates.last().expect("Operation failed in test").0,
821            texts.len()
822        );
823        assert_eq!(
824            updates.last().expect("Operation failed in test").1,
825            texts.len()
826        );
827    }
828
829    #[tokio::test]
830    async fn test_concurrent_performance() {
831        let tokenizer = WordPieceTokenizer::from_pretrained("bert-base-uncased")
832            .expect("Operation failed in test");
833        let async_tokenizer = AsyncTokenizerWrapper::new(tokenizer, Some(8));
834
835        let texts: Vec<&str> = (0..50)
836            .map(|i| {
837                if i % 2 == 0 {
838                    "Hello world from async tokenization"
839                } else {
840                    "This is a performance test"
841                }
842            })
843            .collect();
844
845        let start = Instant::now();
846        let results = async_tokenizer
847            .encode_batch_async(&texts)
848            .await
849            .expect("Operation failed in test");
850        let duration = start.elapsed();
851
852        assert_eq!(results.len(), texts.len());
853        println!("Encoded {} texts in {:?}", texts.len(), duration);
854
855        // Verify all results are valid
856        for result in &results {
857            assert!(!result.input_ids.is_empty());
858            assert_eq!(result.input_ids.len(), result.attention_mask.len());
859        }
860    }
861}