tract_transformers/
lib.rs

1pub mod ops;
2mod rewriter;
3use rewriter::*;
4use tract_nnef::internal::*;
5use tract_nnef::tract_core::transform::ModelTransform;
6
7pub fn get_transform(name: &str) -> Option<Box<dyn ModelTransform>> {
8    match name {
9        "detect-rms-norm" => Some(Box::new(RmsNormTransform)),
10        "detect-apply-rope" => Some(Box::new(ApplyRopeTransform)),
11        "detect-silu" => Some(Box::new(SiluTransform)),
12        "detect-scaled-masked-softmax" => Some(Box::new(ScaledMaskedSoftmaxTransform)),
13        "detect-gelu-approx" => Some(Box::new(GeluTransform)),
14        "detect-kv-cache" => Some(Box::new(KeyValueCacheTransform)),
15        "detect-sdpa-kv-cache-broadcast" => Some(Box::new(SdpaFuseKvCacheBroadcastTransform)),
16        "transformers-detect-all" => Some(Box::new(TransformersTransform)),
17        _ => None,
18    }
19}
20
21pub fn register(registry: &mut Registry) {
22    registry.transforms = Box::new(|s| Ok(get_transform(s)));
23
24    ops::rms_norm::register(registry);
25    ops::silu::register(registry);
26    ops::gelu_approximate::register(registry);
27    ops::apply_rope::register(registry);
28    ops::scaled_masked_softmax::register(registry);
29    ops::sdpa::register(registry);
30}
31
32pub trait WithTractTransformers {
33    fn enable_tract_transformers(&mut self);
34    fn with_tract_transformers(self) -> Self;
35}
36
37impl WithTractTransformers for tract_nnef::framework::Nnef {
38    fn enable_tract_transformers(&mut self) {
39        self.enable_tract_core();
40        self.registries.push(tract_transformers_registry());
41    }
42
43    fn with_tract_transformers(mut self) -> Self {
44        self.enable_tract_transformers();
45        self
46    }
47}
48
49pub fn tract_transformers_registry() -> Registry {
50    let mut reg = Registry::new("tract_transformers")
51        .with_doc("Extension `tract_transformers` extends NNEF with operators")
52        .with_doc("for transformer networks.")
53        .with_doc("")
54        .with_doc("Add `extension tract_transformers` to `graph.nnef`");
55
56    register(&mut reg);
57    reg
58}