diff --git a/.github/workflows/build-test.yml b/.github/workflows/build-test.yml index 9694194..6c61557 100644 --- a/.github/workflows/build-test.yml +++ b/.github/workflows/build-test.yml @@ -35,7 +35,7 @@ jobs: - name: Build LongReadSum shell: bash --login {0} # --login enables PATH variable access - run: make + run: make -d - name: Run tests shell: bash --login {0} diff --git a/Makefile b/Makefile index 7b1e50e..1c23e60 100644 --- a/Makefile +++ b/Makefile @@ -3,6 +3,9 @@ SRC_DIR := $(CURDIR)/src LIB_DIR := $(CURDIR)/lib # Set the library paths for the compiler +# CONDA_PREFIX ?= $(shell echo $$CONDA_PREFIX) +# LIBRARY_PATHS := -L$(LIB_DIR) -L$(CONDA_PREFIX)/lib +# INCLUDE_PATHS := -I$(INCL_DIR) -I$(CONDA_PREFIX)/include LIBRARY_PATHS := -L$(LIB_DIR) -L/usr/share/miniconda/envs/longreadsum/lib INCLUDE_PATHS := -I$(INCL_DIR) -I/usr/share/miniconda/envs/longreadsum/include @@ -11,9 +14,14 @@ all: swig_build compile # Generate the SWIG Python/C++ wrappers swig_build: + mkdir -p $(LIB_DIR) swig -c++ -python -outdir $(LIB_DIR) -I$(INCL_DIR) -o $(SRC_DIR)/lrst_wrap.cpp $(SRC_DIR)/lrst.i # Compile the C++ shared libraries into lib/ compile: LD_LIBRARY_PATH=$(LD_LIBRARY_PATH):/usr/share/miniconda/envs/longreadsum/lib \ CXXFLAGS="$(INCLUDE_PATHS)" LDFLAGS="$(LIBRARY_PATHS)" python3 setup.py build_ext --build-lib $(LIB_DIR) + +# Clean the build directory +clean: + $(RM) -r $(LIB_DIR)/*.so $(LIB_DIR)/*.py $(SRC_DIR)/lrst_wrap.cpp build/ diff --git a/README.md b/README.md index e5b4bed..8184491 100644 --- a/README.md +++ b/README.md @@ -148,7 +148,7 @@ MinION R9.4.1 from https://labs.epi2me.io/gm24385-5mc/) ## General usage ``` -longreadsum bam -i $INPUT_FILE -o $OUTPUT_DIRECTORY --ref $REF_GENOME --modprob 0.8 +longreadsum bam -i $INPUT_FILE -o $OUTPUT_DIRECTORY --mod --modprob 0.8 --ref $REF_GENOME ``` # RRMS BAM @@ -258,7 +258,12 @@ longreadsum bam -i $INPUT_FILE -o $OUTPUT_DIRECTORY # ONT POD5 This section describes how to generate QC reports for ONT POD5 (signal) files and their corresponding basecalled BAM files (data shown is HG002 using ONT -R10.4.1 and LSK114 downloaded from the tutorial https://github.com/epi2me-labs/wf-basecalling). +R10.4.1 and LSK114 downloaded from the tutorial +https://github.com/epi2me-labs/wf-basecalling). + +> [!NOTE] +> This requires generating basecalled BAM files with the move table output. For +> example, for [dorado](https://github.com/nanoporetech/dorado), the parameter is `--emit-moves` ![image](https://github.com/user-attachments/assets/62c3c810-5c1a-4124-816b-74245af8b57c) diff --git a/conda/build.sh b/conda/build.sh index 95f8d11..61720bc 100644 --- a/conda/build.sh +++ b/conda/build.sh @@ -3,18 +3,27 @@ # Add the library path to the LD_LIBRARY_PATH export LD_LIBRARY_PATH=${LD_LIBRARY_PATH}:${PREFIX}/lib +# Ensure the lib directory exists +mkdir -p "${SRC_DIR}"/lib + # Generate the SWIG files +echo "Generating SWIG files..." swig -c++ -python -outdir "${SRC_DIR}"/lib -I"${SRC_DIR}"/include -I"${PREFIX}"/include -o "${SRC_DIR}"/src/lrst_wrap.cpp "${SRC_DIR}"/src/lrst.i # Generate the shared library +echo "Building the shared library..." $PYTHON setup.py -I"${PREFIX}"/include -L"${PREFIX}"/lib install # Create the src directory mkdir -p "${PREFIX}"/src # Copy source files to the bin directory +echo "Copying source files..." cp -r "${SRC_DIR}"/src/*.py "${PREFIX}"/bin # Copy the SWIG generated library to the lib directory +echo "Copying SWIG generated library..." cp -r "${SRC_DIR}"/lib/*.py "${PREFIX}"/lib cp -r "${SRC_DIR}"/lib/*.so "${PREFIX}"/lib + +echo "Build complete." diff --git a/conda/meta.yaml b/conda/meta.yaml index edf847d..e1d7004 100644 --- a/conda/meta.yaml +++ b/conda/meta.yaml @@ -1,14 +1,14 @@ {% set version = "1.4.0" %} -# {% set revision = "b06670513616fd6342233c1c77e6d0bcf138b3bc" %} +{% set revision = "c257da611a4ae2cfcb4c6c42fcb504f808d644f9" %} package: name: longreadsum version: {{ version }} source: - path: ../ - # git_url: https://github.com/WGLab/LongReadSum.git - # git_rev: {{ revision }} + git_url: https://github.com/WGLab/LongReadSum.git + git_rev: {{ revision }} + # path: ../ channels: - conda-forge @@ -29,18 +29,17 @@ requirements: host: - python=3.9 - swig - - hdf5 - htslib=1.20 + - ont_vbz_hdf_plugin # Contains HDF5 as a dependency as well # - jannessp::pod5 # - jannessp::lib-pod5 run: - python=3.9 - numpy - - hdf5 - ont_vbz_hdf_plugin - - htslib=1.20 + - bioconda::htslib=1.20 - plotly - - janessp::pod5 + - jannessp::pod5 - pyarrow # - janessp::lib-pod5 diff --git a/environment.yml b/environment.yml index 6a9e2af..b18d99e 100644 --- a/environment.yml +++ b/environment.yml @@ -1,17 +1,17 @@ name: longreadsum channels: - conda-forge + - jannessp # for pod5 - bioconda - defaults - - jannessp # for pod5 + dependencies: - - python=3.9 + - python - numpy - - hdf5 - ont_vbz_hdf_plugin - - htslib=1.20 + - bioconda::htslib=1.20 - swig - plotly - pytest - - pod5 - - pyarrow \ No newline at end of file + - jannessp::pod5 + - pyarrow diff --git a/include/hts_reader.h b/include/hts_reader.h index 8790e88..bc8f5d8 100644 --- a/include/hts_reader.h +++ b/include/hts_reader.h @@ -38,7 +38,7 @@ class HTSReader { bool reading_complete = false; // Update read and base counts - int updateReadAndBaseCounts(bam1_t* record, Basic_Seq_Statistics* basic_qc, uint64_t *base_quality_distribution); + int updateReadAndBaseCounts(bam1_t* record, Basic_Seq_Statistics& basic_qc, Basic_Seq_Quality_Statistics& seq_quality_info, bool is_primary); // Read the next batch of records from the BAM file int readNextRecords(int batch_size, Output_BAM & output_data, std::mutex & read_mutex, std::unordered_set& read_ids, double base_mod_threshold); @@ -47,9 +47,12 @@ class HTSReader { bool hasNextRecord(); // Return the number of records in the BAM file using the BAM index - int64_t getNumRecords(const std::string &bam_file_name, Output_BAM &final_output, bool mod_analysis, double base_mod_threshold); + int getNumRecords(const std::string &bam_file_name, int thread_count); - std::map getQueryToRefMap(bam1_t *record); + // Run base modification analysis + void runBaseModificationAnalysis(const std::string &bam_filename, Output_BAM& final_output, double base_mod_threshold, int read_count, int sample_count, int thread_count); + + std::map getQueryToRefMap(bam1_t* record); // Add a modification to the base modification map void addModificationToQueryMap(std::map> &base_modifications, int32_t pos, char mod_type, char canonical_base, double likelihood, int strand); diff --git a/include/output_data.h b/include/output_data.h index 53d30bf..2e1610d 100644 --- a/include/output_data.h +++ b/include/output_data.h @@ -14,6 +14,7 @@ Define the output structures for each module. #include "input_parameters.h" #include "tin_stats.h" +#include "utils.h" #define MAX_READ_LENGTH 10485760 #define MAX_BASE_QUALITY 100 @@ -78,7 +79,7 @@ class Basic_Seq_Quality_Statistics //std::vector base_quality_distribution; // Array of base quality distribution initialized to 0 uint64_t base_quality_distribution[MAX_BASE_QUALITY] = {ZeroDefault}; - std::vector read_average_base_quality_distribution; + std::vector read_average_base_quality_distribution; // Read average base quality distribution int min_base_quality = MoneDefault; // minimum base quality; int max_base_quality = MoneDefault; // maximum base quality; std::vector pos_quality_distribution; @@ -114,7 +115,7 @@ class Output_FQ : public Output_FA // Define the base modification data structure (modification type, canonical // base, likelihood, strand: 0 for forward, 1 for reverse, and CpG flag: T/F) -using Base_Modification = std::tuple; +// using Base_Modification = std::tuple; // Define the signal-level data structure for POD5 (ts, ns, move table vector) using POD5_Signal_Data = std::tuple>; @@ -156,95 +157,123 @@ class Base_Move_Table }; +// Structures for storing read length vs. base modification rate data +struct ReadModData +{ + int read_length; + std::unordered_map base_mod_rates; // Type-specific base modification rates +}; + // BAM output class Output_BAM : public Output_FQ { -public: - uint64_t num_primary_alignment = ZeroDefault; // the number of primary alignment/ - uint64_t num_secondary_alignment = ZeroDefault; // the number of secondary alignment - uint64_t num_reads_with_secondary_alignment = ZeroDefault; // the number of long reads with the secondary alignment: one read might have multiple seconard alignment - uint64_t num_supplementary_alignment = ZeroDefault; // the number of supplementary alignment - uint64_t num_reads_with_supplementary_alignment = ZeroDefault; // the number of long reads with secondary alignment; - uint64_t num_reads_with_both_secondary_supplementary_alignment = ZeroDefault; // the number of long reads with both secondary and supplementary alignment. - uint64_t forward_alignment = ZeroDefault; // Total number of forward alignments - uint64_t reverse_alignment = ZeroDefault; // Total number of reverse alignments - std::map reads_with_supplementary; // Map of reads with supplementary alignments - std::map reads_with_secondary; // Map of reads with secondary alignments - - // Similar to Output_FA: below are for mapped. - uint64_t num_matched_bases = ZeroDefault; // the number of matched bases with = - uint64_t num_mismatched_bases = ZeroDefault; // the number of mismatched bases X - uint64_t num_ins_bases = ZeroDefault; // the number of inserted bases; - uint64_t num_del_bases = ZeroDefault; // the number of deleted bases; - uint64_t num_clip_bases = ZeroDefault; // the number of soft-clipped bases; - - // The number of columns can be calculated by summing over the lengths of M/I/D CIGAR operators - int num_columns = ZeroDefault; // the number of columns - double percent_identity = ZeroDefault; // Percent identity = (num columns - NM) / num columns - std::vector accuracy_per_read; - - // Preprint revisions: Remove all counts with unique positions in the - // reference genome, and only report raw counts - uint64_t modified_prediction_count = ZeroDefault; // Total number of modified base predictions - uint64_t sample_modified_base_count = ZeroDefault; // Total number of modified bases passing the threshold - uint64_t sample_modified_base_count_forward = ZeroDefault; // Total number of modified bases passing the threshold on the forward strand - uint64_t sample_modified_base_count_reverse = ZeroDefault; // Total number of modified bases passing the threshold on the reverse strand - uint64_t sample_cpg_forward_count = ZeroDefault; // Total number of modified bases passing the threshold that are in CpG sites and in the forward strand (non-unique) - uint64_t sample_cpg_reverse_count = ZeroDefault; // Total number of modified bases passing the threshold that are in CpG sites and in the reverse strand (non-unique) - std::map>> sample_c_modified_positions; // chr -> vector of (position, strand) for modified bases passing the threshold - - // Signal data section - int read_count = ZeroDefault; - int base_count = ZeroDefault; - std::unordered_map read_move_table; - - // POD5 signal-level information is stored in a map of read names to a map of - // reference positions to a tuple of (ts, ns, move table vector) - std::unordered_map pod5_signal_data; - - // Dictionary of bam filepath to TIN data - std::unordered_map tin_data; - - Basic_Seq_Statistics mapped_long_read_info; - Basic_Seq_Statistics unmapped_long_read_info; - - Basic_Seq_Quality_Statistics mapped_seq_quality_info; - Basic_Seq_Quality_Statistics unmapped_seq_quality_info; - - // POD5 signal data functions - int getReadCount(); - void addReadMoveTable(std::string read_name, std::string sequence_data_str, std::vector move_table, int start, int end); - std::vector getReadMoveTable(std::string read_id); - std::string getReadSequence(std::string read_id); - int getReadSequenceStart(std::string read_id); - int getReadSequenceEnd(std::string read_id); - - // Add a batch of records to the output - void add(Output_BAM &t_output_bam); - - // Add TIN data for a single BAM file - void addTINData(std::string &bam_file, TINStats &tin_data); - - // Get the TIN mean for a single BAM file - double getTINMean(std::string bam_file); - - // Get the TIN median for a single BAM file - double getTINMedian(std::string bam_file); - - // Get the TIN standard deviation for a single BAM file - double getTINStdDev(std::string bam_file); - - // Get the TIN count for a single BAM file - int getTINCount(std::string bam_file); - - // Calculate QC across all records - void global_sum(); - - // Save the output to a summary text file - void save_summary(std::string &output_file, Input_Para ¶ms, Output_BAM &output_data); - - Output_BAM(); - ~Output_BAM(); + public: + uint64_t num_primary_alignment = ZeroDefault; // the number of primary alignment/ + uint64_t num_secondary_alignment = ZeroDefault; // the number of secondary alignment + uint64_t num_reads_with_secondary_alignment = ZeroDefault; // the number of long reads with the secondary alignment: one read might have multiple seconard alignment + uint64_t num_supplementary_alignment = ZeroDefault; // the number of supplementary alignment + uint64_t num_reads_with_supplementary_alignment = ZeroDefault; // the number of long reads with secondary alignment; + uint64_t num_reads_with_both_secondary_supplementary_alignment = ZeroDefault; // the number of long reads with both secondary and supplementary alignment. + uint64_t forward_alignment = ZeroDefault; // Total number of forward alignments + uint64_t reverse_alignment = ZeroDefault; // Total number of reverse alignments + std::map reads_with_supplementary; // Map of reads with supplementary alignments + std::map reads_with_secondary; // Map of reads with secondary alignments + + // Similar to Output_FA: below are for mapped. + uint64_t num_matched_bases = ZeroDefault; // the number of matched bases with = + uint64_t num_mismatched_bases = ZeroDefault; // the number of mismatched bases X + uint64_t num_ins_bases = ZeroDefault; // the number of inserted bases; + uint64_t num_del_bases = ZeroDefault; // the number of deleted bases; + uint64_t num_clip_bases = ZeroDefault; // the number of soft-clipped bases; + + // The number of columns can be calculated by summing over the lengths of M/I/D CIGAR operators + int num_columns = ZeroDefault; // the number of columns + double percent_identity = ZeroDefault; // Percent identity = (num columns - NM) / num columns + std::vector accuracy_per_read; + + // Preprint revisions: Remove all counts with unique positions in the + // reference genome, and only report raw counts + uint64_t modified_prediction_count = ZeroDefault; // Total number of modified base predictions + uint64_t sample_modified_base_count = ZeroDefault; // Total number of modified bases passing the threshold + uint64_t sample_modified_base_count_forward = ZeroDefault; // Total number of modified bases passing the threshold on the forward strand + uint64_t sample_modified_base_count_reverse = ZeroDefault; // Total number of modified bases passing the threshold on the reverse strand + uint64_t sample_cpg_forward_count = ZeroDefault; // Total number of modified bases passing the threshold that are in CpG sites and in the forward strand (non-unique) + uint64_t sample_cpg_reverse_count = ZeroDefault; // Total number of modified bases passing the threshold that are in CpG sites and in the reverse strand (non-unique) + std::map>> sample_c_modified_positions; // chr -> vector of (position, strand) for modified bases passing the threshold + + // std::pair, std::vector> read_length_mod_rate; // Read length vs. base modification rate + // std::unordered_map, std::vector>> read_length_mod_rate; // Read length vs. base modification rate for each base modification type + std::unordered_map base_mod_counts; // Counts for each base modification type exceeding the threshold + std::unordered_map base_mod_counts_forward; // Counts for each base modification type exceeding the threshold on the forward strand + std::unordered_map base_mod_counts_reverse; // Counts for each base modification type exceeding the threshold on the reverse strand + + std::unordered_map>> read_pct_len_vs_mod_prob; // Read length (%) vs. base modification probability for each base modification type + + // Signal data section + int read_count = ZeroDefault; + int base_count = ZeroDefault; + std::unordered_map read_move_table; + + // POD5 signal-level information is stored in a map of read names to a map of + // reference positions to a tuple of (ts, ns, move table vector) + std::unordered_map pod5_signal_data; + + std::unordered_map tin_data; // TIN data for each BAM file + + Basic_Seq_Statistics mapped_long_read_info; + Basic_Seq_Statistics unmapped_long_read_info; + + Basic_Seq_Quality_Statistics mapped_seq_quality_info; + Basic_Seq_Quality_Statistics unmapped_seq_quality_info; + + std::vector read_mod_data; // Read length vs. base modification rate + std::vector getBaseModTypes(); // Get the types of base modifications found + int getReadModDataSize(); // Get the number of read length vs. base modification rate data points + int getNthReadModLength(int read_index); // Get the read length for the nth read + double getNthReadModRate(int read_index, char mod_type); // Get the base modification rate for the nth read for a specific base modification type + uint64_t getModTypeCount(char mod_type); // Get the count of a specific base modification type + uint64_t getModTypeCount(char mod_type, int strand); // Get the count of a specific base modification type for a specific strand + double getNthReadLenPct(int read_index, char mod_type); // Get the read length percentage for the nth read for a specific base modification type + double getNthReadModProb(int read_index, char mod_type); // Get the base modification probability for the nth read for a specific base modification type + + // POD5 signal data functions + int getReadCount(); + void addReadMoveTable(std::string read_name, std::string sequence_data_str, std::vector move_table, int start, int end); + std::vector getReadMoveTable(std::string read_id); + std::string getReadSequence(std::string read_id); + int getReadSequenceStart(std::string read_id); + int getReadSequenceEnd(std::string read_id); + + void updateBaseModCounts(char mod_type, int strand); // Update base modification counts for predictions exceeding the threshold + void updateBaseModProbabilities(char mod_type, double pct_len, double probability); // Update base modification probabilities + void updateReadModRate(int read_length, const std::unordered_map& base_mod_rates); // Update read length vs. base modification rate data + + // Add TIN data for a single BAM file + void addTINData(std::string &bam_file, TINStats &tin_data); + + // TIN mean for a single BAM file + double getTINMean(std::string bam_file); // Get the TIN mean for a single BAM file + + // TIN median for a single BAM file + double getTINMedian(std::string bam_file); + + // TIN standard deviation for a single BAM file + double getTINStdDev(std::string bam_file); + + // TIN count for a single BAM file + int getTINCount(std::string bam_file); + + // Add a batch of records to the output + void add(Output_BAM &t_output_bam); + + // Calculate QC across all records + void global_sum(); + + // Save the output to a summary text file + void save_summary(std::string &output_file, Input_Para ¶ms, Output_BAM &output_data); + + Output_BAM(); + ~Output_BAM(); }; diff --git a/include/seqtxt_module.h b/include/seqtxt_module.h index 1035598..d2f0ce5 100644 --- a/include/seqtxt_module.h +++ b/include/seqtxt_module.h @@ -37,7 +37,7 @@ class SeqTxt_Thread_data { Output_SeqTxt t_output_SeqTxt_; std::string current_line; // Current line being read from the file - size_t read_ss_record(std::ifstream* file_stream, std::map header_columns); + size_t read_ss_record(std::ifstream& file_stream, std::map header_columns); std::map getHeaderColumns(); SeqTxt_Thread_data(Input_Para& ref_input_op, std::map header_columns, int p_thread_id, int p_batch_size); @@ -60,20 +60,13 @@ class SeqTxt_Module{ static std::mutex myMutex_readSeqTxt; static std::mutex myMutex_output; static size_t batch_size_of_record; - Input_Para _input_parameters; - - std::ifstream *input_file_stream; // Stream for the input text file + std::ifstream input_file_stream; // Stream for the input text file std::vector m_threads; - - int has_error; // Methods - // Assign threads - static void SeqTxt_do_thread(std::ifstream* file_stream, Input_Para& ref_input_op, int thread_id, SeqTxt_Thread_data& ref_thread_data, Output_SeqTxt& ref_output); - - // Generate statistics + static void SeqTxt_do_thread(std::ifstream& file_stream, Input_Para& ref_input_op, int thread_id, Output_SeqTxt& ref_output, std::map header_columns, size_t batch_size_of_record); int generateStatistics( Output_SeqTxt& t_output_SeqTxt_info); SeqTxt_Module(Input_Para& _m_input); diff --git a/include/tin.h b/include/tin.h index b7b73f6..195596e 100644 --- a/include/tin.h +++ b/include/tin.h @@ -15,7 +15,7 @@ typedef std::unordered_map getReadDepths(htsFile* bam_file, hts_idx_t* idx, bam_hdr_t* header, std::string chr, int start, int end); diff --git a/include/utils.h b/include/utils.h index 9637f1e..1828dbb 100644 --- a/include/utils.h +++ b/include/utils.h @@ -12,4 +12,6 @@ void printMessage(std::string message); // Print an error message to stderr in a thread-safe manner void printError(std::string message); +void printMemoryUsage(const std::string &functionName); + #endif // UTILS_H diff --git a/lib/__init__.py b/lib/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/setup.py b/setup.py index 3c05b69..f136db5 100644 --- a/setup.py +++ b/setup.py @@ -29,7 +29,7 @@ # Set up the module setup(name="longreadsum", - version='1.4.0', + version='1.5.0', author="WGLab", description="""A fast and flexible QC tool for long read sequencing data""", ext_modules=[lrst_mod], diff --git a/src/bam_module.cpp b/src/bam_module.cpp index f3a0573..0014509 100644 --- a/src/bam_module.cpp +++ b/src/bam_module.cpp @@ -80,7 +80,7 @@ int BAM_Module::calculateStatistics(Input_Para &input_params, Output_BAM &final_ std::cout << "Calculating TIN scores for file: " << filepath << std::endl; TINStats tin_stats; - calculateTIN(&tin_stats, gene_bed, input_params.input_files[i], min_cov, sample_size, input_params.output_folder); + calculateTIN(&tin_stats, gene_bed, input_params.input_files[i], min_cov, sample_size, input_params.output_folder, input_params.threads); // Print the TIN stats std::cout << "Number of transcripts: " << tin_stats.num_transcripts << std::endl; @@ -113,7 +113,7 @@ int BAM_Module::calculateStatistics(Input_Para &input_params, Output_BAM &final_ // process base modifications and TINs if available. // Note: This section utilizes one thread. std::cout << "Getting number of records..." << std::endl; - int num_records = reader.getNumRecords(filepath, final_output, mod_analysis, base_mod_threshold); + int num_records = reader.getNumRecords(filepath, thread_count); std::cout << "Number of records = " << num_records << std::endl; // Exit if there are no records @@ -123,6 +123,13 @@ int BAM_Module::calculateStatistics(Input_Para &input_params, Output_BAM &final_ return exit_code; } + // Run base modification analysis if the flag is set + if (mod_analysis){ + std::cout << "Running base modification analysis..." << std::endl; + int sample_count = 10000; + reader.runBaseModificationAnalysis(filepath, final_output, base_mod_threshold, num_records, sample_count, thread_count); + } + // Determine the batch sizes if the user-specified thread count is greater than 1 int batch_size = 0; if (thread_count > 1) { @@ -146,26 +153,24 @@ int BAM_Module::calculateStatistics(Input_Para &input_params, Output_BAM &final_ } // Calculate statistics in batches + printMemoryUsage("Before batch processing"); + while (reader.hasNextRecord()){ - std::cout << "Generating " << thread_count << " thread(s)..." << std::endl; + // Read the next batch of records + // std::cout << "Generating " << thread_count << " thread(s)..." << + // std::endl; + printMessage("Generating " + std::to_string(thread_count) + " thread(s)..."); std::vector thread_vector; for (int thread_index=0; thread_index rrms_read_ids_copy = input_params.rrms_read_ids; - - // Create a thread std::thread t((BAM_Module::batchStatistics), std::ref(reader), batch_size, rrms_read_ids_copy,std::ref(final_output), std::ref(bam_mutex), std::ref(output_mutex), std::ref(cout_mutex), base_mod_threshold); - - // Add the thread to the vector thread_vector.push_back(std::move(t)); } // Join the threads in thread_vector - std::cout<<"Joining threads..."< lock(output_mutex); final_output.add(record_output); - output_mutex.unlock(); + printMemoryUsage("After record processing"); } std::unordered_set BAM_Module::readRRMSFile(std::string rrms_csv_file, bool accepted_reads) @@ -263,7 +267,10 @@ std::unordered_set BAM_Module::readRRMSFile(std::string rrms_csv_fi std::stringstream ss(header); std::string field; // std::cout << "RRMS CSV header:" << std::endl; - while (std::getline(ss, field, ',')){ + + // Split the header fields + char delimiter = ','; + while (std::getline(ss, field, delimiter)){ header_fields.push_back(field); // std::cout << field << std::endl; } @@ -298,7 +305,7 @@ std::unordered_set BAM_Module::readRRMSFile(std::string rrms_csv_fi std::vector fields; std::string field; std::stringstream ss(line); - while (std::getline(ss, field, ',')){ + while (std::getline(ss, field, delimiter)){ fields.push_back(field); } diff --git a/src/cli.py b/src/cli.py index d4c8739..bbee8cf 100644 --- a/src/cli.py +++ b/src/cli.py @@ -97,10 +97,11 @@ def get_common_param(margs): # Set up logging to stdout logging.basicConfig(stream=sys.stdout, level=get_log_level(margs.log_level), - format="%(asctime)s [%(levelname)s] %(message)s") + format="%(asctime)s %(message)s") + # format="%(asctime)s [%(levelname)s] %(message)s") else: logging.basicConfig(level=get_log_level(margs.log_level), - format="%(asctime)s [%(levelname)s] %(message)s", + format="%(asctime)s %(message)s", handlers=[ logging.FileHandler(margs.log), logging.StreamHandler(sys.stdout) @@ -154,7 +155,7 @@ def fq_module(margs): logging.info("Generating HTML report...") plot_filepaths = plot(fq_output, param_dict, 'FASTQ') fq_html_gen = generate_html.ST_HTML_Generator( - [["basic_st", "read_length_bar", "read_length_hist", "base_counts", "base_quality", + [["basic_st", "read_length_bar", "read_length_hist", "gc_content_hist", "base_counts", "base_quality", "read_avg_base_quality"], "FASTQ QC", param_dict], plot_filepaths, static=False) fq_html_gen.generate_html() @@ -164,7 +165,7 @@ def fq_module(margs): def fa_module(margs): - # Run the FASTA filetype module. + """FASTA file input module.""" # Get the filetype-specific parameters param_dict = get_common_param(margs) @@ -192,7 +193,7 @@ def fa_module(margs): logging.info("Generating HTML report...") plot_filepaths = plot(fa_output, param_dict, 'FASTA') fa_html_gen = generate_html.ST_HTML_Generator( - [["basic_st", "read_length_bar", "read_length_hist", "base_counts"], "FASTA QC", + [["basic_st", "read_length_bar", "read_length_hist", "gc_content_hist", "base_counts"], "FASTA QC", param_dict], plot_filepaths, static=True) fa_html_gen.generate_html() logging.info("Done. Output files are in %s", param_dict["output_folder"]) @@ -221,9 +222,16 @@ def bam_module(margs): param_dict["ref"] = input_para.ref_genome = ref_genome # Set the base modification flag, and filtering threshold - param_dict["mod"] = input_para.mod_analysis = margs.mod + # param_dict["mod"] = input_para.mod_analysis = margs.mod + if margs.mod: + param_dict["mod"] = input_para.mod_analysis = True + else: + param_dict["mod"] = input_para.mod_analysis = False + mod_prob = margs.modprob - param_dict["modprob"] = input_para.base_mod_threshold = mod_prob + param_dict["modprob"] = mod_prob + input_para.base_mod_threshold = mod_prob + logging.info("Base modification threshold is set to " + str(input_para.base_mod_threshold)) # Set the gene BED file for RNA-seq transcript analysis input_para.gene_bed = margs.genebed if margs.genebed != "" or margs.genebed is not None else "" @@ -245,11 +253,13 @@ def bam_module(margs): plot_filepaths = plot(bam_output, param_dict, 'BAM') # Set the list of QC information to display - qc_info_list = ["basic_st", "read_alignments_bar", "base_alignments_bar", "read_length_bar", "read_length_hist", "base_counts", "basic_info", "base_quality"] + qc_info_list = ["basic_st", "read_alignments_bar", "base_alignments_bar", "read_length_bar", "read_length_hist", "gc_content_hist", "base_counts", "base_quality", "read_avg_base_quality"] # If base modifications were found, add the base modification plots # after the first table if bam_output.sample_modified_base_count > 0: + # logging.info("Base modifications found. Adding base modification plots to the HTML report.") + qc_info_list.insert(1, "read_length_mod_rates") # Read length modification rates qc_info_list.insert(1, "base_mods") # If gene BED file was provided, add the TIN plots @@ -298,6 +308,7 @@ def rrms_module(margs): # Set the output prefix param_dict["out_prefix"] = output_prefix + "rrms_" + ("accepted" if filter_type else "rejected") + param_dict["mod"] = input_para.mod_analysis = False # Disable base modification analysis for RRMS (use BAM module for this) # Run the QC module logging.info("Running QC for " + ("accepted" if filter_type else "rejected") + " reads...") @@ -308,10 +319,19 @@ def rrms_module(margs): logging.info("Generating HTML report...") plot_filepaths = plot(bam_output, param_dict, 'BAM') + # Set the list of QC information to display + qc_info_list = ["basic_st", "read_alignments_bar", "base_alignments_bar", "read_length_bar", "read_length_hist", "gc_content_hist", "base_counts", "base_quality"] + + # If base modifications were found, add the base modification + # plots + if bam_output.sample_modified_base_count > 0: + logging.info("Base modifications found. Adding base modification plots to the HTML report.") + qc_info_list.insert(1, "read_length_mod_rates") + qc_info_list.insert(1, "base_mods") + # Generate the HTML report bam_html_gen = generate_html.ST_HTML_Generator( - [["basic_st", "read_alignments_bar", "base_alignments_bar", "read_length_bar", "read_length_hist", "base_counts", "basic_info", - "base_quality"], "BAM QC", param_dict], plot_filepaths, static=False) + [qc_info_list, "BAM QC", param_dict], plot_filepaths, static=False) bam_html_gen.generate_html() logging.info("Done. Output files are in %s", param_dict["output_folder"]) @@ -347,7 +367,7 @@ def seqtxt_module(margs): report_title = "Basecall Summary QC" seqtxt_html_gen = generate_html.ST_HTML_Generator( - [["basic_st", "read_length_bar", "read_length_hist", "basic_info"], + [["basic_st", "read_length_bar", "read_length_hist"], report_title, param_dict], plot_filepaths, static=False) seqtxt_html_gen.generate_html() @@ -383,8 +403,8 @@ def fast5_module(margs): logging.info("Generating HTML report...") plot_filepaths = plot(fast5_output, param_dict, 'FAST5') fast5_html_obj = generate_html.ST_HTML_Generator( - [["basic_st", "read_length_bar", "read_length_hist", "base_counts", "basic_info", "base_quality", - "read_avg_base_quality"], "FAST5 QC", param_dict], plot_filepaths, static=False) + [["basic_st", "read_length_bar", "read_length_hist", "gc_content_hist", "base_counts", "base_quality"], + "FAST5 QC", param_dict], plot_filepaths, static=False) fast5_html_obj.generate_html() logging.info("Done. Output files are in %s", param_dict["output_folder"]) @@ -429,8 +449,7 @@ def fast5_signal_module(margs): logging.info("Generating HTML report...") plot_filepaths = plot(fast5_output, param_dict, 'FAST5s') fast5_html_obj = generate_html.ST_HTML_Generator( - [["basic_st", "read_length_bar", "read_length_hist", "base_counts", "basic_info", "base_quality", - "read_avg_base_quality", "ont_signal"], "FAST5 QC", param_dict], plot_filepaths, static=False) + [["basic_st", "read_length_bar", "read_length_hist", "gc_content_hist", "base_counts", "ont_signal"], "FAST5 QC", param_dict], plot_filepaths, static=False) fast5_html_obj.generate_html(signal_plots=True) logging.info("Done. Output files are in %s", param_dict["output_folder"]) @@ -438,25 +457,6 @@ def fast5_signal_module(margs): logging.error("QC did not generate.") -def set_file_parser_defaults(file_parser): - """Create a parser with default arguments for a specific filetype.""" - file_parser.add_argument("-i", "--input", type=argparse.FileType('r'), default=None, - help="Single input filepath") - file_parser.add_argument("-I", "--inputs", type=str, default=None, - help="Multiple comma-separated input filepaths") - file_parser.add_argument("-P", "--pattern", type=str, default=None, - help="Use pattern matching (*) to specify multiple input files. Enclose the pattern in double quotes.") - file_parser.add_argument("-g", "--log", type=str, default="log_output.log", - help="Log file") - file_parser.add_argument("-G", "--log-level", type=int, default=2, - help="Logging level. 1: DEBUG, 2: INFO, 3: WARNING, 4: ERROR, 5: CRITICAL. Default: 2.") - file_parser.add_argument("-o", "--outputfolder", type=str, default="output_" + prg_name, - help="The output folder.") - file_parser.add_argument("-t", "--threads", type=int, default=1, - help="The number of threads used. Default: 1.") - file_parser.add_argument("-Q", "--outprefix", type=str, default="QC_", - help="The prefix for output filenames. Default: `QC_`.") - def pod5_module(margs): """POD5 file input module.""" # Get the filetype-specific parameters @@ -517,13 +517,32 @@ def pod5_module(margs): # plot_filepaths = plot(read_signal_dict, param_dict, 'POD5') webpage_title = "POD5 QC" fast5_html_obj = generate_html.ST_HTML_Generator( - [["basic_st", "read_length_bar", "read_length_hist", "base_counts", "basic_info", "base_quality", - "read_avg_base_quality", "ont_signal"], webpage_title, param_dict], plot_filepaths, static=False) + [["basic_st", "read_length_bar", "read_length_hist", "gc_content_hist", "base_counts", "ont_signal"], webpage_title, param_dict], plot_filepaths, static=False) fast5_html_obj.generate_html(signal_plots=True) logging.info("Done. Output files are in %s", param_dict["output_folder"]) else: logging.error("QC did not generate.") + + +def set_file_parser_defaults(file_parser): + """Create a parser with default arguments for a specific filetype.""" + file_parser.add_argument("-i", "--input", type=argparse.FileType('r'), default=None, + help="Single input filepath") + file_parser.add_argument("-I", "--inputs", type=str, default=None, + help="Multiple comma-separated input filepaths") + file_parser.add_argument("-P", "--pattern", type=str, default=None, + help="Use pattern matching (*) to specify multiple input files. Enclose the pattern in double quotes.") + file_parser.add_argument("-g", "--log", type=str, default="log_output.log", + help="Log file") + file_parser.add_argument("-G", "--log-level", type=int, default=2, + help="Logging level. 1: DEBUG, 2: INFO, 3: WARNING, 4: ERROR, 5: CRITICAL. Default: 2.") + file_parser.add_argument("-o", "--outputfolder", type=str, default="output_" + prg_name, + help="The output folder.") + file_parser.add_argument("-t", "--threads", type=int, default=1, + help="The number of threads used. Default: 1.") + file_parser.add_argument("-Q", "--outprefix", type=str, default="QC_", + help="The prefix for output filenames. Default: `QC_`.") # Set up the argument parser @@ -635,8 +654,8 @@ def pod5_module(margs): bam_parser.add_argument("--genebed", type=str, default="", help="Gene BED12 file required for calculating TIN scores from RNA-seq BAM files. Default: None.") -bam_parser.add_argument("--modprob", type=float, default=0.8, - help="Base modification filtering threshold. Above/below this value, the base is considered modified/unmodified. Default: 0.8.") +bam_parser.add_argument("--modprob", type=float, default=0.5, + help="Base modification filtering threshold. Above/below this value, the base is considered modified/unmodified. Default: 0.5.") bam_parser.add_argument("--ref", type=str, default="", help="The reference genome FASTA file to use for identifying CpG sites.") diff --git a/src/fast5_module.cpp b/src/fast5_module.cpp index ceab46f..11b2909 100644 --- a/src/fast5_module.cpp +++ b/src/fast5_module.cpp @@ -470,12 +470,6 @@ static int writeSignalQCDetails(const char *input_file, Output_FAST5 &output_dat { int exit_code = 0; -// // Open the CSV files -// std::ofstream raw_csv; -// raw_csv.open(signal_raw_csv); -// std::ofstream qc_csv; -// qc_csv.open(signal_qc_csv); - // Run QC on the HDF5 file //H5::Exception::dontPrint(); // Disable error printing try { @@ -554,11 +548,7 @@ static int writeSignalQCDetails(const char *input_file, Output_FAST5 &output_dat catch (std::exception& e) { std::cerr << "Exception caught : " << e.what() << std::endl; } - -// // Close the CSV files -// raw_csv.close(); -// qc_csv.close(); - + return exit_code; } diff --git a/src/fasta_module.cpp b/src/fasta_module.cpp index 666d369..8b48371 100644 --- a/src/fasta_module.cpp +++ b/src/fasta_module.cpp @@ -6,6 +6,7 @@ FASTA_module.cpp: #include // #include #include +#include // std::round #include #include @@ -92,8 +93,11 @@ static int qc1fasta(const char *input_file, Output_FA &py_output_fa, FILE *read_ long_read_info.total_num_bases += base_count; long_read_info.total_n_cnt += n_count; - read_gc_cnt = 100.0 * gc_count / (double)base_count; - long_read_info.read_gc_content_count[(int)(read_gc_cnt + 0.5)] += 1; + + // Update the per-read GC content distribution + double gc_content_pct = (100.0 * gc_count) / static_cast(base_count); + int gc_content_int = static_cast(std::round(gc_content_pct)); + long_read_info.read_gc_content_count[gc_content_int] += 1; // Remove the newline character from the sequence data size_t pos = sequence_data_str.find_first_of("\r\n"); @@ -168,10 +172,12 @@ static int qc1fasta(const char *input_file, Output_FA &py_output_fa, FILE *read_ long_read_info.read_length_count[(int)base_count] += 1; } - long_read_info.total_num_bases += base_count; - long_read_info.total_n_cnt += n_count; - read_gc_cnt = 100.0 * gc_count / (double)base_count; - long_read_info.read_gc_content_count[(int)(read_gc_cnt + 0.5)] += 1; + long_read_info.total_num_bases += base_count; // Update the total number of bases + + // Update the per-read GC content distribution + double gc_content_pct = (100.0 * gc_count) / static_cast(base_count); + int gc_content_int = static_cast(std::round(gc_content_pct)); + long_read_info.read_gc_content_count[gc_content_int] += 1; // Remove the newline character from the sequence data size_t pos = sequence_data_str.find_first_of("\r\n"); diff --git a/src/fastq_module.cpp b/src/fastq_module.cpp index c45a79d..3d87a32 100644 --- a/src/fastq_module.cpp +++ b/src/fastq_module.cpp @@ -1,28 +1,33 @@ +#include "fastq_module.h" + +#include #include #include -#include -#include -#include -#include +#include // std::sort +#include // std::round + #include -#include +#include +#include -#include "fastq_module.h" +#include +#include +#include "utils.h" int qc1fastq(const char *input_file, char fastq_base_qual_offset, Output_FQ &output_data, FILE *read_details_fp) { int exit_code = 0; int read_len; double read_gc_cnt; - double read_mean_base_qual; Basic_Seq_Statistics &long_read_info = output_data.long_read_info; Basic_Seq_Quality_Statistics &seq_quality_info = output_data.seq_quality_info; long_read_info.total_num_reads = ZeroDefault; // total number of long reads long_read_info.longest_read_length = ZeroDefault; // the length of longest reads std::ifstream input_file_stream(input_file); + int count = 0; if (!input_file_stream.is_open()) { fprintf(stderr, "Failed to open file for reading: %s\n", input_file); @@ -33,6 +38,7 @@ int qc1fastq(const char *input_file, char fastq_base_qual_offset, Output_FQ &out { if (line[0] == '@') { + count++; read_name = line.substr(1); read_name = read_name.substr(0, read_name.find_first_of(" \t")); std::getline(input_file_stream, read_seq); @@ -58,10 +64,29 @@ int qc1fastq(const char *input_file, char fastq_base_qual_offset, Output_FQ &out // Store the read length long_read_info.read_lengths.push_back(read_len); + // Access base quality data + char value; + std::vector base_quality_values; + std::istringstream iss(raw_read_qual); + while (iss >> value) + { + int base_quality_value = value - '!'; + base_quality_values.push_back(base_quality_value); + } + + // Ensure that the base quality string has the same length as + // the read sequence + if (base_quality_values.size() != read_len) + { + printError("Error: Base quality string length does not match read sequence length"); + exit_code = 1; + break; + } + // Process base and quality information read_gc_cnt = 0; - read_mean_base_qual = 0; - uint64_t base_quality_value; + int base_quality_value; + double cumulative_base_prob = 0; // Read cumulative base quality probability for (int i = 0; i < read_len; i++) { if (read_seq[i] == 'A' || read_seq[i] == 'a') @@ -82,15 +107,46 @@ int qc1fastq(const char *input_file, char fastq_base_qual_offset, Output_FQ &out { long_read_info.total_tu_cnt += 1; } - base_quality_value = (uint64_t)raw_read_qual[i] - (uint64_t)fastq_base_qual_offset; - seq_quality_info.base_quality_distribution[base_quality_value] += 1; - read_mean_base_qual += (double) base_quality_value; + + // Get the base quality (Phred) value + base_quality_value = base_quality_values[i]; + try { + seq_quality_info.base_quality_distribution[base_quality_value] += 1; + } catch (const std::out_of_range& oor) { + printError("Warning: Base quality value " + std::to_string(base_quality_value) + " exceeds maximum value"); + } + + // Convert the Phred quality value to a probability + double base_quality_prob = pow(10, -base_quality_value / 10.0); + cumulative_base_prob += base_quality_prob; } - read_gc_cnt = 100.0 * read_gc_cnt / (double)read_len; - long_read_info.read_gc_content_count[(int)(read_gc_cnt + 0.5)] += 1; - read_mean_base_qual /= (double) read_len; - seq_quality_info.read_average_base_quality_distribution[(uint)(read_mean_base_qual + 0.5)] += 1; - fprintf(read_details_fp, "%s\t%d\t%.2f\t%.2f\n", read_name.c_str(), read_len, read_gc_cnt, read_mean_base_qual); + + // Calculate the mean base quality probability + cumulative_base_prob /= (double)read_len; + + // Convert the mean base quality probability to a Phred quality + // value + double read_mean_base_qual = -10.0 * log10(cumulative_base_prob); + + // Update the per-read base quality distribution + int read_mean_base_qual_int = static_cast(std::round(read_mean_base_qual)); + try { + seq_quality_info.read_average_base_quality_distribution[read_mean_base_qual_int] += 1; + } catch (const std::out_of_range& oor) { + printError("Warning: Base quality value " + std::to_string(read_mean_base_qual_int) + " exceeds maximum value"); + } + + // Update the per-read GC content distribution + double gc_content_pct = (100.0 * read_gc_cnt) / static_cast(read_len); + int gc_content_int = static_cast(std::round(gc_content_pct)); + try { + long_read_info.read_gc_content_count[gc_content_int] += 1; + } catch (const std::out_of_range& oor) { + printError("Warning: Invalid GC content value " + std::to_string(gc_content_int)); + } + + // Write read details to file + fprintf(read_details_fp, "%s\t%d\t%.2f\t%.2f\n", read_name.c_str(), read_len, gc_content_pct, read_mean_base_qual); } } input_file_stream.close(); @@ -140,10 +196,7 @@ int qc_fastq_files(Input_Para &_input_data, Output_FQ &output_data) output_data.long_read_info.NXX_read_length.resize(101, 0); // NXX_read_length[50] means N50 read length; NXX_read_length[95] means N95 read length; - //output_data.seq_quality_info.base_quality_distribution.resize(256, 0); - // base_quality_distribution[x] means number of bases that quality = x. - - output_data.seq_quality_info.read_average_base_quality_distribution.resize(256, 0); + output_data.seq_quality_info.read_average_base_quality_distribution.resize(MAX_BASE_QUALITY, 0); if (_input_data.user_defined_fastq_base_qual_offset > 0) { fastq_base_qual_offset = _input_data.user_defined_fastq_base_qual_offset; diff --git a/src/generate_html.py b/src/generate_html.py index dcc5be7..dce7237 100644 --- a/src/generate_html.py +++ b/src/generate_html.py @@ -218,6 +218,48 @@ def generate_header(self): li { margin: 10px 0; } +.help-icon { + position: relative; + display: inline-block; + cursor: pointer; + color: #555; + font-size: 18px; /* Adjust size of the icon */ + margin-top: 10px; /* Adjust spacing if needed */ +} + +.help-icon:hover .tooltip { + visibility: visible; + opacity: 1; +} + +.tooltip { + visibility: hidden; + width: 200px; + background-color: #333; + color: #fff; + text-align: left; + border-radius: 4px; + padding: 8px; + font-size: 14px; + position: absolute; + top: 50%; /* Position the tooltip */ + left: 120%; /* Position the tooltip */ + transform: translateY(-50%); + opacity: 0; + transition: opacity 0.3s; + z-index: 1; +} + +.tooltip::after { + content: ''; + position: absolute; + top: 50%; /* Position the arrow in the middle of the tooltip */ + left: 0; /* Position the arrow on the left edge of the tooltip */ + transform: translateY(-50%); + border-width: 5px; + border-style: solid; + border-color: #333 transparent transparent transparent; +} ''') self.html_writer.write("") @@ -237,18 +279,33 @@ def generate_left(self): self.html_writer.write('

