typeline_core/operators/
chunks.rs

1use std::iter;
2
3use num::Integer;
4
5use crate::{
6    chain::{ChainId, SubchainIndex},
7    cli::call_expr::{Argument, CallExpr, Span},
8    job::{add_transform_to_job, Job, JobData},
9    operators::operator::TransformInstatiation,
10    options::session_setup::SessionSetupData,
11    record_data::{
12        group_track::{GroupTrackIterId, GroupTrackIterRef},
13        iter_hall::IterKind,
14    },
15    typeline_error::TypelineError,
16    utils::indexing_type::IndexingType,
17};
18
19use super::{
20    errors::OperatorCreationError,
21    nop::create_op_nop,
22    operator::{
23        Operator, OperatorDataId, OperatorId, OperatorInstantiation,
24        OperatorOffsetInChain, PreboundOutputsMap,
25    },
26    transform::{Transform, TransformId, TransformState},
27};
28
29pub struct OpChunks {
30    pub subchain: Vec<(Box<dyn Operator>, Span)>,
31    pub subchain_idx: SubchainIndex,
32    pub stride: usize,
33}
34pub struct TfChunksHeader {
35    parent_group_track_iter: GroupTrackIterId,
36    stride: usize,
37    curr_stride_rem: usize,
38    starting_new_group: bool,
39}
40pub struct TfChunksTrailer {}
41
42pub fn create_op_chunks_with_spans(
43    stride: usize,
44    stride_span: Span,
45    subchain: impl IntoIterator<Item = (Box<dyn Operator>, Span)>,
46) -> Result<Box<dyn Operator>, OperatorCreationError> {
47    if stride == 0 {
48        return Err(OperatorCreationError::new(
49            "chunk stride cannot be zero",
50            stride_span,
51        ));
52    }
53
54    let mut subchain = subchain.into_iter().collect::<Vec<_>>();
55    if subchain.is_empty() {
56        subchain.push((create_op_nop(), Span::Generated));
57    }
58    Ok(Box::new(OpChunks {
59        subchain,
60        subchain_idx: SubchainIndex::MAX_VALUE,
61        stride,
62    }))
63}
64
65impl Operator for OpChunks {
66    fn default_name(&self) -> super::operator::OperatorName {
67        "chunks".into()
68    }
69
70    fn output_count(
71        &self,
72        _sess: &crate::context::SessionData,
73        _op_id: OperatorId,
74    ) -> usize {
75        0
76    }
77
78    fn has_dynamic_outputs(
79        &self,
80        _sess: &crate::context::SessionData,
81        _op_id: OperatorId,
82    ) -> bool {
83        false
84    }
85
86    fn output_field_kind(
87        &self,
88        _sess: &crate::context::SessionData,
89        _op_id: OperatorId,
90    ) -> super::operator::OutputFieldKind {
91        super::operator::OutputFieldKind::SameAsInput
92    }
93
94    fn setup(
95        &mut self,
96        sess: &mut SessionSetupData,
97        op_data_id: OperatorDataId,
98        chain_id: ChainId,
99        offset_in_chain: OperatorOffsetInChain,
100        span: Span,
101    ) -> Result<OperatorId, TypelineError> {
102        let op_id = sess.add_op(op_data_id, chain_id, offset_in_chain, span);
103        self.subchain_idx = sess.chains[chain_id].subchains.next_idx();
104        sess.setup_subchain(chain_id, std::mem::take(&mut self.subchain))?;
105        Ok(op_id)
106    }
107
108    fn build_transforms<'a>(
109        &'a self,
110        job: &mut Job<'a>,
111        tf_state: &mut TransformState,
112        op_id: OperatorId,
113        prebound_outputs: &PreboundOutputsMap,
114    ) -> super::operator::TransformInstatiation<'a> {
115        let chain_id =
116            job.job_data.session_data.operator_bases[op_id].chain_id;
117        let subchain_id = job.job_data.session_data.chains[chain_id].subchains
118            [self.subchain_idx];
119        let sc_start_op_id = job.job_data.session_data.chains[subchain_id]
120            .operators
121            .first();
122        let ms_id = tf_state.match_set_id;
123        let desired_batch_size = tf_state.desired_batch_size;
124        let input_field = tf_state.input_field;
125
126        let ms = &mut job.job_data.match_set_mgr.match_sets[ms_id];
127        let next_actor_id = ms.action_buffer.borrow().next_actor_ref();
128        let parent_group_track = tf_state.input_group_track_id;
129
130        let header_tf_id_peek = job.job_data.tf_mgr.transforms.peek_claim_id();
131
132        let parent_group_track_iter =
133            job.job_data.group_track_manager.claim_group_track_iter(
134                parent_group_track,
135                next_actor_id.get_id(),
136                IterKind::Transform(header_tf_id_peek),
137            );
138        let group_track = job.job_data.group_track_manager.add_group_track(
139            &job.job_data.match_set_mgr,
140            Some(parent_group_track),
141            ms_id,
142            next_actor_id,
143        );
144        tf_state.output_group_track_id = group_track;
145
146        let mut trailer_output_field = input_field;
147
148        let header_tf_id = add_transform_to_job(
149            &mut job.job_data,
150            &mut job.transform_data,
151            tf_state.clone(),
152            Box::new(TfChunksHeader {
153                parent_group_track_iter,
154                stride: self.stride,
155                curr_stride_rem: self.stride,
156                starting_new_group: false,
157            }),
158        );
159        debug_assert!(header_tf_id_peek == header_tf_id);
160
161        let mut out_tf_id = header_tf_id;
162        let mut out_ms_id = ms_id;
163        let mut out_group_track = group_track;
164
165        if let Some(&op_id) = sc_start_op_id {
166            let instantiation = job.setup_transforms_from_op(
167                ms_id,
168                op_id,
169                input_field,
170                group_track,
171                None,
172                prebound_outputs,
173            );
174            trailer_output_field = instantiation.next_input_field;
175            job.job_data.tf_mgr.transforms[header_tf_id].successor =
176                Some(instantiation.tfs_begin);
177
178            out_tf_id = instantiation.tfs_end;
179            out_ms_id = instantiation.next_match_set;
180            out_group_track = instantiation.next_group_track;
181        }
182
183        job.job_data
184            .field_mgr
185            .inc_field_refcount(trailer_output_field, 2);
186
187        let mut trailer_tf_state = TransformState::new(
188            trailer_output_field,
189            trailer_output_field,
190            out_ms_id,
191            desired_batch_size,
192            Some(op_id),
193            out_group_track,
194        );
195
196        let parent_group_track_parent =
197            job.job_data.group_track_manager.group_tracks[parent_group_track]
198                .borrow()
199                .parent_group_track_id();
200        let next_actor_ref = job.job_data.match_set_mgr.match_sets[out_ms_id]
201            .action_buffer
202            .borrow()
203            .next_actor_ref();
204        trailer_tf_state.output_group_track_id =
205            job.job_data.group_track_manager.add_group_track(
206                &job.job_data.match_set_mgr,
207                parent_group_track_parent,
208                out_ms_id,
209                next_actor_ref,
210            );
211
212        #[cfg(feature = "debug_state")]
213        {
214            job.job_data.group_track_manager.group_tracks
215                [trailer_tf_state.output_group_track_id]
216                .borrow_mut()
217                .corresponding_header = Some(group_track);
218        }
219
220        let trailer_tf_id = add_transform_to_job(
221            &mut job.job_data,
222            &mut job.transform_data,
223            trailer_tf_state,
224            Box::new(TfChunksTrailer {}),
225        );
226        job.job_data.tf_mgr.transforms[out_tf_id].successor =
227            Some(trailer_tf_id);
228
229        TransformInstatiation::Multiple(OperatorInstantiation {
230            tfs_begin: header_tf_id,
231            tfs_end: trailer_tf_id,
232            next_input_field: trailer_output_field,
233            next_group_track: parent_group_track,
234            next_match_set: out_ms_id,
235        })
236    }
237
238    fn update_bb_for_op(
239        &self,
240        sess: &crate::context::SessionData,
241        ld: &mut crate::liveness_analysis::LivenessData,
242        _op_id: OperatorId,
243        op_n: super::operator::OffsetInChain,
244        cn: &crate::chain::Chain,
245        bb_id: crate::liveness_analysis::BasicBlockId,
246    ) -> bool {
247        ld.basic_blocks[bb_id]
248            .calls
249            .push(cn.subchains[self.subchain_idx].into_bb_id());
250        ld.split_bb_at_call(sess, bb_id, op_n);
251        true
252    }
253}
254
255pub fn create_op_chunks(
256    stride: usize,
257    subchain: impl IntoIterator<Item = Box<dyn Operator>>,
258) -> Result<Box<dyn Operator>, OperatorCreationError> {
259    create_op_chunks_with_spans(
260        stride,
261        Span::Generated,
262        subchain.into_iter().map(|v| (v, Span::Generated)),
263    )
264}
265
266pub fn parse_op_chunks(
267    sess: &mut SessionSetupData,
268    arg: &mut Argument,
269) -> Result<Box<dyn Operator>, TypelineError> {
270    let expr = CallExpr::from_argument_mut(arg)?;
271
272    let stride_arg = expr.require_nth_arg(0, "stride")?;
273
274    let stride = stride_arg.expect_int(expr.op_name, true)?;
275    let stride_span = stride_arg.span;
276
277    let args = std::mem::take(arg.expect_arg_array_mut()?);
278
279    let mut subchain = Vec::new();
280    for arg in args.into_iter().skip(2) {
281        let span = arg.span;
282        let op = sess.parse_argument(arg)?;
283        subchain.push((op, span));
284    }
285
286    Ok(create_op_chunks_with_spans(stride, stride_span, subchain)?)
287}
288
289impl<'a> Transform<'a> for TfChunksHeader {
290    fn display_name(
291        &self,
292        _jd: &JobData,
293        _tf_id: TransformId,
294    ) -> super::transform::DefaultTransformName {
295        "chunks_header".into()
296    }
297    fn update(&mut self, jd: &mut JobData<'a>, tf_id: TransformId) {
298        let (batch_size, ps) = jd.tf_mgr.claim_batch(tf_id);
299        if batch_size == 0 {
300            jd.tf_mgr.submit_batch(
301                tf_id,
302                batch_size,
303                ps.group_to_truncate,
304                ps.input_done,
305            );
306            return;
307        }
308        let tf = &jd.tf_mgr.transforms[tf_id];
309
310        let in_group_track_id = tf.input_group_track_id;
311        let out_group_track_id = tf.output_group_track_id;
312
313        let mut group_track = jd
314            .group_track_manager
315            .borrow_group_track_mut(out_group_track_id);
316
317        group_track.apply_field_actions(&jd.match_set_mgr);
318        let mut parent_record_group_iter =
319            jd.group_track_manager.lookup_group_track_iter(
320                GroupTrackIterRef {
321                    track_id: in_group_track_id,
322                    iter_id: self.parent_group_track_iter,
323                },
324                &jd.match_set_mgr,
325            );
326
327        let stride = self.stride;
328
329        let mut size_rem = batch_size;
330
331        let gs_rem = parent_record_group_iter.group_len_rem().min(size_rem);
332
333        let append_prev = self.curr_stride_rem != stride;
334
335        let appendable = self.curr_stride_rem.min(gs_rem);
336
337        if append_prev {
338            let idx = group_track.group_lengths.len() - 1;
339            group_track.group_lengths.add_value(idx, appendable);
340            parent_record_group_iter.next_n_fields(appendable);
341            self.curr_stride_rem -= appendable;
342            size_rem -= appendable;
343            let eog = parent_record_group_iter.is_end_of_group(ps.input_done);
344            if eog && parent_record_group_iter.try_next_group() {
345                self.curr_stride_rem = stride;
346            }
347            if (!eog && self.curr_stride_rem > 0) || size_rem == 0 {
348                parent_record_group_iter
349                    .store_iter(self.parent_group_track_iter);
350                jd.tf_mgr.submit_batch_ready_for_more(tf_id, batch_size, ps);
351                return;
352            }
353        }
354
355        group_track
356            .group_lengths
357            .promote_to_size_class_of_value(self.stride);
358
359        loop {
360            let gs_rem =
361                parent_record_group_iter.group_len_rem().min(size_rem);
362            parent_record_group_iter.next_n_fields(gs_rem);
363
364            let (full_groups, partial_group) = gs_rem.div_rem(&stride);
365            let have_partial_group = partial_group != 0;
366            let group_count = full_groups + usize::from(have_partial_group);
367            // FIXME: this is wrong. count skipped zero groups
368            group_track
369                .parent_group_advancement
370                .push_back_truncated(usize::from(self.starting_new_group));
371            self.starting_new_group = false;
372
373            group_track
374                .parent_group_advancement
375                .extend_truncated(iter::repeat(0).take(group_count - 1));
376            group_track
377                .group_lengths
378                .extend_truncated(iter::repeat(stride).take(full_groups));
379            if have_partial_group {
380                group_track.group_lengths.push_back_truncated(partial_group);
381            }
382            size_rem -= gs_rem;
383
384            if size_rem == 0 {
385                self.curr_stride_rem = stride - partial_group;
386                break;
387            }
388            parent_record_group_iter.next_group();
389            self.starting_new_group = true;
390        }
391        parent_record_group_iter.store_iter(self.parent_group_track_iter);
392
393        jd.tf_mgr.submit_batch_ready_for_more(tf_id, batch_size, ps);
394    }
395}
396impl<'a> Transform<'a> for TfChunksTrailer {
397    fn display_name(
398        &self,
399        _jd: &JobData,
400        _tf_id: TransformId,
401    ) -> super::transform::DefaultTransformName {
402        "chunks_trailer".into()
403    }
404    fn update(&mut self, jd: &mut JobData<'a>, tf_id: TransformId) {
405        let (batch_size, ps) = jd.tf_mgr.claim_all(tf_id);
406
407        let tf = &jd.tf_mgr.transforms[tf_id];
408
409        let in_group_track_id = tf.input_group_track_id;
410        let out_group_track_id = tf.output_group_track_id;
411
412        jd.group_track_manager.pass_on_leading_groups_to_parent(
413            &jd.match_set_mgr,
414            in_group_track_id,
415            batch_size,
416            ps.input_done,
417            out_group_track_id,
418        );
419
420        jd.tf_mgr.submit_batch(
421            tf_id,
422            batch_size,
423            ps.group_to_truncate,
424            ps.input_done,
425        );
426    }
427}