Skip to main content

tract_transformers/
lib.rs

1pub mod ops;
2mod rewriter;
3use std::collections::HashSet;
4
5use rewriter::*;
6use tract_nnef::internal::*;
7
8register_simple_model_transform!("detect_apply_rope", ApplyRopeTransform);
9register_simple_model_transform!("detect_diag_gather", DetectDiagGatherTransform);
10register_simple_model_transform!("detect_scaled_masked_softmax", ScaledMaskedSoftmaxTransform);
11register_simple_model_transform!("detect_kv_cache", KeyValueCacheTransform);
12register_simple_model_transform!(
13    "detect_sdpa_kv_cache_broadcast",
14    SdpaFuseKvCacheBroadcastTransform
15);
16register_simple_model_transform!("unfold_kv_cache", UnfoldKeyValueCacheTransform);
17register_simple_model_transform!("transformers_detect_all", TransformersTransform);
18
19pub fn register(registry: &mut Registry) {
20    ops::apply_rope::register(registry);
21    ops::scaled_masked_softmax::register(registry);
22    ops::sdpa::register(registry);
23    ops::dyn_kv_cache::register(registry);
24    ops::window_kv_cache::register(registry);
25    ops::kv_quant::register(registry);
26}
27
28pub trait WithTractTransformers {
29    fn enable_tract_transformers(&mut self);
30    fn with_tract_transformers(self) -> Self;
31}
32
33impl WithTractTransformers for tract_nnef::framework::Nnef {
34    fn enable_tract_transformers(&mut self) {
35        self.registries.push(tract_transformers_registry());
36    }
37
38    fn with_tract_transformers(mut self) -> Self {
39        self.enable_tract_transformers();
40        self
41    }
42}
43
44pub fn tract_transformers_registry() -> Registry {
45    let mut reg = Registry::new("tract_transformers")
46        .with_doc("Extension `tract_transformers` extends NNEF with operators")
47        .with_doc("for transformer networks.")
48        .with_doc("")
49        .with_doc("Add `extension tract_transformers` to `graph.nnef`");
50
51    register(&mut reg);
52    reg
53}
54
55pub fn figure_out_causal_llm_b_s_p(
56    model: &TypedModel,
57) -> TractResult<(Option<Symbol>, Option<Symbol>, Option<Symbol>)> {
58    // expectations:
59    // - one input is for tokens, so integer dt (i64 ?) and typically of shape S or 1,S, or B,S
60    // - other inputs are kv cache, some kind of float. shape features both S and P, and B if B is present in tokens
61    let token_input = model
62        .inputs
63        .iter()
64        .position(|i| model.outlet_fact(*i).unwrap().datum_type.is_integer())
65        .context("No token input found")?;
66    let tokens_symbols = model.input_fact(token_input)?.shape.volume().symbols();
67    let kv_symbols = if let Some(kv_input) =
68        model.inputs.iter().position(|i| model.outlet_fact(*i).unwrap().datum_type.is_float())
69    {
70        model.input_fact(kv_input)?.shape.volume().symbols()
71    } else {
72        // Look for KVCache Op
73        let dummy_session_state = TurnState::default();
74        let mut symbols = HashSet::new();
75        for node in &model.nodes {
76            if let Some((_, fact)) =
77                node.op.state(&dummy_session_state, 0)?.and_then(|state| state.init_tensor_fact())
78            {
79                symbols = fact.shape.volume().symbols();
80                break;
81            }
82        }
83        symbols
84    };
85
86    let b = tokens_symbols.intersection(&kv_symbols).cloned().collect::<HashSet<_>>();
87    let s = tokens_symbols.difference(&b).cloned().collect::<HashSet<_>>();
88    let p = kv_symbols.difference(&b).cloned().collect::<HashSet<_>>();
89    Ok((b.into_iter().next(), s.into_iter().next(), p.into_iter().next()))
90}
91
92pub fn memory_arena_hints_for_causal_llm(model: &TypedModel) -> TractResult<SymbolValues> {
93    let (b, s, p) = figure_out_causal_llm_b_s_p(model)?;
94    let mut values = SymbolValues::default()
95        .with(&s.context("Could not determine sequence_len (S)")?, 1024)
96        .with(&p.context("Could not determine past_sequence_len (P)")?, 0);
97    if let Some(b) = b {
98        values = values.with(&b, 1);
99    }
100    Ok(values)
101}