1use async_trait::async_trait;
2use futures::future::{BoxFuture, FutureExt};
3use futures::stream::Stream;
4use std::sync::Arc;
5use tokio::sync::{mpsc, Semaphore};
6use tokio::task;
7use trustformers_core::errors::Result;
8use trustformers_core::traits::{TokenizedInput, Tokenizer};
9
10#[async_trait]
12pub trait AsyncTokenizer: Send + Sync {
13 async fn encode_async(&self, text: &str) -> Result<TokenizedInput>;
15
16 async fn encode_pair_async(&self, text: &str, text2: &str) -> Result<TokenizedInput>;
18
19 async fn decode_async(&self, ids: &[u32]) -> Result<String>;
21
22 async fn encode_batch_async(&self, texts: &[&str]) -> Result<Vec<TokenizedInput>>;
24
25 async fn encode_pair_batch_async(
27 &self,
28 text_pairs: &[(&str, &str)],
29 ) -> Result<Vec<TokenizedInput>>;
30
31 fn encode_stream<'a>(
33 &'a self,
34 texts: Vec<String>,
35 ) -> BoxFuture<'a, Result<Box<dyn Stream<Item = Result<TokenizedInput>> + Send + Unpin>>>;
36}
37
38pub struct AsyncTokenizerWrapper<T> {
40 tokenizer: Arc<T>,
41 max_concurrent_tasks: usize,
42 task_semaphore: Arc<Semaphore>,
43}
44
45impl<T> AsyncTokenizerWrapper<T>
46where
47 T: Tokenizer + Send + Sync + 'static,
48{
49 pub fn new(tokenizer: T, max_concurrent_tasks: Option<usize>) -> Self {
51 let max_tasks = max_concurrent_tasks.unwrap_or(num_cpus::get() * 2);
52 Self {
53 tokenizer: Arc::new(tokenizer),
54 max_concurrent_tasks: max_tasks,
55 task_semaphore: Arc::new(Semaphore::new(max_tasks)),
56 }
57 }
58
59 pub fn with_max_concurrent_tasks(mut self, max_tasks: usize) -> Self {
61 self.max_concurrent_tasks = max_tasks;
62 self.task_semaphore = Arc::new(Semaphore::new(max_tasks));
63 self
64 }
65
66 pub fn inner(&self) -> &Arc<T> {
68 &self.tokenizer
69 }
70}
71
72#[async_trait]
73impl<T> AsyncTokenizer for AsyncTokenizerWrapper<T>
74where
75 T: Tokenizer + Send + Sync + 'static,
76{
77 async fn encode_async(&self, text: &str) -> Result<TokenizedInput> {
78 let tokenizer = Arc::clone(&self.tokenizer);
79 let text = text.to_string();
80 let _permit = self.task_semaphore.acquire().await.map_err(|_| {
81 trustformers_core::errors::TrustformersError::other(
82 anyhow::anyhow!("Failed to acquire semaphore permit").to_string(),
83 )
84 })?;
85
86 task::spawn_blocking(move || tokenizer.encode(&text)).await.map_err(|e| {
87 trustformers_core::errors::TrustformersError::other(
88 anyhow::anyhow!(format!("Task join error: {}", e)).to_string(),
89 )
90 })?
91 }
92
93 async fn encode_pair_async(&self, text: &str, text2: &str) -> Result<TokenizedInput> {
94 let tokenizer = Arc::clone(&self.tokenizer);
95 let text = text.to_string();
96 let text2 = text2.to_string();
97 let _permit = self.task_semaphore.acquire().await.map_err(|_| {
98 trustformers_core::errors::TrustformersError::other(
99 anyhow::anyhow!("Failed to acquire semaphore permit").to_string(),
100 )
101 })?;
102
103 task::spawn_blocking(move || tokenizer.encode_pair(&text, &text2))
104 .await
105 .map_err(|e| {
106 trustformers_core::errors::TrustformersError::other(
107 anyhow::anyhow!(format!("Task join error: {}", e)).to_string(),
108 )
109 })?
110 }
111
112 async fn decode_async(&self, ids: &[u32]) -> Result<String> {
113 let tokenizer = Arc::clone(&self.tokenizer);
114 let ids = ids.to_vec();
115 let _permit = self.task_semaphore.acquire().await.map_err(|_| {
116 trustformers_core::errors::TrustformersError::other(
117 anyhow::anyhow!("Failed to acquire semaphore permit").to_string(),
118 )
119 })?;
120
121 task::spawn_blocking(move || tokenizer.decode(&ids)).await.map_err(|e| {
122 trustformers_core::errors::TrustformersError::other(
123 anyhow::anyhow!(format!("Task join error: {}", e)).to_string(),
124 )
125 })?
126 }
127
128 async fn encode_batch_async(&self, texts: &[&str]) -> Result<Vec<TokenizedInput>> {
129 let mut tasks = Vec::new();
130
131 for text in texts {
132 let tokenizer = Arc::clone(&self.tokenizer);
133 let text = text.to_string();
134 let semaphore = Arc::clone(&self.task_semaphore);
135
136 let task = task::spawn(async move {
137 let _permit = semaphore.acquire().await.map_err(|_| {
138 trustformers_core::errors::TrustformersError::other(
139 anyhow::anyhow!("Failed to acquire semaphore permit").to_string(),
140 )
141 })?;
142
143 task::spawn_blocking(move || tokenizer.encode(&text)).await.map_err(|e| {
144 trustformers_core::errors::TrustformersError::other(
145 anyhow::anyhow!(format!("Task join error: {}", e)).to_string(),
146 )
147 })?
148 });
149
150 tasks.push(task);
151 }
152
153 let mut results = Vec::with_capacity(texts.len());
154 for task in tasks {
155 let result = task.await.map_err(|e| {
156 trustformers_core::errors::TrustformersError::other(
157 anyhow::anyhow!(format!("Task join error: {}", e)).to_string(),
158 )
159 })??;
160 results.push(result);
161 }
162
163 Ok(results)
164 }
165
166 async fn encode_pair_batch_async(
167 &self,
168 text_pairs: &[(&str, &str)],
169 ) -> Result<Vec<TokenizedInput>> {
170 let mut tasks = Vec::new();
171
172 for (text1, text2) in text_pairs {
173 let tokenizer = Arc::clone(&self.tokenizer);
174 let text1 = text1.to_string();
175 let text2 = text2.to_string();
176 let semaphore = Arc::clone(&self.task_semaphore);
177
178 let task = task::spawn(async move {
179 let _permit = semaphore.acquire().await.map_err(|_| {
180 trustformers_core::errors::TrustformersError::other(
181 anyhow::anyhow!("Failed to acquire semaphore permit").to_string(),
182 )
183 })?;
184
185 task::spawn_blocking(move || tokenizer.encode_pair(&text1, &text2))
186 .await
187 .map_err(|e| {
188 trustformers_core::errors::TrustformersError::other(
189 anyhow::anyhow!(format!("Task join error: {}", e)).to_string(),
190 )
191 })?
192 });
193
194 tasks.push(task);
195 }
196
197 let mut results = Vec::with_capacity(text_pairs.len());
198 for task in tasks {
199 let result = task.await.map_err(|e| {
200 trustformers_core::errors::TrustformersError::other(
201 anyhow::anyhow!(format!("Task join error: {}", e)).to_string(),
202 )
203 })??;
204 results.push(result);
205 }
206
207 Ok(results)
208 }
209
210 fn encode_stream<'a>(
211 &'a self,
212 texts: Vec<String>,
213 ) -> BoxFuture<'a, Result<Box<dyn Stream<Item = Result<TokenizedInput>> + Send + Unpin>>> {
214 async move {
215 let (tx, rx) = mpsc::unbounded_channel();
216 let tokenizer = Arc::clone(&self.tokenizer);
217 let semaphore = Arc::clone(&self.task_semaphore);
218
219 task::spawn(async move {
221 for text in texts {
222 let tokenizer = Arc::clone(&tokenizer);
223 let semaphore = Arc::clone(&semaphore);
224 let tx = tx.clone();
225
226 task::spawn(async move {
227 let result = async {
228 let _permit = semaphore.acquire().await.map_err(|_| {
229 trustformers_core::errors::TrustformersError::other(
230 anyhow::anyhow!("Failed to acquire semaphore permit")
231 .to_string(),
232 )
233 })?;
234
235 task::spawn_blocking(move || tokenizer.encode(&text)).await.map_err(
236 |e| {
237 trustformers_core::errors::TrustformersError::other(
238 anyhow::anyhow!(format!("Task join error: {}", e))
239 .to_string(),
240 )
241 },
242 )?
243 }
244 .await;
245
246 let _ = tx.send(result);
247 });
248 }
249 });
250
251 let stream = tokio_stream::wrappers::UnboundedReceiverStream::new(rx);
252 Ok(Box::new(stream)
253 as Box<
254 dyn Stream<Item = Result<TokenizedInput>> + Send + Unpin,
255 >)
256 }
257 .boxed()
258 }
259}
260
261#[derive(Debug, Clone)]
263pub struct AsyncTokenizerConfig {
264 pub max_concurrent_tasks: usize,
266
267 pub stream_buffer_size: usize,
269
270 pub task_timeout_ms: Option<u64>,
272
273 pub enable_cancellation: bool,
275}
276
277impl Default for AsyncTokenizerConfig {
278 fn default() -> Self {
279 Self {
280 max_concurrent_tasks: num_cpus::get() * 2,
281 stream_buffer_size: 1000,
282 task_timeout_ms: None,
283 enable_cancellation: false,
284 }
285 }
286}
287
288pub struct ConfigurableAsyncTokenizer<T> {
290 tokenizer: Arc<T>,
291 config: AsyncTokenizerConfig,
292 task_semaphore: Arc<Semaphore>,
293}
294
295impl<T> ConfigurableAsyncTokenizer<T>
296where
297 T: Tokenizer + Send + Sync + 'static,
298{
299 pub fn new(tokenizer: T, config: AsyncTokenizerConfig) -> Self {
301 let semaphore = Arc::new(Semaphore::new(config.max_concurrent_tasks));
302 Self {
303 tokenizer: Arc::new(tokenizer),
304 config,
305 task_semaphore: semaphore,
306 }
307 }
308
309 pub fn update_config(&mut self, config: AsyncTokenizerConfig) {
311 self.task_semaphore = Arc::new(Semaphore::new(config.max_concurrent_tasks));
312 self.config = config;
313 }
314
315 pub fn config(&self) -> &AsyncTokenizerConfig {
317 &self.config
318 }
319
320 pub async fn encode_large_batch_with_progress<F>(
322 &self,
323 texts: &[&str],
324 mut progress_callback: F,
325 ) -> Result<Vec<TokenizedInput>>
326 where
327 F: FnMut(usize, usize) + Send + 'static,
328 {
329 let total = texts.len();
330 let mut completed = 0;
331 let mut results = Vec::with_capacity(total);
332
333 let chunk_size = (self.config.max_concurrent_tasks).max(1);
335
336 for chunk in texts.chunks(chunk_size) {
337 let chunk_results = self.encode_batch_async(chunk).await?;
338 results.extend(chunk_results);
339
340 completed += chunk.len();
341 progress_callback(completed, total);
342 }
343
344 Ok(results)
345 }
346}
347
348#[async_trait]
349impl<T> AsyncTokenizer for ConfigurableAsyncTokenizer<T>
350where
351 T: Tokenizer + Send + Sync + 'static,
352{
353 async fn encode_async(&self, text: &str) -> Result<TokenizedInput> {
354 let tokenizer = Arc::clone(&self.tokenizer);
355 let text = text.to_string();
356 let _permit = self.task_semaphore.acquire().await.map_err(|_| {
357 trustformers_core::errors::TrustformersError::other(
358 anyhow::anyhow!("Failed to acquire semaphore permit").to_string(),
359 )
360 })?;
361
362 let encoding_task = task::spawn_blocking(move || tokenizer.encode(&text));
363
364 if let Some(timeout_ms) = self.config.task_timeout_ms {
365 match tokio::time::timeout(std::time::Duration::from_millis(timeout_ms), encoding_task)
366 .await
367 {
368 Ok(result) => result.map_err(|e| {
369 trustformers_core::errors::TrustformersError::other(
370 anyhow::anyhow!(format!("Task join error: {}", e)).to_string(),
371 )
372 })?,
373 Err(_) => Err(trustformers_core::errors::TrustformersError::other(
374 anyhow::anyhow!("Tokenization timeout".to_string()).to_string(),
375 )),
376 }
377 } else {
378 encoding_task.await.map_err(|e| {
379 trustformers_core::errors::TrustformersError::other(
380 anyhow::anyhow!(format!("Task join error: {}", e)).to_string(),
381 )
382 })?
383 }
384 }
385
386 async fn encode_pair_async(&self, text: &str, text2: &str) -> Result<TokenizedInput> {
387 let tokenizer = Arc::clone(&self.tokenizer);
388 let text = text.to_string();
389 let text2 = text2.to_string();
390 let _permit = self.task_semaphore.acquire().await.map_err(|_| {
391 trustformers_core::errors::TrustformersError::other(
392 anyhow::anyhow!("Failed to acquire semaphore permit").to_string(),
393 )
394 })?;
395
396 let encoding_task = task::spawn_blocking(move || tokenizer.encode_pair(&text, &text2));
397
398 if let Some(timeout_ms) = self.config.task_timeout_ms {
399 match tokio::time::timeout(std::time::Duration::from_millis(timeout_ms), encoding_task)
400 .await
401 {
402 Ok(result) => result.map_err(|e| {
403 trustformers_core::errors::TrustformersError::other(
404 anyhow::anyhow!(format!("Task join error: {}", e)).to_string(),
405 )
406 })?,
407 Err(_) => Err(trustformers_core::errors::TrustformersError::other(
408 anyhow::anyhow!("Tokenization timeout".to_string()).to_string(),
409 )),
410 }
411 } else {
412 encoding_task.await.map_err(|e| {
413 trustformers_core::errors::TrustformersError::other(
414 anyhow::anyhow!(format!("Task join error: {}", e)).to_string(),
415 )
416 })?
417 }
418 }
419
420 async fn decode_async(&self, ids: &[u32]) -> Result<String> {
421 let tokenizer = Arc::clone(&self.tokenizer);
422 let ids = ids.to_vec();
423 let _permit = self.task_semaphore.acquire().await.map_err(|_| {
424 trustformers_core::errors::TrustformersError::other(
425 anyhow::anyhow!("Failed to acquire semaphore permit").to_string(),
426 )
427 })?;
428
429 let decoding_task = task::spawn_blocking(move || tokenizer.decode(&ids));
430
431 if let Some(timeout_ms) = self.config.task_timeout_ms {
432 match tokio::time::timeout(std::time::Duration::from_millis(timeout_ms), decoding_task)
433 .await
434 {
435 Ok(result) => result.map_err(|e| {
436 trustformers_core::errors::TrustformersError::other(
437 anyhow::anyhow!(format!("Task join error: {}", e)).to_string(),
438 )
439 })?,
440 Err(_) => Err(trustformers_core::errors::TrustformersError::other(
441 anyhow::anyhow!("Decoding timeout".to_string()).to_string(),
442 )),
443 }
444 } else {
445 decoding_task.await.map_err(|e| {
446 trustformers_core::errors::TrustformersError::other(
447 anyhow::anyhow!(format!("Task join error: {}", e)).to_string(),
448 )
449 })?
450 }
451 }
452
453 async fn encode_batch_async(&self, texts: &[&str]) -> Result<Vec<TokenizedInput>> {
454 let mut tasks = Vec::new();
455
456 for text in texts {
457 let tokenizer = Arc::clone(&self.tokenizer);
458 let text = text.to_string();
459 let semaphore = Arc::clone(&self.task_semaphore);
460 let timeout_ms = self.config.task_timeout_ms;
461
462 let task = task::spawn(async move {
463 let _permit = semaphore.acquire().await.map_err(|_| {
464 trustformers_core::errors::TrustformersError::other(
465 anyhow::anyhow!("Failed to acquire semaphore permit").to_string(),
466 )
467 })?;
468
469 let encoding_task = task::spawn_blocking(move || tokenizer.encode(&text));
470
471 if let Some(timeout_ms) = timeout_ms {
472 match tokio::time::timeout(
473 std::time::Duration::from_millis(timeout_ms),
474 encoding_task,
475 )
476 .await
477 {
478 Ok(result) => result.map_err(|e| {
479 trustformers_core::errors::TrustformersError::other(
480 anyhow::anyhow!(format!("Task join error: {}", e)).to_string(),
481 )
482 })?,
483 Err(_) => Err(trustformers_core::errors::TrustformersError::other(
484 anyhow::anyhow!("Tokenization timeout").to_string(),
485 )),
486 }
487 } else {
488 encoding_task.await.map_err(|e| {
489 trustformers_core::errors::TrustformersError::other(
490 anyhow::anyhow!(format!("Task join error: {}", e)).to_string(),
491 )
492 })?
493 }
494 });
495
496 tasks.push(task);
497 }
498
499 let mut results = Vec::with_capacity(texts.len());
500 for task in tasks {
501 let result = task.await.map_err(|e| {
502 trustformers_core::errors::TrustformersError::other(
503 anyhow::anyhow!(format!("Task join error: {}", e)).to_string(),
504 )
505 })??;
506 results.push(result);
507 }
508
509 Ok(results)
510 }
511
512 async fn encode_pair_batch_async(
513 &self,
514 text_pairs: &[(&str, &str)],
515 ) -> Result<Vec<TokenizedInput>> {
516 let mut tasks = Vec::new();
517
518 for (text1, text2) in text_pairs {
519 let tokenizer = Arc::clone(&self.tokenizer);
520 let text1 = text1.to_string();
521 let text2 = text2.to_string();
522 let semaphore = Arc::clone(&self.task_semaphore);
523 let timeout_ms = self.config.task_timeout_ms;
524
525 let task = task::spawn(async move {
526 let _permit = semaphore.acquire().await.map_err(|_| {
527 trustformers_core::errors::TrustformersError::other(
528 anyhow::anyhow!("Failed to acquire semaphore permit").to_string(),
529 )
530 })?;
531
532 let encoding_task =
533 task::spawn_blocking(move || tokenizer.encode_pair(&text1, &text2));
534
535 if let Some(timeout_ms) = timeout_ms {
536 match tokio::time::timeout(
537 std::time::Duration::from_millis(timeout_ms),
538 encoding_task,
539 )
540 .await
541 {
542 Ok(result) => result.map_err(|e| {
543 trustformers_core::errors::TrustformersError::other(
544 anyhow::anyhow!(format!("Task join error: {}", e)).to_string(),
545 )
546 })?,
547 Err(_) => Err(trustformers_core::errors::TrustformersError::other(
548 anyhow::anyhow!("Tokenization timeout").to_string(),
549 )),
550 }
551 } else {
552 encoding_task.await.map_err(|e| {
553 trustformers_core::errors::TrustformersError::other(
554 anyhow::anyhow!(format!("Task join error: {}", e)).to_string(),
555 )
556 })?
557 }
558 });
559
560 tasks.push(task);
561 }
562
563 let mut results = Vec::with_capacity(text_pairs.len());
564 for task in tasks {
565 let result = task.await.map_err(|e| {
566 trustformers_core::errors::TrustformersError::other(
567 anyhow::anyhow!(format!("Task join error: {}", e)).to_string(),
568 )
569 })??;
570 results.push(result);
571 }
572
573 Ok(results)
574 }
575
576 fn encode_stream<'a>(
577 &'a self,
578 texts: Vec<String>,
579 ) -> BoxFuture<'a, Result<Box<dyn Stream<Item = Result<TokenizedInput>> + Send + Unpin>>> {
580 async move {
581 let (tx, rx) = mpsc::channel(self.config.stream_buffer_size);
582 let tokenizer = Arc::clone(&self.tokenizer);
583 let semaphore = Arc::clone(&self.task_semaphore);
584 let timeout_ms = self.config.task_timeout_ms;
585
586 task::spawn(async move {
588 for text in texts {
589 let tokenizer = Arc::clone(&tokenizer);
590 let semaphore = Arc::clone(&semaphore);
591 let tx = tx.clone();
592
593 task::spawn(async move {
594 let result = async {
595 let _permit = semaphore.acquire().await.map_err(|_| {
596 trustformers_core::errors::TrustformersError::other(
597 anyhow::anyhow!("Failed to acquire semaphore permit")
598 .to_string(),
599 )
600 })?;
601
602 let encoding_task =
603 task::spawn_blocking(move || tokenizer.encode(&text));
604
605 if let Some(timeout_ms) = timeout_ms {
606 match tokio::time::timeout(
607 std::time::Duration::from_millis(timeout_ms),
608 encoding_task,
609 )
610 .await
611 {
612 Ok(result) => result.map_err(|e| {
613 trustformers_core::errors::TrustformersError::other(
614 anyhow::anyhow!(format!("Task join error: {}", e))
615 .to_string(),
616 )
617 })?,
618 Err(_) => {
619 Err(trustformers_core::errors::TrustformersError::other(
620 anyhow::anyhow!("Tokenization timeout").to_string(),
621 ))
622 },
623 }
624 } else {
625 encoding_task.await.map_err(|e| {
626 trustformers_core::errors::TrustformersError::other(
627 anyhow::anyhow!(format!("Task join error: {}", e))
628 .to_string(),
629 )
630 })?
631 }
632 }
633 .await;
634
635 let _ = tx.send(result).await;
636 });
637 }
638 });
639
640 let stream = tokio_stream::wrappers::ReceiverStream::new(rx);
641 Ok(Box::new(stream)
642 as Box<
643 dyn Stream<Item = Result<TokenizedInput>> + Send + Unpin,
644 >)
645 }
646 .boxed()
647 }
648}
649
650#[cfg(test)]
651mod tests {
652 use super::*;
653 use crate::wordpiece::WordPieceTokenizer;
654 use futures::StreamExt;
655 use std::time::Instant;
656
657 #[tokio::test]
658 async fn test_async_tokenizer_wrapper() {
659 let mut vocab = std::collections::HashMap::new();
660 vocab.insert("[UNK]".to_string(), 0);
661 vocab.insert("[CLS]".to_string(), 1);
662 vocab.insert("[SEP]".to_string(), 2);
663 vocab.insert("[PAD]".to_string(), 3);
664 vocab.insert("[MASK]".to_string(), 4);
665 vocab.insert("hello".to_string(), 5);
666 vocab.insert("world".to_string(), 6);
667
668 let tokenizer = WordPieceTokenizer::new(vocab, true);
669 let async_tokenizer = AsyncTokenizerWrapper::new(tokenizer, Some(4));
670
671 let result = async_tokenizer
672 .encode_async("Hello world")
673 .await
674 .expect("Operation failed in test");
675 assert!(!result.input_ids.is_empty());
676 }
677
678 #[tokio::test]
679 async fn test_batch_async_encoding() {
680 let tokenizer = WordPieceTokenizer::from_pretrained("bert-base-uncased")
681 .expect("Operation failed in test");
682 let async_tokenizer = AsyncTokenizerWrapper::new(tokenizer, Some(4));
683
684 let texts = vec!["Hello world", "This is a test", "Async tokenization"];
685 let results = async_tokenizer
686 .encode_batch_async(&texts)
687 .await
688 .expect("Operation failed in test");
689
690 assert_eq!(results.len(), texts.len());
691 for result in &results {
692 assert!(!result.input_ids.is_empty());
693 }
694 }
695
696 #[tokio::test]
697 async fn test_configurable_async_tokenizer() {
698 let tokenizer = WordPieceTokenizer::from_pretrained("bert-base-uncased")
699 .expect("Operation failed in test");
700 let config = AsyncTokenizerConfig {
701 max_concurrent_tasks: 2,
702 stream_buffer_size: 100,
703 task_timeout_ms: Some(5000),
704 enable_cancellation: true,
705 };
706 let async_tokenizer = ConfigurableAsyncTokenizer::new(tokenizer, config);
707
708 let result = async_tokenizer
709 .encode_async("Hello world")
710 .await
711 .expect("Operation failed in test");
712 assert!(!result.input_ids.is_empty());
713 }
714
715 #[tokio::test]
716 async fn test_async_decode() {
717 let mut vocab = std::collections::HashMap::new();
718 vocab.insert("[UNK]".to_string(), 0);
719 vocab.insert("[CLS]".to_string(), 1);
720 vocab.insert("[SEP]".to_string(), 2);
721 vocab.insert("[PAD]".to_string(), 3);
722 vocab.insert("[MASK]".to_string(), 4);
723 vocab.insert("hello".to_string(), 5);
724 vocab.insert("world".to_string(), 6);
725
726 let tokenizer = WordPieceTokenizer::new(vocab, true);
727 let async_tokenizer = AsyncTokenizerWrapper::new(tokenizer, Some(4));
728
729 let encoded = async_tokenizer
730 .encode_async("Hello world")
731 .await
732 .expect("Operation failed in test");
733 let decoded = async_tokenizer
734 .decode_async(&encoded.input_ids)
735 .await
736 .expect("Operation failed in test");
737
738 assert!(!decoded.is_empty());
739 assert!(
740 decoded.to_lowercase().contains("hello") || decoded.to_lowercase().contains("world")
741 );
742 }
743
744 #[tokio::test]
745 async fn test_stream_encoding() {
746 let mut vocab = std::collections::HashMap::new();
747 vocab.insert("[UNK]".to_string(), 0);
748 vocab.insert("[CLS]".to_string(), 1);
749 vocab.insert("[SEP]".to_string(), 2);
750 vocab.insert("[PAD]".to_string(), 3);
751 vocab.insert("[MASK]".to_string(), 4);
752 vocab.insert("hello".to_string(), 5);
753 vocab.insert("world".to_string(), 6);
754 vocab.insert("this".to_string(), 7);
755 vocab.insert("is".to_string(), 8);
756 vocab.insert("a".to_string(), 9);
757 vocab.insert("test".to_string(), 10);
758 vocab.insert("async".to_string(), 11);
759 vocab.insert("tokenization".to_string(), 12);
760
761 let tokenizer = WordPieceTokenizer::new(vocab, true);
762 let async_tokenizer = AsyncTokenizerWrapper::new(tokenizer, Some(4));
763
764 let texts = vec![
765 "Hello world".to_string(),
766 "This is a test".to_string(),
767 "Async tokenization".to_string(),
768 ];
769
770 let mut stream = async_tokenizer
771 .encode_stream(texts.clone())
772 .await
773 .expect("Operation failed in test");
774 let mut results = Vec::new();
775
776 while let Some(result) = stream.next().await {
777 results.push(result.expect("Operation failed in test"));
778 }
779
780 assert_eq!(results.len(), texts.len());
781 }
782
783 #[tokio::test]
784 async fn test_large_batch_with_progress() {
785 let tokenizer = WordPieceTokenizer::from_pretrained("bert-base-uncased")
786 .expect("Operation failed in test");
787 let config = AsyncTokenizerConfig::default();
788 let async_tokenizer = ConfigurableAsyncTokenizer::new(tokenizer, config);
789
790 let texts: Vec<&str> = (0..100)
791 .map(
792 |i| {
793 if i % 2 == 0 {
794 "Hello world"
795 } else {
796 "This is a test"
797 }
798 },
799 )
800 .collect();
801
802 let progress_updates = Arc::new(std::sync::Mutex::new(Vec::new()));
803 let progress_updates_clone = Arc::clone(&progress_updates);
804
805 let results = async_tokenizer
806 .encode_large_batch_with_progress(&texts, move |completed, total| {
807 progress_updates_clone
808 .lock()
809 .expect("lock should not be poisoned")
810 .push((completed, total));
811 })
812 .await
813 .expect("Operation failed in test");
814
815 assert_eq!(results.len(), texts.len());
816
817 let updates = progress_updates.lock().expect("lock should not be poisoned");
818 assert!(!updates.is_empty());
819 assert_eq!(
820 updates.last().expect("Operation failed in test").0,
821 texts.len()
822 );
823 assert_eq!(
824 updates.last().expect("Operation failed in test").1,
825 texts.len()
826 );
827 }
828
829 #[tokio::test]
830 async fn test_concurrent_performance() {
831 let tokenizer = WordPieceTokenizer::from_pretrained("bert-base-uncased")
832 .expect("Operation failed in test");
833 let async_tokenizer = AsyncTokenizerWrapper::new(tokenizer, Some(8));
834
835 let texts: Vec<&str> = (0..50)
836 .map(|i| {
837 if i % 2 == 0 {
838 "Hello world from async tokenization"
839 } else {
840 "This is a performance test"
841 }
842 })
843 .collect();
844
845 let start = Instant::now();
846 let results = async_tokenizer
847 .encode_batch_async(&texts)
848 .await
849 .expect("Operation failed in test");
850 let duration = start.elapsed();
851
852 assert_eq!(results.len(), texts.len());
853 println!("Encoded {} texts in {:?}", texts.len(), duration);
854
855 for result in &results {
857 assert!(!result.input_ids.is_empty());
858 assert_eq!(result.input_ids.len(), result.attention_mask.len());
859 }
860 }
861}