xgrammar/matcher/
batch_grammar_matcher.rs1use std::pin::Pin;
2
3use autocxx::prelude::*;
4
5use super::GrammarMatcher;
6use crate::{CxxUniquePtr, DLTensor, cxx_utils};
7
8pub struct BatchGrammarMatcher {
12 inner: CxxUniquePtr<crate::FFIBatchGrammarMatcher>,
13}
14
15impl BatchGrammarMatcher {
16 pub fn new(max_threads: i32) -> Self {
22 let ffi_pin = cxx_utils::make_batch_grammar_matcher(max_threads);
23 Self {
24 inner: ffi_pin,
25 }
26 }
27
28 pub fn new_auto() -> Self {
30 Self::new(-1)
31 }
32
33 pub fn batch_fill_next_token_bitmask(
44 &mut self,
45 matchers: &[GrammarMatcher],
46 bitmask: &mut DLTensor,
47 indices: Option<&[i32]>,
48 debug_print: bool,
49 ) {
50 let mut ffi_matcher_vec = cxx_utils::new_grammar_matcher_vector();
52 {
53 let mut vec_pin = ffi_matcher_vec.pin_mut();
54 cxx_utils::grammar_matcher_vec_reserve(
55 vec_pin.as_mut(),
56 matchers.len(),
57 );
58 for matcher in matchers {
59 cxx_utils::grammar_matcher_vec_push(
60 vec_pin.as_mut(),
61 matcher.ffi_ref(),
62 );
63 }
64 }
65
66 let (has_indices, indices_ptr, indices_len) = match indices {
67 Some(slice) if !slice.is_empty() => {
68 (true, slice.as_ptr(), slice.len())
69 },
70 _ => (false, std::ptr::null(), 0usize),
71 };
72
73 unsafe {
74 cxx_utils::batch_matcher_batch_fill_next_token_bitmask(
75 self.inner.as_mut().expect("BatchGrammarMatcher inner is null"),
76 ffi_matcher_vec.as_mut().unwrap().get_unchecked_mut(),
77 bitmask as *mut _,
78 has_indices,
79 indices_ptr,
80 indices_len,
81 debug_print,
82 );
83 }
84 }
85
86 pub fn batch_accept_token(
96 matchers: &[GrammarMatcher],
97 tokens: &[i32],
98 debug_print: bool,
99 ) -> Box<[bool]> {
100 assert_eq!(
101 matchers.len(),
102 tokens.len(),
103 "matchers and tokens must have the same length"
104 );
105
106 let mut ffi_matcher_vec = cxx_utils::new_grammar_matcher_vector();
107 {
108 let mut vec_pin = ffi_matcher_vec.pin_mut();
109 cxx_utils::grammar_matcher_vec_reserve(
110 vec_pin.as_mut(),
111 matchers.len(),
112 );
113 for matcher in matchers {
114 cxx_utils::grammar_matcher_vec_push(
115 vec_pin.as_mut(),
116 matcher.ffi_ref(),
117 );
118 }
119 }
120
121 let result = unsafe {
122 cxx_utils::batch_accept_token(
123 ffi_matcher_vec.as_mut().unwrap().get_unchecked_mut(),
124 tokens.as_ptr(),
125 tokens.len(),
126 debug_print,
127 )
128 };
129
130 result.iter().map(|&b| b != 0).collect::<Vec<_>>().into_boxed_slice()
131 }
132
133 pub fn batch_accept_string(
143 matchers: &[GrammarMatcher],
144 strings: &[impl AsRef<str>],
145 debug_print: bool,
146 ) -> Box<[bool]> {
147 assert_eq!(
148 matchers.len(),
149 strings.len(),
150 "matchers and strings must have the same length"
151 );
152
153 let mut ffi_matcher_vec = cxx_utils::new_grammar_matcher_vector();
154 {
155 let mut vec_pin = ffi_matcher_vec.pin_mut();
156 cxx_utils::grammar_matcher_vec_reserve(
157 vec_pin.as_mut(),
158 matchers.len(),
159 );
160 for matcher in matchers {
161 cxx_utils::grammar_matcher_vec_push(
162 vec_pin.as_mut(),
163 matcher.ffi_ref(),
164 );
165 }
166 }
167
168 let mut cxx_strings = cxx_utils::new_string_vector();
169 {
170 let mut cxx_vec_pin = cxx_strings.pin_mut();
171 cxx_utils::string_vec_reserve(cxx_vec_pin.as_mut(), strings.len());
172 for string in strings.iter() {
173 let bytes = string.as_ref().as_bytes();
174 unsafe {
175 cxx_utils::string_vec_push_bytes(
176 cxx_vec_pin.as_mut(),
177 bytes.as_ptr() as *const i8,
178 bytes.len(),
179 );
180 }
181 }
182 }
183
184 let result = unsafe {
185 cxx_utils::batch_accept_string(
186 ffi_matcher_vec.as_mut().unwrap().get_unchecked_mut(),
187 cxx_strings.as_ref().unwrap(),
188 debug_print,
189 )
190 };
191
192 result.iter().map(|&b| b != 0).collect::<Vec<_>>().into_boxed_slice()
193 }
194}