Summary

') self.html_writer.write('
    ') + # Define ASCII/Unicode icons for error flags + error_flag_icon = { + True: "⚠", + False: "✔", + } + # Add links to the right sections key_index = 0 for plot_key in self.image_key_list: - self.html_writer.write('
  • ') + # Determine the flag icon + try: + flag = self.plot_filepaths[plot_key]['error_flag'] + except KeyError: + flag = False + + flag_icon = error_flag_icon[flag] + self.html_writer.write('
  • ') + self.html_writer.write(f'{flag_icon} ') self.html_writer.write( '' + self.plot_filepaths[plot_key]['title'] + '') + key_index += 1 self.html_writer.write('
  • ') # Add the input files section link - self.html_writer.write('
  • ') + self.html_writer.write('
  • ') self.html_writer.write('Input File List') key_index += 1 self.html_writer.write('
  • ') @@ -277,17 +334,21 @@ def generate_right(self): self.html_writer.write(dynamic_plot) except KeyError: - logging.error("Missing dynamic plot for %s", plot_key) + # See if an image is available + try: + image_path = self.plot_filepaths[plot_key]['file'] + self.html_writer.write(f'{plot_key}') + except KeyError: + logging.error("Missing plot for %s", plot_key) self.html_writer.write('') key_index += 1 self.html_writer.write('
    ') - self.html_writer.write('

    File count = ' + str( + self.html_writer.write('

    File Count = ' + str( len(self.input_para["input_files"])) + '

    ') - for _af in self.input_para["input_files"]: - self.html_writer.write("
    " + _af) + self.html_writer.write("
    " + "
    ".join([f"{i+1}.\t{af}" for i, af in enumerate(self.input_para["input_files"])])) self.html_writer.write('

    ') key_index += 1 diff --git a/src/hts_reader.cpp b/src/hts_reader.cpp index 31e9ed9..70a6410 100644 --- a/src/hts_reader.cpp +++ b/src/hts_reader.cpp @@ -12,6 +12,7 @@ Class for reading a set number of records from a BAM file. Used for multi-thread #include #include #include // std::find +#include #include #include "utils.h" @@ -35,51 +36,83 @@ HTSReader::~HTSReader(){ } // Update read and base counts -int HTSReader::updateReadAndBaseCounts(bam1_t* record, Basic_Seq_Statistics *basic_qc, uint64_t *base_quality_distribution){ - int exit_code = 0; - - // Update the total number of reads - basic_qc->total_num_reads++; +int HTSReader::updateReadAndBaseCounts(bam1_t* record, Basic_Seq_Statistics& basic_qc, Basic_Seq_Quality_Statistics& seq_quality_info, bool is_primary) { - // Update read length statistics + // Update read QC + basic_qc.total_num_reads++; // Update the total number of reads int read_length = (int) record->core.l_qseq; - basic_qc->total_num_bases += (uint64_t) read_length; // Update the total number of bases - basic_qc->read_lengths.push_back(read_length); + basic_qc.total_num_bases += (uint64_t) read_length; // Update the total number of bases + basic_qc.read_lengths.push_back(read_length); - // Loop and count the number of each base + // Get base counts, quality, and GC content + double read_gc_count = 0.0; // For GC content calculation + double read_base_total = 0.0; // For GC content calculation + double cumulative_base_prob = 0.0; // For mean base quality probability calculation uint8_t *seq = bam_get_seq(record); for (int i = 0; i < read_length; i++) { // Get the base quality and update the base quality histogram - uint64_t base_quality = (uint64_t)bam_get_qual(record)[i]; - base_quality_distribution[base_quality]++; + int base_quality = (int)bam_get_qual(record)[i]; + seq_quality_info.base_quality_distribution[(uint64_t)base_quality]++; + + // Convert the Phred quality value to a probability + double base_quality_prob = pow(10, -base_quality / 10.0); + cumulative_base_prob += base_quality_prob; // Get the base and update the base count char base = seq_nt16_str[bam_seqi(seq, i)]; switch (base) { case 'A': - basic_qc->total_a_cnt++; + basic_qc.total_a_cnt++; + read_base_total++; break; case 'C': - basic_qc->total_c_cnt++; + basic_qc.total_c_cnt++; + read_gc_count++; + read_base_total++; break; case 'G': - basic_qc->total_g_cnt++; + basic_qc.total_g_cnt++; + read_gc_count++; + read_base_total++; break; case 'T': - basic_qc->total_tu_cnt++; + basic_qc.total_tu_cnt++; + read_base_total++; break; case 'N': - basic_qc->total_n_cnt++; + basic_qc.total_n_cnt++; std::cerr << "Warning: N base found in read " << bam_get_qname(record) << std::endl; break; default: - std::cerr << "Error reading nucleotide: " << base << std::endl; - exit_code = 1; + printError("Invalid base: " + std::to_string(base)); break; } } - return exit_code; + // Calculate the mean base quality probability + cumulative_base_prob /= (double)read_length; + + // Convert the mean base quality probability to a Phred quality value + double read_mean_base_qual = -10.0 * log10(cumulative_base_prob); + + // Update the per-read mean base quality distribution + int read_mean_base_qual_int = static_cast(std::round(read_mean_base_qual)); + try { + seq_quality_info.read_average_base_quality_distribution[read_mean_base_qual_int]++; + } catch (const std::out_of_range& oor) { + printError("Warning: Base quality value " + std::to_string(read_mean_base_qual_int) + " exceeds maximum value"); + } + + // Calculate the read GC content percentage if a primary alignment + if (is_primary) { + double gc_content = read_gc_count / read_base_total; + int gc_content_percent = (int) round(gc_content * 100); + std::string query_name = bam_get_qname(record); + // printMessage("Read name: " + query_name + ", GC content: " + std::to_string(gc_content) + ", GC count: " + std::to_string(read_gc_count) + ", Total count: " + std::to_string(read_base_total)); + basic_qc.read_gc_content_count[gc_content_percent]++; + } + + return 0; } // Read the next batch of records from the BAM file and store QC in the output_data object @@ -91,10 +124,10 @@ int HTSReader::readNextRecords(int batch_size, Output_BAM & output_data, std::mu bool read_ids_present = false; if (read_ids.size() > 0){ read_ids_present = true; - printMessage("Filtering reads by read ID"); + // printMessage("Filtering reads by read ID"); - printMessage("Number of read IDs: " + std::to_string(read_ids.size())); - printMessage("First read ID: " + *read_ids.begin()); + // printMessage("Number of read IDs: " + std::to_string(read_ids.size())); + // printMessage("First read ID: " + *read_ids.begin()); // Check if the first read ID has any newlines, carriage returns, tabs, // or spaces if (read_ids.begin()->find_first_of("\n\r\t ") != std::string::npos) { @@ -103,9 +136,6 @@ int HTSReader::readNextRecords(int batch_size, Output_BAM & output_data, std::mu } } - // Access the base quality histogram from the output_data object - uint64_t *base_quality_distribution = output_data.seq_quality_info.base_quality_distribution; - // Do QC on each record and store the results in the output_data object while ((record_count < batch_size) && (exit_code >= 0)) { // Create a record object @@ -154,7 +184,7 @@ int HTSReader::readNextRecords(int batch_size, Output_BAM & output_data, std::mu // Set the atomic flag and print a message if the POD5 tags are // present if (!this->has_pod5_tags.test_and_set()) { - printMessage("POD5 tags found (ts, ns, mv)"); + printMessage("POD5 basecall move table tags found (ts, ns, mv)"); } // Get the ts and ns tags @@ -193,25 +223,23 @@ int HTSReader::readNextRecords(int batch_size, Output_BAM & output_data, std::mu output_data.addReadMoveTable(query_name, seq_str, signal_index_vector, ts, ns); } - // Determine if this is an unmapped read + // Unmapped reads if (record->core.flag & BAM_FUNMAP) { - Basic_Seq_Statistics *basic_qc = &output_data.unmapped_long_read_info; - - // Update read and base QC - this->updateReadAndBaseCounts(record, basic_qc, base_quality_distribution); + Basic_Seq_Statistics& basic_qc = output_data.unmapped_long_read_info; + Basic_Seq_Quality_Statistics& seq_quality_info = output_data.unmapped_seq_quality_info; + this->updateReadAndBaseCounts(record, basic_qc, seq_quality_info, false); } else { - // Set up the basic QC object - Basic_Seq_Statistics *basic_qc = &output_data.mapped_long_read_info; - // Calculate base alignment statistics on non-secondary alignments + Basic_Seq_Statistics& basic_qc = output_data.mapped_long_read_info; + Basic_Seq_Quality_Statistics& seq_quality_info = output_data.seq_quality_info; if (!(record->core.flag & BAM_FSECONDARY)) { // Determine if this is a forward or reverse read if (record->core.flag & BAM_FREVERSE) { - output_data.forward_alignment++; - } else { output_data.reverse_alignment++; + } else { + output_data.forward_alignment++; } // Loop through the cigar string and count the number of insertions, deletions, and matches @@ -258,7 +286,7 @@ int HTSReader::readNextRecords(int batch_size, Output_BAM & output_data, std::mu output_data.num_mismatched_bases += num_mismatches; } - // Determine if this is a secondary alignment (not included in QC, only read count) + // Secondary alignment (not included in QC, only read count) if (record->core.flag & BAM_FSECONDARY) { output_data.num_secondary_alignment++; @@ -268,7 +296,7 @@ int HTSReader::readNextRecords(int batch_size, Output_BAM & output_data, std::mu // Update the read's secondary alignments (count once per read) output_data.reads_with_secondary[query_name] = true; - // Determine if this is a supplementary alignment (not included in QC, only read count) + // Supplementary alignment (not included in QC, only read count) } else if (record->core.flag & BAM_FSUPPLEMENTARY) { output_data.num_supplementary_alignment++; @@ -278,7 +306,7 @@ int HTSReader::readNextRecords(int batch_size, Output_BAM & output_data, std::mu // Update the read's supplementary alignments (count once per read) output_data.reads_with_supplementary[query_name] = true; - // Determine if this is a primary alignment + // Primary alignment } else if (!(record->core.flag & BAM_FSECONDARY || record->core.flag & BAM_FSUPPLEMENTARY)) { output_data.num_primary_alignment++; // Update the number of primary alignments @@ -318,19 +346,10 @@ int HTSReader::readNextRecords(int batch_size, Output_BAM & output_data, std::mu break; } } - - // Update read and base QC - this->updateReadAndBaseCounts(record, basic_qc, base_quality_distribution); - - // Calculate the percent GC content - int percent_gc = round((basic_qc->total_g_cnt + basic_qc->total_c_cnt) / (double) (basic_qc->total_a_cnt + basic_qc->total_c_cnt + basic_qc->total_g_cnt + basic_qc->total_tu_cnt) * 100); - - // Update the GC content histogram - basic_qc->read_gc_content_count.push_back(percent_gc); + this->updateReadAndBaseCounts(record, basic_qc, seq_quality_info, true); } else { - std::cerr << "Error: Unknown alignment type" << std::endl; - std::cerr << "Flag: " << record->core.flag << std::endl; + printError("Error: Unknown alignment type with flag " + std::to_string(record->core.flag)); } } @@ -349,131 +368,182 @@ bool HTSReader::hasNextRecord(){ } // Return the number of records in the BAM file using the BAM index -int64_t HTSReader::getNumRecords(const std::string & bam_filename, Output_BAM &final_output, bool mod_analysis, double base_mod_threshold) { +int HTSReader::getNumRecords(const std::string& bam_filename, int thread_count) { samFile* bam_file = sam_open(bam_filename.c_str(), "r"); + hts_set_threads(bam_file, thread_count); // Enable multi-threading bam_hdr_t* bam_header = sam_hdr_read(bam_file); bam1_t* bam_record = bam_init1(); - - int64_t num_reads = 0; + int num_reads = 0; while (sam_read1(bam_file, bam_header, bam_record) >= 0) { num_reads++; + } - if (mod_analysis) { - - // Base modification tag analysis - // Follow here to get base modification tags: - // https://github.com/samtools/htslib/blob/11205a9ba5e4fc39cc8bb9844d73db2a63fb8119/sam_mods.c - // https://github.com/samtools/htslib/blob/11205a9ba5e4fc39cc8bb9844d73db2a63fb8119/htslib/sam.h#L2274 - hts_base_mod_state *state = hts_base_mod_state_alloc(); - - // Preprint revisions: New data structure that does not require unique - // positions for each base modification - // chr -> vector of (position, strand) for C modified bases passing the threshold - std::vector> c_modified_positions; - - // Parse the base modification tags if a primary alignment - int ret = bam_parse_basemod(bam_record, state); - if (ret >= 0 && !(bam_record->core.flag & BAM_FSECONDARY) && !(bam_record->core.flag & BAM_FSUPPLEMENTARY) && !(bam_record->core.flag & BAM_FUNMAP)) { - - // Get the chromosome if alignments are present - bool alignments_present = true; - std::string chr; - std::map query_to_ref_map; - if (bam_record->core.tid < 0) { - alignments_present = false; - } else { - chr = bam_header->target_name[bam_record->core.tid]; + // Close the BAM file + bam_destroy1(bam_record); + bam_hdr_destroy(bam_header); + sam_close(bam_file); - // Get the query to reference position mapping - query_to_ref_map = this->getQueryToRefMap(bam_record); - } + return num_reads; +} - // Get the strand from the alignment flag (hts_base_mod uses 0 for positive and 1 for negative, - // but it always yields 0...) - int strand = (bam_record->core.flag & BAM_FREVERSE) ? 1 : 0; - - // Iterate over the state object to get the base modification tags - // using bam_next_basemod - hts_base_mod mods[10]; - int n = 0; - int32_t pos = 0; - std::vector query_pos; - while ((n=bam_next_basemod(bam_record, state, mods, 10, &pos)) > 0) { - for (int i = 0; i < n; i++) { - // Update the prediction count - final_output.modified_prediction_count++; - - // Note: The modified base value can be a positive char (e.g. 'm', - // 'h') (DNA Mods DB) or negative integer (ChEBI ID): - // https://github.com/samtools/hts-specs/issues/741 - // DNA Mods: https://dnamod.hoffmanlab.org/ - // ChEBI: https://www.ebi.ac.uk/chebi/searchId.do?chebiId=CHEBI:21839 - // Header line: - // https://github.com/samtools/htslib/blob/11205a9ba5e4fc39cc8bb9844d73db2a63fb8119/htslib/sam.h#L2215 - - // Determine the probability of the modification (-1 if - // unknown) - double probability = -1; - if (mods[i].qual != -1) { - probability = mods[i].qual / 256.0; - - // If the probability is greater than the threshold, - // update the count - if (probability >= base_mod_threshold) { - final_output.sample_modified_base_count++; - - // Update the modified base count for the strand - if (strand == 0) { - final_output.sample_modified_base_count_forward++; - } else { - final_output.sample_modified_base_count_reverse++; - } +void HTSReader::runBaseModificationAnalysis(const std::string &bam_filename, Output_BAM &final_output, double base_mod_threshold, int read_count, int sample_count, int thread_count) +{ + samFile* bam_file = sam_open(bam_filename.c_str(), "r"); + hts_set_threads(bam_file, thread_count); // Enable multi-threading + bam_hdr_t* bam_header = sam_hdr_read(bam_file); + bam1_t* bam_record = bam_init1(); + int64_t read_index = 0; + + // Create a random number generator and seed it with the current time + unsigned seed = std::chrono::system_clock::now().time_since_epoch().count(); + std::default_random_engine generator(seed); + + // Create a list of read indices to sample, and only keep the first + // sample_count reads + std::vector read_indices; + for (int i = 0; i < read_count; i++) { + read_indices.push_back(i); + } + std::shuffle(read_indices.begin(), read_indices.end(), generator); + read_indices.resize(sample_count); + std::unordered_set read_indices_set(read_indices.begin(), read_indices.end()); + printMessage("Number of sampled reads for base modification analysis = " + std::to_string(read_indices_set.size())); + + while (sam_read1(bam_file, bam_header, bam_record) >= 0) { + + // if (read_indices_set.find(read_index) == read_indices_set.end()) { + // read_index++; + // continue; + // } + + // Base modification tag analysis + // Follow here to get base modification tags: + // https://github.com/samtools/htslib/blob/11205a9ba5e4fc39cc8bb9844d73db2a63fb8119/sam_mods.c + // https://github.com/samtools/htslib/blob/11205a9ba5e4fc39cc8bb9844d73db2a63fb8119/htslib/sam.h#L2274 + int read_length = bam_record->core.l_qseq; + hts_base_mod_state *state = hts_base_mod_state_alloc(); + std::vector> c_modified_positions; // C-modified positions for CpG analysis (chr->(position, strand)) + std::unordered_map> base_mod_counts; // Type-specific base modification probabilities (canonical base -> modified base -> [read length %, probability]) + std::unordered_map base_primary_count; // Total base counts for the alignment + + // Parse the base modification tags if a primary alignment + int read_mod_count = 0; + int ret = bam_parse_basemod(bam_record, state); + bool is_primary = !(bam_record->core.flag & BAM_FSECONDARY) && !(bam_record->core.flag & BAM_FSUPPLEMENTARY) && !(bam_record->core.flag & BAM_FUNMAP); + + if (ret >= 0 && is_primary) { + // Get the chromosome if alignments are present + bool alignments_present = true; + std::string chr; + std::map query_to_ref_map; + if (bam_record->core.tid < 0) { + alignments_present = false; + } else { + chr = bam_header->target_name[bam_record->core.tid]; + + // Get the query to reference position mapping + query_to_ref_map = this->getQueryToRefMap(bam_record); + } + + // Get the strand from the alignment flag (hts_base_mod uses 0 for positive and 1 for negative, + // but it always yields 0...) + int strand = (bam_record->core.flag & BAM_FREVERSE) ? 1 : 0; - // Preprint revisions: Store the modified positions - char canonical_base_char = std::toupper(mods[i].canonical_base); - char mod_type = mods[i].modified_base; - if (canonical_base_char == 'C' && mod_type == 'm') { - - // Convert the query position to reference position if available - if (alignments_present) { - if (query_to_ref_map.find(pos) != query_to_ref_map.end()) { - int32_t ref_pos = query_to_ref_map[pos]; - c_modified_positions.push_back(std::make_pair(ref_pos, strand)); - } + // Get the number of each type of base for the read + uint8_t *seq = bam_get_seq(bam_record); + for (int i = 0; i < read_length; i++) { + char base = seq_nt16_str[bam_seqi(seq, i)]; + base_primary_count[std::toupper(base)]++; + } + + // Iterate over the state object to get the base modification tags + // using bam_next_basemod + hts_base_mod mods[10]; + int n = 0; + int32_t pos = 0; + std::vector query_pos; + bool first_mod_found = false; + while ((n=bam_next_basemod(bam_record, state, mods, 10, &pos)) > 0) { + + for (int i = 0; i < n; i++) { + // Update the modified prediction counts + read_mod_count++; // Read-specific count + final_output.modified_prediction_count++; // Cumulative count + char canonical_base_char = std::toupper(mods[i].canonical_base); + char mod_type = mods[i].modified_base; + // base_mod_counts[mod_type]++; // Update the type-specific count + + // Note: The modified base value can be a positive char (e.g. 'm', + // 'h') (DNA Mods DB) or negative integer (ChEBI ID): + // https://github.com/samtools/hts-specs/issues/741 + // DNA Mods: https://dnamod.hoffmanlab.org/ + // ChEBI: https://www.ebi.ac.uk/chebi/searchId.do?chebiId=CHEBI:21839 + // Header line: + // https://github.com/samtools/htslib/blob/11205a9ba5e4fc39cc8bb9844d73db2a63fb8119/htslib/sam.h#L2215 + + // Determine the probability of the modification (-1 if + // unknown) + double probability = -1; + if (mods[i].qual != -1) { + probability = mods[i].qual / 256.0; + + // Update the read length % and probability for the + // modification + double read_len_pct = (double) (pos + 1) / read_length; + // std::cout << "Read length %: " << read_len_pct << ", + // probability: " << probability << std::endl; + + // Update the base modification probabilities for + // sampled reads only (10,000 maximum) + if (read_indices_set.find(read_index) != read_indices_set.end()) { + final_output.updateBaseModProbabilities(mod_type, read_len_pct, probability); // Update the base modification probabilities + } + + // Update counts for predictions exceeding the threshold + if (probability >= base_mod_threshold) { + final_output.updateBaseModCounts(mod_type, strand); // Update the base modification counts + + // Store the modified positions for later CpG + // analysis if a C modification on a primary alignment + if (canonical_base_char == 'C' && mod_type != 'C') { + + // Convert the query position to reference position if available + if (alignments_present) { + if (query_to_ref_map.find(pos) != query_to_ref_map.end()) { + int32_t ref_pos = query_to_ref_map[pos]; + c_modified_positions.push_back(std::make_pair(ref_pos, strand)); } } } } } } + } - // Preprint revisions: Append the modified positions to the output data - if (c_modified_positions.size() > 0) { - // Set the atomic flag and print a message if base - // modification tags are present in the file - if (!this->has_mm_ml_tags.test_and_set()) { - printMessage("Base modification data found (MM, ML tags)"); - } + // Append the modified positions to the output data + if (c_modified_positions.size() > 0) { + // Set the atomic flag and print a message if base + // modification tags are present in the file + if (!this->has_mm_ml_tags.test_and_set()) { + printMessage("Base modification data found (MM, ML tags)"); + } - // Add the modified positions to the output data - if (final_output.sample_c_modified_positions.find(chr) == final_output.sample_c_modified_positions.end()) { - final_output.sample_c_modified_positions[chr] = c_modified_positions; - } else { - final_output.sample_c_modified_positions[chr].insert(final_output.sample_c_modified_positions[chr].end(), c_modified_positions.begin(), c_modified_positions.end()); - } + // Add the modified positions to the output data + if (final_output.sample_c_modified_positions.find(chr) == final_output.sample_c_modified_positions.end()) { + final_output.sample_c_modified_positions[chr] = c_modified_positions; + } else { + final_output.sample_c_modified_positions[chr].insert(final_output.sample_c_modified_positions[chr].end(), c_modified_positions.begin(), c_modified_positions.end()); } } - - // Deallocate the state object - hts_base_mod_state_free(state); } + hts_base_mod_state_free(state); // Deallocate the base modification state object + + read_index++; // Update the read index } bam_destroy1(bam_record); bam_hdr_destroy(bam_header); sam_close(bam_file); - - return num_reads; } // Get the mapping of query positions to reference positions for a given alignment record @@ -503,7 +573,6 @@ std::map HTSReader::getQueryToRefMap(bam1_t *record) query_to_ref_map[current_query_pos] = current_ref_pos + 1; // Use 1-indexed positions current_ref_pos++; current_query_pos++; - // query_to_ref_map[current_query_pos] = current_ref_pos + 1; // Use 1-indexed positions } break; case BAM_CINS: diff --git a/src/lrst.i b/src/lrst.i index 9def020..fc93d0e 100644 --- a/src/lrst.i +++ b/src/lrst.i @@ -36,41 +36,6 @@ lrst.i: SWIG module defining the Python wrapper for our C++ modules $result = list; } -// Map std::map>> to Python -// dictionary -// %typemap(out) std::map>> { -// PyObject *dict = PyDict_New(); -// for (auto const &it : $1) { -// PyObject *inner_dict = PyDict_New(); -// for (auto const &inner_it : it.second) { -// PyObject *tuple = PyTuple_Pack(2, -// PyUnicode_FromStringAndSize(&std::get<0>(inner_it.second), 1), -// PyFloat_FromDouble(std::get<1>(inner_it.second))); -// PyDict_SetItem(inner_dict, -// PyUnicode_FromStringAndSize(&inner_it.first, 1), -// tuple); -// } -// PyDict_SetItem(dict, PyLong_FromLong(it.first), inner_dict); -// } -// $result = dict; -// } - -// Map std::map> to Python -// dictionary -// %typemap(out) std::map> { -// PyObject *dict = PyDict_New(); -// for (auto const &it : $1) { -// PyObject *tuple = PyTuple_Pack(5, -// PyUnicode_FromStringAndSize(&std::get<0>(it.second), 1), -// PyUnicode_FromStringAndSize(&std::get<1>(it.second), 1), -// PyFloat_FromDouble(std::get<2>(it.second)), -// PyLong_FromLong(std::get<3>(it.second)), -// PyBool_FromLong(std::get<4>(it.second))); -// PyDict_SetItem(dict, PyLong_FromLong(it.first), tuple); -// } -// $result = dict; -// } - // Map std::map>> to Python dictionary %typemap(out) std::map>> { @@ -104,12 +69,11 @@ lrst.i: SWIG module defining the Python wrapper for our C++ modules %include %include -// Define the conversion for uint64_t arrays -//%array_class(uint64_t, uint64Array); - %template(IntVector) std::vector; %template(DoubleVector) std::vector; %template(Int2DVector) std::vector>; +%template(StringVector) std::vector; +%template(CharVector) std::vector; // These are the header functions wrapped by our lrst module (Like an 'import') %include "input_parameters.h" // Contains InputPara for passing parameters to C++ diff --git a/src/output_data.cpp b/src/output_data.cpp index cb89804..283f458 100644 --- a/src/output_data.cpp +++ b/src/output_data.cpp @@ -3,6 +3,7 @@ #include // sqrt #include #include +#include // std::round #include "output_data.h" #include "utils.h" @@ -84,9 +85,9 @@ void Basic_Seq_Statistics::add(Basic_Seq_Statistics& basic_qc){ this->read_lengths.insert(this->read_lengths.end(), basic_qc.read_lengths.begin(), basic_qc.read_lengths.end()); } - // Add GC content if not empty - if (!basic_qc.read_gc_content_count.empty()) { - this->read_gc_content_count.insert(this->read_gc_content_count.end(), basic_qc.read_gc_content_count.begin(), basic_qc.read_gc_content_count.end()); + // Update the per-read GC content distribution + for (int i = 0; i < 101; i++) { + this->read_gc_content_count[i] += basic_qc.read_gc_content_count[i]; } } @@ -190,7 +191,6 @@ Basic_Seq_Quality_Statistics::Basic_Seq_Quality_Statistics(){ pos_quality_distribution.resize(MAX_READ_LENGTH, ZeroDefault); pos_quality_distribution_dev.resize(MAX_READ_LENGTH, ZeroDefault); pos_quality_distribution_count.resize(MAX_READ_LENGTH, ZeroDefault); - read_average_base_quality_distribution.resize(MAX_READ_QUALITY, ZeroDefault); read_quality_distribution.resize(MAX_READ_QUALITY, ZeroDefault); } @@ -257,11 +257,142 @@ void Basic_Seq_Quality_Statistics::global_sum(){ // BAM output constructor Output_BAM::Output_BAM(){ + this->num_primary_alignment = 0; + this->num_secondary_alignment = 0; + this->num_supplementary_alignment = 0; + this->num_clip_bases = 0; + this->sample_modified_base_count = 0; + this->sample_modified_base_count_forward = 0; + this->sample_modified_base_count_reverse = 0; + this->forward_alignment = 0; + this->reverse_alignment = 0; + this->base_mod_counts = std::unordered_map(); + this->base_mod_counts_forward = std::unordered_map(); + this->base_mod_counts_reverse = std::unordered_map(); } Output_BAM::~Output_BAM(){ } +void Output_BAM::updateBaseModCounts(char mod_type, int strand) +{ + // Update the sample modified base count for predictions exceeding the threshold + this->sample_modified_base_count++; + this->base_mod_counts[mod_type]++; // Update the type-specific modified base count + + // Update the modified base count for the strand from primary alignments + if (strand == 0) { + this->sample_modified_base_count_forward++; + this->base_mod_counts_forward[mod_type]++; // Update the type-specific modified base count + } else if (strand == 1) { + this->sample_modified_base_count_reverse++; + this->base_mod_counts_reverse[mod_type]++; // Update the type-specific modified base count + } +} + +void Output_BAM::updateBaseModProbabilities(char mod_type, double pct_len, double probability) +{ + // Update the base modification probabilities + this->read_pct_len_vs_mod_prob[mod_type].push_back(std::make_pair(pct_len, probability)); +} + +void Output_BAM::updateReadModRate(int read_length, const std::unordered_map& base_mod_rates) { + ReadModData read_mod_data; + read_mod_data.read_length = read_length; + read_mod_data.base_mod_rates = base_mod_rates; + this->read_mod_data.push_back(read_mod_data); +} + +std::vector Output_BAM::getBaseModTypes() +{ + std::vector base_mod_types; + if (this->base_mod_counts.empty()) { + printError("No base modification counts found."); + return base_mod_types; + } + + for (const auto& it : this->base_mod_counts) { + base_mod_types.push_back(it.first); + } + + return base_mod_types; +} + +int Output_BAM::getReadModDataSize() +{ + return this->read_mod_data.size(); +} + +int Output_BAM::getNthReadModLength(int read_index) +{ + return this->read_mod_data[read_index].read_length; +} + +double Output_BAM::getNthReadModRate(int read_index, char mod_type) +{ + double mod_rate = 0.0; + try { + this->read_mod_data.at(read_index); + } catch (const std::out_of_range& oor) { + std::cerr << "Error: Read index " << read_index << " is out of range." << std::endl; + } + try { + mod_rate = this->read_mod_data[read_index].base_mod_rates.at(mod_type); + } catch (const std::out_of_range& oor) { + // No modification rate found for the specified type in the read + mod_rate = 0.0; + } + return mod_rate; +} + +uint64_t Output_BAM::getModTypeCount(char mod_type) +{ + return this->base_mod_counts[mod_type]; +} + +uint64_t Output_BAM::getModTypeCount(char mod_type, int strand) +{ + if (strand == 0) { + return this->base_mod_counts_forward[mod_type]; + } else { + return this->base_mod_counts_reverse[mod_type]; + } +} + +double Output_BAM::getNthReadLenPct(int read_index, char mod_type) +{ + double read_len_pct = 0.0; + try { + this->read_pct_len_vs_mod_prob.at(mod_type); + } catch (const std::out_of_range& oor) { + std::cerr << "Error: Read length percentage not found for type " << mod_type << std::endl; + } + try { + read_len_pct = this->read_pct_len_vs_mod_prob[mod_type].at(read_index).first; + } catch (const std::out_of_range& oor) { + std::cerr << "Error: Read length percentage not found for read index " << read_index << " and type " << mod_type << std::endl; + return 0.0; + } + return read_len_pct; +} + +double Output_BAM::getNthReadModProb(int read_index, char mod_type) +{ + double mod_prob = -1.0; + try { + this->read_pct_len_vs_mod_prob.at(mod_type); + } catch (const std::out_of_range& oor) { + return mod_prob; + } + try { + mod_prob = this->read_pct_len_vs_mod_prob[mod_type].at(read_index).second; + } catch (const std::out_of_range& oor) { + // std::cerr << "Error: Modification probability not found for read index " << read_index << " and type " << mod_type << std::endl; + return -1.0; + } + return mod_prob; +} + int Output_BAM::getReadCount() { return this->read_move_table.size(); @@ -330,6 +461,11 @@ void Output_BAM::add(Output_BAM &output_data) this->seq_quality_info.base_quality_distribution[i] += output_data.seq_quality_info.base_quality_distribution[i]; } + // Update the read average base quality vector if it is not empty + for (int i=0; iseq_quality_info.read_average_base_quality_distribution[i] += output_data.seq_quality_info.read_average_base_quality_distribution[i]; + } + this->num_matched_bases += output_data.num_matched_bases; this->num_mismatched_bases += output_data.num_mismatched_bases; this->num_ins_bases += output_data.num_ins_bases; @@ -544,7 +680,6 @@ void Output_FAST5::addReadBaseSignals(Base_Signals values){ void Output_FAST5::addReadFastq(std::vector fq, FILE *read_details_fp) { const char * read_name; - double gc_content_pct; // Access the read name std::string header_str = fq[0]; @@ -552,9 +687,7 @@ void Output_FAST5::addReadFastq(std::vector fq, FILE *read_details_ std::string read_name_str; std::getline( iss_header, read_name_str, ' ' ); read_name = read_name_str.c_str(); - - // Access the sequence data - std::string sequence_data_str = fq[1]; + std::string sequence_data_str = fq[1]; // Access the sequence data // Update the total number of bases int base_count = sequence_data_str.length(); @@ -573,11 +706,16 @@ void Output_FAST5::addReadFastq(std::vector fq, FILE *read_details_ base_quality_values.push_back(base_quality_value); } + // Ensure the base quality values match the sequence length + if (base_quality_values.size() != base_count) { + printError("Warning: Base quality values do not match the sequence length for read ID " + std::string(read_name)); + } + // Update the base quality and GC content information int gc_count = 0; - double read_mean_base_qual = 0; + double cumulative_base_prob = 0; // Read cumulative base quality probability char current_base; - uint64_t base_quality_value; + int base_quality_value; for (int i = 0; i < base_count; i++) { current_base = sequence_data_str[i]; @@ -599,23 +737,48 @@ void Output_FAST5::addReadFastq(std::vector fq, FILE *read_details_ { long_read_info.total_tu_cnt += 1; } - // Get the base quality - base_quality_value = (uint64_t)base_quality_values[i]; - seq_quality_info.base_quality_distribution[base_quality_value] += 1; - read_mean_base_qual += (double)base_quality_value; + // Get the base quality (Phred) value + base_quality_value = base_quality_values[i]; + + // Update the per-base quality distribution + try { + seq_quality_info.base_quality_distribution[base_quality_value] += 1; + } catch (const std::out_of_range& oor) { + printError("Warning: Base quality value " + std::to_string(base_quality_value) + " exceeds maximum value"); + } + + // Convert the Phred quality value to a probability + double base_quality_prob = pow(10, -base_quality_value / 10.0); + cumulative_base_prob += base_quality_prob; + } + + // Calculate the mean base quality probability + cumulative_base_prob /= (double)base_count; + + // Convert the mean base quality probability to a Phred quality value + double read_mean_base_qual = -10.0 * log10(cumulative_base_prob); + + // Update the per-read GC content distribution + double gc_content_pct = (100.0 * gc_count) / static_cast(base_count); + int gc_content_int = static_cast(std::round(gc_content_pct)); + try { + long_read_info.read_gc_content_count[gc_content_int] += 1; + } catch (const std::out_of_range& oor) { + printError("Warning: Invalid GC content value " + std::to_string(gc_content_int)); } - // Calculate percent guanine & cytosine - gc_content_pct = 100.0 *( (double)gc_count / (double)base_count ); + // Update the per-read base quality distribution + int read_mean_base_qual_int = static_cast(std::round(read_mean_base_qual)); + + try { + seq_quality_info.read_quality_distribution[read_mean_base_qual_int] += 1; + } catch (const std::out_of_range& oor) { + printError("Warning: Base quality value " + std::to_string(read_mean_base_qual_int) + " exceeds maximum value"); + } - // Look into this section - long_read_info.read_gc_content_count[(int)(gc_content_pct + 0.5)] += 1; - read_mean_base_qual /= (double) base_count; - seq_quality_info.read_average_base_quality_distribution[(uint)(read_mean_base_qual + 0.5)] += 1; - fprintf(read_details_fp, "%s\t%d\t%.2f\t%.2f\n", read_name, base_count, gc_content_pct, read_mean_base_qual); + fprintf(read_details_fp, "%s\t%d\t%.2f\t%.2f\n", read_name, base_count, gc_content_pct, read_mean_base_qual); // Write to file - // Update the total number of reads - long_read_info.total_num_reads += 1; + long_read_info.total_num_reads += 1; // Update read count } // Get the read count diff --git a/src/plot_utils.py b/src/plot_utils.py index 669b602..c5f2bc3 100644 --- a/src/plot_utils.py +++ b/src/plot_utils.py @@ -20,8 +20,9 @@ MAX_READ_QUALITY = 100 PLOT_FONT_SIZE = 16 -# Return a dictionary of default plot filenames + def getDefaultPlotFilenames(): + """Create a default HTML plot data structure.""" plot_filenames = { # for fq/fa "read_length_distr": {'title': "Read Length", 'description': "Read Length Distribution"}, # for bam "read_alignments_bar": {'title': "Read Alignments", @@ -31,9 +32,11 @@ def getDefaultPlotFilenames(): "read_length_bar": {'title': "Read Length Statistics", 'description': "Read Length Statistics"}, "base_counts": {'title': "Base Counts", 'description': "Base Counts", 'summary': ""}, - "basic_info": {'title': "Basic Statistics", - 'description': "Basic Statistics", 'summary': ""}, "read_length_hist": {'title': "Read Length Histogram", 'description': "Read Length Histogram", 'summary': ""}, + + "gc_content_hist": {'title': "GC Content Histogram", 'description': "GC Content Histogram", 'summary': ""}, + + "read_length_mod_rates": {'title': "Read Length vs. Modification Rates", 'description': "Read Length vs. Modification Rates", 'summary': ""}, "base_quality": {'title': "Base Quality Histogram", 'description': "Base Quality Histogram"}, @@ -45,8 +48,9 @@ def getDefaultPlotFilenames(): return plot_filenames -# Wrap the text in the table + def wrap(label): + """Wrap the label text.""" # First split the string into a list of words words = label.split(' ') @@ -67,12 +71,14 @@ def wrap(label): return new_label -# Plot the read alignment numbers -def plot_read_length_stats(output_data, file_type): + +def plot_read_length_stats(output_data, file_type, plot_filepaths): + """Plot the read length statistics.""" # Define the three categories category = ['N50', 'Mean', 'Median'] all_traces = [] + error_flag = False if file_type == 'BAM': # Create a bar trace for each type of read length statistic @@ -85,6 +91,10 @@ def plot_read_length_stats(output_data, file_type): trace = go.Bar(x=category, y=values, name=plot_title) all_traces.append(trace) + # Set the error flag if any of the values are zero (except for unmapped reads) + if i != 2 and (values[0] == 0 or values[1] == 0 or values[2] == 0): + error_flag = True + elif file_type == 'SeqTxt': # Create a bar trace for each type of read length statistic bar_titles = ['All Reads', 'Passed Reads', 'Failed Reads'] @@ -96,6 +106,10 @@ def plot_read_length_stats(output_data, file_type): trace = go.Bar(x=category, y=values, name=plot_title) all_traces.append(trace) + # Set the error flag if any of the values are zero (except for failed reads) + if i != 2 and (values[0] == 0 or values[1] == 0 or values[2] == 0): + error_flag = True + else: # Get the data for all reads key_list = ['n50_read_length', 'mean_read_length', 'median_read_length'] @@ -107,6 +121,11 @@ def plot_read_length_stats(output_data, file_type): trace = go.Bar(x=category, y=values, name=bar_title) all_traces.append(trace) + # Set the error flag if any of the values are zero + if values[0] == 0 or values[1] == 0 or values[2] == 0: + error_flag = True + + # Create the layout layout = go.Layout(title='', xaxis=dict(title='Statistics'), yaxis=dict(title='Length (bp)'), barmode='group', font=dict(size=PLOT_FONT_SIZE)) @@ -114,16 +133,19 @@ def plot_read_length_stats(output_data, file_type): fig = go.Figure(data=all_traces, layout=layout) # Generate the HTML - html_obj = fig.to_html(full_html=False, default_height=500, default_width=700) + # html_obj = fig.to_html(full_html=False, default_height=500, default_width=700) + plot_filepaths['read_length_bar']['dynamic'] = fig.to_html(full_html=False, default_height=500, default_width=700) - return html_obj + # Set the error flag + plot_filepaths['read_length_bar']['error_flag'] = error_flag -# Plot the base counts -def plot_base_counts(output_data, filetype): - # Define the five categories - category = ['A', 'C', 'G', 'T/U', 'N'] - # Create a bar trace for each type of data +def plot_base_counts(output_data, filetype, plot_filepaths): + """Plot overall base counts for the reads.""" + + # Create a bar trace for each base + error_flag = False + category = ['A', 'C', 'G', 'T/U', 'N'] all_traces = [] if filetype == 'BAM': bar_titles = ['All Reads', 'Mapped Reads', 'Unmapped Reads'] @@ -135,6 +157,16 @@ def plot_base_counts(output_data, filetype): trace = go.Bar(x=category, y=values, name=plot_title) all_traces.append(trace) + # Set the error flag if the N count is greater than 10% or the A, C, + # G, or T/U counts are zero (except for unmapped reads) + if i != 2: + if data.total_num_bases == 0: + error_flag = True + elif data.total_n_cnt / data.total_num_bases > 0.1: + error_flag = True + elif data.total_a_cnt == 0 or data.total_c_cnt == 0 or data.total_g_cnt == 0 or data.total_tu_cnt == 0: + error_flag = True + elif filetype == 'SeqTxt': bar_titles = ['All Reads', 'Passed Reads', 'Failed Reads'] data_objects = [output_data.all_long_read_info.long_read_info, output_data.passed_long_read_info.long_read_info, output_data.failed_long_read_info.long_read_info] @@ -145,6 +177,15 @@ def plot_base_counts(output_data, filetype): trace = go.Bar(x=category, y=values, name=plot_title) all_traces.append(trace) + # Set the error flag if the N count is greater than 10% or the A, C, + # G, or T/U counts are zero + if data.total_num_bases == 0: + error_flag = True + elif data.total_n_cnt / data.total_num_bases > 0.1: + error_flag = True + elif data.total_a_cnt == 0 or data.total_c_cnt == 0 or data.total_g_cnt == 0 or data.total_tu_cnt == 0: + error_flag = True + else: plot_title = 'All Reads' data = output_data.long_read_info @@ -152,86 +193,26 @@ def plot_base_counts(output_data, filetype): trace = go.Bar(x=category, y=values, name=plot_title) all_traces.append(trace) - # Create the layout - layout = go.Layout(title='', xaxis=dict(title='Base'), yaxis=dict(title='Counts'), barmode='group', font=dict(size=PLOT_FONT_SIZE)) + # Set the error flag if the N count is greater than 10% or the A, C, + # G, or T/U counts are zero + if data.total_num_bases == 0: + error_flag = True + elif data.total_n_cnt / data.total_num_bases > 0.1: + error_flag = True + elif data.total_a_cnt == 0 or data.total_c_cnt == 0 or data.total_g_cnt == 0 or data.total_tu_cnt == 0: + error_flag = True # Create the figure and add the traces + layout = go.Layout(title='', xaxis=dict(title='Base'), yaxis=dict(title='Counts'), barmode='group', font=dict(size=PLOT_FONT_SIZE)) fig = go.Figure(data=all_traces, layout=layout) # Generate the HTML - html_obj = fig.to_html(full_html=False, default_height=500, default_width=700) + plot_filepaths['base_counts']['dynamic'] = fig.to_html(full_html=False, default_height=500, default_width=700) + plot_filepaths['base_counts']['error_flag'] = error_flag - return html_obj -# Plot basic information about the reads in bar chart format -def plot_basic_info(output_data, file_type): - html_obj = '' - if file_type == 'BAM': - - # Create a bar trace for each type of data - bar_titles = ['All Reads', 'Mapped Reads', 'Unmapped Reads'] - data_objects = [output_data.long_read_info, output_data.mapped_long_read_info, output_data.unmapped_long_read_info] - - # Create subplots for each category - fig = make_subplots(rows=2, cols=2, subplot_titles=("Number of Reads", "Number of Bases", "Longest Read", "GC Content"), horizontal_spacing=0.3, vertical_spacing=0.2) - - # Add traces for each category - key_list = ['total_num_reads', 'total_num_bases', 'longest_read_length', 'gc_cnt'] - for i in range(4): - # Get the data for this category - key_name = key_list[i] - - # Add the traces for each type of data - data = [getattr(data_objects[0], key_name), getattr(data_objects[1], key_name), getattr(data_objects[2], key_name)] - - # Create the trace - trace = go.Bar(x=data, y=bar_titles, orientation='h') - - # Add the trace to the figure - fig.add_trace(trace, row=(i // 2) + 1, col=(i % 2) + 1) - fig.update_layout(showlegend=False) - - # Update the layout - fig.update_layout(showlegend=False, font=dict(size=PLOT_FONT_SIZE)) - - # Generate the HTML - html_obj = fig.to_html(full_html=False, default_height=800, default_width=1200) - - elif file_type == 'SeqTxt': - - # Create a bar trace for each type of data - bar_titles = ['All Reads', 'Passed Reads', 'Failed Reads'] - data_objects = [output_data.all_long_read_info.long_read_info, output_data.passed_long_read_info.long_read_info, output_data.failed_long_read_info.long_read_info] - - # Create subplots for each category - fig = make_subplots(rows=1, cols=3, subplot_titles=("Number of Reads", "Number of Bases", "Longest Read"), horizontal_spacing=0.1) - - # Add traces for each category - key_list = ['total_num_reads', 'total_num_bases', 'longest_read_length'] - for i in range(3): - # Get the data for this category - key_name = key_list[i] - - # Add the traces for each type of data - data = [getattr(data_objects[0], key_name), getattr(data_objects[1], key_name), getattr(data_objects[2], key_name)] - - # Create the trace - trace = go.Bar(x=data, y=bar_titles, orientation='h') - - # Add the trace to the figure - fig.add_trace(trace, row=1, col=i + 1) - - # Update the layout - fig.update_layout(showlegend=False, font=dict(size=PLOT_FONT_SIZE)) - - # Generate the HTML - html_obj = fig.to_html(full_html=False, default_height=500, default_width=1600) - - return html_obj - - -# Plot the read length histograms -def read_lengths_histogram(data, font_size): +def read_lengths_histogram(data, font_size, plot_filepaths): + """Plot the read length histogram.""" linear_bin_count = 10 log_bin_count = 10 @@ -251,9 +232,6 @@ def read_lengths_histogram(data, font_size): hist, _ = np.histogram(read_lengths, bins=edges) # Create a figure with two subplots - # fig = make_subplots( - # rows=2, cols=1, - # subplot_titles=("Read Length Histogram", "Log Read Length Histogram"), vertical_spacing=0.5) fig = make_subplots( rows=1, cols=2, subplot_titles=("Read Length Histogram", "Log Read Length Histogram"), vertical_spacing=0.0) @@ -261,7 +239,6 @@ def read_lengths_histogram(data, font_size): log_col=2 linear_bindata = np.dstack((edges[:-1], edges[1:], hist))[0, :, :] - # linear_bin_centers = np.round((linear_bindata[:, 0] + linear_bindata[:, 1]) / 2, 0) fig.add_trace(go.Bar(x=edges, y=hist, customdata=linear_bindata, hovertemplate='Length: %{customdata[0]:.0f}-%{customdata[1]:.0f}bp
    Counts:%{customdata[2]:.0f}', marker_color='#36a5c7'), row=1, col=linear_col) @@ -273,16 +250,13 @@ def read_lengths_histogram(data, font_size): fig.add_vline(n50, line_width=1, line_dash="dash", annotation_text='N50', annotation_bgcolor="green", annotation_textangle=90, row=1, col=linear_col) - # Log histogram - # Get the log10 histogram of read lengths + # Log scale histogram read_lengths_log = np.log10(read_lengths, out=np.zeros_like(read_lengths), where=(read_lengths != 0)) - # log_hist, log_edges = np.histogram(read_lengths_log, bins=bin_count) log_edges = np.linspace(0, np.max(read_lengths_log), num=log_bin_count + 1) log_hist, _ = np.histogram(read_lengths_log, bins=log_edges) xd = log_edges log_bindata = np.dstack((np.power(10, log_edges)[:-1], np.power(10, log_edges)[1:], log_hist))[0, :, :] - # log_bin_centers = np.round((log_bindata[:, 0] + log_bindata[:, 1]) / 2, 0) yd = log_hist fig.add_trace(go.Bar(x=xd, y=yd, customdata=log_bindata, hovertemplate='Length: %{customdata[0]:.0f}-%{customdata[1]:.0f}bp
    Counts:%{customdata[2]:.0f}', @@ -297,17 +271,7 @@ def read_lengths_histogram(data, font_size): fig.update_annotations(font=dict(color="white")) # Set tick value range for the log scale - # Use the bin edge centers as the tick values - # tick_vals = (log_edges[:-1] + log_edges[1:]) / 2 - # tick_labels = ['{:,}'.format(int(10 ** x)) for x in tick_vals] tick_vals = log_edges - # tick_labels = ['{:,}'.format(int(10 ** x)) for x in tick_vals] - - # Format the tick labels to be in kilobases (kb) if the value is greater than - # 1000, and in bases (b) otherwise - # tick_labels = ['{:,}kb'.format(int(x / 1000)) for x in tick_vals] - # tick_labels = ['{:,}kb'.format(int(x) for x in log_bin_centers) if x > - # 1000 else '{:,}b'.format(int(x)) for x in log_bin_centers] tick_labels = [] for i in range(len(log_bindata)): # Format the tick labels to be in kilobases (kb) if the value is greater @@ -322,21 +286,7 @@ def read_lengths_histogram(data, font_size): tick_labels.append('{}-{}'.format(left_val_str, right_val_str)) fig.update_xaxes(ticks="outside", title_text='Read Length (Log Scale)', title_standoff=0, row=1, col=log_col, tickvals=tick_vals, ticktext=tick_labels, tickangle=45) - # fig.update_xaxes(range=[0, np.max(log_edges)], ticks="outside", title_text='Read Length (Log Scale)', title_standoff=0, row=2, col=1) - # fig.update_xaxes(range=[0, np.max(log_edges)], ticks="outside", title_text='Read Length (Log Scale)', title_standoff=0, row=2, col=1, tickvals=tick_vals) - # tick_vals = list(range(0, 5)) - # fig.update_xaxes( - # range=[0, np.max(log_edges)], - # tickmode='array', - # tickvals=tick_vals, - # ticktext=['{:,}'.format(10 ** x) for x in tick_vals], - # ticks="outside", title_text='Read Length (Log Scale)', title_standoff=0, row=2, col=1) - - # Set the tick value range for the linear scale - # tick_vals = (edges[:-1] + edges[1:]) / 2 - # tick_labels = ['{:,}'.format(int(x)) for x in tick_vals] tick_vals = edges - # tick_labels = ['{:,}'.format(int(x)) for x in tick_vals] # Format the tick labels to be the range of the bin centers tick_labels = [] @@ -352,30 +302,87 @@ def read_lengths_histogram(data, font_size): tick_labels.append('{}-{}'.format(left_val_str, right_val_str)) - # tick_labels = ['{:,}kb'.format(int(x / 1000)) for x in tick_vals] - # tick_labels = ['{:,}kb'.format(int(x)) if x > 1000 else - # '{:,}b'.format(int(x)) for x in linear_bin_centers] linear_col=1 fig.update_xaxes(ticks="outside", title_text='Read Length', title_standoff=0, row=1, col=linear_col, tickvals=tick_vals, ticktext=tick_labels, tickangle=45) - # fig.update_xaxes(ticks="outside", title_text='Read Length', title_standoff=0, row=1, col=1, range=[0, np.max(edges)], tickvals=tick_vals) fig.update_yaxes(ticks="outside", title_text='Counts', title_standoff=0) # Update the layout fig.update_layout(showlegend=False, autosize=True, font=dict(size=PLOT_FONT_SIZE)) - # Set font sizes - # fig.update_layout(showlegend=False, autosize=False) - # fig.update_layout(font=dict(size=font_size), autosize=True) - fig.update_annotations(font_size=annotation_size) - # html_obj = fig.to_html(full_html=False, default_height=500, default_width=700) - html_obj = fig.to_html(full_html=False, default_height=500, default_width=1200) + + # Generate the HTML + plot_filepaths['read_length_hist']['dynamic'] = fig.to_html(full_html=False, default_height=500, default_width=1200) - return html_obj -# Save the 'Base quality' plot image. -def base_quality(data, font_size): +def read_gc_content_histogram(data, font_size, plot_filepaths): + """Plot the per-read GC content histogram.""" + bin_size = 1 + gc_content = np.array(data.read_gc_content_count) + + # Calculate the percentage of reads with a GC content of <30% + # gc_content_below_30 = np.sum(gc_content[:30]) + # logging.info("[TEST] Percentage of reads with GC content <30%: {}".format(gc_content_below_30 / np.sum(gc_content))) + + # # Calculate the percentage of reads with a GC content of >70% + # gc_content_above_70 = np.sum(gc_content[70:]) + # logging.info("[TEST] Percentage of reads with GC content >70%: {}".format(gc_content_above_70 / np.sum(gc_content))) + + # # Calculate the percentage of reads with a GC content of <20% + # gc_content_below_20 = np.sum(gc_content[:20]) + # logging.info("[TEST] Percentage of reads with GC content <20%: {}".format(gc_content_below_20 / np.sum(gc_content))) + + # # Calculate the percentage of reads with a GC content of >60% + # gc_content_above_60 = np.sum(gc_content[60:]) + # logging.info("[TEST] Percentage of reads with GC content >60%: {}".format(gc_content_above_60 / np.sum(gc_content))) + + # Set the error flag if the GC content is below 20% for more than 10% of the + # reads + error_flag = False + if np.sum(gc_content) == 0: + error_flag = True + elif np.sum(gc_content[:20]) / np.sum(gc_content) > 0.1: + error_flag = True + + # Bin the GC content if the bin size is greater than 1 + if bin_size > 1: + gc_content = np.array([np.sum(gc_content[i:i + bin_size]) for i in range(0, 101, bin_size)]) + + gc_content_bins = [i for i in range(0, 101, bin_size)] + + # Generate hover text for each bin + hover_text = [] + if bin_size > 1: + for i in range(len(gc_content_bins)): + hover_text.append('GC content: {}-{}%
    Counts: {}'.format(gc_content_bins[i], gc_content_bins[i] + bin_size, gc_content[i])) + else: + for i in range(len(gc_content_bins)): + hover_text.append('GC content: {}%
    Counts: {}'.format(gc_content_bins[i], gc_content[i])) + + # Set the X values to be the center of the bins + if bin_size > 1: + x_values = [gc_content_bins[i] + bin_size / 2 for i in range(len(gc_content_bins))] + else: + x_values = gc_content_bins + + # Create the figure + fig = go.Figure() + fig.add_trace(go.Bar(x=x_values, y=gc_content, marker_color='#36a5c7', hovertext=hover_text, hoverinfo='text')) + + # Update the layout + fig.update_xaxes(ticks="outside", dtick=10, title_text='GC Content (%)', title_standoff=0) + fig.update_yaxes(ticks="outside", title_text='Number of Reads', title_standoff=0) + fig.update_layout(font=dict(size=PLOT_FONT_SIZE)) # Set font size + + plot_filepaths['gc_content_hist']['dynamic'] = fig.to_html(full_html=False, default_height=500, default_width=700) + plot_filepaths['gc_content_hist']['error_flag'] = error_flag + + +def base_quality(data, font_size, plot_filepaths): + """Plot the base quality distribution.""" xd = np.arange(MAX_BASE_QUALITY) yd = np.array(data.base_quality_distribution) + xd = xd[:60] + yd = yd[:60] fig = go.Figure() customdata = np.dstack((xd, yd))[0, :, :] @@ -388,72 +395,94 @@ def base_quality(data, font_size): fig.update_yaxes(ticks="outside", title_text='Number of bases', title_standoff=0) fig.update_layout(font=dict(size=PLOT_FONT_SIZE)) # Set font size - return fig.to_html(full_html=False, default_height=500, default_width=700) + # return fig.to_html(full_html=False, default_height=500, default_width=700) + plot_filepaths['base_quality']['dynamic'] = fig.to_html(full_html=False, default_height=500, default_width=700) + + # Set the error flag if the base quality is below 20 for more than 10% of + # the bases + error_flag = False + if np.sum(yd) == 0: + error_flag = True + elif np.sum(yd[:20]) / np.sum(yd) > 0.1: + error_flag = True -def read_avg_base_quality(data, font_size): + plot_filepaths['base_quality']['error_flag'] = error_flag + + +def read_avg_base_quality(data, font_size, plot_filepaths): """Plot the read average base quality distribution.""" xd = np.arange(MAX_READ_QUALITY) yd = np.array(data.read_average_base_quality_distribution) + xd = xd[:60] + yd = yd[:60] fig = go.Figure() fig.add_trace(go.Bar(x=xd, y=yd, marker_color='#36a5c7')) - fig.update_xaxes(ticks="outside", dtick=10, title_text='Average Base Quality', title_standoff=0) fig.update_yaxes(ticks="outside", title_text='Number of Reads', title_standoff=0) fig.update_layout(font=dict(size=PLOT_FONT_SIZE)) # Set font size - return fig.to_html(full_html=False, default_height=500, default_width=700) + # return fig.to_html(full_html=False, default_height=500, default_width=700) + plot_filepaths['read_avg_base_quality']['dynamic'] = fig.to_html(full_html=False, default_height=500, default_width=700) + + # Set the error flag if the average base quality is below 20 for more than + # 10% of the reads + error_flag = False + if np.sum(yd) == 0: + error_flag = True + elif np.sum(yd[:20]) / np.sum(yd) > 0.1: + error_flag = True + + plot_filepaths['read_avg_base_quality']['error_flag'] = error_flag def plot_base_modifications(base_modifications): """Plot the base modifications per location.""" - # Get the modification types - modification_types = list(base_modifications.keys()) - # Create the figure + # Add a plot for each modification type fig = go.Figure() - - # Add a trace for each modification type + modification_types = list(base_modifications.keys()) for mod_type in modification_types: - # Get the modification data mod_data = base_modifications[mod_type] - - # Create the trace - trace = go.Scatter(x=mod_data['positions'], y=mod_data['counts'], mode='markers', name=mod_type) - - # Add the trace to the figure + trace = go.Scattergl(x=mod_data['positions'], y=mod_data['counts'], mode='markers', name=mod_type) fig.add_trace(trace) - # Update the layout fig.update_layout(title='Base Modifications', xaxis_title='Position', yaxis_title='Counts', showlegend=True, font=dict(size=PLOT_FONT_SIZE)) - - # Generate the HTML html_obj = fig.to_html(full_html=False, default_height=500, default_width=700) return html_obj -# Main plot function def plot(output_data, para_dict, file_type): + """Generate the plots for the output data.""" + logging.info("Generating plots for file type: {}".format(file_type)) plot_filepaths = getDefaultPlotFilenames() - - # Get the font size for plotly plots - font_size = 14 - - # Create the summary table - create_summary_table(output_data, plot_filepaths, file_type) - - # Create the modified base table if available - if file_type == 'BAM' and para_dict["mod"] > 0: + font_size = 14 # Font size for the plots + create_summary_table(output_data, plot_filepaths, file_type) # Create the summary table + + # Modified base table and plots + try: + para_dict["mod"] + except KeyError: + para_dict["mod"] = False + + if file_type == 'BAM' and para_dict["mod"]: + # Output file for the read length vs. modification rates plot + output_folder = para_dict["output_folder"] + read_length_mod_rate_file = os.path.join(output_folder, 'read_length_hist.png') + plot_filepaths['read_length_mod_rates']['file'] = read_length_mod_rate_file + + # Generate the modified base table and read length vs. modification rates plot base_modification_threshold = para_dict["modprob"] create_modified_base_table(output_data, plot_filepaths, base_modification_threshold) - - # Check if the modified base table is available - if 'base_mods' in plot_filepaths: - logging.info("SUCCESS: Modified base table created") - else: + if 'base_mods' not in plot_filepaths: logging.warning("WARNING: Modified base table not created") # Create the TIN table if available + try: + para_dict["genebed"] + except KeyError: + para_dict["genebed"] = "" + if file_type == 'BAM' and para_dict["genebed"] != "": input_files = para_dict["input_files"] create_tin_table(output_data, input_files, plot_filepaths) @@ -464,9 +493,7 @@ def plot(output_data, para_dict, file_type): else: logging.warning("WARNING: TIN table not created") - # Generate plots - plot_filepaths['base_counts']['dynamic'] = plot_base_counts(output_data, file_type) - plot_filepaths['basic_info']['dynamic'] = plot_basic_info(output_data, file_type) + plot_base_counts(output_data, file_type, plot_filepaths) # Read length histogram if file_type == 'SeqTxt': @@ -475,27 +502,30 @@ def plot(output_data, para_dict, file_type): long_read_data = output_data.long_read_info if file_type != 'FAST5s': - plot_filepaths['read_length_hist']['dynamic'] = read_lengths_histogram(long_read_data, font_size) + read_lengths_histogram(long_read_data, font_size, plot_filepaths) + plot_read_length_stats(output_data, file_type, plot_filepaths) - plot_filepaths['read_length_bar']['dynamic'] = plot_read_length_stats(output_data, file_type) + # GC content histogram + if file_type == 'BAM': + read_gc_content_histogram(output_data.mapped_long_read_info, font_size, plot_filepaths) + elif file_type == 'SeqTxt': + read_gc_content_histogram(output_data.passed_long_read_info.long_read_info, font_size, plot_filepaths) + elif file_type == 'FASTQ' or file_type == 'FASTA': + read_gc_content_histogram(output_data.long_read_info, font_size, plot_filepaths) + # Base quality histogram if file_type != 'FASTA' and file_type != 'FAST5s' and file_type != 'SeqTxt': - # if file_type == 'SeqTxt': - # seq_quality_info = output_data.all_long_read_info.seq_quality_info - # else: seq_quality_info = output_data.seq_quality_info + base_quality(seq_quality_info, font_size, plot_filepaths) + + # Read average base quality histogram + if file_type == 'FASTQ' or file_type == 'FAST5' or file_type == 'BAM': + read_avg_base_quality(seq_quality_info, font_size, plot_filepaths) - # Base quality histogram - plot_filepaths['base_quality']['dynamic'] = base_quality(seq_quality_info, font_size) - - # Read quality histogram - read_quality_dynamic = read_avg_base_quality(seq_quality_info, font_size) - plot_filepaths['read_avg_base_quality']['dynamic'] = read_quality_dynamic - + # Plot the read alignments and base alignments if the file type is BAM if file_type == 'BAM': - # Plot read alignment QC - plot_filepaths['read_alignments_bar']['dynamic'] = plot_alignment_numbers(output_data) - plot_filepaths['base_alignments_bar']['dynamic'] = plot_errors(output_data) + plot_alignment_numbers(output_data, plot_filepaths) + plot_errors(output_data, plot_filepaths) elif file_type == 'FAST5s': plot_filepaths['ont_signal']['dynamic'] = plot_signal(output_data, para_dict) @@ -504,10 +534,9 @@ def plot(output_data, para_dict, file_type): def plot_pod5(pod5_output, para_dict, bam_output=None): """Plot the ONT POD5 signal data for a random sample of reads.""" + out_path = para_dict["output_folder"] plot_filepaths = getDefaultPlotFilenames() - - # Create the summary table create_pod5_table(pod5_output, plot_filepaths) # Generate the signal plots @@ -600,7 +629,7 @@ def plot_pod5(pod5_output, para_dict, bam_output=None): # Plot the signal data x = np.arange(signal_length) - fig.add_trace(go.Scatter( + fig.add_trace(go.Scattergl( x=x, y=nth_read_data, mode='markers', marker=dict(color='LightSkyBlue', @@ -608,7 +637,7 @@ def plot_pod5(pod5_output, para_dict, bam_output=None): line=dict(color='MediumPurple', width=2)), opacity=0.5)) - # Update the plot style + # Update the plot style (using 0-100 to improve performance) fig.update_layout( title=nth_read_name, yaxis_title="Signal", @@ -617,7 +646,6 @@ def plot_pod5(pod5_output, para_dict, bam_output=None): xaxis=dict(range=[0, 100]) ) fig.update_traces(marker={'size': marker_size}) - # fig.update_xaxes(title="Index") # Append the dynamic HTML object to the output structure dynamic_html = fig.to_html(full_html=False) @@ -640,6 +668,8 @@ def plot_signal(output_data, para_dict): # Get read and base counts read_count = output_data.getReadCount() + if read_count == 0: + raise ValueError("No reads found in the dataset") # Randomly sample a small set of reads if it is a large dataset read_sample_size = min(read_count_max, read_count) @@ -692,7 +722,7 @@ def plot_signal(output_data, para_dict): # Plot x = np.arange(start_index, end_index, 1) - fig.add_trace(go.Scatter( + fig.add_trace(go.Scattergl( x=x, y=base_signals, mode='markers', marker=dict(color='LightSkyBlue', @@ -747,6 +777,29 @@ def plot_signal(output_data, para_dict): return output_html_plots +def format_cell(value, type_str='int', error_flag=False): + """Format the cell value for the summary table.""" + style = "background-color: #F88379;" if error_flag else "" + if type_str == 'int': + return "{:,d}".format(style, value) + elif type_str == 'float': + return "{:.1f}".format(style, value) + else: + logging.error("ERROR: Invalid type for formatting cell value") + +def format_row(row_name, values, type_str='int', col_ignore=None): + """Format the row for the summary table. Skip flagging null values in specific columns.""" + cell_str = [] + row_flag = False + for i, value in enumerate(values): + # Set the error flag if the value is 0 except for unmapped reads + error_flag = value == 0 and i != col_ignore + row_flag = row_flag or error_flag # Flag for the entire row + cell_str.append(format_cell(value, type_str, error_flag)) + + return "{}{}".format(row_name, "".join(cell_str)), row_flag + + def create_summary_table(output_data, plot_filepaths, file_type): """Create the summary table for the basic statistics.""" plot_filepaths["basic_st"] = {} @@ -761,73 +814,135 @@ def create_summary_table(output_data, plot_filepaths, file_type): file_type_label = 'Basecall Summary' plot_filepaths["basic_st"]['description'] = "{} Basic Statistics".format(file_type_label) + table_error_flag = False if file_type == 'BAM': + # Add alignment statistics to the summary table table_str = "\n\n\n" \ " " table_str += "\n" - int_str_for_format = " " - double_str_for_format = " " - table_str += int_str_for_format.format("#Total Reads", output_data.mapped_long_read_info.total_num_reads, - output_data.unmapped_long_read_info.total_num_reads, - output_data.long_read_info.total_num_reads) - table_str += int_str_for_format.format("#Total Bases", - output_data.mapped_long_read_info.total_num_bases, - output_data.unmapped_long_read_info.total_num_bases, - output_data.long_read_info.total_num_bases) - table_str += int_str_for_format.format("Longest Read Length", - output_data.mapped_long_read_info.longest_read_length, - output_data.unmapped_long_read_info.longest_read_length, - output_data.long_read_info.longest_read_length) - table_str += int_str_for_format.format("N50", - output_data.mapped_long_read_info.n50_read_length, - output_data.unmapped_long_read_info.n50_read_length, - output_data.long_read_info.n50_read_length) - table_str += double_str_for_format.format("GC Content(%)", - output_data.mapped_long_read_info.gc_cnt * 100, - output_data.unmapped_long_read_info.gc_cnt * 100, - output_data.long_read_info.gc_cnt * 100) - table_str += double_str_for_format.format("Mean Read Length", - output_data.mapped_long_read_info.mean_read_length, - output_data.unmapped_long_read_info.mean_read_length, - output_data.long_read_info.mean_read_length) - table_str += int_str_for_format.format("Median Read Length", - output_data.mapped_long_read_info.median_read_length, - output_data.unmapped_long_read_info.median_read_length, - output_data.long_read_info.median_read_length) + + # Total reads + row_str, row_flag = format_row("Total Reads", \ + [output_data.mapped_long_read_info.total_num_reads, \ + output_data.unmapped_long_read_info.total_num_reads, \ + output_data.long_read_info.total_num_reads], \ + 'int', 1) + table_str += row_str + table_error_flag = table_error_flag or row_flag + + # Total bases + row_str, row_flag = format_row("Total Bases", \ + [output_data.mapped_long_read_info.total_num_bases, \ + output_data.unmapped_long_read_info.total_num_bases, \ + output_data.long_read_info.total_num_bases], \ + 'int', 1) + table_str += row_str + table_error_flag = table_error_flag or row_flag + + # Longest read length + row_str, row_flag = format_row("Longest Read Length", \ + [output_data.mapped_long_read_info.longest_read_length, \ + output_data.unmapped_long_read_info.longest_read_length, \ + output_data.long_read_info.longest_read_length], \ + 'int', 1) + table_str += row_str + table_error_flag = table_error_flag or row_flag + + # N50 + row_str, row_flag = format_row("N50", \ + [output_data.mapped_long_read_info.n50_read_length, \ + output_data.unmapped_long_read_info.n50_read_length, \ + output_data.long_read_info.n50_read_length], \ + 'int', 1) + table_str += row_str + table_error_flag = table_error_flag or row_flag + + # GC content + row_str, row_flag = format_row("GC Content(%)", \ + [output_data.mapped_long_read_info.gc_cnt * 100, \ + output_data.unmapped_long_read_info.gc_cnt * 100, \ + output_data.long_read_info.gc_cnt * 100], \ + 'float', 1) + table_str += row_str + table_error_flag = table_error_flag or row_flag + + # Mean read length + row_str, row_flag = format_row("Mean Read Length", \ + [output_data.mapped_long_read_info.mean_read_length, \ + output_data.unmapped_long_read_info.mean_read_length, \ + output_data.long_read_info.mean_read_length], \ + 'float', 1) + table_str += row_str + table_error_flag = table_error_flag or row_flag + + # Median read length + row_str, row_flag = format_row("Median Read Length", \ + [output_data.mapped_long_read_info.median_read_length, \ + output_data.unmapped_long_read_info.median_read_length, \ + output_data.long_read_info.median_read_length], \ + 'int', 1) + table_str += row_str + table_error_flag = table_error_flag or row_flag elif file_type == 'SeqTxt': table_str = "
    MeasurementMappedUnmappedAll
    {}{:,d}{:," \ - "d}{:,d}
    {}{:.1f}{:.1f}{:.1f}
    \n\n\n" table_str += "\n" - int_str_for_format = "" - double_str_for_format = "" - table_str += int_str_for_format.format("#Total Reads", - output_data.passed_long_read_info.long_read_info.total_num_reads, - output_data.failed_long_read_info.long_read_info.total_num_reads, - output_data.all_long_read_info.long_read_info.total_num_reads) - table_str += int_str_for_format.format("#Total Bases", - output_data.passed_long_read_info.long_read_info.total_num_bases, - output_data.failed_long_read_info.long_read_info.total_num_bases, - output_data.all_long_read_info.long_read_info.total_num_bases) - table_str += int_str_for_format.format("Longest Read Length", - output_data.passed_long_read_info.long_read_info.longest_read_length, - output_data.failed_long_read_info.long_read_info.longest_read_length, - output_data.all_long_read_info.long_read_info.longest_read_length) - table_str += int_str_for_format.format("N50", - output_data.passed_long_read_info.long_read_info.n50_read_length, - output_data.failed_long_read_info.long_read_info.n50_read_length, - output_data.all_long_read_info.long_read_info.n50_read_length) - table_str += double_str_for_format.format("Mean Read Length", - output_data.passed_long_read_info.long_read_info.mean_read_length, - output_data.failed_long_read_info.long_read_info.mean_read_length, - output_data.all_long_read_info.long_read_info.mean_read_length) - table_str += int_str_for_format.format("Median Read Length", - output_data.passed_long_read_info.long_read_info.median_read_length, - output_data.failed_long_read_info.long_read_info.median_read_length, - output_data.all_long_read_info.long_read_info.median_read_length) + + # Total reads + row_str, row_flag = format_row("Total Reads", \ + [output_data.passed_long_read_info.long_read_info.total_num_reads, \ + output_data.failed_long_read_info.long_read_info.total_num_reads, \ + output_data.all_long_read_info.long_read_info.total_num_reads], \ + 'int', 1) + table_str += row_str + table_error_flag = table_error_flag or row_flag + + # Total bases + row_str, row_flag = format_row("Total Bases", \ + [output_data.passed_long_read_info.long_read_info.total_num_bases, \ + output_data.failed_long_read_info.long_read_info.total_num_bases, \ + output_data.all_long_read_info.long_read_info.total_num_bases], \ + 'int', 1) + table_str += row_str + table_error_flag = table_error_flag or row_flag + + # Longest read length + row_str, row_flag = format_row("Longest Read Length", \ + [output_data.passed_long_read_info.long_read_info.longest_read_length, \ + output_data.failed_long_read_info.long_read_info.longest_read_length, \ + output_data.all_long_read_info.long_read_info.longest_read_length], \ + 'int', 1) + table_str += row_str + table_error_flag = table_error_flag or row_flag + + # N50 + row_str, row_flag = format_row("N50", \ + [output_data.passed_long_read_info.long_read_info.n50_read_length, \ + output_data.failed_long_read_info.long_read_info.n50_read_length, \ + output_data.all_long_read_info.long_read_info.n50_read_length], \ + 'int', 1) + table_str += row_str + table_error_flag = table_error_flag or row_flag + + # Mean read length + row_str, row_flag = format_row("Mean Read Length", \ + [output_data.passed_long_read_info.long_read_info.mean_read_length, \ + output_data.failed_long_read_info.long_read_info.mean_read_length, \ + output_data.all_long_read_info.long_read_info.mean_read_length], \ + 'float', 1) + table_str += row_str + table_error_flag = table_error_flag or row_flag + + # Median read length + row_str, row_flag = format_row("Median Read Length", \ + [output_data.passed_long_read_info.long_read_info.median_read_length, \ + output_data.failed_long_read_info.long_read_info.median_read_length, \ + output_data.all_long_read_info.long_read_info.median_read_length], \ + 'int', 1) + table_str += row_str + table_error_flag = table_error_flag or row_flag elif file_type == 'FAST5s': # Get values @@ -837,51 +952,233 @@ def create_summary_table(output_data, plot_filepaths, file_type): # Set up the HTML table table_str = "
    MeasurementPassedFailedAll
    {}{:,d}{:,d}{:,d}
    {}{:.1f}{:.1f}{:.1f}
    \n\n\n" table_str += "\n" - int_str_for_format = "" - table_str += int_str_for_format.format("#Total Reads", read_count) - table_str += int_str_for_format.format("#Total Bases", total_base_count) + + # Total reads + row_str, row_flag = format_row("Total Reads", [read_count], 'int', None) + table_str += row_str + table_error_flag = table_error_flag or row_flag + + # Total bases + row_str, row_flag = format_row("Total Bases", [total_base_count], 'int', None) + table_str += row_str + table_error_flag = table_error_flag or row_flag else: table_str = "
    MeasurementStatistics
    {}{:,d}
    \n\n\n" table_str += "\n" - int_str_for_format = "" - double_str_for_format = "" - table_str += int_str_for_format.format("#Total Reads", - output_data.long_read_info.total_num_reads) - table_str += int_str_for_format.format("#Total Bases", - output_data.long_read_info.total_num_bases) - table_str += int_str_for_format.format("Longest Read Length", - output_data.long_read_info.longest_read_length) - table_str += int_str_for_format.format("N50", - output_data.long_read_info.n50_read_length) - table_str += double_str_for_format.format("GC Content(%)", - output_data.long_read_info.gc_cnt * 100) - table_str += double_str_for_format.format("Mean Read Length", - output_data.long_read_info.mean_read_length) - table_str += int_str_for_format.format("Median Read Length", - output_data.long_read_info.median_read_length) + # Total reads + row_str, row_flag = format_row("Total Reads", [output_data.long_read_info.total_num_reads], 'int', None) + table_str += row_str + table_error_flag = table_error_flag or row_flag + + # Total bases + row_str, row_flag = format_row("Total Bases", [output_data.long_read_info.total_num_bases], 'int', None) + table_str += row_str + table_error_flag = table_error_flag or row_flag + + # Longest read length + row_str, row_flag = format_row("Longest Read Length", [output_data.long_read_info.longest_read_length], 'int', None) + table_str += row_str + table_error_flag = table_error_flag or row_flag + + # N50 + row_str, row_flag = format_row("N50", [output_data.long_read_info.n50_read_length], 'int', None) + table_str += row_str + table_error_flag = table_error_flag or row_flag + + # GC content + row_str, row_flag = format_row("GC Content(%)", [output_data.long_read_info.gc_cnt * 100], 'float', None) + table_str += row_str + table_error_flag = table_error_flag or row_flag + + # Mean read length + row_str, row_flag = format_row("Mean Read Length", [output_data.long_read_info.mean_read_length], 'float', None) + table_str += row_str + table_error_flag = table_error_flag or row_flag + + # Median read length + row_str, row_flag = format_row("Median Read Length", [output_data.long_read_info.median_read_length], 'int', None) + table_str += row_str + table_error_flag = table_error_flag or row_flag table_str += "\n\n
    MeasurementStatistics
    {}{:,d}
    {}{:.1f}
    " + # table_str += """ + #
    + # 💡 + #
    This is your help text explaining the feature!
    + #
    + # """ plot_filepaths["basic_st"]['detail'] = table_str + plot_filepaths["basic_st"]['error_flag'] = table_error_flag + + +def get_axis_name(row, axis_type='x'): + """Get the axis name for the plot.""" + axis_number = row + 1 + return f"{axis_type}axis{axis_number}" if axis_number > 1 else f"{axis_type}axis" + def create_modified_base_table(output_data, plot_filepaths, base_modification_threshold): """Create a summary table for the base modifications.""" + help_text = "Total unfiltered predictions are all predictions prior to applying the base modification probability threshold.\n" \ + "This threshold is set by the user (default: 0.5) and is used to filter out low-confidence base modifications.\n" \ + "Total modification counts are the number of base modifications that pass the threshold.\n" \ + "These counts are also separated by forward and reverse strand predictions.\n" \ + "CpG modification counts are the total CpG modifications that pass the threshold.\n" \ + "These are total counts and not site-specific counts." \ + plot_filepaths["base_mods"] = {} plot_filepaths["base_mods"]['file'] = "" plot_filepaths["base_mods"]['title'] = "Base Modifications" plot_filepaths["base_mods"]['description'] = "Base modification statistics" + table_error_flag = False + + # Print the types of modifications + base_mod_types = output_data.getBaseModTypes() + if base_mod_types: + # logging.info("Modification types: ") + # for mod_type in base_mod_types: + # logging.info(mod_type) + + # logging.info("Getting base modification statistics") + + # Get the read length (%) vs. base modification probability data for + # each sampled read + sample_count = 10000 + read_len_pct = [] + mod_prob = [] + for mod_type in base_mod_types: + for i in range(sample_count): + try: + prob = output_data.getNthReadModProb(i, mod_type) + if prob == -1: # Skip if no modifications for the read + continue + + pct = output_data.getNthReadLenPct(i, mod_type) + read_len_pct.append(pct) + mod_prob.append(prob) + except Exception as e: + logging.error(f"Error getting read length vs. base modification probability data: {e}") + + # Convert the lists to numpy arrays + read_len_pct = np.array(read_len_pct) * 100 # Convert to percentage + mod_prob = np.array(mod_prob) + + # Dictionary of modification character to full name + mod_char_to_name = {'m': '5mC', 'h': '5hmC', 'f': '5fC', 'c': '5caC', \ + 'g': '5hmU', 'e': '5fu', 'b': '5caU', \ + 'a': '6mA', 'o': '8oxoG', 'n': 'Xao', \ + 'C': 'Amb. C', 'A': 'Amb. A', 'T': 'Amb. T', 'G': 'Amb. G',\ + 'N': 'Amb. N', \ + 'v': 'pseU'} + + fig = make_subplots(rows=len(base_mod_types), cols=2, shared_xaxes=False, shared_yaxes=False, vertical_spacing=0.1, subplot_titles=[f"{mod_char_to_name[mod_type]} Modification Probability" for mod_type in base_mod_types]) + + for i, mod_type in enumerate(base_mod_types): + # logging.info(f"Creating trace for modification type: {mod_type} at row: {i + 1}") + + # Add the trace for the read length (%) vs. base modification + # probability scatter plot + fig.add_trace(go.Scatter + (x=read_len_pct, y=mod_prob, mode='markers', name=mod_char_to_name[mod_type], marker=dict(size=5), showlegend=False), + row=i + 1, col=1) + + # Create a histogram of the base modification probabilities + base_mod_prob_hist = go.Histogram(x=mod_prob, name=mod_char_to_name[mod_type], showlegend=False, nbinsx=20) + fig.add_trace(base_mod_prob_hist, row=i + 1, col=2) + + # Update the plot style + fig.update_xaxes(title="Read Length (%)", row=i + 1, col=1) + fig.update_yaxes(title="Modification Probability", row=i + 1, col=1) + fig.update_xaxes(title="Modification Probability", row=i + 1, col=2) + fig.update_yaxes(title="Frequency", row=i + 1, col=2) + fig.update_yaxes(range=[0, 1], row=i + 1, col=1) + + fig.update_layout(title="Read Length vs. Base Modification Probability", font=dict(size=PLOT_FONT_SIZE)) + + # Generate the HTML + if len(base_mod_types) > 0: + plot_height = 500 * len(base_mod_types) + plot_width = 700 * 2 + logging.info("Generating the read length vs. modification rates plot") + plot_filepaths["read_length_mod_rates"]['dynamic'] = fig.to_html(full_html=False, default_height=plot_height, default_width=plot_width) + else: + logging.warning("WARNING: No modification types found") # Create the base modification statistics table + logging.info("Creating the base modification statistics table") table_str = "\n" - table_str += "".format(output_data.modified_prediction_count) - table_str += "".format(base_modification_threshold) - table_str += "".format(output_data.sample_modified_base_count) - table_str += "".format(output_data.sample_modified_base_count_forward) - table_str += "".format(output_data.sample_modified_base_count_reverse) - table_str += "".format(output_data.sample_cpg_forward_count) - table_str += "".format(output_data.sample_cpg_reverse_count) + row_str, row_flag = format_row("Total Unfiltered Predictions", [output_data.modified_prediction_count], 'int', None) + table_str += row_str + table_error_flag = table_error_flag or row_flag + + row_str, row_flag = format_row("Probability Threshold", [base_modification_threshold], 'float', 0) + table_str += row_str + table_error_flag = table_error_flag or row_flag + + row_str, row_flag = format_row("Total Modification Counts", [output_data.sample_modified_base_count], 'int', None) + table_str += row_str + table_error_flag = table_error_flag or row_flag + + row_str, row_flag = format_row("Total Modification Counts (Forward Strand Only)", [output_data.sample_modified_base_count_forward], 'int', None) + table_str += row_str + table_error_flag = table_error_flag or row_flag + + row_str, row_flag = format_row("Total Modification Counts (Reverse Strand Only)", [output_data.sample_modified_base_count_reverse], 'int', None) + table_str += row_str + table_error_flag = table_error_flag or row_flag + + row_str, row_flag = format_row("Total CpG Modification Counts (Forward Strand Only)", [output_data.sample_cpg_forward_count], 'int', None) + table_str += row_str + table_error_flag = table_error_flag or row_flag + + row_str, row_flag = format_row("Total CpG Modification Counts (Reverse Strand Only)", [output_data.sample_cpg_reverse_count], 'int', None) + table_str += row_str + table_error_flag = table_error_flag or row_flag + + # Add the modification type data + for mod_type in base_mod_types: + try: + mod_name = mod_char_to_name[mod_type] + except KeyError: + logging.warning("WARNING: Unknown modification type: {}".format(mod_type)) + mod_name = mod_type + + mod_count = output_data.getModTypeCount(mod_type) + mod_count_fwd = output_data.getModTypeCount(mod_type, 0) + mod_count_rev = output_data.getModTypeCount(mod_type, 1) + + row_str, row_flag = format_row("Total {} Counts in the Sample".format(mod_name), [mod_count], 'int', None) + table_str += row_str + table_error_flag = table_error_flag or row_flag + + row_str, row_flag = format_row("Total {} Counts in the Sample (Forward Strand)".format(mod_name), [mod_count_fwd], 'int', None) + table_str += row_str + table_error_flag = table_error_flag or row_flag + + row_str, row_flag = format_row("Total {} Counts in the Sample (Reverse Strand)".format(mod_name), [mod_count_rev], 'int', None) + table_str += row_str + table_error_flag = table_error_flag or row_flag + + # Finish the table table_str += "\n\n
    Total Predictions{:,d}
    Probability Threshold{:.2f}
    Total Modified Bases in the Sample{:,d}
    Total in the Forward Strand{:,d}
    Total in the Reverse Strand{:,d}
    Total modified CpG Sites in the Sample (Forward Strand){:,d}
    Total modified CpG Sites in the Sample (Reverse Strand){:,d}
    " + + # # Add the help text + # table_str += """ + #
    + # 💡 + #
    {}
    + #
    + # """.format(help_text) + + # Add text below the table suggesting the user to use Modkit for more + # detailed analysis on per-site modification rates + table_str += "

    For per-site modification rates, please use \ + Modkit by Oxford Nanopore Technologies..

    " + + plot_filepaths["base_mods"]['detail'] = table_str + plot_filepaths["base_mods"]['error_flag'] = table_error_flag def create_tin_table(output_data, input_files, plot_filepaths): """Create a summary table for the RNA-Seq TIN values.""" @@ -892,44 +1189,34 @@ def create_tin_table(output_data, input_files, plot_filepaths): # Create a table with the first column showing the BAM filepath, and the # following columns showing TIN count, mean, median, and standard deviation - table_str = "\n\n\n" + table_str = "
    BAM FileCountMeanMedianStdDev
    \n\n\n" table_str += "\n" # Loop through each BAM file + error_flag = False for bam_file in input_files: # Format the filepath as filename only bam_filename = os.path.basename(bam_file) # Get the file data - tin_count = output_data.getTINCount(bam_file) - tin_mean = output_data.getTINMean(bam_file) + # tin_count = output_data.getTINCount(bam_file) + # tin_mean = output_data.getTINMean(bam_file) tin_median = output_data.getTINMedian(bam_file) - tin_std = output_data.getTINStdDev(bam_file) + # tin_std = output_data.getTINStdDev(bam_file) # Add the data to the table - table_str += "".format(bam_filename, tin_count, tin_mean, tin_median, tin_std) + # row_str, row_flag = format_row(bam_filename, [tin_count, tin_mean, + # tin_median, tin_std], 'float', None) + row_str, row_flag = format_row(bam_filename, [tin_median, output_data.getTINCount(bam_file)], 'float', None) + table_str += row_str + error_flag = error_flag or row_flag table_str += "\n\n
    BAM FileMedian TIN ScoreNumber of Transcripts
    {}{:,d}{:.1f}{:.1f}{:.1f}
    " # Add the table to the plot filepaths plot_filepaths["tin"]['detail'] = table_str + plot_filepaths["tin"]['error_flag'] = error_flag - # plot_filepaths["base_mods"] = {} - # plot_filepaths["base_mods"]['file'] = "" - # plot_filepaths["base_mods"]['title'] = "Base Modifications" - # plot_filepaths["base_mods"]['description'] = "Base modification statistics" - - # # Create the base modification statistics table - # table_str = "\n" - # table_str += "".format(output_data.modified_prediction_count) - # table_str += "".format(base_modification_threshold) - # table_str += "".format(output_data.sample_modified_base_count) - # table_str += "".format(output_data.sample_modified_base_count_forward) - # table_str += "".format(output_data.sample_modified_base_count_reverse) - # table_str += "".format(output_data.sample_cpg_forward_count) - # table_str += "".format(output_data.sample_cpg_reverse_count) - # table_str += "\n\n
    Total Predictions{:,d}
    Probability Threshold{:.2f}
    Total Modified Bases in the Sample{:,d}
    Total in the Forward Strand{:,d}
    Total in the Reverse Strand{:,d}
    Total modified CpG Sites in the Sample (Forward Strand){:,d}
    Total modified CpG Sites in the Sample (Reverse Strand){:,d}
    " - # plot_filepaths["base_mods"]['detail'] = table_str def create_pod5_table(output_dict, plot_filepaths): """Create a summary table for the ONT POD5 signal data.""" @@ -938,26 +1225,32 @@ def create_pod5_table(output_dict, plot_filepaths): plot_filepaths["basic_st"]['title'] = "Summary Table" file_type_label = "POD5" plot_filepaths["basic_st"]['description'] = f"{file_type_label} Basic Statistics" + table_error_flag = False - # Get values - read_count = len(output_dict.keys()) - # Set up the HTML table table_str = "\n\n\n" table_str += "\n" - int_str_for_format = "" - table_str += int_str_for_format.format("#Total Reads", read_count) + # int_str_for_format = "" + # table_str += int_str_for_format.format("Total Reads", read_count) + read_count = len(output_dict.keys()) + row_str, row_flag = format_row("Total Reads", [read_count], 'int', None) + table_str += row_str + table_error_flag = table_error_flag or row_flag table_str += "\n\n
    MeasurementStatistics
    {}{:,d}
    {}{:,d}
    " plot_filepaths["basic_st"]['detail'] = table_str + plot_filepaths["basic_st"]['error_flag'] = table_error_flag -def plot_alignment_numbers(data): +def plot_alignment_numbers(data, plot_filepaths): category = ['Primary Alignments', 'Supplementary Alignments', 'Secondary Alignments', 'Reads with Supplementary Alignments', 'Reads with Secondary Alignments', 'Reads with Secondary and Supplementary Alignments', 'Forward Alignments', 'Reverse Alignments'] category = [wrap(x) for x in category] + # Set the error flag if primary alignments equal 0 + error_flag = data.num_primary_alignment == 0 + # Create a horizontally aligned bar plot trace from the data using plotly trace = go.Bar(x=[data.num_primary_alignment, data.num_supplementary_alignment, data.num_secondary_alignment, data.num_reads_with_supplementary_alignment, data.num_reads_with_secondary_alignment, @@ -973,14 +1266,13 @@ def plot_alignment_numbers(data): # Create the figure object fig = go.Figure(data=[trace], layout=layout) - # Generate the HTML object for the plot - html_obj = fig.to_html(full_html=False, default_height=500, default_width=1000) - - return html_obj + # Update the HTML data for the plot + plot_filepaths['read_alignments_bar']['dynamic'] = fig.to_html(full_html=False, default_height=500, default_width=1000) + plot_filepaths['read_alignments_bar']['error_flag'] = error_flag -# Plot base alignment statistics -def plot_errors(output_data): +def plot_errors(output_data, plot_filepaths): + """Plot the error statistics for the alignment data.""" category = \ ['Matched Bases', 'Mismatched Bases', 'Inserted Bases', 'Deleted Bases', 'Clipped Bases\n(Primary Alignments)'] category = [wrap(x) for x in category] @@ -999,7 +1291,14 @@ def plot_errors(output_data): fig = go.Figure(data=[trace], layout=layout) # Generate the HTML object for the plot - html_obj = fig.to_html(full_html=False, default_height=500, default_width=700) + # html_obj = fig.to_html(full_html=False, default_height=500, + # default_width=700) + plot_filepaths['base_alignments_bar']['dynamic'] = fig.to_html(full_html=False, default_height=500, default_width=700) - return html_obj + # Set the error flag if mismatch or clipped bases > matched bases + error_flag = output_data.num_mismatched_bases > output_data.num_matched_bases or \ + output_data.num_clip_bases > output_data.num_matched_bases + plot_filepaths['base_alignments_bar']['error_flag'] = error_flag + + # return html_obj diff --git a/src/seqtxt_module.cpp b/src/seqtxt_module.cpp index 9cd7e7f..62d9929 100644 --- a/src/seqtxt_module.cpp +++ b/src/seqtxt_module.cpp @@ -4,12 +4,16 @@ Class for calling FAST5 statistics modules. */ +#include "seqtxt_module.h" + +/// @cond #include #include #include #include +/// @endcond -#include "seqtxt_module.h" +#include "utils.h" size_t SeqTxt_Module::batch_size_of_record=3000; @@ -36,10 +40,11 @@ std::map SeqTxt_Thread_data::getHeaderColumns() return _header_columns; } -size_t SeqTxt_Thread_data::read_ss_record(std::ifstream* file_stream, std::map header_columns){ +size_t SeqTxt_Thread_data::read_ss_record(std::ifstream& file_stream, std::map header_columns){ //std::cout << "Type 1." << std::endl; thread_index = 0; // Index where this thread's data will be stored - while( std::getline( *file_stream, current_line )) { + while( std::getline( file_stream, current_line ) ) + { std::istringstream column_stream( current_line ); // Read each column value from the record line @@ -93,7 +98,7 @@ SeqTxt_Module::SeqTxt_Module(Input_Para& input_parameters){ has_error = 0; file_index = 0; - input_file_stream = NULL; + // input_file_stream = NULL; if (file_index >= _input_parameters.num_input_files){ std::cerr << "Input file list error." << std::endl; has_error |= 1; @@ -102,8 +107,9 @@ SeqTxt_Module::SeqTxt_Module(Input_Para& input_parameters){ // Open the first file in the list const char * first_filepath = _input_parameters.input_files[file_index].c_str(); - input_file_stream = new std::ifstream(first_filepath); - if (!(input_file_stream->is_open())){ + // input_file_stream = new std::ifstream(first_filepath); + input_file_stream.open(first_filepath); + if (!(input_file_stream.is_open())){ std::cerr << "Cannot open sequencing_summary.txt file="<< first_filepath <(relapse_end_time - relapse_start_time).count() << std::endl; + auto relapse_end_time = std::chrono::high_resolution_clock::now(); + std::cout<<"Elapsed time (seconds): "<< std::chrono::duration_cast(relapse_end_time - relapse_start_time).count() << std::endl; - std::cout<<"sequencing_summary.txt QC "<< (has_error==0?"generated":"failed") << std::endl; + std::cout<<"sequencing_summary.txt QC "<< (has_error==0?"generated":"failed") << std::endl; - return has_error; + return has_error; } -void SeqTxt_Module::SeqTxt_do_thread(std::ifstream* file_stream, Input_Para& ref_input_op, int thread_id, SeqTxt_Thread_data& ref_thread_data, Output_SeqTxt& ref_output ){ +void SeqTxt_Module::SeqTxt_do_thread(std::ifstream& file_stream, Input_Para& ref_input_op, int thread_id, Output_SeqTxt& ref_output, std::map header_columns, size_t batch_size_of_record){ size_t read_ss_size, read_ss_i; - while (true){ - myMutex_readSeqTxt.lock(); - std::map header_column_data = ref_thread_data.getHeaderColumns(); - read_ss_size = ref_thread_data.read_ss_record(file_stream, header_column_data); - - if (read_ss_size == 0 && !(file_index < ref_input_op.num_input_files) ){ - myMutex_readSeqTxt.unlock(); - break; - } - if ( read_ss_size < batch_size_of_record ){ - if ( file_index < ref_input_op.num_input_files ){ - std::cout<< "INFO: Open sequencing_summary.txt file="<< ref_input_op.input_files[file_index] <close(); - file_stream->clear(); - - file_stream->open( ref_input_op.input_files[file_index].c_str() ); - std::string firstline; - std::getline( *file_stream, firstline ); - file_index++; + int total_read_count = 0; + while (true) { + SeqTxt_Thread_data ref_thread_data(ref_input_op, header_columns, thread_id, batch_size_of_record); + { + std::lock_guard lock(myMutex_readSeqTxt); + std::map header_column_data = ref_thread_data.getHeaderColumns(); + read_ss_size = ref_thread_data.read_ss_record(file_stream, header_column_data); + + if (read_ss_size == 0 && !(file_index < ref_input_op.num_input_files) ){ + break; + } + + if ( read_ss_size < batch_size_of_record ){ + if ( file_index < ref_input_op.num_input_files ){ + file_stream.close(); + file_stream.clear(); + + file_stream.open( ref_input_op.input_files[file_index].c_str() ); + std::string firstline; + std::getline( file_stream, firstline ); + file_index++; + } } } - myMutex_readSeqTxt.unlock(); - if (read_ss_size == 0 ) { continue; } + if (read_ss_size == 0 ) { + continue; + } else { + total_read_count += read_ss_size; + printMessage("Thread " + std::to_string(thread_id+1) + " read " + std::to_string(read_ss_size) + " records (total " + std::to_string(total_read_count) + ")"); + } // Columns used for statistics: passes_filtering, sequence_length_template, mean_qscore_template - //ref_thread_data.t_output_SeqTxt_.reset(); ref_thread_data.t_output_SeqTxt_.all_long_read_info.long_read_info.resize(); ref_thread_data.t_output_SeqTxt_.passed_long_read_info.long_read_info.resize(); ref_thread_data.t_output_SeqTxt_.failed_long_read_info.long_read_info.resize(); for(read_ss_i=0; read_ss_ilong_read_info.total_num_reads++; - size_t sequence_base_count = ref_thread_data.stored_records[read_ss_i].sequence_length_template; - seqtxt_statistics->long_read_info.total_num_bases += sequence_base_count; + bool passes_filtering_value = ref_thread_data.stored_records[read_ss_i].passes_filtering; + Basic_SeqTxt_Statistics& seqtxt_statistics = (passes_filtering_value == true) ? ref_thread_data.t_output_SeqTxt_.passed_long_read_info : ref_thread_data.t_output_SeqTxt_.failed_long_read_info; + + seqtxt_statistics.long_read_info.total_num_reads++; + size_t sequence_base_count = ref_thread_data.stored_records[read_ss_i].sequence_length_template; + seqtxt_statistics.long_read_info.total_num_bases += sequence_base_count; // Store the read length - seqtxt_statistics->long_read_info.read_lengths.push_back(sequence_base_count); + seqtxt_statistics.long_read_info.read_lengths.push_back(sequence_base_count); // Update the longest read length int64_t current_read_length = (int64_t) ref_thread_data.stored_records[read_ss_i].sequence_length_template; - if ( seqtxt_statistics->long_read_info.longest_read_length < current_read_length){ - seqtxt_statistics->long_read_info.longest_read_length = current_read_length; - } - seqtxt_statistics->long_read_info.read_length_count[ ref_thread_data.stored_records[read_ss_i].sequence_length_templateseq_quality_info.read_quality_distribution[ int( ref_thread_data.stored_records[read_ss_i].mean_qscore_template ) ] += 1; - if ( seqtxt_statistics->seq_quality_info.min_read_quality == MoneDefault || - seqtxt_statistics->seq_quality_info.min_read_quality>int( ref_thread_data.stored_records[read_ss_i].mean_qscore_template ) ){ - seqtxt_statistics->seq_quality_info.min_read_quality = int( ref_thread_data.stored_records[read_ss_i].mean_qscore_template ); - } - if ( seqtxt_statistics->seq_quality_info.max_read_quality < int( ref_thread_data.stored_records[read_ss_i].mean_qscore_template) ){ - seqtxt_statistics->seq_quality_info.max_read_quality = int( ref_thread_data.stored_records[read_ss_i].mean_qscore_template); - } - } + if ( seqtxt_statistics.long_read_info.longest_read_length < current_read_length){ + seqtxt_statistics.long_read_info.longest_read_length = current_read_length; + } + seqtxt_statistics.long_read_info.read_length_count[ ref_thread_data.stored_records[read_ss_i].sequence_length_templateint( ref_thread_data.stored_records[read_ss_i].mean_qscore_template ) ){ + seqtxt_statistics.seq_quality_info.min_read_quality = int( ref_thread_data.stored_records[read_ss_i].mean_qscore_template ); + } + if ( seqtxt_statistics.seq_quality_info.max_read_quality < int( ref_thread_data.stored_records[read_ss_i].mean_qscore_template) ){ + seqtxt_statistics.seq_quality_info.max_read_quality = int( ref_thread_data.stored_records[read_ss_i].mean_qscore_template); + } + } + std::lock_guard lock(myMutex_output); ref_output.add( ref_thread_data.t_output_SeqTxt_ ); - myMutex_output.unlock(); } } diff --git a/src/tin.cpp b/src/tin.cpp index 0c2d0c4..3beae15 100644 --- a/src/tin.cpp +++ b/src/tin.cpp @@ -171,7 +171,7 @@ bool checkMinReads(htsFile* bam_file, hts_idx_t* idx, bam_hdr_t* header, std::st return min_reads_met; } -void calculateTIN(TINStats* tin_stats, const std::string& gene_bed, const std::string& bam_filepath, int min_cov, int sample_size, const std::string& output_folder) +void calculateTIN(TINStats* tin_stats, const std::string& gene_bed, const std::string& bam_filepath, int min_cov, int sample_size, const std::string& output_folder, int thread_count) { std::cout << "Using TIN minimum coverage " << min_cov << " and sample size " << sample_size << std::endl; @@ -182,6 +182,9 @@ void calculateTIN(TINStats* tin_stats, const std::string& gene_bed, const std::s exit(1); } + // Enable multi-threading + hts_set_threads(bam_file, thread_count); + // Read the BAM header bam_hdr_t* header = sam_hdr_read(bam_file); if (header == NULL) { @@ -206,6 +209,7 @@ void calculateTIN(TINStats* tin_stats, const std::string& gene_bed, const std::s // Loop through the gene BED file and calculate the TIN score for each // transcript + std::cout << "Calculating TIN scores for each transcript..." << std::endl; std::vector TIN_scores; std::vector gene_ids; std::string line; @@ -396,6 +400,11 @@ void calculateTIN(TINStats* tin_stats, const std::string& gene_bed, const std::s // Store the TIN score for the transcript tin_map[name] = std::make_tuple(chrom, start, end, TIN); + + // Log every 1000 transcripts + if (gene_ids.size() % 1000 == 0) { + std::cout << "Processed " << gene_ids.size() << " transcripts" << std::endl; + } } // Close the BAM file @@ -413,6 +422,7 @@ void calculateTIN(TINStats* tin_stats, const std::string& gene_bed, const std::s if (TIN_scores.size() == 0) { std::cerr << "No TIN scores calculated" << std::endl; } else { + std::cout << "Calculating TIN summary for " << TIN_scores.size() << " transcripts..." << std::endl; // Print the TIN mean, median, and standard deviation double TIN_sum = 0; diff --git a/src/utils.cpp b/src/utils.cpp index 4d27130..c2b1dc2 100644 --- a/src/utils.cpp +++ b/src/utils.cpp @@ -1,10 +1,12 @@ #include "utils.h" /// @cond -#include -#include #include +#include +#include #include +#include +#include // getrusage /// @endcond @@ -24,3 +26,14 @@ void printError(std::string message) std::lock_guard lock(print_mtx); std::cerr << message << std::endl; } + +void printMemoryUsage(const std::string& functionName) { + struct rusage usage; + getrusage(RUSAGE_SELF, &usage); + + // Convert from KB to GB + double mem_usage_gb = (double)usage.ru_maxrss / 1024.0 / 1024.0; + std::lock_guard lock(print_mtx); + std::cout << functionName << " memory usage: " + << std::fixed << std::setprecision(2) << mem_usage_gb << " GB" << std::endl; +}