xgrammar/matcher/
batch_grammar_matcher.rs

1use std::pin::Pin;
2
3use autocxx::prelude::*;
4
5use super::GrammarMatcher;
6use crate::{CxxUniquePtr, DLTensor, cxx_utils};
7
8/// A batch version of GrammarMatcher that can fill the next token bitmask for multiple
9/// matchers in parallel. It utilizes multiple threads to speed up the computation. It is
10/// especially useful when the batch size is large.
11pub struct BatchGrammarMatcher {
12    inner: CxxUniquePtr<crate::FFIBatchGrammarMatcher>,
13}
14
15impl BatchGrammarMatcher {
16    /// Construct the batch grammar matcher.
17    ///
18    /// # Parameters
19    /// - `max_threads`: The maximum number of threads to use for parallel processing.
20    ///   Use -1 for automatic thread count (hardware_concurrency / 2).
21    pub fn new(max_threads: i32) -> Self {
22        let ffi_pin = cxx_utils::make_batch_grammar_matcher(max_threads);
23        Self {
24            inner: ffi_pin,
25        }
26    }
27
28    /// Create a batch grammar matcher with automatic thread count.
29    pub fn new_auto() -> Self {
30        Self::new(-1)
31    }
32
33    /// Fill the next token bitmask for multiple matchers.
34    ///
35    /// # Parameters
36    /// - `matchers`: The list of matchers to fill the bitmask for.
37    /// - `bitmask`: Must be a 2-dimensional int32 tensor with shape (bitmask_batch_size, bitmask_size).
38    ///   Bitmask_batch_size could be larger than the actual batch size to allow padding.
39    ///   Bitmask_size equals to ceil(vocab_size/32).
40    /// - `indices`: A list of indices to specify which rows in the bitmask to fill.
41    ///   If None, fill the bitmask [0..matchers.len()).
42    /// - `debug_print`: Whether to print information about generated bitmask (default: false).
43    pub fn batch_fill_next_token_bitmask(
44        &mut self,
45        matchers: &[GrammarMatcher],
46        bitmask: &mut DLTensor,
47        indices: Option<&[i32]>,
48        debug_print: bool,
49    ) {
50        // Create a C++ vector of GrammarMatcher objects
51        let mut ffi_matcher_vec = cxx_utils::new_grammar_matcher_vector();
52        {
53            let mut vec_pin = ffi_matcher_vec.pin_mut();
54            cxx_utils::grammar_matcher_vec_reserve(
55                vec_pin.as_mut(),
56                matchers.len(),
57            );
58            for matcher in matchers {
59                cxx_utils::grammar_matcher_vec_push(
60                    vec_pin.as_mut(),
61                    matcher.ffi_ref(),
62                );
63            }
64        }
65
66        let (has_indices, indices_ptr, indices_len) = match indices {
67            Some(slice) if !slice.is_empty() => {
68                (true, slice.as_ptr(), slice.len())
69            },
70            _ => (false, std::ptr::null(), 0usize),
71        };
72
73        unsafe {
74            cxx_utils::batch_matcher_batch_fill_next_token_bitmask(
75                self.inner.as_mut().expect("BatchGrammarMatcher inner is null"),
76                ffi_matcher_vec.as_mut().unwrap().get_unchecked_mut(),
77                bitmask as *mut _,
78                has_indices,
79                indices_ptr,
80                indices_len,
81                debug_print,
82            );
83        }
84    }
85
86    /// Accept a batch of tokens for multiple matchers.
87    ///
88    /// # Parameters
89    /// - `matchers`: The list of matchers to accept tokens for.
90    /// - `tokens`: The list of tokens to accept.
91    /// - `debug_print`: Whether to print information about generated bitmask (default: false).
92    ///
93    /// # Returns
94    /// A boxed slice of booleans indicating whether each token was accepted by its corresponding matcher.
95    pub fn batch_accept_token(
96        matchers: &[GrammarMatcher],
97        tokens: &[i32],
98        debug_print: bool,
99    ) -> Box<[bool]> {
100        assert_eq!(
101            matchers.len(),
102            tokens.len(),
103            "matchers and tokens must have the same length"
104        );
105
106        let mut ffi_matcher_vec = cxx_utils::new_grammar_matcher_vector();
107        {
108            let mut vec_pin = ffi_matcher_vec.pin_mut();
109            cxx_utils::grammar_matcher_vec_reserve(
110                vec_pin.as_mut(),
111                matchers.len(),
112            );
113            for matcher in matchers {
114                cxx_utils::grammar_matcher_vec_push(
115                    vec_pin.as_mut(),
116                    matcher.ffi_ref(),
117                );
118            }
119        }
120
121        let result = unsafe {
122            cxx_utils::batch_accept_token(
123                ffi_matcher_vec.as_mut().unwrap().get_unchecked_mut(),
124                tokens.as_ptr(),
125                tokens.len(),
126                debug_print,
127            )
128        };
129
130        result.iter().map(|&b| b != 0).collect::<Vec<_>>().into_boxed_slice()
131    }
132
133    /// Accept a batch of strings for multiple matchers.
134    ///
135    /// # Parameters
136    /// - `matchers`: The list of matchers to accept tokens for.
137    /// - `strings`: The list of strings to accept.
138    /// - `debug_print`: Whether to print information about generated bitmask (default: false).
139    ///
140    /// # Returns
141    /// A boxed slice of booleans indicating whether each string was accepted by its corresponding matcher.
142    pub fn batch_accept_string(
143        matchers: &[GrammarMatcher],
144        strings: &[impl AsRef<str>],
145        debug_print: bool,
146    ) -> Box<[bool]> {
147        assert_eq!(
148            matchers.len(),
149            strings.len(),
150            "matchers and strings must have the same length"
151        );
152
153        let mut ffi_matcher_vec = cxx_utils::new_grammar_matcher_vector();
154        {
155            let mut vec_pin = ffi_matcher_vec.pin_mut();
156            cxx_utils::grammar_matcher_vec_reserve(
157                vec_pin.as_mut(),
158                matchers.len(),
159            );
160            for matcher in matchers {
161                cxx_utils::grammar_matcher_vec_push(
162                    vec_pin.as_mut(),
163                    matcher.ffi_ref(),
164                );
165            }
166        }
167
168        let mut cxx_strings = cxx_utils::new_string_vector();
169        {
170            let mut cxx_vec_pin = cxx_strings.pin_mut();
171            cxx_utils::string_vec_reserve(cxx_vec_pin.as_mut(), strings.len());
172            for string in strings.iter() {
173                let bytes = string.as_ref().as_bytes();
174                unsafe {
175                    cxx_utils::string_vec_push_bytes(
176                        cxx_vec_pin.as_mut(),
177                        bytes.as_ptr() as *const i8,
178                        bytes.len(),
179                    );
180                }
181            }
182        }
183
184        let result = unsafe {
185            cxx_utils::batch_accept_string(
186                ffi_matcher_vec.as_mut().unwrap().get_unchecked_mut(),
187                cxx_strings.as_ref().unwrap(),
188                debug_print,
189            )
190        };
191
192        result.iter().map(|&b| b != 0).collect::<Vec<_>>().into_boxed_slice()
193    }
194}