Skip to main content

tract_gpu/memory/
schema.rs

1use std::fmt;
2use std::fmt::Debug;
3use tract_core::internal::*;
4
5use crate::fact::DeviceTypedFactExt;
6use crate::sync::{DeviceSync, DeviceSyncKind};
7
8/// Requirement for node outputs from a memory perspective.
9#[derive(Debug, Clone, PartialEq, Eq, Hash)]
10pub struct NodeMemReq {
11    pub node: usize,
12    pub lifetime: Lifetime,
13    pub mem_size: TDim,
14}
15
16#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
17pub struct Lifetime {
18    pub start: usize,
19    pub end: usize,
20}
21
22impl Lifetime {
23    pub fn is_disjoint(&self, other: &Lifetime) -> bool {
24        self.start >= other.end || other.start >= self.end
25    }
26
27    pub fn is_alive_at_step(&self, step: usize) -> bool {
28        self.start <= step && step < self.end
29    }
30
31    pub fn is_empty(&self) -> bool {
32        self.len() == 0
33    }
34
35    pub fn len(&self) -> usize {
36        self.end - self.start
37    }
38}
39
40fn next_nodes<'a>(model: &'a TypedModel, node: &TypedNode) -> Option<TVec<&'a TypedNode>> {
41    if node.outputs.is_empty() {
42        return None;
43    };
44
45    Some(
46        node.outputs
47            .iter()
48            .flat_map(|o| {
49                o.successors.iter().map(|succ| &model.nodes()[succ.node]).collect::<Vec<_>>()
50            })
51            .collect(),
52    )
53}
54
55pub fn eval_device_mem_req_for_nodes(
56    model: &TypedModel,
57    order: &[usize],
58) -> TractResult<TVec<NodeMemReq>> {
59    let outputs = model.output_outlets()?.to_vec();
60    let flush_lists = order::build_flush_list(model, order, &outputs, |node| {
61        let Ok(facts) = model.node_output_facts(node.id) else { return false };
62
63        let cpu_sync_in_next_nodes = next_nodes(model, node).is_some_and(|nodes| {
64            nodes.iter().any(|it| {
65                it.op_as::<DeviceSync>().is_some_and(|op| op.kind == DeviceSyncKind::ToHost)
66            })
67        });
68
69        !cpu_sync_in_next_nodes
70            && facts
71                .iter()
72                .any(|it| it.to_device_fact().map(|it| it.is_from_device()).unwrap_or(false))
73    });
74    let mut scoped_nodes = tvec![];
75
76    for (step, n) in order.iter().enumerate() {
77        let lifetime_start = step;
78        let lifetime_end = flush_lists
79            .iter()
80            .enumerate()
81            .find(|(_step, flush_list)| flush_list.contains(n))
82            .map(|it| usize::min(it.0 + 1, order.len()));
83
84        // Ignore nodes that won't be flushed from Device.
85        let Some(lifetime_end) = lifetime_end else {
86            continue;
87        };
88
89        let out_device_tmp_facts = model
90            .node_output_facts(*n)?
91            .into_iter()
92            .flat_map(|it| it.to_device_fact().ok())
93            .filter(|it| it.is_from_device())
94            .collect::<TVec<_>>();
95
96        if out_device_tmp_facts.is_empty() {
97            continue;
98        }
99
100        scoped_nodes.push(NodeMemReq {
101            node: *n,
102            lifetime: Lifetime { start: lifetime_start, end: lifetime_end },
103            mem_size: out_device_tmp_facts.iter().map(|it| it.mem_size()).sum::<TDim>(),
104        })
105    }
106
107    Ok(scoped_nodes)
108}
109
110/// A partition is a list of node that have disjoint memory requirement from a lifetime
111/// perspective.
112#[derive(Debug, Clone, PartialEq, Eq, Hash)]
113pub struct Partition {
114    pub nodes: Vec<NodeMemReq>,
115}
116
117impl Partition {
118    pub fn eval_size_to_i64(&self, symbols: &SymbolValues) -> TractResult<i64> {
119        Ok(self
120            .nodes
121            .iter()
122            .map(|it| it.mem_size.eval_to_i64(symbols))
123            .collect::<TractResult<Vec<_>>>()?
124            .into_iter()
125            .max()
126            .unwrap_or(0))
127    }
128
129    pub fn size(&self) -> TDim {
130        TDim::Max(self.nodes.iter().map(|s| s.mem_size.clone()).collect()).simplify()
131    }
132
133    pub fn has_no_conflict_with_lifetime(&self, lifetime: &Lifetime) -> bool {
134        self.nodes.iter().all(|n| n.lifetime.is_disjoint(lifetime))
135    }
136
137    pub fn find_node_alive_at_step(&self, step: usize) -> Option<&NodeMemReq> {
138        self.nodes.iter().find(|it| it.lifetime.is_alive_at_step(step))
139    }
140}
141
142/// This struct represents a resolved memory schema for a model that contains
143/// GPU operators. This schema is concrete.
144#[derive(Debug, Clone, PartialEq, Eq)]
145pub struct DeviceResolvedMemSchema {
146    pub offsets_by_node: Vec<Option<usize>>,
147    pub memory_size: usize,
148}
149
150/// This struct represent a memory schema for node output memory that are handled
151/// by a GPU.
152#[derive(Debug, Clone, PartialEq, Eq, Hash)]
153pub struct DeviceMemSchema {
154    /// Total numbef in the model.
155    pub model_num_nodes: usize,
156    pub by_partition: Vec<Partition>,
157    // vec![vec![Option<NodeMemReq>; num_partitions]; num_steps].
158    pub by_steps: Vec<Vec<Option<NodeMemReq>>>,
159}
160
161impl DeviceMemSchema {
162    /// Returns memory size of each inner partitions.
163    pub fn size_by_partition(&self) -> Vec<TDim> {
164        self.by_partition.iter().map(|it| it.size()).collect()
165    }
166
167    /// Evaluate memory size by partition for given symbol values.
168    pub fn eval_size_by_partition(&self, symbols: &SymbolValues) -> TractResult<Vec<i64>> {
169        self.by_partition.iter().map(|it| it.eval_size_to_i64(symbols)).collect()
170    }
171
172    /// Returns total memory size required for the schema.
173    pub fn memory_size(&self) -> TDim {
174        self.by_partition.iter().map(|it| it.size()).sum()
175    }
176
177    /// Evaluate memory size required for the schema for given symbol values.
178    pub fn eval_memory_size(&self, symbols: &SymbolValues) -> TractResult<i64> {
179        self.by_partition.iter().map(|it| it.eval_size_to_i64(symbols)).sum()
180    }
181
182    /// Compute offsets for each node for given symbols. Node ids
183    /// are indexes in the returned vector.
184    pub fn compute_offset_by_node(
185        &self,
186        symbols: &SymbolValues,
187    ) -> TractResult<Vec<Option<usize>>> {
188        let mut cursor = 0;
189        let mut offset_by_node = vec![None; self.model_num_nodes];
190
191        for partition in self.by_partition.iter() {
192            for node_mem in partition.nodes.iter() {
193                offset_by_node[node_mem.node] = Some(cursor);
194            }
195            cursor += partition.eval_size_to_i64(symbols)? as usize;
196        }
197
198        Ok(offset_by_node)
199    }
200
201    /// Evaluate peak memory size for given symbols. The return value is lower or equal to the memory
202    /// size of the schema. The difference between peak memory size and memory size represents the
203    /// memory fragmentation introduced by the schema.
204    pub fn eval_peak_memory_size(&self, symbols: &SymbolValues) -> TractResult<i64> {
205        Ok(self
206            .by_steps
207            .iter()
208            .map(|active_nodes| {
209                active_nodes
210                    .iter()
211                    .flatten()
212                    .map(|it| it.mem_size.clone())
213                    .sum::<TDim>()
214                    .eval_to_i64(symbols)
215            })
216            .collect::<TractResult<Vec<_>>>()?
217            .into_iter()
218            .max()
219            .unwrap_or(0))
220    }
221
222    /// Evaluate the usage for given symbols as the ratio between
223    /// schema memory size and peak memory size. A value of 1.0 means
224    /// that the schema doesn't introduce memory fragmentation.
225    pub fn eval_usage(&self, symbols: &SymbolValues) -> TractResult<f32> {
226        let memory_size = self.eval_memory_size(symbols)? as f32;
227        let peak_memory_size = self.eval_peak_memory_size(symbols)? as f32;
228        Ok(peak_memory_size / memory_size)
229    }
230}
231
232impl fmt::Display for DeviceMemSchema {
233    fn fmt(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result {
234        for (step, mem_step) in self.by_steps.iter().enumerate() {
235            writeln!(
236                fmt,
237                "step: {:5} => |{}|",
238                step,
239                mem_step
240                    .iter()
241                    .map(|n| -> String {
242                        n.as_ref()
243                            .map(|it| format!("{:^7}", it.node))
244                            .unwrap_or(format!("{:^7}", "*"))
245                    })
246                    .collect::<Vec<String>>()
247                    .join("|")
248            )?;
249        }
250        writeln!(fmt, "memory_size: {}", self.memory_size())?;
251        Ok(())
252    }
253}
254
255impl DeviceMemSchema {
256    /// Resolve Memory schema with given symbols.
257    pub fn resolve(&self, symbols: &SymbolValues) -> TractResult<DeviceResolvedMemSchema> {
258        Ok(DeviceResolvedMemSchema {
259            offsets_by_node: self.compute_offset_by_node(symbols)?,
260            memory_size: self.eval_memory_size(symbols)?.try_into()?,
261        })
262    }
263
264    /// Build a memory schema for given model and execution order. The hint is used to optimize
265    /// the memory schema because it is based on symbolic dimensions. That doesn't mean it will be
266    /// optimal for all possible values for symbolic dimensions.
267    pub fn build(
268        model: &TypedModel,
269        order: &[usize],
270        hint: &SymbolValues,
271    ) -> TractResult<DeviceMemSchema> {
272        let mut nodes_mem_req = eval_device_mem_req_for_nodes(model, order)?;
273
274        let hinted_mem_size = nodes_mem_req
275            .iter()
276            .map(|node_mem| Ok((node_mem.node, node_mem.mem_size.eval_to_i64(hint)?)))
277            .collect::<TractResult<HashMap<usize, i64>>>()?;
278
279        nodes_mem_req.sort_by(|lhs, rhs| {
280            let lhs_hint_mem_size = hinted_mem_size.get(&lhs.node);
281            let rhs_hint_mem_size = hinted_mem_size.get(&rhs.node);
282            lhs_hint_mem_size.cmp(&rhs_hint_mem_size).reverse()
283        });
284
285        let mut partitions: Vec<Partition> = vec![];
286        for node_mem in nodes_mem_req {
287            // Find partitions where node lifetime is disjoint from existing.
288            let mut available = partitions
289                .iter_mut()
290                .filter(|it| it.has_no_conflict_with_lifetime(&node_mem.lifetime))
291                .collect::<Vec<_>>();
292
293            available.sort_by_cached_key(|n| {
294                -n.nodes.iter().flat_map(|it| hinted_mem_size.get(&it.node)).sum::<i64>()
295            });
296
297            match available.first_mut() {
298                Some(available) => {
299                    available.nodes.push(node_mem);
300                }
301                None => partitions.push(Partition { nodes: vec![node_mem] }),
302            }
303        }
304
305        let by_steps: Vec<Vec<Option<NodeMemReq>>> = (0..order.len())
306            .map(|step| {
307                let mem_step: Vec<_> =
308                    partitions.iter().map(|p| p.find_node_alive_at_step(step).cloned()).collect();
309                ensure!(mem_step.len() <= partitions.len());
310                Ok(mem_step)
311            })
312            .collect::<TractResult<Vec<_>>>()?;
313
314        Ok(DeviceMemSchema {
315            model_num_nodes: model.nodes().len(),
316            by_partition: partitions,
317            by_steps,
318        })
319    }
320}
321
322#[cfg(test)]
323mod tests {
324    use super::*;
325
326    #[test]
327    fn test_lifetime_is_disjoint() {
328        let l1 = Lifetime { start: 0, end: 5 };
329        let l2 = Lifetime { start: 5, end: 10 };
330        let l3 = Lifetime { start: 3, end: 7 };
331
332        assert!(l1.is_disjoint(&l2));
333        assert!(l2.is_disjoint(&l1));
334        assert!(!l1.is_disjoint(&l3));
335        assert!(!l3.is_disjoint(&l2));
336    }
337
338    #[test]
339    fn test_lifetime_is_alive_at_step() {
340        let lifetime = Lifetime { start: 5, end: 10 };
341
342        assert!(!lifetime.is_alive_at_step(4));
343        assert!(lifetime.is_alive_at_step(5));
344        assert!(lifetime.is_alive_at_step(7));
345        assert!(lifetime.is_alive_at_step(9));
346        assert!(!lifetime.is_alive_at_step(10));
347    }
348
349    #[test]
350    fn test_empty_lifetime() {
351        let lifetime = Lifetime { start: 5, end: 5 };
352        assert!(lifetime.is_empty());
353        assert_eq!(lifetime.len(), 0);
354    }
355
356    #[test]
357    fn test_node_mem_req_basic() {
358        let req =
359            NodeMemReq { node: 1, lifetime: Lifetime { start: 0, end: 5 }, mem_size: 1000.into() };
360
361        assert_eq!(req.node, 1);
362        assert_eq!(req.lifetime.start, 0);
363        assert_eq!(req.lifetime.end, 5);
364        assert_eq!(req.mem_size.to_i64().unwrap(), 1000);
365    }
366
367    #[test]
368    fn test_partition_has_no_conflict() {
369        let node1 =
370            NodeMemReq { node: 1, lifetime: Lifetime { start: 0, end: 5 }, mem_size: 1000.into() };
371
372        let partition = Partition { nodes: vec![node1] };
373
374        assert!(partition.has_no_conflict_with_lifetime(&Lifetime { start: 5, end: 10 }));
375        assert!(!partition.has_no_conflict_with_lifetime(&Lifetime { start: 3, end: 7 }));
376    }
377
378    #[test]
379    fn test_partition_find_node() {
380        let node1 =
381            NodeMemReq { node: 1, lifetime: Lifetime { start: 0, end: 5 }, mem_size: 1000.into() };
382
383        let node2 =
384            NodeMemReq { node: 2, lifetime: Lifetime { start: 5, end: 10 }, mem_size: 2000.into() };
385
386        let partition = Partition { nodes: vec![node1.clone(), node2.clone()] };
387
388        assert_eq!(partition.find_node_alive_at_step(3), Some(&node1));
389        assert_eq!(partition.find_node_alive_at_step(7), Some(&node2));
390        assert_eq!(partition.find_node_alive_at_step(10), None);
391    }
392}