tract_transformers/
lib.rs1pub mod ops;
2mod rewriter;
3use std::collections::HashSet;
4
5use rewriter::*;
6use tract_nnef::internal::*;
7
8register_simple_model_transform!("detect_apply_rope", ApplyRopeTransform);
9register_simple_model_transform!("detect_diag_gather", DetectDiagGatherTransform);
10register_simple_model_transform!("detect_scaled_masked_softmax", ScaledMaskedSoftmaxTransform);
11register_simple_model_transform!("detect_kv_cache", KeyValueCacheTransform);
12register_simple_model_transform!(
13 "detect_sdpa_kv_cache_broadcast",
14 SdpaFuseKvCacheBroadcastTransform
15);
16register_simple_model_transform!("unfold_kv_cache", UnfoldKeyValueCacheTransform);
17register_simple_model_transform!("transformers_detect_all", TransformersTransform);
18
19pub fn register(registry: &mut Registry) {
20 ops::apply_rope::register(registry);
21 ops::scaled_masked_softmax::register(registry);
22 ops::sdpa::register(registry);
23 ops::dyn_kv_cache::register(registry);
24 ops::window_kv_cache::register(registry);
25}
26
27pub trait WithTractTransformers {
28 fn enable_tract_transformers(&mut self);
29 fn with_tract_transformers(self) -> Self;
30}
31
32impl WithTractTransformers for tract_nnef::framework::Nnef {
33 fn enable_tract_transformers(&mut self) {
34 self.registries.push(tract_transformers_registry());
35 }
36
37 fn with_tract_transformers(mut self) -> Self {
38 self.enable_tract_transformers();
39 self
40 }
41}
42
43pub fn tract_transformers_registry() -> Registry {
44 let mut reg = Registry::new("tract_transformers")
45 .with_doc("Extension `tract_transformers` extends NNEF with operators")
46 .with_doc("for transformer networks.")
47 .with_doc("")
48 .with_doc("Add `extension tract_transformers` to `graph.nnef`");
49
50 register(&mut reg);
51 reg
52}
53
54pub fn figure_out_causal_llm_b_s_p(
55 model: &TypedModel,
56) -> TractResult<(Option<Symbol>, Option<Symbol>, Option<Symbol>)> {
57 let token_input = model
61 .inputs
62 .iter()
63 .position(|i| model.outlet_fact(*i).unwrap().datum_type.is_integer())
64 .context("No token input found")?;
65 let tokens_symbols = model.input_fact(token_input)?.shape.volume().symbols();
66 let kv_symbols = if let Some(kv_input) =
67 model.inputs.iter().position(|i| model.outlet_fact(*i).unwrap().datum_type.is_float())
68 {
69 model.input_fact(kv_input)?.shape.volume().symbols()
70 } else {
71 let dummy_session_state = TurnState::default();
73 let mut symbols = HashSet::new();
74 for node in &model.nodes {
75 if let Some((_, fact)) =
76 node.op.state(&dummy_session_state, 0)?.and_then(|state| state.init_tensor_fact())
77 {
78 symbols = fact.shape.volume().symbols();
79 break;
80 }
81 }
82 symbols
83 };
84
85 let b = tokens_symbols.intersection(&kv_symbols).cloned().collect::<HashSet<_>>();
86 let s = tokens_symbols.difference(&b).cloned().collect::<HashSet<_>>();
87 let p = kv_symbols.difference(&b).cloned().collect::<HashSet<_>>();
88 Ok((b.into_iter().next(), s.into_iter().next(), p.into_iter().next()))
89}
90
91pub fn memory_arena_hints_for_causal_llm(model: &TypedModel) -> TractResult<SymbolValues> {
92 let (b, s, p) = figure_out_causal_llm_b_s_p(model)?;
93 let mut values = SymbolValues::default()
94 .with(&s.context("Could not determine sequence_len (S)")?, 1024)
95 .with(&p.context("Could not determine past_sequence_len (P)")?, 0);
96 if let Some(b) = b {
97 values = values.with(&b, 1);
98 }
99 Ok(values)
100}