tract_transformers/
lib.rs1pub mod ops;
2mod rewriter;
3use rewriter::*;
4use tract_nnef::internal::*;
5use tract_nnef::tract_core::transform::ModelTransform;
6
7pub fn get_transform(name: &str) -> Option<Box<dyn ModelTransform>> {
8 match name {
9 "detect-rms-norm" => Some(Box::new(RmsNormTransform)),
10 "detect-apply-rope" => Some(Box::new(ApplyRopeTransform)),
11 "detect-silu" => Some(Box::new(SiluTransform)),
12 "detect-scaled-masked-softmax" => Some(Box::new(ScaledMaskedSoftmaxTransform)),
13 "detect-gelu-approx" => Some(Box::new(GeluTransform)),
14 "transformers-detect-all" => Some(Box::new(TransformersTransform)),
15 _ => None,
16 }
17}
18
19pub fn register(registry: &mut Registry) {
20 registry.transforms = Box::new(|s| Ok(get_transform(s)));
21
22 ops::rms_norm::register(registry);
23 ops::silu::register(registry);
24 ops::gelu_approximate::register(registry);
25 ops::apply_rope::register(registry);
26 ops::scaled_masked_softmax::register(registry);
27}
28
29pub trait WithTractTransformers {
30 fn enable_tract_transformers(&mut self);
31 fn with_tract_transformers(self) -> Self;
32}
33
34impl WithTractTransformers for tract_nnef::framework::Nnef {
35 fn enable_tract_transformers(&mut self) {
36 self.enable_tract_core();
37 self.registries.push(tract_transformers_registry());
38 }
39
40 fn with_tract_transformers(mut self) -> Self {
41 self.enable_tract_transformers();
42 self
43 }
44}
45
46pub fn tract_transformers_registry() -> Registry {
47 let mut reg = Registry::new("tract_transformers")
48 .with_doc("Extension `tract_transformers` extends NNEF with operators")
49 .with_doc("for transformer networks.")
50 .with_doc("")
51 .with_doc("Add `extension tract_transformers` to `graph.nnef`");
52
53 register(&mut reg);
54 reg
55}