diff --git a/Cargo.toml b/Cargo.toml index a99feec..a987203 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "ext-sort" -version = "0.1.1" +version = "0.1.2" edition = "2021" license = "Unlicense" description = "rust external sort algorithm implementation" @@ -15,6 +15,7 @@ keywords = ["algorithms", "sort", "sorting", "external-sort", "external"] [dependencies] bytesize = { version = "^1.1", optional = true } +clap = { version = "^3.0", features = ["derive"], optional = true } deepsize = { version = "^0.2", optional = true } env_logger = { version = "^0.9", optional = true} log = "^0.4" @@ -30,6 +31,10 @@ rand = "^0.8" [features] memory-limit = ["deepsize"] +[[bin]] +name = "ext-sort" +required-features = ["bytesize", "clap", "env_logger"] + [[example]] name = "quickstart" required-features = ["bytesize", "env_logger"] diff --git a/README.md b/README.md index f2ad6fd..650fd3d 100644 --- a/README.md +++ b/README.md @@ -46,7 +46,7 @@ Activate `memory-limit` feature of the ext-sort crate on Cargo.toml: ```toml [dependencies] -ext-sort = { version = "^0.1.1", features = ["memory-limit"] } +ext-sort = { version = "^0.1.2", features = ["memory-limit"] } ``` ``` rust diff --git a/src/chunk.rs b/src/chunk.rs index fb6bdfe..966c124 100644 --- a/src/chunk.rs +++ b/src/chunk.rs @@ -22,7 +22,10 @@ impl Error for ExternalChunkError {} impl Display for ExternalChunkError { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "{}", self) + match self { + ExternalChunkError::IO(err) => write!(f, "{}", err), + ExternalChunkError::SerializationError(err) => write!(f, "{}", err), + } } } diff --git a/src/main.rs b/src/main.rs new file mode 100644 index 0000000..d8925c7 --- /dev/null +++ b/src/main.rs @@ -0,0 +1,225 @@ +use std::fs; +use std::io::{self, prelude::*}; +use std::path; +use std::process; + +use bytesize::ByteSize; +use clap::ArgEnum; +use env_logger; +use log; + +use ext_sort::buffer::mem::MemoryLimitedBufferBuilder; +use ext_sort::{ExternalSorter, ExternalSorterBuilder}; + +fn main() { + let arg_parser = build_arg_parser(); + + let log_level: LogLevel = arg_parser.value_of_t_or_exit("log_level"); + init_logger(log_level); + + let order: Order = arg_parser.value_of_t_or_exit("sort"); + let tmp_dir: Option<&str> = arg_parser.value_of("tmp_dir"); + let chunk_size = arg_parser.value_of("chunk_size").expect("value is required"); + let threads: Option = arg_parser + .is_present("threads") + .then(|| arg_parser.value_of_t_or_exit("threads")); + + let input = arg_parser.value_of("input").expect("value is required"); + let input_stream = match fs::File::open(input) { + Ok(file) => io::BufReader::new(file), + Err(err) => { + log::error!("input file opening error: {}", err); + process::exit(1); + } + }; + + let output = arg_parser.value_of("output").expect("value is required"); + let mut output_stream = match fs::File::create(output) { + Ok(file) => io::BufWriter::new(file), + Err(err) => { + log::error!("output file creation error: {}", err); + process::exit(1); + } + }; + + let mut sorter_builder = ExternalSorterBuilder::new(); + if let Some(threads) = threads { + sorter_builder = sorter_builder.with_threads_number(threads); + } + + if let Some(tmp_dir) = tmp_dir { + sorter_builder = sorter_builder.with_tmp_dir(path::Path::new(tmp_dir)); + } + + sorter_builder = sorter_builder.with_buffer(MemoryLimitedBufferBuilder::new( + chunk_size.parse::().expect("value is pre-validated").as_u64(), + )); + + let sorter: ExternalSorter = match sorter_builder.build() { + Ok(sorter) => sorter, + Err(err) => { + log::error!("sorter initialization error: {}", err); + process::exit(1); + } + }; + + let compare = |a: &String, b: &String| { + if order == Order::Asc { + a.cmp(&b) + } else { + a.cmp(&b).reverse() + } + }; + + let sorted_stream = match sorter.sort_by(input_stream.lines(), compare) { + Ok(sorted_stream) => sorted_stream, + Err(err) => { + log::error!("data sorting error: {}", err); + process::exit(1); + } + }; + + for line in sorted_stream { + let line = match line { + Ok(line) => line, + Err(err) => { + log::error!("sorting stream error: {}", err); + process::exit(1); + } + }; + if let Err(err) = output_stream.write_all(format!("{}\n", line).as_bytes()) { + log::error!("data saving error: {}", err); + process::exit(1); + }; + } + + if let Err(err) = output_stream.flush() { + log::error!("data flushing error: {}", err); + process::exit(1); + } +} + +#[derive(Copy, Clone, clap::ArgEnum)] +enum LogLevel { + Off, + Error, + Warn, + Info, + Debug, + Trace, +} + +impl LogLevel { + pub fn possible_values() -> impl Iterator> { + Self::value_variants().iter().filter_map(|v| v.to_possible_value()) + } +} + +impl std::str::FromStr for LogLevel { + type Err = String; + + fn from_str(s: &str) -> Result { + ::from_str(s, false) + } +} + +#[derive(Copy, Clone, PartialEq, clap::ArgEnum)] +enum Order { + Asc, + Desc, +} + +impl Order { + pub fn possible_values() -> impl Iterator> { + Order::value_variants().iter().filter_map(|v| v.to_possible_value()) + } +} + +impl std::str::FromStr for Order { + type Err = String; + + fn from_str(s: &str) -> Result { + ::from_str(s, false) + } +} + +fn build_arg_parser() -> clap::ArgMatches { + clap::App::new("ext-sort") + .author("Dmitry P. ") + .about("external sorter") + .arg( + clap::Arg::new("input") + .short('i') + .long("input") + .help("file to be sorted") + .required(true) + .takes_value(true), + ) + .arg( + clap::Arg::new("output") + .short('o') + .long("output") + .help("result file") + .required(true) + .takes_value(true), + ) + .arg( + clap::Arg::new("sort") + .short('s') + .long("sort") + .help("sorting order") + .takes_value(true) + .default_value("asc") + .possible_values(Order::possible_values()), + ) + .arg( + clap::Arg::new("log_level") + .short('l') + .long("loglevel") + .help("logging level") + .takes_value(true) + .default_value("info") + .possible_values(LogLevel::possible_values()), + ) + .arg( + clap::Arg::new("threads") + .short('t') + .long("threads") + .help("number of threads to use for parallel sorting") + .takes_value(true), + ) + .arg( + clap::Arg::new("tmp_dir") + .short('d') + .long("tmp-dir") + .help("directory to be used to store temporary data") + .takes_value(true), + ) + .arg( + clap::Arg::new("chunk_size") + .short('c') + .long("chunk-size") + .help("chunk size") + .required(true) + .takes_value(true) + .validator(|v| match v.parse::() { + Ok(_) => Ok(()), + Err(err) => Err(format!("Chunk size format incorrect: {}", err)), + }), + ) + .get_matches() +} + +fn init_logger(log_level: LogLevel) { + env_logger::Builder::new() + .filter_level(match log_level { + LogLevel::Off => log::LevelFilter::Off, + LogLevel::Error => log::LevelFilter::Error, + LogLevel::Warn => log::LevelFilter::Warn, + LogLevel::Info => log::LevelFilter::Info, + LogLevel::Debug => log::LevelFilter::Debug, + LogLevel::Trace => log::LevelFilter::Trace, + }) + .format_timestamp_millis() + .init(); +} diff --git a/src/merger.rs b/src/merger.rs index 248400b..6d6666c 100644 --- a/src/merger.rs +++ b/src/merger.rs @@ -1,28 +1,80 @@ //! Binary heap merger. +use std::cmp::Ordering; use std::collections::BinaryHeap; use std::error::Error; +/// Value wrapper binding custom compare function to a value. +struct OrderedWrapper +where + F: Fn(&T, &T) -> Ordering, +{ + value: T, + compare: F, +} + +impl OrderedWrapper +where + F: Fn(&T, &T) -> Ordering, +{ + fn wrap(value: T, compare: F) -> Self { + OrderedWrapper { value, compare } + } + + fn unwrap(self) -> T { + self.value + } +} + +impl PartialEq for OrderedWrapper +where + F: Fn(&T, &T) -> Ordering, +{ + fn eq(&self, other: &Self) -> bool { + self.cmp(other) == Ordering::Equal + } +} + +impl Eq for OrderedWrapper where F: Fn(&T, &T) -> Ordering {} + +impl PartialOrd for OrderedWrapper +where + F: Fn(&T, &T) -> Ordering, +{ + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} +impl Ord for OrderedWrapper +where + F: Fn(&T, &T) -> Ordering, +{ + fn cmp(&self, other: &Self) -> Ordering { + (self.compare)(&self.value, &other.value) + } +} + /// Binary heap merger implementation. /// Merges multiple sorted inputs into a single sorted output. /// Time complexity is *m* \* log(*n*) in worst case where *m* is the number of items, /// *n* is the number of chunks (inputs). -pub struct BinaryHeapMerger +pub struct BinaryHeapMerger where - T: Ord, E: Error, + F: Fn(&T, &T) -> Ordering, C: IntoIterator>, { // binary heap is max-heap by default so we reverse it to convert it to min-heap - items: BinaryHeap<(std::cmp::Reverse, usize)>, + items: BinaryHeap<(std::cmp::Reverse>, usize)>, chunks: Vec, initiated: bool, + compare: F, } -impl BinaryHeapMerger +impl BinaryHeapMerger where - T: Ord, E: Error, + F: Fn(&T, &T) -> Ordering, C: IntoIterator>, { /// Creates an instance of a binary heap merger using chunks as inputs. @@ -30,7 +82,7 @@ where /// /// # Arguments /// * `chunks` - Chunks to be merged in a single sorted one - pub fn new(chunks: I) -> Self + pub fn new(chunks: I, compare: F) -> Self where I: IntoIterator, { @@ -40,15 +92,16 @@ where return BinaryHeapMerger { chunks, items, + compare, initiated: false, }; } } -impl Iterator for BinaryHeapMerger +impl Iterator for BinaryHeapMerger where - T: Ord, E: Error, + F: Fn(&T, &T) -> Ordering + Copy, C: IntoIterator>, { type Item = Result; @@ -59,7 +112,9 @@ where for (idx, chunk) in self.chunks.iter_mut().enumerate() { if let Some(item) = chunk.next() { match item { - Ok(item) => self.items.push((std::cmp::Reverse(item), idx)), + Ok(item) => self + .items + .push((std::cmp::Reverse(OrderedWrapper::wrap(item, self.compare)), idx)), Err(err) => return Some(Err(err)), } } @@ -70,12 +125,14 @@ where let (result, idx) = self.items.pop()?; if let Some(item) = self.chunks[idx].next() { match item { - Ok(item) => self.items.push((std::cmp::Reverse(item), idx)), + Ok(item) => self + .items + .push((std::cmp::Reverse(OrderedWrapper::wrap(item, self.compare)), idx)), Err(err) => return Some(Err(err)), } } - return Some(Ok(result.0)); + return Some(Ok(result.0.unwrap())); } } @@ -131,7 +188,7 @@ mod test { #[case] chunks: Vec>>, #[case] expected_result: Vec>, ) { - let merger = BinaryHeapMerger::new(chunks); + let merger = BinaryHeapMerger::new(chunks, i32::cmp); let actual_result = merger.collect(); assert!( compare_vectors_of_result::<_, io::Error>(&actual_result, &expected_result), diff --git a/src/sort.rs b/src/sort.rs index 656db56..33342ea 100644 --- a/src/sort.rs +++ b/src/sort.rs @@ -1,6 +1,7 @@ //! External sorter. use log; +use std::cmp::Ordering; use std::error::Error; use std::fmt; use std::fmt::{Debug, Display}; @@ -64,7 +65,7 @@ impl Display for SortError { #[derive(Clone)] pub struct ExternalSorterBuilder> where - T: Ord + Send, + T: Send, E: Error, B: ChunkBufferBuilder, C: ExternalChunk, @@ -88,7 +89,7 @@ where impl ExternalSorterBuilder where - T: Ord + Send, + T: Send, E: Error, B: ChunkBufferBuilder, C: ExternalChunk, @@ -137,7 +138,7 @@ where impl Default for ExternalSorterBuilder where - T: Ord + Send, + T: Send, E: Error, B: ChunkBufferBuilder, C: ExternalChunk, @@ -158,7 +159,7 @@ where /// External sorter. pub struct ExternalSorter> where - T: Ord + Send, + T: Send, E: Error, B: ChunkBufferBuilder, C: ExternalChunk, @@ -182,7 +183,7 @@ where impl ExternalSorter where - T: Ord + Send, + T: Send, E: Error, B: ChunkBufferBuilder, C: ExternalChunk, @@ -246,17 +247,42 @@ where return Ok(tmp_dir); } - /// Sorts data from input using external sort algorithm. + /// Sorts data from the input. /// Returns an iterator that can be used to get sorted data stream. + /// + /// # Arguments + /// * `input` - Input stream data to be fetched from pub fn sort( &self, input: I, ) -> Result< - BinaryHeapMerger, + BinaryHeapMerger Ordering + Copy, C>, + SortError, + > + where + T: Ord, + I: IntoIterator>, + { + self.sort_by(input, T::cmp) + } + + /// Sorts data from the input using a custom compare function. + /// Returns an iterator that can be used to get sorted data stream. + /// + /// # Arguments + /// * `input` - Input stream data to be fetched from + /// * `compare` - Function be be used to compare items + pub fn sort_by( + &self, + input: I, + compare: F, + ) -> Result< + BinaryHeapMerger, SortError, > where I: IntoIterator>, + F: Fn(&T, &T) -> Ordering + Sync + Send + Copy, { let mut chunk_buf = self.buffer_builder.build(); let mut external_chunks = Vec::new(); @@ -268,34 +294,39 @@ where } if chunk_buf.is_full() { - external_chunks.push(self.create_chunk(chunk_buf)?); + external_chunks.push(self.create_chunk(chunk_buf, compare)?); chunk_buf = self.buffer_builder.build(); } } if chunk_buf.len() > 0 { - external_chunks.push(self.create_chunk(chunk_buf)?); + external_chunks.push(self.create_chunk(chunk_buf, compare)?); } log::debug!("external sort preparation done"); - return Ok(BinaryHeapMerger::new(external_chunks)); + return Ok(BinaryHeapMerger::new(external_chunks, compare)); } - fn create_chunk( + fn create_chunk( &self, - mut chunk: impl ChunkBuffer, - ) -> Result> { + mut buffer: impl ChunkBuffer, + compare: F, + ) -> Result> + where + F: Fn(&T, &T) -> Ordering + Sync + Send, + { log::debug!("sorting chunk data ..."); self.thread_pool.install(|| { - chunk.par_sort(); + buffer.par_sort_by(compare); }); log::debug!("saving chunk data"); - let external_chunk = ExternalChunk::build(&self.tmp_dir, chunk, self.rw_buf_size).map_err(|err| match err { - ExternalChunkError::IO(err) => SortError::IO(err), - ExternalChunkError::SerializationError(err) => SortError::SerializationError(err), - })?; + let external_chunk = + ExternalChunk::build(&self.tmp_dir, buffer, self.rw_buf_size).map_err(|err| match err { + ExternalChunkError::IO(err) => SortError::IO(err), + ExternalChunkError::SerializationError(err) => SortError::SerializationError(err), + })?; return Ok(external_chunk); } @@ -307,11 +338,14 @@ mod test { use std::path::Path; use rand::seq::SliceRandom; + use rstest::*; use super::{ExternalSorter, ExternalSorterBuilder, LimitedBufferBuilder}; - #[test] - fn test_external_sorter() { + #[rstest] + #[case(false)] + #[case(true)] + fn test_external_sorter(#[case] reversed: bool) { let input_sorted = 0..100; let mut input: Vec> = Vec::from_iter(input_sorted.clone().map(|item| Ok(item))); @@ -324,11 +358,21 @@ mod test { .build() .unwrap(); - let result = sorter.sort(input).unwrap(); + let compare = if reversed { + |a: &i32, b: &i32| a.cmp(b).reverse() + } else { + |a: &i32, b: &i32| a.cmp(b) + }; + + let result = sorter.sort_by(input, compare).unwrap(); let actual_result: Result, _> = result.collect(); let actual_result = actual_result.unwrap(); - let expected_result = Vec::from_iter(input_sorted.clone()); + let expected_result = if reversed { + Vec::from_iter(input_sorted.clone().rev()) + } else { + Vec::from_iter(input_sorted.clone()) + }; assert_eq!(actual_result, expected_result) }