Skip to main content

tract_gpu/memory/
schema.rs

1use std::fmt;
2use std::fmt::Debug;
3use tract_core::internal::num_integer::Integer;
4use tract_core::internal::*;
5
6use crate::fact::DeviceTypedFactExt;
7use crate::sync::{DeviceSync, DeviceSyncKind};
8
9/// Requirement for node outputs from a memory perspective.
10#[derive(Debug, Clone, PartialEq, Eq, Hash)]
11pub struct NodeMemReq {
12    pub node: usize,
13    pub lifetime: Lifetime,
14    pub mem_size: TDim,
15}
16
17#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
18pub struct Lifetime {
19    pub start: usize,
20    pub end: usize,
21}
22
23impl Lifetime {
24    pub fn is_disjoint(&self, other: &Lifetime) -> bool {
25        self.start >= other.end || other.start >= self.end
26    }
27
28    pub fn is_alive_at_step(&self, step: usize) -> bool {
29        self.start <= step && step < self.end
30    }
31
32    pub fn is_empty(&self) -> bool {
33        self.len() == 0
34    }
35
36    pub fn len(&self) -> usize {
37        self.end - self.start
38    }
39}
40
41fn next_nodes<'a>(model: &'a TypedModel, node: &TypedNode) -> Option<TVec<&'a TypedNode>> {
42    if node.outputs.is_empty() {
43        return None;
44    };
45
46    Some(
47        node.outputs
48            .iter()
49            .flat_map(|o| {
50                o.successors.iter().map(|succ| &model.nodes()[succ.node]).collect::<Vec<_>>()
51            })
52            .collect(),
53    )
54}
55
56pub fn eval_device_mem_req_for_nodes(
57    model: &TypedModel,
58    order: &[usize],
59) -> TractResult<TVec<NodeMemReq>> {
60    let outputs = model.output_outlets()?.to_vec();
61    let flush_lists = order::build_flush_list(model, order, &outputs, |node| {
62        let Ok(facts) = model.node_output_facts(node.id) else { return false };
63
64        let cpu_sync_in_next_nodes = next_nodes(model, node).is_some_and(|nodes| {
65            nodes.iter().any(|it| {
66                it.op_as::<DeviceSync>().is_some_and(|op| op.kind == DeviceSyncKind::ToHost)
67            })
68        });
69
70        !cpu_sync_in_next_nodes
71            && node.op.name() != "MetalDynKVCache"
72            && node.op.name() != "CudaDynKVCache"
73            && facts
74                .iter()
75                .any(|it| it.to_device_fact().map(|it| it.is_from_device()).unwrap_or(false))
76    });
77    let mut scoped_nodes = tvec![];
78
79    for (step, n) in order.iter().enumerate() {
80        let lifetime_start = step;
81
82        let lifetime_end = flush_lists
83            .iter()
84            .enumerate()
85            .find(|(_step, flush_list)| flush_list.contains(n))
86            .map(|it| usize::min(it.0 + 1, order.len()));
87        // Ignore nodes that won't be flushed from Device.
88        let Some(lifetime_end) = lifetime_end else {
89            continue;
90        };
91
92        let out_device_tmp_facts = model
93            .node_output_facts(*n)?
94            .into_iter()
95            .flat_map(|it| it.to_device_fact().ok())
96            .filter(|it| it.is_from_device())
97            .collect::<TVec<_>>();
98
99        if out_device_tmp_facts.is_empty() {
100            continue;
101        }
102
103        scoped_nodes.push(NodeMemReq {
104            node: *n,
105            lifetime: Lifetime { start: lifetime_start, end: lifetime_end },
106            mem_size: out_device_tmp_facts.iter().map(|it| it.mem_size()).sum::<TDim>(),
107        })
108    }
109
110    Ok(scoped_nodes)
111}
112
113/// A partition is a list of node that have disjoint memory requirement from a lifetime
114/// perspective.
115#[derive(Debug, Clone, PartialEq, Eq, Hash)]
116pub struct Partition {
117    pub nodes: Vec<NodeMemReq>,
118}
119
120impl Partition {
121    pub fn eval_size_to_i64(&self, symbols: &SymbolValues) -> TractResult<i64> {
122        let mut max_size = self
123            .nodes
124            .iter()
125            .map(|it| it.mem_size.eval_to_i64(symbols))
126            .collect::<TractResult<Vec<_>>>()?
127            .into_iter()
128            .max()
129            .unwrap_or(0);
130        max_size = Integer::next_multiple_of(&max_size, &(vector_size() as i64));
131        Ok(max_size)
132    }
133
134    pub fn size(&self) -> TDim {
135        TDim::Max(self.nodes.iter().map(|s| s.mem_size.clone()).collect()).simplify()
136    }
137
138    pub fn has_no_conflict_with_lifetime(&self, lifetime: &Lifetime) -> bool {
139        self.nodes.iter().all(|n| n.lifetime.is_disjoint(lifetime))
140    }
141
142    pub fn find_node_alive_at_step(&self, step: usize) -> Option<&NodeMemReq> {
143        self.nodes.iter().find(|it| it.lifetime.is_alive_at_step(step))
144    }
145}
146
147/// This struct represents a resolved memory schema for a model that contains
148/// GPU operators. This schema is concrete.
149#[derive(Debug, Clone, PartialEq, Eq)]
150pub struct DeviceResolvedMemSchema {
151    pub offsets_by_node: Vec<Option<usize>>,
152    pub memory_size: usize,
153}
154
155/// This struct represent a memory schema for node output memory that are handled
156/// by a GPU.
157#[derive(Debug, Clone, PartialEq, Eq, Hash)]
158pub struct DeviceMemSchema {
159    /// Total numbef in the model.
160    pub model_num_nodes: usize,
161    pub by_partition: Vec<Partition>,
162    // vec![vec![Option<NodeMemReq>; num_partitions]; num_steps].
163    pub by_steps: Vec<Vec<Option<NodeMemReq>>>,
164}
165
166impl DeviceMemSchema {
167    /// Returns memory size of each inner partitions.
168    pub fn size_by_partition(&self) -> Vec<TDim> {
169        self.by_partition.iter().map(|it| it.size()).collect()
170    }
171
172    /// Evaluate memory size by partition for given symbol values.
173    pub fn eval_size_by_partition(&self, symbols: &SymbolValues) -> TractResult<Vec<i64>> {
174        self.by_partition.iter().map(|it| it.eval_size_to_i64(symbols)).collect()
175    }
176
177    /// Returns total memory size required for the schema.
178    pub fn memory_size(&self) -> TDim {
179        self.by_partition.iter().map(|it| it.size()).sum()
180    }
181
182    /// Evaluate memory size required for the schema for given symbol values.
183    pub fn eval_memory_size(&self, symbols: &SymbolValues) -> TractResult<i64> {
184        self.by_partition.iter().map(|it| it.eval_size_to_i64(symbols)).sum()
185    }
186
187    /// Compute offsets for each node for given symbols. Node ids
188    /// are indexes in the returned vector.
189    pub fn compute_offset_by_node(
190        &self,
191        symbols: &SymbolValues,
192    ) -> TractResult<Vec<Option<usize>>> {
193        let mut cursor = 0;
194        let mut offset_by_node = vec![None; self.model_num_nodes];
195
196        for partition in self.by_partition.iter() {
197            for node_mem in partition.nodes.iter() {
198                offset_by_node[node_mem.node] = Some(cursor);
199            }
200            cursor += partition.eval_size_to_i64(symbols)? as usize;
201        }
202
203        Ok(offset_by_node)
204    }
205
206    /// Evaluate peak memory size for given symbols. The return value is lower or equal to the memory
207    /// size of the schema. The difference between peak memory size and memory size represents the
208    /// memory fragmentation introduced by the schema.
209    pub fn eval_peak_memory_size(&self, symbols: &SymbolValues) -> TractResult<i64> {
210        Ok(self
211            .by_steps
212            .iter()
213            .map(|active_nodes| {
214                active_nodes
215                    .iter()
216                    .flatten()
217                    .map(|it| it.mem_size.clone())
218                    .sum::<TDim>()
219                    .eval_to_i64(symbols)
220            })
221            .collect::<TractResult<Vec<_>>>()?
222            .into_iter()
223            .max()
224            .unwrap_or(0))
225    }
226
227    /// Evaluate the usage for given symbols as the ratio between
228    /// schema memory size and peak memory size. A value of 1.0 means
229    /// that the schema doesn't introduce memory fragmentation.
230    pub fn eval_usage(&self, symbols: &SymbolValues) -> TractResult<f32> {
231        let memory_size = self.eval_memory_size(symbols)? as f32;
232        let peak_memory_size = self.eval_peak_memory_size(symbols)? as f32;
233        Ok(peak_memory_size / memory_size)
234    }
235}
236
237impl fmt::Display for DeviceMemSchema {
238    fn fmt(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result {
239        for (step, mem_step) in self.by_steps.iter().enumerate() {
240            writeln!(
241                fmt,
242                "step: {:5} => |{}|",
243                step,
244                mem_step
245                    .iter()
246                    .map(|n| -> String {
247                        n.as_ref()
248                            .map(|it| format!("{:^7}", it.node))
249                            .unwrap_or(format!("{:^7}", "*"))
250                    })
251                    .collect::<Vec<String>>()
252                    .join("|")
253            )?;
254        }
255        writeln!(fmt, "memory_size: {}", self.memory_size())?;
256        Ok(())
257    }
258}
259
260impl DeviceMemSchema {
261    /// Resolve Memory schema with given symbols.
262    pub fn resolve(&self, symbols: &SymbolValues) -> TractResult<DeviceResolvedMemSchema> {
263        Ok(DeviceResolvedMemSchema {
264            offsets_by_node: self.compute_offset_by_node(symbols)?,
265            memory_size: self.eval_memory_size(symbols)?.try_into()?,
266        })
267    }
268
269    /// Build a memory schema for given model and execution order. The hint is used to optimize
270    /// the memory schema because it is based on symbolic dimensions. That doesn't mean it will be
271    /// optimal for all possible values for symbolic dimensions.
272    pub fn build(
273        model: &TypedModel,
274        order: &[usize],
275        hint: &SymbolValues,
276    ) -> TractResult<DeviceMemSchema> {
277        let mut nodes_mem_req = eval_device_mem_req_for_nodes(model, order)?;
278
279        let hinted_mem_size = nodes_mem_req
280            .iter()
281            .map(|node_mem| Ok((node_mem.node, node_mem.mem_size.eval_to_i64(hint)?)))
282            .collect::<TractResult<HashMap<usize, i64>>>()?;
283
284        nodes_mem_req.sort_by(|lhs, rhs| {
285            let lhs_hint_mem_size = hinted_mem_size.get(&lhs.node);
286            let rhs_hint_mem_size = hinted_mem_size.get(&rhs.node);
287            lhs_hint_mem_size.cmp(&rhs_hint_mem_size).reverse()
288        });
289
290        let mut partitions: Vec<Partition> = vec![];
291        for node_mem in nodes_mem_req {
292            // Find partitions where node lifetime is disjoint from existing.
293            let mut available = partitions
294                .iter_mut()
295                .filter(|it| it.has_no_conflict_with_lifetime(&node_mem.lifetime))
296                .collect::<Vec<_>>();
297
298            available.sort_by_cached_key(|n| {
299                -n.nodes.iter().flat_map(|it| hinted_mem_size.get(&it.node)).sum::<i64>()
300            });
301
302            match available.first_mut() {
303                Some(available) => {
304                    available.nodes.push(node_mem);
305                }
306                None => partitions.push(Partition { nodes: vec![node_mem] }),
307            }
308        }
309
310        let by_steps: Vec<Vec<Option<NodeMemReq>>> = (0..order.len())
311            .map(|step| {
312                let mem_step: Vec<_> =
313                    partitions.iter().map(|p| p.find_node_alive_at_step(step).cloned()).collect();
314                ensure!(mem_step.len() <= partitions.len());
315                Ok(mem_step)
316            })
317            .collect::<TractResult<Vec<_>>>()?;
318
319        Ok(DeviceMemSchema {
320            model_num_nodes: model.nodes().len(),
321            by_partition: partitions,
322            by_steps,
323        })
324    }
325}
326
327#[cfg(test)]
328mod tests {
329    use super::*;
330
331    #[test]
332    fn test_lifetime_is_disjoint() {
333        let l1 = Lifetime { start: 0, end: 5 };
334        let l2 = Lifetime { start: 5, end: 10 };
335        let l3 = Lifetime { start: 3, end: 7 };
336
337        assert!(l1.is_disjoint(&l2));
338        assert!(l2.is_disjoint(&l1));
339        assert!(!l1.is_disjoint(&l3));
340        assert!(!l3.is_disjoint(&l2));
341    }
342
343    #[test]
344    fn test_lifetime_is_alive_at_step() {
345        let lifetime = Lifetime { start: 5, end: 10 };
346
347        assert!(!lifetime.is_alive_at_step(4));
348        assert!(lifetime.is_alive_at_step(5));
349        assert!(lifetime.is_alive_at_step(7));
350        assert!(lifetime.is_alive_at_step(9));
351        assert!(!lifetime.is_alive_at_step(10));
352    }
353
354    #[test]
355    fn test_empty_lifetime() {
356        let lifetime = Lifetime { start: 5, end: 5 };
357        assert!(lifetime.is_empty());
358        assert_eq!(lifetime.len(), 0);
359    }
360
361    #[test]
362    fn test_node_mem_req_basic() {
363        let req =
364            NodeMemReq { node: 1, lifetime: Lifetime { start: 0, end: 5 }, mem_size: 1000.into() };
365
366        assert_eq!(req.node, 1);
367        assert_eq!(req.lifetime.start, 0);
368        assert_eq!(req.lifetime.end, 5);
369        assert_eq!(req.mem_size.to_i64().unwrap(), 1000);
370    }
371
372    #[test]
373    fn test_partition_has_no_conflict() {
374        let node1 =
375            NodeMemReq { node: 1, lifetime: Lifetime { start: 0, end: 5 }, mem_size: 1000.into() };
376
377        let partition = Partition { nodes: vec![node1] };
378
379        assert!(partition.has_no_conflict_with_lifetime(&Lifetime { start: 5, end: 10 }));
380        assert!(!partition.has_no_conflict_with_lifetime(&Lifetime { start: 3, end: 7 }));
381    }
382
383    #[test]
384    fn test_partition_find_node() {
385        let node1 =
386            NodeMemReq { node: 1, lifetime: Lifetime { start: 0, end: 5 }, mem_size: 1000.into() };
387
388        let node2 =
389            NodeMemReq { node: 2, lifetime: Lifetime { start: 5, end: 10 }, mem_size: 2000.into() };
390
391        let partition = Partition { nodes: vec![node1.clone(), node2.clone()] };
392
393        assert_eq!(partition.find_node_alive_at_step(3), Some(&node1));
394        assert_eq!(partition.find_node_alive_at_step(7), Some(&node2));
395        assert_eq!(partition.find_node_alive_at_step(10), None);
396    }
397}