1use std::iter;
2
3use num::Integer;
4
5use crate::{
6 chain::{ChainId, SubchainIndex},
7 cli::call_expr::{Argument, CallExpr, Span},
8 job::{add_transform_to_job, Job, JobData},
9 operators::operator::TransformInstatiation,
10 options::session_setup::SessionSetupData,
11 record_data::{
12 group_track::{GroupTrackIterId, GroupTrackIterRef},
13 iter_hall::IterKind,
14 },
15 typeline_error::TypelineError,
16 utils::indexing_type::IndexingType,
17};
18
19use super::{
20 errors::OperatorCreationError,
21 nop::create_op_nop,
22 operator::{
23 Operator, OperatorDataId, OperatorId, OperatorInstantiation,
24 OperatorOffsetInChain, PreboundOutputsMap,
25 },
26 transform::{Transform, TransformId, TransformState},
27};
28
29pub struct OpChunks {
30 pub subchain: Vec<(Box<dyn Operator>, Span)>,
31 pub subchain_idx: SubchainIndex,
32 pub stride: usize,
33}
34pub struct TfChunksHeader {
35 parent_group_track_iter: GroupTrackIterId,
36 stride: usize,
37 curr_stride_rem: usize,
38 starting_new_group: bool,
39}
40pub struct TfChunksTrailer {}
41
42pub fn create_op_chunks_with_spans(
43 stride: usize,
44 stride_span: Span,
45 subchain: impl IntoIterator<Item = (Box<dyn Operator>, Span)>,
46) -> Result<Box<dyn Operator>, OperatorCreationError> {
47 if stride == 0 {
48 return Err(OperatorCreationError::new(
49 "chunk stride cannot be zero",
50 stride_span,
51 ));
52 }
53
54 let mut subchain = subchain.into_iter().collect::<Vec<_>>();
55 if subchain.is_empty() {
56 subchain.push((create_op_nop(), Span::Generated));
57 }
58 Ok(Box::new(OpChunks {
59 subchain,
60 subchain_idx: SubchainIndex::MAX_VALUE,
61 stride,
62 }))
63}
64
65impl Operator for OpChunks {
66 fn default_name(&self) -> super::operator::OperatorName {
67 "chunks".into()
68 }
69
70 fn output_count(
71 &self,
72 _sess: &crate::context::SessionData,
73 _op_id: OperatorId,
74 ) -> usize {
75 0
76 }
77
78 fn has_dynamic_outputs(
79 &self,
80 _sess: &crate::context::SessionData,
81 _op_id: OperatorId,
82 ) -> bool {
83 false
84 }
85
86 fn output_field_kind(
87 &self,
88 _sess: &crate::context::SessionData,
89 _op_id: OperatorId,
90 ) -> super::operator::OutputFieldKind {
91 super::operator::OutputFieldKind::SameAsInput
92 }
93
94 fn setup(
95 &mut self,
96 sess: &mut SessionSetupData,
97 op_data_id: OperatorDataId,
98 chain_id: ChainId,
99 offset_in_chain: OperatorOffsetInChain,
100 span: Span,
101 ) -> Result<OperatorId, TypelineError> {
102 let op_id = sess.add_op(op_data_id, chain_id, offset_in_chain, span);
103 self.subchain_idx = sess.chains[chain_id].subchains.next_idx();
104 sess.setup_subchain(chain_id, std::mem::take(&mut self.subchain))?;
105 Ok(op_id)
106 }
107
108 fn build_transforms<'a>(
109 &'a self,
110 job: &mut Job<'a>,
111 tf_state: &mut TransformState,
112 op_id: OperatorId,
113 prebound_outputs: &PreboundOutputsMap,
114 ) -> super::operator::TransformInstatiation<'a> {
115 let chain_id =
116 job.job_data.session_data.operator_bases[op_id].chain_id;
117 let subchain_id = job.job_data.session_data.chains[chain_id].subchains
118 [self.subchain_idx];
119 let sc_start_op_id = job.job_data.session_data.chains[subchain_id]
120 .operators
121 .first();
122 let ms_id = tf_state.match_set_id;
123 let desired_batch_size = tf_state.desired_batch_size;
124 let input_field = tf_state.input_field;
125
126 let ms = &mut job.job_data.match_set_mgr.match_sets[ms_id];
127 let next_actor_id = ms.action_buffer.borrow().next_actor_ref();
128 let parent_group_track = tf_state.input_group_track_id;
129
130 let header_tf_id_peek = job.job_data.tf_mgr.transforms.peek_claim_id();
131
132 let parent_group_track_iter =
133 job.job_data.group_track_manager.claim_group_track_iter(
134 parent_group_track,
135 next_actor_id.get_id(),
136 IterKind::Transform(header_tf_id_peek),
137 );
138 let group_track = job.job_data.group_track_manager.add_group_track(
139 &job.job_data.match_set_mgr,
140 Some(parent_group_track),
141 ms_id,
142 next_actor_id,
143 );
144 tf_state.output_group_track_id = group_track;
145
146 let mut trailer_output_field = input_field;
147
148 let header_tf_id = add_transform_to_job(
149 &mut job.job_data,
150 &mut job.transform_data,
151 tf_state.clone(),
152 Box::new(TfChunksHeader {
153 parent_group_track_iter,
154 stride: self.stride,
155 curr_stride_rem: self.stride,
156 starting_new_group: false,
157 }),
158 );
159 debug_assert!(header_tf_id_peek == header_tf_id);
160
161 let mut out_tf_id = header_tf_id;
162 let mut out_ms_id = ms_id;
163 let mut out_group_track = group_track;
164
165 if let Some(&op_id) = sc_start_op_id {
166 let instantiation = job.setup_transforms_from_op(
167 ms_id,
168 op_id,
169 input_field,
170 group_track,
171 None,
172 prebound_outputs,
173 );
174 trailer_output_field = instantiation.next_input_field;
175 job.job_data.tf_mgr.transforms[header_tf_id].successor =
176 Some(instantiation.tfs_begin);
177
178 out_tf_id = instantiation.tfs_end;
179 out_ms_id = instantiation.next_match_set;
180 out_group_track = instantiation.next_group_track;
181 }
182
183 job.job_data
184 .field_mgr
185 .inc_field_refcount(trailer_output_field, 2);
186
187 let mut trailer_tf_state = TransformState::new(
188 trailer_output_field,
189 trailer_output_field,
190 out_ms_id,
191 desired_batch_size,
192 Some(op_id),
193 out_group_track,
194 );
195
196 let parent_group_track_parent =
197 job.job_data.group_track_manager.group_tracks[parent_group_track]
198 .borrow()
199 .parent_group_track_id();
200 let next_actor_ref = job.job_data.match_set_mgr.match_sets[out_ms_id]
201 .action_buffer
202 .borrow()
203 .next_actor_ref();
204 trailer_tf_state.output_group_track_id =
205 job.job_data.group_track_manager.add_group_track(
206 &job.job_data.match_set_mgr,
207 parent_group_track_parent,
208 out_ms_id,
209 next_actor_ref,
210 );
211
212 #[cfg(feature = "debug_state")]
213 {
214 job.job_data.group_track_manager.group_tracks
215 [trailer_tf_state.output_group_track_id]
216 .borrow_mut()
217 .corresponding_header = Some(group_track);
218 }
219
220 let trailer_tf_id = add_transform_to_job(
221 &mut job.job_data,
222 &mut job.transform_data,
223 trailer_tf_state,
224 Box::new(TfChunksTrailer {}),
225 );
226 job.job_data.tf_mgr.transforms[out_tf_id].successor =
227 Some(trailer_tf_id);
228
229 TransformInstatiation::Multiple(OperatorInstantiation {
230 tfs_begin: header_tf_id,
231 tfs_end: trailer_tf_id,
232 next_input_field: trailer_output_field,
233 next_group_track: parent_group_track,
234 next_match_set: out_ms_id,
235 })
236 }
237
238 fn update_bb_for_op(
239 &self,
240 sess: &crate::context::SessionData,
241 ld: &mut crate::liveness_analysis::LivenessData,
242 _op_id: OperatorId,
243 op_n: super::operator::OffsetInChain,
244 cn: &crate::chain::Chain,
245 bb_id: crate::liveness_analysis::BasicBlockId,
246 ) -> bool {
247 ld.basic_blocks[bb_id]
248 .calls
249 .push(cn.subchains[self.subchain_idx].into_bb_id());
250 ld.split_bb_at_call(sess, bb_id, op_n);
251 true
252 }
253}
254
255pub fn create_op_chunks(
256 stride: usize,
257 subchain: impl IntoIterator<Item = Box<dyn Operator>>,
258) -> Result<Box<dyn Operator>, OperatorCreationError> {
259 create_op_chunks_with_spans(
260 stride,
261 Span::Generated,
262 subchain.into_iter().map(|v| (v, Span::Generated)),
263 )
264}
265
266pub fn parse_op_chunks(
267 sess: &mut SessionSetupData,
268 arg: &mut Argument,
269) -> Result<Box<dyn Operator>, TypelineError> {
270 let expr = CallExpr::from_argument_mut(arg)?;
271
272 let stride_arg = expr.require_nth_arg(0, "stride")?;
273
274 let stride = stride_arg.expect_int(expr.op_name, true)?;
275 let stride_span = stride_arg.span;
276
277 let args = std::mem::take(arg.expect_arg_array_mut()?);
278
279 let mut subchain = Vec::new();
280 for arg in args.into_iter().skip(2) {
281 let span = arg.span;
282 let op = sess.parse_argument(arg)?;
283 subchain.push((op, span));
284 }
285
286 Ok(create_op_chunks_with_spans(stride, stride_span, subchain)?)
287}
288
289impl<'a> Transform<'a> for TfChunksHeader {
290 fn display_name(
291 &self,
292 _jd: &JobData,
293 _tf_id: TransformId,
294 ) -> super::transform::DefaultTransformName {
295 "chunks_header".into()
296 }
297 fn update(&mut self, jd: &mut JobData<'a>, tf_id: TransformId) {
298 let (batch_size, ps) = jd.tf_mgr.claim_batch(tf_id);
299 if batch_size == 0 {
300 jd.tf_mgr.submit_batch(
301 tf_id,
302 batch_size,
303 ps.group_to_truncate,
304 ps.input_done,
305 );
306 return;
307 }
308 let tf = &jd.tf_mgr.transforms[tf_id];
309
310 let in_group_track_id = tf.input_group_track_id;
311 let out_group_track_id = tf.output_group_track_id;
312
313 let mut group_track = jd
314 .group_track_manager
315 .borrow_group_track_mut(out_group_track_id);
316
317 group_track.apply_field_actions(&jd.match_set_mgr);
318 let mut parent_record_group_iter =
319 jd.group_track_manager.lookup_group_track_iter(
320 GroupTrackIterRef {
321 track_id: in_group_track_id,
322 iter_id: self.parent_group_track_iter,
323 },
324 &jd.match_set_mgr,
325 );
326
327 let stride = self.stride;
328
329 let mut size_rem = batch_size;
330
331 let gs_rem = parent_record_group_iter.group_len_rem().min(size_rem);
332
333 let append_prev = self.curr_stride_rem != stride;
334
335 let appendable = self.curr_stride_rem.min(gs_rem);
336
337 if append_prev {
338 let idx = group_track.group_lengths.len() - 1;
339 group_track.group_lengths.add_value(idx, appendable);
340 parent_record_group_iter.next_n_fields(appendable);
341 self.curr_stride_rem -= appendable;
342 size_rem -= appendable;
343 let eog = parent_record_group_iter.is_end_of_group(ps.input_done);
344 if eog && parent_record_group_iter.try_next_group() {
345 self.curr_stride_rem = stride;
346 }
347 if (!eog && self.curr_stride_rem > 0) || size_rem == 0 {
348 parent_record_group_iter
349 .store_iter(self.parent_group_track_iter);
350 jd.tf_mgr.submit_batch_ready_for_more(tf_id, batch_size, ps);
351 return;
352 }
353 }
354
355 group_track
356 .group_lengths
357 .promote_to_size_class_of_value(self.stride);
358
359 loop {
360 let gs_rem =
361 parent_record_group_iter.group_len_rem().min(size_rem);
362 parent_record_group_iter.next_n_fields(gs_rem);
363
364 let (full_groups, partial_group) = gs_rem.div_rem(&stride);
365 let have_partial_group = partial_group != 0;
366 let group_count = full_groups + usize::from(have_partial_group);
367 group_track
369 .parent_group_advancement
370 .push_back_truncated(usize::from(self.starting_new_group));
371 self.starting_new_group = false;
372
373 group_track
374 .parent_group_advancement
375 .extend_truncated(iter::repeat(0).take(group_count - 1));
376 group_track
377 .group_lengths
378 .extend_truncated(iter::repeat(stride).take(full_groups));
379 if have_partial_group {
380 group_track.group_lengths.push_back_truncated(partial_group);
381 }
382 size_rem -= gs_rem;
383
384 if size_rem == 0 {
385 self.curr_stride_rem = stride - partial_group;
386 break;
387 }
388 parent_record_group_iter.next_group();
389 self.starting_new_group = true;
390 }
391 parent_record_group_iter.store_iter(self.parent_group_track_iter);
392
393 jd.tf_mgr.submit_batch_ready_for_more(tf_id, batch_size, ps);
394 }
395}
396impl<'a> Transform<'a> for TfChunksTrailer {
397 fn display_name(
398 &self,
399 _jd: &JobData,
400 _tf_id: TransformId,
401 ) -> super::transform::DefaultTransformName {
402 "chunks_trailer".into()
403 }
404 fn update(&mut self, jd: &mut JobData<'a>, tf_id: TransformId) {
405 let (batch_size, ps) = jd.tf_mgr.claim_all(tf_id);
406
407 let tf = &jd.tf_mgr.transforms[tf_id];
408
409 let in_group_track_id = tf.input_group_track_id;
410 let out_group_track_id = tf.output_group_track_id;
411
412 jd.group_track_manager.pass_on_leading_groups_to_parent(
413 &jd.match_set_mgr,
414 in_group_track_id,
415 batch_size,
416 ps.input_done,
417 out_group_track_id,
418 );
419
420 jd.tf_mgr.submit_batch(
421 tf_id,
422 batch_size,
423 ps.group_to_truncate,
424 ps.input_done,
425 );
426 }
427}