1use std::fmt;
2use std::fmt::Debug;
3use tract_core::internal::num_integer::Integer;
4use tract_core::internal::*;
5
6use crate::fact::DeviceTypedFactExt;
7use crate::sync::{DeviceSync, DeviceSyncKind};
8
9#[derive(Debug, Clone, PartialEq, Eq, Hash)]
11pub struct NodeMemReq {
12 pub outlet_id: OutletId,
13 pub lifetime: Lifetime,
14 pub mem_size: TDim,
15}
16
17#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
18pub struct Lifetime {
19 pub start: usize,
20 pub end: usize,
21}
22
23impl Lifetime {
24 pub fn is_disjoint(&self, other: &Lifetime) -> bool {
25 self.start >= other.end || other.start >= self.end
26 }
27
28 pub fn is_alive_at_step(&self, step: usize) -> bool {
29 self.start <= step && step < self.end
30 }
31
32 pub fn is_empty(&self) -> bool {
33 self.len() == 0
34 }
35
36 pub fn len(&self) -> usize {
37 self.end - self.start
38 }
39}
40
41fn next_nodes<'a>(model: &'a TypedModel, node: &TypedNode) -> Option<TVec<&'a TypedNode>> {
42 if node.outputs.is_empty() {
43 return None;
44 };
45
46 Some(
47 node.outputs
48 .iter()
49 .flat_map(|o| {
50 o.successors.iter().map(|succ| &model.nodes()[succ.node]).collect::<Vec<_>>()
51 })
52 .collect(),
53 )
54}
55
56pub fn eval_device_mem_req_for_nodes(
57 model: &TypedModel,
58 order: &[usize],
59) -> TractResult<TVec<NodeMemReq>> {
60 let outputs = model.output_outlets()?.to_vec();
61 let flush_lists = order::build_flush_list(model, order, &outputs, |node| {
62 let Ok(facts) = model.node_output_facts(node.id) else { return false };
63
64 let cpu_sync_in_next_nodes = next_nodes(model, node).is_some_and(|nodes| {
65 nodes.iter().any(|it| {
66 it.op_as::<DeviceSync>().is_some_and(|op| op.kind == DeviceSyncKind::ToHost)
67 })
68 });
69
70 !cpu_sync_in_next_nodes
71 && facts.iter().any(|it| {
72 it.as_device_fact()
73 .map(|it| it.is_from_device() && !it.is_state_owned())
74 .unwrap_or(false)
75 })
76 });
77 let mut scoped_nodes = tvec![];
78
79 for (step, n) in order.iter().enumerate() {
80 let lifetime_start = step;
81
82 let lifetime_end = flush_lists
83 .iter()
84 .enumerate()
85 .find(|(_step, flush_list)| flush_list.contains(n))
86 .map(|it| usize::min(it.0 + 1, order.len()));
87 let Some(lifetime_end) = lifetime_end else {
89 continue;
90 };
91
92 let out_device_tmp_facts = model
93 .node_output_facts(*n)?
94 .into_iter()
95 .flat_map(|it| it.as_device_fact())
96 .filter(|it| it.is_from_device())
97 .collect::<TVec<_>>();
98
99 if out_device_tmp_facts.is_empty() {
100 continue;
101 }
102
103 for (slot, fact) in out_device_tmp_facts.iter().enumerate() {
104 let outlet_id = OutletId { node: *n, slot };
105 for buff_size in fact.buffer_sizes() {
106 scoped_nodes.push(NodeMemReq {
107 outlet_id,
108 lifetime: Lifetime { start: lifetime_start, end: lifetime_end },
109 mem_size: buff_size,
110 })
111 }
112 }
113 }
114
115 Ok(scoped_nodes)
116}
117
118fn collect_opaque_facts(model: &TypedModel) -> TractResult<Vec<NodeOpaqueFacts>> {
119 let mut res: Vec<TVec<Option<Box<dyn OpaqueFact>>>> = vec![];
120 for node in model.nodes() {
121 let mut tmp: TVec<Option<Box<dyn OpaqueFact>>> = tvec![];
122 for fact in model.node_output_facts(node.id)? {
123 if let Some(dev_fact) = fact.as_device_fact() {
124 tmp.push(dev_fact.opaque_fact.clone());
125 }
126 }
127 res.push(tmp);
128 }
129 Ok(res)
130}
131
132#[derive(Debug, Clone, PartialEq, Eq, Hash)]
135pub struct Partition {
136 pub nodes: Vec<NodeMemReq>,
137}
138
139impl Partition {
140 pub fn eval_size_to_i64(&self, symbols: &SymbolValues) -> TractResult<i64> {
141 let mut max_size = self
142 .nodes
143 .iter()
144 .map(|it| it.mem_size.eval_to_i64(symbols))
145 .collect::<TractResult<Vec<_>>>()?
146 .into_iter()
147 .max()
148 .unwrap_or(0);
149 max_size = Integer::next_multiple_of(&max_size, &(vector_size() as i64));
150 Ok(max_size)
151 }
152
153 pub fn size(&self) -> TDim {
154 TDim::Max(self.nodes.iter().map(|s| s.mem_size.clone()).collect()).simplify()
155 }
156
157 pub fn has_no_conflict_with_lifetime(&self, lifetime: &Lifetime) -> bool {
158 self.nodes.iter().all(|n| n.lifetime.is_disjoint(lifetime))
159 }
160
161 pub fn find_node_alive_at_step(&self, step: usize) -> Option<&NodeMemReq> {
162 self.nodes.iter().find(|it| it.lifetime.is_alive_at_step(step))
163 }
164}
165
166type NodeOpaqueFacts = TVec<Option<Box<dyn OpaqueFact>>>;
167#[derive(Debug, Clone, PartialEq, Eq)]
170pub struct DeviceResolvedMemSchema {
171 pub offsets_by_node: Vec<Option<TVec<TVec<usize>>>>,
172 pub memory_size: usize,
173}
174
175#[derive(Debug, Clone, PartialEq, Eq, Hash)]
178pub struct DeviceMemSchema {
179 pub model_num_nodes: usize,
181 pub by_partition: Vec<Partition>,
182 pub by_steps: Vec<Vec<Option<NodeMemReq>>>,
184 pub opaque_facts: Vec<NodeOpaqueFacts>,
185}
186
187impl DeviceMemSchema {
188 pub fn size_by_partition(&self) -> Vec<TDim> {
190 self.by_partition.iter().map(|it| it.size()).collect()
191 }
192
193 pub fn eval_size_by_partition(&self, symbols: &SymbolValues) -> TractResult<Vec<i64>> {
195 self.by_partition.iter().map(|it| it.eval_size_to_i64(symbols)).collect()
196 }
197
198 pub fn memory_size(&self) -> TDim {
200 self.by_partition.iter().map(|it| it.size()).sum()
201 }
202
203 pub fn eval_memory_size(&self, symbols: &SymbolValues) -> TractResult<i64> {
205 self.by_partition.iter().map(|it| it.eval_size_to_i64(symbols)).sum()
206 }
207
208 pub fn compute_offset_by_node(
211 &self,
212 symbols: &SymbolValues,
213 ) -> TractResult<Vec<Option<TVec<TVec<usize>>>>> {
214 let mut cursor = 0;
215 let mut offset_by_outlet: Vec<Option<TVec<TVec<usize>>>> = vec![None; self.model_num_nodes];
216
217 for partition in &self.by_partition {
218 for node_mem in &partition.nodes {
219 let node = node_mem.outlet_id.node;
220 let slot = node_mem.outlet_id.slot;
221
222 let slots: &mut TVec<TVec<usize>> =
223 offset_by_outlet[node].get_or_insert_with(|| tvec![tvec!()]);
224
225 if slot < 1 {
226 slots[slot].push(cursor);
227 } else {
228 if slots.len() <= slot {
229 slots.resize_with(slot + 1, TVec::<usize>::new);
230 }
231 slots[slot].push(cursor);
232 }
233 }
234 cursor += partition.eval_size_to_i64(symbols)? as usize;
235 }
236
237 Ok(offset_by_outlet)
238 }
239
240 pub fn eval_peak_memory_size(&self, symbols: &SymbolValues) -> TractResult<i64> {
244 Ok(self
245 .by_steps
246 .iter()
247 .map(|active_nodes| {
248 active_nodes
249 .iter()
250 .flatten()
251 .map(|it| it.mem_size.clone())
252 .sum::<TDim>()
253 .eval_to_i64(symbols)
254 })
255 .collect::<TractResult<Vec<_>>>()?
256 .into_iter()
257 .max()
258 .unwrap_or(0))
259 }
260
261 pub fn eval_usage(&self, symbols: &SymbolValues) -> TractResult<f32> {
265 let memory_size = self.eval_memory_size(symbols)? as f32;
266 let peak_memory_size = self.eval_peak_memory_size(symbols)? as f32;
267 Ok(peak_memory_size / memory_size)
268 }
269}
270
271impl fmt::Display for DeviceMemSchema {
272 fn fmt(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result {
273 for (step, mem_step) in self.by_steps.iter().enumerate() {
274 writeln!(
275 fmt,
276 "step: {:5} => |{}|",
277 step,
278 mem_step
279 .iter()
280 .map(|n| -> String {
281 n.as_ref()
282 .map(|it| format!("{:^7}/{:^7}", it.outlet_id.node, it.outlet_id.slot))
283 .unwrap_or(format!("{:^7}", "*"))
284 })
285 .collect::<Vec<String>>()
286 .join("|")
287 )?;
288 }
289 writeln!(fmt, "memory_size: {}", self.memory_size())?;
290 Ok(())
291 }
292}
293
294impl DeviceMemSchema {
295 pub fn resolve(&self, symbols: &SymbolValues) -> TractResult<DeviceResolvedMemSchema> {
297 Ok(DeviceResolvedMemSchema {
298 offsets_by_node: self.compute_offset_by_node(symbols)?,
299 memory_size: self.eval_memory_size(symbols)?.try_into()?,
300 })
301 }
302
303 pub fn build(
307 model: &TypedModel,
308 order: &[usize],
309 hint: &SymbolValues,
310 ) -> TractResult<DeviceMemSchema> {
311 let mut nodes_mem_req = eval_device_mem_req_for_nodes(model, order)?;
312
313 let opaque_facts = collect_opaque_facts(model)?;
314 let hinted_mem_size = nodes_mem_req
315 .iter()
316 .map(|node_mem| Ok((node_mem.outlet_id, node_mem.mem_size.eval_to_i64(hint)?)))
317 .collect::<TractResult<HashMap<OutletId, i64>>>()?;
318
319 nodes_mem_req.sort_by(|lhs, rhs| {
320 let lhs_hint_mem_size = hinted_mem_size.get(&lhs.outlet_id);
321 let rhs_hint_mem_size = hinted_mem_size.get(&rhs.outlet_id);
322 lhs_hint_mem_size.cmp(&rhs_hint_mem_size).reverse()
323 });
324
325 let mut partitions: Vec<Partition> = vec![];
326 for node_mem in nodes_mem_req {
327 let mut available = partitions
329 .iter_mut()
330 .filter(|it| it.has_no_conflict_with_lifetime(&node_mem.lifetime))
331 .collect::<Vec<_>>();
332
333 available.sort_by_cached_key(|n| {
334 -n.nodes.iter().flat_map(|it| hinted_mem_size.get(&it.outlet_id)).sum::<i64>()
335 });
336
337 match available.first_mut() {
338 Some(available) => {
339 available.nodes.push(node_mem);
340 }
341 None => partitions.push(Partition { nodes: vec![node_mem] }),
342 }
343 }
344
345 let by_steps: Vec<Vec<Option<NodeMemReq>>> = (0..order.len())
346 .map(|step| {
347 let mem_step: Vec<_> =
348 partitions.iter().map(|p| p.find_node_alive_at_step(step).cloned()).collect();
349 ensure!(mem_step.len() <= partitions.len());
350 Ok(mem_step)
351 })
352 .collect::<TractResult<Vec<_>>>()?;
353
354 Ok(DeviceMemSchema {
355 model_num_nodes: model.nodes().len(),
356 by_partition: partitions,
357 by_steps,
358 opaque_facts,
359 })
360 }
361}
362
363#[cfg(test)]
364mod tests {
365 use super::*;
366
367 #[test]
368 fn test_lifetime_is_disjoint() {
369 let l1 = Lifetime { start: 0, end: 5 };
370 let l2 = Lifetime { start: 5, end: 10 };
371 let l3 = Lifetime { start: 3, end: 7 };
372
373 assert!(l1.is_disjoint(&l2));
374 assert!(l2.is_disjoint(&l1));
375 assert!(!l1.is_disjoint(&l3));
376 assert!(!l3.is_disjoint(&l2));
377 }
378
379 #[test]
380 fn test_lifetime_is_alive_at_step() {
381 let lifetime = Lifetime { start: 5, end: 10 };
382
383 assert!(!lifetime.is_alive_at_step(4));
384 assert!(lifetime.is_alive_at_step(5));
385 assert!(lifetime.is_alive_at_step(7));
386 assert!(lifetime.is_alive_at_step(9));
387 assert!(!lifetime.is_alive_at_step(10));
388 }
389
390 #[test]
391 fn test_empty_lifetime() {
392 let lifetime = Lifetime { start: 5, end: 5 };
393 assert!(lifetime.is_empty());
394 assert_eq!(lifetime.len(), 0);
395 }
396
397 #[test]
398 fn test_node_mem_req_basic() {
399 let outlet_id = OutletId { node: 1, slot: 0 };
400 let req = NodeMemReq {
401 outlet_id,
402 lifetime: Lifetime { start: 0, end: 5 },
403 mem_size: 1000.into(),
404 };
405
406 assert_eq!(req.outlet_id.node, 1);
407 assert_eq!(req.lifetime.start, 0);
408 assert_eq!(req.lifetime.end, 5);
409 assert_eq!(req.mem_size.to_i64().unwrap(), 1000);
410 }
411
412 #[test]
413 fn test_partition_has_no_conflict() {
414 let outlet_id = OutletId { node: 1, slot: 0 };
415 let node1 = NodeMemReq {
416 outlet_id,
417 lifetime: Lifetime { start: 0, end: 5 },
418 mem_size: 1000.into(),
419 };
420
421 let partition = Partition { nodes: vec![node1] };
422
423 assert!(partition.has_no_conflict_with_lifetime(&Lifetime { start: 5, end: 10 }));
424 assert!(!partition.has_no_conflict_with_lifetime(&Lifetime { start: 3, end: 7 }));
425 }
426
427 #[test]
428 fn test_partition_find_node() {
429 let outlet_id = OutletId { node: 1, slot: 0 };
430 let node1 = NodeMemReq {
431 outlet_id,
432 lifetime: Lifetime { start: 0, end: 5 },
433 mem_size: 1000.into(),
434 };
435
436 let outlet_id = OutletId { node: 2, slot: 0 };
437 let node2 = NodeMemReq {
438 outlet_id,
439 lifetime: Lifetime { start: 5, end: 10 },
440 mem_size: 2000.into(),
441 };
442
443 let partition = Partition { nodes: vec![node1.clone(), node2.clone()] };
444
445 assert_eq!(partition.find_node_alive_at_step(3), Some(&node1));
446 assert_eq!(partition.find_node_alive_at_step(7), Some(&node2));
447 assert_eq!(partition.find_node_alive_at_step(10), None);
448 }
449}