diff --git a/cli/src/main.rs b/cli/src/main.rs index 58ca0e2..24a1222 100644 --- a/cli/src/main.rs +++ b/cli/src/main.rs @@ -1,4 +1,4 @@ -use libribzip2::stream::{decode_stream, encode_stream}; +use libribzip2::stream::{decode_stream, Encoder}; use libribzip2::EncodingStrategy; use num_cpus; use std::fmt; @@ -119,7 +119,8 @@ fn try_main(opt: Opt) -> Result<(), FileError> { }, }; let threads_val = threads.unwrap_or(num_cpus::get()); - encode_stream(&mut in_file, &mut out_file, threads_val, encoding_strategy); + let mut encoder = Encoder::new(in_file, encoding_strategy, threads_val); + std::io::copy(&mut encoder, &mut out_file).unwrap(); } } } diff --git a/lib/src/bitwise/bitwriter.rs b/lib/src/bitwise/bitwriter.rs index fc2d77d..4f18001 100644 --- a/lib/src/bitwise/bitwriter.rs +++ b/lib/src/bitwise/bitwriter.rs @@ -144,6 +144,58 @@ pub fn increment_symbol(input: Vec) -> Vec { let len = input.len(); convert_to_code_pad_to_n_bits(convert_to_number(&input) + 1, len) } + +#[derive(Default)] +pub struct BufferBitWriter { + pending_bits: Vec, + buffer: Vec, +} + +impl BufferBitWriter { + pub fn pull(&mut self, limit: usize) -> Vec { + let how_many = if limit > self.buffer.len() { + self.buffer.len() + } else { + limit + }; + let out = self.buffer[0..how_many].to_vec(); + let out2 = self.buffer[how_many..self.buffer.len()].to_vec(); + self.buffer = out2; + + out + } + + pub fn content(&self) -> usize { + self.buffer.len() + } +} + +impl BitWriter for BufferBitWriter { + fn write_bits(&mut self, bits_to_write: &[Bit]) -> Result<(), ()> { + self.pending_bits.append(&mut bits_to_write.to_vec()); + let mut chunks = self.pending_bits.chunks_exact(8); + for chunk in &mut chunks { + let number = convert_to_number(chunk); + self.buffer.push(number as u8); + } + self.pending_bits = chunks.remainder().to_vec(); + Ok(()) + } + + fn finalize(&mut self) -> Result<(), ()> { + if self.pending_bits.is_empty() { + return Ok(()); + } + let mut trailing_zeros = vec![Bit::Zero; 8 - self.pending_bits.len()]; + self.pending_bits.append(&mut trailing_zeros); + let byte = convert_to_number(&self.pending_bits); + self.buffer.push(byte as u8); + self.pending_bits.clear(); + + Ok(()) + } +} + #[cfg(test)] mod test { use super::*; diff --git a/lib/src/stream/mod.rs b/lib/src/stream/mod.rs index bcad52e..6019ae5 100644 --- a/lib/src/stream/mod.rs +++ b/lib/src/stream/mod.rs @@ -1,5 +1,6 @@ use crate::bitwise::bitreader::BitReaderImpl; use crate::bitwise::bitwriter::convert_to_code_pad_to_byte; +use crate::bitwise::bitwriter::BufferBitWriter; use crate::block::block_encoder::crc_as_bytes; @@ -12,7 +13,6 @@ use std::thread; use crate::bitwise::bitreader::BitReader; use crate::bitwise::bitwriter::BitWriter; -use crate::bitwise::bitwriter::BitWriterImpl; use crate::block::block_decoder::decode_block; use crate::block::block_encoder::generate_block_data; use crate::block::crc32::crc32; @@ -24,6 +24,138 @@ use super::block::rle::rle_augment; use super::block::rle::rle_total_size; use super::block::symbol_statistics::EncodingStrategy; +/// Encoder to bzip2 encode a stream. +/// ```rust +/// use libribzip2::EncodingStrategy; +/// use libribzip2::stream::Encoder; +/// use std::io::Cursor; +/// +/// let num_threads = 4; +/// let encoding_strategy = EncodingStrategy::Single; +/// +/// let reader = Cursor::new(vec![1, 2, 3, 4]); +/// let mut writer = vec![]; +/// +/// let mut encoder = Encoder::new(reader, encoding_strategy, num_threads); +/// std::io::copy(&mut encoder, &mut writer).unwrap(); +/// ``` +pub struct Encoder { + reader: T, + num_threads: usize, + encoding_strategy: EncodingStrategy, + bit_writer: BufferBitWriter, + total_crc: u32, + finalized: bool, + encoded: bool, + initialized: bool, +} + +impl Encoder +where + T: Read, +{ + pub fn new(reader: T, encoding_strategy: EncodingStrategy, num_threads: usize) -> Self { + Encoder { + reader, + num_threads, + encoding_strategy, + bit_writer: Default::default(), + total_crc: 0, + finalized: false, + encoded: false, + initialized: false, + } + } +} + +impl Read for Encoder +where + T: Read, +{ + fn read(&mut self, buf: &mut [u8]) -> std::io::Result { + const RLE_LIMIT: usize = 900_000; + + let mut worker_threads = (0..self.num_threads) + .map(|num| WorkerThread::spawn(&format!("Thread {}", num), self.encoding_strategy)) + .collect::>(); + + if !self.initialized { + self.bit_writer.write_bits(&file_header()).unwrap(); + self.initialized = true; + } + + while !self.finalized && !self.encoded { + if self.bit_writer.content() > buf.len() { + break; + } + for worker_thread in worker_threads.iter_mut() { + let mut buf = vec![]; + let mut rle_data = vec![]; + let mut rle_total_count = 0; + let mut rle_count = 0; + let mut rle_last_char = None; + while rle_total_count < RLE_LIMIT { + // RLE can blow up 4chars to 5, hence we keep a safety margin + let to_take = (RLE_LIMIT - rle_data.len()) * 4 / 5; + let mut buf_current = vec![]; + if let Ok(size) = self + .reader + .by_ref() + .take(to_take as u64) + .read_to_end(&mut buf_current) + { + if size == 0 { + self.finalized = true; + break; + } + } else { + break; + } + let rle_result = rle(&buf_current, rle_count, rle_last_char); + let mut rle_next = rle_result.data; + let rle_next_count = rle_result.counter; + let rle_next_char = rle_result.last_byte; + + let next_data_len = rle_data.len() + rle_next.len(); + rle_total_count = rle_total_size(next_data_len, rle_next_count, rle_next_char); + + rle_data.append(&mut rle_next); + buf.append(&mut buf_current); + rle_count = rle_next_count; + rle_last_char = rle_next_char; + } + + if buf.len() == 0 { + break; + } + + let rle_total = rle_augment(&rle_data, rle_count, rle_last_char); + let computed_crc = crc32(&buf); + worker_thread.send_work((computed_crc, rle_total)); + } + + for worker_thread in worker_threads.iter_mut() { + if worker_thread.pending { + worker_thread.flush_work_buffer(&mut self.bit_writer, &mut self.total_crc); + } + } + } + + if self.finalized && !self.encoded { + self.bit_writer + .write_bits(&stream_footer(self.total_crc)) + .unwrap(); + self.bit_writer.finalize().unwrap(); + self.encoded = true; + } + + let res = self.bit_writer.pull(buf.len()); + buf[0..res.len()].copy_from_slice(&res); + + Ok(res.len()) + } +} + fn stream_footer(crc: u32) -> Vec { let mut out = vec![]; @@ -100,85 +232,6 @@ impl WorkerThread { } } -/// Encode a stream into a writer. Takes a reader and a writer (i.e. two instances of [std::fs::File]). -/// The number of threads and the encoding strategy can be specified. -pub fn encode_stream( - mut read: impl Read, - mut writer: impl Write, - num_threads: usize, - encoding_strategy: EncodingStrategy, -) { - let mut bit_writer = BitWriterImpl::from_writer(&mut writer); - const RLE_LIMIT: usize = 900_000; - let mut total_crc: u32 = 0; - - let mut worker_threads = (0..num_threads) - .map(|num| WorkerThread::spawn(&format!("Thread {}", num), encoding_strategy)) - .collect::>(); - - bit_writer.write_bits(&file_header()).unwrap(); - - let mut finalize = false; - loop { - if finalize { - break; - } - for worker_thread in worker_threads.iter_mut() { - let mut buf = vec![]; - let mut rle_data = vec![]; - let mut rle_total_count = 0; - let mut rle_count = 0; - let mut rle_last_char = None; - while rle_total_count < RLE_LIMIT { - // RLE can blow up 4chars to 5, hence we keep a safety margin - let to_take = (RLE_LIMIT - rle_data.len()) * 4 / 5; - let mut buf_current = vec![]; - if let Ok(size) = read - .by_ref() - .take(to_take as u64) - .read_to_end(&mut buf_current) - { - if size == 0 { - finalize = true; - break; - } - } else { - break; - } - let rle_result = rle(&buf_current, rle_count, rle_last_char); - let mut rle_next = rle_result.data; - let rle_next_count = rle_result.counter; - let rle_next_char = rle_result.last_byte; - - let next_data_len = rle_data.len() + rle_next.len(); - rle_total_count = rle_total_size(next_data_len, rle_next_count, rle_next_char); - - rle_data.append(&mut rle_next); - buf.append(&mut buf_current); - rle_count = rle_next_count; - rle_last_char = rle_next_char; - } - - if buf.len() == 0 { - break; - } - - let rle_total = rle_augment(&rle_data, rle_count, rle_last_char); - let computed_crc = crc32(&buf); - worker_thread.send_work((computed_crc, rle_total)); - } - - for worker_thread in worker_threads.iter_mut() { - if worker_thread.pending { - worker_thread.flush_work_buffer(&mut bit_writer, &mut total_crc); - } - } - } - - bit_writer.write_bits(&stream_footer(total_crc)).unwrap(); - bit_writer.finalize().unwrap(); -} - fn read_file_header(mut bit_reader: impl BitReader) -> Result<(), ()> { let res = bit_reader.read_bytes(4)?; match &res[..] { @@ -208,6 +261,19 @@ fn what_next(mut bit_reader: impl BitReader) -> Result { } } +/// Encode a stream into a writer. Takes a reader and a writer (i.e. two instances of [std::fs::File]). +/// The number of threads and the encoding strategy can be specified. +#[deprecated] +pub fn encode_stream( + read: impl Read, + mut writer: impl Write, + num_threads: usize, + encoding_strategy: EncodingStrategy, +) { + let mut encoder = Encoder::new(read, encoding_strategy, num_threads); + std::io::copy(&mut encoder, &mut writer).unwrap(); +} + /// Decode a stream into a writer. Takes a reader and a writer (i.e. two instances of [std::fs::File]) pub fn decode_stream(mut reader: impl Read, mut writer: impl Write) -> Result<(), ()> { let mut bit_reader = BitReaderImpl::from_reader(&mut reader);