Skip to main content

spatialbench_arrow/
trip.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use crate::conversions::{decimal128_array_from_iter, to_arrow_timestamp_millis};
19use crate::{RecordBatchIterator, DEFAULT_BATCH_SIZE};
20use arrow::array::{BinaryArray, Int64Array, RecordBatch, TimestampMillisecondArray};
21use arrow::datatypes::{DataType, Field, Schema, SchemaRef, TimeUnit};
22use geo::Geometry;
23use geozero::{CoordDimensions, ToWkb};
24use spatialbench::generators::{Trip, TripGenerator, TripGeneratorIterator};
25use std::sync::{Arc, LazyLock, Mutex};
26
27// Thread-safe wrapper for TripGeneratorIterator
28struct ThreadSafeTripGenerator {
29    generator: Mutex<TripGeneratorIterator>,
30}
31
32impl ThreadSafeTripGenerator {
33    fn new(generator: TripGenerator) -> Self {
34        Self {
35            generator: Mutex::new(generator.iter()),
36        }
37    }
38
39    fn next_batch(&self, batch_size: usize) -> Vec<Trip> {
40        let mut generator = self.generator.lock().unwrap();
41        generator.by_ref().take(batch_size).collect()
42    }
43}
44
45// This is safe because we're using Mutex for synchronization
46unsafe impl Send for ThreadSafeTripGenerator {}
47unsafe impl Sync for ThreadSafeTripGenerator {}
48
49pub struct TripArrow {
50    generator: ThreadSafeTripGenerator,
51    batch_size: usize,
52    schema: SchemaRef,
53}
54
55impl TripArrow {
56    pub fn new(generator: TripGenerator) -> Self {
57        Self {
58            generator: ThreadSafeTripGenerator::new(generator),
59            batch_size: DEFAULT_BATCH_SIZE,
60            schema: TRIP_SCHEMA.clone(),
61        }
62    }
63
64    pub fn with_batch_size(mut self, batch_size: usize) -> Self {
65        self.batch_size = batch_size;
66        self
67    }
68}
69
70impl RecordBatchIterator for TripArrow {
71    fn schema(&self) -> &SchemaRef {
72        &self.schema
73    }
74}
75
76impl Iterator for TripArrow {
77    type Item = RecordBatch;
78
79    fn next(&mut self) -> Option<Self::Item> {
80        // Get next rows to convert
81        let rows = self.generator.next_batch(self.batch_size);
82        if rows.is_empty() {
83            return None;
84        }
85
86        // Convert column by column
87        let t_tripkey = Int64Array::from_iter_values(rows.iter().map(|row| row.t_tripkey));
88        let t_custkey = Int64Array::from_iter_values(rows.iter().map(|row| row.t_custkey));
89        let t_driverkey = Int64Array::from_iter_values(rows.iter().map(|row| row.t_driverkey));
90        let t_vehiclekey = Int64Array::from_iter_values(rows.iter().map(|row| row.t_vehiclekey));
91        let t_pickuptime = TimestampMillisecondArray::from_iter_values(
92            rows.iter()
93                .map(|row| to_arrow_timestamp_millis(row.t_pickuptime)),
94        );
95        let t_dropofftime = TimestampMillisecondArray::from_iter_values(
96            rows.iter()
97                .map(|row| to_arrow_timestamp_millis(row.t_dropofftime)),
98        );
99        let t_fare = decimal128_array_from_iter(rows.iter().map(|row| row.t_fare));
100        let t_tip = decimal128_array_from_iter(rows.iter().map(|row| row.t_tip));
101        let t_totalamount = decimal128_array_from_iter(rows.iter().map(|row| row.t_totalamount));
102        let t_distance = decimal128_array_from_iter(rows.iter().map(|row| row.t_distance));
103        let t_pickuploc = BinaryArray::from_iter_values(rows.iter().map(|row| {
104            Geometry::Point(row.t_pickuploc)
105                .to_wkb(CoordDimensions::xy())
106                .expect("Failed to convert pickup location to WKB")
107        }));
108        let t_dropoffloc = BinaryArray::from_iter_values(rows.iter().map(|row| {
109            Geometry::Point(row.t_dropoffloc)
110                .to_wkb(CoordDimensions::xy())
111                .expect("Failed to convert dropoff location to WKB")
112        }));
113
114        let batch = RecordBatch::try_new(
115            Arc::clone(&self.schema),
116            vec![
117                Arc::new(t_tripkey),
118                Arc::new(t_custkey),
119                Arc::new(t_driverkey),
120                Arc::new(t_vehiclekey),
121                Arc::new(t_pickuptime),
122                Arc::new(t_dropofftime),
123                Arc::new(t_fare),
124                Arc::new(t_tip),
125                Arc::new(t_totalamount),
126                Arc::new(t_distance),
127                Arc::new(t_pickuploc),
128                Arc::new(t_dropoffloc),
129            ],
130        )
131        .unwrap();
132
133        Some(batch)
134    }
135}
136
137/// Schema for the Trip table
138static TRIP_SCHEMA: LazyLock<SchemaRef> = LazyLock::new(make_trip_schema);
139
140fn make_trip_schema() -> SchemaRef {
141    Arc::new(Schema::new(vec![
142        Field::new("t_tripkey", DataType::Int64, false),
143        Field::new("t_custkey", DataType::Int64, false),
144        Field::new("t_driverkey", DataType::Int64, false),
145        Field::new("t_vehiclekey", DataType::Int64, false),
146        Field::new(
147            "t_pickuptime",
148            DataType::Timestamp(TimeUnit::Millisecond, None),
149            false,
150        ),
151        Field::new(
152            "t_dropofftime",
153            DataType::Timestamp(TimeUnit::Millisecond, None),
154            false,
155        ),
156        Field::new("t_fare", DataType::Decimal128(15, 5), false),
157        Field::new("t_tip", DataType::Decimal128(15, 5), false),
158        Field::new("t_totalamount", DataType::Decimal128(15, 5), false),
159        Field::new("t_distance", DataType::Decimal128(15, 5), false),
160        Field::new("t_pickuploc", DataType::Binary, false),
161        Field::new("t_dropoffloc", DataType::Binary, false),
162    ]))
163}