tract_transformers/
lib.rs1pub mod ops;
2mod rewriter;
3use std::collections::HashSet;
4
5use rewriter::*;
6use tract_nnef::internal::*;
7
8register_simple_model_transform!("detect-rms-norm", RmsNormTransform);
9register_simple_model_transform!("detect-apply-rope", ApplyRopeTransform);
10register_simple_model_transform!("detect-silu", SiluTransform);
11register_simple_model_transform!("detect-scaled-masked-softmax", ScaledMaskedSoftmaxTransform);
12register_simple_model_transform!("detect-gelu-approx", GeluTransform);
13register_simple_model_transform!("detect-kv-cache", KeyValueCacheTransform);
14register_simple_model_transform!(
15 "detect-sdpa-kv-cache-broadcast",
16 SdpaFuseKvCacheBroadcastTransform
17);
18register_simple_model_transform!("transformers-detect-all", TransformersTransform);
19
20pub fn register(registry: &mut Registry) {
21 ops::rms_norm::register(registry);
22 ops::silu::register(registry);
23 ops::gelu_approximate::register(registry);
24 ops::apply_rope::register(registry);
25 ops::scaled_masked_softmax::register(registry);
26 ops::sdpa::register(registry);
27}
28
29pub trait WithTractTransformers {
30 fn enable_tract_transformers(&mut self);
31 fn with_tract_transformers(self) -> Self;
32}
33
34impl WithTractTransformers for tract_nnef::framework::Nnef {
35 fn enable_tract_transformers(&mut self) {
36 self.enable_tract_core();
37 self.registries.push(tract_transformers_registry());
38 }
39
40 fn with_tract_transformers(mut self) -> Self {
41 self.enable_tract_transformers();
42 self
43 }
44}
45
46pub fn tract_transformers_registry() -> Registry {
47 let mut reg = Registry::new("tract_transformers")
48 .with_doc("Extension `tract_transformers` extends NNEF with operators")
49 .with_doc("for transformer networks.")
50 .with_doc("")
51 .with_doc("Add `extension tract_transformers` to `graph.nnef`");
52
53 register(&mut reg);
54 reg
55}
56
57pub fn figure_out_causal_llm_b_s_p(
58 model: &TypedModel,
59) -> TractResult<(Option<Symbol>, Option<Symbol>, Option<Symbol>)> {
60 let token_input = model
64 .inputs
65 .iter()
66 .position(|i| model.outlet_fact(*i).unwrap().datum_type.is_integer())
67 .context("No token input found")?;
68 let tokens_symbols = model.input_fact(token_input)?.shape.volume().symbols();
69 let kv_symbols = if let Some(kv_input) =
70 model.inputs.iter().position(|i| model.outlet_fact(*i).unwrap().datum_type.is_float())
71 {
72 model.input_fact(kv_input)?.shape.volume().symbols()
73 } else {
74 let dummy_session_state = TurnState::default();
76 let mut symbols = HashSet::new();
77 for node in &model.nodes {
78 if let Some((_, fact)) =
79 node.op.state(&dummy_session_state, 0)?.and_then(|state| state.init_tensor_fact())
80 {
81 symbols = fact.shape.volume().symbols();
82 break;
83 }
84 }
85 symbols
86 };
87
88 let b = tokens_symbols.intersection(&kv_symbols).cloned().collect::<HashSet<_>>();
89 let s = tokens_symbols.difference(&b).cloned().collect::<HashSet<_>>();
90 let p = kv_symbols.difference(&b).cloned().collect::<HashSet<_>>();
91 Ok((b.into_iter().next(), s.into_iter().next(), p.into_iter().next()))
92}
93
94pub fn memory_arena_hints_for_causal_llm(model: &TypedModel) -> TractResult<SymbolValues> {
95 let (b, s, p) = figure_out_causal_llm_b_s_p(model)?;
96 let mut values = SymbolValues::default()
97 .with(&s.context("Could not determine sequence_len (S)")?, 1024)
98 .with(&p.context("Could not determine past_sequence_len (P)")?, 0);
99 if let Some(b) = b {
100 values = values.with(&b, 1);
101 }
102 Ok(values)
103}