tract_transformers/
lib.rs

1pub mod ops;
2mod rewriter;
3use rewriter::*;
4use tract_nnef::internal::*;
5use tract_nnef::tract_core::transform::ModelTransform;
6
7pub fn get_transform(name: &str) -> Option<Box<dyn ModelTransform>> {
8    match name {
9        "detect-rms-norm" => Some(Box::new(RmsNormTransform)),
10        "detect-apply-rope" => Some(Box::new(ApplyRopeTransform)),
11        "detect-silu" => Some(Box::new(SiluTransform)),
12        "detect-scaled-masked-softmax" => Some(Box::new(ScaledMaskedSoftmaxTransform)),
13        "detect-gelu-approx" => Some(Box::new(GeluTransform)),
14        "transformers-detect-all" => Some(Box::new(TransformersTransform)),
15        _ => None,
16    }
17}
18
19pub fn register(registry: &mut Registry) {
20    registry.transforms = Box::new(|s| Ok(get_transform(s)));
21
22    ops::rms_norm::register(registry);
23    ops::silu::register(registry);
24    ops::gelu_approximate::register(registry);
25    ops::apply_rope::register(registry);
26    ops::scaled_masked_softmax::register(registry);
27}
28
29pub trait WithTractTransformers {
30    fn enable_tract_transformers(&mut self);
31    fn with_tract_transformers(self) -> Self;
32}
33
34impl WithTractTransformers for tract_nnef::framework::Nnef {
35    fn enable_tract_transformers(&mut self) {
36        self.enable_tract_core();
37        self.registries.push(tract_transformers_registry());
38    }
39
40    fn with_tract_transformers(mut self) -> Self {
41        self.enable_tract_transformers();
42        self
43    }
44}
45
46pub fn tract_transformers_registry() -> Registry {
47    let mut reg = Registry::new("tract_transformers")
48        .with_doc("Extension `tract_transformers` extends NNEF with operators")
49        .with_doc("for transformer networks.")
50        .with_doc("")
51        .with_doc("Add `extension tract_transformers` to `graph.nnef`");
52
53    register(&mut reg);
54    reg
55}