honeycomb_benches/
shift.rs

1//! # Description
2//!
3//! ## Routine
4//!
5//! The algorithm fetches all vertices that are not on the border of the map, fetch all identifiers
6//! of each respective vertices' neighbors. Then, for all vertices:
7//!
8//! - compute the average between neighbors
9//! - overwrite current vertex value with computed average
10//!
11//! ## Benchmark
12//!
13//! This binary is meant to be use to evaluate scalability of geometry-only kernels. It is
14//! parallelized using rayon, and the number of thread used for execution can be controlled using
15//! `taskset`. By controlling this, and the grid size, we can evaluate both strong and weak
16//! scaling characteristics.
17
18use honeycomb::kernels::remeshing::move_vertex_to_average;
19use rayon::prelude::*;
20
21use honeycomb::core::stm::{Transaction, TransactionControl};
22use honeycomb::prelude::{
23    CMap2, CMapBuilder, CoordsFloat, DartIdType, NULL_DART_ID, OrbitPolicy, VertexIdType,
24};
25
26use crate::cli::ShiftArgs;
27use crate::utils::hash_file;
28
29pub fn bench_shift<T: CoordsFloat>(args: ShiftArgs) -> CMap2<T> {
30    let mut instant = std::time::Instant::now();
31    let input_map = args.input.to_str().unwrap();
32    let input_hash = hash_file(input_map).unwrap();
33    let map: CMap2<T> = if input_map.ends_with(".cmap") {
34        CMapBuilder::<2, T>::from_cmap_file(input_map)
35            .build()
36            .unwrap()
37    } else if input_map.ends_with(".vtk") {
38        CMapBuilder::<2, T>::from_vtk_file(input_map)
39            .build()
40            .unwrap()
41    } else {
42        panic!(
43            "E: Unknown file format; only .cmap or .vtk files are supported for map initialization"
44        );
45    };
46    let build_time = instant.elapsed();
47
48    if args.no_conflict {
49        todo!("TODO: require a partitioning algorithm")
50    } else {
51        instant = std::time::Instant::now();
52        // fetch all vertices that are not on the boundary of the map
53        let tmp: Vec<(VertexIdType, Vec<VertexIdType>)> = map
54            .iter_vertices()
55            .filter_map(|v| {
56                if map
57                    .orbit(OrbitPolicy::Vertex, v as DartIdType)
58                    .any(|d| map.beta::<2>(d) == NULL_DART_ID)
59                {
60                    None
61                } else {
62                    Some((
63                        v,
64                        map.orbit(OrbitPolicy::Vertex, v as DartIdType)
65                            .map(|d| map.vertex_id(map.beta::<2>(d)))
66                            .collect(),
67                    ))
68                }
69            })
70            .collect();
71        let n_v = tmp.len();
72        let graph_time = instant.elapsed();
73        let n_threads = std::thread::available_parallelism()
74            .map(|v| v.get())
75            .unwrap_or(1);
76
77        println!("| shift benchmark");
78        println!("|-> input      : {input_map} (hash: {input_hash:#0x})");
79        println!("|-> backend    : rayon-iter with {n_threads} thread(s)",);
80        println!("|-> # of rounds: {}", args.n_rounds.get());
81        println!("|-+ init time  :");
82        println!("| |->   map built in {}ms", build_time.as_millis());
83        println!("| |-> graph built in {}ms", graph_time.as_millis());
84
85        println!(" Round | process_time | throughput(vertex/s) | n_transac_retry");
86        // main loop
87        let mut round = 0;
88        let mut process_time;
89        loop {
90            instant = std::time::Instant::now();
91            let n_retry: u32 = tmp
92                .par_iter()
93                .map(|(vid, neigh)| {
94                    let mut n = 0;
95                    Transaction::with_control(
96                        |_| {
97                            n += 1;
98                            TransactionControl::Retry
99                        },
100                        |trans| move_vertex_to_average(trans, &map, *vid, neigh),
101                    );
102                    n
103                })
104                .sum();
105            process_time = instant.elapsed().as_secs_f64();
106            println!(
107                " {:>5} | {:>12.6e} | {:>20.6e} | {:>15}",
108                round,
109                process_time,
110                n_v as f64 / process_time,
111                n_retry,
112            );
113
114            round += 1;
115            if round >= args.n_rounds.get() {
116                break;
117            }
118        }
119    }
120
121    map
122}