1use crate::{
2 bag_of_words::BagOfWordsFeatureGroup,
3 bag_of_words_cosine_similarity::BagOfWordsCosineSimilarityFeatureGroup,
4 identity::IdentityFeatureGroup, normalized::NormalizedFeatureGroup,
5 one_hot_encoded::OneHotEncodedFeatureGroup, word_embedding::WordEmbeddingFeatureGroup,
6 FeatureGroup,
7};
8use ndarray::prelude::*;
9use tangram_table::prelude::*;
10
11pub fn compute_features_array_f32(
13 table: &TableView,
14 feature_groups: &[FeatureGroup],
15 progress: &impl Fn(),
16) -> Array2<f32> {
17 let n_features = feature_groups
18 .iter()
19 .map(|feature_group| feature_group.n_features())
20 .sum::<usize>();
21 let mut features = Array::zeros((table.nrows(), n_features));
22 let mut feature_index = 0;
23 for feature_group in feature_groups.iter() {
24 let n_features_in_group = feature_group.n_features();
25 let slice = s![.., feature_index..feature_index + n_features_in_group];
26 let features = features.slice_mut(slice);
27 compute_features_array_f32_for_feature_group(table, feature_group, features, progress);
28 feature_index += n_features_in_group;
29 }
30 features
31}
32
33fn compute_features_array_f32_for_feature_group(
34 table: &TableView,
35 feature_group: &FeatureGroup,
36 features: ArrayViewMut2<f32>,
37 progress: &impl Fn(),
38) {
39 match &feature_group {
40 FeatureGroup::Identity(feature_group) => {
41 compute_features_array_f32_for_identity_feature_group(
42 table,
43 feature_group,
44 features,
45 progress,
46 )
47 }
48 FeatureGroup::Normalized(feature_group) => {
49 compute_features_array_f32_for_normalized_feature_group(
50 table,
51 feature_group,
52 features,
53 progress,
54 )
55 }
56 FeatureGroup::OneHotEncoded(feature_group) => {
57 compute_features_array_f32_for_one_hot_encoded_feature_group(
58 table,
59 feature_group,
60 features,
61 progress,
62 )
63 }
64 FeatureGroup::BagOfWords(feature_group) => {
65 compute_features_array_f32_for_bag_of_words_feature_group(
66 table,
67 feature_group,
68 features,
69 progress,
70 )
71 }
72 FeatureGroup::BagOfWordsCosineSimilarity(feature_group) => {
73 compute_features_array_f32_for_bag_of_words_cosine_similarity_feature_group(
74 table,
75 feature_group,
76 features,
77 progress,
78 )
79 }
80 FeatureGroup::WordEmbedding(feature_group) => {
81 compute_features_array_f32_for_word_embedding_feature_group(
82 table,
83 feature_group,
84 features,
85 progress,
86 )
87 }
88 }
89}
90
91fn compute_features_array_f32_for_identity_feature_group(
92 table: &TableView,
93 feature_group: &IdentityFeatureGroup,
94 features: ArrayViewMut2<f32>,
95 progress: &impl Fn(),
96) {
97 let source_column = table
99 .columns()
100 .iter()
101 .find(|column| column.name() == Some(&feature_group.source_column_name))
102 .unwrap();
103 feature_group.compute_array_f32(features, source_column.view(), progress);
104}
105
106fn compute_features_array_f32_for_normalized_feature_group(
107 table: &TableView,
108 feature_group: &NormalizedFeatureGroup,
109 features: ArrayViewMut2<f32>,
110 progress: &impl Fn(),
111) {
112 let source_column = table
114 .columns()
115 .iter()
116 .find(|column| column.name() == Some(&feature_group.source_column_name))
117 .unwrap();
118 feature_group.compute_array_f32(features, source_column.view(), progress)
119}
120
121fn compute_features_array_f32_for_one_hot_encoded_feature_group(
122 table: &TableView,
123 feature_group: &OneHotEncodedFeatureGroup,
124 features: ArrayViewMut2<f32>,
125 progress: &impl Fn(),
126) {
127 let source_column = table
129 .columns()
130 .iter()
131 .find(|column| column.name() == Some(&feature_group.source_column_name))
132 .unwrap();
133 feature_group.compute_array_f32(features, source_column.view(), progress);
134}
135
136fn compute_features_array_f32_for_bag_of_words_feature_group(
137 table: &TableView,
138 feature_group: &BagOfWordsFeatureGroup,
139 features: ArrayViewMut2<f32>,
140 progress: &impl Fn(),
141) {
142 let source_column = table
144 .columns()
145 .iter()
146 .find(|column| column.name() == Some(&feature_group.source_column_name))
147 .unwrap();
148 feature_group.compute_array_f32(features, source_column.view(), progress);
149}
150
151fn compute_features_array_f32_for_bag_of_words_cosine_similarity_feature_group(
152 table: &TableView,
153 feature_group: &BagOfWordsCosineSimilarityFeatureGroup,
154 features: ArrayViewMut2<f32>,
155 progress: &impl Fn(),
156) {
157 let source_column_a = table
160 .columns()
161 .iter()
162 .find(|column| column.name().unwrap() == feature_group.source_column_name_a)
163 .unwrap();
164 let source_column_b = table
165 .columns()
166 .iter()
167 .find(|column| column.name().unwrap() == feature_group.source_column_name_b)
168 .unwrap();
169 feature_group.compute_array_f32(
170 features,
171 source_column_a.view(),
172 source_column_b.view(),
173 progress,
174 );
175}
176
177fn compute_features_array_f32_for_word_embedding_feature_group(
178 table: &TableView,
179 feature_group: &WordEmbeddingFeatureGroup,
180 features: ArrayViewMut2<f32>,
181 progress: &impl Fn(),
182) {
183 let source_column = table
185 .columns()
186 .iter()
187 .find(|column| column.name() == Some(&feature_group.source_column_name))
188 .unwrap();
189 feature_group.compute_array_f32(features, source_column.view(), progress);
190}
191
192pub fn compute_features_table(
194 table: &TableView,
195 feature_groups: &[FeatureGroup],
196 progress: &impl Fn(u64),
197) -> Table {
198 let mut features = Table::new(Vec::new(), Vec::new());
199 for feature_group in feature_groups.iter() {
200 compute_features_table_for_feature_group(table, feature_group, &mut features, progress)
201 }
202 features
203}
204
205fn compute_features_table_for_feature_group(
206 table: &TableView,
207 feature_group: &FeatureGroup,
208 features: &mut Table,
209 progress: &impl Fn(u64),
210) {
211 match &feature_group {
212 FeatureGroup::Identity(feature_group) => compute_features_table_for_identity_feature_group(
213 table,
214 feature_group,
215 features,
216 progress,
217 ),
218 FeatureGroup::Normalized(feature_group) => {
219 compute_features_table_for_normalized_feature_group(
220 table,
221 feature_group,
222 features,
223 progress,
224 )
225 }
226 FeatureGroup::OneHotEncoded(_) => unimplemented!(),
227 FeatureGroup::BagOfWords(feature_group) => {
228 compute_features_table_for_bag_of_words_feature_group(
229 table,
230 feature_group,
231 features,
232 progress,
233 )
234 }
235 FeatureGroup::BagOfWordsCosineSimilarity(feature_group) => {
236 compute_features_table_for_bag_of_words_cosine_similarity_feature_group(
237 table,
238 feature_group,
239 features,
240 progress,
241 )
242 }
243 FeatureGroup::WordEmbedding(feature_group) => {
244 compute_features_table_for_word_embedding_feature_group(
245 table,
246 feature_group,
247 features,
248 progress,
249 )
250 }
251 };
252}
253
254fn compute_features_table_for_identity_feature_group(
255 table: &TableView,
256 feature_group: &IdentityFeatureGroup,
257 features: &mut Table,
258 progress: &impl Fn(u64),
259) {
260 let column = table
261 .columns()
262 .iter()
263 .find(|column| column.name().unwrap() == feature_group.source_column_name)
264 .unwrap();
265 let feature_column = feature_group.compute_table(column.view(), progress);
266 features.columns_mut().push(feature_column);
267}
268
269fn compute_features_table_for_normalized_feature_group(
270 table: &TableView,
271 feature_group: &NormalizedFeatureGroup,
272 features: &mut Table,
273 progress: &impl Fn(u64),
274) {
275 let column = table
276 .columns()
277 .iter()
278 .find(|column| column.name().unwrap() == feature_group.source_column_name)
279 .unwrap();
280 let feature_column = feature_group.compute_table(column.view(), progress);
281 features.columns_mut().push(feature_column);
282}
283
284fn compute_features_table_for_bag_of_words_feature_group(
285 table: &TableView,
286 feature_group: &BagOfWordsFeatureGroup,
287 features: &mut Table,
288 progress: &impl Fn(u64),
289) {
290 let source_column = table
292 .columns()
293 .iter()
294 .find(|column| column.name().unwrap() == feature_group.source_column_name)
295 .unwrap();
296 let columns = feature_group.compute_table(source_column.view(), progress);
297 for column in columns {
298 features.columns_mut().push(column);
299 }
300}
301
302fn compute_features_table_for_bag_of_words_cosine_similarity_feature_group(
303 table: &TableView,
304 feature_group: &BagOfWordsCosineSimilarityFeatureGroup,
305 features: &mut Table,
306 progress: &impl Fn(u64),
307) {
308 let source_column_a = table
310 .columns()
311 .iter()
312 .find(|column| column.name().unwrap() == feature_group.source_column_name_a)
313 .unwrap();
314 let source_column_b = table
315 .columns()
316 .iter()
317 .find(|column| column.name().unwrap() == feature_group.source_column_name_b)
318 .unwrap();
319 let column =
320 feature_group.compute_table(source_column_a.view(), source_column_b.view(), progress);
321 features.columns_mut().push(column);
322}
323
324fn compute_features_table_for_word_embedding_feature_group(
325 table: &TableView,
326 feature_group: &WordEmbeddingFeatureGroup,
327 features: &mut Table,
328 progress: &impl Fn(u64),
329) {
330 let source_column = table
332 .columns()
333 .iter()
334 .find(|column| column.name().unwrap() == feature_group.source_column_name)
335 .unwrap();
336 let columns = feature_group.compute_table(source_column.view(), progress);
337 for column in columns {
338 features.columns_mut().push(column);
339 }
340}
341
342pub fn compute_features_array_value<'a>(
343 table: &TableView<'a>,
344 feature_groups: &[FeatureGroup],
345 progress: &impl Fn(),
346) -> Array2<TableValue<'a>> {
347 let n_features = feature_groups.iter().map(|g| g.n_features()).sum::<usize>();
348 let mut features = Array::from_elem((table.nrows(), n_features), TableValue::Unknown);
349 let mut feature_index = 0;
350 for feature_group in feature_groups.iter() {
351 let n_features_in_group = feature_group.n_features();
352 let slice = s![.., feature_index..feature_index + n_features_in_group];
353 let features = features.slice_mut(slice);
354 compute_features_array_value_for_feature_group(table, feature_group, features, progress);
355 feature_index += n_features_in_group;
356 }
357 features
358}
359
360fn compute_features_array_value_for_feature_group(
361 table: &TableView,
362 feature_group: &FeatureGroup,
363 features: ArrayViewMut2<tangram_table::TableValue>,
364 progress: &impl Fn(),
365) {
366 match &feature_group {
367 FeatureGroup::Identity(feature_group) => {
368 compute_features_array_value_for_identity_feature_group(
369 table,
370 feature_group,
371 features,
372 progress,
373 )
374 }
375 FeatureGroup::Normalized(feature_group) => {
376 compute_features_array_value_for_normalized_feature_group(
377 table,
378 feature_group,
379 features,
380 progress,
381 )
382 }
383 FeatureGroup::OneHotEncoded(_) => unimplemented!(),
384 FeatureGroup::BagOfWords(feature_group) => {
385 compute_features_array_value_for_bag_of_words_feature_group(
386 table,
387 feature_group,
388 features,
389 progress,
390 )
391 }
392 FeatureGroup::BagOfWordsCosineSimilarity(feature_group) => {
393 compute_features_array_value_for_bag_of_words_cosine_similarity_feature_group(
394 table,
395 feature_group,
396 features,
397 progress,
398 )
399 }
400 FeatureGroup::WordEmbedding(feature_group) => {
401 compute_features_array_value_for_word_embedding_feature_group(
402 table,
403 feature_group,
404 features,
405 progress,
406 )
407 }
408 }
409}
410
411fn compute_features_array_value_for_identity_feature_group(
412 table: &TableView,
413 feature_group: &IdentityFeatureGroup,
414 features: ArrayViewMut2<tangram_table::TableValue>,
415 progress: &impl Fn(),
416) {
417 let source_column = table
418 .columns()
419 .iter()
420 .find(|column| column.name().unwrap() == feature_group.source_column_name)
421 .unwrap();
422 feature_group.compute_array_value(features, source_column.view(), progress);
423}
424
425fn compute_features_array_value_for_normalized_feature_group(
426 table: &TableView,
427 feature_group: &NormalizedFeatureGroup,
428 features: ArrayViewMut2<tangram_table::TableValue>,
429 progress: &impl Fn(),
430) {
431 let source_column = table
432 .columns()
433 .iter()
434 .find(|column| column.name().unwrap() == feature_group.source_column_name)
435 .unwrap();
436 feature_group.compute_array_value(features, source_column.view(), progress);
437}
438
439fn compute_features_array_value_for_bag_of_words_feature_group(
440 table: &TableView,
441 feature_group: &BagOfWordsFeatureGroup,
442 features: ArrayViewMut2<tangram_table::TableValue>,
443 progress: &impl Fn(),
444) {
445 let source_column = table
447 .columns()
448 .iter()
449 .find(|column| column.name().unwrap() == feature_group.source_column_name)
450 .unwrap();
451 feature_group.compute_array_value(features, source_column.view(), progress);
452}
453
454fn compute_features_array_value_for_bag_of_words_cosine_similarity_feature_group(
455 table: &TableView,
456 feature_group: &BagOfWordsCosineSimilarityFeatureGroup,
457 features: ArrayViewMut2<tangram_table::TableValue>,
458 progress: &impl Fn(),
459) {
460 let source_column_a = table
462 .columns()
463 .iter()
464 .find(|column| column.name().unwrap() == feature_group.source_column_name_a)
465 .unwrap();
466 let source_column_b = table
467 .columns()
468 .iter()
469 .find(|column| column.name().unwrap() == feature_group.source_column_name_b)
470 .unwrap();
471 feature_group.compute_array_value(
472 features,
473 source_column_a.view(),
474 source_column_b.view(),
475 progress,
476 );
477}
478
479fn compute_features_array_value_for_word_embedding_feature_group(
480 table: &TableView,
481 feature_group: &WordEmbeddingFeatureGroup,
482 features: ArrayViewMut2<tangram_table::TableValue>,
483 progress: &impl Fn(),
484) {
485 let source_column = table
487 .columns()
488 .iter()
489 .find(|column| column.name().unwrap() == feature_group.source_column_name)
490 .unwrap();
491 feature_group.compute_array_value(features, source_column.view(), progress);
492}