Skip to main content

tfhe_hpu_backend/fw/
program.rs

1//!
2//! Abstraction used to ease FW writing
3//!
4//! It provide a set of utilities used to help FW implementation
5//! with a clean and easy to read API
6
7use lru::LruCache;
8use std::cell::RefCell;
9use std::collections::HashMap;
10use std::rc::Rc;
11
12use crate::asm;
13
14use super::metavar::{MetaVarCell, MetaVarCellWeak, VarPos};
15use super::FwParameters;
16
17use tracing::trace;
18
19use crate::fw::rtl::config::OpCfg;
20
21#[derive(Debug, Clone)]
22pub struct ProgramInner {
23    uid: usize,
24    pub(crate) params: FwParameters,
25    pub(crate) regs: LruCache<asm::RegId, Option<MetaVarCellWeak>>,
26    pub(crate) heap: LruCache<asm::MemId, Option<MetaVarCellWeak>>,
27    pub(crate) vars: HashMap<usize, MetaVarCellWeak>,
28    pub(crate) stmts: asm::Program<asm::DOp>,
29}
30
31/// ProgramInner constructors
32impl ProgramInner {
33    pub fn new(params: &FwParameters) -> Self {
34        let nb_regs = match std::num::NonZeroUsize::try_from(params.register) {
35            Ok(val) => val,
36            _ => panic!("Error: Number of registers must be >= 0"),
37        };
38        let mut regs = LruCache::<asm::RegId, Option<MetaVarCellWeak>>::new(nb_regs);
39        // At start regs cache is full of unused slot
40        for rid in 0..params.register {
41            regs.put(asm::RegId(rid as u8), None);
42        }
43
44        let nb_heap = match std::num::NonZeroUsize::try_from(params.heap_size) {
45            Ok(val) => val,
46            _ => panic!("Error: Number of heap slot must be >= 0"),
47        };
48        let mut heap = LruCache::<asm::MemId, Option<MetaVarCellWeak>>::new(nb_heap);
49        // At start heap cache is full of unused slot
50        for hid in 0..params.heap_size as u16 {
51            heap.put(asm::MemId::new_heap(hid), None);
52        }
53
54        Self {
55            uid: 0,
56            params: params.clone(),
57            regs,
58            heap,
59            vars: HashMap::new(),
60            stmts: asm::Program::default(),
61        }
62    }
63}
64
65/// Cache handling
66impl ProgramInner {
67    /// Retrieved least-recent-used register entry
68    /// Return associated register id and evicted variable if any
69    /// Warn: Keep cache state unchanged ...
70    pub(crate) fn reg_lru(&mut self) -> (asm::RegId, Option<MetaVarCell>) {
71        let (rid, rdata) = self
72            .regs
73            .peek_lru()
74            .expect("Error: register cache empty. Check register management");
75
76        // Handle evicted slot if any
77        // Convert it in strong reference for later handling
78        let evicted = if let Some(weak_evicted) = rdata {
79            weak_evicted.try_into().ok()
80        } else {
81            None
82        };
83
84        (*rid, evicted)
85    }
86
87    // Tries to get a range of consecutive aligned free registers and falls back
88    // to the range starting a the LRU
89    pub(crate) fn aligned_reg_range(&self, range: usize) -> Option<asm::RegId> {
90        let range = range as u8;
91        let log_size = asm::dop::ceil_ilog2(&range);
92        let mask = (1 << log_size) - 1;
93        let aligned = || {
94            self.regs
95                .iter()
96                .rev()
97                .filter(|(reg, _)| (reg.0 & mask) == 0)
98        };
99        let rid = aligned()
100            .filter(|(reg, _)| {
101                let reg = reg.0;
102                (reg..reg + range).all(|reg| {
103                    self.regs
104                        .peek(&asm::RegId(reg))
105                        .is_some_and(|r| r.is_none())
106                })
107            })
108            .map(|(reg, _)| *reg)
109            .next();
110        rid.or_else(|| {
111            aligned()
112                .filter(|(reg, _)| {
113                    let reg = reg.0;
114                    (reg + 1..reg + range).all(|reg| self.regs.peek(&asm::RegId(reg)).is_some())
115                })
116                .map(|(i, _)| *i)
117                .next()
118        })
119    }
120
121    // Retrieves the indicated RID
122    // The cache state is unchanged
123    pub(crate) fn reg(&mut self, rid: &asm::RegId) -> Option<MetaVarCell> {
124        let rdata = self
125            .regs
126            .peek(rid)
127            .unwrap_or_else(|| panic!("Error register {rid:} is not available"));
128
129        if let Some(weak_evicted) = rdata {
130            weak_evicted.try_into().ok()
131        } else {
132            None
133        }
134    }
135
136    // Insert the MetaVar in the indicated cache slot and return any evicted
137    // value
138    pub(crate) fn reg_swap_force(
139        &mut self,
140        rid: &asm::RegId,
141        var: MetaVarCell,
142    ) -> Option<MetaVarCell> {
143        // Find lru slot
144        let evicted = self.reg(rid);
145
146        // Update cache state
147        *(self.regs.get_mut(rid).expect("Update an `unused` register")) = Some((&var).into());
148
149        evicted
150    }
151
152    /// Release register entry
153    pub(crate) fn reg_promote(&mut self, rid: asm::RegId) {
154        // Update cache state
155        // Put this slot in front of all `empty` slot instead of in lru pos
156        self.regs.promote(&rid);
157        let demote_order = self
158            .regs
159            .iter()
160            .filter(|(_, var)| var.is_none())
161            .map(|(rid, _)| *rid)
162            .collect::<Vec<_>>();
163        demote_order.into_iter().for_each(|rid| {
164            self.regs.demote(&rid);
165        });
166    }
167
168    /// Release register entry
169    pub(crate) fn reg_release(&mut self, rid: asm::RegId) {
170        trace!(target: "Program", "Release Reg {rid}");
171
172        *(self
173            .regs
174            .get_mut(&rid)
175            .expect("Release an `unused` register")) = None;
176
177        self.reg_promote(rid);
178    }
179
180    /// Notify register access to update LRU state
181    pub(crate) fn reg_access(&mut self, rid: asm::RegId) {
182        self.regs.promote(&rid);
183    }
184
185    /// Retrieved least-recent-used heap entry
186    /// Return associated heap id and evicted variable if any
187    /// Warn: Keep cache state unchanged ...
188    fn heap_lru(&mut self) -> (asm::MemId, Option<MetaVarCell>) {
189        let (mid, rdata) = self
190            .heap
191            .peek_lru()
192            .expect("Error: heap cache empty. Check register management");
193
194        // Handle evicted slot if any
195        // Convert it in strong reference for later handling
196        let evicted = if let Some(weak_evicted) = rdata {
197            weak_evicted.try_into().ok()
198        } else {
199            None
200        };
201
202        (*mid, evicted)
203    }
204
205    /// Release register entry
206    pub(crate) fn heap_release(&mut self, mid: asm::MemId) {
207        trace!(target: "Program", "Release Heap {mid}");
208        match mid {
209            asm::MemId::Heap { .. } => {
210                *(self
211                    .heap
212                    .get_mut(&mid)
213                    .expect("Release an `unused` heap slot")) = None;
214                // Update cache state
215                // Put this slot in front of all `empty` slot instead of in lru pos
216                self.heap.promote(&mid);
217                let demote_order = self
218                    .heap
219                    .iter()
220                    .filter(|(_mid, var)| var.is_none())
221                    .map(|(mid, _)| *mid)
222                    .collect::<Vec<_>>();
223                demote_order.into_iter().for_each(|mid| {
224                    self.heap.demote(&mid);
225                });
226            }
227            _ => { /*Only release Heap slot*/ }
228        }
229    }
230
231    /// Notify heap access to update LRU state
232    pub(crate) fn heap_access(&mut self, mid: asm::MemId) {
233        match mid {
234            asm::MemId::Heap { .. } => {
235                self.heap.promote(&mid);
236            }
237            _ => { /* Do Nothing slot do not below to heap*/ }
238        }
239    }
240
241    /// Insert MetaVar in cache and return evicted value if any
242    pub(crate) fn heap_swap_lru(&mut self, var: MetaVarCell) -> (asm::MemId, Option<MetaVarCell>) {
243        // Find lru slot
244        let (mid, evicted) = self.heap_lru();
245
246        // Update cache state
247        *(self
248            .heap
249            .get_mut(&mid)
250            .expect("Update an `unused` heap slot")) = Some((&var).into());
251
252        (mid, evicted)
253    }
254
255    /// Adds the given register for use
256    pub(super) fn reg_put(&mut self, rid: asm::RegId, meta: Option<MetaVarCellWeak>) {
257        assert!(self.regs.peek(&rid).is_none());
258        self.regs.put(rid, meta);
259    }
260}
261
262/// MetaVar handling
263impl ProgramInner {
264    /// Create MetaVar from an optional argument
265    fn var_from(&mut self, from: Option<VarPos>, ref_to_self: Program) -> MetaVarCell {
266        // Create MetaVar
267        let uid = self.uid;
268        self.uid += 1;
269
270        // Construct tfhe params
271        let tfhe_params: asm::DigitParameters = self.params.clone().into();
272        let var = MetaVarCell::new(ref_to_self, uid, from, tfhe_params);
273
274        // Register in var store
275        self.vars.insert(uid, (&var).into());
276
277        var
278    }
279
280    pub fn new_var(&mut self, ref_to_self: Program) -> MetaVarCell {
281        self.var_from(None, ref_to_self)
282    }
283}
284
285#[derive(Clone)]
286pub struct Program {
287    inner: Rc<RefCell<ProgramInner>>,
288}
289
290impl std::ops::Deref for Program {
291    type Target = Rc<RefCell<ProgramInner>>;
292
293    fn deref(&self) -> &Self::Target {
294        &self.inner
295    }
296}
297
298#[derive(Clone)]
299pub struct StmtLink {
300    prog: Program,
301    pos: Vec<usize>,
302}
303
304impl std::fmt::Debug for StmtLink {
305    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> {
306        f.debug_struct("StmtLink").field("pos", &self.pos).finish()
307    }
308}
309
310impl StmtLink {
311    pub fn empty(prog: Program) -> StmtLink {
312        StmtLink {
313            prog,
314            pos: Vec::new(),
315        }
316    }
317
318    pub fn to_flush(&mut self) {
319        if let Some(pos) = self.pos.first() {
320            let mut borrow = self.prog.borrow_mut();
321            let dop = borrow.stmts.get_stmt_mut(*pos);
322            dop.to_flush();
323        }
324    }
325}
326
327impl Program {
328    pub fn new(params: &FwParameters) -> Self {
329        Self {
330            inner: Rc::new(RefCell::new(ProgramInner::new(params))),
331        }
332    }
333
334    pub fn params(&self) -> FwParameters {
335        self.inner.borrow().params.clone()
336    }
337
338    pub fn op_cfg(&self) -> OpCfg {
339        self.inner.borrow().params.op_cfg()
340    }
341
342    pub fn op_name(&self) -> Option<String> {
343        self.inner.borrow().params.op_name.clone()
344    }
345
346    pub fn set_op(&mut self, opname: &str) {
347        self.inner.borrow_mut().params.set_op(opname);
348    }
349
350    pub fn push_comment(&mut self, comment: String) {
351        self.inner.borrow_mut().stmts.push_comment(comment)
352    }
353
354    // pub fn get_stmts(&self) -> Vec<asm::DOp> {
355    //     self.inner.borrow().stmts.clone()
356    // }
357
358    pub fn var_from(&mut self, from: Option<VarPos>) -> MetaVarCell {
359        self.inner.borrow_mut().var_from(from, self.clone())
360    }
361
362    pub fn new_var(&mut self) -> MetaVarCell {
363        self.var_from(None)
364    }
365
366    /// Easy way to create new imm value
367    pub fn new_imm(&mut self, imm: usize) -> MetaVarCell {
368        let arg = Some(VarPos::Imm(asm::ImmId::Cst(imm as u16)));
369        self.var_from(arg)
370    }
371
372    /// Easy way to create constant backed in register
373    pub fn new_cst(&mut self, cst: usize) -> MetaVarCell {
374        let mut var = self.var_from(None);
375        var.reg_alloc_mv();
376        // Force val to 0 then add cst value
377        var -= var.clone();
378        if cst != 0 {
379            let imm = self.new_imm(cst);
380            var += imm;
381        }
382
383        var
384    }
385
386    /// Create templated arguments
387    /// kind is used to specify if it's bind to src/dst or immediate template
388    /// pos_id is used to bind the template to an IOp operand position
389    // TODO pass the associated operand or immediate to obtain the inner blk properties instead of
390    // using the global one
391    pub fn iop_template_var(&mut self, kind: asm::OperandKind, pos_id: u8) -> Vec<MetaVarCell> {
392        let nb_blk = self.params().blk_w() as u8;
393        match kind {
394            asm::OperandKind::Src => {
395                // Digit in iop arg are contiguous
396                (0..nb_blk)
397                    .map(|bid| {
398                        let mid = asm::MemId::new_src(pos_id, bid);
399                        self.var_from(Some(VarPos::Mem(mid)))
400                    })
401                    .collect::<Vec<_>>()
402            }
403            asm::OperandKind::Dst => {
404                // Digit in iop arg are contiguous
405                (0..nb_blk)
406                    .map(|bid| {
407                        let mid = asm::MemId::new_dst(pos_id, bid);
408                        self.var_from(Some(VarPos::Mem(mid)))
409                    })
410                    .collect::<Vec<_>>()
411            }
412            asm::OperandKind::Imm => (0..nb_blk)
413                .map(|bid| {
414                    let iid = asm::ImmId::new_var(pos_id, bid);
415                    self.var_from(Some(VarPos::Imm(iid)))
416                })
417                .collect::<Vec<_>>(),
418            asm::OperandKind::Unknown => panic!("Template var required a known kind"),
419        }
420    }
421
422    pub fn push_stmt(&mut self, asm: asm::dop::DOp) -> StmtLink {
423        let pos = self.borrow_mut().stmts.push_stmt_pos(asm);
424        StmtLink {
425            prog: self.clone(),
426            pos: vec![pos],
427        }
428    }
429}
430
431#[derive(PartialEq, Eq, Debug)]
432pub enum AtomicRegType {
433    NewRange(usize),
434    Existing(asm::RegId),
435    None,
436}
437
438// Register utilities
439impl Program {
440    /// Bulk reserve
441    /// Evict value from cache in a bulk manner. This enable to prevent false dependency of bulk
442    /// operations when cache is almost full Enforce that at least bulk_size register is `free`
443    pub(crate) fn reg_bulk_reserve(&self, bulk_size: usize) {
444        // Iter from Lru -> MRu and take bulk_size regs
445        let to_evict = self
446            .inner
447            .borrow()
448            .regs
449            .iter()
450            .rev()
451            .take(bulk_size)
452            .filter(|(_, var)| var.is_some())
453            .map(|(_, var)| var.as_ref().unwrap().clone())
454            .collect::<Vec<_>>();
455
456        // Evict metavar to heap and release
457        to_evict.into_iter().for_each(|var| {
458            // Evict in memory if needed
459            if let Ok(cell) = MetaVarCell::try_from(&var) {
460                cell.heap_alloc_mv(true);
461            }
462        });
463    }
464
465    /// Removes the given register from use
466    pub fn reg_pop(&self, rid: &asm::RegId) -> Option<MetaVarCellWeak> {
467        self.inner.borrow_mut().regs.pop(rid).unwrap()
468    }
469
470    /// Adds the given register for use
471    pub fn reg_put(&self, rid: asm::RegId, meta: Option<MetaVarCellWeak>) {
472        self.inner.borrow_mut().reg_put(rid, meta);
473    }
474
475    // Inspects the register cache and yields the requested register ranges, if
476    // possible. This does not touch the cache state.
477    pub fn atomic_reg_range(&self, ranges: &[AtomicRegType]) -> Option<Vec<asm::RegId>> {
478        let mut borrow = self.inner.borrow_mut();
479
480        // Clone the register cache to restore it at the end
481        let backup = borrow.regs.clone();
482
483        // Remove first all already allocated ranges
484        ranges.iter().for_each(|r| {
485            if let AtomicRegType::Existing(rid) = r {
486                borrow.regs.pop(rid);
487            }
488        });
489
490        let result: Option<Vec<_>> = ranges
491            .iter()
492            .map(|r| {
493                match r {
494                    AtomicRegType::NewRange(r) => borrow.aligned_reg_range(*r).inspect(|rid| {
495                        borrow.regs.pop(rid);
496                    }),
497                    AtomicRegType::Existing(rid) => Some(*rid),
498                    // To ignore
499                    AtomicRegType::None => Some(asm::RegId::default()),
500                }
501            })
502            .collect();
503
504        // Restore the cache state
505        borrow.regs = backup;
506
507        result
508    }
509}
510
511impl From<Program> for asm::Program<asm::DOp> {
512    fn from(value: Program) -> Self {
513        let inner = value.inner.borrow();
514        inner.stmts.clone()
515    }
516}
517
518/// Syntax sugar to help user wrap PbsLut in MetaVarCell
519#[macro_export]
520macro_rules! new_pbs {
521    (
522        $prog:ident, $pbs: literal
523    ) => {
524        ::paste::paste! {
525            $prog.var_from(Some(metavar::VarPos::Pbs(asm::dop::[<Pbs $pbs:camel>]::default().into())))
526        }
527    };
528}
529
530/// To get an asm PBS from its name
531#[macro_export]
532macro_rules! pbs_by_name {
533    (
534        $pbs: literal
535    ) => {
536        ::paste::paste! {
537            asm::Pbs::[<$pbs:camel>](asm::dop::[<Pbs $pbs:camel>]::default())
538        }
539    };
540}