tract_transformers/
lib.rs1pub mod ops;
2mod rewriter;
3use rewriter::*;
4use tract_nnef::internal::*;
5use tract_nnef::tract_core::transform::ModelTransform;
6
7pub fn get_transform(name: &str) -> Option<Box<dyn ModelTransform>> {
8 match name {
9 "detect-rms-norm" => Some(Box::new(RmsNormTransform)),
10 "detect-apply-rope" => Some(Box::new(ApplyRopeTransform)),
11 "detect-silu" => Some(Box::new(SiluTransform)),
12 "detect-scaled-masked-softmax" => Some(Box::new(ScaledMaskedSoftmaxTransform)),
13 "detect-gelu-approx" => Some(Box::new(GeluTransform)),
14 "detect-kv-cache" => Some(Box::new(KeyValueCacheTransform)),
15 "detect-sdpa-kv-cache-broadcast" => Some(Box::new(SdpaFuseKvCacheBroadcastTransform)),
16 "transformers-detect-all" => Some(Box::new(TransformersTransform)),
17 _ => None,
18 }
19}
20
21pub fn register(registry: &mut Registry) {
22 registry.transforms = Box::new(|s| Ok(get_transform(s)));
23
24 ops::rms_norm::register(registry);
25 ops::silu::register(registry);
26 ops::gelu_approximate::register(registry);
27 ops::apply_rope::register(registry);
28 ops::scaled_masked_softmax::register(registry);
29 ops::sdpa::register(registry);
30}
31
32pub trait WithTractTransformers {
33 fn enable_tract_transformers(&mut self);
34 fn with_tract_transformers(self) -> Self;
35}
36
37impl WithTractTransformers for tract_nnef::framework::Nnef {
38 fn enable_tract_transformers(&mut self) {
39 self.enable_tract_core();
40 self.registries.push(tract_transformers_registry());
41 }
42
43 fn with_tract_transformers(mut self) -> Self {
44 self.enable_tract_transformers();
45 self
46 }
47}
48
49pub fn tract_transformers_registry() -> Registry {
50 let mut reg = Registry::new("tract_transformers")
51 .with_doc("Extension `tract_transformers` extends NNEF with operators")
52 .with_doc("for transformer networks.")
53 .with_doc("")
54 .with_doc("Add `extension tract_transformers` to `graph.nnef`");
55
56 register(&mut reg);
57 reg
58}