Skip to main content

tract_gpu/memory/
schema.rs

1use std::fmt;
2use std::fmt::Debug;
3use tract_core::internal::num_integer::Integer;
4use tract_core::internal::*;
5
6use crate::fact::DeviceTypedFactExt;
7use crate::sync::{DeviceSync, DeviceSyncKind};
8
9/// Requirement for node outputs from a memory perspective.
10#[derive(Debug, Clone, PartialEq, Eq, Hash)]
11pub struct NodeMemReq {
12    pub outlet_id: OutletId,
13    pub lifetime: Lifetime,
14    pub mem_size: TDim,
15}
16
17#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
18pub struct Lifetime {
19    pub start: usize,
20    pub end: usize,
21}
22
23impl Lifetime {
24    pub fn is_disjoint(&self, other: &Lifetime) -> bool {
25        self.start >= other.end || other.start >= self.end
26    }
27
28    pub fn is_alive_at_step(&self, step: usize) -> bool {
29        self.start <= step && step < self.end
30    }
31
32    pub fn is_empty(&self) -> bool {
33        self.len() == 0
34    }
35
36    pub fn len(&self) -> usize {
37        self.end - self.start
38    }
39}
40
41fn next_nodes<'a>(model: &'a TypedModel, node: &TypedNode) -> Option<TVec<&'a TypedNode>> {
42    if node.outputs.is_empty() {
43        return None;
44    };
45
46    Some(
47        node.outputs
48            .iter()
49            .flat_map(|o| {
50                o.successors.iter().map(|succ| &model.nodes()[succ.node]).collect::<Vec<_>>()
51            })
52            .collect(),
53    )
54}
55
56pub fn eval_device_mem_req_for_nodes(
57    model: &TypedModel,
58    order: &[usize],
59) -> TractResult<TVec<NodeMemReq>> {
60    let outputs = model.output_outlets()?.to_vec();
61    let flush_lists = order::build_flush_list(model, order, &outputs, |node| {
62        let Ok(facts) = model.node_output_facts(node.id) else { return false };
63
64        let cpu_sync_in_next_nodes = next_nodes(model, node).is_some_and(|nodes| {
65            nodes.iter().any(|it| {
66                it.op_as::<DeviceSync>().is_some_and(|op| op.kind == DeviceSyncKind::ToHost)
67            })
68        });
69
70        !cpu_sync_in_next_nodes
71            && facts.iter().any(|it| {
72                it.as_device_fact()
73                    .map(|it| it.is_from_device() && !it.is_state_owned())
74                    .unwrap_or(false)
75            })
76    });
77    let mut scoped_nodes = tvec![];
78
79    for (step, n) in order.iter().enumerate() {
80        let lifetime_start = step;
81
82        let lifetime_end = flush_lists
83            .iter()
84            .enumerate()
85            .find(|(_step, flush_list)| flush_list.contains(n))
86            .map(|it| usize::min(it.0 + 1, order.len()));
87        // Ignore nodes that won't be flushed from Device.
88        let Some(lifetime_end) = lifetime_end else {
89            continue;
90        };
91
92        let out_device_tmp_facts = model
93            .node_output_facts(*n)?
94            .into_iter()
95            .flat_map(|it| it.as_device_fact())
96            .filter(|it| it.is_from_device())
97            .collect::<TVec<_>>();
98
99        if out_device_tmp_facts.is_empty() {
100            continue;
101        }
102
103        for (slot, fact) in out_device_tmp_facts.iter().enumerate() {
104            let outlet_id = OutletId { node: *n, slot };
105            for buff_size in fact.buffer_sizes() {
106                scoped_nodes.push(NodeMemReq {
107                    outlet_id,
108                    lifetime: Lifetime { start: lifetime_start, end: lifetime_end },
109                    mem_size: buff_size,
110                })
111            }
112        }
113    }
114
115    Ok(scoped_nodes)
116}
117
118fn collect_opaque_facts(model: &TypedModel) -> TractResult<Vec<NodeOpaqueFacts>> {
119    let mut res: Vec<TVec<Option<Box<dyn OpaqueFact>>>> = vec![];
120    for node in model.nodes() {
121        let mut tmp: TVec<Option<Box<dyn OpaqueFact>>> = tvec![];
122        for fact in model.node_output_facts(node.id)? {
123            if let Some(dev_fact) = fact.as_device_fact() {
124                tmp.push(dev_fact.opaque_fact.clone());
125            }
126        }
127        res.push(tmp);
128    }
129    Ok(res)
130}
131
132/// A partition is a list of node that have disjoint memory requirement from a lifetime
133/// perspective.
134#[derive(Debug, Clone, PartialEq, Eq, Hash)]
135pub struct Partition {
136    pub nodes: Vec<NodeMemReq>,
137}
138
139impl Partition {
140    pub fn eval_size_to_i64(&self, symbols: &SymbolValues) -> TractResult<i64> {
141        let mut max_size = self
142            .nodes
143            .iter()
144            .map(|it| it.mem_size.eval_to_i64(symbols))
145            .collect::<TractResult<Vec<_>>>()?
146            .into_iter()
147            .max()
148            .unwrap_or(0);
149        max_size = Integer::next_multiple_of(&max_size, &(vector_size() as i64));
150        Ok(max_size)
151    }
152
153    pub fn size(&self) -> TDim {
154        TDim::Max(self.nodes.iter().map(|s| s.mem_size.clone()).collect()).simplify()
155    }
156
157    pub fn has_no_conflict_with_lifetime(&self, lifetime: &Lifetime) -> bool {
158        self.nodes.iter().all(|n| n.lifetime.is_disjoint(lifetime))
159    }
160
161    pub fn find_node_alive_at_step(&self, step: usize) -> Option<&NodeMemReq> {
162        self.nodes.iter().find(|it| it.lifetime.is_alive_at_step(step))
163    }
164}
165
166type NodeOpaqueFacts = TVec<Option<Box<dyn OpaqueFact>>>;
167/// This struct represents a resolved memory schema for a model that contains
168/// GPU operators. This schema is concrete.
169#[derive(Debug, Clone, PartialEq, Eq)]
170pub struct DeviceResolvedMemSchema {
171    pub offsets_by_node: Vec<Option<TVec<TVec<usize>>>>,
172    pub memory_size: usize,
173}
174
175/// This struct represent a memory schema for node output memory that are handled
176/// by a GPU.
177#[derive(Debug, Clone, PartialEq, Eq, Hash)]
178pub struct DeviceMemSchema {
179    /// Total numbef in the model.
180    pub model_num_nodes: usize,
181    pub by_partition: Vec<Partition>,
182    // vec![vec![Option<NodeMemReq>; num_partitions]; num_steps].
183    pub by_steps: Vec<Vec<Option<NodeMemReq>>>,
184    pub opaque_facts: Vec<NodeOpaqueFacts>,
185}
186
187impl DeviceMemSchema {
188    /// Returns memory size of each inner partitions.
189    pub fn size_by_partition(&self) -> Vec<TDim> {
190        self.by_partition.iter().map(|it| it.size()).collect()
191    }
192
193    /// Evaluate memory size by partition for given symbol values.
194    pub fn eval_size_by_partition(&self, symbols: &SymbolValues) -> TractResult<Vec<i64>> {
195        self.by_partition.iter().map(|it| it.eval_size_to_i64(symbols)).collect()
196    }
197
198    /// Returns total memory size required for the schema.
199    pub fn memory_size(&self) -> TDim {
200        self.by_partition.iter().map(|it| it.size()).sum()
201    }
202
203    /// Evaluate memory size required for the schema for given symbol values.
204    pub fn eval_memory_size(&self, symbols: &SymbolValues) -> TractResult<i64> {
205        self.by_partition.iter().map(|it| it.eval_size_to_i64(symbols)).sum()
206    }
207
208    /// Compute offsets for each node for given symbols. Node ids
209    /// are indexes in the returned vector.
210    pub fn compute_offset_by_node(
211        &self,
212        symbols: &SymbolValues,
213    ) -> TractResult<Vec<Option<TVec<TVec<usize>>>>> {
214        let mut cursor = 0;
215        let mut offset_by_outlet: Vec<Option<TVec<TVec<usize>>>> = vec![None; self.model_num_nodes];
216
217        for partition in &self.by_partition {
218            for node_mem in &partition.nodes {
219                let node = node_mem.outlet_id.node;
220                let slot = node_mem.outlet_id.slot;
221
222                let slots: &mut TVec<TVec<usize>> =
223                    offset_by_outlet[node].get_or_insert_with(|| tvec![tvec!()]);
224
225                if slot < 1 {
226                    slots[slot].push(cursor);
227                } else {
228                    if slots.len() <= slot {
229                        slots.resize_with(slot + 1, TVec::<usize>::new);
230                    }
231                    slots[slot].push(cursor);
232                }
233            }
234            cursor += partition.eval_size_to_i64(symbols)? as usize;
235        }
236
237        Ok(offset_by_outlet)
238    }
239
240    /// Evaluate peak memory size for given symbols. The return value is lower or equal to the memory
241    /// size of the schema. The difference between peak memory size and memory size represents the
242    /// memory fragmentation introduced by the schema.
243    pub fn eval_peak_memory_size(&self, symbols: &SymbolValues) -> TractResult<i64> {
244        Ok(self
245            .by_steps
246            .iter()
247            .map(|active_nodes| {
248                active_nodes
249                    .iter()
250                    .flatten()
251                    .map(|it| it.mem_size.clone())
252                    .sum::<TDim>()
253                    .eval_to_i64(symbols)
254            })
255            .collect::<TractResult<Vec<_>>>()?
256            .into_iter()
257            .max()
258            .unwrap_or(0))
259    }
260
261    /// Evaluate the usage for given symbols as the ratio between
262    /// schema memory size and peak memory size. A value of 1.0 means
263    /// that the schema doesn't introduce memory fragmentation.
264    pub fn eval_usage(&self, symbols: &SymbolValues) -> TractResult<f32> {
265        let memory_size = self.eval_memory_size(symbols)? as f32;
266        let peak_memory_size = self.eval_peak_memory_size(symbols)? as f32;
267        Ok(peak_memory_size / memory_size)
268    }
269}
270
271impl fmt::Display for DeviceMemSchema {
272    fn fmt(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result {
273        for (step, mem_step) in self.by_steps.iter().enumerate() {
274            writeln!(
275                fmt,
276                "step: {:5} => |{}|",
277                step,
278                mem_step
279                    .iter()
280                    .map(|n| -> String {
281                        n.as_ref()
282                            .map(|it| format!("{:^7}/{:^7}", it.outlet_id.node, it.outlet_id.slot))
283                            .unwrap_or(format!("{:^7}", "*"))
284                    })
285                    .collect::<Vec<String>>()
286                    .join("|")
287            )?;
288        }
289        writeln!(fmt, "memory_size: {}", self.memory_size())?;
290        Ok(())
291    }
292}
293
294impl DeviceMemSchema {
295    /// Resolve Memory schema with given symbols.
296    pub fn resolve(&self, symbols: &SymbolValues) -> TractResult<DeviceResolvedMemSchema> {
297        Ok(DeviceResolvedMemSchema {
298            offsets_by_node: self.compute_offset_by_node(symbols)?,
299            memory_size: self.eval_memory_size(symbols)?.try_into()?,
300        })
301    }
302
303    /// Build a memory schema for given model and execution order. The hint is used to optimize
304    /// the memory schema because it is based on symbolic dimensions. That doesn't mean it will be
305    /// optimal for all possible values for symbolic dimensions.
306    pub fn build(
307        model: &TypedModel,
308        order: &[usize],
309        hint: &SymbolValues,
310    ) -> TractResult<DeviceMemSchema> {
311        let mut nodes_mem_req = eval_device_mem_req_for_nodes(model, order)?;
312
313        let opaque_facts = collect_opaque_facts(model)?;
314        let hinted_mem_size = nodes_mem_req
315            .iter()
316            .map(|node_mem| Ok((node_mem.outlet_id, node_mem.mem_size.eval_to_i64(hint)?)))
317            .collect::<TractResult<HashMap<OutletId, i64>>>()?;
318
319        nodes_mem_req.sort_by(|lhs, rhs| {
320            let lhs_hint_mem_size = hinted_mem_size.get(&lhs.outlet_id);
321            let rhs_hint_mem_size = hinted_mem_size.get(&rhs.outlet_id);
322            lhs_hint_mem_size.cmp(&rhs_hint_mem_size).reverse()
323        });
324
325        let mut partitions: Vec<Partition> = vec![];
326        for node_mem in nodes_mem_req {
327            // Find partitions where node lifetime is disjoint from existing.
328            let mut available = partitions
329                .iter_mut()
330                .filter(|it| it.has_no_conflict_with_lifetime(&node_mem.lifetime))
331                .collect::<Vec<_>>();
332
333            available.sort_by_cached_key(|n| {
334                -n.nodes.iter().flat_map(|it| hinted_mem_size.get(&it.outlet_id)).sum::<i64>()
335            });
336
337            match available.first_mut() {
338                Some(available) => {
339                    available.nodes.push(node_mem);
340                }
341                None => partitions.push(Partition { nodes: vec![node_mem] }),
342            }
343        }
344
345        let by_steps: Vec<Vec<Option<NodeMemReq>>> = (0..order.len())
346            .map(|step| {
347                let mem_step: Vec<_> =
348                    partitions.iter().map(|p| p.find_node_alive_at_step(step).cloned()).collect();
349                ensure!(mem_step.len() <= partitions.len());
350                Ok(mem_step)
351            })
352            .collect::<TractResult<Vec<_>>>()?;
353
354        Ok(DeviceMemSchema {
355            model_num_nodes: model.nodes().len(),
356            by_partition: partitions,
357            by_steps,
358            opaque_facts,
359        })
360    }
361}
362
363#[cfg(test)]
364mod tests {
365    use super::*;
366
367    #[test]
368    fn test_lifetime_is_disjoint() {
369        let l1 = Lifetime { start: 0, end: 5 };
370        let l2 = Lifetime { start: 5, end: 10 };
371        let l3 = Lifetime { start: 3, end: 7 };
372
373        assert!(l1.is_disjoint(&l2));
374        assert!(l2.is_disjoint(&l1));
375        assert!(!l1.is_disjoint(&l3));
376        assert!(!l3.is_disjoint(&l2));
377    }
378
379    #[test]
380    fn test_lifetime_is_alive_at_step() {
381        let lifetime = Lifetime { start: 5, end: 10 };
382
383        assert!(!lifetime.is_alive_at_step(4));
384        assert!(lifetime.is_alive_at_step(5));
385        assert!(lifetime.is_alive_at_step(7));
386        assert!(lifetime.is_alive_at_step(9));
387        assert!(!lifetime.is_alive_at_step(10));
388    }
389
390    #[test]
391    fn test_empty_lifetime() {
392        let lifetime = Lifetime { start: 5, end: 5 };
393        assert!(lifetime.is_empty());
394        assert_eq!(lifetime.len(), 0);
395    }
396
397    #[test]
398    fn test_node_mem_req_basic() {
399        let outlet_id = OutletId { node: 1, slot: 0 };
400        let req = NodeMemReq {
401            outlet_id,
402            lifetime: Lifetime { start: 0, end: 5 },
403            mem_size: 1000.into(),
404        };
405
406        assert_eq!(req.outlet_id.node, 1);
407        assert_eq!(req.lifetime.start, 0);
408        assert_eq!(req.lifetime.end, 5);
409        assert_eq!(req.mem_size.to_i64().unwrap(), 1000);
410    }
411
412    #[test]
413    fn test_partition_has_no_conflict() {
414        let outlet_id = OutletId { node: 1, slot: 0 };
415        let node1 = NodeMemReq {
416            outlet_id,
417            lifetime: Lifetime { start: 0, end: 5 },
418            mem_size: 1000.into(),
419        };
420
421        let partition = Partition { nodes: vec![node1] };
422
423        assert!(partition.has_no_conflict_with_lifetime(&Lifetime { start: 5, end: 10 }));
424        assert!(!partition.has_no_conflict_with_lifetime(&Lifetime { start: 3, end: 7 }));
425    }
426
427    #[test]
428    fn test_partition_find_node() {
429        let outlet_id = OutletId { node: 1, slot: 0 };
430        let node1 = NodeMemReq {
431            outlet_id,
432            lifetime: Lifetime { start: 0, end: 5 },
433            mem_size: 1000.into(),
434        };
435
436        let outlet_id = OutletId { node: 2, slot: 0 };
437        let node2 = NodeMemReq {
438            outlet_id,
439            lifetime: Lifetime { start: 5, end: 10 },
440            mem_size: 2000.into(),
441        };
442
443        let partition = Partition { nodes: vec![node1.clone(), node2.clone()] };
444
445        assert_eq!(partition.find_node_alive_at_step(3), Some(&node1));
446        assert_eq!(partition.find_node_alive_at_step(7), Some(&node2));
447        assert_eq!(partition.find_node_alive_at_step(10), None);
448    }
449}