1#![doc = include_str!("../README.md")]
2
3use polars::export::arrow::array::Utf8ViewArray;
6use polars::prelude::*;
7use pyo3::prelude::*;
8use pyo3_polars::derive::polars_expr;
9use rayon::prelude::*;
10use serde::Deserialize;
11use smallvec::SmallVec;
12use std::borrow::Cow;
13use std::collections::{HashMap, HashSet};
14use std::path::Path;
15use std::sync::{Arc, LazyLock, Mutex};
16use std::time::{Duration, Instant};
17use tokmat::extractor::{Extractor, MatchMode, ParseOutput};
18use tokmat::tel::CompiledPattern;
19use tokmat::token_model::TokenModel;
20use tokmat::tokenizer::{split_input_tokens, tokenize_with_model};
21
22static CONTEXT_CACHE: LazyLock<Mutex<HashMap<String, Arc<ModelContext>>>> =
23 LazyLock::new(|| Mutex::new(HashMap::new()));
24
25struct ModelContext {
26 model: TokenModel,
27 extractor: Extractor,
28 features: ModelFeatures,
29 type_codec: CompactValueCodec,
30 class_codec: CompactValueCodec,
31 type_enum_values: Vec<String>,
32 class_enum_values: Vec<String>,
33}
34
35#[allow(clippy::struct_excessive_bools)]
36#[derive(Debug, Clone, Copy, Default)]
37struct ModelFeatures {
38 has_postalcode: bool,
39 has_num: bool,
40 has_alpha: bool,
41 has_num_extended: bool,
42 has_alpha_extended: bool,
43 has_alpha_num: bool,
44 has_alpha_num_extended: bool,
45}
46
47impl ModelFeatures {
48 fn from_model(model: &TokenModel) -> Self {
49 let available = model.available_names();
50 Self {
51 has_postalcode: available.contains("POSTALCODE"),
52 has_num: available.contains("NUM"),
53 has_alpha: available.contains("ALPHA"),
54 has_num_extended: available.contains("NUM_EXTENDED"),
55 has_alpha_extended: available.contains("ALPHA_EXTENDED"),
56 has_alpha_num: available.contains("ALPHA_NUM"),
57 has_alpha_num_extended: available.contains("ALPHA_NUM_EXTENDED"),
58 }
59 }
60}
61
62#[derive(Debug, Default, Clone, Copy)]
63struct CompactExtractProfile {
64 rows: usize,
65 token_view_ns: Duration,
66 class_id_decode_ns: Duration,
67 raw_join_ns: Duration,
68 parse_ns: Duration,
69}
70
71#[derive(Debug)]
72struct ChunkExtractOutput {
73 field_values: Vec<Vec<Option<String>>>,
74 complements: Vec<Option<String>>,
75 compact_profile: CompactExtractProfile,
76}
77
78const RAW_TOKEN_SENTINEL: u8 = u8::MAX;
79
80#[derive(Debug)]
81struct CompactValueCodec {
82 values_by_id: Vec<String>,
83 ids_by_value: HashMap<String, u8>,
84}
85
86impl CompactValueCodec {
87 fn new(values: impl IntoIterator<Item = String>, label: &str) -> PolarsResult<Self> {
88 let mut values_by_id = values.into_iter().collect::<Vec<_>>();
89 values_by_id.sort();
90 values_by_id.dedup();
91
92 if values_by_id.len() >= usize::from(RAW_TOKEN_SENTINEL) {
93 polars_bail!(
94 ComputeError:
95 "{} vocabulary has {} entries; exceeds UInt8 compact encoding capacity",
96 label,
97 values_by_id.len()
98 );
99 }
100
101 let ids_by_value = values_by_id
102 .iter()
103 .enumerate()
104 .map(|(index, value)| {
105 (
106 value.clone(),
107 u8::try_from(index).expect("codec ids are bounded to u8 by construction"),
108 )
109 })
110 .collect();
111
112 Ok(Self {
113 values_by_id,
114 ids_by_value,
115 })
116 }
117
118 fn encode_known_or_raw(&self, value: &str) -> u8 {
119 self.ids_by_value
120 .get(value)
121 .copied()
122 .unwrap_or(RAW_TOKEN_SENTINEL)
123 }
124
125 fn decode_or_fallback_ref<'a>(&'a self, id: u8, raw_token: &'a str) -> Cow<'a, str> {
126 if id == RAW_TOKEN_SENTINEL {
127 Cow::Borrowed(raw_token)
128 } else {
129 match self.values_by_id.get(id as usize) {
130 Some(value) => Cow::Borrowed(value),
131 None => Cow::Borrowed(raw_token),
132 }
133 }
134 }
135}
136
137#[allow(clippy::struct_excessive_bools)]
138#[derive(Debug, Clone, Copy)]
139struct TokenizeLayout {
140 include_raw_value: bool,
141 include_types: bool,
142 include_classes: bool,
143 include_type_ids: bool,
144 include_class_ids: bool,
145 token_output: StringListOutput,
146 type_output: StringListOutput,
147 class_output: StringListOutput,
148}
149
150impl TokenizeLayout {
151 fn needs_type_values(self) -> bool {
152 self.include_types || self.include_type_ids
153 }
154}
155
156#[allow(clippy::struct_field_names)]
157#[derive(Debug)]
158struct TokenizedColumns {
159 raw_values: Option<Vec<Option<String>>>,
160 token_values: Vec<Option<Vec<String>>>,
161 type_values: Option<Vec<Option<Vec<String>>>>,
162 class_values: Option<Vec<Option<Vec<String>>>>,
163 type_id_values: Option<Vec<Option<Vec<u8>>>>,
164 class_id_values: Option<Vec<Option<Vec<u8>>>>,
165}
166
167#[derive(Debug)]
168struct TokenizedRow {
169 raw_value: Option<String>,
170 tokens: Vec<String>,
171 types: Option<Vec<String>>,
172 classes: Option<Vec<String>>,
173 type_ids: Option<Vec<u8>>,
174 class_ids: Option<Vec<u8>>,
175}
176
177#[derive(Debug, Clone, Copy, Deserialize, Default, PartialEq, Eq)]
178#[serde(rename_all = "lowercase")]
179enum StringListOutput {
180 #[default]
181 String,
182 Categorical,
183 Enum,
184}
185
186#[derive(Clone)]
188pub struct TokmatPolars {
189 context: Arc<ModelContext>,
190}
191
192impl TokmatPolars {
193 pub fn from_model_path(model_path: impl AsRef<Path>) -> PolarsResult<Self> {
199 let model_path = model_path.as_ref().to_string_lossy().into_owned();
200 let context = get_or_load_context(&model_path)?;
201 Ok(Self { context })
202 }
203
204 pub fn tokenize_series(&self, input: &Series) -> PolarsResult<Series> {
211 tokenize_series_with_context(
212 input,
213 &self.context,
214 TokenizeLayout {
215 include_raw_value: true,
216 include_types: true,
217 include_classes: true,
218 include_type_ids: false,
219 include_class_ids: false,
220 token_output: StringListOutput::String,
221 type_output: StringListOutput::String,
222 class_output: StringListOutput::String,
223 },
224 )
225 }
226
227 pub fn extract_series(&self, input: &Series, pattern: &str) -> PolarsResult<Series> {
234 self.extract_series_with_mode(input, pattern, MatchMode::Whole)
235 }
236
237 pub fn extract_series_with_mode(
244 &self,
245 input: &Series,
246 pattern: &str,
247 mode: MatchMode,
248 ) -> PolarsResult<Series> {
249 extract_series_with_context(input, &self.context, pattern, mode)
250 }
251
252 pub fn capture_field_names(&self, pattern: &str) -> PolarsResult<Vec<String>> {
258 let _ = &self.context;
259 capture_field_names_from_pattern(pattern)
260 }
261}
262
263#[allow(clippy::struct_excessive_bools)]
264#[derive(Debug, Clone, Deserialize)]
265struct TokenizeKwargs {
266 model_path: String,
267 #[serde(default = "default_true")]
268 include_raw_value: bool,
269 #[serde(default = "default_true")]
270 include_types: bool,
271 #[serde(default = "default_true")]
272 include_classes: bool,
273 #[serde(default)]
274 include_type_ids: bool,
275 #[serde(default)]
276 include_class_ids: bool,
277 #[serde(default)]
278 token_output: StringListOutput,
279 #[serde(default)]
280 type_output: StringListOutput,
281 #[serde(default)]
282 class_output: StringListOutput,
283}
284
285const fn default_true() -> bool {
286 true
287}
288
289impl TokenizeKwargs {
290 fn layout(&self) -> PolarsResult<TokenizeLayout> {
291 if self.token_output == StringListOutput::Enum
292 || self.type_output == StringListOutput::Enum
293 || self.class_output == StringListOutput::Enum
294 {
295 polars_bail!(
296 InvalidOperation:
297 "enum list output is not supported in tokmat-polars; use 'string' or 'categorical'"
298 );
299 }
300 Ok(TokenizeLayout {
301 include_raw_value: self.include_raw_value,
302 include_types: self.include_types,
303 include_classes: self.include_classes,
304 include_type_ids: self.include_type_ids,
305 include_class_ids: self.include_class_ids,
306 token_output: self.token_output,
307 type_output: self.type_output,
308 class_output: self.class_output,
309 })
310 }
311}
312
313#[derive(Debug, Clone, Deserialize, Default)]
314#[serde(rename_all = "lowercase")]
315enum MatchModeKwarg {
316 #[default]
317 Whole,
318 Start,
319 End,
320 Any,
321}
322
323impl From<MatchModeKwarg> for MatchMode {
324 fn from(mode: MatchModeKwarg) -> Self {
325 match mode {
326 MatchModeKwarg::Whole => Self::Whole,
327 MatchModeKwarg::Start => Self::Start,
328 MatchModeKwarg::End => Self::End,
329 MatchModeKwarg::Any => Self::Any,
330 }
331 }
332}
333
334#[derive(Debug, Clone, Deserialize)]
335struct ExtractKwargs {
336 model_path: String,
337 pattern: String,
338 #[serde(default)]
339 mode: MatchModeKwarg,
340}
341
342#[allow(clippy::needless_pass_by_value)]
343#[polars_expr(output_type_func_with_kwargs=tokenize_output_type)]
344fn tokenize_expr(inputs: &[Series], kwargs: TokenizeKwargs) -> PolarsResult<Series> {
345 tokenize_expr_impl(inputs, &kwargs)
346}
347
348#[allow(clippy::needless_pass_by_value)]
349#[polars_expr(output_type_func_with_kwargs=extract_output_type)]
350fn extract_expr(inputs: &[Series], kwargs: ExtractKwargs) -> PolarsResult<Series> {
351 extract_expr_impl(inputs, kwargs)
352}
353
354#[allow(clippy::unnecessary_wraps)]
355#[pymodule]
356fn tokmat_polars(_py: Python<'_>, _module: &Bound<'_, PyModule>) -> PyResult<()> {
357 Ok(())
358}
359
360#[allow(clippy::needless_pass_by_value, clippy::unnecessary_wraps)]
361fn tokenize_output_type(input_fields: &[Field], kwargs: TokenizeKwargs) -> PolarsResult<Field> {
362 let output_name = output_field_name(input_fields, "tokenized");
363 let context = get_or_load_context(&kwargs.model_path)?;
364 let layout = kwargs.layout()?;
365 Ok(Field::new(
366 output_name,
367 DataType::Struct(tokenize_fields(&context, layout)?),
368 ))
369}
370
371#[allow(clippy::needless_pass_by_value)]
372fn extract_output_type(input_fields: &[Field], kwargs: ExtractKwargs) -> PolarsResult<Field> {
373 let output_name = output_field_name(input_fields, "extracted");
374 let capture_names = capture_field_names_from_pattern(&kwargs.pattern)?;
375 Ok(Field::new(
376 output_name,
377 DataType::Struct(extract_fields(&capture_names)),
378 ))
379}
380
381fn tokenize_expr_impl(inputs: &[Series], kwargs: &TokenizeKwargs) -> PolarsResult<Series> {
382 let input = single_input(inputs, "tokenize_expr")?;
383 let context = get_or_load_context(&kwargs.model_path)?;
384 tokenize_series_with_context(input, &context, kwargs.layout()?)
385}
386
387fn extract_expr_impl(inputs: &[Series], kwargs: ExtractKwargs) -> PolarsResult<Series> {
388 let input = single_input(inputs, "extract_expr")?;
389 let context = get_or_load_context(&kwargs.model_path)?;
390 extract_series_with_context(input, &context, &kwargs.pattern, kwargs.mode.into())
391}
392
393fn tokenize_series_with_context(
394 input: &Series,
395 context: &ModelContext,
396 layout: TokenizeLayout,
397) -> PolarsResult<Series> {
398 if can_tokenize_direct(layout) {
399 return tokenize_series_with_context_direct(input, context, layout);
400 }
401
402 tokenize_series_with_context_staged(input, context, layout)
403}
404
405fn can_tokenize_direct(layout: TokenizeLayout) -> bool {
406 layout.token_output == StringListOutput::String
407 && (!layout.include_types || layout.type_output == StringListOutput::String)
408 && (!layout.include_classes || layout.class_output == StringListOutput::String)
409}
410
411fn tokenize_series_with_context_staged(
412 input: &Series,
413 context: &ModelContext,
414 layout: TokenizeLayout,
415) -> PolarsResult<Series> {
416 let row_count = input.len();
417 let mut columns = TokenizedColumns {
418 raw_values: layout
419 .include_raw_value
420 .then(|| Vec::with_capacity(row_count)),
421 token_values: Vec::with_capacity(row_count),
422 type_values: layout
423 .needs_type_values()
424 .then(|| Vec::with_capacity(row_count)),
425 class_values: layout
426 .include_classes
427 .then(|| Vec::with_capacity(row_count)),
428 type_id_values: layout
429 .include_type_ids
430 .then(|| Vec::with_capacity(row_count)),
431 class_id_values: layout
432 .include_class_ids
433 .then(|| Vec::with_capacity(row_count)),
434 };
435
436 for value in input.str()? {
437 if let Some(raw_value) = value {
438 let tokenized = tokenize_row(raw_value, context, layout);
439 if let Some(raw_values) = columns.raw_values.as_mut() {
440 raw_values.push(tokenized.raw_value);
441 }
442 columns.token_values.push(Some(tokenized.tokens));
443 if let Some(type_values) = columns.type_values.as_mut() {
444 type_values.push(tokenized.types);
445 }
446 if let Some(class_values) = columns.class_values.as_mut() {
447 class_values.push(tokenized.classes);
448 }
449 if let Some(type_id_values) = columns.type_id_values.as_mut() {
450 type_id_values.push(tokenized.type_ids);
451 }
452 if let Some(class_id_values) = columns.class_id_values.as_mut() {
453 class_id_values.push(tokenized.class_ids);
454 }
455 } else {
456 if let Some(raw_values) = columns.raw_values.as_mut() {
457 raw_values.push(None);
458 }
459 columns.token_values.push(None);
460 if let Some(type_values) = columns.type_values.as_mut() {
461 type_values.push(None);
462 }
463 if let Some(class_values) = columns.class_values.as_mut() {
464 class_values.push(None);
465 }
466 if let Some(type_id_values) = columns.type_id_values.as_mut() {
467 type_id_values.push(None);
468 }
469 if let Some(class_id_values) = columns.class_id_values.as_mut() {
470 class_id_values.push(None);
471 }
472 }
473 }
474
475 build_tokenized_struct_series(input.name().clone(), columns, context, layout)
476}
477
478#[allow(clippy::too_many_lines, clippy::cognitive_complexity)]
479fn tokenize_series_with_context_direct(
480 input: &Series,
481 context: &ModelContext,
482 layout: TokenizeLayout,
483) -> PolarsResult<Series> {
484 let row_count = input.len();
485 let input_total_bytes = input
486 .str()?
487 .into_iter()
488 .flatten()
489 .map(str::len)
490 .sum::<usize>();
491
492 let mut raw_values = layout
493 .include_raw_value
494 .then(|| Vec::with_capacity(row_count));
495 let mut token_builder =
496 ListStringChunkedBuilder::new("tokens".into(), row_count, input_total_bytes);
497 let mut type_builder = layout
498 .include_types
499 .then(|| ListStringChunkedBuilder::new("types".into(), row_count, input_total_bytes));
500 let mut class_builder = layout
501 .include_classes
502 .then(|| ListStringChunkedBuilder::new("classes".into(), row_count, input_total_bytes));
503 let mut type_id_builder = layout.include_type_ids.then(|| {
504 ListPrimitiveChunkedBuilder::<UInt8Type>::new(
505 "type_ids".into(),
506 row_count,
507 row_count * 8,
508 DataType::UInt8,
509 )
510 });
511 let mut class_id_builder = layout.include_class_ids.then(|| {
512 ListPrimitiveChunkedBuilder::<UInt8Type>::new(
513 "class_ids".into(),
514 row_count,
515 row_count * 8,
516 DataType::UInt8,
517 )
518 });
519
520 for value in input.str()? {
521 if let Some(raw_value) = value {
522 if let Some(values) = raw_values.as_mut() {
523 values.push(Some(raw_value.to_string()));
524 }
525
526 let tokens = split_input_tokens(raw_value);
527 token_builder.append_values_iter(tokens.iter().map(String::as_str));
528
529 let needs_row_types = layout.include_types || layout.include_type_ids;
530 let needs_row_classes = layout.include_classes || layout.include_class_ids;
531
532 let mut row_types = layout
533 .include_types
534 .then(SmallVec::<[Cow<'_, str>; 12]>::new);
535 let mut row_classes = layout
536 .include_classes
537 .then(SmallVec::<[Cow<'_, str>; 12]>::new);
538 let mut row_type_ids = layout.include_type_ids.then(SmallVec::<[u8; 12]>::new);
539 let mut row_class_ids = layout.include_class_ids.then(SmallVec::<[u8; 12]>::new);
540
541 if needs_row_types || needs_row_classes {
542 for token in &tokens {
543 let token_type = classify_token_ref(token, &context.model, context.features);
544 if let Some(values) = row_types.as_mut() {
545 values.push(token_type.clone());
546 }
547 if let Some(values) = row_type_ids.as_mut() {
548 values.push(context.type_codec.encode_known_or_raw(token_type.as_ref()));
549 }
550
551 let class_value = if token.chars().all(char::is_whitespace) {
552 Cow::Borrowed(token.as_str())
553 } else if let Some(value) = context.model.token_class_lookup().get(token) {
554 Cow::Borrowed(value.as_str())
555 } else {
556 Cow::Owned(token_type.as_ref().to_string())
557 };
558
559 if let Some(values) = row_classes.as_mut() {
560 values.push(class_value.clone());
561 }
562 if let Some(values) = row_class_ids.as_mut() {
563 values.push(
564 context
565 .class_codec
566 .encode_known_or_raw(class_value.as_ref()),
567 );
568 }
569 }
570 }
571
572 if let Some(builder) = type_builder.as_mut() {
573 if let Some(values) = row_types {
574 builder.append_values_iter(values.iter().map(AsRef::as_ref));
575 } else {
576 builder.append_values_iter(std::iter::empty::<&str>());
577 }
578 }
579 if let Some(builder) = class_builder.as_mut() {
580 if let Some(values) = row_classes {
581 builder.append_values_iter(values.iter().map(AsRef::as_ref));
582 } else {
583 builder.append_values_iter(std::iter::empty::<&str>());
584 }
585 }
586 if let Some(builder) = type_id_builder.as_mut() {
587 if let Some(values) = row_type_ids {
588 builder.append_slice(values.as_slice());
589 } else {
590 builder.append_slice(&[]);
591 }
592 }
593 if let Some(builder) = class_id_builder.as_mut() {
594 if let Some(values) = row_class_ids {
595 builder.append_slice(values.as_slice());
596 } else {
597 builder.append_slice(&[]);
598 }
599 }
600 } else {
601 if let Some(values) = raw_values.as_mut() {
602 values.push(None);
603 }
604 token_builder.append_null();
605 if let Some(builder) = type_builder.as_mut() {
606 builder.append_null();
607 }
608 if let Some(builder) = class_builder.as_mut() {
609 builder.append_null();
610 }
611 if let Some(builder) = type_id_builder.as_mut() {
612 builder.append_null();
613 }
614 if let Some(builder) = class_id_builder.as_mut() {
615 builder.append_null();
616 }
617 }
618 }
619
620 let mut fields = Vec::new();
621 if let Some(raw_values) = raw_values {
622 fields.push(
623 StringChunked::from_iter_options(
624 "raw_value".into(),
625 raw_values.iter().map(|value| value.as_deref()),
626 )
627 .into_series(),
628 );
629 }
630 fields.push(token_builder.finish().into_series());
631 if let Some(mut builder) = type_builder {
632 fields.push(builder.finish().into_series());
633 }
634 if let Some(mut builder) = class_builder {
635 fields.push(builder.finish().into_series());
636 }
637 if let Some(mut builder) = type_id_builder {
638 fields.push(builder.finish().into_series());
639 }
640 if let Some(mut builder) = class_id_builder {
641 fields.push(builder.finish().into_series());
642 }
643
644 Ok(StructChunked::from_series(input.name().clone(), row_count, fields.iter())?.into_series())
645}
646
647fn extract_series_with_context(
648 input: &Series,
649 context: &ModelContext,
650 pattern: &str,
651 mode: MatchMode,
652) -> PolarsResult<Series> {
653 let capture_names = capture_field_names_from_pattern(pattern)?;
654
655 match input.dtype() {
656 DataType::String => {
657 extract_from_string_series(input, context, pattern, mode, &capture_names)
658 }
659 DataType::Struct(_) => {
660 extract_from_tokenized_series(input, context, pattern, mode, &capture_names)
661 }
662 dtype => {
663 polars_bail!(
664 InvalidOperation:
665 "extract_series expected a String or tokenized Struct column, got {:?}",
666 dtype
667 )
668 }
669 }
670}
671
672fn get_or_load_context(model_path: &str) -> PolarsResult<Arc<ModelContext>> {
673 let mut cache = CONTEXT_CACHE
674 .lock()
675 .map_err(|error| polars_err!(ComputeError: "context cache poisoned: {}", error))?;
676
677 if let Some(context) = cache.get(model_path) {
678 return Ok(Arc::clone(context));
679 }
680
681 let model = TokenModel::load(Path::new(model_path)).map_err(|error| {
682 polars_err!(
683 ComputeError:
684 "failed to load tokmat model from '{}': {}",
685 model_path,
686 error
687 )
688 })?;
689 let extractor = Extractor::new(
690 model.token_definitions().clone(),
691 model.token_class_list().clone(),
692 );
693
694 let type_vocab = model
695 .token_definitions()
696 .iter()
697 .map(|(name, _)| name.clone())
698 .collect::<Vec<_>>();
699 let class_vocab = model
700 .token_definitions()
701 .iter()
702 .map(|(name, _)| name.clone())
703 .chain(model.token_class_lookup().values().cloned())
704 .collect::<Vec<_>>();
705 let type_codec = CompactValueCodec::new(type_vocab.clone(), "type")?;
706 let class_codec = CompactValueCodec::new(class_vocab.clone(), "class")?;
707 let type_enum_values = enum_categories(type_vocab.into_iter().chain([" ".to_string()]));
708 let class_enum_values = enum_categories(class_vocab.into_iter().chain([" ".to_string()]));
709
710 let features = ModelFeatures::from_model(&model);
711 let context = Arc::new(ModelContext {
712 model,
713 extractor,
714 features,
715 type_codec,
716 class_codec,
717 type_enum_values,
718 class_enum_values,
719 });
720 cache.insert(model_path.to_string(), Arc::clone(&context));
721 Ok(context)
722}
723
724fn single_input<'a>(inputs: &'a [Series], function_name: &str) -> PolarsResult<&'a Series> {
725 match inputs {
726 [input] => Ok(input),
727 _ => polars_bail!(
728 InvalidOperation:
729 "{} expected exactly one input column, got {}",
730 function_name,
731 inputs.len()
732 ),
733 }
734}
735
736fn output_field_name(input_fields: &[Field], fallback: &str) -> PlSmallStr {
737 input_fields
738 .first()
739 .map_or_else(|| fallback.into(), |field| field.name().clone())
740}
741
742fn enum_categories(values: impl IntoIterator<Item = String>) -> Vec<String> {
743 let mut categories = values.into_iter().collect::<Vec<_>>();
744 categories.sort();
745 categories.dedup();
746 categories
747}
748
749fn enum_dtype(values: &[String]) -> DataType {
750 let categories = Utf8ViewArray::from_slice(
751 values
752 .iter()
753 .map(|value| Some(value.as_str()))
754 .collect::<Vec<_>>(),
755 );
756 create_enum_dtype(categories)
757}
758
759fn list_output_dtype(
760 output: StringListOutput,
761 enum_values: Option<&[String]>,
762) -> PolarsResult<DataType> {
763 match output {
764 StringListOutput::String => Ok(DataType::String),
765 StringListOutput::Categorical => {
766 Ok(DataType::Categorical(None, CategoricalOrdering::default()))
767 }
768 StringListOutput::Enum => enum_values
769 .map(enum_dtype)
770 .ok_or_else(|| polars_err!(InvalidOperation: "enum output requires fixed categories")),
771 }
772}
773
774fn cast_string_list_series(
775 series: Series,
776 output: StringListOutput,
777 enum_values: Option<&[String]>,
778) -> PolarsResult<Series> {
779 let target_dtype = DataType::List(Box::new(list_output_dtype(output, enum_values)?));
780 if series.dtype() == &target_dtype {
781 Ok(series)
782 } else {
783 series.cast(&target_dtype)
784 }
785}
786
787fn build_output_string_list_series(
788 name: &str,
789 rows: Vec<Option<Vec<String>>>,
790 output: StringListOutput,
791 enum_values: Option<&[String]>,
792) -> PolarsResult<Series> {
793 match output {
794 StringListOutput::String => Ok(build_string_list_series(name, rows)),
795 StringListOutput::Categorical => {
796 cast_string_list_series(build_string_list_series(name, rows), output, enum_values)
797 }
798 StringListOutput::Enum => build_enum_list_series(
799 name,
800 rows,
801 enum_values.ok_or_else(
802 || polars_err!(InvalidOperation: "enum output requires fixed categories"),
803 )?,
804 ),
805 }
806}
807
808fn tokenize_fields(context: &ModelContext, layout: TokenizeLayout) -> PolarsResult<Vec<Field>> {
809 let mut fields = Vec::with_capacity(6);
810 if layout.include_raw_value {
811 fields.push(Field::new("raw_value".into(), DataType::String));
812 }
813 fields.push(Field::new(
814 "tokens".into(),
815 DataType::List(Box::new(list_output_dtype(layout.token_output, None)?)),
816 ));
817 if layout.include_types {
818 fields.push(Field::new(
819 "types".into(),
820 DataType::List(Box::new(list_output_dtype(
821 layout.type_output,
822 Some(&context.type_enum_values),
823 )?)),
824 ));
825 }
826 if layout.include_classes {
827 fields.push(Field::new(
828 "classes".into(),
829 DataType::List(Box::new(list_output_dtype(
830 layout.class_output,
831 Some(&context.class_enum_values),
832 )?)),
833 ));
834 }
835 if layout.include_type_ids {
836 fields.push(Field::new(
837 "type_ids".into(),
838 DataType::List(Box::new(DataType::UInt8)),
839 ));
840 }
841 if layout.include_class_ids {
842 fields.push(Field::new(
843 "class_ids".into(),
844 DataType::List(Box::new(DataType::UInt8)),
845 ));
846 }
847 Ok(fields)
848}
849
850fn extract_fields(capture_names: &[String]) -> Vec<Field> {
851 let mut fields = capture_names
852 .iter()
853 .map(|name| Field::new(name.clone().into(), DataType::String))
854 .collect::<Vec<_>>();
855 fields.push(Field::new("complement".into(), DataType::String));
856 fields
857}
858
859fn capture_field_names_from_pattern(pattern: &str) -> PolarsResult<Vec<String>> {
860 let compiled = CompiledPattern::compile(pattern).map_err(|error| {
861 polars_err!(
862 ComputeError:
863 "failed to compile TEL pattern '{}': {}",
864 pattern,
865 error
866 )
867 })?;
868
869 let mut seen = HashSet::new();
870 let mut fields = Vec::new();
871 for token_info in compiled.token_info() {
872 let Some(name) = token_info.var_name.as_ref() else {
873 continue;
874 };
875 if !(token_info.is_capturing_group() || token_info.is_vanishing_group()) {
876 continue;
877 }
878 if seen.insert(name.clone()) {
879 fields.push(name.clone());
880 }
881 }
882
883 Ok(fields)
884}
885
886fn extract_from_string_series(
887 input: &Series,
888 context: &ModelContext,
889 pattern: &str,
890 mode: MatchMode,
891 capture_names: &[String],
892) -> PolarsResult<Series> {
893 let mut field_columns = init_extract_columns(capture_names, input.len());
894 let mut complements = Vec::with_capacity(input.len());
895
896 for raw_value in input.str()? {
897 match raw_value {
898 Some(raw_value) => {
899 let tokenized = tokenize_with_model(raw_value, &context.model);
900 let parsed = parse_from_tokenized_parts(
901 context,
902 raw_value,
903 &tokenized.tokens,
904 &tokenized.classes,
905 pattern,
906 mode,
907 )?;
908 push_parse_output(&mut field_columns, &mut complements, Some(parsed));
909 }
910 None => push_parse_output(&mut field_columns, &mut complements, None),
911 }
912 }
913
914 build_extract_struct_series(input.name().clone(), field_columns, &complements)
915}
916
917#[allow(clippy::too_many_lines)]
918fn extract_from_tokenized_series(
919 input: &Series,
920 context: &ModelContext,
921 pattern: &str,
922 mode: MatchMode,
923 capture_names: &[String],
924) -> PolarsResult<Series> {
925 let struct_chunked = input.struct_()?;
926 let fields = struct_chunked.fields_as_series();
927 let field_map = fields
928 .into_iter()
929 .map(|field| (field.name().to_string(), field))
930 .collect::<HashMap<_, _>>();
931
932 let raw_field = field_map.get("raw_value");
933 let tokens_field = field_map.get("tokens").ok_or_else(|| {
934 polars_err!(
935 InvalidOperation:
936 "tokenized struct is missing required 'tokens' field"
937 )
938 })?;
939 let classes_field = field_map.get("classes");
940 let class_ids_field = field_map.get("class_ids");
941 if classes_field.is_none() && class_ids_field.is_none() {
942 polars_bail!(
943 InvalidOperation:
944 "tokenized struct is missing required 'classes' or 'class_ids' field"
945 );
946 }
947
948 if should_parallelize(input.len()) {
949 return extract_from_tokenized_series_parallel(
950 input,
951 context,
952 pattern,
953 mode,
954 capture_names,
955 raw_field,
956 tokens_field,
957 classes_field,
958 class_ids_field,
959 );
960 }
961
962 let mut raw_iter = raw_field
963 .map(|field| field.str().map(IntoIterator::into_iter))
964 .transpose()?;
965 let mut token_iter = tokens_field.list()?.into_iter();
966 let mut class_iter = classes_field
967 .map(|field| field.list().map(IntoIterator::into_iter))
968 .transpose()?;
969 let mut class_id_iter = class_ids_field
970 .map(|field| field.list().map(IntoIterator::into_iter))
971 .transpose()?;
972 let mut field_columns = init_extract_columns(capture_names, input.len());
973 let mut complements = Vec::with_capacity(input.len());
974 let profiling = profile_enabled();
975 let mut compact_profile = profiling.then(CompactExtractProfile::default);
976
977 for index in 0..input.len() {
978 let raw_value = raw_iter.as_mut().and_then(Iterator::next).flatten();
979 let tokens = token_iter.next();
980 let classes = class_iter.as_mut().and_then(Iterator::next);
981 let class_ids = class_id_iter.as_mut().and_then(Iterator::next);
982 match (tokens, classes, class_ids) {
983 (Some(Some(tokens)), Some(Some(classes)), _) => {
984 let token_values = list_series_to_strings(&tokens)?;
985 let class_values = list_series_to_strings(&classes)?;
986 let raw_value_buf = raw_value.map_or_else(
987 || Cow::Owned(join_string_values(&token_values)),
988 Cow::Borrowed,
989 );
990 let parsed = parse_from_tokenized_parts(
991 context,
992 raw_value_buf,
993 &token_values,
994 &class_values,
995 pattern,
996 mode,
997 )?;
998 push_parse_output(&mut field_columns, &mut complements, Some(parsed));
999 }
1000 (Some(Some(tokens)), _, Some(Some(class_ids))) => {
1001 let token_view_start = profiling.then(Instant::now);
1002 let token_values = list_series_to_strings(&tokens)?;
1003 let token_view_elapsed = elapsed_since(token_view_start);
1004 let decode_start = profiling.then(Instant::now);
1005 let class_values = list_series_to_u8(&class_ids)?
1006 .into_iter()
1007 .zip(token_values.iter())
1008 .map(|(class_id, token)| {
1009 context
1010 .class_codec
1011 .decode_or_fallback_ref(class_id, token.as_str())
1012 })
1013 .collect::<Vec<_>>();
1014 let decode_elapsed = elapsed_since(decode_start);
1015 let raw_join_start = profiling
1016 .then_some(raw_value.is_none())
1017 .filter(|should_join| *should_join)
1018 .map(|_| Instant::now());
1019 let raw_value_buf = raw_value.map_or_else(
1020 || Cow::Owned(join_string_values(&token_values)),
1021 Cow::Borrowed,
1022 );
1023 let raw_join_elapsed = elapsed_since(raw_join_start);
1024 let parse_start = profiling.then(Instant::now);
1025 let parsed = parse_from_tokenized_parts(
1026 context,
1027 raw_value_buf,
1028 &token_values,
1029 &class_values,
1030 pattern,
1031 mode,
1032 )?;
1033 let parse_elapsed = elapsed_since(parse_start);
1034 if let Some(profile) = compact_profile.as_mut() {
1035 profile.rows += 1;
1036 profile.token_view_ns += token_view_elapsed;
1037 profile.class_id_decode_ns += decode_elapsed;
1038 profile.raw_join_ns += raw_join_elapsed;
1039 profile.parse_ns += parse_elapsed;
1040 }
1041 push_parse_output(&mut field_columns, &mut complements, Some(parsed));
1042 }
1043 (Some(None), Some(None) | None, Some(None) | None) if raw_value.is_none() => {
1044 push_parse_output(&mut field_columns, &mut complements, None);
1045 }
1046 _ => {
1047 polars_bail!(
1048 InvalidOperation:
1049 "tokenized struct row {} has inconsistent nullability across fields",
1050 index
1051 )
1052 }
1053 }
1054 }
1055
1056 if let Some(profile) = compact_profile {
1057 if let Ok(stats) = context.extractor.stats() {
1058 eprintln!(
1059 "TOKMAT_PROFILE compact rows={} token_view_ns={} class_id_decode_ns={} raw_join_ns={} parse_ns={} tokmat_profiled_rows={} tokmat_total_ns={} tokmat_class_join_ns={} tokmat_class_regex_ns={} tokmat_offset_work_ns={} tokmat_object_join_ns={} tokmat_direct_execution_ns={} tokmat_fallback_regex_ns={}",
1060 profile.rows,
1061 profile.token_view_ns.as_nanos(),
1062 profile.class_id_decode_ns.as_nanos(),
1063 profile.raw_join_ns.as_nanos(),
1064 profile.parse_ns.as_nanos(),
1065 stats.profiled_rows,
1066 stats.profile_total_ns,
1067 stats.profile_class_join_ns,
1068 stats.profile_class_regex_ns,
1069 stats.profile_offset_work_ns,
1070 stats.profile_object_join_ns,
1071 stats.profile_direct_execution_ns,
1072 stats.profile_fallback_regex_ns,
1073 );
1074 }
1075 }
1076
1077 build_extract_struct_series(input.name().clone(), field_columns, &complements)
1078}
1079
1080#[allow(clippy::too_many_arguments)]
1081fn extract_from_tokenized_series_parallel(
1082 input: &Series,
1083 context: &ModelContext,
1084 pattern: &str,
1085 mode: MatchMode,
1086 capture_names: &[String],
1087 raw_field: Option<&Series>,
1088 tokens_field: &Series,
1089 classes_field: Option<&Series>,
1090 class_ids_field: Option<&Series>,
1091) -> PolarsResult<Series> {
1092 let row_count = input.len();
1093 let chunk_size = parallel_chunk_size(row_count);
1094 let profiling = profile_enabled();
1095 let chunk_ranges = (0..row_count)
1096 .step_by(chunk_size)
1097 .map(|start| (start, (start + chunk_size).min(row_count)))
1098 .collect::<Vec<_>>();
1099
1100 let raw_series = raw_field.cloned();
1101 let tokens_series = tokens_field.clone();
1102 let classes_series = classes_field.cloned();
1103 let class_ids_series = class_ids_field.cloned();
1104
1105 let chunk_results = chunk_ranges
1106 .into_par_iter()
1107 .map(|(start, end)| {
1108 process_extract_chunk(
1109 raw_series.as_ref(),
1110 &tokens_series,
1111 classes_series.as_ref(),
1112 class_ids_series.as_ref(),
1113 context,
1114 pattern,
1115 mode,
1116 capture_names,
1117 profiling,
1118 start,
1119 end,
1120 )
1121 })
1122 .collect::<Vec<_>>();
1123
1124 let mut field_values = capture_names
1125 .iter()
1126 .map(|_| Vec::with_capacity(row_count))
1127 .collect::<Vec<_>>();
1128 let mut complements = Vec::with_capacity(row_count);
1129 let mut merged_profile = CompactExtractProfile::default();
1130
1131 for chunk_result in chunk_results {
1132 let chunk = chunk_result?;
1133 for (index, values) in chunk.field_values.into_iter().enumerate() {
1134 field_values[index].extend(values);
1135 }
1136 complements.extend(chunk.complements);
1137 merged_profile.rows += chunk.compact_profile.rows;
1138 merged_profile.token_view_ns += chunk.compact_profile.token_view_ns;
1139 merged_profile.class_id_decode_ns += chunk.compact_profile.class_id_decode_ns;
1140 merged_profile.raw_join_ns += chunk.compact_profile.raw_join_ns;
1141 merged_profile.parse_ns += chunk.compact_profile.parse_ns;
1142 }
1143
1144 if profiling {
1145 if let Ok(stats) = context.extractor.stats() {
1146 eprintln!(
1147 "TOKMAT_PROFILE compact rows={} token_view_ns={} class_id_decode_ns={} raw_join_ns={} parse_ns={} tokmat_profiled_rows={} tokmat_total_ns={} tokmat_class_join_ns={} tokmat_class_regex_ns={} tokmat_offset_work_ns={} tokmat_object_join_ns={} tokmat_direct_execution_ns={} tokmat_fallback_regex_ns={}",
1148 merged_profile.rows,
1149 merged_profile.token_view_ns.as_nanos(),
1150 merged_profile.class_id_decode_ns.as_nanos(),
1151 merged_profile.raw_join_ns.as_nanos(),
1152 merged_profile.parse_ns.as_nanos(),
1153 stats.profiled_rows,
1154 stats.profile_total_ns,
1155 stats.profile_class_join_ns,
1156 stats.profile_class_regex_ns,
1157 stats.profile_offset_work_ns,
1158 stats.profile_object_join_ns,
1159 stats.profile_direct_execution_ns,
1160 stats.profile_fallback_regex_ns,
1161 );
1162 }
1163 }
1164
1165 let named_field_values = capture_names
1166 .iter()
1167 .cloned()
1168 .zip(field_values)
1169 .collect::<Vec<_>>();
1170 build_extract_struct_series(input.name().clone(), named_field_values, &complements)
1171}
1172
1173#[allow(clippy::too_many_arguments)]
1174#[allow(clippy::too_many_lines)]
1175fn process_extract_chunk(
1176 raw_series: Option<&Series>,
1177 tokens_series: &Series,
1178 classes_series: Option<&Series>,
1179 class_ids_series: Option<&Series>,
1180 context: &ModelContext,
1181 pattern: &str,
1182 mode: MatchMode,
1183 capture_names: &[String],
1184 profiling: bool,
1185 start: usize,
1186 end: usize,
1187) -> PolarsResult<ChunkExtractOutput> {
1188 let raw_utf8 = raw_series.map(Series::str).transpose()?;
1189 let token_lists = tokens_series.list()?;
1190 let class_lists = classes_series.map(Series::list).transpose()?;
1191 let class_id_lists = class_ids_series.map(Series::list).transpose()?;
1192
1193 let mut field_values = capture_names
1194 .iter()
1195 .map(|_| Vec::with_capacity(end - start))
1196 .collect::<Vec<_>>();
1197 let mut complements = Vec::with_capacity(end - start);
1198 let mut compact_profile = CompactExtractProfile::default();
1199 let mut class_id_values = Vec::new();
1200
1201 for index in start..end {
1202 let raw_value = raw_utf8.as_ref().and_then(|values| values.get(index));
1203 let tokens = token_lists.get_as_series(index);
1204 let classes = class_lists
1205 .as_ref()
1206 .and_then(|values| values.get_as_series(index));
1207 let class_ids = class_id_lists
1208 .as_ref()
1209 .and_then(|values| values.get_as_series(index));
1210
1211 match (tokens, classes, class_ids) {
1212 (Some(tokens), Some(classes), _) => {
1213 let token_values = list_series_to_strings(&tokens)?;
1214 let class_values = list_series_to_strings(&classes)?;
1215 let raw_value_buf = raw_value.map_or_else(
1216 || Cow::Owned(join_string_values(&token_values)),
1217 Cow::Borrowed,
1218 );
1219 let parsed = parse_from_tokenized_parts(
1220 context,
1221 raw_value_buf,
1222 &token_values,
1223 &class_values,
1224 pattern,
1225 mode,
1226 )?;
1227 push_parse_output_by_index(
1228 &mut field_values,
1229 &mut complements,
1230 capture_names,
1231 Some(parsed),
1232 );
1233 }
1234 (Some(tokens), _, Some(class_ids)) => {
1235 let token_view_start = profiling.then(Instant::now);
1236 let token_values = list_series_to_strings(&tokens)?;
1237 let token_view_elapsed = elapsed_since(token_view_start);
1238
1239 let decode_start = profiling.then(Instant::now);
1240 fill_series_u8(&class_ids, &mut class_id_values)?;
1241 let class_values = class_id_values
1242 .iter()
1243 .zip(token_values.iter())
1244 .map(|(class_id, token)| {
1245 context
1246 .class_codec
1247 .decode_or_fallback_ref(*class_id, token.as_str())
1248 })
1249 .collect::<Vec<_>>();
1250 let decode_elapsed = elapsed_since(decode_start);
1251
1252 let raw_join_start = profiling
1253 .then_some(raw_value.is_none())
1254 .filter(|should_join| *should_join)
1255 .map(|_| Instant::now());
1256 let raw_value_buf = raw_value.map_or_else(
1257 || Cow::Owned(join_string_values(&token_values)),
1258 Cow::Borrowed,
1259 );
1260 let raw_join_elapsed = elapsed_since(raw_join_start);
1261
1262 let parse_start = profiling.then(Instant::now);
1263 let parsed = parse_from_tokenized_parts(
1264 context,
1265 raw_value_buf,
1266 &token_values,
1267 &class_values,
1268 pattern,
1269 mode,
1270 )?;
1271 let parse_elapsed = elapsed_since(parse_start);
1272
1273 compact_profile.rows += 1;
1274 compact_profile.token_view_ns += token_view_elapsed;
1275 compact_profile.class_id_decode_ns += decode_elapsed;
1276 compact_profile.raw_join_ns += raw_join_elapsed;
1277 compact_profile.parse_ns += parse_elapsed;
1278
1279 push_parse_output_by_index(
1280 &mut field_values,
1281 &mut complements,
1282 capture_names,
1283 Some(parsed),
1284 );
1285 }
1286 (None, None, None) if raw_value.is_none() => {
1287 push_parse_output_by_index(
1288 &mut field_values,
1289 &mut complements,
1290 capture_names,
1291 None,
1292 );
1293 }
1294 _ => {
1295 polars_bail!(
1296 InvalidOperation:
1297 "tokenized struct row {} has inconsistent nullability across fields",
1298 index
1299 )
1300 }
1301 }
1302 }
1303
1304 Ok(ChunkExtractOutput {
1305 field_values,
1306 complements,
1307 compact_profile,
1308 })
1309}
1310
1311fn parse_from_tokenized_parts<T: AsRef<str>, C: AsRef<str>>(
1312 context: &ModelContext,
1313 raw_value: impl AsRef<str>,
1314 tokens: &[T],
1315 classes: &[C],
1316 pattern: &str,
1317 mode: MatchMode,
1318) -> PolarsResult<ParseOutput> {
1319 context
1320 .extractor
1321 .parse_tokens_with_views(raw_value.as_ref(), tokens, classes, pattern, mode)
1322 .map_err(|error| {
1323 polars_err!(
1324 ComputeError:
1325 "failed to extract TEL pattern '{}': {}",
1326 pattern,
1327 error
1328 )
1329 })
1330}
1331
1332fn list_series_to_strings(series: &Series) -> PolarsResult<Vec<String>> {
1333 match series.dtype() {
1334 DataType::String => series
1335 .str()?
1336 .into_iter()
1337 .map(|value| {
1338 value.map(ToString::to_string).ok_or_else(
1339 || polars_err!(InvalidOperation: "list values may not contain nulls"),
1340 )
1341 })
1342 .collect(),
1343 DataType::Categorical(_, _) | DataType::Enum(_, _) => {
1344 let casted = series.cast(&DataType::String)?;
1345 casted
1346 .str()?
1347 .into_iter()
1348 .map(|value| {
1349 value.map(ToString::to_string).ok_or_else(
1350 || polars_err!(InvalidOperation: "list values may not contain nulls"),
1351 )
1352 })
1353 .collect()
1354 }
1355 dtype => polars_bail!(
1356 InvalidOperation:
1357 "expected String, Categorical, or Enum list values, got {:?}",
1358 dtype
1359 ),
1360 }
1361}
1362
1363fn list_series_to_u8(series: &Series) -> PolarsResult<Vec<u8>> {
1364 series
1365 .u8()?
1366 .into_iter()
1367 .map(|value| {
1368 value.ok_or_else(|| polars_err!(InvalidOperation: "list values may not contain nulls"))
1369 })
1370 .collect()
1371}
1372
1373fn fill_series_u8(series: &Series, buffer: &mut Vec<u8>) -> PolarsResult<()> {
1374 buffer.clear();
1375 buffer.extend(
1376 series
1377 .u8()?
1378 .into_iter()
1379 .map(|value| {
1380 value.ok_or_else(
1381 || polars_err!(InvalidOperation: "list values may not contain nulls"),
1382 )
1383 })
1384 .collect::<PolarsResult<Vec<_>>>()?,
1385 );
1386 Ok(())
1387}
1388
1389fn join_string_values<T: AsRef<str>>(values: &[T]) -> String {
1390 let total_len = values.iter().map(|value| value.as_ref().len()).sum();
1391 let mut out = String::with_capacity(total_len);
1392 for value in values {
1393 out.push_str(value.as_ref());
1394 }
1395 out
1396}
1397
1398fn profile_enabled() -> bool {
1399 std::env::var("TOKMAT_PROFILE")
1400 .map(|value| value != "0" && !value.is_empty())
1401 .unwrap_or(false)
1402}
1403
1404fn should_parallelize(row_count: usize) -> bool {
1405 let rayon_enabled = std::env::var("TOKMAT_ENABLE_RAYON")
1406 .map(|value| value != "0" && !value.is_empty())
1407 .unwrap_or(false);
1408
1409 rayon_enabled
1410 && std::env::var("TOKMAT_DISABLE_RAYON").is_err()
1411 && rayon::current_num_threads() > 1
1412 && row_count >= 100_000
1413}
1414
1415fn parallel_chunk_size(row_count: usize) -> usize {
1416 let threads = rayon::current_num_threads().max(1);
1417 (row_count / (threads * 4)).max(50_000)
1418}
1419
1420fn elapsed_since(start: Option<Instant>) -> Duration {
1421 start.map_or(Duration::ZERO, |start| start.elapsed())
1422}
1423
1424fn build_tokenized_struct_series(
1425 name: PlSmallStr,
1426 columns: TokenizedColumns,
1427 context: &ModelContext,
1428 layout: TokenizeLayout,
1429) -> PolarsResult<Series> {
1430 let row_count = columns.token_values.len();
1431 let mut fields = Vec::new();
1432
1433 if let Some(raw_values) = columns.raw_values {
1434 fields.push(
1435 StringChunked::from_iter_options(
1436 "raw_value".into(),
1437 raw_values.iter().map(|value| value.as_deref()),
1438 )
1439 .into_series(),
1440 );
1441 }
1442
1443 fields.push(build_output_string_list_series(
1444 "tokens",
1445 columns.token_values,
1446 layout.token_output,
1447 None,
1448 )?);
1449
1450 if let Some(type_values) = columns.type_values {
1451 fields.push(build_output_string_list_series(
1452 "types",
1453 type_values,
1454 layout.type_output,
1455 Some(&context.type_enum_values),
1456 )?);
1457 }
1458 if let Some(class_values) = columns.class_values {
1459 fields.push(build_output_string_list_series(
1460 "classes",
1461 class_values,
1462 layout.class_output,
1463 Some(&context.class_enum_values),
1464 )?);
1465 }
1466 if let Some(type_id_values) = columns.type_id_values {
1467 fields.push(build_u8_list_series("type_ids", type_id_values));
1468 }
1469 if let Some(class_id_values) = columns.class_id_values {
1470 fields.push(build_u8_list_series("class_ids", class_id_values));
1471 }
1472
1473 Ok(StructChunked::from_series(name, row_count, fields.iter())?.into_series())
1474}
1475
1476fn build_extract_struct_series(
1477 name: PlSmallStr,
1478 field_columns: Vec<(String, Vec<Option<String>>)>,
1479 complements: &[Option<String>],
1480) -> PolarsResult<Series> {
1481 let row_count = complements.len();
1482
1483 let mut field_series = field_columns
1484 .into_iter()
1485 .map(|(field_name, values)| {
1486 StringChunked::from_iter_options(
1487 field_name.into(),
1488 values.iter().map(|value| value.as_deref()),
1489 )
1490 .into_series()
1491 })
1492 .collect::<Vec<_>>();
1493
1494 field_series.push(
1495 StringChunked::from_iter_options(
1496 "complement".into(),
1497 complements.iter().map(|value| value.as_deref()),
1498 )
1499 .into_series(),
1500 );
1501
1502 Ok(StructChunked::from_series(name, row_count, field_series.iter())?.into_series())
1503}
1504
1505fn init_extract_columns(
1506 capture_names: &[String],
1507 row_count: usize,
1508) -> Vec<(String, Vec<Option<String>>)> {
1509 capture_names
1510 .iter()
1511 .map(|name| (name.clone(), Vec::with_capacity(row_count)))
1512 .collect()
1513}
1514
1515fn push_parse_output(
1516 field_columns: &mut [(String, Vec<Option<String>>)],
1517 complements: &mut Vec<Option<String>>,
1518 output: Option<ParseOutput>,
1519) {
1520 if let Some(output) = output {
1521 for (field_name, values) in field_columns.iter_mut() {
1522 values.push(output.fields.get(field_name).cloned());
1523 }
1524 complements.push(Some(output.complement));
1525 } else {
1526 for (_, values) in field_columns.iter_mut() {
1527 values.push(None);
1528 }
1529 complements.push(None);
1530 }
1531}
1532
1533fn push_parse_output_by_index(
1534 field_values: &mut [Vec<Option<String>>],
1535 complements: &mut Vec<Option<String>>,
1536 capture_names: &[String],
1537 output: Option<ParseOutput>,
1538) {
1539 if let Some(output) = output {
1540 for (index, name) in capture_names.iter().enumerate() {
1541 field_values[index].push(output.fields.get(name).cloned());
1542 }
1543 complements.push(Some(output.complement));
1544 } else {
1545 for values in field_values.iter_mut() {
1546 values.push(None);
1547 }
1548 complements.push(None);
1549 }
1550}
1551
1552fn build_string_list_series(name: &str, rows: Vec<Option<Vec<String>>>) -> Series {
1553 let row_count = rows.len();
1554 let values_capacity = rows
1555 .iter()
1556 .flatten()
1557 .map(|values| values.iter().map(String::len).sum::<usize>())
1558 .sum();
1559 let mut builder = ListStringChunkedBuilder::new(name.into(), row_count, values_capacity);
1560 for row in rows {
1561 match row {
1562 Some(values) => builder.append_values_iter(values.iter().map(String::as_str)),
1563 None => builder.append_null(),
1564 }
1565 }
1566 builder.finish().into_series()
1567}
1568
1569#[allow(unsafe_code)]
1570fn build_enum_list_series(
1571 name: &str,
1572 rows: Vec<Option<Vec<String>>>,
1573 enum_values: &[String],
1574) -> PolarsResult<Series> {
1575 let enum_dtype = enum_dtype(enum_values);
1576 let rows = rows
1577 .into_iter()
1578 .map(|row| {
1579 row.map(|values| Series::new(PlSmallStr::EMPTY, values).cast(&enum_dtype))
1580 .transpose()
1581 })
1582 .collect::<PolarsResult<Vec<_>>>()?;
1583 let base = rows.into_iter().collect::<ListChunked>().into_series();
1584 let list_dtype = DataType::List(Box::new(enum_dtype));
1585 Ok(unsafe {
1589 Series::from_chunks_and_dtype_unchecked(name.into(), base.chunks().clone(), &list_dtype)
1590 })
1591}
1592
1593fn build_u8_list_series(name: &str, rows: Vec<Option<Vec<u8>>>) -> Series {
1594 let row_count = rows.len();
1595 let values_capacity = rows.iter().flatten().map(Vec::len).sum();
1596 let mut builder = ListPrimitiveChunkedBuilder::<UInt8Type>::new(
1597 name.into(),
1598 row_count,
1599 values_capacity,
1600 DataType::UInt8,
1601 );
1602 for row in rows {
1603 match row {
1604 Some(values) => builder.append_slice(&values),
1605 None => builder.append_null(),
1606 }
1607 }
1608 builder.finish().into_series()
1609}
1610
1611fn classify_token_ref<'a>(
1612 token: &'a str,
1613 model: &'a TokenModel,
1614 features: ModelFeatures,
1615) -> Cow<'a, str> {
1616 if token.is_empty() {
1617 return Cow::Borrowed(token);
1618 }
1619
1620 if token.is_ascii() && features.has_postalcode {
1621 let compact: String = token
1622 .chars()
1623 .filter(|character| {
1624 !character.is_whitespace() && *character != '-' && *character != '_'
1625 })
1626 .collect();
1627 let chars: Vec<char> = compact.chars().collect();
1628 if chars.len() == 6
1629 && chars[0].is_ascii_alphabetic()
1630 && chars[1].is_ascii_digit()
1631 && chars[2].is_ascii_alphabetic()
1632 && chars[3].is_ascii_digit()
1633 && chars[4].is_ascii_alphabetic()
1634 && chars[5].is_ascii_digit()
1635 {
1636 return Cow::Borrowed("POSTALCODE");
1637 }
1638 }
1639
1640 if token.is_ascii()
1641 && token.chars().all(|character| character.is_ascii_digit())
1642 && features.has_num
1643 {
1644 return Cow::Borrowed("NUM");
1645 }
1646
1647 if token.is_ascii() && token.chars().all(char::is_alphabetic) && features.has_alpha {
1648 return Cow::Borrowed("ALPHA");
1649 }
1650
1651 if token.is_ascii()
1652 && token
1653 .chars()
1654 .all(|character| character.is_ascii_digit() || character == '-')
1655 && token.chars().any(|character| character.is_ascii_digit())
1656 && features.has_num_extended
1657 {
1658 return Cow::Borrowed("NUM_EXTENDED");
1659 }
1660
1661 if token.is_ascii()
1662 && token
1663 .chars()
1664 .all(|character| character.is_alphabetic() || character == '-' || character == '\'')
1665 && token.chars().any(char::is_alphabetic)
1666 && features.has_alpha_extended
1667 {
1668 return Cow::Borrowed("ALPHA_EXTENDED");
1669 }
1670
1671 if token.is_ascii()
1672 && token
1673 .chars()
1674 .all(|character| character.is_alphanumeric() || character == '-' || character == '\'')
1675 {
1676 let has_alpha = token.chars().any(char::is_alphabetic);
1677 let has_digit = token.chars().any(|character| character.is_ascii_digit());
1678 if has_alpha && has_digit {
1679 if token.chars().all(char::is_alphanumeric) && features.has_alpha_num {
1680 return Cow::Borrowed("ALPHA_NUM");
1681 }
1682 if features.has_alpha_num_extended {
1683 return Cow::Borrowed("ALPHA_NUM_EXTENDED");
1684 }
1685 }
1686 }
1687
1688 model
1689 .compiled_patterns()
1690 .iter()
1691 .find_map(|(name, regex)| {
1692 regex
1693 .is_match(token.as_bytes())
1694 .ok()
1695 .and_then(|matched| matched.then_some(Cow::Borrowed(name.as_str())))
1696 })
1697 .unwrap_or(Cow::Borrowed(token))
1698}
1699
1700fn tokenize_row(raw_value: &str, context: &ModelContext, layout: TokenizeLayout) -> TokenizedRow {
1701 let tokens = split_input_tokens(raw_value);
1702 let mut type_values = layout
1703 .needs_type_values()
1704 .then(|| Vec::with_capacity(tokens.len()));
1705 let mut class_values = layout
1706 .include_classes
1707 .then(|| Vec::with_capacity(tokens.len()));
1708 let mut type_ids = layout
1709 .include_type_ids
1710 .then(|| Vec::with_capacity(tokens.len()));
1711 let mut class_ids = layout
1712 .include_class_ids
1713 .then(|| Vec::with_capacity(tokens.len()));
1714
1715 for token in &tokens {
1716 let token_type = classify_token_ref(token, &context.model, context.features).into_owned();
1717 if let Some(values) = type_values.as_mut() {
1718 values.push(token_type.clone());
1719 }
1720 if let Some(values) = type_ids.as_mut() {
1721 values.push(context.type_codec.encode_known_or_raw(&token_type));
1722 }
1723
1724 let class_value = if token.chars().all(char::is_whitespace) {
1725 Cow::Borrowed(token.as_str())
1726 } else if let Some(value) = context.model.token_class_lookup().get(token) {
1727 Cow::Borrowed(value.as_str())
1728 } else {
1729 Cow::Borrowed(token_type.as_str())
1730 };
1731
1732 if let Some(values) = class_values.as_mut() {
1733 values.push(class_value.to_string());
1734 }
1735 if let Some(values) = class_ids.as_mut() {
1736 values.push(
1737 context
1738 .class_codec
1739 .encode_known_or_raw(class_value.as_ref()),
1740 );
1741 }
1742 }
1743
1744 TokenizedRow {
1745 raw_value: layout.include_raw_value.then(|| raw_value.to_string()),
1746 tokens,
1747 types: type_values,
1748 classes: class_values,
1749 type_ids,
1750 class_ids,
1751 }
1752}
1753
1754#[cfg(test)]
1755mod tests {
1756 use super::*;
1757
1758 fn fixture_model_path() -> String {
1759 Path::new(env!("CARGO_MANIFEST_DIR"))
1760 .join("tests/fixtures/model_1")
1761 .to_string_lossy()
1762 .into_owned()
1763 }
1764
1765 fn legacy_tokenize_kwargs() -> TokenizeKwargs {
1766 TokenizeKwargs {
1767 model_path: fixture_model_path(),
1768 include_raw_value: true,
1769 include_types: true,
1770 include_classes: true,
1771 include_type_ids: false,
1772 include_class_ids: false,
1773 token_output: StringListOutput::String,
1774 type_output: StringListOutput::String,
1775 class_output: StringListOutput::String,
1776 }
1777 }
1778
1779 #[test]
1780 fn tokenize_helper_returns_struct_with_expected_fields() {
1781 let input = Series::new("address".into(), &[Some("123 MAIN ST"), None]);
1782 let output = tokenize_expr_impl(&[input], &legacy_tokenize_kwargs())
1783 .expect("tokenize should succeed");
1784
1785 let struct_chunked = output.struct_().expect("tokenize output should be struct");
1786 let fields = struct_chunked.fields_as_series();
1787 let field_names = fields
1788 .iter()
1789 .map(|field| field.name().as_str())
1790 .collect::<Vec<_>>();
1791 assert_eq!(field_names, vec!["raw_value", "tokens", "types", "classes"]);
1792
1793 let token_field = fields
1794 .iter()
1795 .find(|field| field.name().as_str() == "tokens")
1796 .expect("tokens field should exist");
1797 let first_tokens = token_field
1798 .list()
1799 .expect("tokens should be a list")
1800 .into_iter()
1801 .next()
1802 .expect("first row should exist")
1803 .expect("first row should be non-null");
1804 let token_values =
1805 list_series_to_strings(&first_tokens).expect("list conversion should work");
1806 assert!(token_values.contains(&"123".to_string()));
1807 assert!(token_values.contains(&"MAIN".to_string()));
1808 }
1809
1810 #[test]
1811 fn tokenize_helper_can_emit_compact_class_ids() {
1812 let input = Series::new("address".into(), ["123 MAIN ST"]);
1813 let output = tokenize_expr_impl(
1814 &[input],
1815 &TokenizeKwargs {
1816 model_path: fixture_model_path(),
1817 include_raw_value: false,
1818 include_types: false,
1819 include_classes: false,
1820 include_type_ids: false,
1821 include_class_ids: true,
1822 token_output: StringListOutput::String,
1823 type_output: StringListOutput::String,
1824 class_output: StringListOutput::String,
1825 },
1826 )
1827 .expect("compact tokenize should succeed");
1828
1829 let struct_chunked = output.struct_().expect("tokenize output should be struct");
1830 let fields = struct_chunked.fields_as_series();
1831 let field_names = fields
1832 .iter()
1833 .map(|field| field.name().as_str())
1834 .collect::<Vec<_>>();
1835 assert_eq!(field_names, vec!["tokens", "class_ids"]);
1836 }
1837
1838 #[test]
1839 fn tokenize_helper_can_emit_categorical_lists() {
1840 let input = Series::new("address".into(), ["123 MAIN ST"]);
1841 let output = tokenize_expr_impl(
1842 &[input],
1843 &TokenizeKwargs {
1844 model_path: fixture_model_path(),
1845 include_raw_value: false,
1846 include_types: true,
1847 include_classes: true,
1848 include_type_ids: false,
1849 include_class_ids: false,
1850 token_output: StringListOutput::Categorical,
1851 type_output: StringListOutput::Categorical,
1852 class_output: StringListOutput::Categorical,
1853 },
1854 )
1855 .expect("categorical tokenize should succeed");
1856
1857 let struct_chunked = output.struct_().expect("tokenize output should be struct");
1858 let fields = struct_chunked.fields_as_series();
1859 assert_eq!(
1860 fields[0].dtype(),
1861 &DataType::List(Box::new(DataType::Categorical(
1862 None,
1863 CategoricalOrdering::default(),
1864 )))
1865 );
1866 assert_eq!(
1867 fields[1].dtype(),
1868 &DataType::List(Box::new(DataType::Categorical(
1869 None,
1870 CategoricalOrdering::default(),
1871 )))
1872 );
1873 assert_eq!(
1874 fields[2].dtype(),
1875 &DataType::List(Box::new(DataType::Categorical(
1876 None,
1877 CategoricalOrdering::default(),
1878 )))
1879 );
1880 }
1881
1882 #[test]
1883 fn extract_helper_accepts_raw_string_input() {
1884 let input = Series::new("address".into(), ["123 MAIN ST"]);
1885 let output = extract_expr_impl(
1886 &[input],
1887 ExtractKwargs {
1888 model_path: fixture_model_path(),
1889 pattern: "<<CIVIC#>> <<STREET@+>> <<TYPE::STREETTYPE>>".to_string(),
1890 mode: MatchModeKwarg::default(),
1891 },
1892 )
1893 .expect("extract should succeed");
1894
1895 let struct_chunked = output.struct_().expect("extract output should be struct");
1896 let fields = struct_chunked.fields_as_series();
1897 let civic = fields
1898 .iter()
1899 .find(|field| field.name().as_str() == "CIVIC")
1900 .expect("CIVIC field should exist")
1901 .str()
1902 .expect("CIVIC field should be string")
1903 .get(0);
1904 let street = fields
1905 .iter()
1906 .find(|field| field.name().as_str() == "STREET")
1907 .expect("STREET field should exist")
1908 .str()
1909 .expect("STREET field should be string")
1910 .get(0);
1911 let street_type = fields
1912 .iter()
1913 .find(|field| field.name().as_str() == "TYPE")
1914 .expect("TYPE field should exist")
1915 .str()
1916 .expect("TYPE field should be string")
1917 .get(0);
1918
1919 assert_eq!(civic, Some("123"));
1920 assert_eq!(street, Some("MAIN"));
1921 assert_eq!(street_type, Some("ST"));
1922 }
1923
1924 #[test]
1925 fn extract_helper_accepts_tokenized_struct_input() {
1926 let tokenized = tokenize_expr_impl(
1927 &[Series::new("address".into(), ["123 MAIN ST"])],
1928 &legacy_tokenize_kwargs(),
1929 )
1930 .expect("tokenize should succeed");
1931
1932 let output = extract_expr_impl(
1933 &[tokenized],
1934 ExtractKwargs {
1935 model_path: fixture_model_path(),
1936 pattern: "<<CIVIC#>> <<STREET@+>> <<TYPE::STREETTYPE>>".to_string(),
1937 mode: MatchModeKwarg::default(),
1938 },
1939 )
1940 .expect("extract should succeed");
1941
1942 let struct_chunked = output.struct_().expect("extract output should be struct");
1943 let fields = struct_chunked.fields_as_series();
1944 let complement = fields
1945 .iter()
1946 .find(|field| field.name().as_str() == "complement")
1947 .expect("complement field should exist")
1948 .str()
1949 .expect("complement should be string")
1950 .get(0);
1951
1952 assert_eq!(complement, Some(""));
1953 }
1954
1955 #[test]
1956 fn extract_helper_accepts_tokenized_struct_without_raw_value_or_types() {
1957 let tokens = build_string_list_series(
1958 "tokens",
1959 [Some(vec![
1960 "123".to_string(),
1961 " ".to_string(),
1962 "MAIN".to_string(),
1963 " ".to_string(),
1964 "ST".to_string(),
1965 ])]
1966 .to_vec(),
1967 );
1968 let classes = build_string_list_series(
1969 "classes",
1970 [Some(vec![
1971 "NUM".to_string(),
1972 " ".to_string(),
1973 "ALPHA".to_string(),
1974 " ".to_string(),
1975 "STREETTYPE".to_string(),
1976 ])]
1977 .to_vec(),
1978 );
1979 let tokenized = StructChunked::from_series("address".into(), 1, [tokens, classes].iter())
1980 .expect("struct should build")
1981 .into_series();
1982
1983 let output = extract_expr_impl(
1984 &[tokenized],
1985 ExtractKwargs {
1986 model_path: fixture_model_path(),
1987 pattern: "<<CIVIC#>> <<STREET@+>> <<TYPE::STREETTYPE>>".to_string(),
1988 mode: MatchModeKwarg::default(),
1989 },
1990 )
1991 .expect("extract should succeed");
1992
1993 let struct_chunked = output.struct_().expect("extract output should be struct");
1994 let fields = struct_chunked.fields_as_series();
1995 let civic = fields
1996 .iter()
1997 .find(|field| field.name().as_str() == "CIVIC")
1998 .expect("CIVIC field should exist")
1999 .str()
2000 .expect("CIVIC field should be string")
2001 .get(0);
2002
2003 assert_eq!(civic, Some("123"));
2004 }
2005
2006 #[test]
2007 fn extract_helper_accepts_tokenized_struct_with_class_ids() {
2008 let tokenized = tokenize_expr_impl(
2009 &[Series::new("address".into(), ["123 MAIN ST"])],
2010 &TokenizeKwargs {
2011 model_path: fixture_model_path(),
2012 include_raw_value: false,
2013 include_types: false,
2014 include_classes: false,
2015 include_type_ids: false,
2016 include_class_ids: true,
2017 token_output: StringListOutput::String,
2018 type_output: StringListOutput::String,
2019 class_output: StringListOutput::String,
2020 },
2021 )
2022 .expect("compact tokenize should succeed");
2023
2024 let output = extract_expr_impl(
2025 &[tokenized],
2026 ExtractKwargs {
2027 model_path: fixture_model_path(),
2028 pattern: "<<CIVIC#>> <<STREET@+>> <<TYPE::STREETTYPE>>".to_string(),
2029 mode: MatchModeKwarg::default(),
2030 },
2031 )
2032 .expect("extract should succeed");
2033
2034 let struct_chunked = output.struct_().expect("extract output should be struct");
2035 let fields = struct_chunked.fields_as_series();
2036 let civic = fields
2037 .iter()
2038 .find(|field| field.name().as_str() == "CIVIC")
2039 .expect("CIVIC field should exist")
2040 .str()
2041 .expect("CIVIC field should be string")
2042 .get(0);
2043
2044 assert_eq!(civic, Some("123"));
2045 }
2046
2047 #[test]
2048 fn extract_helper_respects_any_mode_for_raw_string_input() {
2049 let input = Series::new("address".into(), ["ATTN 123 MAIN ST"]);
2050 let output = extract_expr_impl(
2051 &[input],
2052 ExtractKwargs {
2053 model_path: fixture_model_path(),
2054 pattern: "<<CIVIC#>> <<STREET@+>> <<TYPE::STREETTYPE>>".to_string(),
2055 mode: MatchModeKwarg::Any,
2056 },
2057 )
2058 .expect("extract should succeed");
2059
2060 let struct_chunked = output.struct_().expect("extract output should be struct");
2061 let fields = struct_chunked.fields_as_series();
2062 let civic = fields
2063 .iter()
2064 .find(|field| field.name().as_str() == "CIVIC")
2065 .expect("CIVIC field should exist")
2066 .str()
2067 .expect("CIVIC field should be string")
2068 .get(0);
2069 let complement = fields
2070 .iter()
2071 .find(|field| field.name().as_str() == "complement")
2072 .expect("complement field should exist")
2073 .str()
2074 .expect("complement should be string")
2075 .get(0);
2076
2077 assert_eq!(civic, Some("123"));
2078 assert_eq!(complement, Some("ATTN "));
2079 }
2080
2081 #[test]
2082 fn rust_api_tokenizes_and_extracts() {
2083 let plugin =
2084 TokmatPolars::from_model_path(fixture_model_path()).expect("model should load");
2085 let input = Series::new("address".into(), ["123 MAIN ST"]);
2086
2087 let tokenized = plugin
2088 .tokenize_series(&input)
2089 .expect("tokenize via rust api");
2090 let extracted = plugin
2091 .extract_series(&tokenized, "<<CIVIC#>> <<STREET@+>> <<TYPE::STREETTYPE>>")
2092 .expect("extract via rust api");
2093
2094 let struct_chunked = extracted
2095 .struct_()
2096 .expect("extract output should be struct");
2097 let fields = struct_chunked.fields_as_series();
2098 let civic = fields
2099 .iter()
2100 .find(|field| field.name().as_str() == "CIVIC")
2101 .expect("CIVIC field should exist")
2102 .str()
2103 .expect("CIVIC field should be string")
2104 .get(0);
2105
2106 assert_eq!(civic, Some("123"));
2107 assert_eq!(
2108 plugin
2109 .capture_field_names("<<CIVIC#>> <<STREET@+>> <<TYPE::STREETTYPE>>")
2110 .expect("capture names"),
2111 vec![
2112 "CIVIC".to_string(),
2113 "STREET".to_string(),
2114 "TYPE".to_string()
2115 ]
2116 );
2117 }
2118
2119 #[test]
2120 fn rust_api_extracts_with_explicit_match_mode() {
2121 let plugin =
2122 TokmatPolars::from_model_path(fixture_model_path()).expect("model should load");
2123 let input = Series::new("address".into(), ["ATTN 123 MAIN ST"]);
2124
2125 let extracted = plugin
2126 .extract_series_with_mode(
2127 &input,
2128 "<<CIVIC#>> <<STREET@+>> <<TYPE::STREETTYPE>>",
2129 MatchMode::Any,
2130 )
2131 .expect("extract via rust api");
2132
2133 let struct_chunked = extracted
2134 .struct_()
2135 .expect("extract output should be struct");
2136 let fields = struct_chunked.fields_as_series();
2137 let complement = fields
2138 .iter()
2139 .find(|field| field.name().as_str() == "complement")
2140 .expect("complement field should exist")
2141 .str()
2142 .expect("complement field should be string")
2143 .get(0);
2144
2145 assert_eq!(complement, Some("ATTN "));
2146 }
2147}