From 00700bbdce08fb86dd1abc2d664727aec2f7fd67 Mon Sep 17 00:00:00 2001 From: Troy Comi Date: Wed, 20 Mar 2019 08:46:43 -0400 Subject: [PATCH 01/33] Filter tests migrated to python 3 --- code/analyze/filter_1_main.py | 121 ++++++------- code/analyze/filter_helpers.py | 56 +++--- code/test/analyze/test_filter_1_main.py | 57 ++++++ code/test/analyze/test_filter_helpers.py | 218 +++++++++++++++++++++++ 4 files changed, 359 insertions(+), 93 deletions(-) create mode 100644 code/test/analyze/test_filter_1_main.py create mode 100644 code/test/analyze/test_filter_helpers.py diff --git a/code/analyze/filter_1_main.py b/code/analyze/filter_1_main.py index 531bc51..523ceb9 100644 --- a/code/analyze/filter_1_main.py +++ b/code/analyze/filter_1_main.py @@ -1,7 +1,7 @@ # two levels of filtering: # 1. remove regions that don't look confidently introgressed at all, # based on fraction gaps/masked, number of matches to S288c and not S288c -# --> _filtered1 +# --> _filtered1 # 2. remove regions that we can't confidently pin on a specific reference, # based on whether it matches similarly to other reference(s) # --> _filtered2 @@ -10,64 +10,65 @@ # to choose filtering thresholds for next level -import re import sys -import os -import copy -import predict -from filter_helpers import * -sys.path.insert(0, '..') +from analyze import predict +from analyze.filter_helpers import passes_filters1, write_filtered_line import global_params as gp -sys.path.insert(0, '../misc/') -import read_table -import read_fasta - -args = predict.process_predict_args(sys.argv[1:]) - -for species_from in args['known_states'][1:]: - - print species_from - - fn = gp.analysis_out_dir_absolute + args['tag'] + '/' + \ - 'blocks_' + species_from + \ - '_' + args['tag'] + '_quality.txt' - region_summary, fields = read_table.read_table_rows(fn, '\t') - - fields1i = fields + ['reason'] - fields1 = fields - - fn_out1i = gp.analysis_out_dir_absolute + args['tag'] + '/' + \ - 'blocks_' + species_from + \ - '_' + args['tag'] + '_filtered1intermediate.txt' - - fn_out1 = gp.analysis_out_dir_absolute + args['tag'] + '/' + \ - 'blocks_' + species_from + \ - '_' + args['tag'] + '_filtered1.txt' - - f_out1i = open(fn_out1i, 'w') - f_out1i.write('\t'.join(fields1i) + '\n') - - f_out1 = open(fn_out1, 'w') - f_out1.write('\t'.join(fields1) + '\n') - - for region_id in region_summary: - #print region_id, '****' - region = region_summary[region_id] - headers, seqs = read_fasta.read_fasta(gp.analysis_out_dir_absolute + \ - args['tag'] + \ - '/regions/' + region_id + '.fa.gz', \ - gz = True) - info_string = seqs[-1] - seqs = seqs[:-1] - - # filtering stage 1: things that we're confident in calling not - # S288c - p, reason = passes_filters1(region, info_string) - region['reason'] = reason - write_filtered_line(f_out1i, region_id, region, fields1i) - - if p: - write_filtered_line(f_out1, region_id, region, fields1) - - f_out1i.close() - f_out1.close() +from misc import read_table +from misc import read_fasta + + +def main(): + args = predict.process_predict_args(sys.argv[1:]) + + for species_from in args['known_states'][1:]: + + print(species_from) + + fn = gp.analysis_out_dir_absolute + args['tag'] + '/' + \ + 'blocks_' + species_from + \ + '_' + args['tag'] + '_quality.txt' + region_summary, fields = read_table.read_table_rows(fn, '\t') + + fields1i = fields + ['reason'] + fields1 = fields + + fn_out1i = gp.analysis_out_dir_absolute + args['tag'] + '/' + \ + 'blocks_' + species_from + \ + '_' + args['tag'] + '_filtered1intermediate.txt' + + fn_out1 = gp.analysis_out_dir_absolute + args['tag'] + '/' + \ + 'blocks_' + species_from + \ + '_' + args['tag'] + '_filtered1.txt' + + f_out1i = open(fn_out1i, 'w') + f_out1i.write('\t'.join(fields1i) + '\n') + + f_out1 = open(fn_out1, 'w') + f_out1.write('\t'.join(fields1) + '\n') + + for region_id in region_summary: + region = region_summary[region_id] + headers, seqs = read_fasta.read_fasta( + gp.analysis_out_dir_absolute + + args['tag'] + + '/regions/' + region_id + '.fa.gz', + gz=True) + info_string = seqs[-1] + seqs = seqs[:-1] + + # filtering stage 1: things that we're confident in calling not + # S288c + p, reason = passes_filters1(region, info_string) + region['reason'] = reason + write_filtered_line(f_out1i, region_id, region, fields1i) + + if p: + write_filtered_line(f_out1, region_id, region, fields1) + + f_out1i.close() + f_out1.close() + + +if __name__ == "__main__": + main() diff --git a/code/analyze/filter_helpers.py b/code/analyze/filter_helpers.py index d034796..1ecd6f4 100644 --- a/code/analyze/filter_helpers.py +++ b/code/analyze/filter_helpers.py @@ -1,27 +1,20 @@ -import re -import sys -import os -import copy -import gene_predictions -sys.path.insert(0, '..') import global_params as gp -sys.path.insert(0, '../sim/') -import sim_analyze_hmm_bw as sim -sys.path.insert(0, '../misc/') -import mystats -import read_table -import seq_functions +from misc import seq_functions + def write_filtered_line(f, region_id, region, fields): - f.write(region_id + '\t' + '\t'.join([str(region[field]) for field in fields[1:]])) + f.write(region_id + '\t' + '\t'.join([str(region[field]) + for field in fields[1:]])) f.write('\n') + def passes_filters(region): - + # fraction gaps + masked filter fraction_gaps_masked_threshold = .5 fraction_gaps_masked = \ - (float(region['number_gaps']) + float(region['number_masked_non_gap'])) / \ + (float(region['number_gaps']) + + float(region['number_masked_non_gap'])) / \ (int(region['end']) - int(region['start']) + 1) if fraction_gaps_masked > fraction_gaps_masked_threshold: return False @@ -36,12 +29,13 @@ def passes_filters(region): # result in much larger divergence than we'd expect) id_ref1_threshold = .7 id_ref1 = float(region['number_match_ref1']) / \ - (float(region['aligned_length']) - float(region['number_gaps'])) + (float(region['aligned_length']) - float(region['number_gaps'])) if id_ref1 < id_ref1_threshold: return False - + return True + def passes_filters1(region, info_string): # filtering out things that we can't call introgressed in general # with confidence (i.e. doesn't seem like a strong case against @@ -49,7 +43,7 @@ def passes_filters1(region, info_string): r = gp.alignment_ref_order[0] s = region['predicted_species'] - + aligned_length = (int(region['end']) - int(region['start']) + 1) # FILTER: fraction gaps + masked @@ -62,41 +56,38 @@ def passes_filters1(region, info_string): fraction_gaps_masked_s = \ 1 - float(region['num_sites_nonmask_' + s]) / aligned_length - #print fraction_gaps_masked_r, fraction_gaps_masked_s if fraction_gaps_masked_r > fraction_gaps_masked_threshold: return False, 'fraction gaps/masked in master = ' + \ str(fraction_gaps_masked_r) if fraction_gaps_masked_s > fraction_gaps_masked_threshold: return False, 'fraction gaps/masked in predicted = ' + \ str(fraction_gaps_masked_s) - + # FILTER: number sites analyzed by HMM that match predicted # reference count_P = info_string.count('P') count_C = info_string.count('C') number_match_only_threshold = 7 if count_P < number_match_only_threshold: - return False, 'count_P = ' + str(count_P) + return False, f'count_P = {count_P}' if count_P <= count_C: - return False, 'count_P = ' + str(count_P) + ' and count_C = ' + str(count_C) + return False, f'count_P = {count_P} and count_C = {count_C}' # FILTER: divergence with predicted reference and master reference # (S288c) id_predicted = float(region['match_nongap_' + s]) / \ - float(region['num_sites_nongap_' + s]) + float(region['num_sites_nongap_' + s]) id_master = float(region['match_nongap_' + r]) / \ - float(region['num_sites_nongap_' + r]) - #print region['match_nongap_' + s], region['num_sites_nongap_' + s], region['match_nongap_' + r], region['num_sites_nongap_' + r] + float(region['num_sites_nongap_' + r]) if id_master >= id_predicted: return False, 'id with master = ' + str(id_master) + \ ' and id with predicted = ' + str(id_predicted) - #if id_predicted < .7: - # return False, 'id with predicted = ' + str(id_predicted) if id_master < .7: return False, 'id with master = ' + str(id_master) return True, '' + def passes_filters2(region, seqs, threshold): # filter out things we can't assign to one species specifically; # also return the other reasonable alternatives if we're filtering @@ -105,7 +96,7 @@ def passes_filters2(region, seqs, threshold): refs = gp.alignment_ref_order n = len(seqs[0]) s = region['predicted_species'] - + ids = {} totals = {} P_counts = {} @@ -118,7 +109,9 @@ def passes_filters2(region, seqs, threshold): totals[refs[ri]] = r_total P_count = 0 for i in range(n): - if seqs[ri][i] in skip or seqs[0][i] in skip or seqs[-1][i] in skip: + if (seqs[ri][i] in skip or + seqs[0][i] in skip or + seqs[-1][i] in skip): continue if seqs[-1][i] == seqs[ri][i] and seqs[-1][i] != seqs[0][i]: P_count += 1 @@ -126,14 +119,11 @@ def passes_filters2(region, seqs, threshold): alts = {} for r in ids.keys(): - #if float(totals[r]) / totals[s] > .75 and \ - # ids[r] >= threshold * ids[s] and \ - # P_counts[r] >= threshold * P_counts[s]: # TODO should threshold be the same for both? if ids[r] >= threshold * ids[s] and \ P_counts[r] >= threshold * P_counts[s]: alts[r] = (ids[r], P_counts[r]) - + alt_states = sorted(alts.keys(), key=lambda x: alts[x][0], reverse=True) alt_ids = [alts[state][0] for state in alt_states] alt_P_counts = [alts[state][1] for state in alt_states] diff --git a/code/test/analyze/test_filter_1_main.py b/code/test/analyze/test_filter_1_main.py new file mode 100644 index 0000000..549fd71 --- /dev/null +++ b/code/test/analyze/test_filter_1_main.py @@ -0,0 +1,57 @@ +from analyze import filter_1_main as main + + +def test_main(mocker, capsys): + mocker.patch('analyze.filter_1_main.predict.process_predict_args', + return_value={ + 'known_states': ['state1', 'state2'], + 'tag': 'tag' + }) + mocker.patch('analyze.filter_1_main.gp.analysis_out_dir_absolute', + '/dir') + mocker.patch('analyze.filter_1_main.read_table.read_table_rows', + return_value=({'r1': {}, 'r2': {'a': 1}}, ['regions'])) + mocked_file = mocker.patch('analyze.filter_1_main.open') + mock_fasta = mocker.patch('analyze.filter_1_main.read_fasta.read_fasta', + return_value=(['> seq', '> info'], + ['atcg', 'x..'])) + mock_filter = mocker.patch('analyze.filter_1_main.passes_filters1', + side_effect=[(False, 'test'), # r1 + (True, '')]) # r2 + mock_write = mocker.patch('analyze.filter_1_main.write_filtered_line') + + main.main() + + captured = capsys.readouterr().out + assert captured == 'state2\n' + + assert mock_fasta.call_count == 2 + mock_fasta.assert_any_call('/dirtag/regions/r2.fa.gz', gz=True) + mock_fasta.assert_any_call('/dirtag/regions/r1.fa.gz', gz=True) + + assert mocked_file.call_count == 2 + mocked_file.assert_any_call( + '/dirtag/blocks_state2_tag_filtered1intermediate.txt', 'w') + mocked_file.assert_any_call( + '/dirtag/blocks_state2_tag_filtered1.txt', 'w') + + # just headers, capture others + mocked_file().write.assert_has_calls([ + mocker.call('regions\treason\n'), + mocker.call('regions\n')]) + + assert mock_filter.call_count == 2 + # seems like this references the object, which changes after call + mock_filter.assert_has_calls([ + mocker.call({'reason': 'test'}, 'x..'), + mocker.call({'reason': '', 'a': 1}, 'x..')]) + + assert mock_write.call_count == 3 + mock_write.assert_has_calls([ + mocker.call(mocker.ANY, 'r1', {'reason': 'test'}, + ['regions', 'reason']), + mocker.call(mocker.ANY, 'r2', {'a': 1, 'reason': ''}, + ['regions', 'reason']), + mocker.call(mocker.ANY, 'r2', {'a': 1, 'reason': ''}, + ['regions']), + ]) diff --git a/code/test/analyze/test_filter_helpers.py b/code/test/analyze/test_filter_helpers.py new file mode 100644 index 0000000..a0cd91e --- /dev/null +++ b/code/test/analyze/test_filter_helpers.py @@ -0,0 +1,218 @@ +from analyze import filter_helpers +from io import StringIO +import numpy as np + + +def test_write_filtered_line(): + # single value, first field is ignored + output = StringIO() + filter_helpers.write_filtered_line(output, 'r1', {'chr': 'I'}, ['', 'chr']) + + assert output.getvalue() == 'r1\tI\n' + + # no value + output = StringIO() + filter_helpers.write_filtered_line(output, 'r1', {}, []) + + assert output.getvalue() == 'r1\t\n' + + # two values + output = StringIO() + filter_helpers.write_filtered_line(output, 'r1', + {'a': 'b', 'c': 'd'}, + ['', 'c', 'a']) + + assert output.getvalue() == 'r1\td\tb\n' + + +def test_passes_filters(): + # check gaps + number masked / end-start+1 > 0.5 + region = {'number_gaps': 1, + 'number_masked_non_gap': 0, + 'start': 0, + 'end': 1, + 'number_match_ref2_not_ref1': 0, + 'number_match_ref1': 0, + 'aligned_length': 0, + } + assert filter_helpers.passes_filters(region) is False + region = {'number_gaps': 1, + 'number_masked_non_gap': 1, + 'start': 0, + 'end': 1, + 'number_match_ref2_not_ref1': 0, + 'number_match_ref1': 0, + 'aligned_length': 0, + } + assert filter_helpers.passes_filters(region) is False + + # check match only > 7 + region = {'number_gaps': 0, + 'number_masked_non_gap': 0, + 'start': 0, + 'end': 1, + 'number_match_ref2_not_ref1': 6, + 'number_match_ref1': 0, + 'aligned_length': 0, + } + assert filter_helpers.passes_filters(region) is False + + # check divergences (match_ref1 / aligned - gapped) < 0.7 + region = {'number_gaps': 1, + 'number_masked_non_gap': 0, + 'start': 0, + 'end': 1, + 'number_match_ref2_not_ref1': 7, + 'number_match_ref1': 6, + 'aligned_length': 11, + } + assert filter_helpers.passes_filters(region) is False + + # passes + region = {'number_gaps': 0, + 'number_masked_non_gap': 0, + 'start': 0, + 'end': 1, # fraction gaps > 0.5 + 'number_match_ref2_not_ref1': 7, # >= 7 + 'number_match_ref1': 7, # div >= 0.7 + 'aligned_length': 10, + } + assert filter_helpers.passes_filters(region) is True + + +def test_passes_filters1(mocker): + mocker.patch('analyze.filter_helpers.gp.alignment_ref_order', + ['ref']) + + # fail fraction gapped on reference + region = {'predicted_species': 'pred', + 'start': 0, + 'end': 9, + 'num_sites_nonmask_ref': 4, + 'num_sites_nonmask_pred': 0, + 'match_nongap_pred': 0, + 'num_sites_nongap_pred': 0, + 'match_nongap_ref': 0, + 'num_sites_nongap_ref': 0, + } + + assert filter_helpers.passes_filters1(region, '') == \ + (False, 'fraction gaps/masked in master = 0.6') + + # fail fraction gapped on predicted + region = {'predicted_species': 'pred', + 'start': 0, + 'end': 9, + 'num_sites_nonmask_ref': 5, + 'num_sites_nonmask_pred': 3, + 'match_nongap_pred': 0, + 'num_sites_nongap_pred': 0, + 'match_nongap_ref': 0, + 'num_sites_nongap_ref': 0, + } + + assert filter_helpers.passes_filters1(region, '') == \ + (False, 'fraction gaps/masked in predicted = 0.7') + + # fail match counts + region = {'predicted_species': 'pred', + 'start': 0, + 'end': 9, + 'num_sites_nonmask_ref': 5, + 'num_sites_nonmask_pred': 5, + 'match_nongap_pred': 0, + 'num_sites_nongap_pred': 0, + 'match_nongap_ref': 0, + 'num_sites_nongap_ref': 0, + } + + assert filter_helpers.passes_filters1(region, 'CP') == \ + (False, 'count_P = 1') + assert filter_helpers.passes_filters1(region, 'CCCCCCCCPPPPPPP') == \ + (False, 'count_P = 7 and count_C = 8') + + # fail divergence, master >= pred + region = {'predicted_species': 'pred', + 'start': 0, + 'end': 9, + 'num_sites_nonmask_ref': 5, + 'num_sites_nonmask_pred': 5, + 'match_nongap_pred': 5, + 'num_sites_nongap_pred': 10, + 'match_nongap_ref': 6, + 'num_sites_nongap_ref': 10, + } + + assert filter_helpers.passes_filters1(region, 'CPPPPPPP') == \ + (False, 'id with master = 0.6 and id with predicted = 0.5') + + # fail divergence, master >= 0.7 + region = {'predicted_species': 'pred', + 'start': 0, + 'end': 9, + 'num_sites_nonmask_ref': 5, + 'num_sites_nonmask_pred': 5, + 'match_nongap_pred': 8, + 'num_sites_nongap_pred': 10, + 'match_nongap_ref': 6, + 'num_sites_nongap_ref': 10, + } + + assert filter_helpers.passes_filters1(region, 'CPPPPPPP') == \ + (False, 'id with master = 0.6') + + # passes + region = {'predicted_species': 'pred', + 'start': 0, + 'end': 9, + 'num_sites_nonmask_ref': 5, + 'num_sites_nonmask_pred': 5, + 'match_nongap_pred': 8, + 'num_sites_nongap_pred': 10, + 'match_nongap_ref': 7, + 'num_sites_nongap_ref': 10, + } + + assert filter_helpers.passes_filters1(region, 'CPPPPPPP') == \ + (True, '') + + +def test_passes_filters2(mocker): + mocker.patch('analyze.filter_helpers.gp.alignment_ref_order', + ['ref', '1', '2', '3', '4']) + mocker.patch('analyze.filter_helpers.gp.gap_symbol', '-') + mocker.patch('analyze.filter_helpers.gp.unsequenced_symbol', 'n') + + region = {'predicted_species': '1', + } + seqs = [list('attatt'), # reference + list('aggcat'), # 4 / 5, p = 2 + list('a--tta'), # 2 / 4, p = 1 + list('nng---'), # no matches, '3' not in outputs + list('attatt'), # 2 / 5, p = 0 + list('ag-tat')] # test sequence + + seqs = np.array(seqs) + threshold = 0 + filt, states, ids, p_count = filter_helpers.passes_filters2( + region, seqs, threshold) + assert filt is False + assert states == ['1', '2', '4'] + assert ids == [0.8, 0.5, 0.4] + assert p_count == [2, 1, 0] + + threshold = 0.1 + filt, states, ids, p_count = filter_helpers.passes_filters2( + region, seqs, threshold) + assert filt is False + assert states == ['1', '2'] + assert ids == [0.8, 0.5] + assert p_count == [2, 1] + + threshold = 0.9 + filt, states, ids, p_count = filter_helpers.passes_filters2( + region, seqs, threshold) + assert filt is True + assert states == ['1'] + assert ids == [0.8] + assert p_count == [2] From 4fe85eccee62aec63286080169ba8b4e990f2613 Mon Sep 17 00:00:00 2001 From: Troy Comi Date: Wed, 20 Mar 2019 15:14:51 -0400 Subject: [PATCH 02/33] Filter 1 working with single gz file using seek --- code/analyze/extract_region.py | 63 ++---- code/analyze/filter_1_main.py | 68 +++--- code/misc/region_reader.py | 123 ++++++++++ code/test/analyze/test_extract_region.py | 93 +++++--- code/test/analyze/test_filter_1_main.py | 22 +- .../helper_scripts/compare_filter_outputs.sh | 11 + code/test/helper_scripts/run_filter_1.sh | 14 ++ code/test/misc/test_region_reader.py | 211 ++++++++++++++++++ 8 files changed, 477 insertions(+), 128 deletions(-) create mode 100644 code/misc/region_reader.py create mode 100755 code/test/helper_scripts/compare_filter_outputs.sh create mode 100755 code/test/helper_scripts/run_filter_1.sh create mode 100644 code/test/misc/test_region_reader.py diff --git a/code/analyze/extract_region.py b/code/analyze/extract_region.py index 476e4f0..77b97df 100644 --- a/code/analyze/extract_region.py +++ b/code/analyze/extract_region.py @@ -1,18 +1,15 @@ #!/usr/bin/env python3 import argparse -import os -import pickle -import gzip -import sys +from misc.region_reader import Region_Reader def main(): args = parse_args() - args = validate_args(args) - index = pickle.load(open(args['pickle'], 'rb')) - locations = decode_regions(args['regions'], index, args['list_sort']) - with gzip.open(args['filename'], 'rt') as reader: - write_regions(reader, locations, args['suppress_header']) + args, reader = validate_args(args) + with reader: + locations = decode_regions(args['regions'], + reader, args['list_sort']) + write_regions(reader, locations) def parse_args(args=None): @@ -44,40 +41,24 @@ def validate_args(args): ''' Performs checks and conversions of input, raises ValueErrors if invalid ''' - if not os.path.exists(args['filename']): - raise ValueError(f'{args["filename"]} not found') + reader = Region_Reader(args['filename'], + as_fa=False, + suppress_header=args['suppress_header'], + num_lines=15) - if args['filename'][-6:] != '.fa.gz': - raise ValueError(f'{args["filename"]} expected to be .fa.gz') + args['regions'] = [reader.convert_region(r) for r in args['regions']] - args['pickle'] = args['filename'][:-6] + '.pkl' - if not os.path.exists(args['pickle']): - raise ValueError(f'{args["pickle"]} not found with region file') + return args, reader - parsed_regions = [] - for region in args['regions']: - r = region - if r[0] == 'r': - r = r[1:] - if not r.isdigit(): - raise ValueError(f'{region} could not be parsed') - parsed_regions.append(int(r)) - args['regions'] = parsed_regions - return args - - -def decode_regions(regions, index, retain_sort): +def decode_regions(regions, reader, retain_sort): ''' Converts list of regions to file locations based on index dictionary Retain_sort controls if the output list order is determined by the region order or the disk location (i.e. values of index dict) ''' - try: - result = [index[r] for r in regions] - except KeyError as e: - raise KeyError(f'r{e} not found in index') + result = [reader.decode_region(r) for r in regions] if retain_sort: return result @@ -85,25 +66,13 @@ def decode_regions(regions, index, retain_sort): return sorted(result) -def write_regions(reader, locations, suppress_header, num_lines=15): +def write_regions(reader, locations): ''' Writes the regions specified by index to stdout If print_header is false, ignore first line after location ''' - if suppress_header is True: - num_lines -= 1 - for location in locations: - reader.seek(location) - if suppress_header is True: - reader.readline() - for i in range(num_lines): - line = reader.readline() - if line == '': - print(f'{location} outside of file', file=sys.stderr) - break - else: - print(line, end='') + reader.read_location(location) if __name__ == '__main__': diff --git a/code/analyze/filter_1_main.py b/code/analyze/filter_1_main.py index 523ceb9..68b3021 100644 --- a/code/analyze/filter_1_main.py +++ b/code/analyze/filter_1_main.py @@ -16,58 +16,48 @@ import global_params as gp from misc import read_table from misc import read_fasta +from misc.region_reader import Region_Reader def main(): args = predict.process_predict_args(sys.argv[1:]) + out_dir = gp.analysis_out_dir_absolute + args['tag'] for species_from in args['known_states'][1:]: print(species_from) - fn = gp.analysis_out_dir_absolute + args['tag'] + '/' + \ - 'blocks_' + species_from + \ - '_' + args['tag'] + '_quality.txt' - region_summary, fields = read_table.read_table_rows(fn, '\t') + region_summary, fields = read_table.read_table_rows( + f'{out_dir}/blocks_{species_from}_{args["tag"]}_quality.txt', + '\t') fields1i = fields + ['reason'] fields1 = fields - fn_out1i = gp.analysis_out_dir_absolute + args['tag'] + '/' + \ - 'blocks_' + species_from + \ - '_' + args['tag'] + '_filtered1intermediate.txt' - - fn_out1 = gp.analysis_out_dir_absolute + args['tag'] + '/' + \ - 'blocks_' + species_from + \ - '_' + args['tag'] + '_filtered1.txt' - - f_out1i = open(fn_out1i, 'w') - f_out1i.write('\t'.join(fields1i) + '\n') - - f_out1 = open(fn_out1, 'w') - f_out1.write('\t'.join(fields1) + '\n') - - for region_id in region_summary: - region = region_summary[region_id] - headers, seqs = read_fasta.read_fasta( - gp.analysis_out_dir_absolute + - args['tag'] + - '/regions/' + region_id + '.fa.gz', - gz=True) - info_string = seqs[-1] - seqs = seqs[:-1] - - # filtering stage 1: things that we're confident in calling not - # S288c - p, reason = passes_filters1(region, info_string) - region['reason'] = reason - write_filtered_line(f_out1i, region_id, region, fields1i) - - if p: - write_filtered_line(f_out1, region_id, region, fields1) - - f_out1i.close() - f_out1.close() + with open(f'{out_dir}/blocks_{species_from}_{args["tag"]}' + '_filtered1intermediate.txt', 'w') as f_out1i, \ + open(f'{out_dir}/blocks_{species_from}_{args["tag"]}' + '_filtered1.txt', 'w') as f_out1, \ + Region_Reader(f'{out_dir}/regions/{species_from}.fa.gz', + as_fa=True) as region_reader: + + f_out1i.write('\t'.join(fields1i) + '\n') + f_out1.write('\t'.join(fields1) + '\n') + + for region_id in region_summary: + region = region_summary[region_id] + headers, seqs = region_reader.read_region(region_id) + info_string = seqs[-1] + seqs = seqs[:-1] + + # filtering stage 1: things that we're confident in calling not + # S288c + p, reason = passes_filters1(region, info_string) + region['reason'] = reason + write_filtered_line(f_out1i, region_id, region, fields1i) + + if p: + write_filtered_line(f_out1, region_id, region, fields1) if __name__ == "__main__": diff --git a/code/misc/region_reader.py b/code/misc/region_reader.py new file mode 100644 index 0000000..fe45d42 --- /dev/null +++ b/code/misc/region_reader.py @@ -0,0 +1,123 @@ +import pickle +import gzip +import os +import sys +import numpy as np + + +class Region_Reader(): + def __init__(self, region_file, + as_fa=False, + suppress_header=True, + num_lines=15): + ''' + Checks for valid filename and existance of corresponding pickle + as_fa: if true will return headers and sequences as read_fasta does + suppress_header: if true will not print the #region_id line + num_lines: number of lines to print once seek to index. Includes region + header line. + ''' + if not os.path.exists(region_file): + raise ValueError(f'{region_file} not found') + + if region_file[-6:] != '.fa.gz': + raise ValueError(f'{region_file} expected to be .fa.gz') + + pickle = region_file[:-6] + '.pkl' + if not os.path.exists(pickle): + raise ValueError(f'{pickle} not found with region file') + + self.region_file = region_file + self.pickle = pickle + self.as_fa = as_fa + self.suppress_header = suppress_header + self.num_lines = num_lines + # read one less line when header is skipped + if suppress_header is True: + self.num_lines -= 1 + + def __enter__(self): + self.region_reader = gzip.open(self.region_file, 'rt') + self.index = pickle.load(open(self.pickle, 'rb')) + + def __exit__(self): + self.region_reader.close() + + def read_region(self, region_name): + ''' + read the supplied region name, either printing to stdout or returning + (headers, seqs) tuple depending on as_fa value + ''' + region = self.convert_region(region_name) + location = self.decode_region(region) + return self.read_location(location) + + def read_location(self, location): + ''' + helper method used in extract_region for directly handling locations + ''' + self.region_reader.seek(location) + + if self.suppress_header is True: + self.region_reader.readline() + else: + print(self.region_reader.readline(), end='') + + if self.as_fa: + return self.encode_fa(location) + else: + self.print_region(location) + + def convert_region(self, region_name): + ''' + Checks that region is a digit that starts with r + If so, returns the integer value of the region for decoding + ''' + r = region_name + if r[0] == 'r': + r = r[1:] + if not r.isdigit(): + raise ValueError(f'{region_name} could not be parsed') + return int(r) + + def decode_region(self, region_number): + ''' + Convert region to disk location. + Raises key error if region doesn't exist + ''' + try: + result = self.index[region_number] + except KeyError as e: + raise KeyError(f'r{e} not found in index') + + return result + + def encode_fa(self, location): + ''' + Reads the region file entry and returns headers, seqs + Assumes even numbered lines are headers, odd are sequences + ''' + headers = [] + seqs = [] + for i in range(self.num_lines): + line = self.region_reader.readline() + if line == '': + raise ValueError(f'{location} outside of file') + if i % 2 == 0: # header + headers.append(line[:-1]) + else: + seqs.append(line[:-1]) + + return headers, np.asarray(seqs) + + def print_region(self, location): + ''' + reads the region file entry, printing to stdout + ''' + for i in range(self.num_lines): + line = self.region_reader.readline() + if line == '': + print(f'{location} outside of file', file=sys.stderr) + break + else: + print(line, end='') diff --git a/code/test/analyze/test_extract_region.py b/code/test/analyze/test_extract_region.py index aa0ba26..b910653 100644 --- a/code/test/analyze/test_extract_region.py +++ b/code/test/analyze/test_extract_region.py @@ -1,6 +1,7 @@ from analyze import extract_region as ex import pytest from io import StringIO +from misc.region_reader import Region_Reader def compare_args(args, non_defaults): @@ -40,97 +41,121 @@ def test_validate_args(mocker): # fail on filename existing mocker.patch('os.path.exists', return_value=False) with pytest.raises(ValueError) as e: - ex.validate_args({'filename': 'test'}) + ex.validate_args({'filename': 'test', 'suppress_header': False}) assert 'test not found' in str(e) # fail on filename format mocker.patch('os.path.exists', return_value=True) with pytest.raises(ValueError) as e: - ex.validate_args({'filename': 'test'}) + ex.validate_args({'filename': 'test', + 'suppress_header': False}) # fail on pickle mocker.patch('os.path.exists', side_effect=[True, False]) with pytest.raises(ValueError) as e: - ex.validate_args({'filename': 'test.fa.gz'}) + ex.validate_args({'filename': 'test.fa.gz', + 'suppress_header': False}) # fail on regions mocker.patch('os.path.exists', side_effect=[True, True]) with pytest.raises(ValueError) as e: - ex.validate_args({'filename': 'test.fa.gz', 'regions': ['z123']}) + ex.validate_args({'filename': 'test.fa.gz', + 'regions': ['z123'], + 'suppress_header': False}) # fail on regions mocker.patch('os.path.exists', side_effect=[True, True]) with pytest.raises(ValueError) as e: - ex.validate_args({'filename': 'test.fa.gz', 'regions': ['rz123']}) + ex.validate_args({'filename': 'test.fa.gz', + 'regions': ['rz123'], + 'suppress_header': False}) # fail on regions mocker.patch('os.path.exists', side_effect=[True, True]) with pytest.raises(ValueError) as e: ex.validate_args({'filename': 'test.fa.gz', - 'regions': 'r123 12 z2'.split()}) + 'regions': 'r123 12 z2'.split(), + 'suppress_header': False}) assert 'z2 could not be parsed' in str(e) # success! mocker.patch('os.path.exists', side_effect=[True, True]) - args = ex.validate_args({'filename': 'test.fa.gz', - 'regions': 'r123 12 42'.split()}) - assert args['pickle'] == 'test.pkl' + args, reader = ex.validate_args({'filename': 'test.fa.gz', + 'regions': 'r123 12 42'.split(), + 'suppress_header': False}) + assert reader.pickle == 'test.pkl' assert args['regions'] == [123, 12, 42] -def test_decode_regions(): - index = {1: 2, 10: 3, 100: 4} +@pytest.fixture +def r(mocker): + mocker.patch('os.path.exists', side_effect=[True, True]) + return Region_Reader('test.fa.gz') + + +def test_decode_regions(r): + r.index = {1: 2, 10: 3, 100: 4} # raise key error with pytest.raises(KeyError) as e: - ex.decode_regions([1, 3], index, True) + ex.decode_regions([1, 3], r, True) assert 'r3 not found in index' in str(e) - result = ex.decode_regions([1, 1, 100, 10], index, True) + result = ex.decode_regions([1, 1, 100, 10], r, True) assert result == [2, 2, 4, 3] - result = ex.decode_regions([1, 1, 100, 10], index, False) + result = ex.decode_regions([1, 1, 100, 10], r, False) assert result == [2, 2, 3, 4] -def test_write_regions(capsys): +def test_write_regions(r, capsys): # empty regions - reader = StringIO('') - ex.write_regions(reader, [], True, 1) + r.region_reader = StringIO('') + r.suppress_header = True + r.num_lines = 1 + ex.write_regions(r, []) assert capsys.readouterr().out == '' # outside of file - reader = StringIO('') - ex.write_regions(reader, [100], False, 2) + r.region_reader = StringIO('') + r.suppress_header = False + r.num_lines = 2 + ex.write_regions(r, [100]) soe = capsys.readouterr() assert soe.out == '' assert soe.err == '100 outside of file\n' # outside of file on second - reader = StringIO('a test\n') - ex.write_regions(reader, [0], False, 2) + r.region_reader = StringIO('a test\n') + r.suppress_header = False + r.num_lines = 2 + ex.write_regions(r, [0]) soe = capsys.readouterr() assert soe.err == '0 outside of file\n' assert soe.out == 'a test\n' # normal, no header - reader = StringIO('header\n' - 'line 1\n' - 'line 2\n' - 'header\n' - 'line 3\n') - ex.write_regions(reader, [0, 21, 0], True, 2) + r.region_reader = StringIO('header\n' + 'line 1\n' + 'line 2\n' + 'header\n' + 'line 3\n') + r.suppress_header = True + r.num_lines = 1 + ex.write_regions(r, [0, 21, 0]) soe = capsys.readouterr() assert soe.err == '' assert soe.out == 'line 1\nline 3\nline 1\n' # normal, with header - reader = StringIO('head 1\n' - 'line 1\n' - 'line 2\n' - 'head 2\n' - 'line 3\n') - ex.write_regions(reader, [0, 21, 0], False, 2) + r.region_reader = StringIO('head 1\n' + 'line 1\n' + 'line 2\n' + 'head 2\n' + 'line 3\n') + r.suppress_header = False + r.num_lines = 1 + ex.write_regions(r, [0, 21, 0]) soe = capsys.readouterr() - assert soe.err == '' assert soe.out == 'head 1\nline 1\nhead 2\nline 3\nhead 1\nline 1\n' + assert soe.err == '' diff --git a/code/test/analyze/test_filter_1_main.py b/code/test/analyze/test_filter_1_main.py index 549fd71..aae4a68 100644 --- a/code/test/analyze/test_filter_1_main.py +++ b/code/test/analyze/test_filter_1_main.py @@ -1,4 +1,5 @@ from analyze import filter_1_main as main +from misc.region_reader import Region_Reader def test_main(mocker, capsys): @@ -8,13 +9,15 @@ def test_main(mocker, capsys): 'tag': 'tag' }) mocker.patch('analyze.filter_1_main.gp.analysis_out_dir_absolute', - '/dir') + '/dir') mocker.patch('analyze.filter_1_main.read_table.read_table_rows', return_value=({'r1': {}, 'r2': {'a': 1}}, ['regions'])) mocked_file = mocker.patch('analyze.filter_1_main.open') - mock_fasta = mocker.patch('analyze.filter_1_main.read_fasta.read_fasta', - return_value=(['> seq', '> info'], - ['atcg', 'x..'])) + + mock_read = mocker.patch('analyze.filter_1_main.Region_Reader') + mock_read().__enter__().read_region.return_value = (['> seq', '> info'], + ['atcg', 'x..']) + mock_filter = mocker.patch('analyze.filter_1_main.passes_filters1', side_effect=[(False, 'test'), # r1 (True, '')]) # r2 @@ -25,9 +28,12 @@ def test_main(mocker, capsys): captured = capsys.readouterr().out assert captured == 'state2\n' - assert mock_fasta.call_count == 2 - mock_fasta.assert_any_call('/dirtag/regions/r2.fa.gz', gz=True) - mock_fasta.assert_any_call('/dirtag/regions/r1.fa.gz', gz=True) + assert mock_read.call_count == 2 # called once during setup + mock_read.assert_called_with('/dirtag/regions/state2.fa.gz', as_fa=True) + + assert mock_read().__enter__().read_region.call_count == 2 + mock_read().__enter__().read_region.assert_any_call('r1') + mock_read().__enter__().read_region.assert_any_call('r2') assert mocked_file.call_count == 2 mocked_file.assert_any_call( @@ -36,7 +42,7 @@ def test_main(mocker, capsys): '/dirtag/blocks_state2_tag_filtered1.txt', 'w') # just headers, capture others - mocked_file().write.assert_has_calls([ + mocked_file().__enter__().write.assert_has_calls([ mocker.call('regions\treason\n'), mocker.call('regions\n')]) diff --git a/code/test/helper_scripts/compare_filter_outputs.sh b/code/test/helper_scripts/compare_filter_outputs.sh new file mode 100755 index 0000000..1f05f00 --- /dev/null +++ b/code/test/helper_scripts/compare_filter_outputs.sh @@ -0,0 +1,11 @@ +#! /bin/bash + +actual=/tigress/tcomi/aclark4_temp/results/analysis_test/ +expected=/tigress/tcomi/aclark4_temp/results/analysisp4e2/ +echo starting comarison of $(basename $actual) to $(basename $expected) + +for file in $(ls ${expected}*_filtered1*.txt); do + act=$(echo $file | sed 's/p4e2/_test/g') + cmp $act $file \ + && echo $file passed! || echo $file failed #&& exit +done diff --git a/code/test/helper_scripts/run_filter_1.sh b/code/test/helper_scripts/run_filter_1.sh new file mode 100755 index 0000000..9d14467 --- /dev/null +++ b/code/test/helper_scripts/run_filter_1.sh @@ -0,0 +1,14 @@ +#!/bin/bash + +#SBATCH --time=0-1 +#SBATCH -n 1 +#SBATCH -o "/tigress/tcomi/aclark4_temp/results/filter_%A" + +export PYTHONPATH=/home/tcomi/projects/aclark4_introgression/code/ + +module load anaconda3 +conda activate introgression3 + +ARGS="_test .001 viterbi 10000 .025 10000 .025 10000 .025 10000 .025 unknown 1000 .01" + +python ${PYTHONPATH}analyze/filter_1_main.py $ARGS diff --git a/code/test/misc/test_region_reader.py b/code/test/misc/test_region_reader.py new file mode 100644 index 0000000..9443b8a --- /dev/null +++ b/code/test/misc/test_region_reader.py @@ -0,0 +1,211 @@ +from misc.region_reader import Region_Reader +import pytest +from pytest import approx +from io import StringIO +import numpy as np + + +def test_init(mocker): + # fail on filename existing + mocker.patch('os.path.exists', return_value=False) + with pytest.raises(ValueError) as e: + Region_Reader('test') + assert 'test not found' in str(e) + + # fail on filename format + mocker.patch('os.path.exists', return_value=True) + with pytest.raises(ValueError) as e: + Region_Reader('test') + + # fail on pickle + mocker.patch('os.path.exists', side_effect=[True, False]) + with pytest.raises(ValueError) as e: + Region_Reader('test.fa.gz') + + # success, with defaults + mocker.patch('os.path.exists', side_effect=[True, True]) + r = Region_Reader('test.fa.gz') + assert r.region_file == 'test.fa.gz' + assert r.pickle == 'test.pkl' + assert r.as_fa is False + assert r.suppress_header is True + assert r.num_lines == 14, 'Suppress header did not change num_lines' + + # non defaults + mocker.patch('os.path.exists', side_effect=[True, True]) + r = Region_Reader('test1.fa.gz', + as_fa=True, + suppress_header=False, + num_lines=4) + assert r.region_file == 'test1.fa.gz' + assert r.pickle == 'test1.pkl' + assert r.as_fa is True + assert r.suppress_header is False + assert r.num_lines == 4 + + +@pytest.fixture +def r(mocker): + mocker.patch('os.path.exists', side_effect=[True, True]) + return Region_Reader('test.fa.gz') + + +def test_read_region(r, capsys): + # get fa, don't suppress header + r.region_reader = StringIO('header 1\n' + 'line 1\n' + '#header\n' + 'header 2\n' + 'line 2\n') + r.num_lines = 2 + r.as_fa = True + r.suppress_header = False + r.index = {1: 16} + header, seqs = r.read_region('r1') + assert header == ['header 2'] + assert seqs == approx(np.asarray(['line 2'])) + soe = capsys.readouterr() + assert soe.out == '#header\n' + assert soe.err == '' + + # print, suppress header + r.region_reader = StringIO('header 1\n' + 'line 1\n' + '#header\n' + 'header 2\n' + 'line 2\n') + r.num_lines = 2 + r.as_fa = False + r.suppress_header = True + r.read_region('1') + soe = capsys.readouterr() + assert soe.out == 'header 2\nline 2\n' + assert soe.err == '' + + +def test_read_location(r, capsys): + # get fa, don't suppress header + r.region_reader = StringIO('header 1\n' + 'line 1\n' + '#header\n' + 'header 2\n' + 'line 2\n') + r.num_lines = 2 + r.as_fa = True + r.suppress_header = False + header, seqs = r.read_location(16) + assert header == ['header 2'] + assert seqs == approx(np.asarray(['line 2'])) + soe = capsys.readouterr() + assert soe.out == '#header\n' + assert soe.err == '' + + # print, suppress header + r.region_reader = StringIO('header 1\n' + 'line 1\n' + '#header\n' + 'header 2\n' + 'line 2\n') + r.num_lines = 2 + r.as_fa = False + r.suppress_header = True + r.read_location(16) + soe = capsys.readouterr() + assert soe.out == 'header 2\nline 2\n' + assert soe.err == '' + + +def test_convert_region(r): + with pytest.raises(ValueError) as e: + r.convert_region('z123') + assert 'z123 could not be parsed' in str(e) + + with pytest.raises(ValueError) as e: + r.convert_region('zr123') + assert 'zr123 could not be parsed' in str(e) + + assert r.convert_region('123') == 123 + assert r.convert_region('r123') == 123 + + +def test_decode_region(r): + index = {1: 2, 10: 3, 100: 4} + r.index = index + + # raise key error + with pytest.raises(KeyError) as e: + r.decode_region(3) + assert 'r3 not found in index' in str(e) + + assert r.decode_region(1) == 2 + assert r.decode_region(10) == 3 + assert r.decode_region(100) == 4 + + +def test_encode_fa(r): + # outside of file + r.region_reader = StringIO('') + with pytest.raises(ValueError) as e: + r.encode_fa(100) + assert '100 outside of file' in str(e) + + r.region_reader = StringIO('header 1\n' + 'line 1\n' + 'header 2\n' + 'line 2\n') + r.num_lines = 4 + header, seqs = r.encode_fa(0) + assert header == ['header 1', 'header 2'] + assert seqs == approx(np.asarray(['line 1', 'line 2'])) + + r.region_reader = StringIO('header 1\n' + 'line 1\n' + 'header 2\n' + 'line 2\n') + r.num_lines = 3 + header, seqs = r.encode_fa(0) + assert header == ['header 1', 'header 2'] + assert seqs == approx(np.asarray(['line 1'])) + + +def test_print_region(r, capsys): + # outside of file + r.region_reader = StringIO('') + r.print_region(100) + soe = capsys.readouterr() + assert soe.out == '' + assert soe.err == '100 outside of file\n' + + # outside of file on second position + r.region_reader = StringIO('a test\n') + r.num_lines = 2 + r.print_region(0) + soe = capsys.readouterr() + assert soe.err == '0 outside of file\n' + assert soe.out == 'a test\n' + + # normal + r.region_reader = StringIO('header\n' + 'line 1\n' + 'line 2\n' + 'header\n' + 'line 3\n') + r.num_lines = 1 + r.print_region(0) + r.region_reader = StringIO('header\n' + 'line 1\n') + soe = capsys.readouterr() + assert soe.err == '' + assert soe.out == 'header\n' + + # normal + r.region_reader = StringIO('head 1\n' + 'line 1\n' + 'line 2\n' + 'head 2\n' + 'line 3\n') + r.num_lines = 2 + r.print_region(0) + soe = capsys.readouterr() + assert soe.err == '' + assert soe.out == 'head 1\nline 1\n' From 1a9421486727789b3ae702dacd1e1751295524e3 Mon Sep 17 00:00:00 2001 From: Troy Comi Date: Thu, 21 Mar 2019 10:58:16 -0400 Subject: [PATCH 03/33] Filter 1 running, matching output But is ~10x slower, likely due to seeking through file --- code/analyze/filter_helpers.py | 57 +++++++++++-------- code/misc/region_reader.py | 12 +++- code/test/analyze/test_filter_1_main.py | 1 - .../helper_scripts/compare_filter_outputs.sh | 11 +++- .../helper_scripts/intermediate_format.py | 25 ++++++++ 5 files changed, 78 insertions(+), 28 deletions(-) create mode 100644 code/test/helper_scripts/intermediate_format.py diff --git a/code/analyze/filter_helpers.py b/code/analyze/filter_helpers.py index 1ecd6f4..b77594b 100644 --- a/code/analyze/filter_helpers.py +++ b/code/analyze/filter_helpers.py @@ -1,11 +1,13 @@ import global_params as gp from misc import seq_functions +import numpy as np def write_filtered_line(f, region_id, region, fields): - f.write(region_id + '\t' + '\t'.join([str(region[field]) - for field in fields[1:]])) - f.write('\n') + f.write(f'{region_id}\t' + + '\t'.join([str(region[field]) + for field in fields[1:]]) + + '\n') def passes_filters(region): @@ -57,11 +59,11 @@ def passes_filters1(region, info_string): 1 - float(region['num_sites_nonmask_' + s]) / aligned_length if fraction_gaps_masked_r > fraction_gaps_masked_threshold: - return False, 'fraction gaps/masked in master = ' + \ - str(fraction_gaps_masked_r) + return False, f'fraction gaps/masked in master = '\ + f'{fraction_gaps_masked_r}' if fraction_gaps_masked_s > fraction_gaps_masked_threshold: - return False, 'fraction gaps/masked in predicted = ' + \ - str(fraction_gaps_masked_s) + return False, f'fraction gaps/masked in predicted = '\ + f'{fraction_gaps_masked_s}' # FILTER: number sites analyzed by HMM that match predicted # reference @@ -79,11 +81,12 @@ def passes_filters1(region, info_string): float(region['num_sites_nongap_' + s]) id_master = float(region['match_nongap_' + r]) / \ float(region['num_sites_nongap_' + r]) + if id_master >= id_predicted: - return False, 'id with master = ' + str(id_master) + \ - ' and id with predicted = ' + str(id_predicted) + return False, f'id with master = {id_master} '\ + f'and id with predicted = {id_predicted}' if id_master < .7: - return False, 'id with master = ' + str(id_master) + return False, f'id with master = {id_master}' return True, '' @@ -94,28 +97,34 @@ def passes_filters2(region, seqs, threshold): # it out refs = gp.alignment_ref_order - n = len(seqs[0]) s = region['predicted_species'] ids = {} totals = {} P_counts = {} - skip = [gp.gap_symbol, gp.unsequenced_symbol] - for ri in range(1, len(refs)): + seqs = np.asarray(seqs) + # skip any gap or unsequenced in ref or test + # also skip if ref and test equal (later test ri == test but not ref) + skip = np.any( + (seqs[0] == gp.gap_symbol, + seqs[0] == gp.unsequenced_symbol, + seqs[-1] == gp.gap_symbol, + seqs[-1] == gp.unsequenced_symbol, + seqs[0] == seqs[-1]), + axis=0) + + for ri, ref in enumerate(refs): + if ri == 0: + continue r_match, r_total = seq_functions.seq_id(seqs[-1], seqs[ri]) if r_total != 0: - ids[refs[ri]] = float(r_match) / r_total - totals[refs[ri]] = r_total - P_count = 0 - for i in range(n): - if (seqs[ri][i] in skip or - seqs[0][i] in skip or - seqs[-1][i] in skip): - continue - if seqs[-1][i] == seqs[ri][i] and seqs[-1][i] != seqs[0][i]: - P_count += 1 - P_counts[refs[ri]] = P_count + ids[ref] = r_match / r_total + totals[ref] = r_total + P_counts[ref] = np.sum( + np.logical_and( + np.logical_not(skip), + seqs[ri] == seqs[-1])) alts = {} for r in ids.keys(): diff --git a/code/misc/region_reader.py b/code/misc/region_reader.py index fe45d42..c6280bf 100644 --- a/code/misc/region_reader.py +++ b/code/misc/region_reader.py @@ -39,10 +39,20 @@ def __init__(self, region_file, def __enter__(self): self.region_reader = gzip.open(self.region_file, 'rt') self.index = pickle.load(open(self.pickle, 'rb')) + return self - def __exit__(self): + def __exit__(self, type, value, traceback): self.region_reader.close() + def __repr__(self): + print( + f'region_file = {self.region_file}\n' + f'pickle = {self.pickle}\n' + f'as_fa = {self.as_fa}\n' + f'suppress_header = {self.suppress_header}\n' + f'num_lines = {self.num_lines}\n' + ) + def read_region(self, region_name): ''' read the supplied region name, either printing to stdout or returning diff --git a/code/test/analyze/test_filter_1_main.py b/code/test/analyze/test_filter_1_main.py index aae4a68..d07e89d 100644 --- a/code/test/analyze/test_filter_1_main.py +++ b/code/test/analyze/test_filter_1_main.py @@ -1,5 +1,4 @@ from analyze import filter_1_main as main -from misc.region_reader import Region_Reader def test_main(mocker, capsys): diff --git a/code/test/helper_scripts/compare_filter_outputs.sh b/code/test/helper_scripts/compare_filter_outputs.sh index 1f05f00..adce0f9 100755 --- a/code/test/helper_scripts/compare_filter_outputs.sh +++ b/code/test/helper_scripts/compare_filter_outputs.sh @@ -4,8 +4,15 @@ actual=/tigress/tcomi/aclark4_temp/results/analysis_test/ expected=/tigress/tcomi/aclark4_temp/results/analysisp4e2/ echo starting comarison of $(basename $actual) to $(basename $expected) -for file in $(ls ${expected}*_filtered1*.txt); do +for file in $(ls ${expected}*_filtered1.txt); do act=$(echo $file | sed 's/p4e2/_test/g') - cmp $act $file \ + cmp <(sort $act) <(sort $file) \ + && echo $file passed! || echo $file failed #&& exit +done + +for file in $(ls ${expected}*_filtered1intermediate.txt); do + act=$(echo $file | sed 's/p4e2/_test/g') + cmp <(sort $act | python intermediate_format.py) \ + <(sort $file | python intermediate_format.py) \ && echo $file passed! || echo $file failed #&& exit done diff --git a/code/test/helper_scripts/intermediate_format.py b/code/test/helper_scripts/intermediate_format.py new file mode 100644 index 0000000..3ae7f27 --- /dev/null +++ b/code/test/helper_scripts/intermediate_format.py @@ -0,0 +1,25 @@ +import sys + + +def main(): + precision = 10 + with sys.stdin as reader: + for line in reader: + line = line.strip() + tokens = line.split('\t') + if '=' in tokens[-1]: + eq_tokens = tokens[-1].split(' ') + for i in range(len(eq_tokens)): + try: + float(eq_tokens[i]) + except ValueError: + continue + if len(eq_tokens[i]) > precision: + eq_tokens[i] = eq_tokens[i][:precision] + tokens[-1] = ' '.join(eq_tokens) + line = '\t'.join(tokens) + print(line) + + +if __name__ == "__main__": + main() From f77e1ef2b88d2dd986f07c43b932da549ce90bd7 Mon Sep 17 00:00:00 2001 From: Troy Comi Date: Thu, 21 Mar 2019 13:51:38 -0400 Subject: [PATCH 04/33] Filter 1 with region yield Changed the implementation of Region_Reader to yield headers and seqs This manages to cut the memory (somehow from 60 to 1 MB) and the runtime from 2 minutes to 10 seconds (last commit at 13 minutes). --- code/analyze/filter_1_main.py | 3 +-- code/misc/region_reader.py | 25 ++++++++++++++++++------- code/test/analyze/test_filter_1_main.py | 9 +++------ code/test/misc/test_region_reader.py | 19 +++++++++++++++++++ 4 files changed, 41 insertions(+), 15 deletions(-) diff --git a/code/analyze/filter_1_main.py b/code/analyze/filter_1_main.py index 68b3021..9383a4c 100644 --- a/code/analyze/filter_1_main.py +++ b/code/analyze/filter_1_main.py @@ -44,9 +44,8 @@ def main(): f_out1i.write('\t'.join(fields1i) + '\n') f_out1.write('\t'.join(fields1) + '\n') - for region_id in region_summary: + for region_id, header, seqs in region_reader.yield_fa(): region = region_summary[region_id] - headers, seqs = region_reader.read_region(region_id) info_string = seqs[-1] seqs = seqs[:-1] diff --git a/code/misc/region_reader.py b/code/misc/region_reader.py index c6280bf..fa66057 100644 --- a/code/misc/region_reader.py +++ b/code/misc/region_reader.py @@ -9,13 +9,13 @@ class Region_Reader(): def __init__(self, region_file, as_fa=False, suppress_header=True, - num_lines=15): + num_lines=14): ''' Checks for valid filename and existance of corresponding pickle as_fa: if true will return headers and sequences as read_fasta does suppress_header: if true will not print the #region_id line - num_lines: number of lines to print once seek to index. Includes region - header line. + num_lines: number of lines to print once seek to index. Does not + include region header line. ''' if not os.path.exists(region_file): raise ValueError(f'{region_file} not found') @@ -32,9 +32,6 @@ def __init__(self, region_file, self.as_fa = as_fa self.suppress_header = suppress_header self.num_lines = num_lines - # read one less line when header is skipped - if suppress_header is True: - self.num_lines -= 1 def __enter__(self): self.region_reader = gzip.open(self.region_file, 'rt') @@ -45,7 +42,7 @@ def __exit__(self, type, value, traceback): self.region_reader.close() def __repr__(self): - print( + return ( f'region_file = {self.region_file}\n' f'pickle = {self.pickle}\n' f'as_fa = {self.as_fa}\n' @@ -102,6 +99,20 @@ def decode_region(self, region_number): return result + def yield_fa(self): + ''' + repeatedly yield tuples of region, headers, sequences from fa file + assumes file position starts at header for region + suppress_header is taken as true (will not print) + ''' + while True: + region = self.region_reader.readline()[1:-1] + try: + header, seq = self.encode_fa(region) + yield (region, header, seq) + except ValueError: + break + def encode_fa(self, location): ''' Reads the region file entry and returns headers, seqs diff --git a/code/test/analyze/test_filter_1_main.py b/code/test/analyze/test_filter_1_main.py index d07e89d..8b21b2f 100644 --- a/code/test/analyze/test_filter_1_main.py +++ b/code/test/analyze/test_filter_1_main.py @@ -14,8 +14,9 @@ def test_main(mocker, capsys): mocked_file = mocker.patch('analyze.filter_1_main.open') mock_read = mocker.patch('analyze.filter_1_main.Region_Reader') - mock_read().__enter__().read_region.return_value = (['> seq', '> info'], - ['atcg', 'x..']) + mock_read().__enter__().yield_fa.return_value = iter([ + ('r1', ['> seq', '> info'], ['atcg', 'x..']), + ('r2', ['> seq', '> info'], ['atcg', 'x..'])]) mock_filter = mocker.patch('analyze.filter_1_main.passes_filters1', side_effect=[(False, 'test'), # r1 @@ -30,10 +31,6 @@ def test_main(mocker, capsys): assert mock_read.call_count == 2 # called once during setup mock_read.assert_called_with('/dirtag/regions/state2.fa.gz', as_fa=True) - assert mock_read().__enter__().read_region.call_count == 2 - mock_read().__enter__().read_region.assert_any_call('r1') - mock_read().__enter__().read_region.assert_any_call('r2') - assert mocked_file.call_count == 2 mocked_file.assert_any_call( '/dirtag/blocks_state2_tag_filtered1intermediate.txt', 'w') diff --git a/code/test/misc/test_region_reader.py b/code/test/misc/test_region_reader.py index 9443b8a..e56d9a3 100644 --- a/code/test/misc/test_region_reader.py +++ b/code/test/misc/test_region_reader.py @@ -142,6 +142,25 @@ def test_decode_region(r): assert r.decode_region(100) == 4 +def test_yield_fa(r): + r.region_reader = StringIO('#h1\n' + 'header 1\n' + 'line 1\n' + '#h2\n' + 'header 2\n' + 'line 2\n') + r.num_lines = 2 + regions = ['h1', 'h2'] + headers = [['header 1'], ['header 2']] + seqs = [np.asarray(['line 1']), np.asarray(['line 2'])] + i = 0 + for region, header, seq in r.yield_fa(): + assert region == regions[i] + assert header == headers[i] + assert seq == seqs[i] + i += 1 + + def test_encode_fa(r): # outside of file r.region_reader = StringIO('') From 42d91dd09881d2ac3fd09f6f40e2475b31323b0c Mon Sep 17 00:00:00 2001 From: Troy Comi Date: Fri, 22 Mar 2019 10:18:55 -0400 Subject: [PATCH 05/33] Filter 2 refactor Changed filter2 to use numpy and the new yielded regions Execution is ~10s and uses 1.5 MB memory In addition to floating point precision differences, noticed difference in sorting of alt ids when the values are equal going form python 2 and 3. Handled differences in comparison scripts which format to 10 digits and sort ids when values are equal. --- code/analyze/filter_1_main.py | 1 - code/analyze/filter_2_main.py | 121 +++++++++--------- code/misc/region_reader.py | 9 +- code/test/analyze/r10805.fa | 14 ++ code/test/analyze/test_filter_2_main.py | 89 +++++++++++++ code/test/analyze/test_filter_helpers.py | 40 +++++- .../helper_scripts/compare_filter_outputs.sh | 17 ++- ...ate_format.py => intermediate_format_1.py} | 0 .../helper_scripts/intermediate_format_2.py | 30 +++++ code/test/helper_scripts/run_filter_2.sh | 14 ++ code/test/misc/test_region_reader.py | 35 ++++- 11 files changed, 295 insertions(+), 75 deletions(-) create mode 100644 code/test/analyze/r10805.fa create mode 100644 code/test/analyze/test_filter_2_main.py rename code/test/helper_scripts/{intermediate_format.py => intermediate_format_1.py} (100%) create mode 100644 code/test/helper_scripts/intermediate_format_2.py create mode 100755 code/test/helper_scripts/run_filter_2.sh diff --git a/code/analyze/filter_1_main.py b/code/analyze/filter_1_main.py index 9383a4c..043a782 100644 --- a/code/analyze/filter_1_main.py +++ b/code/analyze/filter_1_main.py @@ -15,7 +15,6 @@ from analyze.filter_helpers import passes_filters1, write_filtered_line import global_params as gp from misc import read_table -from misc import read_fasta from misc.region_reader import Region_Reader diff --git a/code/analyze/filter_2_main.py b/code/analyze/filter_2_main.py index a79e5d4..604b65a 100644 --- a/code/analyze/filter_2_main.py +++ b/code/analyze/filter_2_main.py @@ -1,7 +1,7 @@ # two levels of filtering: # 1. remove regions that don't look confidently introgressed at all, # based on fraction gaps/masked, number of matches to S288c and not S288c -# --> _filtered1 +# --> _filtered1 # 2. remove regions that we can't confidently pin on a specific reference, # based on whether it matches similarly to other reference(s) # --> _filtered2 @@ -9,68 +9,61 @@ # do second level of filtering here, based on previously selected # thresholds -import re import sys -import os -import copy -import predict -from filter_helpers import * -sys.path.insert(0, '..') +from analyze import predict +from analyze.filter_helpers import (write_filtered_line, + passes_filters2) import global_params as gp -sys.path.insert(0, '../misc/') -import read_table -import read_fasta - -args = predict.process_predict_args(sys.argv[2:]) -threshold = float(sys.argv[1]) - -for species_from in args['known_states'][1:]: - - print species_from - - fn = gp.analysis_out_dir_absolute + args['tag'] + '/' + \ - 'blocks_' + species_from + \ - '_' + args['tag'] + '_filtered1.txt' - region_summary, fields = read_table.read_table_rows(fn, '\t') - - fields2i = fields + ['alternative_states', 'alternative_ids', \ - 'alternative_P_counts'] - fields2 = fields - - fn_out2i = gp.analysis_out_dir_absolute + args['tag'] + '/' + \ - 'blocks_' + species_from + \ - '_' + args['tag'] + '_filtered2intermediate.txt' - - fn_out2 = gp.analysis_out_dir_absolute + args['tag'] + '/' + \ - 'blocks_' + species_from + \ - '_' + args['tag'] + '_filtered2.txt' - - f_out2i = open(fn_out2i, 'w') - f_out2i.write('\t'.join(fields2i) + '\n') - - f_out2 = open(fn_out2, 'w') - f_out2.write('\t'.join(fields2) + '\n') - - for region_id in region_summary: - #print region_id, '****' - region = region_summary[region_id] - headers, seqs = read_fasta.read_fasta(gp.analysis_out_dir_absolute + \ - args['tag'] + \ - '/regions/' + region_id + '.fa.gz', \ - gz = True) - info_string = seqs[-1] - seqs = seqs[:-1] - - # filtering stage 2: things that we're confident in calling - # introgressed from one species specifically - p, alt_states, alt_ids, alt_P_counts = passes_filters2(region, seqs, threshold) - region['alternative_states'] = ','.join(alt_states) - region['alternative_ids'] = ','.join([str(x) for x in alt_ids]) - region['alternative_P_counts'] = ','.join([str(x) for x in alt_P_counts]) - write_filtered_line(f_out2i, region_id, region, fields2i) - - if p: - write_filtered_line(f_out2, region_id, region, fields2) - - f_out2i.close() - f_out2.close() +from misc import read_table +from misc.region_reader import Region_Reader + + +def main(): + args = predict.process_predict_args(sys.argv[2:]) + threshold = float(sys.argv[1]) + out_dir = gp.analysis_out_dir_absolute + args['tag'] + + for species_from in args['known_states'][1:]: + + print(species_from) + + region_summary, fields = read_table.read_table_rows( + f'{out_dir}/blocks_{species_from}_{args["tag"]}_filtered1.txt', + '\t') + + fields2i = fields + ['alternative_states', 'alternative_ids', + 'alternative_P_counts'] + fields2 = fields + + with open(f'{out_dir}/blocks_{species_from}_{args["tag"]}' + '_filtered2intermediate.txt', 'w') as f_out2i, \ + open(f'{out_dir}/blocks_{species_from}_{args["tag"]}' + '_filtered2.txt', 'w') as f_out2, \ + Region_Reader(f'{out_dir}/regions/{species_from}.fa.gz', + as_fa=True) as region_reader: + + f_out2i.write('\t'.join(fields2i) + '\n') + f_out2.write('\t'.join(fields2) + '\n') + + for region_id, header, seqs in \ + region_reader.yield_fa(region_summary.keys()): + region = region_summary[region_id] + + seqs = seqs[:-1] + + # filtering stage 2: things that we're confident in calling + # introgressed from one species specifically + p, alt_states, alt_ids, alt_P_counts = passes_filters2( + region, seqs, threshold) + region['alternative_states'] = ','.join(alt_states) + region['alternative_ids'] = ','.join([str(x) for x in alt_ids]) + region['alternative_P_counts'] = ','.join( + [str(x) for x in alt_P_counts]) + write_filtered_line(f_out2i, region_id, region, fields2i) + + if p: + write_filtered_line(f_out2, region_id, region, fields2) + + +if __name__ == '__main__': + main() diff --git a/code/misc/region_reader.py b/code/misc/region_reader.py index fa66057..46d6746 100644 --- a/code/misc/region_reader.py +++ b/code/misc/region_reader.py @@ -99,17 +99,22 @@ def decode_region(self, region_number): return result - def yield_fa(self): + def yield_fa(self, keys=None): ''' repeatedly yield tuples of region, headers, sequences from fa file assumes file position starts at header for region suppress_header is taken as true (will not print) + keys are a list of regions which are valid. If provided will only + yield regions found in keys ''' while True: region = self.region_reader.readline()[1:-1] try: header, seq = self.encode_fa(region) - yield (region, header, seq) + if keys is None or region in keys: + yield (region, header, np.asarray([list(s) for s in seq])) + else: + continue except ValueError: break diff --git a/code/test/analyze/r10805.fa b/code/test/analyze/r10805.fa new file mode 100644 index 0000000..0f65359 --- /dev/null +++ b/code/test/analyze/r10805.fa @@ -0,0 +1,14 @@ +> S288c 279920 281150 +tttagaggattaccatttcaacagatcgtccttagcatataagtagtcgtcaaaaatgaattcaacttcgtctgtttcggcattgtagccgccaactctgatggattcgtggtttttgacaatgatgtcacagcctttttcctttaggaagtccaagtcgaaagtagtggcaataccaatgatcttacaaccggcggcttttccggcggcaatacctgctggagcgtcttcaaatactactaccttagatttggaagggtcttgctcattgatcggatatcctaagccattcctgcccttcagatatggttctggatgaggcttaccctgtttgacatcattagcggtaatgaagtactttggtctcctgattcccagatgctcgaaccatttttgtgccatatcacgggtaccggaagttgccacagcccatttctcttttggtagagcgttcaaagcgttgcacagcttaactgcacctgggacttcaatggatttttcaccgtacttgaccggaatttcagcttctaatttgttaacatactcttcattggcaaagtctggagcgaacttagcaatggcatcaaacgttctccaaccatgcgagacttggataacgtgttcagcatcgaaataaggtttgtccttaccgaaatccctccagaatgcagcaatggctggttgagagatgataatggtaccgtcgacgtcgaacaaagcggcgttaactttcaaagatagaggtttagtagtcaatcccattccgaatattgtttttattgttttat--gtttttccactgatctggtaaacactagctggttggcgctattaatatgaaaagagttagaccaaattgagtagaaaagaaacctttggcaatcctaactatgttgttttagcttgtgtatttaagcgc------------atatatatatatttctgaaaatgacaacatcaaaagaaacgaacttatttagaataaaaagaaacgacttggcttcttattattcctactttacgtcacgtgggaggcccgtttagg----ggggcagctatgtagtttttccgagcgtactttctttcagcatccgaaaagtcctcacttgacggcttacacggaaacgccgcggattgtggggcacagatgatgacgcagacggaacactgc--agaaatctttttaccttgtcgttaaagacgatattagagagaagagttt-ggctggggacaaagtgccagctttt +> CBS432 276420 277656 +tttaaaaggttaccatttcaaaagatcgtccttagcatacaagtagtcatcaaatataaattcaacttcatccgtttcggcatcgtaaccgccaactctgatggattcatggttcttaacaatgatgtcacaacctttttcctttaggaagtctaaatcgaaagtagtggcaataccaatgatcttacaaccggcggccttaccggcggcaatacctgctggggcgtcttcaaaaactaccaccttggatttggaaggatcttgttcattgattggatatcctaggccatttctacctttcaaatatggttctggatgaggtttaccttgtttgacatcattagcggtaatgaaatactttggtctcctgattcctagatgttcaaaccatttttgagccatatcacgggtaccagaagttgccacagcccatttttcttttggtagagcgttcaaagcattgcacagcttgactgcacctggaacttcaatggacttttcaccgtacttgaccggaatttcagcttctagtttgttaacatactcttcattggcaaagtccggagcaaacttagcaatggcatcgtaagttctccagccatgtgagacttggataacatgctcggcatcgaaataaggcttgtccttaccgaagtccctccagaatgcagcgatggctggttgagagatgatgatggtaccgtcgacgtcgaacaaagcggcgttaacttttaaagataaaggtttagtagtcaatcccattttaaatattgtttttattattttcttatattttctgctgttgtaataaacactagcttggt-gtactctgaaaatgaaaaagacaaaaacagactgaat--tagaggaaccttgggctatgttaattgtgttcttctagattgtatgattaacctta----------tatatatatatatttatgaaaatgactacatcaaaagaaacgaacata-ccaaaagaaaaagaaacgacttggcttcttattattcctactttacgtcacgtggga-gcccgtttagggggaggggcaggcatatagtttttccgagcgtagtttcttttagcatccgaaaagtcctcagttggcggcttactcggaaacgccgcggattgtggggcacagaggatagcgcagacggaacacggcagagagaacaattttacttgtcgctaaagacgaaaccagagatgagagtttcggatagggacaaaatgcccgctttc +> N_45 285263 286497 +tttaaaaggttaccatttcaaaagatcgtccttagcatacaagtagtcatcaaatataaattcaacttcatccgtttcggcatcgtaaccgccaactctgatggattcatggttcttaacaatgatgtcacaacctttttcctttaggaagtctaaatcgaaagtagtggcaataccaatgatcttacaaccggcggccttaccggcggcaatacctgctggggcgtcttcaaaaactaccaccttggatttggaaggatcttgttcattgattggatatcctaggccatttctacctttcaaatatggttctggatgaggtttaccttgtttgacatcattagcggtaatgaaatactttggtctcctgattcctagatgttcaaaccatttttgagccatatcacgggtaccagaagttgccacagcccatttctcttttggtagagcgttcaaagcattgcacagcttgactgcacctggaacttcaatggacttttcaccgtacttgaccggaatttcagcttctaatttgttaacatactcttcattggcaaagtccggagcaaacttagcaatggcatcgtaagttctccagccatgcgagacttggataacatgctcggcatcgaaataaggcttgtccttaccgaagtccctccagaatgcagcgatggctggttgagagatgatgatggtaccgtcgacgtcgaacaaagcggcgttaacttttaaagataaaggtttagtagtcaatcccattttaaatattgtttttactattttcttatatgttctgctgttgtaataaacactagcttggt-gtactctgaaaatgaaaaagacaaaaacagactgaat--tagaggaacctcgggctatgttaattgtgttcttctagattgtatgattaacctt------------atatatatatatttatgaaaatgactacatcaaaagaaacgaacata-ccaaaagaaaaagaaacgacttggcttcttattattcctactttacgtcacgtggga-gcccgtttagggggaggggcaggcatatagtttttccgagcgtagtttcttttagcatccgaaaagtcctcagtttgcggcttactcggaaacgccgcggattgtggggcacagaggatagcgcagacggaacacggcagagagaacaattttacttgtcgctaaagacgaaaccagagatgagagtttcggatagggacaaaatgcccgctttc +> DBVPG6304 291210 292453 +cttaaaaggttaccatttcaaaagatcgtccttggcatacaagtagtcatcaaatataaattcaacttcatccgtttcggcatcgtaaccgccaactctgatggattcatggttcttgacaatgatgtcacaacctttttcctttaggaagtctaagtcgaaagtagtggcaataccaatgatcttacaaccggcagccttaccggcggcaatacctgctggggcgtcttcaaaaacaaccaccttggatttggaaggatcttgttcattgattgggtatccaaggccatttctacctttcaaatatggttctggatgaggtttaccttgtttaacatcattagcggtgataaaatactttggtctcttgattcctagatgttcaaaccatttttgagccatatcacgggtaccagaagttgccacagcccatttctcttttggtagagcattcaaagcattgcacagcttgactgcacctggaacttcaatggatttttcaccgtatttgactggaatttcagcttctaatttattaacatattcttcattggcaaagtccggagcaaacttagcaatggcatcgtaagttctccaaccatgcgagacttggataacatgctcggcatcgaaataaggcttgtccttaccgaagtccctccagaatgcagcgatggctggttgagagatgatgatggtaccatcgacatcgaacaaagcggcgttaactttcaaagataaaggtttagtagtcaatcccattttaaatgttgtttttattattttcttatattttttgttgttgtaataaacactagcttgat-gtgctctgaaaatgaaaaagactaaaacaaactgaat--tagaggaaccttgggctatgttaattgtgttcttc-agatagtatgattaaccttgtatatatatatgtatatatgtatatatgaaaatgactacatcaaaagaaacgaacgta-ccaaaagaaaaagaaacgacttggtttcttattattcccactttacgtcacgtggga-gcccgtttagggggaggggttggcatatagtttttccgagcgtagtttcttttagcatccgaaaagtcctcagttggcggcttactcggaaacgccgcggattgtggggcacagaggatagcgcagacggatcacggc--agggaacaattttacttgtcgctaaagacgatatcagagaagagagtttcggatggggacaaaatgcccgctttc +> UWOPS91_917_1 293672 294915 +cttaaaaggttaccatttcaaaagatcgtccttggcatacaagtagtcatcaaatataaattcaacttcatccgtttcggcatcataaccgccaactctgatggattcatggttcttgacaatgatgtcacaacctttttcctttaggaagtctaagtcgaaagtagtggcaataccaatgatcttacaaccggcggccttaccggcggcaatacctgctggggcgtcttcaaaaacaaccaccttggatttggaaggatcttgttcattgattggatatcctaggccatttctacctttcaaatatggttctggatgaggtttaccttgtttaacatcattggcggtgataaaatactttggtctcctgattcctagatgttcaaaccatttttgagccatatcacgggtaccagaagttgccacagcccatttctcttttggtagagcattcaaagcattgcacagcttgactgcacctggaacttcaatggatttttcaccgtatttgactggaatttcagcttctaatttattaacatattcttcattggcaaagtccggagcaaacttagcaatggcatcgtaggtcctccaaccatgcgagacttggataacatgctcggcatcgaaataaggcttgtccttaccgaagtccctccagaatgcagcgatggctggttgagagatgatgatggtaccgtcgacgtcgaacaaagcggcgttaactttcaaagataaaggtttagtagtcaatcccattttaaatattgtttttattattttcttatattttttgttgttgtaataaacactagcttgat-gtgctctgaaaatgaaaaagactaaaacaaactgaat--tagaggaaccttgggctatgttaattgtgttcttc-agatagtatgattaaccttgtatatatatatgtatatatgtatatatgaaaatgactacatcaaaagaaacgaacgta-ccaaaagaaaaagaaacgacttggtttcttattattcccactttacgtcacgtggga-gcccgtttagggggaggggttggcatatagtttttccgagcgtagtttcttttagcatccgaaaagtcctcagttggcggcttactcggaaacgccgcggattgtggggcacagaggatagcgcagacggatcacggc--agggaacaattttacttgtcgctaaagacgatatcagagaagagagtttcggatggggacaaaatgcccgctttc +> yjm248 287893 289127 +tttaaaaggttaccatttcaaaaggtcgtccttagcatacaagtagtcatcaaatataaattcaacttcatccgtttcggcatcgtaaccgccaactctgatggattcatggttcttaacaatgatgtcacaacctttttcctttaggaagtctaaatcgaaagtagtggcaataccaatgatcttacaaccggcggccttaccggcggcaatacctgctggggcgtcttcaaaaactaccaccttggatttggaaggatcttgttcattgattggatatcctaggccatttctacctttcaaatatggttctggatgaggtttaccttgtttgacatcattagcggtaatgaaatactttggtctcctgattcctagatgttcaaaccatttttgagccatatcacgggtaccagaagttgccacagcccatttttcttttggtagagcgttcaaagcattgcacagcttgactgcacctggaacttcaatggacttttcaccgtacttgaccggaatttcagcttctagtttgttaacatactcttcattggcaaagtccggagcaaacttagcaatggcatcgtaagttctccagccatgcgagacttggataacatgctcggcatcgaaataaggcttgtccttaccgaagtccctccagaatgcagcgatggctggttgagagatgatgatggtaccgtcgacgtcgaacaaagcggcgttaacttttaaagataaaggtttagtagtcaatcccattttaaatattgtttttattattttcttatattttctgctgttgtaataaacactagcttggt-gtactctgaaaatgaaaaagacaaaaacagactgaat--tagaggaaccttgggctatgttaattgtgttcttctagattgtatgattaacctt------------atatatatatatttatgaaaatgactacatcaaaagaaacgaacata-ccaaaagaaaaagaaacgacttggcttcttattattcctactttacgtcacgtggga-gcccgtttagggggaggggcaggcatatagtttttccgagcgtagtttcttttagcatccgaaaagtcctcagttggcggcttactcggaaacgccgcggattgtggggcacagaggatagcgcagacggaacacggcagagagaacaattttacttgtcgctaaagacgaaaccagagatgagagtttcggatagggacaaaatgcccgctttc +> info +B...P.P.P............P..X........B.....P........P.....P..P...........P..P..........PB..P....................P.....P..P..............P....................P..P......................................B..P..P....................P...........P..B..P.....P...........P.....P........P..B.....B.P......P..P..P....P..................P.....P.....B........B.....B..B..P............B.......P.....P..P...........P.................P....................X..............B........P...........P...........P...........P...........B.....B................X...B........B.................P.....P.................PP.P..B.....P.....B..............P..P..P..............P..............P.................P....................P........B.....B.......................P.......P......................PPP...B.........C.P....P.--PP.C..BPPB..P.P.PP............P.P.-.PP..P.P..P.......PPPPP.P.P..P.P...P.--P.P..P.....CP...P..PP...P.P....P..Pb..P.B..P.PP....P.P_------------____________B.P..........P..................P..-PP.P..P.................B.............B..................-...........----....BB.PP..P.................P.......P...................P..CP........P.............................P...PP..........B...P..--..BP.P.PP...PP.......P.........P.PP.....PP.......-..P.P........P....P.....P diff --git a/code/test/analyze/test_filter_2_main.py b/code/test/analyze/test_filter_2_main.py new file mode 100644 index 0000000..dd18e02 --- /dev/null +++ b/code/test/analyze/test_filter_2_main.py @@ -0,0 +1,89 @@ +from analyze import filter_2_main as main + + +def test_main(mocker, capsys): + mocker.patch('sys.argv', ['', '0.1']) + mocker.patch('analyze.filter_2_main.predict.process_predict_args', + return_value={ + 'known_states': ['state1', 'state2'], + 'tag': 'tag' + }) + mocker.patch('analyze.filter_2_main.gp.analysis_out_dir_absolute', + '/dir') + mocker.patch('analyze.filter_2_main.read_table.read_table_rows', + return_value=({'r1': {}, 'r2': {'a': 1}}, ['regions'])) + + mocked_file = mocker.patch('analyze.filter_2_main.open') + + mock_read = mocker.patch('analyze.filter_2_main.Region_Reader') + mock_read().__enter__().yield_fa.return_value = iter([ + ('r1', ['> seq', '> info'], ['atcg', 'x..']), + ('r2', ['> seq', '> info'], ['atcg', 'x..'])]) + + mock_filter = mocker.patch('analyze.filter_2_main.passes_filters2', + side_effect=[ + (False, ['1', '2'], [0.8, 0.5], [2, 1, 0]), + (True, ['1'], [0.8], [2]) + ]) + mock_write = mocker.patch('analyze.filter_2_main.write_filtered_line') + + main.main() + + captured = capsys.readouterr().out + assert captured == 'state2\n' + + assert mock_read.call_count == 2 + mock_read.assert_called_with('/dirtag/regions/state2.fa.gz', as_fa=True) + + assert mocked_file.call_count == 2 + mocked_file.assert_any_call( + '/dirtag/blocks_state2_tag_filtered2intermediate.txt', 'w') + mocked_file.assert_any_call( + '/dirtag/blocks_state2_tag_filtered2.txt', 'w') + + # just headers, capture others + mocked_file().__enter__().write.assert_has_calls([ + mocker.call('regions\talternative_states\t' + 'alternative_ids\talternative_P_counts\n'), + mocker.call('regions\n')]) + + assert mock_filter.call_count == 2 + # seems like this references the object, which changes after call + mock_filter.assert_has_calls([ + mocker.call( + {'alternative_states': '1,2', + 'alternative_ids': '0.8,0.5', + 'alternative_P_counts': '2,1,0'}, + ['atcg'], 0.1), + mocker.call( + {'a': 1, + 'alternative_states': '1', + 'alternative_ids': '0.8', + 'alternative_P_counts': '2'}, + ['atcg'], 0.1)]) + + assert mock_write.call_count == 3 + mock_write.assert_has_calls([ + mocker.call(mocker.ANY, 'r1', + {'alternative_states': '1,2', + 'alternative_ids': '0.8,0.5', + 'alternative_P_counts': '2,1,0'}, + ['regions', 'alternative_states', + 'alternative_ids', 'alternative_P_counts'] + ), + mocker.call(mocker.ANY, 'r2', + {'a': 1, + 'alternative_states': '1', + 'alternative_ids': '0.8', + 'alternative_P_counts': '2'}, + ['regions', 'alternative_states', + 'alternative_ids', 'alternative_P_counts'] + ), + mocker.call(mocker.ANY, 'r2', + {'a': 1, + 'alternative_states': '1', + 'alternative_ids': '0.8', + 'alternative_P_counts': '2'}, + ['regions'] + ) + ]) diff --git a/code/test/analyze/test_filter_helpers.py b/code/test/analyze/test_filter_helpers.py index a0cd91e..94d0212 100644 --- a/code/test/analyze/test_filter_helpers.py +++ b/code/test/analyze/test_filter_helpers.py @@ -1,6 +1,10 @@ from analyze import filter_helpers from io import StringIO import numpy as np +from misc import read_fasta +import os +import warnings +from pytest import approx def test_write_filtered_line(): @@ -58,14 +62,13 @@ def test_passes_filters(): assert filter_helpers.passes_filters(region) is False # check divergences (match_ref1 / aligned - gapped) < 0.7 - region = {'number_gaps': 1, - 'number_masked_non_gap': 0, + region = {'number_masked_non_gap': 0, 'start': 0, 'end': 1, 'number_match_ref2_not_ref1': 7, 'number_match_ref1': 6, 'aligned_length': 11, - } + 'number_gaps': 1} assert filter_helpers.passes_filters(region) is False # passes @@ -192,7 +195,6 @@ def test_passes_filters2(mocker): list('attatt'), # 2 / 5, p = 0 list('ag-tat')] # test sequence - seqs = np.array(seqs) threshold = 0 filt, states, ids, p_count = filter_helpers.passes_filters2( region, seqs, threshold) @@ -216,3 +218,33 @@ def test_passes_filters2(mocker): assert states == ['1'] assert ids == [0.8] assert p_count == [2] + + +def test_passes_filters2_on_region(mocker): + mocker.patch('analyze.filter_helpers.gp.alignment_ref_order', + ['S288c', 'CBS432', 'N_45', 'DBVPG6304', 'UWOPS91_917_1']) + mocker.patch('analyze.filter_helpers.gp.gap_symbol', '-') + mocker.patch('analyze.filter_helpers.gp.unsequenced_symbol', 'n') + + fa = os.path.join(os.path.split(__file__)[0], 'r10805.fa') + + if os.path.exists(fa): + headers, seqs = read_fasta.read_fasta(fa, gz=False) + seqs = seqs[:-1] + p, alt_states, alt_ids, alt_P_counts = filter_helpers.passes_filters2( + {'predicted_species': 'N_45'}, seqs, 0.1) + assert p is False + assert alt_states == ['CBS432', 'N_45', 'UWOPS91_917_1', 'DBVPG6304'] + assert alt_ids == approx([0.9983805668016195, 0.994331983805668, + 0.9642857142857143, 0.9618506493506493]) + assert alt_P_counts == [145, 143, 128, 129] + + p, alt_states, alt_ids, alt_P_counts = filter_helpers.passes_filters2( + {'predicted_species': 'N_45'}, seqs, 0.98) + assert p is False + assert alt_states == ['CBS432', 'N_45'] + assert alt_ids == approx([0.9983805668016195, 0.994331983805668]) + assert alt_P_counts == [145, 143] + + else: + warnings.warn('Unable to test with datafile r10805.fa') diff --git a/code/test/helper_scripts/compare_filter_outputs.sh b/code/test/helper_scripts/compare_filter_outputs.sh index adce0f9..533888b 100755 --- a/code/test/helper_scripts/compare_filter_outputs.sh +++ b/code/test/helper_scripts/compare_filter_outputs.sh @@ -12,7 +12,20 @@ done for file in $(ls ${expected}*_filtered1intermediate.txt); do act=$(echo $file | sed 's/p4e2/_test/g') - cmp <(sort $act | python intermediate_format.py) \ - <(sort $file | python intermediate_format.py) \ + cmp <(sort $act | python intermediate_format_1.py) \ + <(sort $file | python intermediate_format_1.py) \ && echo $file passed! || echo $file failed #&& exit done + +for file in $(ls ${expected}*_filtered2.txt); do + act=$(echo $file | sed 's/p4e2/_test/g') + cmp <(sort $act) <(sort $file) \ + && echo $file passed! || echo $file failed #&& exit +done + +for file in $(ls ${expected}*_filtered2intermediate.txt); do + act=$(echo $file | sed 's/p4e2/_test/g') + cmp <(sort $act | python intermediate_format_2.py) \ + <(sort $file | python intermediate_format_2.py) \ + && echo $file passed! || echo $file failed && exit +done diff --git a/code/test/helper_scripts/intermediate_format.py b/code/test/helper_scripts/intermediate_format_1.py similarity index 100% rename from code/test/helper_scripts/intermediate_format.py rename to code/test/helper_scripts/intermediate_format_1.py diff --git a/code/test/helper_scripts/intermediate_format_2.py b/code/test/helper_scripts/intermediate_format_2.py new file mode 100644 index 0000000..ec5ca97 --- /dev/null +++ b/code/test/helper_scripts/intermediate_format_2.py @@ -0,0 +1,30 @@ +import sys + + +def main(): + precision = 10 + with sys.stdin as reader: + for line in reader: + line = line.strip() + tokens = line.split('\t') + # limit float sizes to 10 characters + for j in range(len(tokens)-2, len(tokens)): + float_tokens = tokens[j].split(',') + for i in range(len(float_tokens)): + try: + float(float_tokens[i]) + except ValueError: + continue + if len(float_tokens[i]) > precision: + float_tokens[i] = float_tokens[i][:precision] + tokens[j] = ','.join(float_tokens) + # check if alt ids are equal, sorting is messed up from py2 to 3 + id_toks = tokens[-2].split(',') + if len(id_toks) > 1 and id_toks[0] == id_toks[1]: + tokens[-3] = ','.join(sorted(tokens[-3].split(','))) + line = '\t'.join(tokens) + print(line) + + +if __name__ == "__main__": + main() diff --git a/code/test/helper_scripts/run_filter_2.sh b/code/test/helper_scripts/run_filter_2.sh new file mode 100755 index 0000000..a3cbb08 --- /dev/null +++ b/code/test/helper_scripts/run_filter_2.sh @@ -0,0 +1,14 @@ +#!/bin/bash + +#SBATCH --time=0-1 +#SBATCH -n 1 +#SBATCH -o "/tigress/tcomi/aclark4_temp/results/filter_%A" + +export PYTHONPATH=/home/tcomi/projects/aclark4_introgression/code/ + +module load anaconda3 +conda activate introgression3 + +ARGS="0.98 _test .001 viterbi 10000 .025 10000 .025 10000 .025 10000 .025 unknown 1000 .01" + +python ${PYTHONPATH}analyze/filter_2_main.py $ARGS diff --git a/code/test/misc/test_region_reader.py b/code/test/misc/test_region_reader.py index e56d9a3..2b98c6b 100644 --- a/code/test/misc/test_region_reader.py +++ b/code/test/misc/test_region_reader.py @@ -152,13 +152,44 @@ def test_yield_fa(r): r.num_lines = 2 regions = ['h1', 'h2'] headers = [['header 1'], ['header 2']] - seqs = [np.asarray(['line 1']), np.asarray(['line 2'])] + seqs = [np.asarray([list('line 1')]), np.asarray([list('line 2')])] i = 0 for region, header, seq in r.yield_fa(): assert region == regions[i] assert header == headers[i] - assert seq == seqs[i] + assert seq == approx(seqs[i]) i += 1 + assert i == 2 + + r.region_reader = StringIO('#h1\n' + 'header 1\n' + 'line 1\n' + '#h2\n' + 'header 2\n' + 'line 2\n') + i = 0 + # with keys added, should only yield one value + for region, header, seq in r.yield_fa({'h1': ''}.keys()): + assert region == regions[i] + assert header == headers[i] + assert seq == approx(seqs[i]) + i += 1 + assert i == 1 + + r.region_reader = StringIO('#h1\n' + 'header 1\n' + 'line 1\n' + '#h2\n' + 'header 2\n' + 'line 2\n') + i = 0 + # with keys added, should only yield one value + for region, header, seq in r.yield_fa({'h0': ''}.keys()): + assert region == regions[i] + assert header == headers[i] + assert seq == approx(seqs[i]) + i += 1 + assert i == 0 def test_encode_fa(r): From a1c6155139129cd5a89937b1d15e160b1c4fc7a6 Mon Sep 17 00:00:00 2001 From: Troy Comi Date: Fri, 22 Mar 2019 13:04:43 -0400 Subject: [PATCH 06/33] Adding docstrings and return types --- code/analyze/extract_region.py | 16 +++-- code/analyze/filter_1_main.py | 5 +- code/analyze/filter_2_main.py | 5 +- code/analyze/filter_helpers.py | 102 +++++++++++++++++++++----------- code/analyze/id_regions_main.py | 5 +- code/analyze/predict.py | 41 ++++++++++--- 6 files changed, 124 insertions(+), 50 deletions(-) diff --git a/code/analyze/extract_region.py b/code/analyze/extract_region.py index 77b97df..e7649d7 100644 --- a/code/analyze/extract_region.py +++ b/code/analyze/extract_region.py @@ -1,9 +1,14 @@ #!/usr/bin/env python3 import argparse from misc.region_reader import Region_Reader +from typing import List, Tuple def main(): + ''' + Main method to read in arguments from stdin and perform lookup with + Region_Reader + ''' args = parse_args() args, reader = validate_args(args) with reader: @@ -12,7 +17,7 @@ def main(): write_regions(reader, locations) -def parse_args(args=None): +def parse_args(args: List[str] = None) -> argparse.Namespace: ''' Read in input arguments or the supplied list of strings Returns a dictionary of options @@ -37,7 +42,8 @@ def parse_args(args=None): return vars(parser.parse_args(args)) -def validate_args(args): +def validate_args(args: argparse.Namespace) -> Tuple[argparse.Namespace, + Region_Reader]: ''' Performs checks and conversions of input, raises ValueErrors if invalid ''' @@ -51,7 +57,9 @@ def validate_args(args): return args, reader -def decode_regions(regions, reader, retain_sort): +def decode_regions(regions: List[int], + reader: Region_Reader, + retain_sort: bool) -> List[int]: ''' Converts list of regions to file locations based on index dictionary Retain_sort controls if the output list order is determined by the @@ -66,7 +74,7 @@ def decode_regions(regions, reader, retain_sort): return sorted(result) -def write_regions(reader, locations): +def write_regions(reader: Region_Reader, locations: List[int]) -> None: ''' Writes the regions specified by index to stdout If print_header is false, ignore first line after location diff --git a/code/analyze/filter_1_main.py b/code/analyze/filter_1_main.py index 043a782..70dc334 100644 --- a/code/analyze/filter_1_main.py +++ b/code/analyze/filter_1_main.py @@ -18,7 +18,10 @@ from misc.region_reader import Region_Reader -def main(): +def main() -> None: + ''' + Perform first step of filtering + ''' args = predict.process_predict_args(sys.argv[1:]) out_dir = gp.analysis_out_dir_absolute + args['tag'] diff --git a/code/analyze/filter_2_main.py b/code/analyze/filter_2_main.py index 604b65a..01986f8 100644 --- a/code/analyze/filter_2_main.py +++ b/code/analyze/filter_2_main.py @@ -18,7 +18,10 @@ from misc.region_reader import Region_Reader -def main(): +def main() -> None: + ''' + Perform second stage of filtering + ''' args = predict.process_predict_args(sys.argv[2:]) threshold = float(sys.argv[1]) out_dir = gp.analysis_out_dir_absolute + args['tag'] diff --git a/code/analyze/filter_helpers.py b/code/analyze/filter_helpers.py index b77594b..0a24a38 100644 --- a/code/analyze/filter_helpers.py +++ b/code/analyze/filter_helpers.py @@ -1,17 +1,29 @@ import global_params as gp from misc import seq_functions import numpy as np - - -def write_filtered_line(f, region_id, region, fields): - f.write(f'{region_id}\t' - + '\t'.join([str(region[field]) - for field in fields[1:]]) - + '\n') - - -def passes_filters(region): - +from typing import List, Dict, TextIO, Tuple + + +def write_filtered_line(writer: TextIO, + region_id: str, + region: Dict, + fields: List) -> None: + ''' + Write the region id and values in "region" dict to open file writer + ''' + writer.write(f'{region_id}\t' + + '\t'.join([str(region[field]) + for field in fields[1:]]) + + '\n') + + +def passes_filters(region: Dict) -> bool: + ''' + test if the supplied region satisfies: + -Fraction of gaps and masked < 0.5 + -Number of matching > 7 + -Divergence < 0.7 + ''' # fraction gaps + masked filter fraction_gaps_masked_threshold = .5 fraction_gaps_masked = \ @@ -38,13 +50,23 @@ def passes_filters(region): return True -def passes_filters1(region, info_string): - # filtering out things that we can't call introgressed in general - # with confidence (i.e. doesn't seem like a strong case against - # being S288c) - - r = gp.alignment_ref_order[0] - s = region['predicted_species'] +def passes_filters1(region: Dict, info: str) -> Tuple[bool, str]: + ''' + filtering out things that we can't call introgressed in general + with confidence (i.e. doesn't seem like a strong case against + being S288c) + Return true if the region passes the filter, or false with a string + specifying which filter failed + Tests: + -fraction of gaps masked in reference > 0.5 + -fraction of gaps masked in predicted species > 0.5 + -number of matches to predicted > 7 + -number of matches to predicted > number matches to reference + -divergence with predicted species + ''' + + reference_species = gp.alignment_ref_order[0] + predicted_species = region['predicted_species'] aligned_length = (int(region['end']) - int(region['start']) + 1) @@ -54,9 +76,9 @@ def passes_filters1(region, info_string): # reference x nor the test sequence is masked or has a gap or # unsequenced character fraction_gaps_masked_r = \ - 1 - float(region['num_sites_nonmask_' + r]) / aligned_length + 1 - region['num_sites_nonmask_' + reference_species] / aligned_length fraction_gaps_masked_s = \ - 1 - float(region['num_sites_nonmask_' + s]) / aligned_length + 1 - region['num_sites_nonmask_' + predicted_species] / aligned_length if fraction_gaps_masked_r > fraction_gaps_masked_threshold: return False, f'fraction gaps/masked in master = '\ @@ -65,10 +87,10 @@ def passes_filters1(region, info_string): return False, f'fraction gaps/masked in predicted = '\ f'{fraction_gaps_masked_s}' - # FILTER: number sites analyzed by HMM that match predicted - # reference - count_P = info_string.count('P') - count_C = info_string.count('C') + # FILTER: number sites analyzed by HMM that match predicted (P) + # reference (C) + count_P = info.count('P') + count_C = info.count('C') number_match_only_threshold = 7 if count_P < number_match_only_threshold: return False, f'count_P = {count_P}' @@ -77,10 +99,10 @@ def passes_filters1(region, info_string): # FILTER: divergence with predicted reference and master reference # (S288c) - id_predicted = float(region['match_nongap_' + s]) / \ - float(region['num_sites_nongap_' + s]) - id_master = float(region['match_nongap_' + r]) / \ - float(region['num_sites_nongap_' + r]) + id_predicted = float(region['match_nongap_' + predicted_species]) / \ + float(region['num_sites_nongap_' + predicted_species]) + id_master = float(region['match_nongap_' + reference_species]) / \ + float(region['num_sites_nongap_' + reference_species]) if id_master >= id_predicted: return False, f'id with master = {id_master} '\ @@ -91,16 +113,29 @@ def passes_filters1(region, info_string): return True, '' -def passes_filters2(region, seqs, threshold): - # filter out things we can't assign to one species specifically; - # also return the other reasonable alternatives if we're filtering - # it out +def passes_filters2(region: Dict, + seqs: np.array, + threshold: float) -> Tuple[bool, + List[str], + List[float], + List[int]]: + ''' + filter out things we can't assign to one species specifically; + return the other reasonable alternatives if we're filtering + it out + Returns a tuple of: + True if the region passes the filter + A list of likely species for the region + A list of fraction of matching sequence for each species + A list of total matching sites + Fails the filter if number of matches and fraction matching are >= more + than one state for the region + ''' refs = gp.alignment_ref_order s = region['predicted_species'] ids = {} - totals = {} P_counts = {} seqs = np.asarray(seqs) @@ -120,7 +155,6 @@ def passes_filters2(region, seqs, threshold): r_match, r_total = seq_functions.seq_id(seqs[-1], seqs[ri]) if r_total != 0: ids[ref] = r_match / r_total - totals[ref] = r_total P_counts[ref] = np.sum( np.logical_and( np.logical_not(skip), diff --git a/code/analyze/id_regions_main.py b/code/analyze/id_regions_main.py index b75cbf3..839eb29 100644 --- a/code/analyze/id_regions_main.py +++ b/code/analyze/id_regions_main.py @@ -4,7 +4,10 @@ import global_params as gp -def main(): +def main() -> None: + ''' + Adds a unique region id to block files, producing labeled text files + ''' args = predict.process_predict_args(sys.argv[1:]) # order regions by chromosome, start (break ties alphabetically by strain) diff --git a/code/analyze/predict.py b/code/analyze/predict.py index d96d399..f0474ad 100644 --- a/code/analyze/predict.py +++ b/code/analyze/predict.py @@ -8,9 +8,13 @@ import global_params as gp from misc import read_fasta import numpy as np +from typing import List, Dict, Tuple -def process_predict_args(arg_list): +def process_predict_args(arg_list: List[str]) -> Dict: + ''' + Parses arguments from argv, producing dictionary of parsed values + ''' d = {} i = 0 @@ -24,7 +28,7 @@ def process_predict_args(arg_list): d['threshold'] = 'viterbi' try: d['threshold'] = float(arg_list[i]) - except: + except ValueError: pass i += 1 @@ -66,8 +70,12 @@ def process_predict_args(arg_list): return d -def read_aligned_seqs(fn, strain): - headers, seqs = read_fasta.read_fasta(fn) +def read_aligned_seqs(fast_file: str, + strain: str) -> Tuple[np.array, np.array]: + ''' + Read fasta file, returning sequences of references and the specied strain + ''' + headers, seqs = read_fasta.read_fasta(fast_file) d = {} for i in range(len(seqs)): name = headers[i][1:].split(' ')[0] @@ -81,7 +89,13 @@ def read_aligned_seqs(fn, strain): return ref_seqs, predict_seq -def set_expectations(args, n): +def set_expectations(args: Dict, n: int) -> None: + ''' + sets expected number of tracts and bases for each reference + based on expected length of introgressed tracts and expected + total fraction of genome + also takes n, length of the sequence to analyze + ''' species_to = gp.alignment_ref_order[0] species_from = gp.alignment_ref_order[1:] @@ -103,7 +117,16 @@ def set_expectations(args, n): args['expected_num_tracts'][species_to] -def ungap_and_code(predict_seq, ref_seqs, index_ref=0): +def ungap_and_code(predict_seq: str, + ref_seqs: List[str], + index_ref: int = 0) -> Tuple[np.array, np.array]: + ''' + Remove any sequence locations where a gap is present and code + into matching or mismatching sequence + Returns the coded sequences, by default an array of + where matching, - + where mismatching. Also return the positions where the sequences are not + gapped. + ''' # index_ref is index of reference strain to index relative to # build character array sequences = np.array([list(predict_seq)] + @@ -131,6 +154,9 @@ def ungap_and_code(predict_seq, ref_seqs, index_ref=0): def poly_sites(sequences, positions): + ''' + WORKING ON ADDING DOC STRINGS AND TYPING!! + ''' seq_len = len(sequences[0]) # check if seq only contains match_symbol retain = np.vectorize( @@ -293,9 +319,6 @@ def predict_introgressed(ref_seqs, predict_seq, predict_args, if return_positions: return positions - # sets expected number of tracts and bases for each reference - # based on expected length of introgressed tracts and expected - # total fraction of genome set_expectations(predict_args, len(predict_seq)) # set initial hmm parameters based on combination of (1) initial From 884400785c4f7e2c49f3c688e7155e3e1805736b Mon Sep 17 00:00:00 2001 From: Troy Comi Date: Fri, 22 Mar 2019 15:30:39 -0400 Subject: [PATCH 07/33] Filter 2 thresholds refactor Changed operation of threshold scan to limit number of read throughs of the region files. --- code/analyze/filter_2_thresholds_main.py | 128 ++++++++++-------- .../analyze/test_filter_2_thresholds_main.py | 104 ++++++++++++++ .../helper_scripts/run_filter_2_thresholds.sh | 14 ++ 3 files changed, 193 insertions(+), 53 deletions(-) create mode 100644 code/test/analyze/test_filter_2_thresholds_main.py create mode 100755 code/test/helper_scripts/run_filter_2_thresholds.sh diff --git a/code/analyze/filter_2_thresholds_main.py b/code/analyze/filter_2_thresholds_main.py index 7a62fe8..7a35ee6 100644 --- a/code/analyze/filter_2_thresholds_main.py +++ b/code/analyze/filter_2_thresholds_main.py @@ -8,58 +8,80 @@ # then we'll make some plots in R to see if there's a sort of obvious # place to draw the line -import re import sys -import os -import copy -from collections import defaultdict -import predict -from filter_helpers import * -sys.path.insert(0, '..') +from analyze import predict +from analyze.filter_helpers import passes_filters2 import global_params as gp -sys.path.insert(0, '../misc/') -import read_table -import read_fasta - -args = predict.process_predict_args(sys.argv[1:]) - -#thresholds = [.99, .98, .97, .96, .95, .94, .93, .92, .91, .9, .88, .85, .82, .8, .75, .7, .6, .5] -#thresholds = [.999, .995, .985, .975, .965, .955, .945, .935, .925, .915, .905, .89, .87, .86] -thresholds = [1] - -open_mode = 'a' -f = open(gp.analysis_out_dir_absolute + args['tag'] + \ - '/filter_2_thresholds_' + args['tag'] + '.txt', open_mode) -if open_mode == 'w': - f.write('threshold\tpredicted_state\talternative_states\tcount\n') -for threshold in thresholds: - print threshold - for species_from in args['known_states'][1:]: - - print '*', species_from - - fn = gp.analysis_out_dir_absolute + args['tag'] + '/' + \ - 'blocks_' + species_from + \ - '_' + args['tag'] + '_filtered1.txt' - region_summary, fields = read_table.read_table_rows(fn, '\t') - - d = defaultdict(int) - for region_id in region_summary: - #print region_id, '****' - region = region_summary[region_id] - headers, seqs = read_fasta.read_fasta(gp.analysis_out_dir_absolute + \ - args['tag'] + \ - '/regions/' + region_id + '.fa.gz', \ - gz = True) - info_string = seqs[-1] - seqs = seqs[:-1] - - p, alt_states, alt_ids, alt_P_counts = \ - passes_filters2(region, seqs, threshold) - - d[','.join(sorted(alt_states))] += 1 - - for key in d: - f.write(str(threshold) + '\t' + species_from + '\t' + \ - key + '\t' + str(d[key]) + '\n') -f.close() +from misc import read_table +from misc.region_reader import Region_Reader + + +thresholds = [.999, .995, .985, .975, .965, .955, .945, + .935, .925, .915, .905, .89, .87, .86] +# thresholds = [.99, .98, .97, .96, .95, .94, .93, .92, +# .91, .9, .88, .85, .82, .8, .75, .7, .6, .5] +# thresholds = [1] + + +def main(): + args = predict.process_predict_args(sys.argv[1:]) + out_dir = gp.analysis_out_dir_absolute + args['tag'] + + open_mode = 'w' + with open(f'{out_dir}/filter_2_thresholds_{args["tag"]}.txt', open_mode)\ + as writer: + if open_mode == 'w': + writer.write( + 'threshold\tpredicted_state\talternative_states\tcount\n') + + data_table = {} + for species_from in args['known_states'][1:]: + print(f'* {species_from}') + + region_summary, fields = read_table.read_table_rows( + f'{out_dir}/blocks_{species_from}' + f'_{args["tag"]}_filtered1.txt', + '\t') + + with Region_Reader(f'{out_dir}/regions/{species_from}.fa.gz', + as_fa=True) as region_reader: + for region_id, header, seqs in \ + region_reader.yield_fa(region_summary.keys()): + + region = region_summary[region_id] + seqs = seqs[:-1] + + for threshold in thresholds: + _, alt_states, _, _ = \ + passes_filters2(region, seqs, threshold) + + record_data_hit(data_table, + threshold, + species_from, + ','.join(sorted(alt_states))) + + for threshold in thresholds: + for species in args['known_states'][1:]: + d = data_table[threshold][species] + for key in d.keys(): + writer.write(f'{threshold}\t{species}\t{key}\t{d[key]}\n') + + +def record_data_hit(data_dict, threshold, species, key): + ''' + adds an entry to the data table or increments if exists + ''' + if threshold not in data_dict: + data_dict[threshold] = {} + + if species not in data_dict[threshold]: + data_dict[threshold][species] = {} + + if key not in data_dict[threshold][species]: + data_dict[threshold][species][key] = 0 + + data_dict[threshold][species][key] += 1 + + +if __name__ == "__main__": + main() diff --git a/code/test/analyze/test_filter_2_thresholds_main.py b/code/test/analyze/test_filter_2_thresholds_main.py new file mode 100644 index 0000000..9758737 --- /dev/null +++ b/code/test/analyze/test_filter_2_thresholds_main.py @@ -0,0 +1,104 @@ +from analyze import filter_2_thresholds_main as main + + +def test_main(mocker, capsys): + mocker.patch('sys.argv', ['', '0.1']) + mocker.patch( + 'analyze.filter_2_thresholds_main.predict.process_predict_args', + return_value={ + 'known_states': ['state1', 'state2'], + 'tag': 'tag' + }) + mocker.patch( + 'analyze.filter_2_thresholds_main.thresholds', + [0.99, 0.95]) + mocker.patch( + 'analyze.filter_2_thresholds_main.gp.analysis_out_dir_absolute', + '/dir') + mocker.patch('analyze.filter_2_thresholds_main.read_table.read_table_rows', + return_value=({'r1': {}, 'r2': {'a': 1}}, ['regions'])) + + mocked_file = mocker.patch('analyze.filter_2_thresholds_main.open') + mock_read = mocker.patch('analyze.filter_2_thresholds_main.Region_Reader') + mock_read().__enter__().yield_fa.return_value = iter([ + ('r1', ['> seq', '> info'], ['atcg', 'x..']), + ('r2', ['> seq', '> info'], ['atcg', 'x..'])]) + mock_filter = mocker.patch( + 'analyze.filter_2_thresholds_main.passes_filters2', + side_effect=[ + (False, ['1', '2'], [0.8, 0.5], [2, 1, 0]), + (True, ['1'], [0.8], [2]), + (True, ['1'], [0.8], [2]), + (False, ['1', '2'], [0.8, 0.5], [2, 1, 0]) + ]) + + main.main() + + captured = capsys.readouterr().out + assert captured == '* state2\n' + + assert mock_read.call_count == 2 + mock_read.assert_called_with('/dirtag/regions/state2.fa.gz', as_fa=True) + + assert mocked_file.call_count == 1 + mocked_file.assert_any_call( + '/dirtag/filter_2_thresholds_tag.txt', 'w') + + mocked_file().__enter__().write.assert_has_calls([ + mocker.call('threshold\tpredicted_state\talternative_states\tcount\n'), + mocker.call('0.99\tstate2\t1,2\t1\n'), + mocker.call('0.99\tstate2\t1\t1\n'), + mocker.call('0.95\tstate2\t1\t1\n'), + mocker.call('0.95\tstate2\t1,2\t1\n'), + ]) + + assert mock_filter.call_count == 4 + print(mock_filter.call_args_list) + mock_filter.assert_has_calls([ + mocker.call({}, ['atcg'], 0.99), + mocker.call({}, ['atcg'], 0.95), + mocker.call({'a': 1}, ['atcg'], 0.99), + mocker.call({'a': 1}, ['atcg'], 0.95), + ]) + + +def test_record_data_hit(): + dt = {} + main.record_data_hit(dt, 0.9, 's1', 'k1') + assert dt == {0.9: {'s1': {'k1': 1}}} + main.record_data_hit(dt, 0.9, 's1', 'k1') + main.record_data_hit(dt, 0.9, 's1', 'k1') + assert dt == {0.9: {'s1': {'k1': 3}}} + main.record_data_hit(dt, 0.9, 's1', 'k2') + assert dt == { + 0.9: { + 's1': {'k1': 3, 'k2': 1} + } + } + main.record_data_hit(dt, 0.9, 's2', 'k2') + assert dt == { + 0.9: { + 's1': {'k1': 3, 'k2': 1}, + 's2': {'k2': 1} + } + } + main.record_data_hit(dt, 0.8, 's2', 'k2') + assert dt == { + 0.9: { + 's1': {'k1': 3, 'k2': 1}, + 's2': {'k2': 1} + }, + 0.8: { + 's2': {'k2': 1} + } + } + main.record_data_hit(dt, 0.9, 's2', 'k2') + assert dt == { + 0.9: { + 's1': {'k1': 3, 'k2': 1}, + 's2': {'k2': 2} + }, + 0.8: { + 's2': {'k2': 1} + } + } diff --git a/code/test/helper_scripts/run_filter_2_thresholds.sh b/code/test/helper_scripts/run_filter_2_thresholds.sh new file mode 100755 index 0000000..e6144ab --- /dev/null +++ b/code/test/helper_scripts/run_filter_2_thresholds.sh @@ -0,0 +1,14 @@ +#!/bin/bash + +#SBATCH --time=0-1 +#SBATCH -n 1 +#SBATCH -o "/tigress/tcomi/aclark4_temp/results/thresh_%A" + +export PYTHONPATH=/home/tcomi/projects/aclark4_introgression/code/ + +module load anaconda3 +conda activate introgression3 + +ARGS="_test .001 viterbi 10000 .025 10000 .025 10000 .025 10000 .025 unknown 1000 .01" + +python ${PYTHONPATH}analyze/filter_2_thresholds_main.py $ARGS From 7c82f959f078990583ca0f00f7267756df910caf Mon Sep 17 00:00:00 2001 From: Troy Comi Date: Tue, 26 Mar 2019 14:43:12 -0400 Subject: [PATCH 08/33] Summarize Strain under test, refactored Limited changes as original method was fairly fast --- code/analyze/summarize_strain_states_main.py | 175 +++++++++--------- .../test_summarize_strain_states_main.py | 90 +++++++++ .../helper_scripts/run_summarize_strain.sh | 14 ++ 3 files changed, 190 insertions(+), 89 deletions(-) create mode 100644 code/test/analyze/test_summarize_strain_states_main.py create mode 100755 code/test/helper_scripts/run_summarize_strain.sh diff --git a/code/analyze/summarize_strain_states_main.py b/code/analyze/summarize_strain_states_main.py index d161570..eddfd01 100644 --- a/code/analyze/summarize_strain_states_main.py +++ b/code/analyze/summarize_strain_states_main.py @@ -1,117 +1,114 @@ -import re import sys -import os -import copy import itertools -import gene_predictions -import predict +from analyze import predict from collections import defaultdict -from filter_helpers import * -sys.path.insert(0, '..') import global_params as gp -sys.path.insert(0, '../misc/') -import read_table -import read_fasta +from misc import read_table -args = predict.process_predict_args(sys.argv[1:]) -d = defaultdict(lambda: defaultdict(int)) -for species_from in args['known_states'][1:]: +def main(): + args = predict.process_predict_args(sys.argv[1:]) - print species_from + d = defaultdict(lambda: defaultdict(int)) + outdir = gp.analysis_out_dir_absolute + args['tag'] + states = args['known_states'][1:] + for species_from in states: - fn_filtered1i = gp.analysis_out_dir_absolute + args['tag'] + '/' + \ - 'blocks_' + species_from + \ - '_' + args['tag'] + '_filtered1intermediate.txt' - fn_filtered2i = gp.analysis_out_dir_absolute + args['tag'] + '/' + \ - 'blocks_' + species_from + \ - '_' + args['tag'] + '_filtered2intermediate.txt' + print(species_from) - regions1, fields1 = read_table.read_table_rows(fn_filtered1i, '\t') - regions2, fields2 = read_table.read_table_rows(fn_filtered2i, '\t') + regions1, _ = read_table.read_table_rows( + f'{outdir}/blocks_{species_from}_' + f'{args["tag"]}_filtered1intermediate.txt', '\t') + regions2, _ = read_table.read_table_rows( + f'{outdir}/blocks_{species_from}_' + f'{args["tag"]}_filtered2intermediate.txt', '\t') - for region_id in regions1: + for region_id, region1 in regions1.items(): - strain = regions1[region_id]['strain'] - length = int(regions1[region_id]['end']) - int(regions1[region_id]['start']) + 1 - d[strain]['num_regions_' + species_from] += 1 - d[strain]['num_regions_total'] += 1 - d[strain]['num_bases_' + species_from] += length - d[strain]['num_bases_total'] += length - if regions1[region_id]['reason'] == '': - d[strain]['num_regions_' + species_from + '_filtered1'] += 1 + strain = region1['strain'] + length = int(region1['end']) - int(region1['start']) + 1 + + d[strain][f'num_regions_{species_from}'] += 1 + d[strain]['num_regions_total'] += 1 + d[strain][f'num_bases_{species_from}'] += length + d[strain]['num_bases_total'] += length + + if regions1[region_id]['reason'] != '': + continue + + d[strain][f'num_regions_{species_from}_filtered1'] += 1 d[strain]['num_regions_total_filtered1'] += 1 - d[strain]['num_bases_' + species_from + '_filtered1'] += length + d[strain][f'num_bases_{species_from}_filtered1'] += length d[strain]['num_bases_total_filtered1'] += length alt_states = regions2[region_id]['alternative_states'].split(',') for species_from_alt in alt_states: - d[strain]['num_regions_' + species_from_alt + \ + d[strain][f'num_regions_{species_from_alt}' '_filtered2_inclusive'] += 1 - d[strain]['num_bases_' + species_from_alt + \ + d[strain][f'num_bases_{species_from_alt}' '_filtered2_inclusive'] += length if species_from_alt == species_from: d[strain]['num_regions_total_filtered2_inclusive'] += 1 d[strain]['num_bases_total_filtered2_inclusive'] += length - + if len(alt_states) == 1: - d[strain]['num_regions_' + species_from + \ + d[strain][f'num_regions_{species_from}' '_filtered2'] += 1 d[strain]['num_regions_total_filtered2'] += 1 - d[strain]['num_bases_' + species_from + \ + d[strain][f'num_bases_{species_from}' '_filtered2'] += length d[strain]['num_bases_total_filtered2'] += length - else: - d[strain]['num_bases_' + '_or_'.join(sorted(alt_states)) + '_filtered2i'] += length - - d[strain]['num_bases_' + str(len(alt_states)) + '_filtered2i'] += length - - -strain_info = [line[:-1].split('\t') for line in open('../../100_genomes_info.txt', 'r')] -strain_origins = dict(zip([x[0].lower() for x in strain_info], \ - [(x[5], x[3], x[4]) for x in strain_info])) -for strain in d.keys(): - d[strain]['population'] = strain_origins[strain][0] - d[strain]['geographic_origin'] = strain_origins[strain][1] - d[strain]['environmental_origin'] = strain_origins[strain][2] - -fn = gp.analysis_out_dir_absolute + args['tag'] + '/' + 'state_counts_by_strain.txt' -f = open(fn, 'w') -fields = [] - -fields += ['population', 'geographic_origin', 'environmental_origin'] - -fields += ['num_regions_' + x for x in args['known_states'][1:]] -fields += ['num_regions_total'] -fields += ['num_regions_' + x + '_filtered1' for x in args['known_states'][1:]] -fields += ['num_regions_total_filtered1'] -fields += ['num_regions_' + x + '_filtered2' for x in args['known_states'][1:]] -fields += ['num_regions_total_filtered2'] -fields += ['num_regions_' + x + '_filtered2_inclusive' for x in args['known_states'][1:]] -fields += ['num_regions_total_filtered2_inclusive'] - -fields += ['num_bases_' + x for x in args['known_states'][1:]] -fields += ['num_bases_total'] -fields += ['num_bases_' + x + '_filtered1' for x in args['known_states'][1:]] -fields += ['num_bases_total_filtered1'] -fields += ['num_bases_' + x + '_filtered2' for x in args['known_states'][1:]] -fields += ['num_bases_total_filtered2'] -fields += ['num_bases_' + x + '_filtered2_inclusive' for x in args['known_states'][1:]] -fields += ['num_bases_total_filtered2_inclusive'] - -r = sorted(gp.alignment_ref_order[1:]) -for n in range(2, len(r)+1): - x = itertools.combinations(r, n) - for combo in x: - fields += ['num_bases_' + '_or_'.join(combo) + '_filtered2i'] - fields += ['num_bases_' + str(n) + '_filtered2i'] - -f.write('strain' + '\t' + '\t'.join(fields) + '\n') - -for strain in sorted(d.keys()): - f.write(strain + '\t') - f.write('\t'.join([str(d[strain][x]) for x in fields])) - f.write('\n') -f.close() + d[strain]['num_bases_' + + '_or_'.join(sorted(alt_states)) + + '_filtered2i'] += length + + d[strain][f'num_bases_{len(alt_states)}_filtered2i'] += length + + with open( + '/home/tcomi/projects/aclark4_introgression/100_genomes_info.txt', + 'r') as reader: + strain_info = [line[:-1].split('\t') for line in reader] + strain_info = {x[0].lower(): (x[5], x[3], x[4]) for x in strain_info} + + for strain in d.keys(): + d[strain]['population'] = strain_info[strain][0] + d[strain]['geographic_origin'] = strain_info[strain][1] + d[strain]['environmental_origin'] = strain_info[strain][2] + + fields = ['population', 'geographic_origin', 'environmental_origin'] +\ + [f'num_regions_{x}' for x in states] +\ + ['num_regions_total'] +\ + [f'num_regions_{x}_filtered1' for x in states] +\ + ['num_regions_total_filtered1'] +\ + [f'num_regions_{x}_filtered2' for x in states] +\ + ['num_regions_total_filtered2'] +\ + [f'num_regions_{x}_filtered2_inclusive' for x in states] +\ + ['num_regions_total_filtered2_inclusive'] +\ + [f'num_bases_{x}' for x in states] +\ + ['num_bases_total'] +\ + [f'num_bases_{x}_filtered1' for x in states] +\ + ['num_bases_total_filtered1'] +\ + [f'num_bases_{x}_filtered2' for x in states] +\ + ['num_bases_total_filtered2'] +\ + [f'num_bases_{x}_filtered2_inclusive' for x in states] +\ + ['num_bases_total_filtered2_inclusive'] + + r = sorted(gp.alignment_ref_order[1:]) + for n in range(2, len(r)+1): + for combo in itertools.combinations(r, n): + fields += ['num_bases_' + '_or_'.join(combo) + '_filtered2i'] + fields += ['num_bases_' + str(n) + '_filtered2i'] + + with open(f'{outdir}/state_counts_by_strain.txt', 'w') as writer: + writer.write('strain\t' + '\t'.join(fields) + '\n') + + for strain in sorted(d.keys()): + writer.write(f'{strain}\t' + + '\t'.join([str(d[strain][x]) for x in fields]) + + '\n') + + +if __name__ == '__main__': + main() diff --git a/code/test/analyze/test_summarize_strain_states_main.py b/code/test/analyze/test_summarize_strain_states_main.py new file mode 100644 index 0000000..2c2fadb --- /dev/null +++ b/code/test/analyze/test_summarize_strain_states_main.py @@ -0,0 +1,90 @@ +import analyze.summarize_strain_states_main as main + + +def test_main(mocker, capsys): + mocker.patch( + 'analyze.summarize_strain_states_main.predict.process_predict_args', + return_value={ + 'known_states': ['state1', 'state2', 'state3'], + 'tag': 'tag' + }) + mocker.patch( + 'analyze.summarize_strain_states_main.gp.analysis_out_dir_absolute', + '/dir') + mocker.patch( + 'analyze.summarize_strain_states_main.gp.alignment_ref_order', + ['state1', 'state2', 'state3']) + + mock_read = mocker.patch( + 'analyze.summarize_strain_states_main.read_table.read_table_rows', + side_effect=[ + ({'r1': {'strain': 's1', 'start': 10, 'end': 20, + 'reason': ''}, + 'r2': {'strain': 's2', 'start': 25, 'end': 40, + 'reason': 'test'}}, + ['regions']), + ({'r1': {'alternative_states': 'state2,state3'}, + 'r2': {'not_called': ''}}, + ['regions']), + ({'r1': {'strain': 's1', 'start': 35, 'end': 40, + 'reason': ''}, + 'r2': {'strain': 's2', 'start': 4, 'end': 8, + 'reason': ''}}, + ['regions']), + ({'r1': {'alternative_states': 'state3,state2'}, + 'r2': {'alternative_states': 'state3'}}, + ['regions']) + ]) + + handle1 = mocker.MagicMock() + handle1.__enter__.return_value.__iter__.return_value = \ + ("s1\tnothing\tnothing\tgeo1\tenv1\tpop1\n", + "s2\tnothing\tnothing\tgeo2\tenv2\tpop2\n") + handle2 = mocker.MagicMock() + + mocker.patch( + 'analyze.summarize_strain_states_main.open', + side_effect=(handle1, handle2) + ) + + main.main() + + captured = capsys.readouterr().out + assert captured == 'state2\nstate3\n' + + assert handle1.called_with(mocker.ANY, 'r') + + assert handle2.called_with('/dirtag/state_counts_by_strain.txt', 'w') + + calls = handle2.__enter__().write.call_args_list + assert handle2.__enter__().write.call_count == 3 + assert calls[0][0] == \ + ('strain\tpopulation\tgeographic_origin\tenvironmental_origin\t' + 'num_regions_state2\tnum_regions_state3\tnum_regions_total\t' + 'num_regions_state2_filtered1\tnum_regions_state3_filtered1\t' + 'num_regions_total_filtered1\tnum_regions_state2_filtered2\t' + 'num_regions_state3_filtered2\tnum_regions_total_filtered2\t' + 'num_regions_state2_filtered2_inclusive\t' + 'num_regions_state3_filtered2_inclusive\t' + 'num_regions_total_filtered2_inclusive\tnum_bases_state2\t' + 'num_bases_state3\tnum_bases_total\tnum_bases_state2_filtered1\t' + 'num_bases_state3_filtered1\tnum_bases_total_filtered1\t' + 'num_bases_state2_filtered2\tnum_bases_state3_filtered2\t' + 'num_bases_total_filtered2\tnum_bases_state2_filtered2_inclusive\t' + 'num_bases_state3_filtered2_inclusive\t' + 'num_bases_total_filtered2_inclusive\t' + 'num_bases_state2_or_state3_filtered2i\tnum_bases_2_filtered2i\n',) + assert calls[1][0] == \ + ('s1\tpop1\tgeo1\tenv1\t1\t1\t2\t1\t1\t2\t0\t0\t0\t2\t2\t2\t11\t6\t' + '17\t11\t6\t17\t0\t0\t0\t17\t17\t17\t17\t17\n',) + assert calls[2][0] == \ + ('s2\tpop2\tgeo2\tenv2\t1\t1\t2\t0\t1\t1\t0\t1\t1\t0\t1\t1\t16\t5\t21' + '\t0\t5\t5\t0\t5\t5\t0\t5\t5\t0\t0\n',) + + assert mock_read.call_count == 4 + fname = '/dirtag/blocks_state{s}_tag_filtered{f}intermediate.txt' + mock_read.assert_has_calls([ + mocker.call(fname.format(s=2, f=1), '\t'), + mocker.call(fname.format(s=2, f=2), '\t'), + mocker.call(fname.format(s=3, f=1), '\t'), + mocker.call(fname.format(s=3, f=2), '\t')]) diff --git a/code/test/helper_scripts/run_summarize_strain.sh b/code/test/helper_scripts/run_summarize_strain.sh new file mode 100755 index 0000000..77fbe6a --- /dev/null +++ b/code/test/helper_scripts/run_summarize_strain.sh @@ -0,0 +1,14 @@ +#!/bin/bash + +#SBATCH --time=0-1 +#SBATCH -n 1 +#SBATCH -o "/tigress/tcomi/aclark4_temp/results/filter_%A" + +export PYTHONPATH=/home/tcomi/projects/aclark4_introgression/code/ + +module load anaconda3 +conda activate introgression3 + +ARGS="_test .001 viterbi 10000 .025 10000 .025 10000 .025 10000 .025 unknown 1000 .01" + +python ${PYTHONPATH}analyze/summarize_strain_states_main.py $ARGS From 32675821857947a3dcc73eaf2d07256838fcef13 Mon Sep 17 00:00:00 2001 From: Troy Comi Date: Mon, 1 Apr 2019 16:11:16 -0400 Subject: [PATCH 09/33] Working on documentation --- code/align/align_helpers.py | 29 ++- code/analyze/combine_chromosome_files_main.py | 31 --- code/analyze/filter_1_main.py | 8 + code/analyze/filter_2_main.py | 8 + code/analyze/filter_2_thresholds_main.py | 12 +- code/analyze/id_regions_main.py | 5 + code/analyze/predict.py | 212 ++++++++++++++---- code/analyze/predict_main.py | 14 +- code/analyze/summarize_region_quality.py | 171 +++++++++----- code/analyze/summarize_region_quality_main.py | 17 +- code/analyze/summarize_strain_states_main.py | 12 +- code/misc/binary_search.py | 6 +- code/misc/read_fasta.py | 8 +- code/misc/read_table.py | 26 ++- code/misc/region_reader.py | 24 +- code/misc/seq_functions.py | 8 +- code/sim/sim_predict.py | 6 +- code/sim/sim_process.py | 32 ++- code/test/analyze/test_predict.py | 24 +- code/test/misc/test_read_table.py | 5 - code/test/sim/test_sim_process.py | 2 +- 21 files changed, 473 insertions(+), 187 deletions(-) delete mode 100644 code/analyze/combine_chromosome_files_main.py diff --git a/code/align/align_helpers.py b/code/align/align_helpers.py index be216de..a1f44ee 100644 --- a/code/align/align_helpers.py +++ b/code/align/align_helpers.py @@ -1,12 +1,24 @@ import os import global_params as gp +from typing import List, Tuple -def flatten(l): +def flatten(l: List[List]) -> List: + ''' + Flatten list of lists into a single list + ''' return [item for sublist in l for item in sublist] -def get_strains(dirs): +def get_strains(dirs: List[str]) -> List[Tuple[str, str]]: + ''' + Find all strains in the provided list of directories + Returns a sorted list of tuples with (strain_name, directory) entries + Checks for files with the fasta_suffix and contain _chr + strain_name is the name of the file up to _chr. + Raises assertion error if the number of files found is < number of strains + * the number of chromosomes + ''' # get all non-reference strains of cerevisiae and paradoxus; could # generalize this someday... @@ -15,10 +27,11 @@ def get_strains(dirs): for d in dirs: fns = os.listdir(d) # only look at fasta files in the directory - fns = filter(lambda x: x.endswith(gp.fasta_suffix), fns) # only look at files containing '_chr' which should be chromosome # sequence files - fns = list(filter(lambda x: '_chr' in x, fns)) + fns = list( + filter(lambda x: x.endswith(gp.fasta_suffix) and '_chr' in x, + fns)) num_files = len(fns) if num_files == 0: print(f'found no chromosome sequence files in {d} ' @@ -33,7 +46,13 @@ def get_strains(dirs): return sorted(s) -def concatenate_fasta(input_files, names, output_file): +def concatenate_fasta(input_files: List[str], + names: List[str], + output_file: str) -> None: + ''' + Combines several fasta files together into a single output + Adds header between each input fasta as > name[i] filename + ''' with open(output_file, 'w') as output: for i, file in enumerate(input_files): with open(file, 'r') as input: diff --git a/code/analyze/combine_chromosome_files_main.py b/code/analyze/combine_chromosome_files_main.py deleted file mode 100644 index 1ec1814..0000000 --- a/code/analyze/combine_chromosome_files_main.py +++ /dev/null @@ -1,31 +0,0 @@ -import sys -import os -import gzip -import predict -import global_params as gp - -args = predict.process_predict_args(sys.argv[1:]) - -header = open(gp.analysis_out_dir_absolute + args['tag'] + '/' + \ - 'blocks_' + args['known_states'][0] + \ - '_' + args['tag'] + '_chr' + gp.chrms[0] + '_quality.txt', 'r').readline() - -for species_from in args['known_states']: - fn = gp.analysis_out_dir_absolute + args['tag'] + '/' + \ - 'blocks_' + species_from + \ - '_' + args['tag'] + '_quality.txt' - f = open(fn, 'w') - f.write(header) - for chrm in gp.chrms: - fn_chrm = gp.analysis_out_dir_absolute + args['tag'] + '/' + \ - 'blocks_' + species_from + \ - '_' + args['tag'] + '_chr' + chrm + '_quality.txt' - try: - fc = open(fn_chrm, 'r') - except: - continue - fc.readline() - for line in fc.readlines(): - f.write(line) - f.close() - diff --git a/code/analyze/filter_1_main.py b/code/analyze/filter_1_main.py index 70dc334..c2c7ec7 100644 --- a/code/analyze/filter_1_main.py +++ b/code/analyze/filter_1_main.py @@ -21,6 +21,14 @@ def main() -> None: ''' Perform first step of filtering + Input files: + -blocks_{species}_quality.txt + + Output files: + -blocks_{species}_filtered1intermediate.txt + -blocks_{species}_filtered1.txt + -regions/{species}.fa.gz + -regions/{species}.pkl ''' args = predict.process_predict_args(sys.argv[1:]) out_dir = gp.analysis_out_dir_absolute + args['tag'] diff --git a/code/analyze/filter_2_main.py b/code/analyze/filter_2_main.py index 01986f8..8e98b25 100644 --- a/code/analyze/filter_2_main.py +++ b/code/analyze/filter_2_main.py @@ -21,6 +21,14 @@ def main() -> None: ''' Perform second stage of filtering + Input files: + -blocks_{species}_filtered1.txt + regions/{species}.fa.gz + regions/{species}.pkl + + Output files: + -blocks_{species}_filtered2.txt + -blocks_{species}_filtered2intermediate.txt ''' args = predict.process_predict_args(sys.argv[2:]) threshold = float(sys.argv[1]) diff --git a/code/analyze/filter_2_thresholds_main.py b/code/analyze/filter_2_thresholds_main.py index 7a35ee6..f6cb61a 100644 --- a/code/analyze/filter_2_thresholds_main.py +++ b/code/analyze/filter_2_thresholds_main.py @@ -23,7 +23,17 @@ # thresholds = [1] -def main(): +def main() -> None: + ''' + Perform second stage of filtering with several threshold levels + Input files: + -blocks_{species}_filtered1.txt + -regions/{species}.fa.gz + -regions/{species}.pkl + + Output files: + -filter_2_thresholds.txt + ''' args = predict.process_predict_args(sys.argv[1:]) out_dir = gp.analysis_out_dir_absolute + args['tag'] diff --git a/code/analyze/id_regions_main.py b/code/analyze/id_regions_main.py index 839eb29..adc59c6 100644 --- a/code/analyze/id_regions_main.py +++ b/code/analyze/id_regions_main.py @@ -7,6 +7,11 @@ def main() -> None: ''' Adds a unique region id to block files, producing labeled text files + Input files: + -blocks_{species}.txt + + Output files: + -blocks_{species}_labeled.txt ''' args = predict.process_predict_args(sys.argv[1:]) diff --git a/code/analyze/predict.py b/code/analyze/predict.py index f0474ad..6f8b37b 100644 --- a/code/analyze/predict.py +++ b/code/analyze/predict.py @@ -8,7 +8,7 @@ import global_params as gp from misc import read_fasta import numpy as np -from typing import List, Dict, Tuple +from typing import List, Dict, Tuple, TextIO def process_predict_args(arg_list: List[str]) -> Dict: @@ -123,7 +123,7 @@ def ungap_and_code(predict_seq: str, ''' Remove any sequence locations where a gap is present and code into matching or mismatching sequence - Returns the coded sequences, by default an array of + where matching, - + Returns the coded sequences, by default an array of + where matching, - where mismatching. Also return the positions where the sequences are not gapped. ''' @@ -145,29 +145,38 @@ def ungap_and_code(predict_seq: str, gp.match_symbol, gp.mismatch_symbol) - # 1: indexing removes currently examined sequence - matches = [''.join(row) - for row in np.transpose(matches[:, np.all(isvalid, axis=0)])] + matches = np.fromiter((''.join(row) + for row in np.transpose( + matches[:, np.all(isvalid, axis=0)])), + dtype=f'U{len(sequences) - 1}') - # NOTE list is for unit test comparisons return matches, positions -def poly_sites(sequences, positions): +def poly_sites(sequences: np.array, + positions: np.array) -> Tuple[np.array, np.array]: ''' - WORKING ON ADDING DOC STRINGS AND TYPING!! + Remove all sequences where the sequence is all match_symbol + Returns the filtered sequence and position ''' seq_len = len(sequences[0]) # check if seq only contains match_symbol retain = np.vectorize( lambda x: x.count(gp.match_symbol) != seq_len)(sequences) indices = np.where(retain)[0] - ps_poly = [positions[i] for i in indices] - seq_poly = [sequences[i] for i in indices] + ps_poly = positions[indices] + seq_poly = sequences[indices] return seq_poly, ps_poly -def get_symbol_freqs(sequence): +def get_symbol_freqs(sequence: np.array) -> Tuple[Dict, Dict, List]: + ''' + Calculate metrics from the provided, coded sequence + Returns: + the fraction matching for each species + the fraction of each matching pattern (e.g. +--++) + the weighted fraction of matches for each species + ''' individual = [] weighted = [] @@ -194,8 +203,15 @@ def get_symbol_freqs(sequence): return individual, symbols, weighted -def initial_probabilities(known_states, unknown_states, - expected_frac, weighted_match_freqs): +def initial_probabilities(known_states: List[str], + unknown_states: List[str], + expected_frac: Dict, + weighted_match_freqs: List[float]) -> np.array: + ''' + Estimate the initial probability of being in each state + based on the number of states and their expected fractions + Returns the initial probability of each state + ''' init = [] expectation_weight = .9 @@ -212,7 +228,13 @@ def initial_probabilities(known_states, unknown_states, return init / np.sum(init) -def emission_probabilities(known_states, unknown_states, symbols): +def emission_probabilities(known_states: List[str], + unknown_states: List[str], + symbols: List[str]) -> List[Dict]: + ''' + Estimate initial emission probabilities + Return estimates as list of default dict of probabilities + ''' probabilities = { gp.mismatch_symbol + gp.match_symbol: 0.9, @@ -260,8 +282,13 @@ def emission_probabilities(known_states, unknown_states, symbols): return result -def transition_probabilities(known_states, unknown_states, - expected_frac, expected_tract_lengths): +def transition_probabilities(known_states: List[str], + unknown_states: List[str], + expected_frac: Dict, + expected_tract_lengths: Dict) -> np.array: + ''' + Estimate initial transition probabilities + ''' # doesn't depend on sequence observations but maybe it should? @@ -286,16 +313,26 @@ def transition_probabilities(known_states, unknown_states, return transitions / transitions.sum(axis=1)[:, None] -def initial_hmm_parameters(seq, known_states, unknown_states, - expected_frac, expected_tract_lengths): +def initial_hmm_parameters(seq: np.array, + known_states: List[str], + unknown_states: List[str], + expected_frac: Dict, + expected_tract_lengths: Dict) -> hmm_bw.HMM: + ''' + Build a HMM object initialized based on expected values and provided data + ''' # get frequencies of individual symbols (e.g. '+') and all full # combinations of symbols (e.g. '+++-') - individual_symbol_freqs, symbol_freqs, weighted_match_freqs = get_symbol_freqs(seq) + (individual_symbol_freqs, + symbol_freqs, + weighted_match_freqs) = get_symbol_freqs(seq) init = initial_probabilities(known_states, unknown_states, expected_frac, weighted_match_freqs) - emis = emission_probabilities(known_states, unknown_states, symbol_freqs.keys()) + emis = emission_probabilities(known_states, + unknown_states, + symbol_freqs.keys()) trans = transition_probabilities(known_states, unknown_states, expected_frac, expected_tract_lengths) @@ -308,9 +345,33 @@ def initial_hmm_parameters(seq, known_states, unknown_states, return hmm -def predict_introgressed(ref_seqs, predict_seq, predict_args, - train=True, only_poly_sites=True, - return_positions=False): +def predict_introgressed(ref_seqs: np.array, + predict_seq: np.array, + predict_args: Dict, + train: bool = True, + only_poly_sites: bool = True, + return_positions: bool = False) -> Tuple[ + List[str], + np.array, + hmm_bw.HMM, + hmm_bw.HMM, + np.array + ]: + ''' + Predict regions of introgression within the predicted sequence + ref_seqs: 2d np character array of the reference sequences + predict_seq: np character array of the sequence to perform prediction on + train: control whether or not to perform Baum-Welch estimation on HMM + only_poly_sites: control if only polymorphic sites should be considered + return_positions: if true, only the position of sites in reference sequence + is returned + Generally will return a tuple of the following: + The predicted types as a list of states + The posterior decoding of the trained HMM + The trained HMM object + The untrained HMM without sequences + The positions of sites with respect to the reference sequence + ''' # code sequence by which reference it matches at each site seq_coded, positions = ungap_and_code(predict_seq, ref_seqs) @@ -363,24 +424,39 @@ def predict_introgressed(ref_seqs, predict_seq, predict_args, return predicted, p[0], hmm, hmm_init, positions -def write_positions(ps, writer, strain, chrm): +def write_positions(positions: np.array, + writer: TextIO, + strain: str, + chrm: str) -> None: + ''' + Write the positions of the specific strain, chromosome as a line to the + provided textIO object + ''' writer.write(f'{strain}\t{chrm}\t' + - '\t'.join([str(x) for x in ps]) + '\n') + '\t'.join([str(x) for x in positions]) + '\n') -def read_positions(fn): - # dictionary keyed by strain and then chromosome - with gzip.open(fn, 'rb') as reader: +def read_positions(filename: str) -> Dict[str, Dict[str, List[int]]]: + ''' + Read in positions from the provided filename, returning a dictionary + keyed first by the strain, then chromosome. Returned positions are + lists of ints + ''' + with gzip.open(filename, 'rb') as reader: result = defaultdict({}) for line in reader: line = line.split() strain, chrm = line[0:2] - ps = [int(x) for x in line[2:]] - result[strain][chrm] = ps + positions = [int(x) for x in line[2:]] + result[strain][chrm] = positions return result -def write_blocks_header(writer): +def write_blocks_header(writer: TextIO) -> None: + ''' + Write header line to tab delimited block file: + strain chromosome predicted_species start end num_sites_hmm + ''' # NOTE: num_sites_hmm represents the sites considered by the HMM, # so it might exclude non-polymorphic sites in addition to gaps writer.write('\t'.join(['strain', @@ -392,26 +468,38 @@ def write_blocks_header(writer): + '\n') -# TODO: find source of all the newlines in output!! -def write_blocks(state_seq_blocks, ps, writer, strain, chrm, species_pred): - # file format is: - # strain chrm predicted_species start end number_non_gap +def write_blocks(state_seq_blocks: List[Tuple[int, int]], + positions: np.array, + writer: TextIO, + strain: str, + chrm: str, + species_pred: str) -> None: + ''' + Write entry into tab delimited block file, with columns: + strain chromosome predicted_species start end num_sites_hmm + ''' writer.write('\n'.join( ['\t'.join([strain, chrm, species_pred, - str(ps[start]), - str(ps[end]), + str(positions[start]), + str(positions[end]), str(end - start + 1)]) for start, end in state_seq_blocks])) - if state_seq_blocks: + if state_seq_blocks: # ensure ends with \n writer.write('\n') -def read_blocks(fn, labeled=False): - # return dictionary of (start, end, number_non_gap, [region_id]), - # keyed by strain and then chromosome - with open(fn, 'r') as reader: +def read_blocks(filename: str, + labeled: bool = False) -> Dict[ + str, Dict[str, Tuple[int, int, int, str]]]: + ''' + Read in the supplied block file, returning a dict keyed on strain, + then chromosome. Values are tuples of start, end, and number of postions + for the block. + If labeled is true, values contain the region_id as last element + ''' + with open(filename, 'r') as reader: reader.readline() # header result = defaultdict(lambda: defaultdict(list)) for line in reader: @@ -428,8 +516,11 @@ def read_blocks(fn, labeled=False): return result -def get_emis_symbols(known_states): - +def get_emis_symbols(known_states: List[str]) -> List[str]: + ''' + Generate all permutations of match and mismatch symbols with + len(known_states) characters, in lexigraphical order + ''' symbols = [gp.match_symbol, gp.mismatch_symbol] emis_symbols = [''.join(x) for x in itertools.product(symbols, repeat=len(known_states))] @@ -437,7 +528,15 @@ def get_emis_symbols(known_states): return emis_symbols -def write_hmm_header(known_states, unknown_states, symbols, writer): +def write_hmm_header(known_states: List[str], + unknown_states: List[str], + symbols: List[str], + writer: TextIO) -> None: + ''' + Write the header line for an hmm file to the provided textIO object + Output is tab delimited with: + strain chromosome initial_probs emissions transitions + ''' writer.write('strain\tchromosome\t') @@ -455,7 +554,17 @@ def write_hmm_header(known_states, unknown_states, symbols, writer): writer.write('\n') -def write_hmm(hmm, writer, strain, chrm, emis_symbols): +def write_hmm(hmm: hmm_bw.HMM, + writer: TextIO, + strain: str, + chrm: str, + emis_symbols: List[str]): + ''' + Write information on the provided hmm as a line to the supplied textIO + object. + Output is tab delimited with: + strain chromosome initial_probs emissions transitions + ''' writer.write(f'{strain}\t{chrm}\t') states = len(hmm.hidden_states) @@ -472,7 +581,16 @@ def write_hmm(hmm, writer, strain, chrm, emis_symbols): writer.write('\n') -def write_state_probs(probs, writer, strain, chrm, states): +def write_state_probs(probs: Dict[str, List[float]], + writer: TextIO, + strain: str, + chrm: str, + states: List[str]) -> None: + ''' + Write the probability each state to the supplied textIO object + Output is tab delimited with: + strain chrom state1:prob1,prob2,...,probn state2... + ''' writer.write(f'{strain}\t{chrm}\t') writer.write('\t'.join( diff --git a/code/analyze/predict_main.py b/code/analyze/predict_main.py index 43fd3f8..293f9f3 100644 --- a/code/analyze/predict_main.py +++ b/code/analyze/predict_main.py @@ -7,8 +7,20 @@ from align import align_helpers from misc import read_fasta -# read in analysis parameters +''' +Predict states from aligned sequences +Input files: +-refs_{strain}_chr{chromosome}_mafft.fa + +Output files: +-blocks{species}.txt +-hmm_init.txt +-hmm.txt +-positions.txt +-probs.txt +''' +# read in analysis parameters args = predict.process_predict_args(sys.argv[1:]) strain_dirs = align_helpers.get_strains( diff --git a/code/analyze/summarize_region_quality.py b/code/analyze/summarize_region_quality.py index 8bb825c..08d934a 100644 --- a/code/analyze/summarize_region_quality.py +++ b/code/analyze/summarize_region_quality.py @@ -2,6 +2,8 @@ import global_params as gp from misc import binary_search import numpy as np +from typing import List, Tuple, Dict + cen_starts = [151465, 238207, 114385, 449711, 151987, 148510, 496920, 105586, 355629, 436307, 440129, 150828, @@ -51,6 +53,7 @@ def distance_from_telomere(start, end, chrm): # region overlaps centromere: return minimum distance from either telomere return min(start - tel_left_ends[i], tel_right_starts[i] - end) + def distance_from_centromere(start, end, chrm): assert start <= end, str(start) + ' ' + str(end) @@ -65,13 +68,15 @@ def distance_from_centromere(start, end, chrm): # region overlaps centromere: return 0 return 0 + def write_region_summary_plus(fn, regions, fields): f = open(fn, 'w') f.write('region_id\t' + '\t'.join(fields) + '\n') keys = sorted(regions.keys(), key=lambda x: int(x[1:])) for region_id in keys: f.write(region_id + '\t') - f.write('\t'.join([str(regions[region_id][field]) for field in fields])) + f.write('\t'.join([str(regions[region_id][field]) + for field in fields])) f.write('\n') f.close() @@ -85,6 +90,7 @@ def gap_columns(seqs): break return g + def longest_consecutive(s, c): max_consecutive = 0 current_consecutive = 0 @@ -123,6 +129,7 @@ def masked_columns(seqs): mask_non_gap_total += 1 return mask_total, mask_non_gap_total + def index_by_reference(ref_seq, seq): # return dictionary keyed by reference index, with value the # corresponding index in non-reference sequence @@ -139,9 +146,12 @@ def index_by_reference(ref_seq, seq): return d -def index_alignment_by_reference(ref_seq): - # want a way to go from reference sequence coordinate to index in - # alignment +def index_alignment_by_reference(ref_seq: np.array) -> np.array: + ''' + Find locations of non-gapped sites in reference sequence + want a way to go from reference sequence coordinate to index in + alignment + ''' return np.where(ref_seq != gp.gap_symbol)[0] @@ -152,25 +162,54 @@ def num_sites_between(sites, start, end): return j - i, sites[i:j] -def read_masked_intervals(fn): - with open(fn, 'r') as reader: +def read_masked_intervals(filename: str) -> List[Tuple[int, int]]: + ''' + Read the interval file provided and return start and end sequences + as a list of tuples of 2 ints + ''' + with open(filename, 'r') as reader: reader.readline() # header - ints = [] + intervals = [] for line in reader: line = line.split() - ints.append((int(line[0]), int(line[2]))) + intervals.append((int(line[0]), int(line[2]))) - return ints + return intervals -def convert_intervals_to_sites(ints): +def convert_intervals_to_sites(intervals: List[Tuple[int, int]]) -> np.array: + ''' + Given a list of start, end positions, returns a 1D np.array of all sites + contined in the intervals List + convert_intervals_to_sites([(1, 2), (4, 6)]) -> [1, 2, 4, 5, 6] + ''' sites = [] - for start, end in ints: + for start, end in intervals: sites += range(start, end + 1) return np.array(sites) -def seq_id_hmm(seq1, seq2, offset, include_sites): +def seq_id_hmm(seq1: np.array, + seq2: np.array, + offset: int, + include_sites: List[int]) -> Tuple[ + int, int, Dict[str, List[bool]]]: + ''' + Compare two sequences and provide statistics of their overlap considering + only the included sites. + Takes the two sequences to consider, an offset of the included sites, + and a list of the included sites. + Returns: + -the total number of matching sites, where seq1[i] == seq2[i] and + i is an element in included_sites - offset + -the total number of sites considered in the included sites, e.g. where + included_sites - offset >= 0 and < len(seq) + -a dict with the following keys: + -gap_flag: true where seq1 or seq1 == gap_symbol + -unseq_flag: true where seq1 or seq1 == unsequenced_symbol + -hmm_flag: true where hmm_flag[i] is in included_sites - offset + -match: true where seq1 == seq2, regardless of symbol + ''' sites = np.array(include_sites) - offset info_gap = np.logical_or(seq1 == gp.gap_symbol, @@ -198,11 +237,25 @@ def seq_id_hmm(seq1, seq2, offset, include_sites): 'hmm_flag': info_hmm, 'match': info_match} -def seq_id_unmasked(seq1, seq2, offset, exclude_sites1, exclude_sites2): - # total_sites is number of sites at which neither sequence is - # masked or has a gap or unsequenced character; total_match is the - # number of those sites at which the two sequences match - # gapped and unsequenced locations +def seq_id_unmasked(seq1: np.array, + seq2: np.array, + offset: int, + exclude_sites1: List[int], + exclude_sites2: List[int]) -> Tuple[ + int, int, Dict[str, List[bool]]]: + ''' + Compare two sequences and provide statistics of their overlap considering + only the included sites. + Takes two sequences, an offset applied to each excluded sites list + Returns: + -total number of matching sites in non-excluded sites. A position is + excluded if it is an element of either excluded site list - offset, + or it is a gap or unsequenced symbol in either sequence. + -total number of non-excluded sites + A dict with the following keys: + -mask_flag: a boolean array that is true if the position is in + either excluded list - offset + ''' info_gap = np.logical_or(seq1 == gp.gap_symbol, seq2 == gp.gap_symbol) info_unseq = np.logical_or(seq1 == gp.unsequenced_symbol, @@ -237,44 +290,26 @@ def seq_id_unmasked(seq1, seq2, offset, exclude_sites1, exclude_sites2): return total_match, total_sites, {'mask_flag': info_mask} - n = len(seq1) - total_sites = 0 - total_match = 0 - - skip = [gp.gap_symbol, gp.unsequenced_symbol] - info_mask = [False for i in range(n)] - for i in range(n): - - if binary_search.present(exclude_sites1, i + offset) or \ - binary_search.present(exclude_sites2, i + offset): - info_mask[i] = True - continue - if seq1[i] not in skip and seq2[i] not in skip: - total_sites += 1 - if seq1[i] == seq2[i]: - total_match += 1 - - # TODO: keep track of gapped/masked sites for master/predicted to - # incorporate into info string later - return total_match, total_sites, {'mask_flag': info_mask} - - -def make_info_string_unknown(info, master_ind): - - # used with indices to decode result - decoder = np.array(list('Xx._-')) - indices = np.zeros(info['gap_any_flag'].shape, int) - indices[info['match_flag'][:, master_ind]] = 1 # x - matches = np.all(info['match_flag'], axis=1) - indices[matches] = 2 # . - indices[info['mask_any_flag']] = 3 # _ - indices[info['gap_any_flag']] = 4 # - - - return ''.join(decoder[indices]) - - -def make_info_string(info, master_ind, predict_ind): +def make_info_string(info: Dict[str, List[bool]], + master_ind: int, + predict_ind: int) -> str: + ''' + Summarize info dictionary into a string. master_ind is the index of + the master reference state. predict_ind is the index of the predicted + state. The return string is encoded as each position as: + '-': if either master or predict has a gap + '_': if either master or predict is masked + '.': if any state has a match + 'b': both predict and master match + 'c': master matches but not predict + 'p': predict matches but not master + 'x': no other condition applies + if the position is in the hmm_flag it will be capitalized for x, p, c, or + b + in order of precidence, e.g. if a position satisfies both '-' and '.', + it will be '-'. + ''' if predict_ind >= info['match_flag'].shape[1]: return make_info_string_unknown(info, master_ind) @@ -296,3 +331,31 @@ def make_info_string(info, master_ind, predict_ind): axis=1)] = 10 # - return ''.join(decoder[indices]) + + +def make_info_string_unknown(info: Dict[str, List[bool]], + master_ind: int) -> str: + ''' + Summarize info dictionary into a string for unknown state. + master_ind is the index of the master reference state. + The return string is encoded as each position as: + '-': if any state has a gap + '_': if any state has a mask + '.': all states match + 'x': master matches + 'X': no other condition applies + in order of precidence, e.g. if a position satisfies both '-' and '.', + it will be '-'. + ''' + + # used with indices to decode result + decoder = np.array(list('Xx._-')) + indices = np.zeros(info['gap_any_flag'].shape, int) + + indices[info['match_flag'][:, master_ind]] = 1 # x + matches = np.all(info['match_flag'], axis=1) + indices[matches] = 2 # . + indices[info['mask_any_flag']] = 3 # _ + indices[info['gap_any_flag']] = 4 # - + + return ''.join(decoder[indices]) diff --git a/code/analyze/summarize_region_quality_main.py b/code/analyze/summarize_region_quality_main.py index e375de1..f1db407 100644 --- a/code/analyze/summarize_region_quality_main.py +++ b/code/analyze/summarize_region_quality_main.py @@ -17,7 +17,22 @@ import pickle -def main(): +def main() -> None: + ''' + Summarize region quality of each region + First parameter is the species to process + Input files: + -blocks_{species}_labeled.txt + -{species}_chr_intervals.txt + -{species}_chr_mafft.fa + -{species}_chr_mafft.fa + + Output files: + -positions_{tag}.txt.gz + -regions file as {species}.fa.gz + -index file for the fz.gz + -blocks_{species}_quality.txt + ''' args = predict.process_predict_args(sys.argv[2:]) diff --git a/code/analyze/summarize_strain_states_main.py b/code/analyze/summarize_strain_states_main.py index eddfd01..cf74860 100644 --- a/code/analyze/summarize_strain_states_main.py +++ b/code/analyze/summarize_strain_states_main.py @@ -6,7 +6,17 @@ from misc import read_table -def main(): +def main() -> None: + ''' + Generate summary information for the state of each position in the sequence + Input files: + -blocks_{species}_filtered1intermediate.txt + -blocks_{species}_filtered2intermediate.txt + -100_genomes_info.txt + + Output files: + -state_counts_by_strain.txt + ''' args = predict.process_predict_args(sys.argv[1:]) d = defaultdict(lambda: defaultdict(int)) diff --git a/code/misc/binary_search.py b/code/misc/binary_search.py index 73c4cee..79594ca 100644 --- a/code/misc/binary_search.py +++ b/code/misc/binary_search.py @@ -1,7 +1,9 @@ import bisect +from typing import List -def present(a, x): - 'Locate the leftmost value exactly equal to x' + +def present(a: List[int], x: int) -> bool: + 'Locate the leftmost value exactly equal to x in a' i = bisect.bisect_left(a, x) if i != len(a) and a[i] == x: return True diff --git a/code/misc/read_fasta.py b/code/misc/read_fasta.py index 4554396..b09c691 100644 --- a/code/misc/read_fasta.py +++ b/code/misc/read_fasta.py @@ -1,8 +1,14 @@ import gzip import numpy as np +from typing import Tuple, List -def read_fasta(fn, gz=False): +def read_fasta(fn: str, gz: bool = False) -> Tuple[ + List[str], np.array]: + ''' + Read the provided fasta file, returning the + headers (lines startin with >) and sequences + ''' headers = [] seqs = [] diff --git a/code/misc/read_table.py b/code/misc/read_table.py index 7c2f5bd..e5ba31b 100644 --- a/code/misc/read_table.py +++ b/code/misc/read_table.py @@ -1,9 +1,24 @@ import gzip -import io +from typing import List, Dict, Tuple -def read_table_rows(fn, sep, header=True, key_ind=0): - # returns dictionary of rows keyed by first item in row +def read_table_rows(fn: str, + sep: str, + header: bool = True, + key_ind: int = 0) -> Tuple[ + Dict[str, Dict[str, List[str]]], + List[str]]: + ''' + Read the text file of tabular data by rows + fn: filename to read + sep: the column delimiter + header: flag to indicate a header is present + If a header is provided, labels are returned from the first row. + Return value becomes a dictionary of dictionaries, keyed first + by the key_ind, then the column label + key_ind: the column index to use as keys in output + returns dictionary of rows keyed by key_ind and labels + ''' reader = None if fn.endswith('.gz'): @@ -31,7 +46,10 @@ def read_table_rows(fn, sep, header=True, key_ind=0): return table, labels -def read_table_columns(fn, sep, group_by=None, **filter_output): +def read_table_columns(fn: str, + sep: str, + group_by: str = None, + **filter_output) -> Tuple[Dict, List[str]]: ''' Reads sep delimited file to generate dictionary of columns, keyed by labels Optionally, a column to group by can be specified, changing the return diff --git a/code/misc/region_reader.py b/code/misc/region_reader.py index 46d6746..53aac84 100644 --- a/code/misc/region_reader.py +++ b/code/misc/region_reader.py @@ -3,13 +3,15 @@ import os import sys import numpy as np +from typing import Dict, List, Tuple class Region_Reader(): - def __init__(self, region_file, - as_fa=False, - suppress_header=True, - num_lines=14): + def __init__(self, + region_file: str, + as_fa: bool = False, + suppress_header: bool = True, + num_lines: int = 14): ''' Checks for valid filename and existance of corresponding pickle as_fa: if true will return headers and sequences as read_fasta does @@ -50,7 +52,7 @@ def __repr__(self): f'num_lines = {self.num_lines}\n' ) - def read_region(self, region_name): + def read_region(self, region_name: str): ''' read the supplied region name, either printing to stdout or returning (headers, seqs) tuple depending on as_fa value @@ -59,7 +61,7 @@ def read_region(self, region_name): location = self.decode_region(region) return self.read_location(location) - def read_location(self, location): + def read_location(self, location: int): ''' helper method used in extract_region for directly handling locations ''' @@ -75,7 +77,7 @@ def read_location(self, location): else: self.print_region(location) - def convert_region(self, region_name): + def convert_region(self, region_name: str) -> int: ''' Checks that region is a digit that starts with r If so, returns the integer value of the region for decoding @@ -87,7 +89,7 @@ def convert_region(self, region_name): raise ValueError(f'{region_name} could not be parsed') return int(r) - def decode_region(self, region_number): + def decode_region(self, region_number: int) -> int: ''' Convert region to disk location. Raises key error if region doesn't exist @@ -99,7 +101,7 @@ def decode_region(self, region_number): return result - def yield_fa(self, keys=None): + def yield_fa(self, keys=None) -> Tuple[str, List[str], List[List[str]]]: ''' repeatedly yield tuples of region, headers, sequences from fa file assumes file position starts at header for region @@ -118,7 +120,7 @@ def yield_fa(self, keys=None): except ValueError: break - def encode_fa(self, location): + def encode_fa(self, location: int) -> Tuple[List[str], List[List[str]]]: ''' Reads the region file entry and returns headers, seqs Assumes even numbered lines are headers, odd are sequences @@ -136,7 +138,7 @@ def encode_fa(self, location): return headers, np.asarray(seqs) - def print_region(self, location): + def print_region(self, location: int) -> None: ''' reads the region file entry, printing to stdout ''' diff --git a/code/misc/seq_functions.py b/code/misc/seq_functions.py index 60f9a25..fe4b65b 100644 --- a/code/misc/seq_functions.py +++ b/code/misc/seq_functions.py @@ -1,4 +1,5 @@ import numpy as np +from typing import List, Tuple r = {'A': 'T', 'T': 'A', 'G': 'C', 'C': 'G', @@ -56,7 +57,12 @@ def index_ignoring_gaps(s, i, s_start, gap_symbol): return x -def seq_id(ref_seq, seq): +def seq_id(ref_seq: List[str], seq: List[str]) -> Tuple[int, int]: + ''' + Given two sequences, determine the total number of valid matching sites + and number of valid sites. A site is valid if it is a upper or lower case + ATCG + ''' length = min(ref_seq.size, seq.size) valid_seq = list(r.keys()) valid = np.logical_and( diff --git a/code/sim/sim_predict.py b/code/sim/sim_predict.py index 37245e6..7eba6e4 100644 --- a/code/sim/sim_predict.py +++ b/code/sim/sim_predict.py @@ -2,6 +2,7 @@ import itertools from sim import sim_process import global_params as gp +from typing import List def process_args(arg_list, sim_args, i=1): @@ -419,7 +420,10 @@ def initial_hmm_parameters(seqs_coded, species_to_indices, species_to, \ return p['init'], p['emis'], p['trans'] -def convert_predictions(path, states): +def convert_predictions(path: List[int], states: List[str]): + ''' + Convert a path of index values into strings based on the states list + ''' return [states[p] for p in path] diff --git a/code/sim/sim_process.py b/code/sim/sim_process.py index 31d1855..57325d8 100644 --- a/code/sim/sim_process.py +++ b/code/sim/sim_process.py @@ -1,4 +1,5 @@ import numpy as np +from typing import List, Dict, Tuple # given fractional positions for snvs and length of sequence l, @@ -82,7 +83,7 @@ def read_one_sim(f, num_sites, num_samples): t_string = f.readline() recomb_sites = [] trees = [] - + while t_string[0] == '[': t_start = t_string.find(']') + 1 recomb_sites.append(int(t_string[1:t_start-1])) @@ -110,7 +111,14 @@ def read_one_sim(f, num_sites, num_samples): return sim -def convert_to_blocks_one(state_seq, states): +def convert_to_blocks_one(state_seq: List[str], + states: List[str]) -> Dict[ + str, List[Tuple[int, int]]]: + ''' + Convert a list of sequences into a structure with start and end positions + Return structure is a dict keyed on species with values of Lists of + each block, which is a tuple with start and end positions + ''' # single individual state sequence blocks = {} for state in states: @@ -302,8 +310,14 @@ def read_state_probs(f, line): return d, rep, line -def threshold_predicted(predicted, probs, threshold, default_state): - +def threshold_predicted(predicted: List[str], + probs: List[float], + threshold: float, + default_state: str) -> List[str]: + ''' + Given a list of states, predicted, and the associated probabilities, probs + Converts any states with probability < threshold to the default state + ''' predicted_thresholded = np.array(predicted) probs = np.array(probs) predicted_thresholded[probs < threshold] = default_state @@ -328,9 +342,13 @@ def fill_seqs(polymorphic_seqs, polymorphic_sites, nsites, fill): return seqs_filled -def get_max_path(p, states): - # p is a list of dictionaries, one per site; each dict has keys - # for each state, with associated probability +def get_max_path(p: np.array, states: List[float]) -> Tuple[ + List[int], List[float]]: + ''' + p is a list of dictionaries, one per site; each dict has keys + for each state, with associated probability + Return the maximum likelihood path and the associated probabilities + ''' max_positions = np.argmax(p, axis=1) max_path = [states[i] for i in max_positions] max_probs = [p[i, pos] for i, pos in enumerate(max_positions)] diff --git a/code/test/analyze/test_predict.py b/code/test/analyze/test_predict.py index c556411..500aaca 100644 --- a/code/test/analyze/test_predict.py +++ b/code/test/analyze/test_predict.py @@ -3,7 +3,7 @@ import pytest from pytest import approx from io import StringIO -from collections import Counter, defaultdict +from collections import defaultdict import random import numpy as np @@ -141,7 +141,7 @@ def test_ungap_and_code(): ['abc', 'def', 'ghi'], # several references 0) # reference index assert positions == approx([]) - assert sequence == [] + assert sequence == approx([]) # one match sequence, positions = predict.ungap_and_code( @@ -157,7 +157,7 @@ def test_ungap_and_code(): ['abc', 'def', '-hi'], 0) assert positions == approx([]) - assert sequence == [] + assert sequence == approx([]) # two matches sequence, positions = predict.ungap_and_code( @@ -165,7 +165,7 @@ def test_ungap_and_code(): ['abc', 'def', 'gei'], 0) assert positions == approx([0, 1]) - assert sequence == ['+--', '-++'] + assert (sequence == ['+--', '-++']).all() # mess with ref index sequence, positions = predict.ungap_and_code( @@ -173,14 +173,13 @@ def test_ungap_and_code(): ['a--bc', 'deeef', 'geeei'], 0) assert positions == approx([0, 1]) - assert sequence == ['+--', '-++'] + assert (sequence == ['+--', '-++']).all() sequence, positions = predict.ungap_and_code( 'a--e-', ['a--bc', 'deeef', 'geeei'], 1) assert positions == approx([0, 3]) - assert sequence == ['+--', '-++'] - + assert (sequence == ['+--', '-++']).all() sequence, positions = predict.ungap_and_code( 'a---ef--i', @@ -189,16 +188,16 @@ def test_ungap_and_code(): 'a-ceef-hh'], 0) - assert sequence == '+++ -++ +-+ ++-'.split() + assert (sequence == '+++ -++ +-+ ++-'.split()).all() assert positions == approx([0, 3, 4, 7]) def test_poly_sites(): sequence, positions = predict.poly_sites( - '+++ -++ +-+ ++-'.split(), - [0, 3, 4, 7] + np.array('+++ -++ +-+ ++-'.split()), + np.array([0, 3, 4, 7]) ) - assert sequence == '-++ +-+ ++-'.split() + assert (sequence == '-++ +-+ ++-'.split()).all() assert positions == approx([3, 4, 7]) @@ -239,7 +238,7 @@ def test_get_symbol_freqs(): def symbol_test_helper(sequence): - ind, symb, weigh = predict.get_symbol_freqs(sequence) + ind, symb, weigh = predict.get_symbol_freqs(np.array(sequence)) num_states = len(sequence[0]) num_sites = len(sequence) @@ -301,7 +300,6 @@ def test_emission_probabilities(args): # normal mode symbols = predict.get_emis_symbols([1]*5) - # NOTE not sure why this takes the keys in predict.py or uses len-2 emis = predict.emission_probabilities(args['known_states'], args['unknown_states'], symbols) diff --git a/code/test/misc/test_read_table.py b/code/test/misc/test_read_table.py index aa5c7e4..0691f94 100644 --- a/code/test/misc/test_read_table.py +++ b/code/test/misc/test_read_table.py @@ -1,6 +1,5 @@ from misc import read_table from io import StringIO -import pytest def test_read_table_rows_empty(mocker): @@ -20,8 +19,6 @@ def test_read_table_rows_empty(mocker): def return_args(arg): return arg - mocker.patch('misc.read_table.io.BufferedReader', - side_effect=return_args) d, labels = read_table.read_table_rows('mocked.gz', '\t', False) mocked_gz.assert_called_with('mocked.gz', 'rt') @@ -113,8 +110,6 @@ def test_read_table_columns_empty(mocker): def return_args(arg): return arg - mocker.patch('misc.read_table.io.BufferedReader', - side_effect=return_args) d, labels = read_table.read_table_columns('mocked.gz', '\t') mocked_gz.assert_called_with('mocked.gz', 'rt') diff --git a/code/test/sim/test_sim_process.py b/code/test/sim/test_sim_process.py index 40c7953..2f54781 100644 --- a/code/test/sim/test_sim_process.py +++ b/code/test/sim/test_sim_process.py @@ -24,7 +24,7 @@ def test_get_max_path(hm): assert probs == max_probs -def test_get_threshold_predicted(hm): +def test_threshold_predicted(hm): post = hm.posterior_decoding() path, probs = sim_process.get_max_path(post[0], hm.hidden_states) for thresh in (0, 0.2, 0.5, 0.8, 1): From 58634eb4e975d850059f269ede74cbdf2f5d37c7 Mon Sep 17 00:00:00 2001 From: Troy Comi Date: Tue, 9 Apr 2019 12:56:48 -0400 Subject: [PATCH 10/33] Suppress log warnings --- code/hmm/hmm_bw.py | 20 ++++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/code/hmm/hmm_bw.py b/code/hmm/hmm_bw.py index 26c8e37..081c002 100644 --- a/code/hmm/hmm_bw.py +++ b/code/hmm/hmm_bw.py @@ -291,14 +291,16 @@ def forward(self) -> np.array: # Markov process was at state j at time t # returns array of size observations, observations[0], hidden_states # determine emission probabilities for each measured value - emis = np.transpose(np.log(self.emissions[:, self.observations])) + with np.errstate(divide='ignore'): + emis = np.transpose(np.log(self.emissions[:, self.observations])) trans = np.log(self.transitions) alpha = np.empty((len(self.observations), len(self.observations[0]), len(self.hidden_states)), float) # initialize to initial probabilitiy * observed emission - alpha[:, 0, :] = np.log(self.initial_p[None, :]) + emis[0, :, :] + with np.errstate(divide='ignore'): + alpha[:, 0, :] = np.log(self.initial_p[None, :]) + emis[0, :, :] # recursively fill array for i in range(1, len(self.observations[0])): alpha[:, i, :] = np.logaddexp.reduce(alpha[:, i-1, :][:, :, None] + @@ -312,8 +314,9 @@ def backward(self) -> np.array: ''' # probability that the sequence from t+1 to end was observed # and Markov process was at state j at time t - emis = np.transpose(np.log(self.emissions[:, self.observations])) - trans = np.log(self.transitions) + with np.errstate(divide='ignore'): + emis = np.transpose(np.log(self.emissions[:, self.observations])) + trans = np.log(self.transitions) beta = np.zeros((len(self.observations), len(self.observations[0]), len(self.hidden_states)), float) @@ -337,12 +340,13 @@ def calculate_max_states(self) -> Tuple[np.array, np.array]: len(self.hidden_states)), int) # build array of emissions based on observations - emissions = np.log(np.transpose(self.emissions)[self.observations]) + with np.errstate(divide='ignore'): + emissions = np.log(np.transpose(self.emissions)[self.observations]) - trans_emis = np.log(self.transitions[None, :, :]) +\ - emissions[:, None, :] + trans_emis = np.log(self.transitions[None, :, :]) +\ + emissions[:, None, :] - probabilities[0, :] = np.log(self.initial_p) + emissions[0] + probabilities[0, :] = np.log(self.initial_p) + emissions[0] states[0, :] = -1 for i in range(1, len(emissions)): From 96d6cfb088ff520eb620fa3cf24bf13bf79175e4 Mon Sep 17 00:00:00 2001 From: Troy Comi Date: Tue, 9 Apr 2019 14:18:39 -0400 Subject: [PATCH 11/33] Tests passing Mocked and encoded constants enough to get things passing --- code/analyze/filter_1_main.py | 2 +- code/analyze/filter_helpers.py | 1 - code/analyze/predict.py | 5 +- code/analyze/summarize_strain_states_main.py | 2 +- code/test/analyze/test_filter_1_main.py | 4 +- code/test/analyze/test_filter_2_main.py | 4 +- code/test/analyze/test_filter_helpers.py | 34 +++++------ code/test/analyze/test_id_regions_main.py | 21 +++++-- code/test/analyze/test_predict.py | 56 +++++++++++++------ .../test_summarize_region_quality_main.py | 9 +++ .../test_summarize_strain_states_main.py | 3 - 11 files changed, 86 insertions(+), 55 deletions(-) diff --git a/code/analyze/filter_1_main.py b/code/analyze/filter_1_main.py index 83f8f74..50194f2 100644 --- a/code/analyze/filter_1_main.py +++ b/code/analyze/filter_1_main.py @@ -63,7 +63,7 @@ def main() -> None: # S288c p, reason = passes_filters1(region, info_string, - args['known_state'][0]) + args['known_states'][0]) region['reason'] = reason write_filtered_line(f_out1i, region_id, region, fields1i) diff --git a/code/analyze/filter_helpers.py b/code/analyze/filter_helpers.py index a916a28..c263436 100644 --- a/code/analyze/filter_helpers.py +++ b/code/analyze/filter_helpers.py @@ -134,7 +134,6 @@ def passes_filters2(region: Dict, than one state for the region ''' - refs = gp.alignment_ref_order s = region['predicted_species'] ids = {} diff --git a/code/analyze/predict.py b/code/analyze/predict.py index 1303b63..e5b5655 100644 --- a/code/analyze/predict.py +++ b/code/analyze/predict.py @@ -323,7 +323,7 @@ def initial_hmm_parameters(seq: np.array, known_states: List[str], unknown_states: List[str], expected_frac: Dict, - expected_tract_lengths: Dict) -> hmm_bw.HMM: + expected_lengths: Dict) -> hmm_bw.HMM: ''' Build a HMM object initialized based on expected values and provided data ''' @@ -340,7 +340,7 @@ def initial_hmm_parameters(seq: np.array, unknown_states, symbol_freqs.keys()) trans = transition_probabilities(known_states, unknown_states, - expected_frac, expected_length) + expected_frac, expected_lengths) # new Hidden Markov Model hmm = hmm_bw.HMM() @@ -431,7 +431,6 @@ def predict_introgressed(ref_seqs: np.array, return predicted, p[0], hmm, hmm_init, positions -# TODO is there another convert to blocks (one)? def convert_to_blocks(state_seq: List[str], states: List[str]) -> Dict[ str, List[Tuple[int, int]]]: diff --git a/code/analyze/summarize_strain_states_main.py b/code/analyze/summarize_strain_states_main.py index cf74860..344dae7 100644 --- a/code/analyze/summarize_strain_states_main.py +++ b/code/analyze/summarize_strain_states_main.py @@ -105,7 +105,7 @@ def main() -> None: [f'num_bases_{x}_filtered2_inclusive' for x in states] +\ ['num_bases_total_filtered2_inclusive'] - r = sorted(gp.alignment_ref_order[1:]) + r = sorted(args['known_states'][1:]) for n in range(2, len(r)+1): for combo in itertools.combinations(r, n): fields += ['num_bases_' + '_or_'.join(combo) + '_filtered2i'] diff --git a/code/test/analyze/test_filter_1_main.py b/code/test/analyze/test_filter_1_main.py index 8b21b2f..5d3f097 100644 --- a/code/test/analyze/test_filter_1_main.py +++ b/code/test/analyze/test_filter_1_main.py @@ -45,8 +45,8 @@ def test_main(mocker, capsys): assert mock_filter.call_count == 2 # seems like this references the object, which changes after call mock_filter.assert_has_calls([ - mocker.call({'reason': 'test'}, 'x..'), - mocker.call({'reason': '', 'a': 1}, 'x..')]) + mocker.call({'reason': 'test'}, 'x..', 'state1'), + mocker.call({'reason': '', 'a': 1}, 'x..', 'state1')]) assert mock_write.call_count == 3 mock_write.assert_has_calls([ diff --git a/code/test/analyze/test_filter_2_main.py b/code/test/analyze/test_filter_2_main.py index dd18e02..da383b3 100644 --- a/code/test/analyze/test_filter_2_main.py +++ b/code/test/analyze/test_filter_2_main.py @@ -54,13 +54,13 @@ def test_main(mocker, capsys): {'alternative_states': '1,2', 'alternative_ids': '0.8,0.5', 'alternative_P_counts': '2,1,0'}, - ['atcg'], 0.1), + ['atcg'], 0.1, ['state1', 'state2']), mocker.call( {'a': 1, 'alternative_states': '1', 'alternative_ids': '0.8', 'alternative_P_counts': '2'}, - ['atcg'], 0.1)]) + ['atcg'], 0.1, ['state1', 'state2'])]) assert mock_write.call_count == 3 mock_write.assert_has_calls([ diff --git a/code/test/analyze/test_filter_helpers.py b/code/test/analyze/test_filter_helpers.py index 94d0212..9d8ea48 100644 --- a/code/test/analyze/test_filter_helpers.py +++ b/code/test/analyze/test_filter_helpers.py @@ -84,9 +84,6 @@ def test_passes_filters(): def test_passes_filters1(mocker): - mocker.patch('analyze.filter_helpers.gp.alignment_ref_order', - ['ref']) - # fail fraction gapped on reference region = {'predicted_species': 'pred', 'start': 0, @@ -99,7 +96,7 @@ def test_passes_filters1(mocker): 'num_sites_nongap_ref': 0, } - assert filter_helpers.passes_filters1(region, '') == \ + assert filter_helpers.passes_filters1(region, '', 'ref') == \ (False, 'fraction gaps/masked in master = 0.6') # fail fraction gapped on predicted @@ -114,7 +111,7 @@ def test_passes_filters1(mocker): 'num_sites_nongap_ref': 0, } - assert filter_helpers.passes_filters1(region, '') == \ + assert filter_helpers.passes_filters1(region, '', 'ref') == \ (False, 'fraction gaps/masked in predicted = 0.7') # fail match counts @@ -129,9 +126,10 @@ def test_passes_filters1(mocker): 'num_sites_nongap_ref': 0, } - assert filter_helpers.passes_filters1(region, 'CP') == \ + assert filter_helpers.passes_filters1(region, 'CP', 'ref') == \ (False, 'count_P = 1') - assert filter_helpers.passes_filters1(region, 'CCCCCCCCPPPPPPP') == \ + assert filter_helpers.passes_filters1(region, + 'CCCCCCCCPPPPPPP', 'ref') == \ (False, 'count_P = 7 and count_C = 8') # fail divergence, master >= pred @@ -146,7 +144,7 @@ def test_passes_filters1(mocker): 'num_sites_nongap_ref': 10, } - assert filter_helpers.passes_filters1(region, 'CPPPPPPP') == \ + assert filter_helpers.passes_filters1(region, 'CPPPPPPP', 'ref') == \ (False, 'id with master = 0.6 and id with predicted = 0.5') # fail divergence, master >= 0.7 @@ -161,7 +159,7 @@ def test_passes_filters1(mocker): 'num_sites_nongap_ref': 10, } - assert filter_helpers.passes_filters1(region, 'CPPPPPPP') == \ + assert filter_helpers.passes_filters1(region, 'CPPPPPPP', 'ref') == \ (False, 'id with master = 0.6') # passes @@ -176,13 +174,11 @@ def test_passes_filters1(mocker): 'num_sites_nongap_ref': 10, } - assert filter_helpers.passes_filters1(region, 'CPPPPPPP') == \ + assert filter_helpers.passes_filters1(region, 'CPPPPPPP', 'ref') == \ (True, '') def test_passes_filters2(mocker): - mocker.patch('analyze.filter_helpers.gp.alignment_ref_order', - ['ref', '1', '2', '3', '4']) mocker.patch('analyze.filter_helpers.gp.gap_symbol', '-') mocker.patch('analyze.filter_helpers.gp.unsequenced_symbol', 'n') @@ -197,7 +193,7 @@ def test_passes_filters2(mocker): threshold = 0 filt, states, ids, p_count = filter_helpers.passes_filters2( - region, seqs, threshold) + region, seqs, threshold, ['ref', '1', '2', '3', '4']) assert filt is False assert states == ['1', '2', '4'] assert ids == [0.8, 0.5, 0.4] @@ -205,7 +201,7 @@ def test_passes_filters2(mocker): threshold = 0.1 filt, states, ids, p_count = filter_helpers.passes_filters2( - region, seqs, threshold) + region, seqs, threshold, ['ref', '1', '2', '3', '4']) assert filt is False assert states == ['1', '2'] assert ids == [0.8, 0.5] @@ -213,7 +209,7 @@ def test_passes_filters2(mocker): threshold = 0.9 filt, states, ids, p_count = filter_helpers.passes_filters2( - region, seqs, threshold) + region, seqs, threshold, ['ref', '1', '2', '3', '4']) assert filt is True assert states == ['1'] assert ids == [0.8] @@ -221,8 +217,6 @@ def test_passes_filters2(mocker): def test_passes_filters2_on_region(mocker): - mocker.patch('analyze.filter_helpers.gp.alignment_ref_order', - ['S288c', 'CBS432', 'N_45', 'DBVPG6304', 'UWOPS91_917_1']) mocker.patch('analyze.filter_helpers.gp.gap_symbol', '-') mocker.patch('analyze.filter_helpers.gp.unsequenced_symbol', 'n') @@ -232,7 +226,8 @@ def test_passes_filters2_on_region(mocker): headers, seqs = read_fasta.read_fasta(fa, gz=False) seqs = seqs[:-1] p, alt_states, alt_ids, alt_P_counts = filter_helpers.passes_filters2( - {'predicted_species': 'N_45'}, seqs, 0.1) + {'predicted_species': 'N_45'}, seqs, 0.1, + ['S288c', 'CBS432', 'N_45', 'DBVPG6304', 'UWOPS91_917_1']) assert p is False assert alt_states == ['CBS432', 'N_45', 'UWOPS91_917_1', 'DBVPG6304'] assert alt_ids == approx([0.9983805668016195, 0.994331983805668, @@ -240,7 +235,8 @@ def test_passes_filters2_on_region(mocker): assert alt_P_counts == [145, 143, 128, 129] p, alt_states, alt_ids, alt_P_counts = filter_helpers.passes_filters2( - {'predicted_species': 'N_45'}, seqs, 0.98) + {'predicted_species': 'N_45'}, seqs, 0.98, + ['S288c', 'CBS432', 'N_45', 'DBVPG6304', 'UWOPS91_917_1']) assert p is False assert alt_states == ['CBS432', 'N_45'] assert alt_ids == approx([0.9983805668016195, 0.994331983805668]) diff --git a/code/test/analyze/test_id_regions_main.py b/code/test/analyze/test_id_regions_main.py index bdbfc00..5811bd5 100644 --- a/code/test/analyze/test_id_regions_main.py +++ b/code/test/analyze/test_id_regions_main.py @@ -4,14 +4,19 @@ def test_main_blank(mocker): # setup global params to match expectations - mocker.patch('analyze.predict.gp.alignment_ref_order', - ['ref', 'state1']) mocker.patch('analyze.id_regions_main.gp.chrms', ['I', 'II', 'III', 'IV', 'V', 'VI', 'VII', 'VIII', 'IX', 'X', 'XI', 'XII', 'XIII', 'XIV', 'XV', 'XVI']) mocker.patch('analyze.id_regions_main.gp.analysis_out_dir_absolute', 'dir/') - + mocker.patch( + 'analyze.summarize_strain_states_main.predict.process_predict_args', + return_value={ + 'known_states': ['S288c', 'CBS432', 'N_45', + 'DBVPG6304', 'UWOPS91_917_1'], + 'states': ['ref', 'state1', 'unknown'], + 'tag': 'tag' + }) mocker.patch('sys.argv', "test.py tag .001 viterbi 1000 .025 unknown 1000 .01".split()) mocker.patch('analyze.predict.read_blocks', @@ -36,8 +41,12 @@ def test_main_blank(mocker): def test_main(mocker): # setup global params to match expectations - mocker.patch('analyze.predict.gp.alignment_ref_order', - ['ref', 'state1']) + mocker.patch( + 'analyze.summarize_strain_states_main.predict.process_predict_args', + return_value={ + 'states': ['ref', 'state1', 'unknown'], + 'tag': 'tag' + }) mocker.patch('analyze.id_regions_main.gp.chrms', ['I', 'II', 'III', 'IV', 'V', 'VI', 'VII', 'VIII', 'IX', 'X', 'XI', 'XII', 'XIII', 'XIV', 'XV', 'XVI']) @@ -72,10 +81,10 @@ def test_main(mocker): main.main() - assert mocked_file.call_count == 3 mocked_file.assert_any_call('dir/tag/blocks_ref_tag_labeled.txt', 'w') mocked_file.assert_any_call('dir/tag/blocks_state1_tag_labeled.txt', 'w') mocked_file.assert_any_call('dir/tag/blocks_unknown_tag_labeled.txt', 'w') + assert mocked_file.call_count == 3 # headers calls = [ diff --git a/code/test/analyze/test_predict.py b/code/test/analyze/test_predict.py index 016b4e1..874e749 100644 --- a/code/test/analyze/test_predict.py +++ b/code/test/analyze/test_predict.py @@ -18,13 +18,35 @@ def test_gp_symbols(): @pytest.fixture def args(): - args = predict.process_predict_args('p4e2 .001 viterbi 10000 .025 10000\ - .025 10000 .025 10000 .025 unknown\ - 1000 .01'.split()) + args = {} + args['tag'] = 'p4e2' + args['improvement_frac'] = 0.001 + args['threshold'] = 'viterbi' + + args['known_states'] = ['S288c', 'CBS432', 'N_45', + 'DBVPG6304', 'UWOPS91_917_1'] + args['unknown_states'] = ['unknown'] + args['states'] = args['known_states'] + ['unknown'] + + args['expected_frac'] = {'DBVPG6304': 0.025, + 'UWOPS91_917_1': 0.025, + 'unknown': 0.01, + 'CBS432': 0.025, + 'N_45': 0.025, + 'S288c': 0.89} + + args['expected_length'] = {'DBVPG6304': 10000.0, + 'UWOPS91_917_1': 10000.0, + 'unknown': 1000.0, + 'CBS432': 10000.0, + 'N_45': 10000.0, + 'S288c': 0} + args['expected_num_tracts'] = {} + args['expected_bases'] = {} return args -def test_process_predict_args(): +def old_test_process_predict_args(): # test with default args args = predict.process_predict_args('p4e2 .001 viterbi 10000 .025 10000\ .025 10000 .025 10000 .025 unknown\ @@ -44,7 +66,7 @@ def test_process_predict_args(): 'N_45': 0.025, 'S288c': 0.89} - assert args['expected_tract_lengths'] == {'DBVPG6304': 10000.0, + assert args['expected_length'] == {'DBVPG6304': 10000.0, 'UWOPS91_917_1': 10000.0, 'unknown': 1000.0, 'CBS432': 10000.0, @@ -56,7 +78,7 @@ def test_process_predict_args(): assert len(args.keys()) == 10 -def test_process_predict_args_threshold(): +def old_test_process_predict_args_threshold(): args = predict.process_predict_args('p4e2 .001 test 10000 .025 10000\ .025 10000 .025 10000 .025 unknown\ 1000 .01'.split()) @@ -68,7 +90,7 @@ def test_process_predict_args_threshold(): assert args['threshold'] == 0.1 -def test_process_predict_args_exceptions(): +def old_test_process_predict_args_exceptions(): # not enough unknown values with pytest.raises(IndexError): predict.process_predict_args('p4e2 .001 0.1 10000 .025 10000\ @@ -202,7 +224,7 @@ def test_poly_sites(): def test_set_expectations_default(args): - prev_tract = dict(args['expected_tract_lengths']) + prev_tract = dict(args['expected_length']) assert args['expected_num_tracts'] == {} assert args['expected_bases'] == {} predict.set_expectations(args, 1e5) # made number arbitrary @@ -219,7 +241,7 @@ def test_set_expectations_default(args): 'N_45': 0.025 * 1e5, 'S288c': 1e5 - 1e4} prev_tract['S288c'] = 45000 - assert args['expected_tract_lengths'] == prev_tract + assert args['expected_length'] == prev_tract def test_get_symbol_freqs(): @@ -382,11 +404,11 @@ def mynorm(d): def test_transition_probabilities(args): - args['expected_tract_lengths']['S288c'] = 45000 + args['expected_length']['S288c'] = 45000 trans = predict.transition_probabilities(args['known_states'], args['unknown_states'], args['expected_frac'], - args['expected_tract_lengths']) + args['expected_length']) np_trans = np_transition(args) for i in range(len(trans)): @@ -396,7 +418,7 @@ def test_transition_probabilities(args): def np_transition(args): states = args['known_states'] + args['unknown_states'] expected_frac = args['expected_frac'] - expected_tract_lengths = args['expected_tract_lengths'] + expected_length = args['expected_length'] trans = [] for i in range(len(states)): state_from = states[i] @@ -405,9 +427,9 @@ def np_transition(args): for j in range(len(states)): state_to = states[j] if state_from == state_to: - trans[i].append(1 - 1./expected_tract_lengths[state_from]) + trans[i].append(1 - 1./expected_length[state_from]) else: - trans[i].append(1./expected_tract_lengths[state_from] * + trans[i].append(1./expected_length[state_from] * expected_frac[state_to] * scale_other) trans[i] /= np.sum(trans[i]) @@ -416,14 +438,14 @@ def np_transition(args): def test_initial_hmm_parameters(args): - args['expected_tract_lengths']['S288c'] = 45000 + args['expected_length']['S288c'] = 45000 symbols = predict.get_emis_symbols([1]*5) hm = predict.initial_hmm_parameters( symbols, args['known_states'], args['unknown_states'], args['expected_frac'], - args['expected_tract_lengths']) + args['expected_length']) assert args['expected_frac'] == {'DBVPG6304': 0.025, 'UWOPS91_917_1': 0.025, @@ -667,7 +689,7 @@ def test_convert_to_blocks_one(): seq = [str(random.randint(0, 9)) for i in range(100)] help_test_convert_blocks(states, seq) - + def help_test_convert_blocks(states, seq): blocks = predict.convert_to_blocks(seq, states) diff --git a/code/test/analyze/test_summarize_region_quality_main.py b/code/test/analyze/test_summarize_region_quality_main.py index 1decede..065a293 100644 --- a/code/test/analyze/test_summarize_region_quality_main.py +++ b/code/test/analyze/test_summarize_region_quality_main.py @@ -7,6 +7,15 @@ def test_main(mocker): mocker.patch( 'analyze.summarize_region_quality_main.gp.analysis_out_dir_absolute', 'dir/') + mocker.patch( + 'analyze.summarize_strain_states_main.predict.process_predict_args', + return_value={ + 'known_states': ['S288c', 'CBS432', 'N_45', + 'DBVPG6304', 'UWOPS91_917_1'], + 'states': ['S288c', 'CBS432', 'N_45', + 'DBVPG6304', 'UWOPS91_917_1', 'unknown'], + 'tag': 'tag' + }) mocker.patch('analyze.summarize_region_quality_main.gp.chrms', ['I', 'II', 'III', 'IV', 'V', 'VI', 'VII', 'VIII', 'IX', 'X', 'XI', 'XII', 'XIII', 'XIV', 'XV', 'XVI']) diff --git a/code/test/analyze/test_summarize_strain_states_main.py b/code/test/analyze/test_summarize_strain_states_main.py index 2c2fadb..bf3fae1 100644 --- a/code/test/analyze/test_summarize_strain_states_main.py +++ b/code/test/analyze/test_summarize_strain_states_main.py @@ -11,9 +11,6 @@ def test_main(mocker, capsys): mocker.patch( 'analyze.summarize_strain_states_main.gp.analysis_out_dir_absolute', '/dir') - mocker.patch( - 'analyze.summarize_strain_states_main.gp.alignment_ref_order', - ['state1', 'state2', 'state3']) mock_read = mocker.patch( 'analyze.summarize_strain_states_main.read_table.read_table_rows', From 4be5be79bfcd0309ae58885ad7dda5669a079644 Mon Sep 17 00:00:00 2001 From: Troy Comi Date: Wed, 10 Apr 2019 11:09:36 -0400 Subject: [PATCH 12/33] Config yaml Added a yaml version of the config to replace global params and setup args files. Helper script clean_config performs lookups of referenced entries in config. Have not added to main methods yet. --- code/config.yaml | 102 ++++++++++++++++++++ code/misc/clean_config.py | 140 ++++++++++++++++++++++++++++ code/test/misc/test_clean_config.py | 107 +++++++++++++++++++++ 3 files changed, 349 insertions(+) create mode 100644 code/config.yaml create mode 100644 code/misc/clean_config.py create mode 100644 code/test/misc/test_clean_config.py diff --git a/code/config.yaml b/code/config.yaml new file mode 100644 index 0000000..8343076 --- /dev/null +++ b/code/config.yaml @@ -0,0 +1,102 @@ +--- +# biological parameters +mu: 1.84e-10 + + +# should we leave the alignments already completed in the alignments +# directory alone? +resume_alignment: False + +HMM_symbols: + match: '+' + mismatch: '-' + unknown: '?' + unsequenced: 'n' + gap: '-' + unaligned: '?' + masked: 'x' + +paths: + output_base: /tigress/tcomi/aclark4_temp/results + + fasta_suffix: .fa # suffix for _all_ fasta files + # suffix for _all_ alignment files + # this needs to match the suffix output by mugsy + alignment_suffix: .maf + + # sequence locations/names + # master_ref now automatically assumed to be first + # reference specified in setup_args file + + # now specified in setup_args file + + # alignment files + + # alignments directory now specified in setup_args file + + mask_dir: /tigress/tcomi/aclark4_temp/par4/masked/ + alignments_dir: /tigress/tcomi/aclark4_temp/par4/ + + simulations: + sim_base: __OUTPUT_BASE__/sim + prefix: sim_out_ + suffix: .txt + + analysis: + analysis_base: __OUTPUT_BASE__/analysis + regions: __ANALYSIS_BASE__/regions/ + genes: __ANALYSIS_BASE__/genes/ + + # software install locations + software: + root_install: /tigress/anneec/software + mugsy: __ROOT_INSTALL__/mugsy/ + tcoffee: "__ROOT_INSTALL__/\ + T-COFFEE_installer_Version_11.00.8cbe486_linux_x64/bin/" + mafft: __ROOT_INSTALL__/mafft/bin/ + ms: __ROOT_INSTALL__/msdir/ + # including dustmasker + blast: "__ROOT_INSTALL__/ncbi-blast-2.7.1+-src/\ + c++/ReleaseMT/bin/" + orffinder: __ROOT_INSTALL__/ + ldselect: __ROOT_INSTALL__/ldSelect/ + structure: __ROOT_INSTALL__/structure/ + +chrms: ['I', 'II', 'III', 'IV', 'V', 'VI', 'VII', 'VIII', + 'IX', 'X', 'XI', 'XII', 'XIII', 'XIV', 'XV', 'XVI'] + +analysis_params: + tag: p2e4 + improvement_frac: 0.001 + threshold: viterbi + input_root: /tigress/AKEY/akey_vol2/aclark4/nobackup + + reference: + name: S228c + fasta_name: SC88c_SGD-R64 + base_dir: __INPUT_ROOT__/100_genomes/genomes/S288c_SGD-R64/ + gene_bank_dir: __INPUT_ROOT__/S288c/ + + known_states: + - name: CBS432 + fasta_name: CBS432 + base_dir: /tigress/anneec/projects/introgression/data/CBS432/ + gene_bank_dir: __INPUT_ROOT__/CBS432/ + + - name: N_45 + fasta_name: N_45 + base_dir: __INPUT_ROOT__/para_sgrp/strains/N_45/ + gene_bank_dir: null + + - name: DBVPG6304 + fasta_name: DBVPG6304 + base_dir: __INPUT_ROOT__/para_sgrp/strains/DBVPG6304/ + gene_bank_dir: null + + - name: UWOPS91_917_1 + fasta_name: UWOPS91_917_1 + base_dir: __INPUT_ROOT__/para_sgrp/strains/UWOPS91_917_1/ + gene_bank_dir: null + + test_dir: __INPUT_ROOT__/100_genomes/genomes_gb/ + gene_bank_all: __INPUT_ROOT__/100_genomes/sequence.gb diff --git a/code/misc/clean_config.py b/code/misc/clean_config.py new file mode 100644 index 0000000..f886a62 --- /dev/null +++ b/code/misc/clean_config.py @@ -0,0 +1,140 @@ +import re +from copy import copy +from typing import Dict, List + + +''' +clean_config.py + +Helper functions for performing replacements on yaml config files +''' + + +def clean_config(config: Dict, + valid_replacements: Dict[str, str] = None) -> Dict: + ''' + Performs subsitution of variables in string recursively replacing + strings of the form __.+__ with the matching key. Nested variables + with the same name replace parent values. + config is the possibly nested dict with values to replace + valid_replacements are the valid entries for performing replacements + ''' + result = {} + if valid_replacements is None: + valid_replacements = dict() + len_values = len(config) + while config: + # want to look at valid replacements first, + # to possibly replace their values + keys = config.keys() + keys = list([k for k in keys if k in valid_replacements] + + [k for k in keys if k not in valid_replacements]) + + for key in keys: + value = config[key] + if isinstance(value, str): + value = replace_entry(value, valid_replacements) + if value is None: + continue # don't remove + result[key] = value + valid_replacements[key] = value + + elif isinstance(value, dict): + result[key] = clean_config(value, + copy(valid_replacements)) + + elif isinstance(value, list): + result[key] = clean_list(value, + valid_replacements) + + else: + result[key] = value + valid_replacements[key] = str(value) + + config.pop(key) + + if len_values == len(config): + raise Exception('Failed to dereference all keys, remaining ' + f'values are:\n {print_dict(config)}') + + len_values = len(config) + + return result + + +def clean_list(config: List, + valid_replacements: Dict[str, str] = None) -> List: + ''' + Performs substitution on list of config objects + ''' + result = [] + for value in config: + if isinstance(value, str): + output = replace_entry(value, valid_replacements) + if output is None: + raise Exception(f'Failed to dereference list entry: "{value}"') + result.append(output) + + elif isinstance(value, list): + result.append(clean_list(value, valid_replacements)) + + elif isinstance(value, dict): + result.append(clean_config(value, copy(valid_replacements))) + + else: + result.append(value) + + return result + + +def replace_entry(value: str, valid_replacements: Dict[str, str]) -> str: + ''' + Replace instances of __.+__ with the key in valid_replacements + If valid replacements is none or the key is not found, return None + Else return the (possibly) substituted string with all instances of /+ + replaced with / (common in path replacements) + ''' + replacements = re.findall('__(.+?)__', value) + for replacement in set(replacements): + replace = replacement.lower() + if valid_replacements is None or replace not in valid_replacements: + return None + value = re.sub(f'__{replacement}__', + valid_replacements[replace], + value) + return re.sub('/+', '/', value) + + +def print_dict(d: Dict, lvl: int = 0) -> str: + ''' + Return pretty representation of the dictionary d. + lvl is the starting amount to indent the line + ''' + result = '' + for k, v in d.items(): + if isinstance(v, dict): + result += ' ' * lvl + f'{k} -\n' + result += print_dict(v, lvl+1) + elif isinstance(v, list): + result += ' ' * lvl + f'{k} -\n' + result += print_list(v, lvl+1) + else: + result += ' ' * lvl + f'{k} - {v}\n' + return result + + +def print_list(l: List, lvl: int = 0) -> str: + ''' + Return pretty representation of the list l. + lvl is the startin amount to indent the line + ''' + result = '' + for i, v in enumerate(l): + result += ' ' * lvl + f'{i}:\n' + if isinstance(v, dict): + result += print_dict(v, lvl+1) + elif isinstance(v, list): + result += print_list(v, lvl+1) + else: + result += ' ' * lvl + f'{v},\n' + return result diff --git a/code/test/misc/test_clean_config.py b/code/test/misc/test_clean_config.py new file mode 100644 index 0000000..e933e40 --- /dev/null +++ b/code/test/misc/test_clean_config.py @@ -0,0 +1,107 @@ +import pytest +from misc.clean_config import clean_config, print_dict, clean_list +from yaml import load + + +def test_simple(): + config = {'base_name': 'base', + 'test': '__BASE_NAME__/test.txt', + 'test2': '__BASE_NAME__/__BASE_NAME__/test.txt', + 'test3': '__BASE_NAME____BASE_NAME__/test.txt', + 'test4': '__BASE_NAME__/__TEST__/test.txt', + 'test5': '__BASE_NAME__/__TEST__/test__DIGIT__.txt', + 'test6': '__BASE_NAME__//test__DIGIT__.txt', + 'test7': '//test///test////test/', + 'digit': 10, + } + config = clean_config(config) + assert config == {'base_name': 'base', + 'test': 'base/test.txt', + 'test2': 'base/base/test.txt', + 'test3': 'basebase/test.txt', + 'test4': 'base/base/test.txt/test.txt', + 'test5': 'base/base/test.txt/test10.txt', + 'test6': 'base/test10.txt', + 'test7': '/test/test/test/', + 'digit': 10, + } + + +def test_circular(): + config = {'base_name': '__TEST2__', + 'test': '__BASE_NAME__/test.txt', + 'test2': '__BASE_NAME__/test.txt', + } + with pytest.raises(Exception) as e: + clean_config(config) + assert 'Failed to dereference all keys' in str(e) + + +def test_nest_dict(): + config = { + 'base_name': 'base', + 'test': '__BASE_NAME__/test.txt', + 'dict2': { + 'test2': '__BASE_NAME__/test2.txt', + 'base_name': 'base2', + 'dict3': { + 'base_name': 'base3', + 'test3': '__BASE_NAME__/test3.txt' + }, + 'dict4': { + 'test4': '__BASE_NAME__/test3.txt' + } + }, + 'test5': '__BASE_NAME__/test5.txt', + 'dict5': { + 'base_name': '__BASE_NAME__', + 'test6': '__BASE_NAME__/test_5.txt' + } + } + config = clean_config(config) + assert config == { + 'base_name': 'base', + 'test': 'base/test.txt', + 'dict2': { + 'test2': 'base2/test2.txt', + 'base_name': 'base2', + 'dict3': { + 'base_name': 'base3', + 'test3': 'base3/test3.txt' + }, + 'dict4': { + 'test4': 'base2/test3.txt' + }}, + 'test5': 'base/test5.txt', + 'dict5': { + 'base_name': 'base', + 'test6': 'base/test_5.txt' + } + } + + +def test_clean_list(): + config = [ + 'test', + {'base': 'base', + 'test2': '__BASE__/test2'}, + {'base': 'base2', + 'test3': '__BASE__/test3'}, + [1, 2, 3], + 17 + ] + assert clean_list(config) == [ + 'test', + {'base': 'base', + 'test2': 'base/test2'}, + {'base': 'base2', + 'test3': 'base2/test3'}, + [1, 2, 3], + 17 + ] + + with pytest.raises(Exception) as e: + clean_list(['__NOT_FOUND__']) + assert 'Failed to dereference list entry: "__NOT_FOUND__"' in str(e) + + assert clean_list(['__BASE__/test'], {'base': 'base'}) == ['base/test'] From 9102a16e677c8dcd662edd0ee93a3fcc9bbfcae2 Mon Sep 17 00:00:00 2001 From: Troy Comi Date: Fri, 19 Apr 2019 08:36:31 -0400 Subject: [PATCH 13/33] Working on predict main --- .gitignore | 1 + code/analyze/main.py | 191 ++++++++++++++++ code/analyze/predict.py | 42 ++++ code/analyze/predict_main.py | 8 +- code/config.yaml | 52 +++-- .../misc/{clean_config.py => config_utils.py} | 125 ++++++++++- code/setup.py | 16 ++ code/test/analyze/test_main.py | 30 +++ code/test/analyze/test_predict.py | 4 + code/test/hmm/test_hmm_bw.py | 3 +- code/test/misc/test_clean_config.py | 107 --------- code/test/misc/test_config_utils.py | 209 ++++++++++++++++++ 12 files changed, 644 insertions(+), 144 deletions(-) create mode 100644 code/analyze/main.py rename code/misc/{clean_config.py => config_utils.py} (56%) create mode 100644 code/setup.py create mode 100644 code/test/analyze/test_main.py delete mode 100644 code/test/misc/test_clean_config.py create mode 100644 code/test/misc/test_config_utils.py diff --git a/.gitignore b/.gitignore index cbfb72c..124f3ae 100644 --- a/.gitignore +++ b/.gitignore @@ -33,3 +33,4 @@ code/*/scratch/* code/setup/* .coverage *.swp +*egg-info diff --git a/code/analyze/main.py b/code/analyze/main.py new file mode 100644 index 0000000..bc7a533 --- /dev/null +++ b/code/analyze/main.py @@ -0,0 +1,191 @@ +import click +import yaml +import glob +import re +import logging as log +from misc import config_utils +from misc.config_utils import (get_nested, check_wildcards, get_states, + validate) +from typing import List, Dict + + +# TODO also check for snakemake object? +@click.group(invoke_without_command=True) +@click.option('--config', '-c', + multiple=True, + type=click.File('r'), + help='Base configuration yaml.') +@click.option('-v', '--verbosity', count=True, default=2) +@click.pass_context +def cli(ctx, config, verbosity): + ''' + Main entry script to run analyze methods + ''' + + verbosity = 4 if verbosity > 4 else verbosity + levelstr = ['CRITICAL', 'ERROR', + 'WARNING', 'INFO', + 'DEBUG'][verbosity] + level = [log.CRITICAL, log.ERROR, + log.WARNING, log.INFO, + log.DEBUG][verbosity] + + log.basicConfig(level=level) + log.info(f'Verbosity set to {levelstr}') + + ctx.ensure_object(dict) + + log.info(f'Reading in {len(config)} config files') + for path in config: + conf = yaml.safe_load(path) + ctx.obj = config_utils.merge_dicts(ctx.obj, conf) + + ctx.obj = config_utils.clean_config(ctx.obj) + log.debug('Cleaned config:\n' + config_utils.print_dict(ctx.obj)) + + if ctx.invoked_subcommand is None: + click.echo_via_pager( + click.style( + 'No command supplied. Read in the following config:\n', + fg='yellow') + + config_utils.print_dict(ctx.obj)) + + +@cli.command() +@click.pass_context +@click.option('--blocks', default='', help='Block file location with {state}') +@click.option('--prefix', default='', help='Prefix of test-strain files ' + 'default to list of states joined with _.') +@click.option('--test-strains', default='', + help='Test files location with {strain} and {chrom}') +@click.option('--hmm-initial', default='', + help='Initial hmm parameter text file') +@click.option('--hmm-trained', default='', + help='Trained hmm parameter text file') +@click.option('--positions', default='', + help='Positions file, gzipped') +@click.option('--probabilities', default='', + help='Probabilities file, gzipped') +@click.option('--alignment', default='', + help='Alignment file location with ' + '{prefix}, {strain}, and {chrom}') +def predict(ctx, + blocks, + prefix, + test_strains, + hmm_initial, + hmm_trained, + positions, + probabilities, + alignment): + config = ctx.obj + + chromosomes = validate(config, + 'chromosomes', + 'No chromosomes specified in config file!') + + blocks = validate(config, + 'paths.analysis.block_files', + 'No block file provided', + blocks) + + check_wildcards(blocks, 'state') + log.info(f'output blocks file for predict is {blocks}') + + known, unknown = get_states(config) + if prefix == '': + prefix = '_'.join(known) + + log.info(f'prefix is {prefix}') + + if test_strains == '': + test_strains = get_nested(config, 'paths.test_strains') + else: + # need to support list for test strains + test_strains = [test_strains] + for test_strain in test_strains: + check_wildcards(test_strain, 'strain,chrom') + + log.info(f'found {len(test_strains)} test strains') + + strains = get_strains(config, test_strains, prefix, chromosomes) + log.info(f'found {len(strains)} unique strains') + + hmm_initial = validate(config, + 'paths.analysis.hmm_initial', + 'No initial hmm file provided', + hmm_initial) + log.info(f'hmm_initial is {hmm_initial}') + + hmm_trained = validate(config, + 'paths.analysis.hmm_trained', + 'No trained hmm file provided', + hmm_trained) + log.info(f'hmm_trained is {hmm_trained}') + + positions = validate(config, + 'paths.analysis.positions', + 'No positions file provided', + positions) + log.info(f'positions is {positions}') + + probabilities = validate(config, + 'paths.analysis.probabilities', + 'No probabilities file provided', + probabilities) + log.info(f'probabilities is {probabilities}') + + alignment = validate(config, + 'paths.analysis.alignment', + 'No alignment file provided', + alignment) + check_wildcards(alignment, 'prefix,strain,chrom') + alignment = alignment.replace('{prefix}', prefix) + log.info(f'alignment is {alignment}') + + +def get_strains(config: Dict, + test_strains: List, + prefix: str, + chromosomes: List): + ''' + Helper method to get strains supplied in config, or from test_strains + ''' + strains = get_nested(config, 'strains') + + if strains is None: + # try to build strains from wildcards in test_strains + strains = {} + for test_strain in test_strains: + strain_glob = test_strain.format( + prefix=prefix, + strain='*', + chrom='*') + log.info(f'searching for {strain_glob}') + for fname in glob.iglob(strain_glob): + match = re.match( + test_strain.format( + prefix=prefix, + strain='(?P.*?)', + chrom='(?P[^_]*?)' + ), + fname) + if match: + log.debug(f'matched with {match.group("strain", "chrom")}') + strain, chrom = match.group('strain', 'chrom') + if strain not in strains: + strains[strain] = [] + strains[strain].append(chrom) + + if len(strains) == 0: + err = f'Found no chromosome sequence files in {test_strains}' + log.exception(err) + raise ValueError(err) + + for strain, chroms in strains.items(): + if len(chromosomes) != len(chroms): + err = (f'Strain {strain} has incorrect number of chromosomes. ' + f'Expected {len(chromosomes)} found {len(chroms)}') + log.exception(err) + raise ValueError(err) + return list(sorted(strains.keys())) diff --git a/code/analyze/predict.py b/code/analyze/predict.py index e5b5655..c34790c 100644 --- a/code/analyze/predict.py +++ b/code/analyze/predict.py @@ -1,4 +1,5 @@ import copy +import os import gzip import itertools from collections import defaultdict, Counter @@ -9,6 +10,9 @@ from misc import read_fasta import numpy as np from typing import List, Dict, Tuple, TextIO +from contextlib import ExitStack +import logging as log +from misc.read_fasta import read_fasta def process_predict_args(arg_list: List[str]) -> Dict: @@ -637,3 +641,41 @@ def write_state_probs(probs: Dict[str, List[float]], for i, state in enumerate(states)])) writer.write('\n') + + +def run(known_states, unknown_states, + hmm_initial, hmm_trained, + blocks, positions, probabilities, + chromosomes, strains, alignment): + + emission_symbols = get_emis_symbols(known_states) + + with open(hmm_initial, 'w') as initial, \ + open(hmm_trained, 'w') as trained, \ + gzip.open(positions, 'wt') as positions, \ + gzip.open(probabilities, 'wt') as probabilities, \ + ExitStack() as stack: + + block_writers = {state: + stack.enter_context( + open(blocks.format(state=state), 'w')) + for state in known_states + unknown_states} + + write_hmm_header(known_states, unknown_states, + emission_symbols, initial) + write_hmm_header(known_states, unknown_states, + emission_symbols, trained) + + for chrom in chromosomes: + for strain in strains: + log.info(f'working on: {strain} {chrom}') + alignment_file = alignment.format(strain=strain, chrom=chrom) + + headers, sequences = read_fasta(alignment_file) + + references = sequences[:-1] + predicted = sequences[-1] + + states, probabilities, hmm_trained, hmm_initial, positions =\ + predict_introgressed(references, predicted, + ARGS, train=True) diff --git a/code/analyze/predict_main.py b/code/analyze/predict_main.py index d564543..3b7d8c9 100644 --- a/code/analyze/predict_main.py +++ b/code/analyze/predict_main.py @@ -30,13 +30,9 @@ # output files and if and where to resume ##====== -if not os.path.isdir(gp.analysis_out_dir_absolute + args['tag']): - os.makedirs(gp.analysis_out_dir_absolute + args['tag']) - -# positions -# TODO move this to more general location and make separate files for -# each strain x chrm base_dir = f'{gp.analysis_out_dir_absolute}{args["tag"]}' +if not os.path.isdir(base_dir): + os.makedirs(base_dir) # introgressed blocks blocks_f = {} diff --git a/code/config.yaml b/code/config.yaml index 8343076..c4bc225 100644 --- a/code/config.yaml +++ b/code/config.yaml @@ -16,36 +16,37 @@ HMM_symbols: unaligned: '?' masked: 'x' -paths: - output_base: /tigress/tcomi/aclark4_temp/results +output_root: /tigress/tcomi/aclark4_temp/results +input_root: /tigress/AKEY/akey_vol2/aclark4/nobackup +paths: fasta_suffix: .fa # suffix for _all_ fasta files # suffix for _all_ alignment files # this needs to match the suffix output by mugsy alignment_suffix: .maf - # sequence locations/names - # master_ref now automatically assumed to be first - # reference specified in setup_args file - - # now specified in setup_args file - - # alignment files + masks: /tigress/tcomi/aclark4_temp/par4/masked/ + alignments: /tigress/tcomi/aclark4_temp/par4/ - # alignments directory now specified in setup_args file - - mask_dir: /tigress/tcomi/aclark4_temp/par4/masked/ - alignments_dir: /tigress/tcomi/aclark4_temp/par4/ + test_strains: + - "__INPUT_ROOT__/100_genomes/genomes_gb/\ + {strain}_chr{chrom}.fa" simulations: - sim_base: __OUTPUT_BASE__/sim + sim_base: __OUTPUT_ROOT__/sim prefix: sim_out_ suffix: .txt analysis: - analysis_base: __OUTPUT_BASE__/analysis + analysis_base: __OUTPUT_ROOT__/analysis regions: __ANALYSIS_BASE__/regions/ genes: __ANALYSIS_BASE__/genes/ + block_files: __ANALYSIS_BASE__/blocks_{state}.txt + hmm_initial: __ANALYSIS_BASE__/hmm_initial.txt + hmm_trained: __ANALYSIS_BASE__/hmm_trained.txt + positions: __ANALYSIS_BASE__/positions.txt.gz + probabilities: __ANALYSIS_BASE__/probabilities.txt.gz + alignment: __ALIGNMENTS__/{prefix}_{strain}_chr{chrom}_mafft.maf # software install locations software: @@ -62,8 +63,15 @@ paths: ldselect: __ROOT_INSTALL__/ldSelect/ structure: __ROOT_INSTALL__/structure/ -chrms: ['I', 'II', 'III', 'IV', 'V', 'VI', 'VII', 'VIII', - 'IX', 'X', 'XI', 'XII', 'XIII', 'XIV', 'XV', 'XVI'] +chromosomes: ['I', 'II', 'III', 'IV', 'V', 'VI', 'VII', 'VIII', + 'IX', 'X', 'XI', 'XII', 'XIII', 'XIV', 'XV', 'XVI'] + +# can optionally list all strains to consider +# if blank will glob with TEST_STRAINS paths +# strains: + +# can provide a prefix for the alignment files +# if blank will be the reference and known state names joined with '_' analysis_params: tag: p2e4 @@ -71,32 +79,30 @@ analysis_params: threshold: viterbi input_root: /tigress/AKEY/akey_vol2/aclark4/nobackup + # master known state, prepeded to list of known states reference: name: S228c - fasta_name: SC88c_SGD-R64 base_dir: __INPUT_ROOT__/100_genomes/genomes/S288c_SGD-R64/ gene_bank_dir: __INPUT_ROOT__/S288c/ known_states: - name: CBS432 - fasta_name: CBS432 base_dir: /tigress/anneec/projects/introgression/data/CBS432/ gene_bank_dir: __INPUT_ROOT__/CBS432/ - name: N_45 - fasta_name: N_45 base_dir: __INPUT_ROOT__/para_sgrp/strains/N_45/ gene_bank_dir: null - name: DBVPG6304 - fasta_name: DBVPG6304 base_dir: __INPUT_ROOT__/para_sgrp/strains/DBVPG6304/ gene_bank_dir: null - name: UWOPS91_917_1 - fasta_name: UWOPS91_917_1 base_dir: __INPUT_ROOT__/para_sgrp/strains/UWOPS91_917_1/ gene_bank_dir: null - test_dir: __INPUT_ROOT__/100_genomes/genomes_gb/ + unknown_states: + - name: unknown + gene_bank_all: __INPUT_ROOT__/100_genomes/sequence.gb diff --git a/code/misc/clean_config.py b/code/misc/config_utils.py similarity index 56% rename from code/misc/clean_config.py rename to code/misc/config_utils.py index f886a62..5ba7e40 100644 --- a/code/misc/clean_config.py +++ b/code/misc/config_utils.py @@ -1,12 +1,13 @@ import re from copy import copy -from typing import Dict, List +from typing import Dict, List, Tuple +import logging as log ''' -clean_config.py +config_utils.py -Helper functions for performing replacements on yaml config files +Helper functions for working with yaml config files ''' @@ -130,11 +131,121 @@ def print_list(l: List, lvl: int = 0) -> str: ''' result = '' for i, v in enumerate(l): - result += ' ' * lvl + f'{i}:\n' if isinstance(v, dict): - result += print_dict(v, lvl+1) + result += ' ' * lvl + f'{i}:\n' + print_dict(v, lvl+1) elif isinstance(v, list): - result += print_list(v, lvl+1) + result += ' ' * lvl + f'{i}:\n' + print_list(v, lvl+1) else: - result += ' ' * lvl + f'{v},\n' + result += ' ' * lvl + f'{i}:\t{v},\n' return result + + +def merge_dicts(parent: Dict, new: Dict) -> Dict: + ''' + Merge the new dict into parent. Existing items are overwritten, + dicts are merged recursively, lists are combined as sets. + ''' + + for k, v in new.items(): + if k in parent: + if isinstance(v, dict): + parent[k] = merge_dicts(parent[k], v) + + else: + parent[k] = v + else: + parent[k] = v + + return parent + + +def merge_lists(parent: List, new: List) -> List: + ''' + Merge new list into parent. If new item isn't in list, add it. + Overwriting and nesting is not supported as it seems ill-defined. + ''' + for i, v in enumerate(new): + if v not in parent: + parent.append(v) + + return parent + + +def get_nested(config: Dict, keys: str): + ''' + Return the value of the nested keys, or none if the key is invalid + keys is a period separated list of keys as a string + ''' + keys = keys.split('.') + value = config + try: + for k in keys: + value = value[k] + except KeyError: + return None + return value + + +def check_wildcards(path: str, wildcards: str) -> bool: + ''' + Check if the supplied path contains all required wildcards + wildcards are provided as a comma separated list string + returns true if all wildcards are present in path, e.g. {wildcard} in path + else raises a ValueError with the unfound wildcard + ''' + for wildcard in wildcards.split(','): + if f'{{{wildcard}}}' not in path: + err = f'{{{wildcard}}} not found in {path}' + log.exception(err) + raise ValueError(err) + + return True + + +def get_states(config: Dict) -> Tuple[List, List]: + ''' + From the provided config dict, build lists of known and unknown states + from the analysis params + ''' + + ref = get_nested(config, 'analysis_params.reference.name') + if ref is None: + ref = [] + else: + ref = [ref] + + known = get_nested(config, 'analysis_params.known_states') + if known is None: + known = [] + + known_states = ref + [s['name'] for s in known] + + unknown = get_nested(config, 'analysis_params.unknown_states') + if unknown is None: + unknown = [] + + unknown_states = [s['name'] for s in unknown] + + return known_states, unknown_states + + +def validate(config: Dict, + path: str, + exception: str, + value: str = None): + ''' + validate the supplied value, raising exception if no value is found + config: the config dictionary to lookup + path: the path in nested config dict + exception: string to display if no value is found + value: starting value. values of None or '' will cause lookup into config + ''' + + if value is None or value == '': + value = get_nested(config, path) + + if value is None: + log.exception(exception) + raise ValueError(exception) + + return value diff --git a/code/setup.py b/code/setup.py new file mode 100644 index 0000000..d975663 --- /dev/null +++ b/code/setup.py @@ -0,0 +1,16 @@ +from setuptools import setup, find_packages + +setup( + name='introgression', + version='0.1', + packages=find_packages(), + include_package_data=True, + install_requires=[ + 'Click', + ], + entry_points=''' + [console_scripts] + introgression=analyze.main:cli + +''', +) diff --git a/code/test/analyze/test_main.py b/code/test/analyze/test_main.py new file mode 100644 index 0000000..a847f7b --- /dev/null +++ b/code/test/analyze/test_main.py @@ -0,0 +1,30 @@ +import pytest +from click.testing import CliRunner +import analyze.main as main +import yaml + + +@pytest.fixture +def runner(): + return CliRunner() + + +def test_main_cli(runner, mocker): + result = runner.invoke(main.cli) + assert result.exit_code == 0 + + with runner.isolated_filesystem(): + clean = mocker.patch('analyze.main.config_utils.clean_config', + return_value=dict()) + with open('config1.yaml', 'w') as f: + yaml.dump({'test': '123'}, f) + with open('config2.yaml', 'w') as f: + yaml.dump({'test': '23', 'test2': '34'}, f) + + result = runner.invoke( + main.cli, + '--config config1.yaml --config config2.yaml'.split()) + assert result.exit_code == 0 + clean.assert_called_with( + {'test': '23', + 'test2': '34'}) diff --git a/code/test/analyze/test_predict.py b/code/test/analyze/test_predict.py index 874e749..8cb67c8 100644 --- a/code/test/analyze/test_predict.py +++ b/code/test/analyze/test_predict.py @@ -709,3 +709,7 @@ def help_test_convert_blocks(states, seq): for k in blocks: assert blocks[k] == result[k] + + +def test_run(): + pass diff --git a/code/test/hmm/test_hmm_bw.py b/code/test/hmm/test_hmm_bw.py index ced8288..6e7c538 100644 --- a/code/test/hmm/test_hmm_bw.py +++ b/code/test/hmm/test_hmm_bw.py @@ -189,7 +189,8 @@ def test_emission_probabilities(hm3): den = np.logaddexp.reduce(gamma, axis=0) den = np.logaddexp.reduce(den, axis=0) - obs = np.array([i == hm.observations for i in range(len(hm.observed_states))]) + obs = np.array([i == hm.observations + for i in range(len(hm.observed_states))]) obs = np.moveaxis(obs, [0, 1, 2], [2, 0, 1]) gam = np.where(obs[:, :, None, :], gamma[:, :, :, None], np.NINF) num = np.logaddexp.reduce(np.logaddexp.reduce(gam)) diff --git a/code/test/misc/test_clean_config.py b/code/test/misc/test_clean_config.py deleted file mode 100644 index e933e40..0000000 --- a/code/test/misc/test_clean_config.py +++ /dev/null @@ -1,107 +0,0 @@ -import pytest -from misc.clean_config import clean_config, print_dict, clean_list -from yaml import load - - -def test_simple(): - config = {'base_name': 'base', - 'test': '__BASE_NAME__/test.txt', - 'test2': '__BASE_NAME__/__BASE_NAME__/test.txt', - 'test3': '__BASE_NAME____BASE_NAME__/test.txt', - 'test4': '__BASE_NAME__/__TEST__/test.txt', - 'test5': '__BASE_NAME__/__TEST__/test__DIGIT__.txt', - 'test6': '__BASE_NAME__//test__DIGIT__.txt', - 'test7': '//test///test////test/', - 'digit': 10, - } - config = clean_config(config) - assert config == {'base_name': 'base', - 'test': 'base/test.txt', - 'test2': 'base/base/test.txt', - 'test3': 'basebase/test.txt', - 'test4': 'base/base/test.txt/test.txt', - 'test5': 'base/base/test.txt/test10.txt', - 'test6': 'base/test10.txt', - 'test7': '/test/test/test/', - 'digit': 10, - } - - -def test_circular(): - config = {'base_name': '__TEST2__', - 'test': '__BASE_NAME__/test.txt', - 'test2': '__BASE_NAME__/test.txt', - } - with pytest.raises(Exception) as e: - clean_config(config) - assert 'Failed to dereference all keys' in str(e) - - -def test_nest_dict(): - config = { - 'base_name': 'base', - 'test': '__BASE_NAME__/test.txt', - 'dict2': { - 'test2': '__BASE_NAME__/test2.txt', - 'base_name': 'base2', - 'dict3': { - 'base_name': 'base3', - 'test3': '__BASE_NAME__/test3.txt' - }, - 'dict4': { - 'test4': '__BASE_NAME__/test3.txt' - } - }, - 'test5': '__BASE_NAME__/test5.txt', - 'dict5': { - 'base_name': '__BASE_NAME__', - 'test6': '__BASE_NAME__/test_5.txt' - } - } - config = clean_config(config) - assert config == { - 'base_name': 'base', - 'test': 'base/test.txt', - 'dict2': { - 'test2': 'base2/test2.txt', - 'base_name': 'base2', - 'dict3': { - 'base_name': 'base3', - 'test3': 'base3/test3.txt' - }, - 'dict4': { - 'test4': 'base2/test3.txt' - }}, - 'test5': 'base/test5.txt', - 'dict5': { - 'base_name': 'base', - 'test6': 'base/test_5.txt' - } - } - - -def test_clean_list(): - config = [ - 'test', - {'base': 'base', - 'test2': '__BASE__/test2'}, - {'base': 'base2', - 'test3': '__BASE__/test3'}, - [1, 2, 3], - 17 - ] - assert clean_list(config) == [ - 'test', - {'base': 'base', - 'test2': 'base/test2'}, - {'base': 'base2', - 'test3': 'base2/test3'}, - [1, 2, 3], - 17 - ] - - with pytest.raises(Exception) as e: - clean_list(['__NOT_FOUND__']) - assert 'Failed to dereference list entry: "__NOT_FOUND__"' in str(e) - - assert clean_list(['__BASE__/test'], {'base': 'base'}) == ['base/test'] diff --git a/code/test/misc/test_config_utils.py b/code/test/misc/test_config_utils.py new file mode 100644 index 0000000..aa855a0 --- /dev/null +++ b/code/test/misc/test_config_utils.py @@ -0,0 +1,209 @@ +import pytest +from misc.config_utils import (clean_config, clean_list, + merge_lists, merge_dicts, + get_nested, check_wildcards, + get_states, validate) + + +def test_simple(): + config = {'base_name': 'base', + 'test': '__BASE_NAME__/test.txt', + 'test2': '__BASE_NAME__/__BASE_NAME__/test.txt', + 'test3': '__BASE_NAME____BASE_NAME__/test.txt', + 'test4': '__BASE_NAME__/__TEST__/test.txt', + 'test5': '__BASE_NAME__/__TEST__/test__DIGIT__.txt', + 'test6': '__BASE_NAME__//test__DIGIT__.txt', + 'test7': '//test///test////test/', + 'digit': 10, + } + config = clean_config(config) + assert config == {'base_name': 'base', + 'test': 'base/test.txt', + 'test2': 'base/base/test.txt', + 'test3': 'basebase/test.txt', + 'test4': 'base/base/test.txt/test.txt', + 'test5': 'base/base/test.txt/test10.txt', + 'test6': 'base/test10.txt', + 'test7': '/test/test/test/', + 'digit': 10, + } + + +def test_circular(): + config = {'base_name': '__TEST2__', + 'test': '__BASE_NAME__/test.txt', + 'test2': '__BASE_NAME__/test.txt', + } + with pytest.raises(Exception) as e: + clean_config(config) + assert 'Failed to dereference all keys' in str(e) + + +def test_nest_dict(): + config = { + 'base_name': 'base', + 'test': '__BASE_NAME__/test.txt', + 'dict2': { + 'test2': '__BASE_NAME__/test2.txt', + 'base_name': 'base2', + 'dict3': { + 'base_name': 'base3', + 'test3': '__BASE_NAME__/test3.txt' + }, + 'dict4': { + 'test4': '__BASE_NAME__/test3.txt' + } + }, + 'test5': '__BASE_NAME__/test5.txt', + 'dict5': { + 'base_name': '__BASE_NAME__', + 'test6': '__BASE_NAME__/test_5.txt' + } + } + config = clean_config(config) + assert config == { + 'base_name': 'base', + 'test': 'base/test.txt', + 'dict2': { + 'test2': 'base2/test2.txt', + 'base_name': 'base2', + 'dict3': { + 'base_name': 'base3', + 'test3': 'base3/test3.txt' + }, + 'dict4': { + 'test4': 'base2/test3.txt' + }}, + 'test5': 'base/test5.txt', + 'dict5': { + 'base_name': 'base', + 'test6': 'base/test_5.txt' + } + } + + +def test_clean_list(): + config = [ + 'test', + {'base': 'base', + 'test2': '__BASE__/test2'}, + {'base': 'base2', + 'test3': '__BASE__/test3'}, + [1, 2, 3], + 17 + ] + assert clean_list(config) == [ + 'test', + {'base': 'base', + 'test2': 'base/test2'}, + {'base': 'base2', + 'test3': 'base2/test3'}, + [1, 2, 3], + 17 + ] + + with pytest.raises(Exception) as e: + clean_list(['__NOT_FOUND__']) + assert 'Failed to dereference list entry: "__NOT_FOUND__"' in str(e) + + assert clean_list(['__BASE__/test'], {'base': 'base'}) == ['base/test'] + + +def test_merge_lists(): + assert merge_lists([], list('abc')) == list('abc') + assert merge_lists(list('abc'), list('abc')) == list('abc') + assert merge_lists(list('abc'), list('d')) == list('abcd') + assert merge_lists(list('abc'), []) == list('abc') + assert merge_lists(list('abc'), [1, 2, 3]) == ['a', 'b', 'c', 1, 2, 3] + assert merge_lists([{'a': 1, 'b': 2}, 1], [{'a': 1, 'b': 2}, 3]) ==\ + [{'a': 1, 'b': 2}, 1, 3] + assert merge_lists([{'a': 1, 'b': 2}, 1], [{'a': 1, 'b': 3}, 3]) ==\ + [{'a': 1, 'b': 2}, 1, {'a': 1, 'b': 3}, 3] + + +def test_merge_dicts(): + assert merge_dicts({1: 1, 2: 2}, {}) == {1: 1, 2: 2} + assert merge_dicts({}, {1: 1, 2: 2}) == {1: 1, 2: 2} + assert merge_dicts({1: 3}, {1: 1, 2: 2}) == {1: 1, 2: 2} + # only new value type matters + assert merge_dicts({1: 3, 2: {}}, {1: 1, 2: 2}) == {1: 1, 2: 2} + # nested dict + assert merge_dicts({1: 3, 2: {3: 4}}, {1: 1, 2: {3: 3}}) == \ + {1: 1, 2: {3: 3}} + # nested list, just overwrite + assert merge_dicts({1: 3, 2: {3: [1, 2]}}, {1: 1, 2: {3: [3, 4]}}) == \ + {1: 1, 2: {3: [3, 4]}} + + +def test_get_nested(): + assert get_nested({'a': 1}, 'a') == 1 + assert get_nested({'a': 1}, 'b') is None + assert get_nested({'a': {'b': 2}}, 'a.b') == 2 + assert get_nested({'a': {'b': 2}}, 'a.c') is None + assert get_nested({'a': {'b': {'c': 3}}}, 'a.b.c') == 3 + + +def test_check_wildcards(mocker): + assert check_wildcards('{test}.txt', 'test') + assert check_wildcards('{test}{string}.txt', 'test,string') + + mock_log = mocker.patch('misc.config_utils.log.exception') + with pytest.raises(ValueError) as e: + check_wildcards('test.txt', 'test') + + mock_log.assert_called_with('{test} not found in test.txt') + assert '{test} not found in test.txt' in str(e) + + +def test_get_states(): + assert get_states({}) == ([], []) + assert get_states( + { + 'analysis_params': { + 'known_states': [ + {'name': 'k1'}, + {'name': 'k2'}, + {'name': 'k3'}, + ], + 'unknown_states': [ + {'name': 'u1'}, + {'name': 'u2'}, + ] + } + }) == ('k1 k2 k3'.split(), 'u1 u2'.split()) + assert get_states( + { + 'analysis_params': { + 'reference': {'name': 'ref'}, + 'unknown_states': [ + {'name': 'u1'}, + {'name': 'u2'}, + ] + } + }) == ('ref'.split(), 'u1 u2'.split()) + assert get_states( + { + 'analysis_params': { + 'reference': {'name': 'ref'}, + 'known_states': [ + {'name': 'k1'}, + {'name': 'k2'}, + {'name': 'k3'}, + ], + 'unknown_states': [ + {'name': 'u1'}, + {'name': 'u2'}, + ] + } + }) == ('ref k1 k2 k3'.split(), 'u1 u2'.split()) + + +def test_validate(mocker): + assert validate({}, '', '', 'test') == 'test' + assert validate({'path': 'test'}, 'path', '') == 'test' + assert validate({'path': 'test'}, 'path', '', '') == 'test' + mock_log = mocker.patch('misc.config_utils.log.exception') + with pytest.raises(ValueError) as e: + validate({'path': 'test'}, 'path2', 'except', '') + assert 'except' in str(e) + mock_log.assert_called_with('except') From 0ee0ecf83b44fa0618e76d81aa58195b5a47a10b Mon Sep 17 00:00:00 2001 From: Troy Comi Date: Fri, 19 Apr 2019 10:27:37 -0400 Subject: [PATCH 14/33] Started Predict refactor Started moving predict methods into a Predictor class to clean up the long argument lists for performing a prediction. Need to test new code and get old tests passing again with minimal object --- code/analyze/main.py | 131 +++-------------- code/analyze/predict.py | 315 +++++++++++++++++++++++++++++----------- code/environment.yml | 12 +- 3 files changed, 266 insertions(+), 192 deletions(-) diff --git a/code/analyze/main.py b/code/analyze/main.py index bc7a533..6fd217a 100644 --- a/code/analyze/main.py +++ b/code/analyze/main.py @@ -80,112 +80,29 @@ def predict(ctx, alignment): config = ctx.obj - chromosomes = validate(config, - 'chromosomes', - 'No chromosomes specified in config file!') - - blocks = validate(config, - 'paths.analysis.block_files', - 'No block file provided', - blocks) - - check_wildcards(blocks, 'state') - log.info(f'output blocks file for predict is {blocks}') - - known, unknown = get_states(config) - if prefix == '': - prefix = '_'.join(known) - - log.info(f'prefix is {prefix}') - - if test_strains == '': - test_strains = get_nested(config, 'paths.test_strains') - else: - # need to support list for test strains - test_strains = [test_strains] - for test_strain in test_strains: - check_wildcards(test_strain, 'strain,chrom') - - log.info(f'found {len(test_strains)} test strains') - - strains = get_strains(config, test_strains, prefix, chromosomes) - log.info(f'found {len(strains)} unique strains') - - hmm_initial = validate(config, - 'paths.analysis.hmm_initial', - 'No initial hmm file provided', - hmm_initial) - log.info(f'hmm_initial is {hmm_initial}') - - hmm_trained = validate(config, - 'paths.analysis.hmm_trained', - 'No trained hmm file provided', - hmm_trained) - log.info(f'hmm_trained is {hmm_trained}') - - positions = validate(config, - 'paths.analysis.positions', - 'No positions file provided', - positions) + predictor = predict.Predictor(config) + predictor.set_chromosomes() + + predictor.set_blocks_file(blocks) + log.info(f'output blocks file for predict is {predictor.blocks}') + + predictor.set_prefix(prefix) + log.info(f'prefix is {predictor.prefix}') + + predictor.set_strains(test_strains) + log.info(f'found {len(predictor.test_strains)} test strains') + log.info(f'found {len(predictor.strains)} unique strains') + + predictor.set_output_files(hmm_initial, + hmm_trained, + positions, + probabilities, + alignment) + log.info(f'hmm_initial is {predictor.hmm_initial}') + log.info(f'hmm_trained is {predictor.hmm_trained}') log.info(f'positions is {positions}') + log.info(f'probabilities is {predictor.probabilities}') + log.info(f'alignment is {predictor.alignment}') - probabilities = validate(config, - 'paths.analysis.probabilities', - 'No probabilities file provided', - probabilities) - log.info(f'probabilities is {probabilities}') - - alignment = validate(config, - 'paths.analysis.alignment', - 'No alignment file provided', - alignment) - check_wildcards(alignment, 'prefix,strain,chrom') - alignment = alignment.replace('{prefix}', prefix) - log.info(f'alignment is {alignment}') - - -def get_strains(config: Dict, - test_strains: List, - prefix: str, - chromosomes: List): - ''' - Helper method to get strains supplied in config, or from test_strains - ''' - strains = get_nested(config, 'strains') - - if strains is None: - # try to build strains from wildcards in test_strains - strains = {} - for test_strain in test_strains: - strain_glob = test_strain.format( - prefix=prefix, - strain='*', - chrom='*') - log.info(f'searching for {strain_glob}') - for fname in glob.iglob(strain_glob): - match = re.match( - test_strain.format( - prefix=prefix, - strain='(?P.*?)', - chrom='(?P[^_]*?)' - ), - fname) - if match: - log.debug(f'matched with {match.group("strain", "chrom")}') - strain, chrom = match.group('strain', 'chrom') - if strain not in strains: - strains[strain] = [] - strains[strain].append(chrom) - - if len(strains) == 0: - err = f'Found no chromosome sequence files in {test_strains}' - log.exception(err) - raise ValueError(err) - - for strain, chroms in strains.items(): - if len(chromosomes) != len(chroms): - err = (f'Strain {strain} has incorrect number of chromosomes. ' - f'Expected {len(chromosomes)} found {len(chroms)}') - log.exception(err) - raise ValueError(err) - return list(sorted(strains.keys())) + predictor.validate_arguments() + predictor.run_prediction() diff --git a/code/analyze/predict.py b/code/analyze/predict.py index c34790c..fbcfe5f 100644 --- a/code/analyze/predict.py +++ b/code/analyze/predict.py @@ -1,18 +1,20 @@ import copy -import os import gzip +import glob +import re import itertools from collections import defaultdict, Counter from hmm import hmm_bw from sim import sim_predict from sim import sim_process import global_params as gp -from misc import read_fasta import numpy as np from typing import List, Dict, Tuple, TextIO from contextlib import ExitStack import logging as log from misc.read_fasta import read_fasta +from misc.config_utils import (check_wildcards, validate, + get_states, get_nested) def process_predict_args(arg_list: List[str]) -> Dict: @@ -74,24 +76,237 @@ def process_predict_args(arg_list: List[str]) -> Dict: return d -def read_aligned_seqs(fast_file: str, - strain: str) -> Tuple[np.array, np.array]: - ''' - Read fasta file, returning sequences of references and the specied strain - ''' - headers, seqs = read_fasta.read_fasta(fast_file) - d = {} - for i in range(len(seqs)): - name = headers[i][1:].split(' ')[0] - d[name] = seqs[i] - - ref_seqs = [] - for ref in gp.alignment_ref_order: - ref_seqs.append(d[ref]) - predict_seq = d[strain] - - return ref_seqs, predict_seq +class Predictor(): + ''' + Predictor class + Stores all variables needed to run an HMM prediction + ''' + def __init__(self, configuration: Dict): + self.config = configuration + self.known_states, self.unknown_states = get_states(self.config) + self.chromosomes = None + self.blocks = None + self.prefix = None + self.strains = None + self.hmm_initial = None + self.hmm_trained = None + self.positions = None + self.probabilities = None + self.alignment = None + + def set_chromosomes(self): + ''' + Gets the chromosome list from provided config, raising a ValueError + if undefined. + ''' + self.chromosomes = validate( + self.config, + 'chromosomes', + 'No chromosomes specified in config file!') + + def set_blocks_file(self, blocks: str = None): + ''' + Set the block wildcard filename. Checks for appropriate wildcards + ''' + self.blocks = validate( + self.config, + 'paths.analysis.block_files', + 'No block file provided', + blocks) + + check_wildcards(self.blocks, 'state') + + def set_prefix(self, prefix: str = ''): + ''' + Set prefix string of the predictor to the supplied value or + build it from the known states + ''' + if prefix == '': + self.prefix = '_'.join(self.known_states) + else: + self.prefix = prefix + + def set_strains(self, test_strains: str = ''): + ''' + build the strains to perform prediction on + ''' + if test_strains == '': + test_strains = get_nested(self.config, 'paths.test_strains') + else: + # need to support list for test strains + test_strains = [test_strains] + for test_strain in test_strains: + check_wildcards(test_strain, 'strain,chrom') + + self.find_strains(test_strains) + + def find_strains(self, test_strains: List[str]): + ''' + Helper method to get strains supplied in config, or from test_strains + ''' + strains = get_nested(self.config, 'strains') + + if strains is None: + # try to build strains from wildcards in test_strains + strains = {} + for test_strain in self.test_strains: + # find matching files + strain_glob = test_strain.format( + prefix=self.prefix, + strain='*', + chrom='*') + log.info(f'searching for {strain_glob}') + for fname in glob.iglob(strain_glob): + # extract wildcard matches + match = re.match( + test_strain.format( + prefix=self.prefix, + strain='(?P.*?)', + chrom='(?P[^_]*?)' + ), + fname) + if match: + log.debug( + f'matched with {match.group("strain", "chrom")}') + strain, chrom = match.group('strain', 'chrom') + if strain not in strains: + strains[strain] = [] + strains[strain].append(chrom) + + if len(strains) == 0: + err = ('Found no chromosome sequence files ' + f'in {self.test_strains}') + log.exception(err) + raise ValueError(err) + + for strain, chroms in strains.items(): + if len(self.chromosomes) != len(chroms): + err = (f'Strain {strain} has incorrect number of ' + f'chromosomes. Expected {len(chromosomes)} ' + f'found {len(chroms)}') + log.exception(err) + raise ValueError(err) + + self.strains = list(sorted(strains.keys())) + + def set_output_files(self, + hmm_initial: str, + hmm_trained: str, + positions: str, + probabilities: str, + alignment: str): + ''' + Set output files from provided values or config. + Raises value errors if a file is not provided. + Checks alignment for all wildcards and replaces prefix. + ''' + self.hmm_initial = validate(self.config, + 'paths.analysis.hmm_initial', + 'No initial hmm file provided', + hmm_initial) + + self.hmm_trained = validate(self.config, + 'paths.analysis.hmm_trained', + 'No trained hmm file provided', + hmm_trained) + + self.positions = validate(self.config, + 'paths.analysis.positions', + 'No positions file provided', + positions) + + self.probabilities = validate(self.config, + 'paths.analysis.probabilities', + 'No probabilities file provided', + probabilities) + + alignment = validate(self.config, + 'paths.analysis.alignment', + 'No alignment file provided', + alignment) + check_wildcards(alignment, 'prefix,strain,chrom') + self.alignment = alignment.replace('{prefix}', self.prefix) + + def validate_arguments(self): + ''' + Check that all required instance variables are set to perform a + prediction run + ''' + args = [ + 'chromosomes', + 'blocks', + 'prefix', + 'strains', + 'hmm_initial', + 'hmm_trained', + 'positions', + 'probabilities', + 'alignment', + ] + variables = self.__dict__ + for arg in args: + if variables[arg] is None: + err = ('Failed to validate Predictor, required argument ' + f'{arg} was unset') + log.exception(err) + raise ValueError(err) + + def run_prediction(self): + ''' + Run prediction with this predictor object + ''' + self.emission_symbols = get_emis_symbols(self.known_states) + + with open(self.hmm_initial, 'w') as initial, \ + open(self.hmm_trained, 'w') as trained, \ + gzip.open(self.positions, 'wt') as positions, \ + gzip.open(self.probabilities, 'wt') as probabilities, \ + ExitStack() as stack: + + block_writers = {state: + stack.enter_context( + open(self.blocks.format(state=state), 'w')) + for state in + self.known_states + self.unknown_states} + + self.write_hmm_header(initial) + self.write_hmm_header(trained) + + for chrom in chromosomes: + for strain in strains: + log.info(f'working on: {strain} {chrom}') + alignment_file = alignment.format(strain=strain, chrom=chrom) + + headers, sequences = read_fasta(alignment_file) + + references = sequences[:-1] + predicted = sequences[-1] + + states, probabilities, hmm_trained, hmm_initial, positions =\ + predict_introgressed(references, predicted, + ARGS, train=True) + + def write_hmm_header(self, writer: TextIO) -> None: + ''' + Write the header line for an hmm file to the provided textIO object + Output is tab delimited with: + strain chromosome initial_probs emissions transitions + ''' + + writer.write('strain\tchromosome\t') + + states = self.known_states + self.unknown_states + + writer.write('\t'.join( + [f'init_{s}' for s in states] + # initial + [f'emis_{s}_{symbol}' + for s in states + for symbol in self.emission_symbols] + # emissions + [f'trans_{s1}_{s2}' + for s1 in states + for s2 in states])) # transitions + writer.write('\n') def set_expectations(args: Dict, n: int) -> None: ''' @@ -570,30 +785,6 @@ def get_emis_symbols(known_states: List[str]) -> List[str]: return emis_symbols -def write_hmm_header(known_states: List[str], - unknown_states: List[str], - symbols: List[str], - writer: TextIO) -> None: - ''' - Write the header line for an hmm file to the provided textIO object - Output is tab delimited with: - strain chromosome initial_probs emissions transitions - ''' - - writer.write('strain\tchromosome\t') - - states = known_states + unknown_states - - writer.write('\t'.join( - [f'init_{s}' for s in states] + # initial - [f'emis_{s}_{symbol}' - for s in states - for symbol in symbols] + # emissions - [f'trans_{s1}_{s2}' - for s1 in states - for s2 in states])) # transitions - - writer.write('\n') def write_hmm(hmm: hmm_bw.HMM, @@ -641,41 +832,3 @@ def write_state_probs(probs: Dict[str, List[float]], for i, state in enumerate(states)])) writer.write('\n') - - -def run(known_states, unknown_states, - hmm_initial, hmm_trained, - blocks, positions, probabilities, - chromosomes, strains, alignment): - - emission_symbols = get_emis_symbols(known_states) - - with open(hmm_initial, 'w') as initial, \ - open(hmm_trained, 'w') as trained, \ - gzip.open(positions, 'wt') as positions, \ - gzip.open(probabilities, 'wt') as probabilities, \ - ExitStack() as stack: - - block_writers = {state: - stack.enter_context( - open(blocks.format(state=state), 'w')) - for state in known_states + unknown_states} - - write_hmm_header(known_states, unknown_states, - emission_symbols, initial) - write_hmm_header(known_states, unknown_states, - emission_symbols, trained) - - for chrom in chromosomes: - for strain in strains: - log.info(f'working on: {strain} {chrom}') - alignment_file = alignment.format(strain=strain, chrom=chrom) - - headers, sequences = read_fasta(alignment_file) - - references = sequences[:-1] - predicted = sequences[-1] - - states, probabilities, hmm_trained, hmm_initial, positions =\ - predict_introgressed(references, predicted, - ARGS, train=True) diff --git a/code/environment.yml b/code/environment.yml index 7dca4f3..227ba9e 100644 --- a/code/environment.yml +++ b/code/environment.yml @@ -1,13 +1,15 @@ --- -name: introgression3 +name: introgression channels: + - conda-forge - defaults dependencies: - atomicwrites=1.3.0=py_0 - attrs=18.2.0=py37h28b3542_0 - blas=1.0=mkl - - ca-certificates=2019.1.23=0 - - certifi=2018.11.29=py37_0 + - ca-certificates=2019.3.9=hecc5488_0 + - certifi=2019.3.9=py37_0 + - click=7.0=py_0 - coverage=4.5.2=py37h7b6447c_0 - intel-openmp=2019.1=144 - libedit=3.1.20181209=hc058e9b_0 @@ -22,7 +24,7 @@ dependencies: - ncurses=6.1=he6710b0_1 - numpy=1.15.4=py37h7e9f1db_0 - numpy-base=1.15.4=py37hde5b4d6_0 - - openssl=1.1.1a=h7b6447c_0 + - openssl=1.1.1b=h14c3975_1 - pip=19.0.1=py37_0 - pluggy=0.8.1=py37_0 - py=1.7.0=py37_0 @@ -30,6 +32,7 @@ dependencies: - pytest-cov=2.6.1=py37_0 - pytest-mock=1.10.0=py37_0 - python=3.7.2=h0371630_0 + - pyyaml=5.1=py37h14c3975_0 - readline=7.0=h7b6447c_5 - setuptools=40.8.0=py37_0 - six=1.12.0=py37_0 @@ -37,4 +40,5 @@ dependencies: - tk=8.6.8=hbc83047_0 - wheel=0.32.3=py37_0 - xz=5.2.4=h14c3975_4 + - yaml=0.1.7=h14c3975_1001 - zlib=1.2.11=h7b6447c_3 From dfc08462a75e87e2bff20f9ca20c559deec48697 Mon Sep 17 00:00:00 2001 From: Troy Comi Date: Thu, 25 Apr 2019 15:02:21 -0400 Subject: [PATCH 15/33] Translated predict_main Changed the implementation of predict_main.py into click with support for the new yaml configuration file. Refactored predict into two main objects to simplify the main code. Added a README.md --- .gitignore | 1 + README.md | 53 + code/analyze/main.py | 72 +- code/analyze/predict.py | 1079 ++++++++++------- code/analyze/predict_main.py | 122 -- code/config.yaml | 13 +- code/hmm/hmm_bw.py | 63 +- code/misc/config_utils.py | 2 + code/misc/read_fasta.py | 2 +- code/sim/sim_analyze_hmm_bw.py | 7 +- code/sim/sim_process.py | 4 +- code/test/analyze/test_main.py | 101 +- code/test/analyze/test_main_predict_args.py | 285 +++++ code/test/analyze/test_main_predict_config.py | 397 ++++++ code/test/analyze/test_predict.py | 715 ----------- code/test/analyze/test_predict_hmm_builder.py | 527 ++++++++ code/test/analyze/test_predict_predictor.py | 1052 ++++++++++++++++ code/test/hmm/test_hmm_bw.py | 19 +- code/test/misc/test_config_utils.py | 1 + 19 files changed, 3130 insertions(+), 1385 deletions(-) create mode 100644 README.md delete mode 100644 code/analyze/predict_main.py create mode 100644 code/test/analyze/test_main_predict_args.py create mode 100644 code/test/analyze/test_main_predict_config.py delete mode 100644 code/test/analyze/test_predict.py create mode 100644 code/test/analyze/test_predict_hmm_builder.py create mode 100644 code/test/analyze/test_predict_predictor.py diff --git a/.gitignore b/.gitignore index 124f3ae..b00c0bd 100644 --- a/.gitignore +++ b/.gitignore @@ -34,3 +34,4 @@ code/setup/* .coverage *.swp *egg-info +tags diff --git a/README.md b/README.md new file mode 100644 index 0000000..b1aa71d --- /dev/null +++ b/README.md @@ -0,0 +1,53 @@ +# introgression +> Some sort of short, description + +## Background +Things about science + +## Installation +All required packages are specified in the conda environment located in +`code/environment.yml`. The introgression environment can be generated with +``` +conda env create -f environment.yml +``` +To access the command line bindings of the main class, install with pip with +``` +conda activate introgression +pip install --editable . +``` +while in the code directory. + +## Usage + +### Configuration +A set of initial parameters are provided in `code/config.yaml` which need to +be set specific for your system and dataset. + +Strings of the form \_\_KEY\_\_ +are substituted during execution and are used as a shortcut. For example, +with 'output\_root' set to `/data/results`, the value `__OUTPUT_ROOT__/genes/` +becomes `/data/results/genes/` + +Strings of the form {state} are used for wildcards within the code. Their +location and surrounding characters can change, but the wildcard must be the +same. For example, `blocks_{state}.txt` can be changed to +`{state}_with-block.txt` but not `blocks_{st}.txt`. + +### Command Line +With the package installed and the conda environment activated, main methods +are accessed with the `introgression` command. Some documentation is provided +by adding the argument `--help` to introgression or any of its subcommands. + +### introgression +Options include: +- --config: specify one or more configuration files. Files are evaluated in +order. Conflicting values are overwritten by the newest file. This allows a +base configuration for the system and analysis-specific configurations added +as needed. +- verbosity: set by varying the number of v's attached to the option, with +`-v` indicating a log level of critical and `-vvvvv` indicating debug logging. +Available subcommands are: +- predict + +## License +TBD diff --git a/code/analyze/main.py b/code/analyze/main.py index 6fd217a..281c307 100644 --- a/code/analyze/main.py +++ b/code/analyze/main.py @@ -1,12 +1,8 @@ import click import yaml -import glob -import re import logging as log from misc import config_utils -from misc.config_utils import (get_nested, check_wildcards, get_states, - validate) -from typing import List, Dict +import analyze.predict # TODO also check for snakemake object? @@ -15,27 +11,30 @@ multiple=True, type=click.File('r'), help='Base configuration yaml.') -@click.option('-v', '--verbosity', count=True, default=2) +@click.option('-v', '--verbosity', count=True, default=3) @click.pass_context def cli(ctx, config, verbosity): ''' Main entry script to run analyze methods ''' + verbosity -= 1 verbosity = 4 if verbosity > 4 else verbosity - levelstr = ['CRITICAL', 'ERROR', - 'WARNING', 'INFO', - 'DEBUG'][verbosity] - level = [log.CRITICAL, log.ERROR, - log.WARNING, log.INFO, - log.DEBUG][verbosity] + levelstr, level = [ + ('CRITICAL', log.CRITICAL), + ('ERROR', log.ERROR), + ('WARNING', log.WARNING), + ('INFO', log.INFO), + ('DEBUG', log.DEBUG), + ][verbosity] log.basicConfig(level=level) log.info(f'Verbosity set to {levelstr}') ctx.ensure_object(dict) - log.info(f'Reading in {len(config)} config files') + confs = len(config) + log.info(f'Reading in {confs} config file{"" if confs == 1 else "s"}') for path in config: conf = yaml.safe_load(path) ctx.obj = config_utils.merge_dicts(ctx.obj, conf) @@ -66,9 +65,14 @@ def cli(ctx, config, verbosity): help='Positions file, gzipped') @click.option('--probabilities', default='', help='Probabilities file, gzipped') +@click.option('--threshold', default='', + help='Threshold to apply to estimated path. Valid values are ' + 'floats or `viterbi\'') @click.option('--alignment', default='', help='Alignment file location with ' '{prefix}, {strain}, and {chrom}') +@click.option('--only-poly-sites/--all-sites', default=True, + help='Consider only polymorphic sites or all sites') def predict(ctx, blocks, prefix, @@ -77,32 +81,46 @@ def predict(ctx, hmm_trained, positions, probabilities, - alignment): + threshold, + alignment, + only_poly_sites): config = ctx.obj - predictor = predict.Predictor(config) + predictor = analyze.predict.Predictor(config) predictor.set_chromosomes() + log.info(f'Found {len(predictor.chromosomes)} chromosomes in config') + + predictor.set_threshold(threshold) + log.info(f'Threshold value is \'{predictor.threshold}\'') predictor.set_blocks_file(blocks) - log.info(f'output blocks file for predict is {predictor.blocks}') + log.info(f'Output blocks file is \'{predictor.blocks}\'') predictor.set_prefix(prefix) - log.info(f'prefix is {predictor.prefix}') + log.info(f'Prefix is \'{predictor.prefix}\'') predictor.set_strains(test_strains) - log.info(f'found {len(predictor.test_strains)} test strains') - log.info(f'found {len(predictor.strains)} unique strains') + if predictor.test_strains is None: + log.info(f'No test_strains provided') + else: + str_len = len(predictor.test_strains) + log.info(f'Found {str_len} test strain' + f'{"" if str_len == 1 else "s"}') + log.info(f'Found {len(predictor.strains)} unique strains') predictor.set_output_files(hmm_initial, hmm_trained, positions, probabilities, alignment) - log.info(f'hmm_initial is {predictor.hmm_initial}') - log.info(f'hmm_trained is {predictor.hmm_trained}') - log.info(f'positions is {positions}') - log.info(f'probabilities is {predictor.probabilities}') - log.info(f'alignment is {predictor.alignment}') - - predictor.validate_arguments() - predictor.run_prediction() + log.info(f'Hmm_initial file is \'{predictor.hmm_initial}\'') + log.info(f'Hmm_trained file is \'{predictor.hmm_trained}\'') + log.info(f'Positions file is \'{predictor.positions}\'') + log.info(f'Probabilities file is \'{predictor.probabilities}\'') + log.info(f'Alignment file is \'{predictor.alignment}\'') + + predictor.run_prediction(only_poly_sites) + + +if __name__ == '__main__': + cli() diff --git a/code/analyze/predict.py b/code/analyze/predict.py index fbcfe5f..e819b55 100644 --- a/code/analyze/predict.py +++ b/code/analyze/predict.py @@ -7,7 +7,6 @@ from hmm import hmm_bw from sim import sim_predict from sim import sim_process -import global_params as gp import numpy as np from typing import List, Dict, Tuple, TextIO from contextlib import ExitStack @@ -17,11 +16,13 @@ get_states, get_nested) +# TODO remove gp references for symbols. pass args or fold into object? def process_predict_args(arg_list: List[str]) -> Dict: ''' Parses arguments from argv, producing dictionary of parsed values ''' + import global_params as gp d = {} i = 0 @@ -84,6 +85,7 @@ class Predictor(): def __init__(self, configuration: Dict): self.config = configuration self.known_states, self.unknown_states = get_states(self.config) + self.states = self.known_states + self.unknown_states self.chromosomes = None self.blocks = None self.prefix = None @@ -93,6 +95,7 @@ def __init__(self, configuration: Dict): self.positions = None self.probabilities = None self.alignment = None + self.threshold = None def set_chromosomes(self): ''' @@ -122,10 +125,32 @@ def set_prefix(self, prefix: str = ''): build it from the known states ''' if prefix == '': + if self.known_states == []: + err = 'Unable to build prefix, no known states provided' + log.exception(err) + raise ValueError(err) + self.prefix = '_'.join(self.known_states) else: self.prefix = prefix + def set_threshold(self, threshold: str = None): + ''' + Set the threshold. Checks if set and converts to float if possible + ''' + self.threshold = validate( + self.config, + 'analysis_params.threshold', + 'No threshold provided', + threshold) + try: + self.threshold = float(self.threshold) + except ValueError: + if self.threshold != 'viterbi': + err = f'Unsupported threshold value: {self.threshold}' + log.exception(err) + raise ValueError(err) + def set_strains(self, test_strains: str = ''): ''' build the strains to perform prediction on @@ -135,32 +160,40 @@ def set_strains(self, test_strains: str = ''): else: # need to support list for test strains test_strains = [test_strains] - for test_strain in test_strains: - check_wildcards(test_strain, 'strain,chrom') + + if test_strains is not None: + for test_strain in test_strains: + check_wildcards(test_strain, 'strain,chrom') self.find_strains(test_strains) - def find_strains(self, test_strains: List[str]): + def find_strains(self, test_strains: List[str] = None): ''' Helper method to get strains supplied in config, or from test_strains ''' strains = get_nested(self.config, 'strains') + self.test_strains = test_strains if strains is None: + if test_strains is None: + err = ('Unable to find strains in config and ' + 'no test_strains provided') + log.exception(err) + raise ValueError(err) + # try to build strains from wildcards in test_strains strains = {} - for test_strain in self.test_strains: + for test_strain in test_strains: # find matching files strain_glob = test_strain.format( - prefix=self.prefix, strain='*', chrom='*') log.info(f'searching for {strain_glob}') for fname in glob.iglob(strain_glob): # extract wildcard matches + print(fname) match = re.match( test_strain.format( - prefix=self.prefix, strain='(?P.*?)', chrom='(?P[^_]*?)' ), @@ -175,19 +208,22 @@ def find_strains(self, test_strains: List[str]): if len(strains) == 0: err = ('Found no chromosome sequence files ' - f'in {self.test_strains}') + f'in {test_strains}') log.exception(err) raise ValueError(err) for strain, chroms in strains.items(): if len(self.chromosomes) != len(chroms): err = (f'Strain {strain} has incorrect number of ' - f'chromosomes. Expected {len(chromosomes)} ' + f'chromosomes. Expected {len(self.chromosomes)} ' f'found {len(chroms)}') log.exception(err) raise ValueError(err) - self.strains = list(sorted(strains.keys())) + self.strains = list(sorted(strains.keys())) + + else: # strains set in config + self.strains = list(sorted(set(strains))) def set_output_files(self, hmm_initial: str, @@ -210,10 +246,11 @@ def set_output_files(self, 'No trained hmm file provided', hmm_trained) - self.positions = validate(self.config, - 'paths.analysis.positions', - 'No positions file provided', - positions) + if positions == '': + self.positions = get_nested(self.config, + 'paths.analysis.positions') + else: + self.positions = positions self.probabilities = validate(self.config, 'paths.analysis.probabilities', @@ -230,7 +267,7 @@ def set_output_files(self, def validate_arguments(self): ''' Check that all required instance variables are set to perform a - prediction run + prediction run. Returns true if valid, raises value error otherwise ''' args = [ 'chromosomes', @@ -239,9 +276,11 @@ def validate_arguments(self): 'strains', 'hmm_initial', 'hmm_trained', - 'positions', 'probabilities', 'alignment', + 'known_states', + 'unknown_states', + 'threshold', ] variables = self.__dict__ for arg in args: @@ -251,40 +290,109 @@ def validate_arguments(self): log.exception(err) raise ValueError(err) - def run_prediction(self): + # check the parameters for each state are present + known_states = get_nested(self.config, + 'analysis_params.known_states') + if known_states is None: + err = 'Configuration did not provide any known_states' + log.exception(err) + raise ValueError(err) + + for s in known_states: + if 'expected_length' not in s: + err = f'{s["name"]} did not provide an expected_length' + log.exception(err) + raise ValueError(err) + if 'expected_fraction' not in s: + err = f'{s["name"]} did not provide an expected_fraction' + log.exception(err) + raise ValueError(err) + + unknown_states = get_nested(self.config, + 'analysis_params.unknown_states') + if unknown_states is not None: + for s in unknown_states: + if 'expected_length' not in s: + err = f'{s["name"]} did not provide an expected_length' + log.exception(err) + raise ValueError(err) + if 'expected_fraction' not in s: + err = f'{s["name"]} did not provide an expected_fraction' + log.exception(err) + raise ValueError(err) + + reference = get_nested(self.config, + 'analysis_params.reference') + if reference is None: + err = f'Configuration did not specify a reference strain' + log.exception(err) + raise ValueError(err) + + return True + + def run_prediction(self, only_poly_sites=True): ''' Run prediction with this predictor object ''' - self.emission_symbols = get_emis_symbols(self.known_states) + self.validate_arguments() + + hmm_builder = HMM_Builder(self.config) + hmm_builder.set_expected_values() + self.emission_symbols = \ + hmm_builder.update_emission_symbols(len(self.known_states)) with open(self.hmm_initial, 'w') as initial, \ open(self.hmm_trained, 'w') as trained, \ - gzip.open(self.positions, 'wt') as positions, \ gzip.open(self.probabilities, 'wt') as probabilities, \ ExitStack() as stack: + self.write_hmm_header(initial) + self.write_hmm_header(trained) + + if self.positions is not None: + positions = stack.enter_context( + gzip.open(self.positions, 'wt')) + else: + positions = None + block_writers = {state: stack.enter_context( open(self.blocks.format(state=state), 'w')) for state in - self.known_states + self.unknown_states} + self.states} + for writer in block_writers.values(): + self.write_blocks_header(writer) - self.write_hmm_header(initial) - self.write_hmm_header(trained) - - for chrom in chromosomes: - for strain in strains: + for chrom in self.chromosomes: + for strain in self.strains: log.info(f'working on: {strain} {chrom}') - alignment_file = alignment.format(strain=strain, chrom=chrom) - headers, sequences = read_fasta(alignment_file) + # get sequences and encode + alignment_file = self.alignment.format( + strain=strain, chrom=chrom) + + hmm_initial, hmm_trained, pos = hmm_builder.run_hmm( + alignment_file, only_poly_sites) + + self.write_hmm(hmm_initial, initial, strain, chrom) + self.write_hmm(hmm_trained, trained, strain, chrom) - references = sequences[:-1] - predicted = sequences[-1] + # process and threshold hmm result + predicted_states, probs = self.process_path(hmm_trained) + state_blocks = self.convert_to_blocks(predicted_states) - states, probabilities, hmm_trained, hmm_initial, positions =\ - predict_introgressed(references, predicted, - ARGS, train=True) + if positions is not None: + self.write_positions(pos, positions, strain, chrom) + + for state, block in state_blocks.items(): + self.write_blocks(block, + pos, + block_writers[state], + strain, + chrom, + state) + + self.write_state_probs(probs, probabilities, strain, chrom) def write_hmm_header(self, writer: TextIO) -> None: ''' @@ -308,389 +416,519 @@ def write_hmm_header(self, writer: TextIO) -> None: writer.write('\n') -def set_expectations(args: Dict, n: int) -> None: - ''' - sets expected number of tracts and bases for each reference - based on expected length of introgressed tracts and expected - total fraction of genome - also takes n, length of the sequence to analyze - ''' + def write_hmm(self, + hmm: hmm_bw.HMM, + writer: TextIO, + strain: str, + chrm: str): + ''' + Write information on the provided hmm as a line to the supplied textIO + object. + Output is tab delimited with: + strain chromosome initial_probs emissions transitions + ''' + writer.write(f'{strain}\t{chrm}\t') - species_to = args['known_states'][0] - species_from = args['known_states'][1:] + states = len(hmm.hidden_states) + writer.write('\t'.join( + [f'{p}' for p in hmm.initial_p] + # initial + [f'{hmm.emissions[i, hmm.symbol_to_ind[symbol]]}' + if symbol in hmm.symbol_to_ind else '0.0' + for i in range(states) + for symbol in self.emission_symbols] + # emission + [f'{hmm.transitions[i, j]}' + for i in range(states) + for j in range(states)] # transition + )) + writer.write('\n') - args['expected_num_tracts'] = {} - args['expected_bases'] = {} - for s in species_from: - args['expected_num_tracts'][s] = \ - args['expected_frac'][s] * n / args['expected_length'][s] - args['expected_bases'][s] = args['expected_num_tracts'][s] * \ - args['expected_length'][s] + def write_blocks_header(self, writer: TextIO) -> None: + ''' + Write header line to tab delimited block file: + strain chromosome predicted_species start end num_sites_hmm + ''' + # NOTE: num_sites_hmm represents the sites considered by the HMM, + # so it might exclude non-polymorphic sites in addition to gaps + writer.write('\t'.join(['strain', + 'chromosome', + 'predicted_species', + 'start', + 'end', + 'num_sites_hmm']) + + '\n') + + def write_blocks(self, + state_seq_blocks: List[Tuple[int, int]], + positions: np.array, + writer: TextIO, + strain: str, + chrm: str, + species_pred: str) -> None: + ''' + Write entry into tab delimited block file, with columns: + strain chromosome predicted_species start end num_sites_hmm + ''' + writer.write('\n'.join( + ['\t'.join([strain, + chrm, + species_pred, + str(positions[start]), + str(positions[end]), + str(end - start + 1)]) + for start, end in state_seq_blocks])) + if state_seq_blocks: # ensure ends with \n + writer.write('\n') + + def write_positions(self, + positions: np.array, + writer: TextIO, + strain: str, + chrm: str) -> None: + ''' + Write the positions of the specific strain, chromosome as a line to the + provided textIO object + ''' + writer.write(f'{strain}\t{chrm}\t' + + '\t'.join([str(x) for x in positions]) + '\n') + + def write_state_probs(self, + probs: Dict[str, List[float]], + writer: TextIO, + strain: str, + chrm: str) -> None: + ''' + Write the probability of each state to the supplied textIO object + Output is tab delimited with: + strain chrom state1:prob1,prob2,...,probn state2... + ''' + writer.write(f'{strain}\t{chrm}\t') - args['expected_bases'][species_to] = \ - n - sum([args['expected_bases'][s] for s in species_from]) + writer.write('\t'.join( + [f'{state}:' + + ','.join([f'{site[i]:.5f}' for site in probs]) + for i, state in enumerate(self.states)])) - args['expected_num_tracts'][species_to] = \ - sum([args['expected_num_tracts'][s] for s in species_from]) + 1 + writer.write('\n') - args['expected_length'][species_to] = \ - args['expected_bases'][species_to] /\ - args['expected_num_tracts'][species_to] + def process_path(self, hmm: hmm_bw.HMM) -> Tuple[List[str], np.array]: + ''' + Process the hmm path based the the predictor threshold value + Return the predicted states and the probabilities of the master + reference sequence + ''' + probabilities = hmm.posterior_decoding()[0] + # posterior + if type(self.threshold) is float: + path, path_probs = sim_process.get_max_path(probabilities, + hmm.hidden_states) + path_t = sim_process.threshold_predicted(path, path_probs, + self.threshold, + self.known_states[0]) + return path_t, probabilities -def ungap_and_code(predict_seq: str, - ref_seqs: List[str], - index_ref: int = 0) -> Tuple[np.array, np.array]: - ''' - Remove any sequence locations where a gap is present and code - into matching or mismatching sequence - Returns the coded sequences, by default an array of + where matching, - - where mismatching. Also return the positions where the sequences are not - gapped. - ''' - # index_ref is index of reference strain to index relative to - # build character array - sequences = np.array([list(predict_seq)] + - [list(r) for r in ref_seqs]) + else: + predicted = sim_predict.convert_predictions(hmm.viterbi(), + self.states) + return predicted, probabilities - isbase = sequences != gp.gap_symbol + def convert_to_blocks(self, + state_seq: List[str]) -> Dict[ + str, List[Tuple[int, int]]]: + ''' + Convert a list of sequences into a structure of start and end positions + Return structure is a dict keyed on species with values of Lists of + each block, which is a tuple with start and end positions + ''' + # single individual state sequence + blocks = {} + for state in self.states: + blocks[state] = [] + prev_species = state_seq[0] + block_start = 0 + block_end = 0 + for i in range(len(state_seq)): + if state_seq[i] == prev_species: + block_end = i + else: + blocks[prev_species].append((block_start, block_end)) + block_start = i + block_end = i + prev_species = state_seq[i] + # add last block + if prev_species not in blocks: + blocks[prev_species] = [] + blocks[prev_species].append((block_start, block_end)) - # make boolean for valid characters - isvalid = np.logical_and(sequences != gp.gap_symbol, - sequences != gp.unsequenced_symbol) + return blocks - # positions are where everything is valid, index where the reference is - # valid. The +1 removes the predict sequence at index 0 - positions = np.where( - np.all(isvalid[:, isbase[index_ref+1, :]], axis=0))[0] - matches = np.where(sequences[0] == sequences[1:], - gp.match_symbol, - gp.mismatch_symbol) +class HMM_Builder(): + def __init__(self, configuration): + self.config = configuration + self.symbols = { + 'match': '+', + 'mismatch': '-', + 'unknown': '?', + 'unsequenced': 'n', + 'gap': '-', + 'unaligned': '?', + 'masked': 'x' + } + config_symbols = get_nested(self.config, 'HMM_symbols') + if config_symbols is not None: + for k, v in config_symbols.items(): + if k not in self.symbols: + log.warning("Unused symbol in configuration: " + f"{k} -> '{v}'") + else: + self.symbols[k] = v + log.debug(f"Overwriting default symbol for {k} with '{v}'") + + for k, v in self.symbols.items(): + if k not in config_symbols: + log.warning(f'Symbol for {k} unset in config, ' + f"using default '{v}'") - matches = np.fromiter((''.join(row) - for row in np.transpose( - matches[:, np.all(isvalid, axis=0)])), - dtype=f'U{len(sequences) - 1}') + else: + for k, v in self.symbols.items(): + log.warning(f'Symbol for {k} unset in config, ' + f"using default '{v}'") + + self.convergence = get_nested(self.config, + 'analysis_params.convergence_threshold') + if self.convergence is None: + log.warning('No value set for convergence_threshold, using ' + 'default of 0.001') + self.convergence = 0.001 + + def update_emission_symbols(self, repeats: int): + ''' + Generate all permutations of match and mismatch symbols with + repeats number of characters, in lexigraphical order. + Sets internal state and returns the emission symbols + ''' + syms = [self.symbols['match'], self.symbols['mismatch']] + emis_symbols = [''.join(x) for x in + itertools.product(syms, + repeat=repeats)] + emis_symbols.sort() + self.emission_symbols = emis_symbols + return emis_symbols + + def get_symbol_freqs(self, sequence: np.array) -> Tuple[Dict, List]: + ''' + Calculate metrics from the provided, coded sequence + Returns: + the fraction of each matching pattern (e.g. +--++) + the weighted fraction of matches for each species + ''' - return matches, positions + weighted = [] + symbols = defaultdict(int, Counter(sequence)) + total = len(sequence) + for k in symbols: + symbols[k] /= total -def poly_sites(sequences: np.array, - positions: np.array) -> Tuple[np.array, np.array]: - ''' - Remove all sequences where the sequence is all match_symbol - Returns the filtered sequence and position - ''' - seq_len = len(sequences[0]) - # check if seq only contains match_symbol - retain = np.vectorize( - lambda x: x.count(gp.match_symbol) != seq_len)(sequences) - indices = np.where(retain)[0] + sequence = np.array([list(s) for s in sequence]) - ps_poly = positions[indices] - seq_poly = sequences[indices] + # look along species + for s in np.transpose(sequence): + s = ''.join(s) + counts = Counter(s) + weighted.append(counts[self.symbols['match']]) - return seq_poly, ps_poly + total = sum(weighted) + weighted = [w / total for w in weighted] + return symbols, weighted + def set_expected_values(self): + ''' + Get expected lengths and fractions for each state. + Assumes config has been validated by Predictor prior to running + ''' + self.expected_lengths = {} + self.expected_fractions = {} + known_states = get_nested(self.config, + 'analysis_params.known_states') + for state in known_states: + self.expected_lengths[state['name']] = state['expected_length'] + self.expected_fractions[state['name']] = state['expected_fraction'] + + unknown_states = get_nested(self.config, + 'analysis_params.unknown_states') + for state in unknown_states: + self.expected_lengths[state['name']] = state['expected_length'] + self.expected_fractions[state['name']] = state['expected_fraction'] + + reference = get_nested(self.config, + 'analysis_params.reference') + # expected fraction of reference is the remainder after other states + # are specified + self.expected_fractions[reference['name']] =\ + 1 - sum(self.expected_fractions.values()) -def get_symbol_freqs(sequence: np.array) -> Tuple[Dict, Dict, List]: - ''' - Calculate metrics from the provided, coded sequence - Returns: - the fraction matching for each species - the fraction of each matching pattern (e.g. +--++) - the weighted fraction of matches for each species - ''' + self.known_states, self.unknown_states = get_states(self.config) - individual = [] - weighted = [] + self.ref_state = get_nested(self.config, + 'analysis_params.reference.name') - symbols = defaultdict(int, Counter(sequence)) - total = len(sequence) - for k in symbols: - symbols[k] /= total + # have to remove effect of unknown of these values for later + self.ref_fraction = self.expected_fractions[self.ref_state] + \ + sum([self.expected_fractions[s] for s in self.unknown_states]) + # sum of fraction / length, or 1 / tract length + self.other_sum = sum([self.expected_fractions[s['name']] / + self.expected_lengths[s['name']] + for s in known_states]) - sequence = np.array([list(s) for s in sequence]) + def update_expected_length(self, total_length: int): + ''' + Updates the expected length for the reference state + based on the provided total_length of the sequence. + This is the expected length of a single tract, determined as the sum + of the total length (sequence length * fraction) divided by the number + of tracts (sequence length * 1 / other's tracts). The + 1 assumes that + the sequence will start and end with the reference. + ''' + self.expected_lengths[self.ref_state] = ( + total_length * self.ref_fraction / + (total_length * self.other_sum + 1)) - # look along species - for s in np.transpose(sequence): - s = ''.join(s) - counts = Counter(s) - weighted.append(counts[gp.match_symbol]) - total = sum(counts.values()) - for k in counts: - counts[k] /= total - individual.append(defaultdict(int, counts)) + def initial_probabilities(self, + weighted_match_freqs: List[float]) -> np.array: + ''' + Estimate the initial probability of being in each state + based on the number of states and their expected fractions + Returns the initial probability of each state + ''' - total = sum(weighted) - weighted = [w / total for w in weighted] - return individual, symbols, weighted + init = [] + expectation_weight = .9 + for s, state in enumerate(self.known_states): + expected = self.expected_fractions[state] + estimated = weighted_match_freqs[s] + init.append(expected * expectation_weight + + estimated * (1 - expectation_weight)) + for state in self.unknown_states: + expected_frac = self.expected_fractions[state] + init.append(expected_frac) -def initial_probabilities(known_states: List[str], - unknown_states: List[str], - expected_frac: Dict, - weighted_match_freqs: List[float]) -> np.array: - ''' - Estimate the initial probability of being in each state - based on the number of states and their expected fractions - Returns the initial probability of each state - ''' + return init / np.sum(init) - init = [] - expectation_weight = .9 - for s, state in enumerate(known_states): - expected = expected_frac[state] - estimated = weighted_match_freqs[s] - init.append(expected * expectation_weight + - estimated * (1 - expectation_weight)) + def emission_probabilities(self, + symbols: List[str]) -> List[Dict]: + ''' + Estimate initial emission probabilities + Return estimates as list of default dict of probabilities + ''' - for state in unknown_states: - expected_frac = expected_frac[state] - init.append(expected_frac) + match = self.symbols['match'] + mismatch = self.symbols['mismatch'] + probabilities = { + mismatch + match: 0.9, + match + match: 0.09, + mismatch + mismatch: 0.009, + match + mismatch: 0.001, + } + + mismatch_bias = .99 + + num_per_category = 2 ** (len(self.known_states) - 2) + for key in probabilities: + probabilities[key] *= num_per_category + + # for known states + symbol_array = np.array([list(s) for s in symbols], dtype=' np.array: + ''' + Estimate initial transition probabilities + ''' - return init / np.sum(init) + # doesn't depend on sequence observations but maybe it should? + # also should we care about number of tracts rather than fraction + # of genome? maybe theoretically, but that number is a lot more + # suspect -def emission_probabilities(known_states: List[str], - unknown_states: List[str], - symbols: List[str]) -> List[Dict]: - ''' - Estimate initial emission probabilities - Return estimates as list of default dict of probabilities - ''' + states = self.known_states + self.unknown_states - probabilities = { - gp.mismatch_symbol + gp.match_symbol: 0.9, - gp.match_symbol + gp.match_symbol: 0.09, - gp.mismatch_symbol + gp.mismatch_symbol: 0.009, - gp.match_symbol + gp.mismatch_symbol: 0.001, - } - - mismatch_bias = .99 - - num_per_category = 2 ** (len(known_states) - 2) - for key in probabilities: - probabilities[key] *= num_per_category - - # for known states - symbol_array = np.array([list(s) for s in symbols], dtype=' np.array: - ''' - Estimate initial transition probabilities - ''' + def build_initial_hmm(self, seq: np.array) -> hmm_bw.HMM: + ''' + Build a HMM object initialized based on expected values and sequence + ''' - # doesn't depend on sequence observations but maybe it should? + # get frequencies of individual symbols (e.g. '+') and all full + # combinations of symbols (e.g. '+++-') + (symbol_freqs, + weighted_match_freqs) = self.get_symbol_freqs(seq) - # also should we care about number of tracts rather than fraction - # of genome? maybe theoretically, but that number is a lot more - # suspect + # new Hidden Markov Model + hmm = hmm_bw.HMM() - states = known_states + unknown_states + hmm.set_initial_p(self.initial_probabilities(weighted_match_freqs)) + hmm.set_emissions(self.emission_probabilities(symbol_freqs.keys())) + hmm.set_transitions(self.transition_probabilities()) + return hmm - fractions = np.array([expected_frac[s] for s in states]) - lengths = 1/np.array([expected_lengths[s] for s in states]) + def run_hmm(self, + alignment_file: str, + only_poly_sites: bool = True) -> Tuple[hmm_bw.HMM, + hmm_bw.HMM, + np.array]: + ''' + Runs the hmm training, returning the initial and trained HMM along + with the positions of hmm importance + ''' + coded_sequence, positions, len_seq = \ + self.encode_sequence(alignment_file, only_poly_sites) - # general case, - # trans[i,j] = 1/ length[i] * expected[j] * 1 /(1 - fraction[i]) - transitions = np.outer( - np.multiply(lengths, 1/(1-fractions)), - fractions) - # when i == j, trans[i,j] = 1 - 1/length[i] - np.fill_diagonal(transitions, 1-lengths) + self.update_expected_length(len_seq) + # set initial hmm parameters based on combination of (1) initial + # expectations (length of introgressed tract and fraction of + # genome/total number tracts and bases) and (2) number of sites at + # which predict seq matches each reference + hmm = self.build_initial_hmm(coded_sequence) - # normalize - return transitions / transitions.sum(axis=1)[:, None] + # set states and initial probabilties + hmm.set_hidden_states(self.known_states + self.unknown_states) + # copy before setting observations to save memory + hmm_init = copy.deepcopy(hmm) -def initial_hmm_parameters(seq: np.array, - known_states: List[str], - unknown_states: List[str], - expected_frac: Dict, - expected_lengths: Dict) -> hmm_bw.HMM: - ''' - Build a HMM object initialized based on expected values and provided data - ''' + # set obs + hmm.set_observations([coded_sequence]) - # get frequencies of individual symbols (e.g. '+') and all full - # combinations of symbols (e.g. '+++-') - (individual_symbol_freqs, - symbol_freqs, - weighted_match_freqs) = get_symbol_freqs(seq) - - init = initial_probabilities(known_states, unknown_states, - expected_frac, weighted_match_freqs) - emis = emission_probabilities(known_states, - unknown_states, - symbol_freqs.keys()) - trans = transition_probabilities(known_states, unknown_states, - expected_frac, expected_lengths) - - # new Hidden Markov Model - hmm = hmm_bw.HMM() - - hmm.set_initial_p(init) - hmm.set_emissions(emis) - hmm.set_transitions(trans) - return hmm - - -def predict_introgressed(ref_seqs: np.array, - predict_seq: np.array, - predict_args: Dict, - train: bool = True, - only_poly_sites: bool = True, - return_positions: bool = False) -> Tuple[ - List[str], - np.array, - hmm_bw.HMM, - hmm_bw.HMM, - np.array - ]: - ''' - Predict regions of introgression within the predicted sequence - ref_seqs: 2d np character array of the reference sequences - predict_seq: np character array of the sequence to perform prediction on - train: control whether or not to perform Baum-Welch estimation on HMM - only_poly_sites: control if only polymorphic sites should be considered - return_positions: if true, only the position of sites in reference sequence - is returned - Generally will return a tuple of the following: - The predicted types as a list of states - The posterior decoding of the trained HMM - The trained HMM object - The untrained HMM without sequences - The positions of sites with respect to the reference sequence - ''' + # Baum-Welch parameter estimation + hmm.train(self.convergence) - # code sequence by which reference it matches at each site; - # positions are relative to master (first) reference sequence - seq_coded, positions = ungap_and_code(predict_seq, ref_seqs) - if only_poly_sites: - seq_coded, positions = poly_sites(seq_coded, positions) - if return_positions: - return positions - - set_expectations(predict_args, len(predict_seq)) - - # set initial hmm parameters based on combination of (1) initial - # expectations (length of introgressed tract and fraction of - # genome/total number tracts and bases) and (2) number of sites at - # which predict seq matches each reference - hmm = initial_hmm_parameters(seq_coded, - predict_args['known_states'], - predict_args['unknown_states'], - predict_args['expected_frac'], - predict_args['expected_length']) - - # make predictions - - # set states and initial probabilties - hmm.set_hidden_states(predict_args['states']) - - # copy before setting observations to save memory - hmm_init = copy.deepcopy(hmm) - - # set obs - hmm.set_observations([seq_coded]) - - # optional Baum-Welch parameter estimation - if train: - hmm.train(predict_args['improvement_frac']) - - p = hmm.posterior_decoding() - path, path_probs = sim_process.get_max_path(p[0], hmm.hidden_states) - - # posterior - if type(predict_args['threshold']) is float: - path_t = sim_process.threshold_predicted(path, path_probs, - predict_args['threshold'], - predict_args['states'][0]) - return path_t, p[0], hmm, hmm_init, positions - - else: - hmm.set_observations([seq_coded]) - predicted = sim_predict.convert_predictions(hmm.viterbi(), - predict_args['states']) - return predicted, p[0], hmm, hmm_init, positions - - -def convert_to_blocks(state_seq: List[str], - states: List[str]) -> Dict[ - str, List[Tuple[int, int]]]: - ''' - Convert a list of sequences into a structure with start and end positions - Return structure is a dict keyed on species with values of Lists of - each block, which is a tuple with start and end positions - ''' - # single individual state sequence - blocks = {} - for state in states: - blocks[state] = [] - prev_species = state_seq[0] - block_start = 0 - block_end = 0 - for i in range(len(state_seq)): - if state_seq[i] == prev_species: - block_end = i - else: - blocks[prev_species].append((block_start, block_end)) - block_start = i - block_end = i - prev_species = state_seq[i] - # add last block - if prev_species not in blocks: - blocks[prev_species] = [] - blocks[prev_species].append((block_start, block_end)) - - return blocks - - -def write_positions(positions: np.array, - writer: TextIO, - strain: str, - chrm: str) -> None: - ''' - Write the positions of the specific strain, chromosome as a line to the - provided textIO object - ''' - writer.write(f'{strain}\t{chrm}\t' + - '\t'.join([str(x) for x in positions]) + '\n') + return hmm_init, hmm, positions + + def encode_sequence(self, + alignment_file: str, + only_poly_sites: bool = True) -> Tuple[ + np.array, + np.array, + int]: + ''' + open the supplied alignment file, encode, and return the coded + sequence along with the positions. If only_poly_sites is True, + also filter out non-polymorphic sites. + Returns the encoded sequence, positions, and length of original seq + ''' + _, sequences = read_fasta(alignment_file) + + references = sequences[:-1] + predicted = sequences[-1] + + seq_coded, positions = self.ungap_and_code(predicted, references) + if only_poly_sites: + seq_coded, positions = self.poly_sites(seq_coded, positions) + + return seq_coded, positions, len(predicted) + + def ungap_and_code(self, + predict_seq: str, + ref_seqs: List[str], + index_ref: int = 0) -> Tuple[np.array, np.array]: + ''' + Remove any sequence locations where a gap is present and code + into matching or mismatching sequence + Returns the coded sequences, by default an array of + where matching, - + where mismatching. Also return the positions where the sequences are + not gapped. + ''' + # index_ref is index of reference strain to index relative to + # build character array + sequences = np.array([list(predict_seq)] + + [list(r) for r in ref_seqs]) + + isbase = sequences != self.symbols['gap'] + + # make boolean for valid characters + isvalid = np.logical_and(sequences != self.symbols['gap'], + sequences != self.symbols['unsequenced']) + + # positions are where everything is valid, index where the reference is + # valid. The +1 removes the predict sequence at index 0 + positions = np.where( + np.all(isvalid[:, isbase[index_ref+1, :]], axis=0))[0] + + matches = np.where(sequences[0] == sequences[1:], + self.symbols['match'], + self.symbols['mismatch']) + + matches = np.fromiter((''.join(row) + for row in np.transpose( + matches[:, np.all(isvalid, axis=0)])), + dtype=f'U{len(sequences) - 1}') + + return matches, positions + + def poly_sites(self, + sequences: np.array, + positions: np.array) -> Tuple[np.array, np.array]: + ''' + Remove all sequences where the sequence is all match_symbol + Returns the filtered sequence and position + ''' + seq_len = len(sequences[0]) + # check if seq only contains match_symbol + retain = np.vectorize( + lambda x: x.count(self.symbols['match']) != seq_len)(sequences) + indices = np.where(retain)[0] + + ps_poly = positions[indices] + seq_poly = sequences[indices] + + return seq_poly, ps_poly def read_positions(filename: str) -> Dict[str, Dict[str, List[int]]]: @@ -699,7 +937,7 @@ def read_positions(filename: str) -> Dict[str, Dict[str, List[int]]]: keyed first by the strain, then chromosome. Returned positions are lists of ints ''' - with gzip.open(filename, 'rb') as reader: + with gzip.open(filename, 'rt') as reader: result = defaultdict({}) for line in reader: line = line.split() @@ -709,44 +947,6 @@ def read_positions(filename: str) -> Dict[str, Dict[str, List[int]]]: return result -def write_blocks_header(writer: TextIO) -> None: - ''' - Write header line to tab delimited block file: - strain chromosome predicted_species start end num_sites_hmm - ''' - # NOTE: num_sites_hmm represents the sites considered by the HMM, - # so it might exclude non-polymorphic sites in addition to gaps - writer.write('\t'.join(['strain', - 'chromosome', - 'predicted_species', - 'start', - 'end', - 'num_sites_hmm']) - + '\n') - - -def write_blocks(state_seq_blocks: List[Tuple[int, int]], - positions: np.array, - writer: TextIO, - strain: str, - chrm: str, - species_pred: str) -> None: - ''' - Write entry into tab delimited block file, with columns: - strain chromosome predicted_species start end num_sites_hmm - ''' - writer.write('\n'.join( - ['\t'.join([strain, - chrm, - species_pred, - str(positions[start]), - str(positions[end]), - str(end - start + 1)]) - for start, end in state_seq_blocks])) - if state_seq_blocks: # ensure ends with \n - writer.write('\n') - - def read_blocks(filename: str, labeled: bool = False) -> Dict[ str, Dict[str, Tuple[int, int, int, str]]]: @@ -771,64 +971,3 @@ def read_blocks(filename: str, item = (int(start), int(end), int(number_non_gap)) result[strain][chrm].append(item) return result - - -def get_emis_symbols(known_states: List[str]) -> List[str]: - ''' - Generate all permutations of match and mismatch symbols with - len(known_states) characters, in lexigraphical order - ''' - symbols = [gp.match_symbol, gp.mismatch_symbol] - emis_symbols = [''.join(x) for x in - itertools.product(symbols, repeat=len(known_states))] - emis_symbols.sort() - return emis_symbols - - - - -def write_hmm(hmm: hmm_bw.HMM, - writer: TextIO, - strain: str, - chrm: str, - emis_symbols: List[str]): - ''' - Write information on the provided hmm as a line to the supplied textIO - object. - Output is tab delimited with: - strain chromosome initial_probs emissions transitions - ''' - writer.write(f'{strain}\t{chrm}\t') - - states = len(hmm.hidden_states) - writer.write('\t'.join( - [f'{p}' for p in hmm.initial_p] + # initial - [f'{hmm.emissions[i, hmm.symbol_to_ind[symbol]]}' - if symbol in hmm.symbol_to_ind else '0.0' - for i in range(states) - for symbol in emis_symbols] + # emission - [f'{hmm.transitions[i, j]}' - for i in range(states) - for j in range(states)] # transition - )) - writer.write('\n') - - -def write_state_probs(probs: Dict[str, List[float]], - writer: TextIO, - strain: str, - chrm: str, - states: List[str]) -> None: - ''' - Write the probability each state to the supplied textIO object - Output is tab delimited with: - strain chrom state1:prob1,prob2,...,probn state2... - ''' - writer.write(f'{strain}\t{chrm}\t') - - writer.write('\t'.join( - [f'{state}:' + - ','.join([f'{site[i]:.5f}' for site in probs]) - for i, state in enumerate(states)])) - - writer.write('\n') diff --git a/code/analyze/predict_main.py b/code/analyze/predict_main.py deleted file mode 100644 index 3b7d8c9..0000000 --- a/code/analyze/predict_main.py +++ /dev/null @@ -1,122 +0,0 @@ -import sys -import os -import predict -import read_args -import gzip -import predict -import global_params as gp -from misc import read_fasta - - -''' -Predict states from aligned sequences -Input files: --refs_{strain}_chr{chromosome}_mafft.fa - -Output files: --blocks{species}.txt --hmm_init.txt --hmm.txt --positions.txt --probs.txt -''' -# read in analysis parameters -args = predict.process_predict_args(sys.argv[1:]) - -strain_dirs = align_helpers.get_strains( - align_helpers.flatten(gp.non_ref_dirs.values())) - -##====== -# output files and if and where to resume -##====== - -base_dir = f'{gp.analysis_out_dir_absolute}{args["tag"]}' -if not os.path.isdir(base_dir): - os.makedirs(base_dir) - -# introgressed blocks -blocks_f = {} -for s in args['states']: - blocks_f[s] = open(f'{base_dir}/blocks_{s}_{args["tag"]}.txt', 'w') - predict.write_blocks_header(blocks_f[s]) - -# HMM parameters -emis_symbols = predict.get_emis_symbols(args['known_states']) - -hmm_init_f = open(f'{base_dir}/hmm_init_{args["tag"]}.txt', 'w') -predict.write_hmm_header(args['known_states'], args['unknown_states'], - emis_symbols, hmm_init_f) - -hmm_f = open(f'{base_dir}/hmm_{args["tag"]}.txt', 'w') -predict.write_hmm_header(args['known_states'], args['unknown_states'], - emis_symbols, hmm_f) - -# posterior probabilities - -write_ps = True -if write_ps: - ps_f = gzip.open(f'{base_dir}/positions_{args["tag"]}.txt.gz', 'wt') - -probs_f = gzip.open(f'{base_dir}/probs_{args["tag"]}.txt.gz', 'wt') - -# loop through all sequences and predict introgression - - -for chrm in gp.chrms: - - for strain, strain_dir in strain_dirs: - - print(f'working on: {strain} {chrm}') - - ref_prefix = '_'.join(args['known_states']) - fn = (f'{args["setup_args"]["alignments_directory"]}' - f'{ref_prefix}_{strain}' - f'_chr{chrm}_mafft{gp.alignment_suffix}') - - if not os.path.exists(fn): - print(fn) - print(f'no alignment for {strain} {chrm}') - continue - - headers, seqs = read_fasta.read_fasta(fn) - - ref_seqs = seqs[:-1] - predict_seq = seqs[-1] - - # predict introgressed/non-introgressed tracts - - state_seq, probs, hmm, hmm_init, ps = \ - predict.predict_introgressed(ref_seqs, predict_seq, - args, train=True) - - state_seq_blocks = predict.convert_to_blocks(state_seq, args['states']) - - # output - - # the positions actually used in predictions - # (alignment columns with no gaps) - if write_ps: - predict.write_positions(ps, ps_f, strain, chrm) - - # blocks predicted to be introgressed, separate files for each species - for s in state_seq_blocks: - predict.write_blocks(state_seq_blocks[s], ps, blocks_f[s], - strain, chrm, s) - - # summary info about HMM (before training) - predict.write_hmm(hmm_init, hmm_init_f, strain, chrm, emis_symbols) - - # summary info about HMM (after training) - predict.write_hmm(hmm, hmm_f, strain, chrm, emis_symbols) - - # probabilities at each site - predict.write_state_probs(probs, probs_f, strain, - chrm, hmm.hidden_states) - -for k in blocks_f: - blocks_f[k].close() - -ps_f.close() -hmm_init_f.close() -hmm_f.close() -probs_f.close() diff --git a/code/config.yaml b/code/config.yaml index c4bc225..95f559f 100644 --- a/code/config.yaml +++ b/code/config.yaml @@ -75,7 +75,8 @@ chromosomes: ['I', 'II', 'III', 'IV', 'V', 'VI', 'VII', 'VIII', analysis_params: tag: p2e4 - improvement_frac: 0.001 + convergence_threshold: 0.001 + # threshold can be 'viterbi' or a float to threshold probabilities threshold: viterbi input_root: /tigress/AKEY/akey_vol2/aclark4/nobackup @@ -89,20 +90,30 @@ analysis_params: - name: CBS432 base_dir: /tigress/anneec/projects/introgression/data/CBS432/ gene_bank_dir: __INPUT_ROOT__/CBS432/ + expected_length: 10000 + expected_fraction: 0.025 - name: N_45 base_dir: __INPUT_ROOT__/para_sgrp/strains/N_45/ gene_bank_dir: null + expected_length: 10000 + expected_fraction: 0.025 - name: DBVPG6304 base_dir: __INPUT_ROOT__/para_sgrp/strains/DBVPG6304/ gene_bank_dir: null + expected_length: 10000 + expected_fraction: 0.025 - name: UWOPS91_917_1 base_dir: __INPUT_ROOT__/para_sgrp/strains/UWOPS91_917_1/ gene_bank_dir: null + expected_length: 10000 + expected_fraction: 0.025 unknown_states: - name: unknown + expected_length: 1000 + expected_fraction: 0.01 gene_bank_all: __INPUT_ROOT__/100_genomes/sequence.gb diff --git a/code/hmm/hmm_bw.py b/code/hmm/hmm_bw.py index 081c002..ea667d6 100644 --- a/code/hmm/hmm_bw.py +++ b/code/hmm/hmm_bw.py @@ -1,5 +1,6 @@ import numpy as np from typing import List, Dict, Tuple +import logging as log class HMM: @@ -92,40 +93,36 @@ def print_results(self, iterations: int, LL: float) -> None: ''' Write current state of HMM to stdout ''' - print( - f'''Iterations: {iterations} + message = f'Iterations: {iterations}\n\nLog Likelihood:\n{LL:.30e}' -Log Likelihood: -{LL:.30e} - -Initial State Probabilities:''' - ) + message += '\n\nInitial State Probabilities:\n' for i in range(len(self.hidden_states)): - print(f"{self.hidden_states[i]}={self.initial_p[i]:.30e}") - print() - print("Transition Probabilities:") + message += f'{self.hidden_states[i]}={self.initial_p[i]:.30e}\n' + + message += '\nTransition Probabilities:\n' for i in range(len(self.hidden_states)): for j in range(len(self.hidden_states)): - print(f"{self.hidden_states[i]},{self.hidden_states[j]}\ - ={self.transitions[i][j]:.30e}") - print() - print("Emission Probabilities:") + message += f"{self.hidden_states[i]},{self.hidden_states[j]}\ + ={self.transitions[i][j]:.30e}\n" + + message += '\nEmission Probabilities:\n' for i in range(len(self.hidden_states)): for k in sorted(self.observed_states): - print(f"{self.hidden_states[i]},{k}=\ - {self.emissions[i, self.symbol_to_ind[k]]:.30e}") - print() + message += f"{self.hidden_states[i]},{k}=\ + {self.emissions[i, self.symbol_to_ind[k]]:.30e}\n" + message += '\n' + log.debug(message) def train(self, - improvement_frac: float = .01, - max_iterations: int = None) -> None: + improvement_frac: float = 0.01, + max_iterations: int = None) -> None: ''' Train the hmm until either the max iterations is reached or the log likelihood fails to improve beyond the improvement factor ''' # calculate current log likelihood - print("calculating alpha") + log.debug('calculating alpha') alpha = self.forward() LL = self.log_likelihood(alpha) @@ -141,32 +138,32 @@ def train(self, and iterations < max_iterations)\ or LL - prev_LL > threshold: - print(f"Iteration {iterations}") + log.info(f'Iteration {iterations}') - print("calculating beta") + log.debug('calculating beta') beta = self.backward() - print("calculating gamma") + log.debug('calculating gamma') gamma = self.state_probs(alpha, beta) - print("calculating xi") + log.debug('calculating xi') xi = self.bw(alpha, beta) - print("updating parameters") + log.debug('updating parameters') self.initial_p = self.initial_probabilities(gamma) self.transitions = self.transition_probabilities(xi, gamma) self.emissions = self.emission_probabilities(gamma) assert np.isclose(np.sum(self.initial_p), 1), \ - f"{beta}\n{np.sum(self.initial_p)} {self.initial_p}" + f'{beta}\n{np.sum(self.initial_p)} {self.initial_p}' for t in self.transitions: assert np.isclose(np.sum(t), 1), \ - f"{xi} {gamma} {np.sum(t)} {t}" + f'{xi} {gamma} {np.sum(t)} {t}' for e in self.emissions: - assert np.isclose(np.sum(e), 1), f"{np.sum(e.values())} {e}" + assert np.isclose(np.sum(e), 1), f'{np.sum(e.values())} {e}' iterations += 1 - print("calculating alpha") + log.debug("calculating alpha") alpha = self.forward() prev_LL = LL @@ -176,11 +173,11 @@ def train(self, self.print_results(iterations, LL) if LL < prev_LL and not np.isclose(LL, prev_LL): - # NOTE does not stop execution - print('PROBLEM: log-likelihood stopped increasing; \ - stopping training now') + log.error('PROBLEM: log-likelihood stopped increasing; ' + 'stopping training now') + return - print(f"finished in {iterations} iterations") + log.info(f'finished in {iterations} iterations') def log_likelihood(self, alpha: np.array) -> float: ''' diff --git a/code/misc/config_utils.py b/code/misc/config_utils.py index 5ba7e40..428855a 100644 --- a/code/misc/config_utils.py +++ b/code/misc/config_utils.py @@ -176,6 +176,8 @@ def get_nested(config: Dict, keys: str): Return the value of the nested keys, or none if the key is invalid keys is a period separated list of keys as a string ''' + if config is None: + return None keys = keys.split('.') value = config try: diff --git a/code/misc/read_fasta.py b/code/misc/read_fasta.py index b09c691..a422f30 100644 --- a/code/misc/read_fasta.py +++ b/code/misc/read_fasta.py @@ -7,7 +7,7 @@ def read_fasta(fn: str, gz: bool = False) -> Tuple[ List[str], np.array]: ''' Read the provided fasta file, returning the - headers (lines startin with >) and sequences + headers (lines starting with >) and sequences ''' headers = [] diff --git a/code/sim/sim_analyze_hmm_bw.py b/code/sim/sim_analyze_hmm_bw.py index 0ca33f2..534329e 100644 --- a/code/sim/sim_analyze_hmm_bw.py +++ b/code/sim/sim_analyze_hmm_bw.py @@ -380,11 +380,10 @@ def get_symbol_freqs_one(states, seqs, unknown_species, \ return individual_symbol_freqs, symbol_freqs, weighted_match_freqs + def convert_predictions(path, states): - new_path = [] - for p in path: - new_path.append(states[p]) - return new_path + return [states[p] for p in path] + def initial_hmm_parameters(seqs, predict_species, index_to_species, states, \ unknown_species, \ diff --git a/code/sim/sim_process.py b/code/sim/sim_process.py index 57325d8..7db624a 100644 --- a/code/sim/sim_process.py +++ b/code/sim/sim_process.py @@ -318,10 +318,8 @@ def threshold_predicted(predicted: List[str], Given a list of states, predicted, and the associated probabilities, probs Converts any states with probability < threshold to the default state ''' - predicted_thresholded = np.array(predicted) probs = np.array(probs) - predicted_thresholded[probs < threshold] = default_state - return list(predicted_thresholded) + return list(np.where(probs < threshold, default_state, predicted)) def fill_seq(seq, polymorphic_sites, nsites, fill): diff --git a/code/test/analyze/test_main.py b/code/test/analyze/test_main.py index a847f7b..e771db4 100644 --- a/code/test/analyze/test_main.py +++ b/code/test/analyze/test_main.py @@ -2,6 +2,7 @@ from click.testing import CliRunner import analyze.main as main import yaml +import logging as log @pytest.fixture @@ -9,13 +10,18 @@ def runner(): return CliRunner() -def test_main_cli(runner, mocker): +def test_main_cli_configs(runner, mocker): result = runner.invoke(main.cli) assert result.exit_code == 0 with runner.isolated_filesystem(): - clean = mocker.patch('analyze.main.config_utils.clean_config', - return_value=dict()) + mock_clean = mocker.patch('analyze.main.config_utils.clean_config', + side_effect=lambda x: x) + mock_echo = mocker.patch('analyze.main.click.echo_via_pager') + mock_log_info = mocker.patch('analyze.main.log.info') + mock_log_debug = mocker.patch('analyze.main.log.debug') + mock_log_lvl = mocker.patch('analyze.main.log.basicConfig') + with open('config1.yaml', 'w') as f: yaml.dump({'test': '123'}, f) with open('config2.yaml', 'w') as f: @@ -25,6 +31,93 @@ def test_main_cli(runner, mocker): main.cli, '--config config1.yaml --config config2.yaml'.split()) assert result.exit_code == 0 - clean.assert_called_with( + mock_clean.assert_called_with( {'test': '23', 'test2': '34'}) + + # since no subcommand was called + mock_echo.assert_called_once() + + mock_log_lvl.assert_called_once_with(level=log.WARNING) + assert mock_log_info.call_args_list == [ + mocker.call('Verbosity set to WARNING'), + mocker.call('Reading in 2 config files') + ] + assert mock_log_debug.call_args_list == [ + mocker.call('Cleaned config:\ntest - 23\ntest2 - 34\n') + ] + + +def test_main_cli_verbosity(runner, mocker): + mock_log_info = mocker.patch('analyze.main.log.info') + mock_log_lvl = mocker.patch('analyze.main.log.basicConfig') + + result = runner.invoke( + main.cli, + '-v') + assert result.exit_code == 0 + mock_log_lvl.assert_called_once_with(level=log.CRITICAL) + assert mock_log_info.call_args_list == [ + mocker.call('Verbosity set to CRITICAL'), + mocker.call('Reading in 0 config files') + ] + + mock_log_info.reset_mock() + mock_log_lvl.reset_mock() + result = runner.invoke( + main.cli, + '-vv') + assert result.exit_code == 0 + mock_log_lvl.assert_called_once_with(level=log.ERROR) + assert mock_log_info.call_args_list == [ + mocker.call('Verbosity set to ERROR'), + mocker.call('Reading in 0 config files') + ] + + mock_log_info.reset_mock() + mock_log_lvl.reset_mock() + result = runner.invoke( + main.cli, + '-vvv') + assert result.exit_code == 0 + mock_log_lvl.assert_called_once_with(level=log.WARNING) + assert mock_log_info.call_args_list == [ + mocker.call('Verbosity set to WARNING'), + mocker.call('Reading in 0 config files') + ] + + mock_log_info.reset_mock() + mock_log_lvl.reset_mock() + result = runner.invoke( + main.cli, + '-vvvv') + assert result.exit_code == 0 + mock_log_lvl.assert_called_once_with(level=log.INFO) + assert mock_log_info.call_args_list == [ + mocker.call('Verbosity set to INFO'), + mocker.call('Reading in 0 config files') + ] + + mock_log_info.reset_mock() + mock_log_lvl.reset_mock() + result = runner.invoke( + main.cli, + '-vvvvv') + assert result.exit_code == 0 + mock_log_lvl.assert_called_once_with(level=log.DEBUG) + assert mock_log_info.call_args_list == [ + mocker.call('Verbosity set to DEBUG'), + mocker.call('Reading in 0 config files') + ] + + mock_log_info.reset_mock() + mock_log_lvl.reset_mock() + result = runner.invoke( + main.cli, + '-vvvvvv') + assert result.exit_code == 0 + mock_log_lvl.assert_called_once_with(level=log.DEBUG) + assert mock_log_info.call_args_list == [ + mocker.call('Verbosity set to DEBUG'), + mocker.call('Reading in 0 config files') + ] diff --git a/code/test/analyze/test_main_predict_args.py b/code/test/analyze/test_main_predict_args.py new file mode 100644 index 0000000..fa2d1c8 --- /dev/null +++ b/code/test/analyze/test_main_predict_args.py @@ -0,0 +1,285 @@ +import pytest +from click.testing import CliRunner +import analyze.main as main +import yaml +from analyze import predict +from pathlib import Path + + +''' +Unit tests for the predict command of main.py when all parameters are +provided by the config file +''' + + +@pytest.fixture +def runner(): + return CliRunner() + + +def test_threshold(runner, mocker): + mock_log = mocker.patch('analyze.main.log.info') + with runner.isolated_filesystem(): + with open('config.yaml', 'w') as f: + yaml.dump( + { + 'chromosomes': 'I II III'.split(), + }, f) + + result = runner.invoke( + main.cli, + '--config config.yaml predict --threshold 0.05') + + assert result.exit_code != 0 + assert str(result.exception) == 'No block file provided' + assert mock_log.call_args_list == [ + mocker.call('Verbosity set to WARNING'), + mocker.call('Reading in 1 config file'), + mocker.call('Found 3 chromosomes in config'), + mocker.call("Threshold value is '0.05'") + ] + + +def test_block(runner, mocker): + mock_log = mocker.patch('analyze.main.log.info') + with runner.isolated_filesystem(): + with open('config.yaml', 'w') as f: + yaml.dump( + { + 'chromosomes': 'I II III'.split(), + }, f) + + result = runner.invoke( + main.cli, + '--config config.yaml predict --threshold viterbi ' + '--blocks blocks_{state}.txt' + ) + + assert result.exit_code != 0 + assert str(result.exception) == \ + 'Unable to build prefix, no known states provided' + assert mock_log.call_args_list == [ + mocker.call('Verbosity set to WARNING'), + mocker.call('Reading in 1 config file'), + mocker.call('Found 3 chromosomes in config'), + mocker.call("Threshold value is 'viterbi'"), + mocker.call("Output blocks file is 'blocks_{state}.txt'"), + ] + + +def test_prefix(runner, mocker): + mock_log = mocker.patch('analyze.main.log.info') + with runner.isolated_filesystem(): + with open('config.yaml', 'w') as f: + yaml.dump( + { + 'chromosomes': 'I II III'.split(), + }, f) + + result = runner.invoke( + main.cli, + '--config config.yaml predict --threshold viterbi ' + '--blocks blocks_{state}.txt --prefix s1_s2' + ) + + assert result.exit_code != 0 + assert str(result.exception) == \ + 'Unable to find strains in config and no test_strains provided' + assert mock_log.call_args_list == [ + mocker.call('Verbosity set to WARNING'), + mocker.call('Reading in 1 config file'), + mocker.call('Found 3 chromosomes in config'), + mocker.call("Threshold value is 'viterbi'"), + mocker.call("Output blocks file is 'blocks_{state}.txt'"), + mocker.call("Prefix is 's1_s2'"), + ] + + +def test_test_strains(runner, mocker): + mock_log = mocker.patch('analyze.main.log.info') + with runner.isolated_filesystem(): + with open('config.yaml', 'w') as f: + yaml.dump( + { + 'chromosomes': 'I II III'.split(), + }, f) + + Path('s1_chrI.fa').touch() + Path('s1_chrII.fa').touch() + Path('s1_chrIII.fa').touch() + Path('s2_chrI.fa').touch() + Path('s2_chrII.fa').touch() + Path('s2_chrIII.fa').touch() + + result = runner.invoke( + main.cli, + '--config config.yaml predict --threshold viterbi ' + '--blocks blocks_{state}.txt --prefix s1_s2 ' + '--test-strains {strain}_chr{chrom}.fa' + ) + + assert result.exit_code != 0 + assert str(result.exception) == \ + 'No initial hmm file provided' + + print(mock_log.call_args_list) + assert mock_log.call_args_list == [ + mocker.call('Verbosity set to WARNING'), + mocker.call('Reading in 1 config file'), + mocker.call('Found 3 chromosomes in config'), + mocker.call("Threshold value is 'viterbi'"), + mocker.call("Output blocks file is 'blocks_{state}.txt'"), + mocker.call("Prefix is 's1_s2'"), + mocker.call('searching for *_chr*.fa'), + mocker.call('Found 1 test strain'), + mocker.call('Found 2 unique strains'), + ] + + +def test_outputs(runner, mocker): + mock_log = mocker.patch('analyze.main.log.info') + mock_calls = [ + mocker.call('Verbosity set to WARNING'), + mocker.call('Reading in 1 config file'), + mocker.call('Found 3 chromosomes in config'), + mocker.call("Threshold value is 'viterbi'"), + mocker.call("Output blocks file is 'blocks_{state}.txt'"), + mocker.call("Prefix is 's1_s2'"), + mocker.call('No test_strains provided'), + mocker.call('Found 2 unique strains'), + ] + + with runner.isolated_filesystem(): + with open('config.yaml', 'w') as f: + yaml.dump( + { + 'chromosomes': 'I II III'.split(), + 'strains': 'str1 str2 str1'.split(), + }, f) + + mock_log.reset_mock() + result = runner.invoke( + main.cli, + '--config config.yaml predict --threshold viterbi ' + '--blocks blocks_{state}.txt --prefix s1_s2 ' + '--hmm-initial hmm_init.txt ' + ) + + assert result.exit_code != 0 + assert str(result.exception) == \ + 'No trained hmm file provided' + assert mock_log.call_args_list == mock_calls + + with runner.isolated_filesystem(): + with open('config.yaml', 'w') as f: + yaml.dump( + { + 'chromosomes': 'I II III'.split(), + 'strains': 'str1 str2 str1'.split(), + }, f) + + mock_log.reset_mock() + result = runner.invoke( + main.cli, + '--config config.yaml predict --threshold viterbi ' + '--blocks blocks_{state}.txt --prefix s1_s2 ' + '--hmm-initial hmm_init.txt ' + '--hmm-trained hmm_trained.txt ' + ) + + assert result.exit_code != 0 + assert str(result.exception) == \ + 'No probabilities file provided' + assert mock_log.call_args_list == mock_calls + + with runner.isolated_filesystem(): + with open('config.yaml', 'w') as f: + yaml.dump( + { + 'chromosomes': 'I II III'.split(), + 'strains': 'str1 str2 str1'.split(), + }, f) + + mock_log.reset_mock() + result = runner.invoke( + main.cli, + '--config config.yaml predict --threshold viterbi ' + '--blocks blocks_{state}.txt --prefix s1_s2 ' + '--hmm-initial hmm_init.txt ' + '--hmm-trained hmm_trained.txt ' + '--probabilities probs.txt.gz ' + ) + + assert result.exit_code != 0 + assert str(result.exception) == \ + 'No alignment file provided' + assert mock_log.call_args_list == mock_calls + + with runner.isolated_filesystem(): + with open('config.yaml', 'w') as f: + yaml.dump( + { + 'chromosomes': 'I II III'.split(), + 'strains': 'str1 str2 str1'.split(), + 'analysis_params': { + 'known_states': [ + {'name': 's1'}, + {'name': 's2'}], + }, + }, f) + + mock_log.reset_mock() + result = runner.invoke( + main.cli, + '--config config.yaml predict --threshold viterbi ' + '--blocks blocks_{state}.txt --prefix s1_s2 ' + '--hmm-initial hmm_init.txt ' + '--hmm-trained hmm_trained.txt ' + '--probabilities probs.txt.gz ' + '--alignment {prefix}_{strain}_chr{chrom}.maf ' + ) + + assert result.exit_code != 0 + assert str(result.exception) == \ + 's1 did not provide an expected_length' + assert mock_log.call_args_list == mock_calls + [ + mocker.call("Hmm_initial file is 'hmm_init.txt'"), + mocker.call("Hmm_trained file is 'hmm_trained.txt'"), + mocker.call("Positions file is 'None'"), + mocker.call("Probabilities file is 'probs.txt.gz'"), + mocker.call("Alignment file is 's1_s2_{strain}_chr{chrom}.maf'")] + + mock_predict = mocker.patch.object(predict.Predictor, 'run_prediction') + with runner.isolated_filesystem(): + with open('config.yaml', 'w') as f: + yaml.dump( + { + 'chromosomes': 'I II III'.split(), + 'strains': 'str1 str2 str1'.split(), + 'analysis_params': { + 'known_states': [ + {'name': 's1'}, + {'name': 's2'}], + }, + }, f) + + mock_log.reset_mock() + result = runner.invoke( + main.cli, + '--config config.yaml predict --threshold viterbi ' + '--blocks blocks_{state}.txt --prefix s1_s2 ' + '--hmm-initial hmm_init.txt ' + '--hmm-trained hmm_trained.txt ' + '--probabilities probs.txt.gz ' + '--positions pos.txt.gz ' + '--alignment {prefix}_{strain}_chr{chrom}.maf ' + ) + + assert result.exit_code == 0 + assert mock_log.call_args_list == mock_calls + [ + mocker.call("Hmm_initial file is 'hmm_init.txt'"), + mocker.call("Hmm_trained file is 'hmm_trained.txt'"), + mocker.call("Positions file is 'pos.txt.gz'"), + mocker.call("Probabilities file is 'probs.txt.gz'"), + mocker.call("Alignment file is 's1_s2_{strain}_chr{chrom}.maf'")] + mock_predict.called_once_with(True) diff --git a/code/test/analyze/test_main_predict_config.py b/code/test/analyze/test_main_predict_config.py new file mode 100644 index 0000000..c7ff871 --- /dev/null +++ b/code/test/analyze/test_main_predict_config.py @@ -0,0 +1,397 @@ +import pytest +from click.testing import CliRunner +import analyze.main as main +import yaml +from analyze import predict +from pathlib import Path + + +''' +Unit tests for the predict command of main.py when all parameters are +provided by the config file +''' + + +@pytest.fixture +def runner(): + return CliRunner() + + +def test_chroms(runner, mocker): + result = runner.invoke( + main.cli, + 'predict') + assert result.exit_code != 0 + assert str(result.exception) == 'No chromosomes specified in config file!' + + mock_log = mocker.patch('analyze.main.log.info') + with runner.isolated_filesystem(): + with open('config.yaml', 'w') as f: + yaml.dump( + { + 'chromosomes': 'I II III'.split() + }, f) + + result = runner.invoke( + main.cli, + '--config config.yaml predict') + + assert result.exit_code != 0 + assert str(result.exception) == 'No threshold provided' + assert mock_log.call_args_list == [ + mocker.call('Verbosity set to WARNING'), + mocker.call('Reading in 1 config file'), + mocker.call('Found 3 chromosomes in config') + ] + + +def test_threshold(runner, mocker): + mock_log = mocker.patch('analyze.main.log.info') + with runner.isolated_filesystem(): + with open('config.yaml', 'w') as f: + yaml.dump( + { + 'chromosomes': 'I II III'.split(), + 'analysis_params': { + 'threshold': 'viterbi' + } + }, f) + + result = runner.invoke( + main.cli, + '--config config.yaml predict') + + assert result.exit_code != 0 + assert str(result.exception) == 'No block file provided' + assert mock_log.call_args_list == [ + mocker.call('Verbosity set to WARNING'), + mocker.call('Reading in 1 config file'), + mocker.call('Found 3 chromosomes in config'), + mocker.call("Threshold value is 'viterbi'") + ] + + +def test_block(runner, mocker): + mock_log = mocker.patch('analyze.main.log.info') + with runner.isolated_filesystem(): + with open('config.yaml', 'w') as f: + yaml.dump( + { + 'chromosomes': 'I II III'.split(), + 'analysis_params': { + 'threshold': 'viterbi' + }, + 'paths': {'analysis': { + 'block_files': 'blocks_{state}.txt', + }}, + }, f) + + result = runner.invoke( + main.cli, + '--config config.yaml predict') + + assert result.exit_code != 0 + assert str(result.exception) == \ + 'Unable to build prefix, no known states provided' + assert mock_log.call_args_list == [ + mocker.call('Verbosity set to WARNING'), + mocker.call('Reading in 1 config file'), + mocker.call('Found 3 chromosomes in config'), + mocker.call("Threshold value is 'viterbi'"), + mocker.call("Output blocks file is 'blocks_{state}.txt'"), + ] + + +def test_prefix(runner, mocker): + mock_log = mocker.patch('analyze.main.log.info') + with runner.isolated_filesystem(): + with open('config.yaml', 'w') as f: + yaml.dump( + { + 'chromosomes': 'I II III'.split(), + 'analysis_params': { + 'threshold': 'viterbi', + 'known_states': [ + {'name': 's1'}, + {'name': 's2'}], + }, + 'paths': {'analysis': { + 'block_files': 'blocks_{state}.txt', + }}, + }, f) + + result = runner.invoke( + main.cli, + '--config config.yaml predict') + + assert result.exit_code != 0 + assert str(result.exception) == \ + 'Unable to find strains in config and no test_strains provided' + assert mock_log.call_args_list == [ + mocker.call('Verbosity set to WARNING'), + mocker.call('Reading in 1 config file'), + mocker.call('Found 3 chromosomes in config'), + mocker.call("Threshold value is 'viterbi'"), + mocker.call("Output blocks file is 'blocks_{state}.txt'"), + mocker.call("Prefix is 's1_s2'"), + ] + + +def test_strains(runner, mocker): + mock_log = mocker.patch('analyze.main.log.info') + with runner.isolated_filesystem(): + with open('config.yaml', 'w') as f: + yaml.dump( + { + 'chromosomes': 'I II III'.split(), + 'strains': 'str1 str2 str1'.split(), + 'analysis_params': { + 'threshold': 'viterbi', + 'known_states': [ + {'name': 's1'}, + {'name': 's2'}], + }, + 'paths': {'analysis': { + 'block_files': 'blocks_{state}.txt', + }}, + }, f) + + result = runner.invoke( + main.cli, + '--config config.yaml predict') + + assert result.exit_code != 0 + assert str(result.exception) == \ + 'No initial hmm file provided' + assert mock_log.call_args_list == [ + mocker.call('Verbosity set to WARNING'), + mocker.call('Reading in 1 config file'), + mocker.call('Found 3 chromosomes in config'), + mocker.call("Threshold value is 'viterbi'"), + mocker.call("Output blocks file is 'blocks_{state}.txt'"), + mocker.call("Prefix is 's1_s2'"), + mocker.call('No test_strains provided'), + mocker.call('Found 2 unique strains'), + ] + + +def test_test_strains(runner, mocker): + mock_log = mocker.patch('analyze.main.log.info') + with runner.isolated_filesystem(): + with open('config.yaml', 'w') as f: + yaml.dump( + { + 'chromosomes': 'I II III'.split(), + 'analysis_params': { + 'threshold': 'viterbi', + 'known_states': [ + {'name': 's1'}, + {'name': 's2'}], + }, + 'paths': { + 'analysis': { + 'block_files': 'blocks_{state}.txt', + }, + 'test_strains': ['{strain}_chr{chrom}.fa']}, + }, f) + + Path('s1_chrI.fa').touch() + Path('s1_chrII.fa').touch() + Path('s1_chrIII.fa').touch() + Path('s2_chrI.fa').touch() + Path('s2_chrII.fa').touch() + Path('s2_chrIII.fa').touch() + + result = runner.invoke( + main.cli, + '--config config.yaml predict') + + assert result.exit_code != 0 + assert str(result.exception) == \ + 'No initial hmm file provided' + + assert mock_log.call_args_list == [ + mocker.call('Verbosity set to WARNING'), + mocker.call('Reading in 1 config file'), + mocker.call('Found 3 chromosomes in config'), + mocker.call("Threshold value is 'viterbi'"), + mocker.call("Output blocks file is 'blocks_{state}.txt'"), + mocker.call("Prefix is 's1_s2'"), + mocker.call('searching for *_chr*.fa'), + mocker.call('Found 1 test strain'), + mocker.call('Found 2 unique strains'), + ] + + +def test_outputs(runner, mocker): + mock_log = mocker.patch('analyze.main.log.info') + mock_calls = [ + mocker.call('Verbosity set to WARNING'), + mocker.call('Reading in 1 config file'), + mocker.call('Found 3 chromosomes in config'), + mocker.call("Threshold value is 'viterbi'"), + mocker.call("Output blocks file is 'blocks_{state}.txt'"), + mocker.call("Prefix is 's1_s2'"), + mocker.call('No test_strains provided'), + mocker.call('Found 2 unique strains'), + ] + + with runner.isolated_filesystem(): + with open('config.yaml', 'w') as f: + yaml.dump( + { + 'chromosomes': 'I II III'.split(), + 'strains': 'str1 str2 str1'.split(), + 'analysis_params': { + 'threshold': 'viterbi', + 'known_states': [ + {'name': 's1'}, + {'name': 's2'}], + }, + 'paths': {'analysis': { + 'block_files': 'blocks_{state}.txt', + 'hmm_initial': 'hmm_init.txt', + }}, + }, f) + + mock_log.reset_mock() + result = runner.invoke( + main.cli, + '--config config.yaml predict') + + assert result.exit_code != 0 + assert str(result.exception) == \ + 'No trained hmm file provided' + assert mock_log.call_args_list == mock_calls + + with runner.isolated_filesystem(): + with open('config.yaml', 'w') as f: + yaml.dump( + { + 'chromosomes': 'I II III'.split(), + 'strains': 'str1 str2 str1'.split(), + 'analysis_params': { + 'threshold': 'viterbi', + 'known_states': [ + {'name': 's1'}, + {'name': 's2'}], + }, + 'paths': {'analysis': { + 'block_files': 'blocks_{state}.txt', + 'hmm_initial': 'hmm_init.txt', + 'hmm_trained': 'hmm_trained.txt', + }}, + }, f) + + mock_log.reset_mock() + result = runner.invoke( + main.cli, + '--config config.yaml predict') + + assert result.exit_code != 0 + assert str(result.exception) == \ + 'No probabilities file provided' + assert mock_log.call_args_list == mock_calls + + with runner.isolated_filesystem(): + with open('config.yaml', 'w') as f: + yaml.dump( + { + 'chromosomes': 'I II III'.split(), + 'strains': 'str1 str2 str1'.split(), + 'analysis_params': { + 'threshold': 'viterbi', + 'known_states': [ + {'name': 's1'}, + {'name': 's2'}], + }, + 'paths': {'analysis': { + 'block_files': 'blocks_{state}.txt', + 'hmm_initial': 'hmm_init.txt', + 'hmm_trained': 'hmm_trained.txt', + 'probabilities': 'probs.txt.gz', + }}, + }, f) + + mock_log.reset_mock() + result = runner.invoke( + main.cli, + '--config config.yaml predict') + + assert result.exit_code != 0 + assert str(result.exception) == \ + 'No alignment file provided' + assert mock_log.call_args_list == mock_calls + + with runner.isolated_filesystem(): + with open('config.yaml', 'w') as f: + yaml.dump( + { + 'chromosomes': 'I II III'.split(), + 'strains': 'str1 str2 str1'.split(), + 'analysis_params': { + 'threshold': 'viterbi', + 'known_states': [ + {'name': 's1'}, + {'name': 's2'}], + }, + 'paths': {'analysis': { + 'block_files': 'blocks_{state}.txt', + 'hmm_initial': 'hmm_init.txt', + 'hmm_trained': 'hmm_trained.txt', + 'probabilities': 'probs.txt.gz', + 'alignment': '{prefix}_{strain}_chr{chrom}.maf', + }}, + }, f) + + mock_log.reset_mock() + result = runner.invoke( + main.cli, + '--config config.yaml predict') + + assert result.exit_code != 0 + assert str(result.exception) == \ + 's1 did not provide an expected_length' + assert mock_log.call_args_list == mock_calls + [ + mocker.call("Hmm_initial file is 'hmm_init.txt'"), + mocker.call("Hmm_trained file is 'hmm_trained.txt'"), + mocker.call("Positions file is 'None'"), + mocker.call("Probabilities file is 'probs.txt.gz'"), + mocker.call("Alignment file is 's1_s2_{strain}_chr{chrom}.maf'")] + + mock_predict = mocker.patch.object(predict.Predictor, 'run_prediction') + with runner.isolated_filesystem(): + with open('config.yaml', 'w') as f: + yaml.dump( + { + 'chromosomes': 'I II III'.split(), + 'strains': 'str1 str2 str1'.split(), + 'analysis_params': { + 'threshold': 'viterbi', + 'known_states': [ + {'name': 's1'}, + {'name': 's2'}], + }, + 'paths': {'analysis': { + 'block_files': 'blocks_{state}.txt', + 'hmm_initial': 'hmm_init.txt', + 'hmm_trained': 'hmm_trained.txt', + 'positions': 'pos.txt.gz', + 'probabilities': 'probs.txt.gz', + 'alignment': '{prefix}_{strain}_chr{chrom}.maf', + }}, + }, f) + + mock_log.reset_mock() + result = runner.invoke( + main.cli, + '--config config.yaml predict') + + assert result.exit_code == 0 + assert mock_log.call_args_list == mock_calls + [ + mocker.call("Hmm_initial file is 'hmm_init.txt'"), + mocker.call("Hmm_trained file is 'hmm_trained.txt'"), + mocker.call("Positions file is 'pos.txt.gz'"), + mocker.call("Probabilities file is 'probs.txt.gz'"), + mocker.call("Alignment file is 's1_s2_{strain}_chr{chrom}.maf'")] + mock_predict.called_once_with(True) diff --git a/code/test/analyze/test_predict.py b/code/test/analyze/test_predict.py deleted file mode 100644 index 8cb67c8..0000000 --- a/code/test/analyze/test_predict.py +++ /dev/null @@ -1,715 +0,0 @@ -from analyze import predict -from hmm import hmm_bw as hmm -import pytest -from pytest import approx -from io import StringIO -from collections import defaultdict -import random -import numpy as np - - -def test_gp_symbols(): - # because the following tests use symbols instead of the variables - assert predict.gp.match_symbol == '+' - assert predict.gp.mismatch_symbol == '-' - assert predict.gp.gap_symbol == '-' - assert predict.gp.unsequenced_symbol == 'n' - - -@pytest.fixture -def args(): - args = {} - args['tag'] = 'p4e2' - args['improvement_frac'] = 0.001 - args['threshold'] = 'viterbi' - - args['known_states'] = ['S288c', 'CBS432', 'N_45', - 'DBVPG6304', 'UWOPS91_917_1'] - args['unknown_states'] = ['unknown'] - args['states'] = args['known_states'] + ['unknown'] - - args['expected_frac'] = {'DBVPG6304': 0.025, - 'UWOPS91_917_1': 0.025, - 'unknown': 0.01, - 'CBS432': 0.025, - 'N_45': 0.025, - 'S288c': 0.89} - - args['expected_length'] = {'DBVPG6304': 10000.0, - 'UWOPS91_917_1': 10000.0, - 'unknown': 1000.0, - 'CBS432': 10000.0, - 'N_45': 10000.0, - 'S288c': 0} - args['expected_num_tracts'] = {} - args['expected_bases'] = {} - return args - - -def old_test_process_predict_args(): - # test with default args - args = predict.process_predict_args('p4e2 .001 viterbi 10000 .025 10000\ - .025 10000 .025 10000 .025 unknown\ - 1000 .01'.split()) - assert args['tag'] == 'p4e2' - assert args['improvement_frac'] == 0.001 - assert args['threshold'] == 'viterbi' - - assert args['known_states'] == predict.gp.alignment_ref_order - assert args['unknown_states'] == ['unknown'] - assert args['states'] == predict.gp.alignment_ref_order + ['unknown'] - - assert args['expected_frac'] == {'DBVPG6304': 0.025, - 'UWOPS91_917_1': 0.025, - 'unknown': 0.01, - 'CBS432': 0.025, - 'N_45': 0.025, - 'S288c': 0.89} - - assert args['expected_length'] == {'DBVPG6304': 10000.0, - 'UWOPS91_917_1': 10000.0, - 'unknown': 1000.0, - 'CBS432': 10000.0, - 'N_45': 10000.0, - 'S288c': 0} - assert args['expected_num_tracts'] == {} - assert args['expected_bases'] == {} - - assert len(args.keys()) == 10 - - -def old_test_process_predict_args_threshold(): - args = predict.process_predict_args('p4e2 .001 test 10000 .025 10000\ - .025 10000 .025 10000 .025 unknown\ - 1000 .01'.split()) - assert args['threshold'] == 'viterbi' - - args = predict.process_predict_args('p4e2 .001 0.1 10000 .025 10000\ - .025 10000 .025 10000 .025 unknown\ - 1000 .01'.split()) - assert args['threshold'] == 0.1 - - -def old_test_process_predict_args_exceptions(): - # not enough unknown values - with pytest.raises(IndexError): - predict.process_predict_args('p4e2 .001 0.1 10000 .025 10000\ - .025 10000 .025 10000 .025 unknown\ - 1000'.split()) - - # not enough arg values - with pytest.raises(ValueError): - predict.process_predict_args('p4e2 .001 0.1 10000 .025 10000\ - .025 10000 .025 .025 unknown\ - 1000 0.01'.split()) - - with pytest.raises(ValueError): - predict.process_predict_args('p4e2 .001 0.1 10000 .025 10000\ - .025 10000 .025 10000 .025 unknown\ - 1000 NotADouble'.split()) - - -def test_write_blocks_header(): - writer = StringIO() - predict.write_blocks_header(writer) - - assert writer.getvalue() == '\t'.join(['strain', - 'chromosome', - 'predicted_species', - 'start', - 'end', - 'num_sites_hmm']) + '\n' - - -def test_get_emis_symbols(): - assert predict.get_emis_symbols([1]*1) == ['+', '-'] - assert predict.get_emis_symbols([1]*3) == ['+++', - '++-', - '+-+', - '+--', - '-++', - '-+-', - '--+', - '---', - ] - - -def test_write_hmm_header(): - # TODO this has a return value which doesn't seem to be used - writer = StringIO() - predict.write_hmm_header([], [], [], writer) - assert writer.getvalue() == 'strain\tchromosome\t\n' - - writer = StringIO() - predict.write_hmm_header(['s1', 's2'], ['u1'], ['-', '+'], writer) - - header = 'strain\tchromosome\t' - header += '\t'.join( - ['init_{}'.format(s) for s in ['s1', 's2', 'u1']] + - ['emis_{}_{}'.format(s, sym) - for s in ['s1', 's2', 'u1'] - for sym in ['-', '+']] + - ['trans_{}_{}'.format(s, s2) - for s in ['s1', 's2', 'u1'] - for s2 in ['s1', 's2', 'u1']]) - - assert writer.getvalue() == header + '\n' - - -def test_ungap_and_code(): - # nothing in prediction - sequence, positions = predict.ungap_and_code( - '---', # predicted reference string - ['abc', 'def', 'ghi'], # several references - 0) # reference index - assert positions == approx([]) - assert sequence == approx([]) - - # one match - sequence, positions = predict.ungap_and_code( - 'a--', - ['abc', 'def', 'ghi'], - 0) - assert positions == approx([0]) - assert sequence == ['+--'] - - # no match from refs - sequence, positions = predict.ungap_and_code( - 'a--', - ['abc', 'def', '-hi'], - 0) - assert positions == approx([]) - assert sequence == approx([]) - - # two matches - sequence, positions = predict.ungap_and_code( - 'ae-', - ['abc', 'def', 'gei'], - 0) - assert positions == approx([0, 1]) - assert (sequence == ['+--', '-++']).all() - - # mess with ref index - sequence, positions = predict.ungap_and_code( - 'a--e-', - ['a--bc', 'deeef', 'geeei'], - 0) - assert positions == approx([0, 1]) - assert (sequence == ['+--', '-++']).all() - sequence, positions = predict.ungap_and_code( - 'a--e-', - ['a--bc', 'deeef', 'geeei'], - 1) - assert positions == approx([0, 3]) - assert (sequence == ['+--', '-++']).all() - - sequence, positions = predict.ungap_and_code( - 'a---ef--i', - ['ab-dhfghi', - 'a-cceeg-i', - 'a-ceef-hh'], - 0) - - assert (sequence == '+++ -++ +-+ ++-'.split()).all() - assert positions == approx([0, 3, 4, 7]) - - -def test_poly_sites(): - sequence, positions = predict.poly_sites( - np.array('+++ -++ +-+ ++-'.split()), - np.array([0, 3, 4, 7]) - ) - assert (sequence == '-++ +-+ ++-'.split()).all() - assert positions == approx([3, 4, 7]) - - -def test_set_expectations_default(args): - prev_tract = dict(args['expected_length']) - assert args['expected_num_tracts'] == {} - assert args['expected_bases'] == {} - predict.set_expectations(args, 1e5) # made number arbitrary - print(args) - assert args['expected_num_tracts'] == {'DBVPG6304': 0.025 * 10, - 'UWOPS91_917_1': 0.025 * 10, - 'CBS432': 0.025 * 10, - 'N_45': 0.025 * 10, - 'S288c': 1 + 1} - - assert args['expected_bases'] == {'DBVPG6304': 0.025 * 1e5, - 'UWOPS91_917_1': 0.025 * 1e5, - 'CBS432': 0.025 * 1e5, - 'N_45': 0.025 * 1e5, - 'S288c': 1e5 - 1e4} - prev_tract['S288c'] = 45000 - assert args['expected_length'] == prev_tract - - -def test_get_symbol_freqs(): - sequence = '-++ +-+ ++- ---'.split() - symbol_test_helper(sequence) - # TODO throw better exception or handle better - # symbol_test_helper([]) - symbol_test_helper(['+']) - # get all len 10 symbols - syms = predict.get_emis_symbols([1]*10) - - random.seed(0) - for i in range(10): - sequence = [random.choice(syms) for j in range(100)] - symbol_test_helper(sequence) - - -def symbol_test_helper(sequence): - ind, symb, weigh = predict.get_symbol_freqs(np.array(sequence)) - - num_states = len(sequence[0]) - num_sites = len(sequence) - - individual_symbol_freqs = [] - for s in range(num_states): - d = defaultdict(int) - for i in range(num_sites): - d[sequence[i][s]] += 1 - for sym in d: - d[sym] /= num_sites - individual_symbol_freqs.append(d) - - symbol_freqs = defaultdict(int) - for i in range(num_sites): - symbol_freqs[sequence[i]] += 1 - for sym in symbol_freqs: - symbol_freqs[sym] /= num_sites - - # for each state, how often seq matches that state relative to - # others - weighted_match_freqs = [] - for s in range(num_states): - weighted_match_freqs.append( - individual_symbol_freqs[s][predict.gp.match_symbol]) - - weighted_match_freqs /= np.sum(weighted_match_freqs) - - assert ind == individual_symbol_freqs - assert symb == symbol_freqs - assert weigh == approx(weighted_match_freqs) - - -def test_initial_probabilities(args): - probs = predict.initial_probabilities(args['known_states'], - args['unknown_states'], - args['expected_frac'], - [0.1, 0.2, 0.3, 0.4, 0.5]) - - assert args['expected_frac'] == {'DBVPG6304': 0.025, - 'UWOPS91_917_1': 0.025, - 'unknown': 0.01, - 'CBS432': 0.025, - 'N_45': 0.025, - 'S288c': 0.89} - p = [0.1 + (0.89 - 0.1) * 0.9, - 0.2 + (0.025 - 0.2) * 0.9, - 0.3 + (0.025 - 0.3) * 0.9, - 0.4 + (0.025 - 0.4) * 0.9, - 0.5 + (0.025 - 0.5) * 0.9, - 0.01] - - p = p / np.sum(p, dtype=np.float) - - assert probs == approx(p) - - -def test_emission_probabilities(args): - # normal mode - symbols = predict.get_emis_symbols([1]*5) - - emis = predict.emission_probabilities(args['known_states'], - args['unknown_states'], - symbols) - - np_emis = np_emission(args, symbols) - is_approx_equal_list_dict(emis, np_emis) - - # too many symbols - symbols = predict.get_emis_symbols([1]*6) - emis = predict.emission_probabilities(args['known_states'], - args['unknown_states'], - symbols) - np_emis = np_emission(args, symbols) - is_approx_equal_list_dict(emis, np_emis) - - # more unknowns - args['unknown_states'].append('test') - symbols = predict.get_emis_symbols([1]*5) - emis = predict.emission_probabilities(args['known_states'], - args['unknown_states'], - symbols) - np_emis = np_emission(args, symbols) - is_approx_equal_list_dict(emis, np_emis) - - # no unknowns - args['unknown_states'] = [] - symbols = predict.get_emis_symbols([1]*5) - emis = predict.emission_probabilities(args['known_states'], - args['unknown_states'], - symbols) - np_emis = np_emission(args, symbols) - is_approx_equal_list_dict(emis, np_emis) - - -def np_emission(args, symbols): - probs = {'-+': 0.9, - '++': 0.09, - '--': 0.009, - '+-': 0.001} - mismatch_bias = 0.99 - - known_len = len(args['known_states']) - for k in probs: - probs[k] *= 2**(known_len - 2) - - emis = [] - # using older, iterative version - for s in range(known_len): - emis.append(defaultdict(float)) - for symbol in symbols: - key = symbol[0] + symbol[s] - emis[s][symbol] = probs[key] - - emis[s] = mynorm(emis[s]) - - symbol_len = len(symbols[0]) - for s in range(len(args['unknown_states'])): - emis.append(defaultdict(float)) - for symbol in symbols: - match_count = symbol.count('+') - mismatch_count = symbol_len - match_count - emis[s + known_len][symbol] = \ - (match_count * (1 - mismatch_bias) - + mismatch_count * mismatch_bias) - emis[s + known_len] = mynorm(emis[s + known_len]) - - return emis - - -def is_approx_equal_list_dict(actual, expected): - for i in range(len(actual)): - for k in actual[i]: - assert actual[i][k] == approx(expected[i][k]),\ - "failed at i={}, k={}".format(i, k) - - -def mynorm(d): - total = float(sum(d.values())) - return {k: v/total for k, v in d.items()} - - -def test_transition_probabilities(args): - args['expected_length']['S288c'] = 45000 - trans = predict.transition_probabilities(args['known_states'], - args['unknown_states'], - args['expected_frac'], - args['expected_length']) - - np_trans = np_transition(args) - for i in range(len(trans)): - assert trans[i] == approx(np_trans[i]) - - -def np_transition(args): - states = args['known_states'] + args['unknown_states'] - expected_frac = args['expected_frac'] - expected_length = args['expected_length'] - trans = [] - for i in range(len(states)): - state_from = states[i] - trans.append([]) - scale_other = 1 / (1 - expected_frac[state_from]) - for j in range(len(states)): - state_to = states[j] - if state_from == state_to: - trans[i].append(1 - 1./expected_length[state_from]) - else: - trans[i].append(1./expected_length[state_from] * - expected_frac[state_to] * scale_other) - - trans[i] /= np.sum(trans[i]) - - return trans - - -def test_initial_hmm_parameters(args): - args['expected_length']['S288c'] = 45000 - symbols = predict.get_emis_symbols([1]*5) - hm = predict.initial_hmm_parameters( - symbols, - args['known_states'], - args['unknown_states'], - args['expected_frac'], - args['expected_length']) - - assert args['expected_frac'] == {'DBVPG6304': 0.025, - 'UWOPS91_917_1': 0.025, - 'unknown': 0.01, - 'CBS432': 0.025, - 'N_45': 0.025, - 'S288c': 0.89} - p = [0.2 + (0.89 - 0.2) * 0.9, - 0.2 + (0.025 - 0.2) * 0.9, - 0.2 + (0.025 - 0.2) * 0.9, - 0.2 + (0.025 - 0.2) * 0.9, - 0.2 + (0.025 - 0.2) * 0.9, - 0.01] - - p = p / np.sum(p, dtype=np.float) - assert hm.initial_p == approx(p) - - np_emis = np_emission(args, symbols) - hm2 = hmm.HMM() - hm2.set_emissions(np_emis) - assert hm.emissions == approx(hm2.emissions) - - np_trans = np_transition(args) - for i in range(len(hm.transitions)): - assert hm.transitions[i] == approx(np_trans[i]) - - -def test_predict_introgressed(args, capsys): - seqs = [list('NNENNENNEN'), # S2288c - list('NNNENEENNN'), # CBS432 - list('NN-NNEENNN'), # N_45 - list('NEENN-ENEN'), # DBVPG6304 - list('ENENNEENEN'), # UWOPS.. - list('NNENNEENEN'), # predicted - ] - ref = seqs[:-1] - pred = seqs[-1] - - path, prob, hmm, hmm_init, ps = predict.predict_introgressed( - ref, pred, args, train=True) - - # check hmm output - captured = capsys.readouterr() - out = captured.out.split('\n') - assert 'finished in 10 iterations' in out[-2] - - # ps are locations of polymorphic sites, not counting missing '-' - assert ps == approx([0, 1, 3, 6, 8]) - assert np.array_equal(hmm.initial_p, np.array([1, 0, 0, 0, 0, 0])) - - # check path - assert path == ['S288c', 'S288c', 'UWOPS91_917_1', - 'UWOPS91_917_1', 'UWOPS91_917_1'] - - assert prob[0][0] == 1 - - -def test_write_positions(): - output = StringIO() - predict.write_positions([0, 1, 3, 5, 7], output, 'test', 'I') - assert output.getvalue() == "{}\t{}\t{}\n".format( - "test", - "I", - "\t".join([str(i) for i in (0, 1, 3, 5, 7)])) - - -def test_write_blocks(): - output = StringIO() - block = [] - pos = [i * 2 for i in range(20)] - predict.write_blocks(block, - pos, - output, 'test', 'I', 'pred') - - assert output.getvalue() == '' - - output = StringIO() - block = [(0, 1), (4, 6), (10, 8)] - pos = [i * 2 for i in range(20)] - predict.write_blocks(block, - pos, - output, 'test', 'I', 'pred') - - result = "\n".join( - ["\t".join(['test', 'I', 'pred', - str(pos[s]), str(pos[e]), str(e - s + 1)]) - for s, e in block]) + "\n" - - assert output.getvalue() == result - - -def test_read_blocks(mocker): - block_in = StringIO(''' -''') - - mocked_file = mocker.patch('analyze.predict.open', - return_value=block_in) - output = predict.read_blocks('mocked') - - mocked_file.assert_called_with('mocked', 'r') - assert list(output.keys()) == [] - - block_in = StringIO('''header -test\tI\tpred\t100\t200\t10 -''') - - mocked_file = mocker.patch('analyze.predict.open', - return_value=block_in) - output = predict.read_blocks('mocked') - - assert len(output) == 1 - assert output['test']['I'] == [(100, 200, 10)] - - block_in = StringIO('''header -test\tI\tpred\t100\t200\t10 -test\tI\tpred\t200\t200\t30 -test\tI\tpred\t300\t400\t40 -test\tII\tpred\t300\t400\t40 -test2\tIII\tpred\t300\t400\t47 -''') - - mocked_file = mocker.patch('analyze.predict.open', - return_value=block_in) - output = predict.read_blocks('mocked') - - assert len(output) == 2 - assert len(output['test']) == 2 - assert len(output['test2']) == 1 - assert output['test']['I'] == [ - (100, 200, 10), - (200, 200, 30), - (300, 400, 40), - ] - assert output['test']['II'] == [(300, 400, 40)] - assert output['test2']['III'] == [(300, 400, 47)] - - -def test_read_blocks_labeled(mocker): - block_in = StringIO(''' -''') - - mocked_file = mocker.patch('analyze.predict.open', - return_value=block_in) - output = predict.read_blocks('mocked', labeled=True) - - mocked_file.assert_called_with('mocked', 'r') - assert list(output.keys()) == [] - - block_in = StringIO('''header -r1\ttest\tI\tpred\t100\t200\t10 -''') - - mocked_file = mocker.patch('analyze.predict.open', - return_value=block_in) - output = predict.read_blocks('mocked', labeled=True) - - assert len(output) == 1 - assert output['test']['I'] == [('r1', 100, 200, 10)] - - block_in = StringIO('''header -r1\ttest\tI\tpred\t100\t200\t10 -r2\ttest\tI\tpred\t200\t200\t30 -r3\ttest\tI\tpred\t300\t400\t40 -r4\ttest\tII\tpred\t300\t400\t40 -r5\ttest2\tIII\tpred\t300\t400\t47 -''') - - mocked_file = mocker.patch('analyze.predict.open', - return_value=block_in) - output = predict.read_blocks('mocked', labeled=True) - - assert len(output) == 2 - assert len(output['test']) == 2 - assert len(output['test2']) == 1 - assert output['test']['I'] == [ - ('r1', 100, 200, 10), - ('r2', 200, 200, 30), - ('r3', 300, 400, 40), - ] - assert output['test']['II'] == [('r4', 300, 400, 40)] - assert output['test2']['III'] == [('r5', 300, 400, 47)] - - -def test_write_hmm(): - output = StringIO() - - hm = hmm.HMM() - - # empty hmm - predict.write_hmm(hm, output, 'strain', 'I', list('abc')) - assert output.getvalue() == 'strain\tI\t\n' - - hm.set_hidden_states(list('abc')) - hm.set_initial_p([0, 1, 0]) - hm.set_transitions([[0, 1, 0], [1, 0, 0], [0, 0, 1]]) - hm.set_emissions([{'a': 1, 'b': 0, 'c': 0}, - {'a': 0, 'b': 0, 'c': 1}, - {'a': 0, 'b': 1, 'c': 0}, - ]) - - output = StringIO() - predict.write_hmm(hm, output, 'strain', 'I', list('abc')) - - result = 'strain\tI\t' - result += '\t'.join(list('010')) + '\t' # init - result += '\t'.join(list('100001010')) + '\t' # emis - result += '\t'.join(list('010100001')) + '\n' # trans - assert output.getvalue() == result - - -def test_write_state_probs(): - output = StringIO() - predict.write_state_probs([{}], output, 'strain', 'I', []) - - assert output.getvalue() == 'strain\tI\t\n' - - output = StringIO() - predict.write_state_probs([ - [0, 0, 1], - [1, 0, 0], - [0, 1, 1], - ], output, 'strain', 'I', list('abc')) - - assert output.getvalue() == \ - ('strain\tI\t' - 'a:0.00000,1.00000,0.00000\t' - 'b:0.00000,0.00000,1.00000\t' - 'c:1.00000,0.00000,1.00000\n') - - -def test_convert_to_blocks_one(): - random.seed(0) - states = [str(i) for i in range(10)] - # TODO fix this as an error or throw another exception - # help_test_convert_blocks(states, []) - help_test_convert_blocks(states, list('1')) - help_test_convert_blocks(states, list('12')) - help_test_convert_blocks(states, list('1111')) - - for test in range(10): - seq = [str(random.randint(0, 9)) for i in range(100)] - help_test_convert_blocks(states, seq) - - -def help_test_convert_blocks(states, seq): - blocks = predict.convert_to_blocks(seq, states) - - nseq = np.array(seq, int) - # add element to the end to catch repeats on last index - nseq = np.append(nseq, nseq[-1]+1) - diff = np.diff(nseq) - locs = np.nonzero(diff)[0] - lens = np.diff(locs) - lens = np.append(locs[0]+1, lens) - - current = 0 - result = defaultdict(list) - for i, l in enumerate(locs): - result[seq[l]].append((current, current + lens[i] - 1)) - current += lens[i] - - for k in blocks: - assert blocks[k] == result[k] - - -def test_run(): - pass diff --git a/code/test/analyze/test_predict_hmm_builder.py b/code/test/analyze/test_predict_hmm_builder.py new file mode 100644 index 0000000..232a5fd --- /dev/null +++ b/code/test/analyze/test_predict_hmm_builder.py @@ -0,0 +1,527 @@ +from analyze import predict +from hmm import hmm_bw as hmm +import pytest +from pytest import approx +from io import StringIO +from collections import defaultdict +import random +import numpy as np + + +@pytest.fixture +def default_builder(): + builder = predict.HMM_Builder({ + 'analysis_params': + {'reference': {'name': 'S228c'}, + 'known_states': [ + {'name': 'CBS432', + 'expected_length': 10000, + 'expected_fraction': 0.025}, + {'name': 'N_45', + 'expected_length': 10000, + 'expected_fraction': 0.025}, + {'name': 'DBVPG6304', + 'expected_length': 10000, + 'expected_fraction': 0.025}, + {'name': 'UWOPS91_917_1', + 'expected_length': 10000, + 'expected_fraction': 0.025}, + ], + 'unknown_states': [{'name': 'unknown', + 'expected_length': 1000, + 'expected_fraction': 0.01}, + ] + } + }) + builder.set_expected_values() + builder.update_expected_length(1e5) + return builder + + +@pytest.fixture +def builder(): + return predict.HMM_Builder(None) + + +def test_builder(builder): + assert builder.config is None + assert builder.symbols == { + 'match': '+', + 'mismatch': '-', + 'unknown': '?', + 'unsequenced': 'n', + 'gap': '-', + 'unaligned': '?', + 'masked': 'x' + } + + +def test_init(mocker): + mock_log = mocker.patch('analyze.predict.log') + predict.HMM_Builder(None) + # no config, all warnings + mock_log.warning.has_calls([ + mocker.call("Symbol for match unset in config, using default '+'"), + mocker.call("Symbol for mismatch unset in config, using default '-'"), + mocker.call("Symbol for unknown unset in config, using default '?'"), + mocker.call("Symbol for unsequenced unset in config, " + "using default 'n'"), + mocker.call("Symbol for gap unset in config, using default '-'"), + mocker.call("Symbol for unaligned unset in config, using default '?'"), + mocker.call("Symbol for masked unset in config, using default 'x'") + ]) + + # config, same warnings as above along with unused + mock_log = mocker.patch('analyze.predict.log') + predict.HMM_Builder({'HMM_symbols': {'unused': 'X'}}) + mock_log.warning.has_calls([ + mocker.call("Unused symbol in configuration: unused -> 'X'"), + mocker.call("Symbol for mismatch unset in config, using default '-'"), + mocker.call("Symbol for unknown unset in config, using default '?'"), + mocker.call("Symbol for unsequenced unset in config, " + "using default 'n'"), + mocker.call("Symbol for gap unset in config, using default '-'"), + mocker.call("Symbol for unaligned unset in config, using default '?'"), + mocker.call("Symbol for masked unset in config, using default 'x'") + ]) + + # overwrite + mock_log = mocker.patch('analyze.predict.log') + predict.HMM_Builder({'HMM_symbols': {'masked': 'X'}}) + mock_log.debug.has_calls([ + mocker.call("Overwriting default symbol for masked with 'X'") + ]) + + +def test_update_emission_symbols(builder): + assert builder.update_emission_symbols(1) == ['+', '-'] + assert builder.update_emission_symbols(3) == ['+++', + '++-', + '+-+', + '+--', + '-++', + '-+-', + '--+', + '---', + ] + + +def test_get_symbol_freqs(builder): + sequence = '-++ +-+ ++- ---'.split() + symbol_test_helper(sequence, builder) + symbol_test_helper(['+'], builder) + # get all len 10 symbols + syms = builder.update_emission_symbols(10) + + random.seed(0) + for i in range(10): + sequence = [random.choice(syms) for j in range(100)] + symbol_test_helper(sequence, builder) + + +def symbol_test_helper(sequence, builder): + symb, weigh = builder.get_symbol_freqs(np.array(sequence)) + + num_states = len(sequence[0]) + num_sites = len(sequence) + + individual_symbol_freqs = [] + for s in range(num_states): + d = defaultdict(int) + for i in range(num_sites): + d[sequence[i][s]] += 1 + for sym in d: + d[sym] /= num_sites + individual_symbol_freqs.append(d) + + symbol_freqs = defaultdict(int) + for i in range(num_sites): + symbol_freqs[sequence[i]] += 1 + for sym in symbol_freqs: + symbol_freqs[sym] /= num_sites + + # for each state, how often seq matches that state relative to + # others + weighted_match_freqs = [] + for s in range(num_states): + weighted_match_freqs.append( + individual_symbol_freqs[s][builder.symbols['match']]) + + weighted_match_freqs /= np.sum(weighted_match_freqs) + + assert symb == symbol_freqs + assert weigh == approx(weighted_match_freqs) + + +def test_set_expected_values(builder): + builder.config = { + 'analysis_params': + {'reference': {'name': 'S228c'}, + 'known_states': [ + {'name': 'CBS432', + 'expected_length': 10, + 'expected_fraction': 0.01}, + {'name': 'N_45', + 'expected_length': 10, + 'expected_fraction': 0.01}, + {'name': 'DBVPG6304', + 'expected_length': 10, + 'expected_fraction': 0.01}, + {'name': 'UWOPS91_917_1', + 'expected_length': 10, + 'expected_fraction': 0.01}, + ], + 'unknown_states': [{'name': 'unknown', + 'expected_length': 10, + 'expected_fraction': 0.01}, + ] + } + } + builder.set_expected_values() + assert builder.expected_lengths == { + 'CBS432': 10, + 'N_45': 10, + 'DBVPG6304': 10, + 'UWOPS91_917_1': 10, + 'unknown': 10} + + assert builder.expected_fractions == { + 'S228c': 0.95, + 'CBS432': 0.01, + 'N_45': 0.01, + 'DBVPG6304': 0.01, + 'UWOPS91_917_1': 0.01, + 'unknown': 0.01} + assert builder.ref_fraction == 0.96 + assert builder.other_sum == 0.004 + assert builder.ref_state == 'S228c' + + +def test_update_expected_length(builder): + builder.config = { + 'analysis_params': + {'reference': {'name': 'S228c'}, + 'known_states': [ + {'name': 'CBS432', + 'expected_length': 10000, + 'expected_fraction': 0.025}, + {'name': 'N_45', + 'expected_length': 10000, + 'expected_fraction': 0.025}, + {'name': 'DBVPG6304', + 'expected_length': 10000, + 'expected_fraction': 0.025}, + {'name': 'UWOPS91_917_1', + 'expected_length': 10000, + 'expected_fraction': 0.025}, + ], + 'unknown_states': [{'name': 'unknown', + 'expected_length': 1000, + 'expected_fraction': 0.01}, + ] + } + } + builder.set_expected_values() + + assert builder.expected_lengths == { + 'CBS432': 10000, + 'N_45': 10000, + 'DBVPG6304': 10000, + 'UWOPS91_917_1': 10000, + 'unknown': 1000} + + assert builder.expected_fractions == { + 'S228c': 0.89, + 'CBS432': 0.025, + 'N_45': 0.025, + 'DBVPG6304': 0.025, + 'UWOPS91_917_1': 0.025, + 'unknown': 0.01} + + assert builder.ref_fraction == 0.9 + assert builder.other_sum == 1e-5 + assert builder.ref_state == 'S228c' + + builder.update_expected_length(1e5) + assert builder.expected_lengths['S228c'] == 45000 + + +def test_initial_probabilities(default_builder): + probs = default_builder.initial_probabilities( + [0.1, 0.2, 0.3, 0.4, 0.5]) + + assert default_builder.expected_fractions == ( + {'DBVPG6304': 0.025, + 'UWOPS91_917_1': 0.025, + 'unknown': 0.01, + 'CBS432': 0.025, + 'N_45': 0.025, + 'S228c': 0.89}) + + p = [0.1 + (0.89 - 0.1) * 0.9, + 0.2 + (0.025 - 0.2) * 0.9, + 0.3 + (0.025 - 0.3) * 0.9, + 0.4 + (0.025 - 0.4) * 0.9, + 0.5 + (0.025 - 0.5) * 0.9, + 0.01] + + p = p / np.sum(p, dtype=np.float) + + assert probs == approx(p) + + +def test_emission_probabilities(default_builder): + # normal mode, 5 known_states + symbols = default_builder.update_emission_symbols(5) + emis = default_builder.emission_probabilities(symbols) + + iter_emis = iter_emission(default_builder, symbols) + is_approx_equal_list_dict(emis, iter_emis) + + # more unknowns + default_builder.unknown_states.append('test') + emis = default_builder.emission_probabilities(symbols) + iter_emis = iter_emission(default_builder, symbols) + is_approx_equal_list_dict(emis, iter_emis) + + # no unknowns + default_builder.unknown_states = [] + emis = default_builder.emission_probabilities(symbols) + iter_emis = iter_emission(default_builder, symbols) + is_approx_equal_list_dict(emis, iter_emis) + + # too many symbols + symbols = default_builder.update_emission_symbols(6) + emis = default_builder.emission_probabilities(symbols) + iter_emis = iter_emission(default_builder, symbols) + is_approx_equal_list_dict(emis, iter_emis) + + +def iter_emission(builder, symbols): + probs = {'-+': 0.9, + '++': 0.09, + '--': 0.009, + '+-': 0.001} + mismatch_bias = 0.99 + + known_len = len(builder.known_states) + for k in probs: + probs[k] *= 2**(known_len - 2) + + emis = [] + # using older, iterative version + for s in range(known_len): + emis.append(defaultdict(float)) + for symbol in symbols: + key = symbol[0] + symbol[s] + emis[s][symbol] = probs[key] + + emis[s] = mynorm(emis[s]) + + symbol_len = len(symbols[0]) + for s in range(len(builder.unknown_states)): + emis.append(defaultdict(float)) + for symbol in symbols: + match_count = symbol.count('+') + mismatch_count = symbol_len - match_count + emis[s + known_len][symbol] = \ + (match_count * (1 - mismatch_bias) + + mismatch_count * mismatch_bias) + emis[s + known_len] = mynorm(emis[s + known_len]) + + return emis + + +def is_approx_equal_list_dict(actual, expected): + for i in range(len(actual)): + for k in actual[i]: + assert actual[i][k] == approx(expected[i][k]),\ + "failed at i={}, k={}".format(i, k) + + +def mynorm(d): + total = float(sum(d.values())) + return {k: v/total for k, v in d.items()} + + +def test_transition_probabilities(default_builder): + trans = default_builder.transition_probabilities() + + iter_trans = iter_transition(default_builder) + for i in range(len(trans)): + assert trans[i] == approx(iter_trans[i]) + + +def iter_transition(builder): + states = builder.known_states + builder.unknown_states + expected_frac = builder.expected_fractions + expected_length = builder.expected_lengths + trans = [] + for i in range(len(states)): + state_from = states[i] + trans.append([]) + scale_other = 1 / (1 - expected_frac[state_from]) + for j in range(len(states)): + state_to = states[j] + if state_from == state_to: + trans[i].append(1 - 1./expected_length[state_from]) + else: + trans[i].append(1./expected_length[state_from] * + expected_frac[state_to] * scale_other) + + trans[i] /= np.sum(trans[i]) + + return trans + + +def test_build_initial_hmm(default_builder): + symbols = default_builder.update_emission_symbols(5) + hm = default_builder.build_initial_hmm( + symbols) + + assert default_builder.expected_fractions == ( + {'DBVPG6304': 0.025, + 'UWOPS91_917_1': 0.025, + 'unknown': 0.01, + 'CBS432': 0.025, + 'N_45': 0.025, + 'S228c': 0.89}) + + p = [0.2 + (0.89 - 0.2) * 0.9, + 0.2 + (0.025 - 0.2) * 0.9, + 0.2 + (0.025 - 0.2) * 0.9, + 0.2 + (0.025 - 0.2) * 0.9, + 0.2 + (0.025 - 0.2) * 0.9, + 0.01] + + p = p / np.sum(p, dtype=np.float) + assert hm.initial_p == approx(p) + + iter_emis = iter_emission(default_builder, symbols) + hm2 = hmm.HMM() + hm2.set_emissions(iter_emis) + assert hm.emissions == approx(hm2.emissions) + + iter_trans = iter_transition(default_builder) + for i in range(len(hm.transitions)): + assert hm.transitions[i] == approx(iter_trans[i]) + + +def test_run_hmm(default_builder, capsys, mocker): + seqs = [list('NNENNENNEN'), # S2288c + list('NNNENEENNN'), # CBS432 + list('NN-NNEENNN'), # N_45 + list('NEENN-ENEN'), # DBVPG6304 + list('ENENNEENEN'), # UWOPS.. + list('NNENNEENEN'), # predicted + ] + mock_fasta = mocker.patch('analyze.predict.read_fasta', + return_value=(None, seqs)) + mock_log_hmm = mocker.patch('hmm.hmm_bw.log.info') + + hmm_init, hmm, positions = default_builder.run_hmm('MOCKED', True) + + mock_fasta.called_with('MOCKED') + + # check hmm output + assert mock_log_hmm.call_args_list[-3:] == \ + [mocker.call('Iteration 8'), + mocker.call('Iteration 9'), + mocker.call('finished in 10 iterations')] + + # ps are locations of polymorphic sites, not counting missing '-' + assert positions == approx([0, 1, 3, 6, 8]) + assert np.array_equal(hmm.initial_p, np.array([1, 0, 0, 0, 0, 0])) + np.testing.assert_allclose( + hmm_init.initial_p, + np.array([0.8212314, 0.03825122, 0.04350912, + 0.04350912, 0.04350912, 0.00999001])) + + +def test_encode_sequence(builder, mocker): + mock_fasta = mocker.patch('analyze.predict.read_fasta', + return_value=(None, + [ + list('abcd'), + list('abed'), + list('bbcf'), + ])) + + seq_coded, positions, len_pred = builder.encode_sequence('test', True) + assert (seq_coded == '-- +- --'.split()).all() + assert (positions == [0, 2, 3]).all() + assert len_pred == 4 + mock_fasta.called_with('test') + + seq_coded, positions, len_pred = builder.encode_sequence('test2', False) + assert (seq_coded == '-- ++ +- --'.split()).all() + assert (positions == [0, 1, 2, 3]).all() + assert len_pred == 4 + mock_fasta.called_with('test2') + + +def test_ungap_and_code(builder): + # nothing in prediction + sequence, positions = builder.ungap_and_code( + '---', # predicted reference string + ['abc', 'def', 'ghi'], # several references + 0) # reference index + assert positions == approx([]) + assert sequence == approx([]) + + # one match + sequence, positions = builder.ungap_and_code( + 'a--', + ['abc', 'def', 'ghi'], + 0) + assert positions == approx([0]) + assert sequence == ['+--'] + + # no match from refs + sequence, positions = builder.ungap_and_code( + 'a--', + ['abc', 'def', '-hi'], + 0) + assert positions == approx([]) + assert sequence == approx([]) + + # two matches + sequence, positions = builder.ungap_and_code( + 'ae-', + ['abc', 'def', 'gei'], + 0) + assert positions == approx([0, 1]) + assert (sequence == ['+--', '-++']).all() + + # mess with ref index + sequence, positions = builder.ungap_and_code( + 'a--e-', + ['a--bc', 'deeef', 'geeei'], + 0) + assert positions == approx([0, 1]) + assert (sequence == ['+--', '-++']).all() + sequence, positions = builder.ungap_and_code( + 'a--e-', + ['a--bc', 'deeef', 'geeei'], + 1) + assert positions == approx([0, 3]) + assert (sequence == ['+--', '-++']).all() + + sequence, positions = builder.ungap_and_code( + 'a---ef--i', + ['ab-dhfghi', + 'a-cceeg-i', + 'a-ceef-hh'], + 0) + + assert (sequence == '+++ -++ +-+ ++-'.split()).all() + assert positions == approx([0, 3, 4, 7]) + + +def test_poly_sites(builder): + sequence, positions = builder.poly_sites( + np.array('+++ -++ +-+ ++-'.split()), + np.array([0, 3, 4, 7]) + ) + assert (sequence == '-++ +-+ ++-'.split()).all() + assert positions == approx([3, 4, 7]) diff --git a/code/test/analyze/test_predict_predictor.py b/code/test/analyze/test_predict_predictor.py new file mode 100644 index 0000000..30e5434 --- /dev/null +++ b/code/test/analyze/test_predict_predictor.py @@ -0,0 +1,1052 @@ +from analyze import predict +from hmm import hmm_bw as hmm +import pytest +from pytest import approx +from io import StringIO +from collections import defaultdict +import random +import numpy as np + + +@pytest.fixture +def predictor(): + result = predict.Predictor( + configuration={ + 'analysis_params': + {'reference': {'name': 'S228c'}, + 'known_states': [ + {'name': 'CBS432'}, + {'name': 'N_45'}, + {'name': 'DBVPG6304'}, + {'name': 'UWOPS91_917_1'}, + ], + 'unknown_states': [{'name': 'unknown'}] + } + } + ) + return result + + +def test_predictor(predictor): + assert predictor.known_states ==\ + 'S228c CBS432 N_45 DBVPG6304 UWOPS91_917_1'.split() + assert predictor.unknown_states == ['unknown'] + + +def test_set_chromosomes(predictor): + with pytest.raises(ValueError) as e: + predictor.set_chromosomes() + assert 'No chromosomes specified in config file!' in str(e) + + predictor.config = {'chromosomes': ['I']} + predictor.set_chromosomes() + assert predictor.chromosomes == ['I'] + + +def test_set_blocks_file(predictor): + with pytest.raises(ValueError) as e: + predictor.set_blocks_file('blocks_file') + assert '{state} not found in blocks_file' in str(e) + + predictor.set_blocks_file('blocks_file{state}') + assert predictor.blocks == 'blocks_file{state}' + + with pytest.raises(ValueError) as e: + predictor.set_blocks_file() + assert 'No block file provided' in str(e) + + predictor.config = {'paths': {'analysis': {'block_files': 'blocks_file'}}} + with pytest.raises(ValueError) as e: + predictor.set_blocks_file() + assert '{state} not found in blocks_file' in str(e) + + predictor.config = {'paths': {'analysis': {'block_files': + 'blocks_file{state}'}}} + predictor.set_blocks_file() + assert predictor.blocks == 'blocks_file{state}' + + +def test_set_prefix(predictor): + predictor.known_states = ['s1'] + predictor.set_prefix() + assert predictor.prefix == 's1' + + predictor.known_states = 's1 s2'.split() + predictor.set_prefix() + assert predictor.prefix == 's1_s2' + + predictor.set_prefix('prefix') + assert predictor.prefix == 'prefix' + + predictor.known_states = [] + with pytest.raises(ValueError) as e: + predictor.set_prefix() + assert 'Unable to build prefix, no known states provided' in str(e) + + +def test_set_threshold(predictor): + with pytest.raises(ValueError) as e: + predictor.set_threshold() + assert 'No threshold provided' in str(e) + + predictor.config = {'analysis_params': {'threshold': 'asdf'}} + with pytest.raises(ValueError) as e: + predictor.set_threshold() + assert 'Unsupported threshold value: asdf' in str(e) + + predictor.set_threshold(0.05) + assert predictor.threshold == 0.05 + + predictor.config = {'analysis_params': + {'threshold': 'viterbi'}} + predictor.set_threshold() + assert predictor.threshold == 'viterbi' + + +def test_set_strains(predictor, mocker): + mock_find = mocker.patch.object(predict.Predictor, 'find_strains') + + predictor.set_strains() + mock_find.called_with(None) + + with pytest.raises(ValueError) as e: + predictor.config = {'paths': {'test_strains': ['test']}} + predictor.set_strains() + assert '{strain} not found in test' in str(e) + + with pytest.raises(ValueError) as e: + predictor.config = {'paths': {'test_strains': ['test{strain}']}} + predictor.set_strains() + assert '{chrom} not found in test{strain}' in str(e) + + predictor.config = {'paths': {'test_strains': + ['test{strain}{chrom}']}} + predictor.set_strains() + mock_find.called_with(['test{strain}{chrom}']) + + predictor.set_strains('test{strain}{chrom}') + mock_find.called_with(['test{strain}{chrom}']) + + +def test_find_strains(predictor, mocker): + with pytest.raises(ValueError) as e: + predictor.find_strains() + assert ('Unable to find strains in config and ' + 'no test_strains provided') in str(e) + + predictor.config = {'strains': ['test2', 'test1']} + predictor.find_strains() + # sorted + assert predictor.strains == 'test1 test2'.split() + + predictor.config = {} + predictor.chromosomes = ['I'] + + # too many chroms for s1 + mock_glob = mocker.patch('analyze.predict.glob.iglob', + side_effect=[[ + 'test_prefix_s1_c1.fa', + 'test_prefix_s2_c1.fa', + 'test_prefix_s1_c2.fa', + 'test_prefix.fa', + ]]) + mock_log = mocker.patch('analyze.predict.log') + with pytest.raises(ValueError) as e: + predictor.find_strains(['test_prefix_{strain}_{chrom}.fa']) + + assert 'Strain s1 has incorrect number of chromosomes. Expected 1 found 2'\ + in str(e) + mock_glob.assert_called_with('test_prefix_*_*.fa') + mock_log.info.assert_called_with('searching for test_prefix_*_*.fa') + assert mock_log.debug.call_args_list == \ + [mocker.call("matched with ('s1', 'c1')"), + mocker.call("matched with ('s2', 'c1')"), + mocker.call("matched with ('s1', 'c2')"), + ] + + # no matches + mock_glob = mocker.patch('analyze.predict.glob.iglob', + side_effect=[[ + 'test_prefix.fa', + ]]) + mock_log = mocker.patch('analyze.predict.log') + with pytest.raises(ValueError) as e: + predictor.find_strains(['test_prefix_{strain}_{chrom}.fa']) + assert ('Found no chromosome sequence files in ' + "['test_prefix_{strain}_{chrom}.fa']") in str(e) + mock_glob.assert_called_with('test_prefix_*_*.fa') + mock_log.info.assert_called_with('searching for test_prefix_*_*.fa') + assert mock_log.debug.call_args_list == [] + + # correct, with second test_strains + mock_glob = mocker.patch('analyze.predict.glob.iglob', + side_effect=[ + [ + 'test_prefix_s1_c1.fa', + 'test_prefix_s2_c1.fa', + 'test_prefix.fa', + ], + ['test_prefix_c2_s3.fa'] + ]) + mock_log = mocker.patch('analyze.predict.log') + predictor.find_strains(['test_prefix_{strain}_{chrom}.fa', + 'test_prefix_{chrom}_{strain}.fa']) + assert mock_glob.call_args_list == \ + [mocker.call('test_prefix_*_*.fa'), + mocker.call('test_prefix_*_*.fa')] + assert mock_log.info.call_args_list ==\ + [mocker.call('searching for test_prefix_*_*.fa'), + mocker.call('searching for test_prefix_*_*.fa')] + assert mock_log.debug.call_args_list == \ + [mocker.call("matched with ('s1', 'c1')"), + mocker.call("matched with ('s2', 'c1')"), + mocker.call("matched with ('s3', 'c2')"), + ] + assert predictor.strains == ['s1', 's2', 's3'] + + +def test_set_output_files(predictor): + with pytest.raises(ValueError) as e: + predictor.set_output_files('', '', '', '', '') + assert 'No initial hmm file provided' in str(e) + + with pytest.raises(ValueError) as e: + predictor.set_output_files('init', '', '', '', '') + assert 'No trained hmm file provided' in str(e) + + with pytest.raises(ValueError) as e: + predictor.set_output_files('init', 'trained', 'pos', 'prob', '') + assert 'No alignment file provided' in str(e) + + with pytest.raises(ValueError) as e: + predictor.set_output_files('init', 'trained', 'pos', 'prob', 'align') + assert '{prefix} not found in align' in str(e) + + with pytest.raises(ValueError) as e: + predictor.set_output_files('init', 'trained', 'pos', 'prob', + 'align{prefix}') + assert '{strain} not found in align{prefix}' in str(e) + + with pytest.raises(ValueError) as e: + predictor.set_output_files('init', 'trained', 'pos', 'prob', + 'align{prefix}{strain}') + assert '{chrom} not found in align{prefix}{strain}' in str(e) + + predictor.prefix = 'pre' + predictor.set_output_files('init', 'trained', 'pos', 'prob', + 'align{prefix}{strain}{chrom}') + assert predictor.hmm_initial == 'init' + assert predictor.hmm_trained == 'trained' + assert predictor.positions == 'pos' + assert predictor.probabilities == 'prob' + assert predictor.alignment == 'alignpre{strain}{chrom}' + + predictor.set_output_files('init', 'trained', '', 'prob', + 'align{prefix}{strain}{chrom}') + assert predictor.hmm_initial == 'init' + assert predictor.hmm_trained == 'trained' + assert predictor.positions is None + assert predictor.probabilities == 'prob' + assert predictor.alignment == 'alignpre{strain}{chrom}' + + with pytest.raises(ValueError) as e: + predictor.config = {'paths': {'analysis': {'hmm_initial': 'init'}}} + predictor.set_output_files('', '', '', '', '') + assert 'No trained hmm file provided' in str(e) + + with pytest.raises(ValueError) as e: + predictor.config = {'paths': {'analysis': {'hmm_initial': 'init', + 'hmm_trained': 'trained', + 'positions': 'pos' + }}} + predictor.set_output_files('', '', '', '', '') + assert 'No probabilities file provided' in str(e) + + with pytest.raises(ValueError) as e: + predictor.config = {'paths': {'analysis': {'hmm_initial': 'init', + 'hmm_trained': 'trained', + 'positions': 'pos', + 'probabilities': 'prob' + }}} + predictor.set_output_files('', '', '', '', '') + assert 'No alignment file provided' in str(e) + + predictor.config = {'paths': {'analysis': { + 'hmm_initial': 'init', + 'hmm_trained': 'trained', + 'positions': 'pos', + 'probabilities': 'prob', + 'alignment': 'align{prefix}{strain}{chrom}' + }}} + predictor.set_output_files('', '', '', '', '') + + assert predictor.hmm_initial == 'init' + assert predictor.hmm_trained == 'trained' + assert predictor.positions == 'pos' + assert predictor.probabilities == 'prob' + assert predictor.alignment == 'alignpre{strain}{chrom}' + + predictor.config = {'paths': {'analysis': { + 'hmm_initial': 'init', + 'hmm_trained': 'trained', + 'probabilities': 'prob', + 'alignment': 'align{prefix}{strain}{chrom}' + }}} + predictor.set_output_files('', '', '', '', '') + + assert predictor.hmm_initial == 'init' + assert predictor.hmm_trained == 'trained' + assert predictor.positions is None + assert predictor.probabilities == 'prob' + assert predictor.alignment == 'alignpre{strain}{chrom}' + + +def test_validate_arguments(predictor): + predictor.chromosomes = 1 + predictor.blocks = 1 + predictor.prefix = 1 + predictor.strains = 1 + predictor.hmm_initial = 1 + predictor.hmm_trained = 1 + predictor.probabilities = 1 + predictor.alignment = 1 + predictor.known_states = 1 + predictor.unknown_states = 1 + predictor.threshold = 1 + predictor.config = { + 'analysis_params': + {'reference': {'name': 'S228c'}, + 'known_states': [ + {'name': 'CBS432', + 'expected_length': 1, + 'expected_fraction': 0.01}, + {'name': 'N_45', + 'expected_length': 1, + 'expected_fraction': 0.01}, + {'name': 'DBVPG6304', + 'expected_length': 1, + 'expected_fraction': 0.01}, + {'name': 'UWOPS91_917_1', + 'expected_length': 1, + 'expected_fraction': 0.01}, + ], + 'unknown_states': [{'name': 'unknown', + 'expected_length': 1, + 'expected_fraction': 0.01}, + ] + } + } + + assert predictor.validate_arguments() + + args = [ + 'chromosomes', + 'blocks', + 'prefix', + 'strains', + 'hmm_initial', + 'hmm_trained', + 'probabilities', + 'alignment', + 'known_states', + 'unknown_states', + 'threshold' + ] + + for arg in args: + predictor.__dict__[arg] = None + with pytest.raises(ValueError) as e: + predictor.validate_arguments() + assert ('Failed to validate Predictor, ' + f'required argument {arg} was unset') in str(e) + predictor.__dict__[arg] = 1 + + predictor.config = { + 'analysis_params': + {'reference': {'name': 'S228c'}, + 'unknown_states': [{'name': 'unknown', + 'expected_length': 1, + 'expected_fraction': 0.01}, + ] + } + } + with pytest.raises(ValueError) as e: + predictor.validate_arguments() + assert 'Configuration did not provide any known_states' in str(e) + + predictor.config = { + 'analysis_params': + {'known_states': [ + {'name': 'CBS432', + 'expected_length': 1, + 'expected_fraction': 0.01}, + {'name': 'N_45', + 'expected_length': 1, + 'expected_fraction': 0.01}, + ], + 'unknown_states': [{'name': 'unknown', + 'expected_length': 1, + 'expected_fraction': 0.01}, + ] + } + } + with pytest.raises(ValueError) as e: + predictor.validate_arguments() + assert 'Configuration did not specify a reference strain' in str(e) + + predictor.config = { + 'analysis_params': + {'reference': {'name': 'S228c'}, + 'known_states': [ + {'name': 'CBS432', + 'expected_fraction': 0.01}, + {'name': 'N_45', + 'expected_length': 1, + 'expected_fraction': 0.01}, + ], + 'unknown_states': [{'name': 'unknown', + 'expected_length': 1, + 'expected_fraction': 0.01}, + ] + } + } + with pytest.raises(ValueError) as e: + predictor.validate_arguments() + assert 'CBS432 did not provide an expected_length' in str(e) + + predictor.config = { + 'analysis_params': + {'reference': {'name': 'S228c'}, + 'known_states': [ + {'name': 'CBS432', + 'expected_length': 1, + 'expected_fraction': 0.01}, + {'name': 'N_45', + 'expected_length': 1, + }, + ], + 'unknown_states': [{'name': 'unknown', + 'expected_length': 1, + 'expected_fraction': 0.01}, + ] + } + } + with pytest.raises(ValueError) as e: + predictor.validate_arguments() + assert 'N_45 did not provide an expected_fraction' in str(e) + + predictor.config = { + 'analysis_params': + {'reference': {'name': 'S228c'}, + 'known_states': [ + {'name': 'CBS432', + 'expected_length': 1, + 'expected_fraction': 0.01}, + {'name': 'N_45', + 'expected_length': 1, + 'expected_fraction': 0.01}, + ], + 'unknown_states': [{'name': 'unknown', + 'expected_fraction': 0.01}, + ] + } + } + with pytest.raises(ValueError) as e: + predictor.validate_arguments() + assert 'unknown did not provide an expected_length' in str(e) + + predictor.config = { + 'analysis_params': + {'reference': {'name': 'S228c'}, + 'known_states': [ + {'name': 'CBS432', + 'expected_length': 1, + 'expected_fraction': 0.01}, + {'name': 'N_45', + 'expected_length': 1, + 'expected_fraction': 0.01}, + ], + 'unknown_states': [{'name': 'unknown', + 'expected_length': 1, + }, + ] + } + } + with pytest.raises(ValueError) as e: + predictor.validate_arguments() + assert 'unknown did not provide an expected_fraction' in str(e) + + +def test_run_prediction_no_pos(predictor, mocker, capsys): + predictor.chromosomes = ['I', 'II'] + predictor.blocks = 'blocks{state}.txt' + predictor.prefix = 'prefix' + predictor.strains = ['s1', 's2'] + predictor.hmm_initial = 'hmm_initial.txt' + predictor.hmm_trained = 'hmm_trained.txt' + predictor.probabilities = 'probs.txt' + predictor.alignment = 'prefix_{strain}_chr{chrom}.maf' + predictor.known_states = 'S228c CBS432 N_45 DBVP UWOP'.split() + predictor.unknown_states = ['unknown'] + predictor.states = predictor.known_states + predictor.unknown_states + predictor.threshold = 'viterbi' + predictor.config = { + 'analysis_params': + {'reference': {'name': 'S228c'}, + 'known_states': [ + {'name': 'CBS432', + 'expected_length': 10000, + 'expected_fraction': 0.025}, + {'name': 'N_45', + 'expected_length': 10000, + 'expected_fraction': 0.025}, + {'name': 'DBVP', + 'expected_length': 10000, + 'expected_fraction': 0.025}, + {'name': 'UWOP', + 'expected_length': 10000, + 'expected_fraction': 0.025}, + ], + 'unknown_states': [{'name': 'unknown', + 'expected_length': 1000, + 'expected_fraction': 0.01}, + ] + } + } + mock_files = [mocker.MagicMock() for i in range(8)] + mocker.patch('analyze.predict.open', + side_effect=mock_files) + mock_gzip = mocker.patch('analyze.predict.gzip.open') + mocker.patch('analyze.predict.log') + mocker.patch('analyze.predict.read_fasta', + return_value=(None, + [list('NNENNENNEN'), # S228c + list('NNNENEENNN'), # CBS432 + list('NN-NNEENNN'), # N_45 + list('NEENN-ENEN'), # DBVPG6304 + list('ENENNEENEN'), # UWOPS.. + list('NNENNEENEN'), # predicted + ] + )) + + mock_log_hmm = mocker.patch('hmm.hmm_bw.log.info') + + predictor.run_prediction(only_poly_sites=True) + + # check hmm output + assert mock_log_hmm.call_args_list[-3:] == \ + [mocker.call('Iteration 8'), + mocker.call('Iteration 9'), + mocker.call('finished in 10 iterations')] + + assert mock_gzip.call_args_list == [mocker.call('probs.txt', 'wt')] + + # probs and pos interspersed + print(mock_gzip.return_value.__enter__().write.call_args_list) + assert mock_gzip.return_value.__enter__().write.call_args_list == \ + [ + mocker.call('s1\tI\t'), + mocker.ANY, + mocker.call('\n'), + mocker.call('s2\tI\t'), + mocker.ANY, + mocker.call('\n'), + mocker.call('s1\tII\t'), + mocker.ANY, + mocker.call('\n'), + mocker.call('s2\tII\t'), + mocker.ANY, + mocker.call('\n'), + + ] + + +def test_run_prediction_full(predictor, mocker): + predictor.chromosomes = ['I', 'II'] + predictor.blocks = 'blocks{state}.txt' + predictor.prefix = 'prefix' + predictor.strains = ['s1', 's2'] + predictor.hmm_initial = 'hmm_initial.txt' + predictor.hmm_trained = 'hmm_trained.txt' + predictor.probabilities = 'probs.txt' + predictor.positions = 'pos.txt' + predictor.alignment = 'prefix_{strain}_chr{chrom}.maf' + predictor.known_states = 'S228c CBS432 N_45 DBVP UWOP'.split() + predictor.unknown_states = ['unknown'] + predictor.states = predictor.known_states + predictor.unknown_states + predictor.threshold = 'viterbi' + predictor.config = { + 'analysis_params': + {'reference': {'name': 'S228c'}, + 'known_states': [ + {'name': 'CBS432', + 'expected_length': 10000, + 'expected_fraction': 0.025}, + {'name': 'N_45', + 'expected_length': 10000, + 'expected_fraction': 0.025}, + {'name': 'DBVP', + 'expected_length': 10000, + 'expected_fraction': 0.025}, + {'name': 'UWOP', + 'expected_length': 10000, + 'expected_fraction': 0.025}, + ], + 'unknown_states': [{'name': 'unknown', + 'expected_length': 1000, + 'expected_fraction': 0.01}, + ] + } + } + mock_files = [mocker.MagicMock() for i in range(8)] + mock_open = mocker.patch('analyze.predict.open', + side_effect=mock_files) + mock_gzip = mocker.patch('analyze.predict.gzip.open') + mock_log = mocker.patch('analyze.predict.log') + mock_fasta = mocker.patch('analyze.predict.read_fasta', + return_value=(None, + [list('NNENNENNEN'), # S228c + list('NNNENEENNN'), # CBS432 + list('NN-NNEENNN'), # N_45 + list('NEENN-ENEN'), # DBVPG6304 + list('ENENNEENEN'), # UWOPS.. + list('NNENNEENEN'), # predicted + ] + )) + mock_log_hmm = mocker.patch('hmm.hmm_bw.log.info') + + predictor.run_prediction(only_poly_sites=True) + + # check hmm output + assert mock_log_hmm.call_args_list[-3:] == \ + [mocker.call('Iteration 8'), + mocker.call('Iteration 9'), + mocker.call('finished in 10 iterations')] + + mock_open.assert_has_calls([ + mocker.call('hmm_initial.txt', 'w'), + mocker.call('hmm_trained.txt', 'w'), + mocker.call('blocksS228c.txt', 'w'), + mocker.call('blocksCBS432.txt', 'w'), + mocker.call('blocksN_45.txt', 'w'), + mocker.call('blocksDBVP.txt', 'w'), + mocker.call('blocksUWOP.txt', 'w'), + mocker.call('blocksunknown.txt', 'w')]) + + # hmm_initial + mock_files[0].__enter__().write.assert_has_calls( + [ + mocker.call('strain\tchromosome\t'), + mocker.ANY, + mocker.call('\n'), + mocker.call('s1\tI\t'), + mocker.ANY, + mocker.call('\n'), + mocker.call('s2\tI\t'), + mocker.ANY, + mocker.call('\n'), + mocker.call('s1\tII\t'), + mocker.ANY, + mocker.call('\n'), + mocker.call('s2\tII\t'), + mocker.ANY, + mocker.call('\n') + ]) + # trained + mock_files[1].__enter__().write.assert_has_calls( + [ + mocker.call('strain\tchromosome\t'), + mocker.ANY, + mocker.call('\n'), + mocker.call('s1\tI\t'), + mocker.ANY, + mocker.call('\n'), + mocker.call('s2\tI\t'), + mocker.ANY, + mocker.call('\n'), + mocker.call('s1\tII\t'), + mocker.ANY, + mocker.call('\n'), + mocker.call('s2\tII\t'), + mocker.ANY, + mocker.call('\n') + ]) + # check initial probability (5th write, dereference to get string...) + hmm_entry = mock_files[1].__enter__().write.\ + call_args_list[4][0][0].split('\t') + assert hmm_entry[0] == '1.0' + assert hmm_entry[1] == '0.0' + assert hmm_entry[2] == '0.0' + assert hmm_entry[3] == '0.0' + assert hmm_entry[4] == '0.0' + assert hmm_entry[5] == '0.0' + assert hmm_entry[6] == '0.0' + + # blocks S228c + mock_files[2].__enter__().write.assert_has_calls( + [ + mocker.call('strain\tchromosome\tpredicted_species' + '\tstart\tend\tnum_sites_hmm\n'), + mocker.call('s1\tI\tS228c\t0\t1\t2'), + mocker.call('\n'), + mocker.call('s2\tI\tS228c\t0\t1\t2'), + mocker.call('\n'), + mocker.call('s1\tII\tS228c\t0\t1\t2'), + mocker.call('\n'), + mocker.call('s2\tII\tS228c\t0\t1\t2'), + mocker.call('\n') + ]) + # blocks CBS432 + mock_files[3].__enter__().write.assert_has_calls( + [ + mocker.call('strain\tchromosome\tpredicted_species' + '\tstart\tend\tnum_sites_hmm\n'), + mocker.call(''), + mocker.call(''), + mocker.call(''), + mocker.call('')]) + # blocks N_45 + mock_files[4].__enter__().write.assert_has_calls( + [ + mocker.call('strain\tchromosome\tpredicted_species' + '\tstart\tend\tnum_sites_hmm\n'), + mocker.call(''), + mocker.call(''), + mocker.call(''), + mocker.call('')]) + # blocks DBVP + mock_files[5].__enter__().write.assert_has_calls( + [ + mocker.call('strain\tchromosome\tpredicted_species' + '\tstart\tend\tnum_sites_hmm\n'), + mocker.call(''), + mocker.call(''), + mocker.call(''), + mocker.call('')]) + # blocks UWOP + mock_files[6].__enter__().write.assert_has_calls( + [ + mocker.call('strain\tchromosome\tpredicted_species' + '\tstart\tend\tnum_sites_hmm\n'), + mocker.call('s1\tI\tUWOP\t3\t8\t3'), + mocker.call('\n'), + mocker.call('s2\tI\tUWOP\t3\t8\t3'), + mocker.call('\n'), + mocker.call('s1\tII\tUWOP\t3\t8\t3'), + mocker.call('\n'), + mocker.call('s2\tII\tUWOP\t3\t8\t3'), + mocker.call('\n') + ]) + # blocks unknown + mock_files[7].__enter__().write.assert_has_calls( + [ + mocker.call('strain\tchromosome\tpredicted_species' + '\tstart\tend\tnum_sites_hmm\n'), + mocker.call(''), + mocker.call(''), + mocker.call(''), + mocker.call('')]) + + mock_gzip.assert_any_call('probs.txt', 'wt') + mock_gzip.assert_any_call('pos.txt', 'wt') + + # probs and pos interspersed + mock_gzip.return_value.__enter__().write.assert_has_calls( + [ + mocker.call('s1\tI\t0\t1\t3\t6\t8\n'), + mocker.call('s1\tI\t'), + mocker.ANY, + mocker.call('\n'), + mocker.call('s2\tI\t0\t1\t3\t6\t8\n'), + mocker.call('s2\tI\t'), + mocker.ANY, + mocker.call('\n'), + mocker.call('s1\tII\t0\t1\t3\t6\t8\n'), + mocker.call('s1\tII\t'), + mocker.ANY, + mocker.call('\n'), + mocker.call('s2\tII\t0\t1\t3\t6\t8\n'), + mocker.call('s2\tII\t'), + mocker.ANY, + mocker.call('\n'), + + ]) + + mock_fasta.assert_has_calls([ + mocker.call('prefix_s1_chrI.maf'), + mocker.call('prefix_s2_chrI.maf'), + mocker.call('prefix_s1_chrII.maf'), + mocker.call('prefix_s2_chrII.maf') + ]) + + mock_log.info.assert_has_calls([ + mocker.call('working on: s1 I'), + mocker.call('working on: s2 I'), + mocker.call('working on: s1 II'), + mocker.call('working on: s2 II') + ]) + + +def test_write_hmm_header(predictor): + predictor.known_states = [] + predictor.unknown_states = [] + predictor.emission_symbols = [] + writer = StringIO() + predictor.write_hmm_header(writer) + assert writer.getvalue() == 'strain\tchromosome\t\n' + + predictor.known_states = ['s1', 's2'] + predictor.unknown_states = ['u1'] + predictor.emission_symbols = ['-', '+'] + writer = StringIO() + predictor.write_hmm_header(writer) + + header = 'strain\tchromosome\t' + header += '\t'.join( + ['init_{}'.format(s) for s in ['s1', 's2', 'u1']] + + ['emis_{}_{}'.format(s, sym) + for s in ['s1', 's2', 'u1'] + for sym in ['-', '+']] + + ['trans_{}_{}'.format(s, s2) + for s in ['s1', 's2', 'u1'] + for s2 in ['s1', 's2', 'u1']]) + + assert writer.getvalue() == header + '\n' + + +def test_write_hmm(predictor): + predictor.emission_symbols = list('abc') + output = StringIO() + + hm = hmm.HMM() + + # empty hmm + predictor.write_hmm(hm, output, 'strain', 'I') + assert output.getvalue() == 'strain\tI\t\n' + + hm.set_hidden_states(list('abc')) + hm.set_initial_p([0, 1, 0]) + hm.set_transitions([[0, 1, 0], [1, 0, 0], [0, 0, 1]]) + hm.set_emissions([{'a': 1, 'b': 0, 'c': 0}, + {'a': 0, 'b': 0, 'c': 1}, + {'a': 0, 'b': 1, 'c': 0}, + ]) + + output = StringIO() + predictor.write_hmm(hm, output, 'strain', 'I') + + result = 'strain\tI\t' + result += '\t'.join(list('010')) + '\t' # init + result += '\t'.join(list('100001010')) + '\t' # emis + result += '\t'.join(list('010100001')) + '\n' # trans + assert output.getvalue() == result + + +def test_write_blocks_header(predictor): + writer = StringIO() + predictor.write_blocks_header(writer) + + assert writer.getvalue() == '\t'.join(['strain', + 'chromosome', + 'predicted_species', + 'start', + 'end', + 'num_sites_hmm']) + '\n' + + +def test_write_blocks(predictor): + output = StringIO() + block = [] + pos = [i * 2 for i in range(20)] + predictor.write_blocks(block, + pos, + output, 'test', 'I', 'pred') + + assert output.getvalue() == '' + + output = StringIO() + block = [(0, 1), (4, 6), (10, 8)] + pos = [i * 2 for i in range(20)] + predictor.write_blocks(block, + pos, + output, 'test', 'I', 'pred') + + result = "\n".join( + ["\t".join(['test', 'I', 'pred', + str(pos[s]), str(pos[e]), str(e - s + 1)]) + for s, e in block]) + "\n" + + assert output.getvalue() == result + + +def test_write_positions(predictor): + output = StringIO() + predictor.write_positions([0, 1, 3, 5, 7], output, 'test', 'I') + assert output.getvalue() == "{}\t{}\t{}\n".format( + "test", + "I", + "\t".join([str(i) for i in (0, 1, 3, 5, 7)])) + + +def test_write_state_probs(predictor): + output = StringIO() + predictor.states = [] + predictor.write_state_probs([{}], output, 'strain', 'I') + + assert output.getvalue() == 'strain\tI\t\n' + + output = StringIO() + predictor.states = list('abc') + predictor.write_state_probs([ + [0, 0, 1], + [1, 0, 0], + [0, 1, 1], + ], output, 'strain', 'I') + + assert output.getvalue() == \ + ('strain\tI\t' + 'a:0.00000,1.00000,0.00000\t' + 'b:0.00000,0.00000,1.00000\t' + 'c:1.00000,0.00000,1.00000\n') + + +def test_process_path(predictor, hm): + probs = hm.posterior_decoding()[0] + predictor.set_threshold(0.8) + predictor.states = 'N E'.split() + predictor.known_states = 'N E'.split() + path, probability = predictor.process_path(hm) + assert (probability == probs).all() + assert path == 'E E N E E N E E N N'.split() + + predictor.set_threshold('viterbi') + path, probability = predictor.process_path(hm) + + assert (probability == probs).all() + assert path == 'E E N E E N E E N E'.split() + + +def test_convert_to_blocks(predictor): + random.seed(0) + states = [str(i) for i in range(10)] + help_test_convert_blocks(states, list('1'), predictor) + help_test_convert_blocks(states, list('12'), predictor) + help_test_convert_blocks(states, list('1111'), predictor) + + for test in range(10): + seq = [str(random.randint(0, 9)) for i in range(100)] + help_test_convert_blocks(states, seq, predictor) + + +def help_test_convert_blocks(states, seq, predictor): + predictor.states = states + blocks = predictor.convert_to_blocks(seq) + + nseq = np.array(seq, int) + # add element to the end to catch repeats on last index + nseq = np.append(nseq, nseq[-1]+1) + diff = np.diff(nseq) + locs = np.nonzero(diff)[0] + lens = np.diff(locs) + lens = np.append(locs[0]+1, lens) + + current = 0 + result = defaultdict(list) + for i, l in enumerate(locs): + result[seq[l]].append((current, current + lens[i] - 1)) + current += lens[i] + + for k in blocks: + assert blocks[k] == result[k] + + +def test_read_blocks(mocker): + block_in = StringIO(''' +''') + + mocked_file = mocker.patch('analyze.predict.open', + return_value=block_in) + output = predict.read_blocks('mocked') + + mocked_file.assert_called_with('mocked', 'r') + assert list(output.keys()) == [] + + block_in = StringIO('''header +test\tI\tpred\t100\t200\t10 +''') + + mocked_file = mocker.patch('analyze.predict.open', + return_value=block_in) + output = predict.read_blocks('mocked') + + assert len(output) == 1 + assert output['test']['I'] == [(100, 200, 10)] + + block_in = StringIO('''header +test\tI\tpred\t100\t200\t10 +test\tI\tpred\t200\t200\t30 +test\tI\tpred\t300\t400\t40 +test\tII\tpred\t300\t400\t40 +test2\tIII\tpred\t300\t400\t47 +''') + + mocked_file = mocker.patch('analyze.predict.open', + return_value=block_in) + output = predict.read_blocks('mocked') + + assert len(output) == 2 + assert len(output['test']) == 2 + assert len(output['test2']) == 1 + assert output['test']['I'] == [ + (100, 200, 10), + (200, 200, 30), + (300, 400, 40), + ] + assert output['test']['II'] == [(300, 400, 40)] + assert output['test2']['III'] == [(300, 400, 47)] + + +def test_read_blocks_labeled(mocker): + block_in = StringIO(''' +''') + + mocked_file = mocker.patch('analyze.predict.open', + return_value=block_in) + output = predict.read_blocks('mocked', labeled=True) + + mocked_file.assert_called_with('mocked', 'r') + assert list(output.keys()) == [] + + block_in = StringIO('''header +r1\ttest\tI\tpred\t100\t200\t10 +''') + + mocked_file = mocker.patch('analyze.predict.open', + return_value=block_in) + output = predict.read_blocks('mocked', labeled=True) + + assert len(output) == 1 + assert output['test']['I'] == [('r1', 100, 200, 10)] + + block_in = StringIO('''header +r1\ttest\tI\tpred\t100\t200\t10 +r2\ttest\tI\tpred\t200\t200\t30 +r3\ttest\tI\tpred\t300\t400\t40 +r4\ttest\tII\tpred\t300\t400\t40 +r5\ttest2\tIII\tpred\t300\t400\t47 +''') + + mocked_file = mocker.patch('analyze.predict.open', + return_value=block_in) + output = predict.read_blocks('mocked', labeled=True) + + assert len(output) == 2 + assert len(output['test']) == 2 + assert len(output['test2']) == 1 + assert output['test']['I'] == [ + ('r1', 100, 200, 10), + ('r2', 200, 200, 30), + ('r3', 300, 400, 40), + ] + assert output['test']['II'] == [('r4', 300, 400, 40)] + assert output['test2']['III'] == [('r5', 300, 400, 47)] diff --git a/code/test/hmm/test_hmm_bw.py b/code/test/hmm/test_hmm_bw.py index 6e7c538..b271bc0 100644 --- a/code/test/hmm/test_hmm_bw.py +++ b/code/test/hmm/test_hmm_bw.py @@ -2,6 +2,7 @@ import pytest from pytest import approx import numpy as np +import logging as log def test_init(): @@ -42,11 +43,15 @@ def test_setters(): assert '[0, 0] 0' in str(e) -def test_print_results(capsys, hm): +def test_print_results(mocker, hm): + mock_debug = mocker.patch('hmm.hmm_bw.log.debug') + log.basicConfig(level=log.DEBUG) + hm.print_results(0, 1) - captured = capsys.readouterr() - out = captured.out.split('\n') + captured = mock_debug.call_args[0][0] + out = captured.split('\n') + print('\n'.join(out)) assert out[0] == 'Iterations: 0' assert out[2] == 'Log Likelihood:' assert out[5] == 'Initial State Probabilities:' @@ -69,10 +74,14 @@ def test_print_results(capsys, hm): assert float(out[19].split('=')[1]) == 0.8 -def test_train(capsys, hm): +def test_train(mocker, hm): + mock_debug = mocker.patch('hmm.hmm_bw.log.debug') + log.basicConfig(level=log.DEBUG) + hm.train() # get output from last report - out = capsys.readouterr().out.split('\n')[-23:] + captured = mock_debug.call_args[0][0] + out = captured.split('\n')[-23:] assert out[0] == 'Iterations: 2' assert out[2] == 'Log Likelihood:' assert out[5] == 'Initial State Probabilities:' diff --git a/code/test/misc/test_config_utils.py b/code/test/misc/test_config_utils.py index aa855a0..f0bed03 100644 --- a/code/test/misc/test_config_utils.py +++ b/code/test/misc/test_config_utils.py @@ -141,6 +141,7 @@ def test_get_nested(): assert get_nested({'a': {'b': 2}}, 'a.b') == 2 assert get_nested({'a': {'b': 2}}, 'a.c') is None assert get_nested({'a': {'b': {'c': 3}}}, 'a.b.c') == 3 + assert get_nested(None, 'key') is None def test_check_wildcards(mocker): From 85eaac8d2a7b94af72127c0fc24195052f5c713c Mon Sep 17 00:00:00 2001 From: Troy Comi Date: Thu, 25 Apr 2019 15:07:02 -0400 Subject: [PATCH 16/33] Fixed readme --- README.md | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index b1aa71d..64b6ac8 100644 --- a/README.md +++ b/README.md @@ -38,7 +38,7 @@ With the package installed and the conda environment activated, main methods are accessed with the `introgression` command. Some documentation is provided by adding the argument `--help` to introgression or any of its subcommands. -### introgression +#### introgression Options include: - --config: specify one or more configuration files. Files are evaluated in order. Conflicting values are overwritten by the newest file. This allows a @@ -46,8 +46,13 @@ base configuration for the system and analysis-specific configurations added as needed. - verbosity: set by varying the number of v's attached to the option, with `-v` indicating a log level of critical and `-vvvvv` indicating debug logging. + Available subcommands are: -- predict +##### predict +A brief description of what this does. + +Available options are: +- forthcoming... ## License TBD From 29ae638022d4de18c3dcaf0e292506812d6a77a9 Mon Sep 17 00:00:00 2001 From: Troy Comi Date: Fri, 26 Apr 2019 09:36:17 -0400 Subject: [PATCH 17/33] Tested predict main Modified behavior with missing files to match previous implementation (continue). Currently matching original implementation on chromosome 1. --- code/analyze/main.py | 3 +- code/analyze/predict.py | 19 ++-- code/config.yaml | 9 +- code/test/analyze/test_main_predict_args.py | 15 +++ code/test/analyze/test_predict_hmm_builder.py | 21 ++--- code/test/analyze/test_predict_predictor.py | 91 ++++++++++--------- code/test/helper_scripts/compare_outputs.sh | 6 +- code/test/helper_scripts/test_predict.slurm | 18 ++-- 8 files changed, 100 insertions(+), 82 deletions(-) diff --git a/code/analyze/main.py b/code/analyze/main.py index 281c307..0fb56e5 100644 --- a/code/analyze/main.py +++ b/code/analyze/main.py @@ -72,7 +72,8 @@ def cli(ctx, config, verbosity): help='Alignment file location with ' '{prefix}, {strain}, and {chrom}') @click.option('--only-poly-sites/--all-sites', default=True, - help='Consider only polymorphic sites or all sites') + help='Consider only polymorphic sites or all sites. ' + 'Default is only polymorphic.') def predict(ctx, blocks, prefix, diff --git a/code/analyze/predict.py b/code/analyze/predict.py index e819b55..46ccac9 100644 --- a/code/analyze/predict.py +++ b/code/analyze/predict.py @@ -2,6 +2,7 @@ import gzip import glob import re +import os import itertools from collections import defaultdict, Counter from hmm import hmm_bw @@ -191,7 +192,6 @@ def find_strains(self, test_strains: List[str] = None): log.info(f'searching for {strain_glob}') for fname in glob.iglob(strain_glob): # extract wildcard matches - print(fname) match = re.match( test_strain.format( strain='(?P.*?)', @@ -203,8 +203,8 @@ def find_strains(self, test_strains: List[str] = None): f'matched with {match.group("strain", "chrom")}') strain, chrom = match.group('strain', 'chrom') if strain not in strains: - strains[strain] = [] - strains[strain].append(chrom) + strains[strain] = set() + strains[strain].add(chrom) if len(strains) == 0: err = ('Found no chromosome sequence files ' @@ -212,11 +212,13 @@ def find_strains(self, test_strains: List[str] = None): log.exception(err) raise ValueError(err) + # check if requested chromosomes are within the list of chroms + chrom_set = set(self.chromosomes) for strain, chroms in strains.items(): - if len(self.chromosomes) != len(chroms): - err = (f'Strain {strain} has incorrect number of ' - f'chromosomes. Expected {len(self.chromosomes)} ' - f'found {len(chroms)}') + if not chrom_set.issubset(chroms): + not_found = chrom_set.difference(chroms).pop() + err = (f'Strain {strain} is missing chromosomes. ' + f'Unable to find chromosome \'{not_found}\'') log.exception(err) raise ValueError(err) @@ -371,6 +373,9 @@ def run_prediction(self, only_poly_sites=True): alignment_file = self.alignment.format( strain=strain, chrom=chrom) + if not os.path.exists(alignment_file): + log.info(f'skipping, file {alignment_file} not found') + continue hmm_initial, hmm_trained, pos = hmm_builder.run_hmm( alignment_file, only_poly_sites) diff --git a/code/config.yaml b/code/config.yaml index 95f559f..4d34e02 100644 --- a/code/config.yaml +++ b/code/config.yaml @@ -38,7 +38,7 @@ paths: suffix: .txt analysis: - analysis_base: __OUTPUT_ROOT__/analysis + analysis_base: __OUTPUT_ROOT__/analysis_chr1_test regions: __ANALYSIS_BASE__/regions/ genes: __ANALYSIS_BASE__/genes/ block_files: __ANALYSIS_BASE__/blocks_{state}.txt @@ -63,8 +63,9 @@ paths: ldselect: __ROOT_INSTALL__/ldSelect/ structure: __ROOT_INSTALL__/structure/ -chromosomes: ['I', 'II', 'III', 'IV', 'V', 'VI', 'VII', 'VIII', - 'IX', 'X', 'XI', 'XII', 'XIII', 'XIV', 'XV', 'XVI'] +chromosomes: ['I'] +# chromosomes: ['I', 'II', 'III', 'IV', 'V', 'VI', 'VII', 'VIII', +# 'IX', 'X', 'XI', 'XII', 'XIII', 'XIV', 'XV', 'XVI'] # can optionally list all strains to consider # if blank will glob with TEST_STRAINS paths @@ -82,7 +83,7 @@ analysis_params: # master known state, prepeded to list of known states reference: - name: S228c + name: S288c base_dir: __INPUT_ROOT__/100_genomes/genomes/S288c_SGD-R64/ gene_bank_dir: __INPUT_ROOT__/S288c/ diff --git a/code/test/analyze/test_main_predict_args.py b/code/test/analyze/test_main_predict_args.py index fa2d1c8..52e38e0 100644 --- a/code/test/analyze/test_main_predict_args.py +++ b/code/test/analyze/test_main_predict_args.py @@ -283,3 +283,18 @@ def test_outputs(runner, mocker): mocker.call("Probabilities file is 'probs.txt.gz'"), mocker.call("Alignment file is 's1_s2_{strain}_chr{chrom}.maf'")] mock_predict.called_once_with(True) + + mock_predict.reset_mock() + result = runner.invoke( + main.cli, + '--config config.yaml predict --threshold viterbi ' + '--blocks blocks_{state}.txt --prefix s1_s2 ' + '--hmm-initial hmm_init.txt ' + '--hmm-trained hmm_trained.txt ' + '--probabilities probs.txt.gz ' + '--positions pos.txt.gz ' + '--alignment {prefix}_{strain}_chr{chrom}.maf ' + '--all-sites' + ) + + mock_predict.called_once_with(False) diff --git a/code/test/analyze/test_predict_hmm_builder.py b/code/test/analyze/test_predict_hmm_builder.py index 232a5fd..1704cbd 100644 --- a/code/test/analyze/test_predict_hmm_builder.py +++ b/code/test/analyze/test_predict_hmm_builder.py @@ -2,7 +2,6 @@ from hmm import hmm_bw as hmm import pytest from pytest import approx -from io import StringIO from collections import defaultdict import random import numpy as np @@ -12,7 +11,7 @@ def default_builder(): builder = predict.HMM_Builder({ 'analysis_params': - {'reference': {'name': 'S228c'}, + {'reference': {'name': 'S288c'}, 'known_states': [ {'name': 'CBS432', 'expected_length': 10000, @@ -156,7 +155,7 @@ def symbol_test_helper(sequence, builder): def test_set_expected_values(builder): builder.config = { 'analysis_params': - {'reference': {'name': 'S228c'}, + {'reference': {'name': 'S288c'}, 'known_states': [ {'name': 'CBS432', 'expected_length': 10, @@ -186,7 +185,7 @@ def test_set_expected_values(builder): 'unknown': 10} assert builder.expected_fractions == { - 'S228c': 0.95, + 'S288c': 0.95, 'CBS432': 0.01, 'N_45': 0.01, 'DBVPG6304': 0.01, @@ -194,13 +193,13 @@ def test_set_expected_values(builder): 'unknown': 0.01} assert builder.ref_fraction == 0.96 assert builder.other_sum == 0.004 - assert builder.ref_state == 'S228c' + assert builder.ref_state == 'S288c' def test_update_expected_length(builder): builder.config = { 'analysis_params': - {'reference': {'name': 'S228c'}, + {'reference': {'name': 'S288c'}, 'known_states': [ {'name': 'CBS432', 'expected_length': 10000, @@ -231,7 +230,7 @@ def test_update_expected_length(builder): 'unknown': 1000} assert builder.expected_fractions == { - 'S228c': 0.89, + 'S288c': 0.89, 'CBS432': 0.025, 'N_45': 0.025, 'DBVPG6304': 0.025, @@ -240,10 +239,10 @@ def test_update_expected_length(builder): assert builder.ref_fraction == 0.9 assert builder.other_sum == 1e-5 - assert builder.ref_state == 'S228c' + assert builder.ref_state == 'S288c' builder.update_expected_length(1e5) - assert builder.expected_lengths['S228c'] == 45000 + assert builder.expected_lengths['S288c'] == 45000 def test_initial_probabilities(default_builder): @@ -256,7 +255,7 @@ def test_initial_probabilities(default_builder): 'unknown': 0.01, 'CBS432': 0.025, 'N_45': 0.025, - 'S228c': 0.89}) + 'S288c': 0.89}) p = [0.1 + (0.89 - 0.1) * 0.9, 0.2 + (0.025 - 0.2) * 0.9, @@ -385,7 +384,7 @@ def test_build_initial_hmm(default_builder): 'unknown': 0.01, 'CBS432': 0.025, 'N_45': 0.025, - 'S228c': 0.89}) + 'S288c': 0.89}) p = [0.2 + (0.89 - 0.2) * 0.9, 0.2 + (0.025 - 0.2) * 0.9, diff --git a/code/test/analyze/test_predict_predictor.py b/code/test/analyze/test_predict_predictor.py index 30e5434..d91f0e6 100644 --- a/code/test/analyze/test_predict_predictor.py +++ b/code/test/analyze/test_predict_predictor.py @@ -1,7 +1,6 @@ from analyze import predict from hmm import hmm_bw as hmm import pytest -from pytest import approx from io import StringIO from collections import defaultdict import random @@ -13,7 +12,7 @@ def predictor(): result = predict.Predictor( configuration={ 'analysis_params': - {'reference': {'name': 'S228c'}, + {'reference': {'name': 'S288c'}, 'known_states': [ {'name': 'CBS432'}, {'name': 'N_45'}, @@ -29,7 +28,7 @@ def predictor(): def test_predictor(predictor): assert predictor.known_states ==\ - 'S228c CBS432 N_45 DBVPG6304 UWOPS91_917_1'.split() + 'S288c CBS432 N_45 DBVPG6304 UWOPS91_917_1'.split() assert predictor.unknown_states == ['unknown'] @@ -145,23 +144,23 @@ def test_find_strains(predictor, mocker): # too many chroms for s1 mock_glob = mocker.patch('analyze.predict.glob.iglob', side_effect=[[ - 'test_prefix_s1_c1.fa', - 'test_prefix_s2_c1.fa', - 'test_prefix_s1_c2.fa', + 'test_prefix_s1_cII.fa', + 'test_prefix_s2_cII.fa', + 'test_prefix_s1_cIII.fa', 'test_prefix.fa', ]]) mock_log = mocker.patch('analyze.predict.log') with pytest.raises(ValueError) as e: - predictor.find_strains(['test_prefix_{strain}_{chrom}.fa']) + predictor.find_strains(['test_prefix_{strain}_c{chrom}.fa']) - assert 'Strain s1 has incorrect number of chromosomes. Expected 1 found 2'\ + assert "Strain s1 is missing chromosomes. Unable to find chromosome 'I'"\ in str(e) - mock_glob.assert_called_with('test_prefix_*_*.fa') - mock_log.info.assert_called_with('searching for test_prefix_*_*.fa') + mock_glob.assert_called_with('test_prefix_*_c*.fa') + mock_log.info.assert_called_with('searching for test_prefix_*_c*.fa') assert mock_log.debug.call_args_list == \ - [mocker.call("matched with ('s1', 'c1')"), - mocker.call("matched with ('s2', 'c1')"), - mocker.call("matched with ('s1', 'c2')"), + [mocker.call("matched with ('s1', 'II')"), + mocker.call("matched with ('s2', 'II')"), + mocker.call("matched with ('s1', 'III')"), ] # no matches @@ -178,29 +177,31 @@ def test_find_strains(predictor, mocker): mock_log.info.assert_called_with('searching for test_prefix_*_*.fa') assert mock_log.debug.call_args_list == [] - # correct, with second test_strains + # correct, with second test_strains, extra chromosomes mock_glob = mocker.patch('analyze.predict.glob.iglob', side_effect=[ [ - 'test_prefix_s1_c1.fa', - 'test_prefix_s2_c1.fa', + 'test_prefix_s1_cI.fa', + 'test_prefix_s2_cI.fa', + 'test_prefix_s2_cII.fa', 'test_prefix.fa', ], - ['test_prefix_c2_s3.fa'] + ['test_prefix_cI_s3.fa'] ]) mock_log = mocker.patch('analyze.predict.log') - predictor.find_strains(['test_prefix_{strain}_{chrom}.fa', - 'test_prefix_{chrom}_{strain}.fa']) + predictor.find_strains(['test_prefix_{strain}_c{chrom}.fa', + 'test_prefix_c{chrom}_{strain}.fa']) assert mock_glob.call_args_list == \ - [mocker.call('test_prefix_*_*.fa'), - mocker.call('test_prefix_*_*.fa')] + [mocker.call('test_prefix_*_c*.fa'), + mocker.call('test_prefix_c*_*.fa')] assert mock_log.info.call_args_list ==\ - [mocker.call('searching for test_prefix_*_*.fa'), - mocker.call('searching for test_prefix_*_*.fa')] + [mocker.call('searching for test_prefix_*_c*.fa'), + mocker.call('searching for test_prefix_c*_*.fa')] assert mock_log.debug.call_args_list == \ - [mocker.call("matched with ('s1', 'c1')"), - mocker.call("matched with ('s2', 'c1')"), - mocker.call("matched with ('s3', 'c2')"), + [mocker.call("matched with ('s1', 'I')"), + mocker.call("matched with ('s2', 'I')"), + mocker.call("matched with ('s2', 'II')"), + mocker.call("matched with ('s3', 'I')"), ] assert predictor.strains == ['s1', 's2', 's3'] @@ -315,7 +316,7 @@ def test_validate_arguments(predictor): predictor.threshold = 1 predictor.config = { 'analysis_params': - {'reference': {'name': 'S228c'}, + {'reference': {'name': 'S288c'}, 'known_states': [ {'name': 'CBS432', 'expected_length': 1, @@ -363,7 +364,7 @@ def test_validate_arguments(predictor): predictor.config = { 'analysis_params': - {'reference': {'name': 'S228c'}, + {'reference': {'name': 'S288c'}, 'unknown_states': [{'name': 'unknown', 'expected_length': 1, 'expected_fraction': 0.01}, @@ -396,7 +397,7 @@ def test_validate_arguments(predictor): predictor.config = { 'analysis_params': - {'reference': {'name': 'S228c'}, + {'reference': {'name': 'S288c'}, 'known_states': [ {'name': 'CBS432', 'expected_fraction': 0.01}, @@ -416,7 +417,7 @@ def test_validate_arguments(predictor): predictor.config = { 'analysis_params': - {'reference': {'name': 'S228c'}, + {'reference': {'name': 'S288c'}, 'known_states': [ {'name': 'CBS432', 'expected_length': 1, @@ -437,7 +438,7 @@ def test_validate_arguments(predictor): predictor.config = { 'analysis_params': - {'reference': {'name': 'S228c'}, + {'reference': {'name': 'S288c'}, 'known_states': [ {'name': 'CBS432', 'expected_length': 1, @@ -457,7 +458,7 @@ def test_validate_arguments(predictor): predictor.config = { 'analysis_params': - {'reference': {'name': 'S228c'}, + {'reference': {'name': 'S288c'}, 'known_states': [ {'name': 'CBS432', 'expected_length': 1, @@ -486,13 +487,13 @@ def test_run_prediction_no_pos(predictor, mocker, capsys): predictor.hmm_trained = 'hmm_trained.txt' predictor.probabilities = 'probs.txt' predictor.alignment = 'prefix_{strain}_chr{chrom}.maf' - predictor.known_states = 'S228c CBS432 N_45 DBVP UWOP'.split() + predictor.known_states = 'S288c CBS432 N_45 DBVP UWOP'.split() predictor.unknown_states = ['unknown'] predictor.states = predictor.known_states + predictor.unknown_states predictor.threshold = 'viterbi' predictor.config = { 'analysis_params': - {'reference': {'name': 'S228c'}, + {'reference': {'name': 'S288c'}, 'known_states': [ {'name': 'CBS432', 'expected_length': 10000, @@ -518,9 +519,10 @@ def test_run_prediction_no_pos(predictor, mocker, capsys): side_effect=mock_files) mock_gzip = mocker.patch('analyze.predict.gzip.open') mocker.patch('analyze.predict.log') + mocker.patch('analyze.predict.os.path.exists', return_value=True) mocker.patch('analyze.predict.read_fasta', return_value=(None, - [list('NNENNENNEN'), # S228c + [list('NNENNENNEN'), # S288c list('NNNENEENNN'), # CBS432 list('NN-NNEENNN'), # N_45 list('NEENN-ENEN'), # DBVPG6304 @@ -571,13 +573,13 @@ def test_run_prediction_full(predictor, mocker): predictor.probabilities = 'probs.txt' predictor.positions = 'pos.txt' predictor.alignment = 'prefix_{strain}_chr{chrom}.maf' - predictor.known_states = 'S228c CBS432 N_45 DBVP UWOP'.split() + predictor.known_states = 'S288c CBS432 N_45 DBVP UWOP'.split() predictor.unknown_states = ['unknown'] predictor.states = predictor.known_states + predictor.unknown_states predictor.threshold = 'viterbi' predictor.config = { 'analysis_params': - {'reference': {'name': 'S228c'}, + {'reference': {'name': 'S288c'}, 'known_states': [ {'name': 'CBS432', 'expected_length': 10000, @@ -603,9 +605,10 @@ def test_run_prediction_full(predictor, mocker): side_effect=mock_files) mock_gzip = mocker.patch('analyze.predict.gzip.open') mock_log = mocker.patch('analyze.predict.log') + mocker.patch('analyze.predict.os.path.exists', return_value=True) mock_fasta = mocker.patch('analyze.predict.read_fasta', return_value=(None, - [list('NNENNENNEN'), # S228c + [list('NNENNENNEN'), # S288c list('NNNENEENNN'), # CBS432 list('NN-NNEENNN'), # N_45 list('NEENN-ENEN'), # DBVPG6304 @@ -626,7 +629,7 @@ def test_run_prediction_full(predictor, mocker): mock_open.assert_has_calls([ mocker.call('hmm_initial.txt', 'w'), mocker.call('hmm_trained.txt', 'w'), - mocker.call('blocksS228c.txt', 'w'), + mocker.call('blocksS288c.txt', 'w'), mocker.call('blocksCBS432.txt', 'w'), mocker.call('blocksN_45.txt', 'w'), mocker.call('blocksDBVP.txt', 'w'), @@ -682,18 +685,18 @@ def test_run_prediction_full(predictor, mocker): assert hmm_entry[5] == '0.0' assert hmm_entry[6] == '0.0' - # blocks S228c + # blocks S288c mock_files[2].__enter__().write.assert_has_calls( [ mocker.call('strain\tchromosome\tpredicted_species' '\tstart\tend\tnum_sites_hmm\n'), - mocker.call('s1\tI\tS228c\t0\t1\t2'), + mocker.call('s1\tI\tS288c\t0\t1\t2'), mocker.call('\n'), - mocker.call('s2\tI\tS228c\t0\t1\t2'), + mocker.call('s2\tI\tS288c\t0\t1\t2'), mocker.call('\n'), - mocker.call('s1\tII\tS228c\t0\t1\t2'), + mocker.call('s1\tII\tS288c\t0\t1\t2'), mocker.call('\n'), - mocker.call('s2\tII\tS228c\t0\t1\t2'), + mocker.call('s2\tII\tS288c\t0\t1\t2'), mocker.call('\n') ]) # blocks CBS432 diff --git a/code/test/helper_scripts/compare_outputs.sh b/code/test/helper_scripts/compare_outputs.sh index 5c0f4a7..40544c2 100755 --- a/code/test/helper_scripts/compare_outputs.sh +++ b/code/test/helper_scripts/compare_outputs.sh @@ -1,13 +1,13 @@ #! /bin/bash -actual=/tigress/tcomi/aclark4_temp/results/analysis_test/ -expected=/tigress/tcomi/aclark4_temp/results/analysisp4e2/ +actual=/tigress/tcomi/aclark4_temp/results/analysis_chr1_test/ +expected=/tigress/tcomi/aclark4_temp/results/analysis_chr1/ echo starting comarison of $(basename $actual) to $(basename $expected) module load anaconda3 for file in $(ls $expected); do - act=$(echo $file | sed 's/\(.*\)p4e2\(\.txt.*\)/\1_test\2/') + act=$(echo $file | sed 's/__chr1//') if [[ $file = hmm* ]]; then cmp <(cat $actual$act | python hmm_format.py) \ <(cat $expected$file | python hmm_format.py) \ diff --git a/code/test/helper_scripts/test_predict.slurm b/code/test/helper_scripts/test_predict.slurm index f0033b3..dffeae8 100755 --- a/code/test/helper_scripts/test_predict.slurm +++ b/code/test/helper_scripts/test_predict.slurm @@ -1,23 +1,17 @@ #!/bin/bash -## SBATCH --array=1 #SBATCH --time=6-0 - #SBATCH -n 1 #SBATCH -o "/tigress/tcomi/aclark4_temp/results/predict_%A" -# ARGS=$(head -n $SLURM_ARRAY_TASK_ID predict_args.txt | tail -n 1) - -export PYTHONPATH=/home/tcomi/projects/aclark4_introgression/code/ - +config=/home/tcomi/projects/aclark4_introgression/code/config.yaml #Make sure chrms is set to only I -#ARGS="_chr1_test .001 viterbi 10000 .025 10000 .025 10000 .025 10000 .025 unknown 1000 .01" -ARGS="_test .001 viterbi 10000 .025 10000 .025 10000 .025 10000 .025 unknown 1000 .01" module load anaconda3 conda activate introgression3 -#python $PYTHONPATH/analyze/predict_main.py $ARGS -# gzip after -gzip /tigress/tcomi/aclark4_temp/results/analysis_test/positions__test.txt -gzip /tigress/tcomi/aclark4_temp/results/analysis_test/probs__test.txt +introgression \ + --config $config \ + -vvvv \ + predict + From 9438339886cdf7c51166be92e98aa243ce05a220 Mon Sep 17 00:00:00 2001 From: Troy Comi Date: Fri, 26 Apr 2019 10:50:36 -0400 Subject: [PATCH 18/33] Log file, progress bar Added log file option and config. When set, a progress bar is displayed on the console with click. --- code/analyze/main.py | 23 +++++++++++++++---- code/analyze/predict.py | 18 ++++++++++++++- code/config.yaml | 5 +++- code/test/analyze/test_main.py | 14 +++++------ code/test/analyze/test_main_predict_args.py | 10 ++++---- code/test/analyze/test_main_predict_config.py | 14 +++++------ code/test/analyze/test_predict_predictor.py | 8 +++---- code/test/helper_scripts/test_predict.slurm | 1 + 8 files changed, 63 insertions(+), 30 deletions(-) diff --git a/code/analyze/main.py b/code/analyze/main.py index 0fb56e5..57519e3 100644 --- a/code/analyze/main.py +++ b/code/analyze/main.py @@ -12,8 +12,11 @@ type=click.File('r'), help='Base configuration yaml.') @click.option('-v', '--verbosity', count=True, default=3) +@click.option('--log-file', + default='', + help='Optional log file. If unset print to stdout.') @click.pass_context -def cli(ctx, config, verbosity): +def cli(ctx, config, verbosity, log_file): ''' Main entry script to run analyze methods ''' @@ -28,18 +31,28 @@ def cli(ctx, config, verbosity): ('DEBUG', log.DEBUG), ][verbosity] - log.basicConfig(level=level) - log.info(f'Verbosity set to {levelstr}') - ctx.ensure_object(dict) confs = len(config) - log.info(f'Reading in {confs} config file{"" if confs == 1 else "s"}') for path in config: conf = yaml.safe_load(path) ctx.obj = config_utils.merge_dicts(ctx.obj, conf) ctx.obj = config_utils.clean_config(ctx.obj) + + if log_file == '': + log_file = config_utils.get_nested(ctx.obj, 'paths.log_file') + + if config_utils.get_nested(ctx.obj, 'paths'): + ctx.obj['paths']['log_file'] = log_file + + if log_file is not None: + log.basicConfig(level=level, filename=log_file, filemode='w') + else: + log.basicConfig(level=level) + log.info(f'Verbosity set to {levelstr}') + + log.info(f'Read in {confs} config file{"" if confs == 1 else "s"}') log.debug('Cleaned config:\n' + config_utils.print_dict(ctx.obj)) if ctx.invoked_subcommand is None: diff --git a/code/analyze/predict.py b/code/analyze/predict.py index 46ccac9..f3affaa 100644 --- a/code/analyze/predict.py +++ b/code/analyze/predict.py @@ -4,6 +4,7 @@ import re import os import itertools +import click from collections import defaultdict, Counter from hmm import hmm_bw from sim import sim_predict @@ -365,9 +366,21 @@ def run_prediction(self, only_poly_sites=True): for writer in block_writers.values(): self.write_blocks_header(writer) + counter = 0 + total = len(self.chromosomes) * len(self.strains) + # logging to file + progress_bar = None + if get_nested(self.config, 'paths.log_file'): + progress_bar = stack.enter_context( + click.progressbar( + length=total, + label='Running prediction')) + for chrom in self.chromosomes: for strain in self.strains: - log.info(f'working on: {strain} {chrom}') + counter += 1 + log.info(f'working on: {strain} {chrom} ' + f'({counter} of {total})') # get sequences and encode alignment_file = self.alignment.format( @@ -399,6 +412,9 @@ def run_prediction(self, only_poly_sites=True): self.write_state_probs(probs, probabilities, strain, chrom) + if progress_bar: + progress_bar.update(1) + def write_hmm_header(self, writer: TextIO) -> None: ''' Write the header line for an hmm file to the provided textIO object diff --git a/code/config.yaml b/code/config.yaml index 4d34e02..219b9a6 100644 --- a/code/config.yaml +++ b/code/config.yaml @@ -20,6 +20,8 @@ output_root: /tigress/tcomi/aclark4_temp/results input_root: /tigress/AKEY/akey_vol2/aclark4/nobackup paths: + # optional log file + # log_file: introgression.log fasta_suffix: .fa # suffix for _all_ fasta files # suffix for _all_ alignment files # this needs to match the suffix output by mugsy @@ -44,9 +46,10 @@ paths: block_files: __ANALYSIS_BASE__/blocks_{state}.txt hmm_initial: __ANALYSIS_BASE__/hmm_initial.txt hmm_trained: __ANALYSIS_BASE__/hmm_trained.txt - positions: __ANALYSIS_BASE__/positions.txt.gz probabilities: __ANALYSIS_BASE__/probabilities.txt.gz alignment: __ALIGNMENTS__/{prefix}_{strain}_chr{chrom}_mafft.maf + # positions are optional + positions: __ANALYSIS_BASE__/positions.txt.gz # software install locations software: diff --git a/code/test/analyze/test_main.py b/code/test/analyze/test_main.py index e771db4..ae01a95 100644 --- a/code/test/analyze/test_main.py +++ b/code/test/analyze/test_main.py @@ -41,7 +41,7 @@ def test_main_cli_configs(runner, mocker): mock_log_lvl.assert_called_once_with(level=log.WARNING) assert mock_log_info.call_args_list == [ mocker.call('Verbosity set to WARNING'), - mocker.call('Reading in 2 config files') + mocker.call('Read in 2 config files') ] assert mock_log_debug.call_args_list == [ mocker.call('Cleaned config:\ntest - 23\ntest2 - 34\n') @@ -59,7 +59,7 @@ def test_main_cli_verbosity(runner, mocker): mock_log_lvl.assert_called_once_with(level=log.CRITICAL) assert mock_log_info.call_args_list == [ mocker.call('Verbosity set to CRITICAL'), - mocker.call('Reading in 0 config files') + mocker.call('Read in 0 config files') ] mock_log_info.reset_mock() @@ -71,7 +71,7 @@ def test_main_cli_verbosity(runner, mocker): mock_log_lvl.assert_called_once_with(level=log.ERROR) assert mock_log_info.call_args_list == [ mocker.call('Verbosity set to ERROR'), - mocker.call('Reading in 0 config files') + mocker.call('Read in 0 config files') ] mock_log_info.reset_mock() @@ -83,7 +83,7 @@ def test_main_cli_verbosity(runner, mocker): mock_log_lvl.assert_called_once_with(level=log.WARNING) assert mock_log_info.call_args_list == [ mocker.call('Verbosity set to WARNING'), - mocker.call('Reading in 0 config files') + mocker.call('Read in 0 config files') ] mock_log_info.reset_mock() @@ -95,7 +95,7 @@ def test_main_cli_verbosity(runner, mocker): mock_log_lvl.assert_called_once_with(level=log.INFO) assert mock_log_info.call_args_list == [ mocker.call('Verbosity set to INFO'), - mocker.call('Reading in 0 config files') + mocker.call('Read in 0 config files') ] mock_log_info.reset_mock() @@ -107,7 +107,7 @@ def test_main_cli_verbosity(runner, mocker): mock_log_lvl.assert_called_once_with(level=log.DEBUG) assert mock_log_info.call_args_list == [ mocker.call('Verbosity set to DEBUG'), - mocker.call('Reading in 0 config files') + mocker.call('Read in 0 config files') ] mock_log_info.reset_mock() @@ -119,5 +119,5 @@ def test_main_cli_verbosity(runner, mocker): mock_log_lvl.assert_called_once_with(level=log.DEBUG) assert mock_log_info.call_args_list == [ mocker.call('Verbosity set to DEBUG'), - mocker.call('Reading in 0 config files') + mocker.call('Read in 0 config files') ] diff --git a/code/test/analyze/test_main_predict_args.py b/code/test/analyze/test_main_predict_args.py index 52e38e0..d67d374 100644 --- a/code/test/analyze/test_main_predict_args.py +++ b/code/test/analyze/test_main_predict_args.py @@ -34,7 +34,7 @@ def test_threshold(runner, mocker): assert str(result.exception) == 'No block file provided' assert mock_log.call_args_list == [ mocker.call('Verbosity set to WARNING'), - mocker.call('Reading in 1 config file'), + mocker.call('Read in 1 config file'), mocker.call('Found 3 chromosomes in config'), mocker.call("Threshold value is '0.05'") ] @@ -60,7 +60,7 @@ def test_block(runner, mocker): 'Unable to build prefix, no known states provided' assert mock_log.call_args_list == [ mocker.call('Verbosity set to WARNING'), - mocker.call('Reading in 1 config file'), + mocker.call('Read in 1 config file'), mocker.call('Found 3 chromosomes in config'), mocker.call("Threshold value is 'viterbi'"), mocker.call("Output blocks file is 'blocks_{state}.txt'"), @@ -87,7 +87,7 @@ def test_prefix(runner, mocker): 'Unable to find strains in config and no test_strains provided' assert mock_log.call_args_list == [ mocker.call('Verbosity set to WARNING'), - mocker.call('Reading in 1 config file'), + mocker.call('Read in 1 config file'), mocker.call('Found 3 chromosomes in config'), mocker.call("Threshold value is 'viterbi'"), mocker.call("Output blocks file is 'blocks_{state}.txt'"), @@ -125,7 +125,7 @@ def test_test_strains(runner, mocker): print(mock_log.call_args_list) assert mock_log.call_args_list == [ mocker.call('Verbosity set to WARNING'), - mocker.call('Reading in 1 config file'), + mocker.call('Read in 1 config file'), mocker.call('Found 3 chromosomes in config'), mocker.call("Threshold value is 'viterbi'"), mocker.call("Output blocks file is 'blocks_{state}.txt'"), @@ -140,7 +140,7 @@ def test_outputs(runner, mocker): mock_log = mocker.patch('analyze.main.log.info') mock_calls = [ mocker.call('Verbosity set to WARNING'), - mocker.call('Reading in 1 config file'), + mocker.call('Read in 1 config file'), mocker.call('Found 3 chromosomes in config'), mocker.call("Threshold value is 'viterbi'"), mocker.call("Output blocks file is 'blocks_{state}.txt'"), diff --git a/code/test/analyze/test_main_predict_config.py b/code/test/analyze/test_main_predict_config.py index c7ff871..ccf7070 100644 --- a/code/test/analyze/test_main_predict_config.py +++ b/code/test/analyze/test_main_predict_config.py @@ -40,7 +40,7 @@ def test_chroms(runner, mocker): assert str(result.exception) == 'No threshold provided' assert mock_log.call_args_list == [ mocker.call('Verbosity set to WARNING'), - mocker.call('Reading in 1 config file'), + mocker.call('Read in 1 config file'), mocker.call('Found 3 chromosomes in config') ] @@ -65,7 +65,7 @@ def test_threshold(runner, mocker): assert str(result.exception) == 'No block file provided' assert mock_log.call_args_list == [ mocker.call('Verbosity set to WARNING'), - mocker.call('Reading in 1 config file'), + mocker.call('Read in 1 config file'), mocker.call('Found 3 chromosomes in config'), mocker.call("Threshold value is 'viterbi'") ] @@ -95,7 +95,7 @@ def test_block(runner, mocker): 'Unable to build prefix, no known states provided' assert mock_log.call_args_list == [ mocker.call('Verbosity set to WARNING'), - mocker.call('Reading in 1 config file'), + mocker.call('Read in 1 config file'), mocker.call('Found 3 chromosomes in config'), mocker.call("Threshold value is 'viterbi'"), mocker.call("Output blocks file is 'blocks_{state}.txt'"), @@ -129,7 +129,7 @@ def test_prefix(runner, mocker): 'Unable to find strains in config and no test_strains provided' assert mock_log.call_args_list == [ mocker.call('Verbosity set to WARNING'), - mocker.call('Reading in 1 config file'), + mocker.call('Read in 1 config file'), mocker.call('Found 3 chromosomes in config'), mocker.call("Threshold value is 'viterbi'"), mocker.call("Output blocks file is 'blocks_{state}.txt'"), @@ -165,7 +165,7 @@ def test_strains(runner, mocker): 'No initial hmm file provided' assert mock_log.call_args_list == [ mocker.call('Verbosity set to WARNING'), - mocker.call('Reading in 1 config file'), + mocker.call('Read in 1 config file'), mocker.call('Found 3 chromosomes in config'), mocker.call("Threshold value is 'viterbi'"), mocker.call("Output blocks file is 'blocks_{state}.txt'"), @@ -212,7 +212,7 @@ def test_test_strains(runner, mocker): assert mock_log.call_args_list == [ mocker.call('Verbosity set to WARNING'), - mocker.call('Reading in 1 config file'), + mocker.call('Read in 1 config file'), mocker.call('Found 3 chromosomes in config'), mocker.call("Threshold value is 'viterbi'"), mocker.call("Output blocks file is 'blocks_{state}.txt'"), @@ -227,7 +227,7 @@ def test_outputs(runner, mocker): mock_log = mocker.patch('analyze.main.log.info') mock_calls = [ mocker.call('Verbosity set to WARNING'), - mocker.call('Reading in 1 config file'), + mocker.call('Read in 1 config file'), mocker.call('Found 3 chromosomes in config'), mocker.call("Threshold value is 'viterbi'"), mocker.call("Output blocks file is 'blocks_{state}.txt'"), diff --git a/code/test/analyze/test_predict_predictor.py b/code/test/analyze/test_predict_predictor.py index d91f0e6..30fb073 100644 --- a/code/test/analyze/test_predict_predictor.py +++ b/code/test/analyze/test_predict_predictor.py @@ -783,10 +783,10 @@ def test_run_prediction_full(predictor, mocker): ]) mock_log.info.assert_has_calls([ - mocker.call('working on: s1 I'), - mocker.call('working on: s2 I'), - mocker.call('working on: s1 II'), - mocker.call('working on: s2 II') + mocker.call('working on: s1 I (1 of 4)'), + mocker.call('working on: s2 I (2 of 4)'), + mocker.call('working on: s1 II (3 of 4)'), + mocker.call('working on: s2 II (4 of 4)') ]) diff --git a/code/test/helper_scripts/test_predict.slurm b/code/test/helper_scripts/test_predict.slurm index dffeae8..d407187 100755 --- a/code/test/helper_scripts/test_predict.slurm +++ b/code/test/helper_scripts/test_predict.slurm @@ -13,5 +13,6 @@ conda activate introgression3 introgression \ --config $config \ -vvvv \ + --log-file test.log \ predict From 3fef2526a99f962ecff287326facdefa2265af11 Mon Sep 17 00:00:00 2001 From: Troy Comi Date: Tue, 30 Apr 2019 09:22:42 -0400 Subject: [PATCH 19/33] Refactor id_regions Added class for adding region ids and integrated with the main click method. Created new class to hold configuration and handle setting logic, as that was heavily reused between main methods. --- .gitignore | 1 + code/analyze/id_regions.py | 70 +++ code/analyze/id_regions_main.py | 61 -- code/analyze/introgression_configuration.py | 395 ++++++++++++ code/analyze/main.py | 107 ++-- code/analyze/predict.py | 400 ++---------- code/config.yaml | 12 +- code/misc/config_utils.py | 29 +- code/test/analyze/test_id_regions.py | 142 +++++ code/test/analyze/test_id_regions_main.py | 112 ---- .../test_introgression_configuration.py | 592 ++++++++++++++++++ code/test/analyze/test_main.py | 9 +- code/test/analyze/test_main_id_config.py | 143 +++++ code/test/analyze/test_main_predict_args.py | 45 +- code/test/analyze/test_main_predict_config.py | 57 +- code/test/analyze/test_predict_hmm_builder.py | 39 +- code/test/analyze/test_predict_predictor.py | 555 ++-------------- code/test/helper_scripts/test_id_main.slurm | 17 +- code/test/helper_scripts/test_predict.slurm | 2 +- code/test/misc/test_config_utils.py | 45 +- 20 files changed, 1663 insertions(+), 1170 deletions(-) create mode 100644 code/analyze/id_regions.py delete mode 100644 code/analyze/id_regions_main.py create mode 100644 code/analyze/introgression_configuration.py create mode 100644 code/test/analyze/test_id_regions.py delete mode 100644 code/test/analyze/test_id_regions_main.py create mode 100644 code/test/analyze/test_introgression_configuration.py create mode 100644 code/test/analyze/test_main_id_config.py diff --git a/.gitignore b/.gitignore index b00c0bd..967c89c 100644 --- a/.gitignore +++ b/.gitignore @@ -35,3 +35,4 @@ code/setup/* *.swp *egg-info tags +*.log diff --git a/code/analyze/id_regions.py b/code/analyze/id_regions.py new file mode 100644 index 0000000..71090d7 --- /dev/null +++ b/code/analyze/id_regions.py @@ -0,0 +1,70 @@ +from contextlib import ExitStack +from operator import itemgetter +from analyze.introgression_configuration import Configuration +from analyze.predict import read_blocks +import click + + +class ID_producer(): + ''' + ID_producer + Adds unique region id to block files + ''' + def __init__(self, configuration: Configuration): + self.config = configuration + + def add_ids(self): + ''' + Adds a unique region id to block files, producing labeled text files + ''' + self.config.validate_id_regions_arguments() + regions = dict(zip(self.config.chromosomes, + [[] for _ in self.config.chromosomes])) + with ExitStack() as stack: + writers = {} + + # Progress bars don't seem to show since these complete too fast + progress_bar = None + if self.config.log_file: + progress_bar = stack.enter_context( + click.progressbar( + length=len(self.config.states), + label='Reading in states')) + + for state in self.config.states: + # read in region as dict keyed by strain, chromosome: + # (start, end, number non gapped) + region = read_blocks(self.config.blocks.format(state=state)) + for strain, d_strain in region.items(): + for chrm, d_chrm in d_strain.items(): + for start, end, num in d_chrm: + regions[chrm].append( + (start, end, num, strain, state)) + + # open writer + writers[state] = stack.enter_context( + open(self.config.labeled_blocks.format(state=state), 'w')) + writers[state].write( + 'region_id\tstrain\tchromosome\tpredicted_species\t' + 'start\tend\tnum_sites_hmm\n') + + if progress_bar: + progress_bar.update(1) + id_counter = 1 + + if progress_bar: + progress_bar = stack.enter_context( + click.progressbar( + length=len(regions.keys()), + label='Adding regions')) + + for chrm, entries in regions.items(): + # sort by start, then strain + for start, end, num, strain, state in \ + sorted(entries, key=itemgetter(0, 3)): + writers[state].write( + f'r{id_counter}\t{strain}\t{chrm}\t{state}\t' + f'{start}\t{end}\t{num}\n') + id_counter += 1 + if progress_bar: + progress_bar.update(1) diff --git a/code/analyze/id_regions_main.py b/code/analyze/id_regions_main.py deleted file mode 100644 index adc59c6..0000000 --- a/code/analyze/id_regions_main.py +++ /dev/null @@ -1,61 +0,0 @@ -import sys -from analyze import predict -from operator import itemgetter -import global_params as gp - - -def main() -> None: - ''' - Adds a unique region id to block files, producing labeled text files - Input files: - -blocks_{species}.txt - - Output files: - -blocks_{species}_labeled.txt - ''' - args = predict.process_predict_args(sys.argv[1:]) - - # order regions by chromosome, start (break ties alphabetically by strain) - all_regions_by_chrm = dict(zip(gp.chrms, [[] for chrm in gp.chrms])) - output_files = {} - base_dir = gp.analysis_out_dir_absolute + args['tag'] - for species_from in args['states']: - - # strain chromosome predicted_species start end number_non_gap - fn = f'{base_dir}/blocks_{species_from}_{args["tag"]}.txt' - - # introgressed regions keyed by strain and then chromosome: - # (start, end, number_non_gap) - regions = predict.read_blocks(fn) - - for strain in regions: - for chrm in regions[strain]: - for entry in regions[strain][chrm]: - start, end, number_non_gap = entry - all_regions_by_chrm[chrm].append( - (start, end, number_non_gap, strain, species_from)) - - output_files[species_from] = f'{fn[:-4]}_labeled.txt' - - writers = {} - for species_from in args['states']: - writers[species_from] = open(output_files[species_from], 'w') - writers[species_from].write( - 'region_id\tstrain\tchromosome\tpredicted_species\t' - 'start\tend\tnum_sites_hmm\n') - - idc = 1 - for chrm in gp.chrms: - for entry in sorted(all_regions_by_chrm[chrm], key=itemgetter(0, 3)): - (start, end, number_non_gap, strain, species_from) = entry - writers[species_from].write( - f'r{idc}\t{strain}\t{chrm}\t{species_from}\t' - f'{start}\t{end}\t{number_non_gap}\n') - idc += 1 - - for species_from in args['states']: - writers[species_from].close() - - -if __name__ == "__main__": - main() diff --git a/code/analyze/introgression_configuration.py b/code/analyze/introgression_configuration.py new file mode 100644 index 0000000..fa3d364 --- /dev/null +++ b/code/analyze/introgression_configuration.py @@ -0,0 +1,395 @@ +import glob +import re +from typing import Tuple, Dict, List +import logging as log +from misc.config_utils import (get_nested, clean_config, merge_dicts, + print_dict, validate, check_wildcards) + + +class Configuration(): + def __init__(self): + self.config = {} + self.log_file = None + + def add_config(self, configuration: Dict): + ''' + merge the provided configuration dictionary with this object. + Cleans configuration + ''' + self.config = clean_config( + merge_dicts(self.config, configuration)) + + def get_states(self) -> Tuple[List, List]: + ''' + Build lists of known and unknown states from the analysis params + ''' + + ref = get_nested(self.config, 'analysis_params.reference.name') + if ref is None: + ref = [] + else: + ref = [ref] + + known = get_nested(self.config, 'analysis_params.known_states') + if known is None: + known = [] + + known_states = ref + [s['name'] for s in known] + + unknown = get_nested(self.config, 'analysis_params.unknown_states') + if unknown is None: + unknown = [] + + unknown_states = [s['name'] for s in unknown] + + return known_states, unknown_states + + def set_states(self, states: List[str] = None): + ''' + Set the states for which to perform region naming + ''' + if states is None or states == []: + self.known_states, self.unknown_states = self.get_states() + self.states = self.known_states + self.unknown_states + else: + self.states = states + + if self.states == []: + err = 'No states specified' + log.exception(err) + raise ValueError(err) + + def set_log_file(self, log_file: str = ''): + ''' + sets log file based on provided value or config + ''' + if log_file == '': + self.log_file = get_nested(self.config, 'paths.log_file') + else: + self.log_file = log_file + + def set_chromosomes(self): + ''' + Gets the chromosome list from config, raising a ValueError + if undefined. + ''' + self.chromosomes = validate( + self.config, + 'chromosomes', + 'No chromosomes specified in config file!') + + def set_threshold(self, threshold: str = None): + ''' + Set the threshold. Checks if set and converts to float if possible. + Failing float casting, will store a string if it is 'viterbi', + otherwise throws a ValueError + ''' + self.threshold = validate( + self.config, + 'analysis_params.threshold', + 'No threshold provided', + threshold) + try: + self.threshold = float(self.threshold) + except ValueError: + if self.threshold != 'viterbi': + err = f'Unsupported threshold value: {self.threshold}' + log.exception(err) + raise ValueError(err) + + def set_blocks_file(self, blocks: str = None): + ''' + Set the block wildcard filename. Checks for appropriate wildcards + ''' + self.blocks = validate( + self.config, + 'paths.analysis.block_files', + 'No block file provided', + blocks) + + check_wildcards(self.blocks, 'state') + + def set_labeled_blocks_file(self, blocks: str = None): + ''' + Set the labeled block wildcard filename. + Checks for appropriate wildcards + ''' + self.labeled_blocks = validate( + self.config, + 'paths.analysis.labeled_block_files', + 'No labeled block file provided', + blocks) + + check_wildcards(self.labeled_blocks, 'state') + + def set_prefix(self, prefix: str = ''): + ''' + Set prefix string of the predictor to the supplied value or + build it from the known states + ''' + if prefix == '': + if self.known_states == []: + err = 'Unable to build prefix, no known states provided' + log.exception(err) + raise ValueError(err) + + self.prefix = '_'.join(self.known_states) + else: + self.prefix = prefix + + def set_strains(self, test_strains: str = ''): + ''' + build the strains to perform prediction on + ''' + if test_strains == '': + test_strains = get_nested(self.config, 'paths.test_strains') + else: + # need to support list for test strains + test_strains = [test_strains] + + if test_strains is not None: + for test_strain in test_strains: + check_wildcards(test_strain, 'strain,chrom') + + self.find_strains(test_strains) + + def find_strains(self, test_strains: List[str] = None): + ''' + Helper method to get strains supplied in config, or from test_strains + ''' + strains = get_nested(self.config, 'strains') + self.test_strains = test_strains + + if strains is None: + if test_strains is None: + err = ('Unable to find strains in config and ' + 'no test_strains provided') + log.exception(err) + raise ValueError(err) + + # try to build strains from wildcards in test_strains + strains = {} + for test_strain in test_strains: + # find matching files + strain_glob = test_strain.format( + strain='*', + chrom='*') + log.info(f'searching for {strain_glob}') + for fname in glob.iglob(strain_glob): + # extract wildcard matches + match = re.match( + test_strain.format( + strain='(?P.*?)', + chrom='(?P[^_]*?)' + ), + fname) + if match: + log.debug( + f'matched with {match.group("strain", "chrom")}') + strain, chrom = match.group('strain', 'chrom') + if strain not in strains: + strains[strain] = set() + strains[strain].add(chrom) + + if len(strains) == 0: + err = ('Found no chromosome sequence files ' + f'in {test_strains}') + log.exception(err) + raise ValueError(err) + + # check if requested chromosomes are within the list of chroms + chrom_set = set(self.chromosomes) + for strain, chroms in strains.items(): + if not chrom_set.issubset(chroms): + not_found = chrom_set.difference(chroms).pop() + err = (f'Strain {strain} is missing chromosomes. ' + f'Unable to find chromosome \'{not_found}\'') + log.exception(err) + raise ValueError(err) + + self.strains = list(sorted(strains.keys())) + + else: # strains set in config + self.strains = list(sorted(set(strains))) + + def set_predict_files(self, + hmm_initial: str, + hmm_trained: str, + positions: str, + probabilities: str, + alignment: str): + ''' + Set output files from provided values or config. + Raises value errors if a file is not provided. + Checks alignment for all wildcards and replaces prefix. + ''' + self.hmm_initial = validate(self.config, + 'paths.analysis.hmm_initial', + 'No initial hmm file provided', + hmm_initial) + + self.hmm_trained = validate(self.config, + 'paths.analysis.hmm_trained', + 'No trained hmm file provided', + hmm_trained) + + if positions == '': + self.positions = get_nested(self.config, + 'paths.analysis.positions') + else: + self.positions = positions + + self.probabilities = validate(self.config, + 'paths.analysis.probabilities', + 'No probabilities file provided', + probabilities) + + alignment = validate(self.config, + 'paths.analysis.alignment', + 'No alignment file provided', + alignment) + check_wildcards(alignment, 'prefix,strain,chrom') + self.alignment = alignment.replace('{prefix}', self.prefix) + + def set_HMM_symbols(self): + ''' + Set symbols based on config values, using defaults if unset + ''' + self.symbols = { + 'match': '+', + 'mismatch': '-', + 'unknown': '?', + 'unsequenced': 'n', + 'gap': '-', + 'unaligned': '?', + 'masked': 'x' + } + config_symbols = get_nested(self.config, 'HMM_symbols') + if config_symbols is not None: + for k, v in config_symbols.items(): + if k not in self.symbols: + log.warning("Unused symbol in configuration: " + f"{k} -> '{v}'") + else: + self.symbols[k] = v + log.debug(f"Overwriting default symbol for {k} with '{v}'") + + for k, v in self.symbols.items(): + if k not in config_symbols: + log.warning(f'Symbol for {k} unset in config, ' + f"using default '{v}'") + + else: + for k, v in self.symbols.items(): + log.warning(f'Symbol for {k} unset in config, ' + f"using default '{v}'") + + def set_convergence(self): + ''' + Set convergence for HMM training, using default if unset + ''' + self.convergence = get_nested(self.config, + 'analysis_params.convergence_threshold') + if self.convergence is None: + log.warning('No value set for convergence_threshold, using ' + 'default of 0.001') + self.convergence = 0.001 + + def get(self, key: str): + ''' + Get nested key from underlying dictionary. Returning none if any + key is not in dict + ''' + return get_nested(self.config, key) + + def validate_predict_arguments(self): + ''' + Check that all required instance variables are set to perform a + prediction run. Returns true if valid, raises value error otherwise + ''' + args = [ + 'chromosomes', + 'blocks', + 'prefix', + 'strains', + 'hmm_initial', + 'hmm_trained', + 'probabilities', + 'alignment', + 'known_states', + 'unknown_states', + 'threshold', + ] + variables = self.__dict__ + for arg in args: + if arg not in variables or variables[arg] is None: + err = ('Failed to validate Predictor, required argument ' + f"'{arg}' was unset") + log.exception(err) + raise ValueError(err) + + # check the parameters for each state are present + known_states = self.get('analysis_params.known_states') + if known_states is None: + err = 'Configuration did not provide any known_states' + log.exception(err) + raise ValueError(err) + + for s in known_states: + if 'expected_length' not in s: + err = f'{s["name"]} did not provide an expected_length' + log.exception(err) + raise ValueError(err) + if 'expected_fraction' not in s: + err = f'{s["name"]} did not provide an expected_fraction' + log.exception(err) + raise ValueError(err) + + unknown_states = self.get('analysis_params.unknown_states') + if unknown_states is not None: + for s in unknown_states: + if 'expected_length' not in s: + err = f'{s["name"]} did not provide an expected_length' + log.exception(err) + raise ValueError(err) + if 'expected_fraction' not in s: + err = f'{s["name"]} did not provide an expected_fraction' + log.exception(err) + raise ValueError(err) + + reference = self.get('analysis_params.reference') + if reference is None: + err = f'Configuration did not specify a reference strain' + log.exception(err) + raise ValueError(err) + + return True + + def validate_id_regions_arguments(self): + ''' + Check that all required instance variables are set to perform a + id producer run. Returns true if valid, raises value error otherwise + ''' + args = [ + 'chromosomes', + 'blocks', + 'labeled_blocks', + 'states', + ] + variables = self.__dict__ + for arg in args: + if arg not in variables or variables[arg] is None: + err = ('Failed to validate ID Producer, required argument ' + f"'{arg}' was unset") + log.exception(err) + raise ValueError(err) + + return True + + def __repr__(self): + return ('Config file:\n' + + print_dict(self.config) + + '\nSettings:\n' + + print_dict({k: v for k, v in self.__dict__.items() + if k != 'config'}) + ) diff --git a/code/analyze/main.py b/code/analyze/main.py index 57519e3..082f3fd 100644 --- a/code/analyze/main.py +++ b/code/analyze/main.py @@ -1,8 +1,9 @@ import click import yaml import logging as log -from misc import config_utils import analyze.predict +from analyze.introgression_configuration import Configuration +from analyze.id_regions import ID_producer # TODO also check for snakemake object? @@ -31,40 +32,31 @@ def cli(ctx, config, verbosity, log_file): ('DEBUG', log.DEBUG), ][verbosity] - ctx.ensure_object(dict) + ctx.ensure_object(Configuration) confs = len(config) for path in config: conf = yaml.safe_load(path) - ctx.obj = config_utils.merge_dicts(ctx.obj, conf) + ctx.obj.add_config(conf) - ctx.obj = config_utils.clean_config(ctx.obj) - - if log_file == '': - log_file = config_utils.get_nested(ctx.obj, 'paths.log_file') - - if config_utils.get_nested(ctx.obj, 'paths'): - ctx.obj['paths']['log_file'] = log_file - - if log_file is not None: - log.basicConfig(level=level, filename=log_file, filemode='w') + ctx.obj.set_log_file(log_file) + if ctx.obj.log_file is not None: + log.basicConfig(level=level, filename=ctx.obj.log_file, filemode='w') else: log.basicConfig(level=level) log.info(f'Verbosity set to {levelstr}') log.info(f'Read in {confs} config file{"" if confs == 1 else "s"}') - log.debug('Cleaned config:\n' + config_utils.print_dict(ctx.obj)) + log.debug('Cleaned config:\n' + repr(ctx.obj)) if ctx.invoked_subcommand is None: click.echo_via_pager( click.style( 'No command supplied. Read in the following config:\n', - fg='yellow') + - config_utils.print_dict(ctx.obj)) + fg='yellow') + repr(ctx.obj)) @cli.command() -@click.pass_context @click.option('--blocks', default='', help='Block file location with {state}') @click.option('--prefix', default='', help='Prefix of test-strain files ' 'default to list of states joined with _.') @@ -87,6 +79,7 @@ def cli(ctx, config, verbosity, log_file): @click.option('--only-poly-sites/--all-sites', default=True, help='Consider only polymorphic sites or all sites. ' 'Default is only polymorphic.') +@click.pass_context def predict(ctx, blocks, prefix, @@ -100,41 +93,73 @@ def predict(ctx, only_poly_sites): config = ctx.obj - predictor = analyze.predict.Predictor(config) - predictor.set_chromosomes() - log.info(f'Found {len(predictor.chromosomes)} chromosomes in config') + config.set_chromosomes() + log.info(f'Found {len(config.chromosomes)} chromosomes in config') - predictor.set_threshold(threshold) - log.info(f'Threshold value is \'{predictor.threshold}\'') + config.set_threshold(threshold) + log.info(f'Threshold value is \'{config.threshold}\'') - predictor.set_blocks_file(blocks) - log.info(f'Output blocks file is \'{predictor.blocks}\'') + config.set_blocks_file(blocks) + log.info(f'Output blocks file is \'{config.blocks}\'') - predictor.set_prefix(prefix) - log.info(f'Prefix is \'{predictor.prefix}\'') + config.set_states() + config.set_prefix(prefix) + log.info(f'Prefix is \'{config.prefix}\'') - predictor.set_strains(test_strains) - if predictor.test_strains is None: + config.set_strains(test_strains) + if config.test_strains is None: log.info(f'No test_strains provided') else: - str_len = len(predictor.test_strains) + str_len = len(config.test_strains) log.info(f'Found {str_len} test strain' f'{"" if str_len == 1 else "s"}') - log.info(f'Found {len(predictor.strains)} unique strains') - - predictor.set_output_files(hmm_initial, - hmm_trained, - positions, - probabilities, - alignment) - log.info(f'Hmm_initial file is \'{predictor.hmm_initial}\'') - log.info(f'Hmm_trained file is \'{predictor.hmm_trained}\'') - log.info(f'Positions file is \'{predictor.positions}\'') - log.info(f'Probabilities file is \'{predictor.probabilities}\'') - log.info(f'Alignment file is \'{predictor.alignment}\'') + str_len = len(config.strains) + log.info(f'Found {str_len} unique strain' + f'{"" if str_len == 1 else "s"}') + + config.set_predict_files(hmm_initial, + hmm_trained, + positions, + probabilities, + alignment) + log.info(f'Hmm_initial file is \'{config.hmm_initial}\'') + log.info(f'Hmm_trained file is \'{config.hmm_trained}\'') + log.info(f'Positions file is \'{config.positions}\'') + log.info(f'Probabilities file is \'{config.probabilities}\'') + log.info(f'Alignment file is \'{config.alignment}\'') + predictor = analyze.predict.Predictor(config) + if only_poly_sites: + log.info('Only considering polymorphic sites') + else: + log.info('Considering all sites') predictor.run_prediction(only_poly_sites) +# accept multiple states and pass as list +@cli.command() +@click.option('--blocks', default='', help='Block file location with {state}') +@click.option('--labeled', default='', help='Block file location with {state}') +@click.option('--state', multiple=True, help='States to add ids to') +@click.pass_context +def id_regions(ctx, blocks, labeled, state): + config = ctx.obj + config.set_chromosomes() + log.info(f'Found {len(config.chromosomes)} chromosomes in config') + + state = list(state) + config.set_states(state) + log.info(f'Found {len(config.states)} states to process') + + config.set_blocks_file(blocks) + log.info(f'Input blocks file is \'{config.blocks}\'') + + config.set_labeled_blocks_file(labeled) + log.info(f'Output blocks file is \'{config.labeled_blocks}\'') + + id_producer = ID_producer(config) + id_producer.add_ids() + + if __name__ == '__main__': cli() diff --git a/code/analyze/predict.py b/code/analyze/predict.py index f3affaa..7a544e7 100644 --- a/code/analyze/predict.py +++ b/code/analyze/predict.py @@ -1,7 +1,5 @@ import copy import gzip -import glob -import re import os import itertools import click @@ -14,8 +12,7 @@ from contextlib import ExitStack import logging as log from misc.read_fasta import read_fasta -from misc.config_utils import (check_wildcards, validate, - get_states, get_nested) +from analyze.introgression_configuration import Configuration # TODO remove gp references for symbols. pass args or fold into object? @@ -84,333 +81,89 @@ class Predictor(): Predictor class Stores all variables needed to run an HMM prediction ''' - def __init__(self, configuration: Dict): + def __init__(self, configuration: Configuration): self.config = configuration - self.known_states, self.unknown_states = get_states(self.config) - self.states = self.known_states + self.unknown_states - self.chromosomes = None - self.blocks = None - self.prefix = None - self.strains = None - self.hmm_initial = None - self.hmm_trained = None - self.positions = None - self.probabilities = None - self.alignment = None - self.threshold = None - - def set_chromosomes(self): - ''' - Gets the chromosome list from provided config, raising a ValueError - if undefined. - ''' - self.chromosomes = validate( - self.config, - 'chromosomes', - 'No chromosomes specified in config file!') - - def set_blocks_file(self, blocks: str = None): - ''' - Set the block wildcard filename. Checks for appropriate wildcards - ''' - self.blocks = validate( - self.config, - 'paths.analysis.block_files', - 'No block file provided', - blocks) - - check_wildcards(self.blocks, 'state') - - def set_prefix(self, prefix: str = ''): - ''' - Set prefix string of the predictor to the supplied value or - build it from the known states - ''' - if prefix == '': - if self.known_states == []: - err = 'Unable to build prefix, no known states provided' - log.exception(err) - raise ValueError(err) - - self.prefix = '_'.join(self.known_states) - else: - self.prefix = prefix - - def set_threshold(self, threshold: str = None): - ''' - Set the threshold. Checks if set and converts to float if possible - ''' - self.threshold = validate( - self.config, - 'analysis_params.threshold', - 'No threshold provided', - threshold) - try: - self.threshold = float(self.threshold) - except ValueError: - if self.threshold != 'viterbi': - err = f'Unsupported threshold value: {self.threshold}' - log.exception(err) - raise ValueError(err) - - def set_strains(self, test_strains: str = ''): - ''' - build the strains to perform prediction on - ''' - if test_strains == '': - test_strains = get_nested(self.config, 'paths.test_strains') - else: - # need to support list for test strains - test_strains = [test_strains] - - if test_strains is not None: - for test_strain in test_strains: - check_wildcards(test_strain, 'strain,chrom') - - self.find_strains(test_strains) - - def find_strains(self, test_strains: List[str] = None): - ''' - Helper method to get strains supplied in config, or from test_strains - ''' - strains = get_nested(self.config, 'strains') - self.test_strains = test_strains - - if strains is None: - if test_strains is None: - err = ('Unable to find strains in config and ' - 'no test_strains provided') - log.exception(err) - raise ValueError(err) - - # try to build strains from wildcards in test_strains - strains = {} - for test_strain in test_strains: - # find matching files - strain_glob = test_strain.format( - strain='*', - chrom='*') - log.info(f'searching for {strain_glob}') - for fname in glob.iglob(strain_glob): - # extract wildcard matches - match = re.match( - test_strain.format( - strain='(?P.*?)', - chrom='(?P[^_]*?)' - ), - fname) - if match: - log.debug( - f'matched with {match.group("strain", "chrom")}') - strain, chrom = match.group('strain', 'chrom') - if strain not in strains: - strains[strain] = set() - strains[strain].add(chrom) - - if len(strains) == 0: - err = ('Found no chromosome sequence files ' - f'in {test_strains}') - log.exception(err) - raise ValueError(err) - - # check if requested chromosomes are within the list of chroms - chrom_set = set(self.chromosomes) - for strain, chroms in strains.items(): - if not chrom_set.issubset(chroms): - not_found = chrom_set.difference(chroms).pop() - err = (f'Strain {strain} is missing chromosomes. ' - f'Unable to find chromosome \'{not_found}\'') - log.exception(err) - raise ValueError(err) - - self.strains = list(sorted(strains.keys())) - - else: # strains set in config - self.strains = list(sorted(set(strains))) - - def set_output_files(self, - hmm_initial: str, - hmm_trained: str, - positions: str, - probabilities: str, - alignment: str): - ''' - Set output files from provided values or config. - Raises value errors if a file is not provided. - Checks alignment for all wildcards and replaces prefix. - ''' - self.hmm_initial = validate(self.config, - 'paths.analysis.hmm_initial', - 'No initial hmm file provided', - hmm_initial) - - self.hmm_trained = validate(self.config, - 'paths.analysis.hmm_trained', - 'No trained hmm file provided', - hmm_trained) - - if positions == '': - self.positions = get_nested(self.config, - 'paths.analysis.positions') - else: - self.positions = positions - - self.probabilities = validate(self.config, - 'paths.analysis.probabilities', - 'No probabilities file provided', - probabilities) - - alignment = validate(self.config, - 'paths.analysis.alignment', - 'No alignment file provided', - alignment) - check_wildcards(alignment, 'prefix,strain,chrom') - self.alignment = alignment.replace('{prefix}', self.prefix) - - def validate_arguments(self): - ''' - Check that all required instance variables are set to perform a - prediction run. Returns true if valid, raises value error otherwise - ''' - args = [ - 'chromosomes', - 'blocks', - 'prefix', - 'strains', - 'hmm_initial', - 'hmm_trained', - 'probabilities', - 'alignment', - 'known_states', - 'unknown_states', - 'threshold', - ] - variables = self.__dict__ - for arg in args: - if variables[arg] is None: - err = ('Failed to validate Predictor, required argument ' - f'{arg} was unset') - log.exception(err) - raise ValueError(err) - - # check the parameters for each state are present - known_states = get_nested(self.config, - 'analysis_params.known_states') - if known_states is None: - err = 'Configuration did not provide any known_states' - log.exception(err) - raise ValueError(err) - - for s in known_states: - if 'expected_length' not in s: - err = f'{s["name"]} did not provide an expected_length' - log.exception(err) - raise ValueError(err) - if 'expected_fraction' not in s: - err = f'{s["name"]} did not provide an expected_fraction' - log.exception(err) - raise ValueError(err) - - unknown_states = get_nested(self.config, - 'analysis_params.unknown_states') - if unknown_states is not None: - for s in unknown_states: - if 'expected_length' not in s: - err = f'{s["name"]} did not provide an expected_length' - log.exception(err) - raise ValueError(err) - if 'expected_fraction' not in s: - err = f'{s["name"]} did not provide an expected_fraction' - log.exception(err) - raise ValueError(err) - - reference = get_nested(self.config, - 'analysis_params.reference') - if reference is None: - err = f'Configuration did not specify a reference strain' - log.exception(err) - raise ValueError(err) - - return True def run_prediction(self, only_poly_sites=True): ''' Run prediction with this predictor object ''' - self.validate_arguments() + self.config.validate_predict_arguments() hmm_builder = HMM_Builder(self.config) hmm_builder.set_expected_values() self.emission_symbols = \ - hmm_builder.update_emission_symbols(len(self.known_states)) + hmm_builder.update_emission_symbols(len(self.config.known_states)) - with open(self.hmm_initial, 'w') as initial, \ - open(self.hmm_trained, 'w') as trained, \ - gzip.open(self.probabilities, 'wt') as probabilities, \ + with open(self.config.hmm_initial, 'w') as initial, \ + open(self.config.hmm_trained, 'w') as trained, \ + gzip.open(self.config.probabilities, 'wt') as probabilities, \ ExitStack() as stack: self.write_hmm_header(initial) self.write_hmm_header(trained) - if self.positions is not None: + if self.config.positions is not None: positions = stack.enter_context( - gzip.open(self.positions, 'wt')) + gzip.open(self.config.positions, 'wt')) else: positions = None block_writers = {state: stack.enter_context( - open(self.blocks.format(state=state), 'w')) + open(self.config.blocks.format( + state=state), 'w')) for state in - self.states} + self.config.states} for writer in block_writers.values(): self.write_blocks_header(writer) counter = 0 - total = len(self.chromosomes) * len(self.strains) - # logging to file + total = len(self.config.chromosomes) * len(self.config.strains) progress_bar = None - if get_nested(self.config, 'paths.log_file'): + if self.config.log_file: # logging to file progress_bar = stack.enter_context( click.progressbar( length=total, label='Running prediction')) - for chrom in self.chromosomes: - for strain in self.strains: + for chrom in self.config.chromosomes: + for strain in self.config.strains: counter += 1 log.info(f'working on: {strain} {chrom} ' f'({counter} of {total})') # get sequences and encode - alignment_file = self.alignment.format( + alignment_file = self.config.alignment.format( strain=strain, chrom=chrom) if not os.path.exists(alignment_file): log.info(f'skipping, file {alignment_file} not found') - continue - hmm_initial, hmm_trained, pos = hmm_builder.run_hmm( - alignment_file, only_poly_sites) + else: + hmm_initial, hmm_trained, pos = hmm_builder.run_hmm( + alignment_file, only_poly_sites) - self.write_hmm(hmm_initial, initial, strain, chrom) - self.write_hmm(hmm_trained, trained, strain, chrom) + self.write_hmm(hmm_initial, initial, strain, chrom) + self.write_hmm(hmm_trained, trained, strain, chrom) - # process and threshold hmm result - predicted_states, probs = self.process_path(hmm_trained) - state_blocks = self.convert_to_blocks(predicted_states) + # process and threshold hmm result + predicted_states, probs = self.process_path( + hmm_trained) + state_blocks = self.convert_to_blocks(predicted_states) - if positions is not None: - self.write_positions(pos, positions, strain, chrom) + if positions is not None: + self.write_positions(pos, positions, strain, chrom) - for state, block in state_blocks.items(): - self.write_blocks(block, - pos, - block_writers[state], - strain, - chrom, - state) + for state, block in state_blocks.items(): + self.write_blocks(block, + pos, + block_writers[state], + strain, + chrom, + state) - self.write_state_probs(probs, probabilities, strain, chrom) + self.write_state_probs(probs, probabilities, + strain, chrom) if progress_bar: progress_bar.update(1) @@ -424,7 +177,7 @@ def write_hmm_header(self, writer: TextIO) -> None: writer.write('strain\tchromosome\t') - states = self.known_states + self.unknown_states + states = self.config.states writer.write('\t'.join( [f'init_{s}' for s in states] + # initial @@ -527,7 +280,7 @@ def write_state_probs(self, writer.write('\t'.join( [f'{state}:' + ','.join([f'{site[i]:.5f}' for site in probs]) - for i, state in enumerate(self.states)])) + for i, state in enumerate(self.config.states)])) writer.write('\n') @@ -540,17 +293,20 @@ def process_path(self, hmm: hmm_bw.HMM) -> Tuple[List[str], np.array]: probabilities = hmm.posterior_decoding()[0] # posterior - if type(self.threshold) is float: + if type(self.config.threshold) is float: path, path_probs = sim_process.get_max_path(probabilities, hmm.hidden_states) - path_t = sim_process.threshold_predicted(path, path_probs, - self.threshold, - self.known_states[0]) + path_t = sim_process.threshold_predicted( + path, + path_probs, + self.config.threshold, + self.config.known_states[0]) + return path_t, probabilities else: predicted = sim_predict.convert_predictions(hmm.viterbi(), - self.states) + self.config.states) return predicted, probabilities def convert_to_blocks(self, @@ -563,7 +319,7 @@ def convert_to_blocks(self, ''' # single individual state sequence blocks = {} - for state in self.states: + for state in self.config.states: blocks[state] = [] prev_species = state_seq[0] block_start = 0 @@ -585,43 +341,11 @@ def convert_to_blocks(self, class HMM_Builder(): - def __init__(self, configuration): + def __init__(self, configuration: Configuration): self.config = configuration - self.symbols = { - 'match': '+', - 'mismatch': '-', - 'unknown': '?', - 'unsequenced': 'n', - 'gap': '-', - 'unaligned': '?', - 'masked': 'x' - } - config_symbols = get_nested(self.config, 'HMM_symbols') - if config_symbols is not None: - for k, v in config_symbols.items(): - if k not in self.symbols: - log.warning("Unused symbol in configuration: " - f"{k} -> '{v}'") - else: - self.symbols[k] = v - log.debug(f"Overwriting default symbol for {k} with '{v}'") - - for k, v in self.symbols.items(): - if k not in config_symbols: - log.warning(f'Symbol for {k} unset in config, ' - f"using default '{v}'") - - else: - for k, v in self.symbols.items(): - log.warning(f'Symbol for {k} unset in config, ' - f"using default '{v}'") - - self.convergence = get_nested(self.config, - 'analysis_params.convergence_threshold') - if self.convergence is None: - log.warning('No value set for convergence_threshold, using ' - 'default of 0.001') - self.convergence = 0.001 + self.config.set_HMM_symbols() + self.symbols = self.config.symbols + self.config.set_convergence() def update_emission_symbols(self, repeats: int): ''' @@ -671,29 +395,25 @@ def set_expected_values(self): ''' self.expected_lengths = {} self.expected_fractions = {} - known_states = get_nested(self.config, - 'analysis_params.known_states') + known_states = self.config.get('analysis_params.known_states') for state in known_states: self.expected_lengths[state['name']] = state['expected_length'] self.expected_fractions[state['name']] = state['expected_fraction'] - unknown_states = get_nested(self.config, - 'analysis_params.unknown_states') + unknown_states = self.config.get('analysis_params.unknown_states') for state in unknown_states: self.expected_lengths[state['name']] = state['expected_length'] self.expected_fractions[state['name']] = state['expected_fraction'] - reference = get_nested(self.config, - 'analysis_params.reference') + reference = self.config.get('analysis_params.reference') # expected fraction of reference is the remainder after other states # are specified self.expected_fractions[reference['name']] =\ 1 - sum(self.expected_fractions.values()) - self.known_states, self.unknown_states = get_states(self.config) - - self.ref_state = get_nested(self.config, - 'analysis_params.reference.name') + self.ref_state = self.config.get('analysis_params.reference.name') + self.known_states = self.config.known_states + self.unknown_states = self.config.unknown_states # have to remove effect of unknown of these values for later self.ref_fraction = self.expected_fractions[self.ref_state] + \ @@ -805,7 +525,7 @@ def transition_probabilities(self) -> np.array: # of genome? maybe theoretically, but that number is a lot more # suspect - states = self.known_states + self.unknown_states + states = self.config.states fractions = np.array([self.expected_fractions[s] for s in states]) lengths = 1/np.array([self.expected_lengths[s] for s in states]) @@ -868,7 +588,7 @@ def run_hmm(self, hmm.set_observations([coded_sequence]) # Baum-Welch parameter estimation - hmm.train(self.convergence) + hmm.train(self.config.convergence) return hmm_init, hmm, positions diff --git a/code/config.yaml b/code/config.yaml index 219b9a6..c7d8db4 100644 --- a/code/config.yaml +++ b/code/config.yaml @@ -40,10 +40,12 @@ paths: suffix: .txt analysis: - analysis_base: __OUTPUT_ROOT__/analysis_chr1_test + analysis_base: __OUTPUT_ROOT__/analysisp4e2 regions: __ANALYSIS_BASE__/regions/ genes: __ANALYSIS_BASE__/genes/ - block_files: __ANALYSIS_BASE__/blocks_{state}.txt + block_files: __ANALYSIS_BASE__/blocks_{state}_p4e2.txt + labeled_block_files: "__ANALYSIS_BASE__/../analysis_test/\ + blocks_{state}__test_labeled.txt" hmm_initial: __ANALYSIS_BASE__/hmm_initial.txt hmm_trained: __ANALYSIS_BASE__/hmm_trained.txt probabilities: __ANALYSIS_BASE__/probabilities.txt.gz @@ -66,9 +68,9 @@ paths: ldselect: __ROOT_INSTALL__/ldSelect/ structure: __ROOT_INSTALL__/structure/ -chromosomes: ['I'] -# chromosomes: ['I', 'II', 'III', 'IV', 'V', 'VI', 'VII', 'VIII', -# 'IX', 'X', 'XI', 'XII', 'XIII', 'XIV', 'XV', 'XVI'] +# chromosomes: ['I'] +chromosomes: ['I', 'II', 'III', 'IV', 'V', 'VI', 'VII', 'VIII', + 'IX', 'X', 'XI', 'XII', 'XIII', 'XIV', 'XV', 'XVI'] # can optionally list all strains to consider # if blank will glob with TEST_STRAINS paths diff --git a/code/misc/config_utils.py b/code/misc/config_utils.py index 428855a..2efeb5d 100644 --- a/code/misc/config_utils.py +++ b/code/misc/config_utils.py @@ -1,6 +1,6 @@ import re from copy import copy -from typing import Dict, List, Tuple +from typing import Dict, List import logging as log @@ -204,33 +204,6 @@ def check_wildcards(path: str, wildcards: str) -> bool: return True -def get_states(config: Dict) -> Tuple[List, List]: - ''' - From the provided config dict, build lists of known and unknown states - from the analysis params - ''' - - ref = get_nested(config, 'analysis_params.reference.name') - if ref is None: - ref = [] - else: - ref = [ref] - - known = get_nested(config, 'analysis_params.known_states') - if known is None: - known = [] - - known_states = ref + [s['name'] for s in known] - - unknown = get_nested(config, 'analysis_params.unknown_states') - if unknown is None: - unknown = [] - - unknown_states = [s['name'] for s in unknown] - - return known_states, unknown_states - - def validate(config: Dict, path: str, exception: str, diff --git a/code/test/analyze/test_id_regions.py b/code/test/analyze/test_id_regions.py new file mode 100644 index 0000000..62151b4 --- /dev/null +++ b/code/test/analyze/test_id_regions.py @@ -0,0 +1,142 @@ +from analyze import id_regions +import pytest +from analyze.introgression_configuration import Configuration + + +@pytest.fixture +def id_producer(): + config = Configuration() + config.add_config({ + 'analysis_params': + {'reference': {'name': 'S288c'}, + 'known_states': [ + {'name': 'CBS432'}, + {'name': 'N_45'}, + {'name': 'DBVPG6304'}, + {'name': 'UWOPS91_917_1'}, + ], + 'unknown_states': [{'name': 'unknown'}] + } + }) + config.set_states() + result = id_regions.ID_producer(config) + return result + + +def test_producer(id_producer): + assert id_producer.config.known_states == \ + 'S288c CBS432 N_45 DBVPG6304 UWOPS91_917_1'.split() + assert id_producer.config.unknown_states == \ + 'unknown'.split() + assert id_producer.config.states == \ + 'S288c CBS432 N_45 DBVPG6304 UWOPS91_917_1 unknown'.split() + + +def test_add_ids_empty(id_producer, mocker): + id_producer.config.add_config({ + 'chromosomes': ['I', 'II', 'III', 'IV', 'V', 'VI', 'VII', 'VIII', 'IX', + 'X', 'XI', 'XII', 'XIII', 'XIV', 'XV', 'XVI'], + 'paths': {'analysis': {'block_files': 'dir/blocks_{state}.txt', + 'labeled_block_files': + 'dir/blocks_{state}_labeled.txt', + }}}) + + id_producer.config.states = 'ref state1 unknown'.split() + id_producer.config.set_blocks_file() + id_producer.config.set_labeled_blocks_file() + id_producer.config.set_chromosomes() + + mocker.patch('analyze.id_regions.read_blocks', + return_value={}) + + mocked_file = mocker.patch('analyze.id_regions.open', + mocker.mock_open()) + + id_producer.add_ids() + + assert mocked_file.call_count == 3 + mocked_file.assert_any_call('dir/blocks_ref_labeled.txt', 'w') + mocked_file.assert_any_call('dir/blocks_state1_labeled.txt', 'w') + mocked_file.assert_any_call('dir/blocks_unknown_labeled.txt', 'w') + + # just headers + mocked_file().write.assert_has_calls([ + mocker.call('region_id\tstrain\tchromosome\tpredicted_species' + '\tstart\tend\tnum_sites_hmm\n') + ]*3) + + +def test_add_ids(id_producer, mocker): + id_producer.config.add_config({ + 'chromosomes': ['I', 'II', 'III', 'IV', 'V', 'VI', 'VII', 'VIII', 'IX', + 'X', 'XI', 'XII', 'XIII', 'XIV', 'XV', 'XVI'], + 'paths': {'analysis': {'block_files': 'dir/blocks_{state}.txt', + 'labeled_block_files': + 'dir/blocks_{state}_labeled.txt', + }}}) + + id_producer.config.states = 'ref state1 unknown'.split() + id_producer.config.set_blocks_file() + id_producer.config.set_labeled_blocks_file() + id_producer.config.set_chromosomes() + + regions = [ + { + 'strain1': { + 'I': [(10, 100, 10), (10, 100, 1)], + 'VI': [(10, 100, 10), (10, 100, 1)], + }, + 'strain2': { + 'V': [(10, 100, 10), (10, 100, 1)], + }, + 'strain3': { + 'III': [(10, 100, 10), (10, 100, 1)], + } + }, + { + 'strain1': { + 'IX': [(10, 100, 10), (10, 100, 1)], + }, + 'strain2': { + 'II': [(10, 100, 10), (10, 100, 1)], + }, + 'strain3': { + 'X': [(10, 100, 10), (10, 100, 1)], + } + }, + {} + ] + mocker.patch('analyze.id_regions.read_blocks', + side_effect=regions) + + mocked_file = mocker.patch('analyze.id_regions.open', + mocker.mock_open()) + + id_producer.add_ids() + + assert mocked_file.call_count == 3 + mocked_file.assert_any_call('dir/blocks_ref_labeled.txt', 'w') + mocked_file.assert_any_call('dir/blocks_state1_labeled.txt', 'w') + mocked_file.assert_any_call('dir/blocks_unknown_labeled.txt', 'w') + + # headers + calls = [ + mocker.call('region_id\tstrain\tchromosome\tpredicted_species' + '\tstart\tend\tnum_sites_hmm\n') + ]*3 + [ + mocker.call('r1\tstrain1\tI\tref\t10\t100\t10\n'), + mocker.call('r2\tstrain1\tI\tref\t10\t100\t1\n'), + mocker.call('r3\tstrain2\tII\tstate1\t10\t100\t10\n'), + mocker.call('r4\tstrain2\tII\tstate1\t10\t100\t1\n'), + mocker.call('r5\tstrain3\tIII\tref\t10\t100\t10\n'), + mocker.call('r6\tstrain3\tIII\tref\t10\t100\t1\n'), + mocker.call('r7\tstrain2\tV\tref\t10\t100\t10\n'), + mocker.call('r8\tstrain2\tV\tref\t10\t100\t1\n'), + mocker.call('r9\tstrain1\tVI\tref\t10\t100\t10\n'), + mocker.call('r10\tstrain1\tVI\tref\t10\t100\t1\n'), + mocker.call('r11\tstrain1\tIX\tstate1\t10\t100\t10\n'), + mocker.call('r12\tstrain1\tIX\tstate1\t10\t100\t1\n'), + mocker.call('r13\tstrain3\tX\tstate1\t10\t100\t10\n'), + mocker.call('r14\tstrain3\tX\tstate1\t10\t100\t1\n'), + ] + mocked_file().write.assert_has_calls(calls) diff --git a/code/test/analyze/test_id_regions_main.py b/code/test/analyze/test_id_regions_main.py deleted file mode 100644 index 5811bd5..0000000 --- a/code/test/analyze/test_id_regions_main.py +++ /dev/null @@ -1,112 +0,0 @@ -from analyze import id_regions_main as main -from operator import itemgetter - - -def test_main_blank(mocker): - # setup global params to match expectations - mocker.patch('analyze.id_regions_main.gp.chrms', - ['I', 'II', 'III', 'IV', 'V', 'VI', 'VII', 'VIII', 'IX', - 'X', 'XI', 'XII', 'XIII', 'XIV', 'XV', 'XVI']) - mocker.patch('analyze.id_regions_main.gp.analysis_out_dir_absolute', - 'dir/') - mocker.patch( - 'analyze.summarize_strain_states_main.predict.process_predict_args', - return_value={ - 'known_states': ['S288c', 'CBS432', 'N_45', - 'DBVPG6304', 'UWOPS91_917_1'], - 'states': ['ref', 'state1', 'unknown'], - 'tag': 'tag' - }) - mocker.patch('sys.argv', - "test.py tag .001 viterbi 1000 .025 unknown 1000 .01".split()) - mocker.patch('analyze.predict.read_blocks', - return_value={}) - - mocked_file = mocker.patch('analyze.id_regions_main.open', - mocker.mock_open()) - - main.main() - - assert mocked_file.call_count == 3 - mocked_file.assert_any_call('dir/tag/blocks_ref_tag_labeled.txt', 'w') - mocked_file.assert_any_call('dir/tag/blocks_state1_tag_labeled.txt', 'w') - mocked_file.assert_any_call('dir/tag/blocks_unknown_tag_labeled.txt', 'w') - - # just headers - mocked_file().write.assert_has_calls([ - mocker.call('region_id\tstrain\tchromosome\tpredicted_species' - '\tstart\tend\tnum_sites_hmm\n') - ]*3) - - -def test_main(mocker): - # setup global params to match expectations - mocker.patch( - 'analyze.summarize_strain_states_main.predict.process_predict_args', - return_value={ - 'states': ['ref', 'state1', 'unknown'], - 'tag': 'tag' - }) - mocker.patch('analyze.id_regions_main.gp.chrms', - ['I', 'II', 'III', 'IV', 'V', 'VI', 'VII', 'VIII', 'IX', - 'X', 'XI', 'XII', 'XIII', 'XIV', 'XV', 'XVI']) - mocker.patch('analyze.id_regions_main.gp.analysis_out_dir_absolute', - 'dir/') - - mocker.patch('sys.argv', - "test.py tag .001 viterbi 1000 .025 unknown 1000 .01".split()) - - regions = { - 'ref': { - 'I': [(10, 100, 10), (10, 100, 1)], - 'IX': [(10, 100, 10), (10, 100, 1)], - 'VI': [(10, 100, 10), (10, 100, 1)], - }, - 'state1': { - 'II': [(10, 100, 10), (10, 100, 1)], - 'X': [(10, 100, 10), (10, 100, 1)], - 'V': [(10, 100, 10), (10, 100, 1)], - }, - 'unknown': { - 'II': [(10, 100, 10), (10, 100, 1)], - 'X': [(10, 100, 10), (10, 100, 1)], - 'V': [(10, 100, 10), (10, 100, 1)], - } - } - mocker.patch('analyze.predict.read_blocks', - return_value=regions) - - mocked_file = mocker.patch('analyze.id_regions_main.open', - mocker.mock_open()) - - main.main() - - mocked_file.assert_any_call('dir/tag/blocks_ref_tag_labeled.txt', 'w') - mocked_file.assert_any_call('dir/tag/blocks_state1_tag_labeled.txt', 'w') - mocked_file.assert_any_call('dir/tag/blocks_unknown_tag_labeled.txt', 'w') - assert mocked_file.call_count == 3 - - # headers - calls = [ - mocker.call('region_id\tstrain\tchromosome\tpredicted_species' - '\tstart\tend\tnum_sites_hmm\n') - ]*3 - - rid = 1 - by_chrom = dict(zip(main.gp.chrms, [[] for chrm in main.gp.chrms])) - for spec in ('ref', 'state1', 'unknown'): - for s in sorted(regions): - for c in regions[s]: - for e in regions[s][c]: - start, end, num = e - by_chrom[c].append((start, end, num, s, spec)) - - for c in main.gp.chrms: - for e in sorted(by_chrom[c], key=itemgetter(0, 3)): - start, end, num, s, spec = e - calls.append(mocker.call( - 'r{}\t{}\t{}\t{}\t{}\t{}\t{}\n'.format( - rid, s, c, spec, start, end, num))) - rid += 1 - - mocked_file().write.assert_has_calls(calls) diff --git a/code/test/analyze/test_introgression_configuration.py b/code/test/analyze/test_introgression_configuration.py new file mode 100644 index 0000000..1e7a032 --- /dev/null +++ b/code/test/analyze/test_introgression_configuration.py @@ -0,0 +1,592 @@ +from analyze.introgression_configuration import Configuration +import pytest + + +@pytest.fixture() +def config(): + return Configuration() + + +def test_set_log_file(config): + config.set_log_file() + assert config.log_file is None + + config.set_log_file('test') + assert config.log_file == 'test' + + config.config = {'paths': {'log_file': 'log'}} + config.set_log_file() + assert config.log_file == 'log' + + config.set_log_file('test') + assert config.log_file == 'test' + + +def test_set_chromosomes(config): + with pytest.raises(ValueError) as e: + config.set_chromosomes() + assert 'No chromosomes specified in config file!' in str(e) + + config.config = {'chromosomes': ['I']} + config.set_chromosomes() + assert config.chromosomes == ['I'] + + +def test_get_states(config): + assert config.get_states() == ([], []) + + config.config = { + 'analysis_params': { + 'known_states': [ + {'name': 'k1'}, + {'name': 'k2'}, + {'name': 'k3'}, + ], + 'unknown_states': [ + {'name': 'u1'}, + {'name': 'u2'}, + ] + } + } + assert config.get_states() == ('k1 k2 k3'.split(), 'u1 u2'.split()) + + config.config = { + 'analysis_params': { + 'reference': {'name': 'ref'}, + 'unknown_states': [ + {'name': 'u1'}, + {'name': 'u2'}, + ] + } + } + assert config.get_states() == ('ref'.split(), 'u1 u2'.split()) + + config.config = { + 'analysis_params': { + 'reference': {'name': 'ref'}, + 'known_states': [ + {'name': 'k1'}, + {'name': 'k2'}, + {'name': 'k3'}, + ], + 'unknown_states': [ + {'name': 'u1'}, + {'name': 'u2'}, + ] + } + } + assert config.get_states() == ('ref k1 k2 k3'.split(), 'u1 u2'.split()) + + +def test_set_states(config): + config.config = { + 'analysis_params': + {'reference': {'name': 'S288c'}, + 'known_states': [ + {'name': 'CBS432'}, + {'name': 'N_45'}, + {'name': 'DBVPG6304'}, + {'name': 'UWOPS91_917_1'}, + ], + 'unknown_states': [{'name': 'unknown'}] + } + } + + config.set_states() + assert config.known_states ==\ + 'S288c CBS432 N_45 DBVPG6304 UWOPS91_917_1'.split() + assert config.unknown_states ==\ + 'unknown'.split() + assert config.states ==\ + 'S288c CBS432 N_45 DBVPG6304 UWOPS91_917_1 unknown'.split() + + config.set_states([]) + assert config.known_states ==\ + 'S288c CBS432 N_45 DBVPG6304 UWOPS91_917_1'.split() + assert config.unknown_states ==\ + 'unknown'.split() + assert config.states ==\ + 'S288c CBS432 N_45 DBVPG6304 UWOPS91_917_1 unknown'.split() + + config.set_states('testing 123'.split()) + assert config.states == ['testing', '123'] + + config.config = {} + + with pytest.raises(ValueError) as e: + config.set_states() + assert 'No states specified' in str(e) + + +def test_set_threshold(config): + with pytest.raises(ValueError) as e: + config.set_threshold() + assert 'No threshold provided' in str(e) + + config.config = {'analysis_params': {'threshold': 'asdf'}} + with pytest.raises(ValueError) as e: + config.set_threshold() + assert 'Unsupported threshold value: asdf' in str(e) + + config.set_threshold(0.05) + assert config.threshold == 0.05 + + config.config = {'analysis_params': + {'threshold': 'viterbi'}} + config.set_threshold() + assert config.threshold == 'viterbi' + + +def test_set_labeled_blocks_file(config): + with pytest.raises(ValueError) as e: + config.set_labeled_blocks_file('blocks_file') + assert '{state} not found in blocks_file' in str(e) + + config.set_labeled_blocks_file('blocks_file{state}') + assert config.labeled_blocks == 'blocks_file{state}' + + with pytest.raises(ValueError) as e: + config.set_labeled_blocks_file() + assert 'No labeled block file provided' in str(e) + + config.config = {'paths': {'analysis': + {'labeled_block_files': 'blocks_file'}}} + with pytest.raises(ValueError) as e: + config.set_labeled_blocks_file() + assert '{state} not found in blocks_file' in str(e) + + config.config = {'paths': {'analysis': {'labeled_block_files': + 'blocks_file{state}'}}} + config.set_labeled_blocks_file() + assert config.labeled_blocks == 'blocks_file{state}' + + +def test_set_blocks_file(config): + with pytest.raises(ValueError) as e: + config.set_blocks_file('blocks_file') + assert '{state} not found in blocks_file' in str(e) + + config.set_blocks_file('blocks_file{state}') + assert config.blocks == 'blocks_file{state}' + + with pytest.raises(ValueError) as e: + config.set_blocks_file() + assert 'No block file provided' in str(e) + + config.config = {'paths': {'analysis': {'block_files': 'blocks_file'}}} + with pytest.raises(ValueError) as e: + config.set_blocks_file() + assert '{state} not found in blocks_file' in str(e) + + config.config = {'paths': {'analysis': {'block_files': + 'blocks_file{state}'}}} + config.set_blocks_file() + assert config.blocks == 'blocks_file{state}' + + +def test_set_prefix(config): + config.known_states = ['s1'] + config.set_prefix() + assert config.prefix == 's1' + + config.known_states = 's1 s2'.split() + config.set_prefix() + assert config.prefix == 's1_s2' + + config.set_prefix('prefix') + assert config.prefix == 'prefix' + + config.known_states = [] + with pytest.raises(ValueError) as e: + config.set_prefix() + assert 'Unable to build prefix, no known states provided' in str(e) + + +def test_set_strains(config, mocker): + mock_find = mocker.patch.object(Configuration, 'find_strains') + + config.set_strains() + mock_find.called_with(None) + + with pytest.raises(ValueError) as e: + config.config = {'paths': {'test_strains': ['test']}} + config.set_strains() + assert '{strain} not found in test' in str(e) + + with pytest.raises(ValueError) as e: + config.config = {'paths': {'test_strains': ['test{strain}']}} + config.set_strains() + assert '{chrom} not found in test{strain}' in str(e) + + config.config = {'paths': {'test_strains': + ['test{strain}{chrom}']}} + config.set_strains() + mock_find.called_with(['test{strain}{chrom}']) + + config.set_strains('test{strain}{chrom}') + mock_find.called_with(['test{strain}{chrom}']) + + +def test_find_strains(config, mocker): + with pytest.raises(ValueError) as e: + config.find_strains() + assert ('Unable to find strains in config and ' + 'no test_strains provided') in str(e) + + config.config = {'strains': ['test2', 'test1']} + config.find_strains() + # sorted + assert config.strains == 'test1 test2'.split() + + config.config = {} + config.chromosomes = ['I'] + + # too many chroms for s1 + mock_glob = mocker.patch('analyze.introgression_configuration.glob.iglob', + side_effect=[[ + 'test_prefix_s1_cII.fa', + 'test_prefix_s2_cII.fa', + 'test_prefix_s1_cIII.fa', + 'test_prefix.fa', + ]]) + mock_log = mocker.patch('analyze.introgression_configuration.log') + with pytest.raises(ValueError) as e: + config.find_strains(['test_prefix_{strain}_c{chrom}.fa']) + + assert "Strain s1 is missing chromosomes. Unable to find chromosome 'I'"\ + in str(e) + mock_glob.assert_called_with('test_prefix_*_c*.fa') + mock_log.info.assert_called_with('searching for test_prefix_*_c*.fa') + assert mock_log.debug.call_args_list == \ + [mocker.call("matched with ('s1', 'II')"), + mocker.call("matched with ('s2', 'II')"), + mocker.call("matched with ('s1', 'III')"), + ] + + # no matches + mock_glob = mocker.patch('analyze.introgression_configuration.glob.iglob', + side_effect=[[ + 'test_prefix.fa', + ]]) + mock_log = mocker.patch('analyze.introgression_configuration.log') + with pytest.raises(ValueError) as e: + config.find_strains(['test_prefix_{strain}_{chrom}.fa']) + assert ('Found no chromosome sequence files in ' + "['test_prefix_{strain}_{chrom}.fa']") in str(e) + mock_glob.assert_called_with('test_prefix_*_*.fa') + mock_log.info.assert_called_with('searching for test_prefix_*_*.fa') + assert mock_log.debug.call_args_list == [] + + # correct, with second test_strains, extra chromosomes + mock_glob = mocker.patch('analyze.introgression_configuration.glob.iglob', + side_effect=[ + [ + 'test_prefix_s1_cI.fa', + 'test_prefix_s2_cI.fa', + 'test_prefix_s2_cII.fa', + 'test_prefix.fa', + ], + ['test_prefix_cI_s3.fa'] + ]) + mock_log = mocker.patch('analyze.introgression_configuration.log') + config.find_strains(['test_prefix_{strain}_c{chrom}.fa', + 'test_prefix_c{chrom}_{strain}.fa']) + assert mock_glob.call_args_list == \ + [mocker.call('test_prefix_*_c*.fa'), + mocker.call('test_prefix_c*_*.fa')] + assert mock_log.info.call_args_list ==\ + [mocker.call('searching for test_prefix_*_c*.fa'), + mocker.call('searching for test_prefix_c*_*.fa')] + assert mock_log.debug.call_args_list == \ + [mocker.call("matched with ('s1', 'I')"), + mocker.call("matched with ('s2', 'I')"), + mocker.call("matched with ('s2', 'II')"), + mocker.call("matched with ('s3', 'I')"), + ] + assert config.strains == ['s1', 's2', 's3'] + + +def test_set_predict_files(config): + with pytest.raises(ValueError) as e: + config.set_predict_files('', '', '', '', '') + assert 'No initial hmm file provided' in str(e) + + with pytest.raises(ValueError) as e: + config.set_predict_files('init', '', '', '', '') + assert 'No trained hmm file provided' in str(e) + + with pytest.raises(ValueError) as e: + config.set_predict_files('init', 'trained', 'pos', 'prob', '') + assert 'No alignment file provided' in str(e) + + with pytest.raises(ValueError) as e: + config.set_predict_files('init', 'trained', 'pos', 'prob', 'align') + assert '{prefix} not found in align' in str(e) + + with pytest.raises(ValueError) as e: + config.set_predict_files('init', 'trained', 'pos', 'prob', + 'align{prefix}') + assert '{strain} not found in align{prefix}' in str(e) + + with pytest.raises(ValueError) as e: + config.set_predict_files('init', 'trained', 'pos', 'prob', + 'align{prefix}{strain}') + assert '{chrom} not found in align{prefix}{strain}' in str(e) + + config.prefix = 'pre' + config.set_predict_files('init', 'trained', 'pos', 'prob', + 'align{prefix}{strain}{chrom}') + assert config.hmm_initial == 'init' + assert config.hmm_trained == 'trained' + assert config.positions == 'pos' + assert config.probabilities == 'prob' + assert config.alignment == 'alignpre{strain}{chrom}' + + config.set_predict_files('init', 'trained', '', 'prob', + 'align{prefix}{strain}{chrom}') + assert config.hmm_initial == 'init' + assert config.hmm_trained == 'trained' + assert config.positions is None + assert config.probabilities == 'prob' + assert config.alignment == 'alignpre{strain}{chrom}' + + with pytest.raises(ValueError) as e: + config.config = {'paths': {'analysis': {'hmm_initial': 'init'}}} + config.set_predict_files('', '', '', '', '') + assert 'No trained hmm file provided' in str(e) + + with pytest.raises(ValueError) as e: + config.config = {'paths': {'analysis': {'hmm_initial': 'init', + 'hmm_trained': 'trained', + 'positions': 'pos' + }}} + config.set_predict_files('', '', '', '', '') + assert 'No probabilities file provided' in str(e) + + with pytest.raises(ValueError) as e: + config.config = {'paths': {'analysis': {'hmm_initial': 'init', + 'hmm_trained': 'trained', + 'positions': 'pos', + 'probabilities': 'prob' + }}} + config.set_predict_files('', '', '', '', '') + assert 'No alignment file provided' in str(e) + + config.config = {'paths': {'analysis': { + 'hmm_initial': 'init', + 'hmm_trained': 'trained', + 'positions': 'pos', + 'probabilities': 'prob', + 'alignment': 'align{prefix}{strain}{chrom}' + }}} + config.set_predict_files('', '', '', '', '') + + assert config.hmm_initial == 'init' + assert config.hmm_trained == 'trained' + assert config.positions == 'pos' + assert config.probabilities == 'prob' + assert config.alignment == 'alignpre{strain}{chrom}' + + config.config = {'paths': {'analysis': { + 'hmm_initial': 'init', + 'hmm_trained': 'trained', + 'probabilities': 'prob', + 'alignment': 'align{prefix}{strain}{chrom}' + }}} + config.set_predict_files('', '', '', '', '') + + assert config.hmm_initial == 'init' + assert config.hmm_trained == 'trained' + assert config.positions is None + assert config.probabilities == 'prob' + assert config.alignment == 'alignpre{strain}{chrom}' + + +def test_validate_predict_arguments(config): + config.chromosomes = 1 + config.blocks = 1 + config.prefix = 1 + config.strains = 1 + config.hmm_initial = 1 + config.hmm_trained = 1 + config.probabilities = 1 + config.alignment = 1 + config.known_states = 1 + config.unknown_states = 1 + config.threshold = 1 + config.config = { + 'analysis_params': + {'reference': {'name': 'S288c'}, + 'known_states': [ + {'name': 'CBS432', + 'expected_length': 1, + 'expected_fraction': 0.01}, + {'name': 'N_45', + 'expected_length': 1, + 'expected_fraction': 0.01}, + {'name': 'DBVPG6304', + 'expected_length': 1, + 'expected_fraction': 0.01}, + {'name': 'UWOPS91_917_1', + 'expected_length': 1, + 'expected_fraction': 0.01}, + ], + 'unknown_states': [{'name': 'unknown', + 'expected_length': 1, + 'expected_fraction': 0.01}, + ] + } + } + + assert config.validate_predict_arguments() + + args = [ + 'chromosomes', + 'blocks', + 'prefix', + 'strains', + 'hmm_initial', + 'hmm_trained', + 'probabilities', + 'alignment', + 'known_states', + 'unknown_states', + 'threshold' + ] + + for arg in args: + config.__dict__[arg] = None + with pytest.raises(ValueError) as e: + config.validate_predict_arguments() + assert ('Failed to validate Predictor, ' + f"required argument '{arg}' was unset") in str(e) + config.__dict__[arg] = 1 + + config.config = { + 'analysis_params': + {'reference': {'name': 'S288c'}, + 'unknown_states': [{'name': 'unknown', + 'expected_length': 1, + 'expected_fraction': 0.01}, + ] + } + } + with pytest.raises(ValueError) as e: + config.validate_predict_arguments() + assert 'Configuration did not provide any known_states' in str(e) + + config.config = { + 'analysis_params': + {'known_states': [ + {'name': 'CBS432', + 'expected_length': 1, + 'expected_fraction': 0.01}, + {'name': 'N_45', + 'expected_length': 1, + 'expected_fraction': 0.01}, + ], + 'unknown_states': [{'name': 'unknown', + 'expected_length': 1, + 'expected_fraction': 0.01}, + ] + } + } + with pytest.raises(ValueError) as e: + config.validate_predict_arguments() + assert 'Configuration did not specify a reference strain' in str(e) + + config.config = { + 'analysis_params': + {'reference': {'name': 'S288c'}, + 'known_states': [ + {'name': 'CBS432', + 'expected_fraction': 0.01}, + {'name': 'N_45', + 'expected_length': 1, + 'expected_fraction': 0.01}, + ], + 'unknown_states': [{'name': 'unknown', + 'expected_length': 1, + 'expected_fraction': 0.01}, + ] + } + } + with pytest.raises(ValueError) as e: + config.validate_predict_arguments() + assert 'CBS432 did not provide an expected_length' in str(e) + + config.config = { + 'analysis_params': + {'reference': {'name': 'S288c'}, + 'known_states': [ + {'name': 'CBS432', + 'expected_length': 1, + 'expected_fraction': 0.01}, + {'name': 'N_45', + 'expected_length': 1, + }, + ], + 'unknown_states': [{'name': 'unknown', + 'expected_length': 1, + 'expected_fraction': 0.01}, + ] + } + } + with pytest.raises(ValueError) as e: + config.validate_predict_arguments() + assert 'N_45 did not provide an expected_fraction' in str(e) + + config.config = { + 'analysis_params': + {'reference': {'name': 'S288c'}, + 'known_states': [ + {'name': 'CBS432', + 'expected_length': 1, + 'expected_fraction': 0.01}, + {'name': 'N_45', + 'expected_length': 1, + 'expected_fraction': 0.01}, + ], + 'unknown_states': [{'name': 'unknown', + 'expected_fraction': 0.01}, + ] + } + } + with pytest.raises(ValueError) as e: + config.validate_predict_arguments() + assert 'unknown did not provide an expected_length' in str(e) + + config.config = { + 'analysis_params': + {'reference': {'name': 'S288c'}, + 'known_states': [ + {'name': 'CBS432', + 'expected_length': 1, + 'expected_fraction': 0.01}, + {'name': 'N_45', + 'expected_length': 1, + 'expected_fraction': 0.01}, + ], + 'unknown_states': [{'name': 'unknown', + 'expected_length': 1, + }, + ] + } + } + with pytest.raises(ValueError) as e: + config.validate_predict_arguments() + assert 'unknown did not provide an expected_fraction' in str(e) + + +def test_validate_id_regions_arguments(config): + with pytest.raises(ValueError) as e: + config.validate_id_regions_arguments() + assert ('Failed to validate ID Producer, ' + "required argument 'chromosomes' was unset") in str(e) + + config.chromosomes = 1 + config.blocks = 1 + config.labeled_blocks = 1 + config.states = 1 + + assert config.validate_id_regions_arguments() diff --git a/code/test/analyze/test_main.py b/code/test/analyze/test_main.py index ae01a95..532d860 100644 --- a/code/test/analyze/test_main.py +++ b/code/test/analyze/test_main.py @@ -15,8 +15,6 @@ def test_main_cli_configs(runner, mocker): assert result.exit_code == 0 with runner.isolated_filesystem(): - mock_clean = mocker.patch('analyze.main.config_utils.clean_config', - side_effect=lambda x: x) mock_echo = mocker.patch('analyze.main.click.echo_via_pager') mock_log_info = mocker.patch('analyze.main.log.info') mock_log_debug = mocker.patch('analyze.main.log.debug') @@ -31,9 +29,6 @@ def test_main_cli_configs(runner, mocker): main.cli, '--config config1.yaml --config config2.yaml'.split()) assert result.exit_code == 0 - mock_clean.assert_called_with( - {'test': '23', - 'test2': '34'}) # since no subcommand was called mock_echo.assert_called_once() @@ -43,8 +38,10 @@ def test_main_cli_configs(runner, mocker): mocker.call('Verbosity set to WARNING'), mocker.call('Read in 2 config files') ] + print(mock_log_debug.call_args_list[0][0]) assert mock_log_debug.call_args_list == [ - mocker.call('Cleaned config:\ntest - 23\ntest2 - 34\n') + mocker.call('Cleaned config:\nConfig file:\ntest - 23\n' + 'test2 - 34\n\nSettings:\nlog_file - None\n') ] diff --git a/code/test/analyze/test_main_id_config.py b/code/test/analyze/test_main_id_config.py new file mode 100644 index 0000000..c3631f0 --- /dev/null +++ b/code/test/analyze/test_main_id_config.py @@ -0,0 +1,143 @@ +import pytest +from click.testing import CliRunner +import analyze.main as main +import yaml +from analyze.id_regions import ID_producer + + +''' +Unit tests for the id_regions command of main.py when all parameters are +provided by the config file +''' + + +@pytest.fixture +def runner(): + return CliRunner() + + +def test_empty(runner): + result = runner.invoke( + main.cli, + 'id-regions') + assert result.exit_code != 0 + assert str(result.exception) == 'No chromosomes specified in config file!' + + +def test_chroms(runner, mocker): + mock_log = mocker.patch('analyze.main.log.info') + with runner.isolated_filesystem(): + with open('config.yaml', 'w') as f: + yaml.dump( + { + 'chromosomes': 'I II III'.split() + }, f) + + result = runner.invoke( + main.cli, + '--config config.yaml id-regions') + + assert result.exit_code != 0 + assert str(result.exception) == 'No states specified' + assert mock_log.call_args_list == [ + mocker.call('Verbosity set to WARNING'), + mocker.call('Read in 1 config file'), + mocker.call('Found 3 chromosomes in config') + ] + + +def test_states(runner, mocker): + mock_log = mocker.patch('analyze.main.log.info') + with runner.isolated_filesystem(): + with open('config.yaml', 'w') as f: + yaml.dump( + { + 'chromosomes': 'I II III'.split(), + 'analysis_params': { + 'known_states': [ + {'name': 's1'}, + {'name': 's2'}], + }, + }, f) + + result = runner.invoke( + main.cli, + '--config config.yaml id-regions') + + assert result.exit_code != 0 + assert str(result.exception) == 'No block file provided' + assert mock_log.call_args_list == [ + mocker.call('Verbosity set to WARNING'), + mocker.call('Read in 1 config file'), + mocker.call('Found 3 chromosomes in config'), + mocker.call('Found 2 states to process'), + ] + + +def test_block_file(runner, mocker): + mock_log = mocker.patch('analyze.main.log.info') + with runner.isolated_filesystem(): + with open('config.yaml', 'w') as f: + yaml.dump( + { + 'chromosomes': 'I II III'.split(), + 'analysis_params': { + 'known_states': [ + {'name': 's1'}, + {'name': 's2'}], + }, + 'paths': {'analysis': { + 'block_files': 'block_{state}.txt', + }} + }, f) + + result = runner.invoke( + main.cli, + '--config config.yaml id-regions') + + assert result.exit_code != 0 + assert str(result.exception) == 'No labeled block file provided' + assert mock_log.call_args_list == [ + mocker.call('Verbosity set to WARNING'), + mocker.call('Read in 1 config file'), + mocker.call('Found 3 chromosomes in config'), + mocker.call('Found 2 states to process'), + mocker.call("Input blocks file is 'block_{state}.txt'"), + ] + + +def test_labeled_block_file(runner, mocker): + mock_log = mocker.patch('analyze.main.log.info') + with runner.isolated_filesystem(): + with open('config.yaml', 'w') as f: + yaml.dump( + { + 'chromosomes': 'I II III'.split(), + 'analysis_params': { + 'known_states': [ + {'name': 's1'}, + {'name': 's2'}], + }, + 'paths': {'analysis': { + 'block_files': 'block_{state}.txt', + 'labeled_block_files': 'labeled_block_{state}.txt', + }} + }, f) + + mock_id = mocker.patch.object(ID_producer, 'add_ids') + + result = runner.invoke( + main.cli, + '--config config.yaml id-regions') + + assert result.exit_code == 0 + assert mock_log.call_args_list == [ + mocker.call('Verbosity set to WARNING'), + mocker.call('Read in 1 config file'), + mocker.call('Found 3 chromosomes in config'), + mocker.call('Found 2 states to process'), + mocker.call("Input blocks file is 'block_{state}.txt'"), + mocker.call("Output blocks file is 'labeled_block_{state}.txt'"), + ] + + mock_id.called_once() diff --git a/code/test/analyze/test_main_predict_args.py b/code/test/analyze/test_main_predict_args.py index d67d374..42b27a1 100644 --- a/code/test/analyze/test_main_predict_args.py +++ b/code/test/analyze/test_main_predict_args.py @@ -7,8 +7,8 @@ ''' -Unit tests for the predict command of main.py when all parameters are -provided by the config file +Unit tests for the predict command of main.py when parameters are +provided by the arguments primarily ''' @@ -47,6 +47,11 @@ def test_block(runner, mocker): yaml.dump( { 'chromosomes': 'I II III'.split(), + 'analysis_params': { + 'known_states': [ + {'name': 's1'}, + {'name': 's2'}], + }, }, f) result = runner.invoke( @@ -57,13 +62,14 @@ def test_block(runner, mocker): assert result.exit_code != 0 assert str(result.exception) == \ - 'Unable to build prefix, no known states provided' + 'Unable to find strains in config and no test_strains provided' assert mock_log.call_args_list == [ mocker.call('Verbosity set to WARNING'), mocker.call('Read in 1 config file'), mocker.call('Found 3 chromosomes in config'), mocker.call("Threshold value is 'viterbi'"), mocker.call("Output blocks file is 'blocks_{state}.txt'"), + mocker.call("Prefix is 's1_s2'"), ] @@ -74,6 +80,11 @@ def test_prefix(runner, mocker): yaml.dump( { 'chromosomes': 'I II III'.split(), + 'analysis_params': { + 'known_states': [ + {'name': 's1'}, + {'name': 's2'}], + }, }, f) result = runner.invoke( @@ -102,6 +113,11 @@ def test_test_strains(runner, mocker): yaml.dump( { 'chromosomes': 'I II III'.split(), + 'analysis_params': { + 'known_states': [ + {'name': 's1'}, + {'name': 's2'}], + }, }, f) Path('s1_chrI.fa').touch() @@ -155,6 +171,11 @@ def test_outputs(runner, mocker): { 'chromosomes': 'I II III'.split(), 'strains': 'str1 str2 str1'.split(), + 'analysis_params': { + 'known_states': [ + {'name': 's1'}, + {'name': 's2'}], + }, }, f) mock_log.reset_mock() @@ -176,6 +197,11 @@ def test_outputs(runner, mocker): { 'chromosomes': 'I II III'.split(), 'strains': 'str1 str2 str1'.split(), + 'analysis_params': { + 'known_states': [ + {'name': 's1'}, + {'name': 's2'}], + }, }, f) mock_log.reset_mock() @@ -198,6 +224,11 @@ def test_outputs(runner, mocker): { 'chromosomes': 'I II III'.split(), 'strains': 'str1 str2 str1'.split(), + 'analysis_params': { + 'known_states': [ + {'name': 's1'}, + {'name': 's2'}], + }, }, f) mock_log.reset_mock() @@ -247,7 +278,9 @@ def test_outputs(runner, mocker): mocker.call("Hmm_trained file is 'hmm_trained.txt'"), mocker.call("Positions file is 'None'"), mocker.call("Probabilities file is 'probs.txt.gz'"), - mocker.call("Alignment file is 's1_s2_{strain}_chr{chrom}.maf'")] + mocker.call("Alignment file is 's1_s2_{strain}_chr{chrom}.maf'"), + mocker.call("Only considering polymorphic sites"), + ] mock_predict = mocker.patch.object(predict.Predictor, 'run_prediction') with runner.isolated_filesystem(): @@ -281,7 +314,9 @@ def test_outputs(runner, mocker): mocker.call("Hmm_trained file is 'hmm_trained.txt'"), mocker.call("Positions file is 'pos.txt.gz'"), mocker.call("Probabilities file is 'probs.txt.gz'"), - mocker.call("Alignment file is 's1_s2_{strain}_chr{chrom}.maf'")] + mocker.call("Alignment file is 's1_s2_{strain}_chr{chrom}.maf'"), + mocker.call("Only considering polymorphic sites"), + ] mock_predict.called_once_with(True) mock_predict.reset_mock() diff --git a/code/test/analyze/test_main_predict_config.py b/code/test/analyze/test_main_predict_config.py index ccf7070..9d9146e 100644 --- a/code/test/analyze/test_main_predict_config.py +++ b/code/test/analyze/test_main_predict_config.py @@ -79,7 +79,10 @@ def test_block(runner, mocker): { 'chromosomes': 'I II III'.split(), 'analysis_params': { - 'threshold': 'viterbi' + 'threshold': 'viterbi', + 'known_states': [ + {'name': 's1'}, + {'name': 's2'}], }, 'paths': {'analysis': { 'block_files': 'blocks_{state}.txt', @@ -92,13 +95,14 @@ def test_block(runner, mocker): assert result.exit_code != 0 assert str(result.exception) == \ - 'Unable to build prefix, no known states provided' + 'Unable to find strains in config and no test_strains provided' assert mock_log.call_args_list == [ mocker.call('Verbosity set to WARNING'), mocker.call('Read in 1 config file'), mocker.call('Found 3 chromosomes in config'), mocker.call("Threshold value is 'viterbi'"), mocker.call("Output blocks file is 'blocks_{state}.txt'"), + mocker.call("Prefix is 's1_s2'"), ] @@ -357,7 +361,9 @@ def test_outputs(runner, mocker): mocker.call("Hmm_trained file is 'hmm_trained.txt'"), mocker.call("Positions file is 'None'"), mocker.call("Probabilities file is 'probs.txt.gz'"), - mocker.call("Alignment file is 's1_s2_{strain}_chr{chrom}.maf'")] + mocker.call("Alignment file is 's1_s2_{strain}_chr{chrom}.maf'"), + mocker.call("Only considering polymorphic sites") + ] mock_predict = mocker.patch.object(predict.Predictor, 'run_prediction') with runner.isolated_filesystem(): @@ -385,7 +391,46 @@ def test_outputs(runner, mocker): mock_log.reset_mock() result = runner.invoke( main.cli, - '--config config.yaml predict') + '--config config.yaml predict --only-poly-sites') + + assert result.exit_code == 0 + assert mock_log.call_args_list == mock_calls + [ + mocker.call("Hmm_initial file is 'hmm_init.txt'"), + mocker.call("Hmm_trained file is 'hmm_trained.txt'"), + mocker.call("Positions file is 'pos.txt.gz'"), + mocker.call("Probabilities file is 'probs.txt.gz'"), + mocker.call("Alignment file is 's1_s2_{strain}_chr{chrom}.maf'"), + mocker.call("Only considering polymorphic sites"), + ] + mock_predict.called_once_with(True) + + mock_predict = mocker.patch.object(predict.Predictor, 'run_prediction') + with runner.isolated_filesystem(): + with open('config.yaml', 'w') as f: + yaml.dump( + { + 'chromosomes': 'I II III'.split(), + 'strains': 'str1 str2 str1'.split(), + 'analysis_params': { + 'threshold': 'viterbi', + 'known_states': [ + {'name': 's1'}, + {'name': 's2'}], + }, + 'paths': {'analysis': { + 'block_files': 'blocks_{state}.txt', + 'hmm_initial': 'hmm_init.txt', + 'hmm_trained': 'hmm_trained.txt', + 'positions': 'pos.txt.gz', + 'probabilities': 'probs.txt.gz', + 'alignment': '{prefix}_{strain}_chr{chrom}.maf', + }}, + }, f) + + mock_log.reset_mock() + result = runner.invoke( + main.cli, + '--config config.yaml predict --all-sites') assert result.exit_code == 0 assert mock_log.call_args_list == mock_calls + [ @@ -393,5 +438,7 @@ def test_outputs(runner, mocker): mocker.call("Hmm_trained file is 'hmm_trained.txt'"), mocker.call("Positions file is 'pos.txt.gz'"), mocker.call("Probabilities file is 'probs.txt.gz'"), - mocker.call("Alignment file is 's1_s2_{strain}_chr{chrom}.maf'")] + mocker.call("Alignment file is 's1_s2_{strain}_chr{chrom}.maf'"), + mocker.call("Considering all sites"), + ] mock_predict.called_once_with(True) diff --git a/code/test/analyze/test_predict_hmm_builder.py b/code/test/analyze/test_predict_hmm_builder.py index 1704cbd..bb70b68 100644 --- a/code/test/analyze/test_predict_hmm_builder.py +++ b/code/test/analyze/test_predict_hmm_builder.py @@ -5,11 +5,17 @@ from collections import defaultdict import random import numpy as np +from analyze.introgression_configuration import Configuration @pytest.fixture -def default_builder(): - builder = predict.HMM_Builder({ +def config(): + return Configuration() + + +@pytest.fixture +def default_builder(config): + config.config = { 'analysis_params': {'reference': {'name': 'S288c'}, 'known_states': [ @@ -31,19 +37,20 @@ def default_builder(): 'expected_fraction': 0.01}, ] } - }) + } + builder = predict.HMM_Builder(config) + config.set_states() builder.set_expected_values() builder.update_expected_length(1e5) return builder @pytest.fixture -def builder(): - return predict.HMM_Builder(None) +def builder(config): + return predict.HMM_Builder(config) def test_builder(builder): - assert builder.config is None assert builder.symbols == { 'match': '+', 'mismatch': '-', @@ -55,9 +62,9 @@ def test_builder(builder): } -def test_init(mocker): +def test_init(mocker, config): mock_log = mocker.patch('analyze.predict.log') - predict.HMM_Builder(None) + predict.HMM_Builder(config) # no config, all warnings mock_log.warning.has_calls([ mocker.call("Symbol for match unset in config, using default '+'"), @@ -72,7 +79,8 @@ def test_init(mocker): # config, same warnings as above along with unused mock_log = mocker.patch('analyze.predict.log') - predict.HMM_Builder({'HMM_symbols': {'unused': 'X'}}) + config.config = {'HMM_symbols': {'unused': 'X'}} + predict.HMM_Builder(config) mock_log.warning.has_calls([ mocker.call("Unused symbol in configuration: unused -> 'X'"), mocker.call("Symbol for mismatch unset in config, using default '-'"), @@ -86,7 +94,8 @@ def test_init(mocker): # overwrite mock_log = mocker.patch('analyze.predict.log') - predict.HMM_Builder({'HMM_symbols': {'masked': 'X'}}) + config.config = {'HMM_symbols': {'masked': 'X'}} + predict.HMM_Builder(config) mock_log.debug.has_calls([ mocker.call("Overwriting default symbol for masked with 'X'") ]) @@ -152,8 +161,8 @@ def symbol_test_helper(sequence, builder): assert weigh == approx(weighted_match_freqs) -def test_set_expected_values(builder): - builder.config = { +def test_set_expected_values(builder, config): + config.config = { 'analysis_params': {'reference': {'name': 'S288c'}, 'known_states': [ @@ -176,6 +185,7 @@ def test_set_expected_values(builder): ] } } + config.set_states() builder.set_expected_values() assert builder.expected_lengths == { 'CBS432': 10, @@ -196,8 +206,8 @@ def test_set_expected_values(builder): assert builder.ref_state == 'S288c' -def test_update_expected_length(builder): - builder.config = { +def test_update_expected_length(builder, config): + config.config = { 'analysis_params': {'reference': {'name': 'S288c'}, 'known_states': [ @@ -220,6 +230,7 @@ def test_update_expected_length(builder): ] } } + config.set_states() builder.set_expected_values() assert builder.expected_lengths == { diff --git a/code/test/analyze/test_predict_predictor.py b/code/test/analyze/test_predict_predictor.py index 30fb073..ee583ee 100644 --- a/code/test/analyze/test_predict_predictor.py +++ b/code/test/analyze/test_predict_predictor.py @@ -5,12 +5,13 @@ from collections import defaultdict import random import numpy as np +from analyze.introgression_configuration import Configuration @pytest.fixture -def predictor(): - result = predict.Predictor( - configuration={ +def config(): + config = Configuration() + config.add_config({ 'analysis_params': {'reference': {'name': 'S288c'}, 'known_states': [ @@ -21,477 +22,39 @@ def predictor(): ], 'unknown_states': [{'name': 'unknown'}] } - } - ) - return result - - -def test_predictor(predictor): - assert predictor.known_states ==\ - 'S288c CBS432 N_45 DBVPG6304 UWOPS91_917_1'.split() - assert predictor.unknown_states == ['unknown'] - - -def test_set_chromosomes(predictor): - with pytest.raises(ValueError) as e: - predictor.set_chromosomes() - assert 'No chromosomes specified in config file!' in str(e) - - predictor.config = {'chromosomes': ['I']} - predictor.set_chromosomes() - assert predictor.chromosomes == ['I'] - - -def test_set_blocks_file(predictor): - with pytest.raises(ValueError) as e: - predictor.set_blocks_file('blocks_file') - assert '{state} not found in blocks_file' in str(e) - - predictor.set_blocks_file('blocks_file{state}') - assert predictor.blocks == 'blocks_file{state}' - - with pytest.raises(ValueError) as e: - predictor.set_blocks_file() - assert 'No block file provided' in str(e) - - predictor.config = {'paths': {'analysis': {'block_files': 'blocks_file'}}} - with pytest.raises(ValueError) as e: - predictor.set_blocks_file() - assert '{state} not found in blocks_file' in str(e) - - predictor.config = {'paths': {'analysis': {'block_files': - 'blocks_file{state}'}}} - predictor.set_blocks_file() - assert predictor.blocks == 'blocks_file{state}' - - -def test_set_prefix(predictor): - predictor.known_states = ['s1'] - predictor.set_prefix() - assert predictor.prefix == 's1' - - predictor.known_states = 's1 s2'.split() - predictor.set_prefix() - assert predictor.prefix == 's1_s2' - - predictor.set_prefix('prefix') - assert predictor.prefix == 'prefix' - - predictor.known_states = [] - with pytest.raises(ValueError) as e: - predictor.set_prefix() - assert 'Unable to build prefix, no known states provided' in str(e) - - -def test_set_threshold(predictor): - with pytest.raises(ValueError) as e: - predictor.set_threshold() - assert 'No threshold provided' in str(e) - - predictor.config = {'analysis_params': {'threshold': 'asdf'}} - with pytest.raises(ValueError) as e: - predictor.set_threshold() - assert 'Unsupported threshold value: asdf' in str(e) - - predictor.set_threshold(0.05) - assert predictor.threshold == 0.05 - - predictor.config = {'analysis_params': - {'threshold': 'viterbi'}} - predictor.set_threshold() - assert predictor.threshold == 'viterbi' - - -def test_set_strains(predictor, mocker): - mock_find = mocker.patch.object(predict.Predictor, 'find_strains') - - predictor.set_strains() - mock_find.called_with(None) - - with pytest.raises(ValueError) as e: - predictor.config = {'paths': {'test_strains': ['test']}} - predictor.set_strains() - assert '{strain} not found in test' in str(e) - - with pytest.raises(ValueError) as e: - predictor.config = {'paths': {'test_strains': ['test{strain}']}} - predictor.set_strains() - assert '{chrom} not found in test{strain}' in str(e) - - predictor.config = {'paths': {'test_strains': - ['test{strain}{chrom}']}} - predictor.set_strains() - mock_find.called_with(['test{strain}{chrom}']) + }) - predictor.set_strains('test{strain}{chrom}') - mock_find.called_with(['test{strain}{chrom}']) + return config -def test_find_strains(predictor, mocker): - with pytest.raises(ValueError) as e: - predictor.find_strains() - assert ('Unable to find strains in config and ' - 'no test_strains provided') in str(e) - - predictor.config = {'strains': ['test2', 'test1']} - predictor.find_strains() - # sorted - assert predictor.strains == 'test1 test2'.split() - - predictor.config = {} - predictor.chromosomes = ['I'] - - # too many chroms for s1 - mock_glob = mocker.patch('analyze.predict.glob.iglob', - side_effect=[[ - 'test_prefix_s1_cII.fa', - 'test_prefix_s2_cII.fa', - 'test_prefix_s1_cIII.fa', - 'test_prefix.fa', - ]]) - mock_log = mocker.patch('analyze.predict.log') - with pytest.raises(ValueError) as e: - predictor.find_strains(['test_prefix_{strain}_c{chrom}.fa']) - - assert "Strain s1 is missing chromosomes. Unable to find chromosome 'I'"\ - in str(e) - mock_glob.assert_called_with('test_prefix_*_c*.fa') - mock_log.info.assert_called_with('searching for test_prefix_*_c*.fa') - assert mock_log.debug.call_args_list == \ - [mocker.call("matched with ('s1', 'II')"), - mocker.call("matched with ('s2', 'II')"), - mocker.call("matched with ('s1', 'III')"), - ] - - # no matches - mock_glob = mocker.patch('analyze.predict.glob.iglob', - side_effect=[[ - 'test_prefix.fa', - ]]) - mock_log = mocker.patch('analyze.predict.log') - with pytest.raises(ValueError) as e: - predictor.find_strains(['test_prefix_{strain}_{chrom}.fa']) - assert ('Found no chromosome sequence files in ' - "['test_prefix_{strain}_{chrom}.fa']") in str(e) - mock_glob.assert_called_with('test_prefix_*_*.fa') - mock_log.info.assert_called_with('searching for test_prefix_*_*.fa') - assert mock_log.debug.call_args_list == [] - - # correct, with second test_strains, extra chromosomes - mock_glob = mocker.patch('analyze.predict.glob.iglob', - side_effect=[ - [ - 'test_prefix_s1_cI.fa', - 'test_prefix_s2_cI.fa', - 'test_prefix_s2_cII.fa', - 'test_prefix.fa', - ], - ['test_prefix_cI_s3.fa'] - ]) - mock_log = mocker.patch('analyze.predict.log') - predictor.find_strains(['test_prefix_{strain}_c{chrom}.fa', - 'test_prefix_c{chrom}_{strain}.fa']) - assert mock_glob.call_args_list == \ - [mocker.call('test_prefix_*_c*.fa'), - mocker.call('test_prefix_c*_*.fa')] - assert mock_log.info.call_args_list ==\ - [mocker.call('searching for test_prefix_*_c*.fa'), - mocker.call('searching for test_prefix_c*_*.fa')] - assert mock_log.debug.call_args_list == \ - [mocker.call("matched with ('s1', 'I')"), - mocker.call("matched with ('s2', 'I')"), - mocker.call("matched with ('s2', 'II')"), - mocker.call("matched with ('s3', 'I')"), - ] - assert predictor.strains == ['s1', 's2', 's3'] - - -def test_set_output_files(predictor): - with pytest.raises(ValueError) as e: - predictor.set_output_files('', '', '', '', '') - assert 'No initial hmm file provided' in str(e) - - with pytest.raises(ValueError) as e: - predictor.set_output_files('init', '', '', '', '') - assert 'No trained hmm file provided' in str(e) - - with pytest.raises(ValueError) as e: - predictor.set_output_files('init', 'trained', 'pos', 'prob', '') - assert 'No alignment file provided' in str(e) - - with pytest.raises(ValueError) as e: - predictor.set_output_files('init', 'trained', 'pos', 'prob', 'align') - assert '{prefix} not found in align' in str(e) - - with pytest.raises(ValueError) as e: - predictor.set_output_files('init', 'trained', 'pos', 'prob', - 'align{prefix}') - assert '{strain} not found in align{prefix}' in str(e) - - with pytest.raises(ValueError) as e: - predictor.set_output_files('init', 'trained', 'pos', 'prob', - 'align{prefix}{strain}') - assert '{chrom} not found in align{prefix}{strain}' in str(e) - - predictor.prefix = 'pre' - predictor.set_output_files('init', 'trained', 'pos', 'prob', - 'align{prefix}{strain}{chrom}') - assert predictor.hmm_initial == 'init' - assert predictor.hmm_trained == 'trained' - assert predictor.positions == 'pos' - assert predictor.probabilities == 'prob' - assert predictor.alignment == 'alignpre{strain}{chrom}' - - predictor.set_output_files('init', 'trained', '', 'prob', - 'align{prefix}{strain}{chrom}') - assert predictor.hmm_initial == 'init' - assert predictor.hmm_trained == 'trained' - assert predictor.positions is None - assert predictor.probabilities == 'prob' - assert predictor.alignment == 'alignpre{strain}{chrom}' - - with pytest.raises(ValueError) as e: - predictor.config = {'paths': {'analysis': {'hmm_initial': 'init'}}} - predictor.set_output_files('', '', '', '', '') - assert 'No trained hmm file provided' in str(e) - - with pytest.raises(ValueError) as e: - predictor.config = {'paths': {'analysis': {'hmm_initial': 'init', - 'hmm_trained': 'trained', - 'positions': 'pos' - }}} - predictor.set_output_files('', '', '', '', '') - assert 'No probabilities file provided' in str(e) - - with pytest.raises(ValueError) as e: - predictor.config = {'paths': {'analysis': {'hmm_initial': 'init', - 'hmm_trained': 'trained', - 'positions': 'pos', - 'probabilities': 'prob' - }}} - predictor.set_output_files('', '', '', '', '') - assert 'No alignment file provided' in str(e) - - predictor.config = {'paths': {'analysis': { - 'hmm_initial': 'init', - 'hmm_trained': 'trained', - 'positions': 'pos', - 'probabilities': 'prob', - 'alignment': 'align{prefix}{strain}{chrom}' - }}} - predictor.set_output_files('', '', '', '', '') - - assert predictor.hmm_initial == 'init' - assert predictor.hmm_trained == 'trained' - assert predictor.positions == 'pos' - assert predictor.probabilities == 'prob' - assert predictor.alignment == 'alignpre{strain}{chrom}' - - predictor.config = {'paths': {'analysis': { - 'hmm_initial': 'init', - 'hmm_trained': 'trained', - 'probabilities': 'prob', - 'alignment': 'align{prefix}{strain}{chrom}' - }}} - predictor.set_output_files('', '', '', '', '') - - assert predictor.hmm_initial == 'init' - assert predictor.hmm_trained == 'trained' - assert predictor.positions is None - assert predictor.probabilities == 'prob' - assert predictor.alignment == 'alignpre{strain}{chrom}' - - -def test_validate_arguments(predictor): - predictor.chromosomes = 1 - predictor.blocks = 1 - predictor.prefix = 1 - predictor.strains = 1 - predictor.hmm_initial = 1 - predictor.hmm_trained = 1 - predictor.probabilities = 1 - predictor.alignment = 1 - predictor.known_states = 1 - predictor.unknown_states = 1 - predictor.threshold = 1 - predictor.config = { - 'analysis_params': - {'reference': {'name': 'S288c'}, - 'known_states': [ - {'name': 'CBS432', - 'expected_length': 1, - 'expected_fraction': 0.01}, - {'name': 'N_45', - 'expected_length': 1, - 'expected_fraction': 0.01}, - {'name': 'DBVPG6304', - 'expected_length': 1, - 'expected_fraction': 0.01}, - {'name': 'UWOPS91_917_1', - 'expected_length': 1, - 'expected_fraction': 0.01}, - ], - 'unknown_states': [{'name': 'unknown', - 'expected_length': 1, - 'expected_fraction': 0.01}, - ] - } - } - - assert predictor.validate_arguments() - - args = [ - 'chromosomes', - 'blocks', - 'prefix', - 'strains', - 'hmm_initial', - 'hmm_trained', - 'probabilities', - 'alignment', - 'known_states', - 'unknown_states', - 'threshold' - ] - - for arg in args: - predictor.__dict__[arg] = None - with pytest.raises(ValueError) as e: - predictor.validate_arguments() - assert ('Failed to validate Predictor, ' - f'required argument {arg} was unset') in str(e) - predictor.__dict__[arg] = 1 - - predictor.config = { - 'analysis_params': - {'reference': {'name': 'S288c'}, - 'unknown_states': [{'name': 'unknown', - 'expected_length': 1, - 'expected_fraction': 0.01}, - ] - } - } - with pytest.raises(ValueError) as e: - predictor.validate_arguments() - assert 'Configuration did not provide any known_states' in str(e) - - predictor.config = { - 'analysis_params': - {'known_states': [ - {'name': 'CBS432', - 'expected_length': 1, - 'expected_fraction': 0.01}, - {'name': 'N_45', - 'expected_length': 1, - 'expected_fraction': 0.01}, - ], - 'unknown_states': [{'name': 'unknown', - 'expected_length': 1, - 'expected_fraction': 0.01}, - ] - } - } - with pytest.raises(ValueError) as e: - predictor.validate_arguments() - assert 'Configuration did not specify a reference strain' in str(e) - - predictor.config = { - 'analysis_params': - {'reference': {'name': 'S288c'}, - 'known_states': [ - {'name': 'CBS432', - 'expected_fraction': 0.01}, - {'name': 'N_45', - 'expected_length': 1, - 'expected_fraction': 0.01}, - ], - 'unknown_states': [{'name': 'unknown', - 'expected_length': 1, - 'expected_fraction': 0.01}, - ] - } - } - with pytest.raises(ValueError) as e: - predictor.validate_arguments() - assert 'CBS432 did not provide an expected_length' in str(e) - - predictor.config = { - 'analysis_params': - {'reference': {'name': 'S288c'}, - 'known_states': [ - {'name': 'CBS432', - 'expected_length': 1, - 'expected_fraction': 0.01}, - {'name': 'N_45', - 'expected_length': 1, - }, - ], - 'unknown_states': [{'name': 'unknown', - 'expected_length': 1, - 'expected_fraction': 0.01}, - ] - } - } - with pytest.raises(ValueError) as e: - predictor.validate_arguments() - assert 'N_45 did not provide an expected_fraction' in str(e) +@pytest.fixture +def predictor(config): + result = predict.Predictor(config) + config.set_states() + return result - predictor.config = { - 'analysis_params': - {'reference': {'name': 'S288c'}, - 'known_states': [ - {'name': 'CBS432', - 'expected_length': 1, - 'expected_fraction': 0.01}, - {'name': 'N_45', - 'expected_length': 1, - 'expected_fraction': 0.01}, - ], - 'unknown_states': [{'name': 'unknown', - 'expected_fraction': 0.01}, - ] - } - } - with pytest.raises(ValueError) as e: - predictor.validate_arguments() - assert 'unknown did not provide an expected_length' in str(e) - predictor.config = { - 'analysis_params': - {'reference': {'name': 'S288c'}, - 'known_states': [ - {'name': 'CBS432', - 'expected_length': 1, - 'expected_fraction': 0.01}, - {'name': 'N_45', - 'expected_length': 1, - 'expected_fraction': 0.01}, - ], - 'unknown_states': [{'name': 'unknown', - 'expected_length': 1, - }, - ] - } - } - with pytest.raises(ValueError) as e: - predictor.validate_arguments() - assert 'unknown did not provide an expected_fraction' in str(e) - - -def test_run_prediction_no_pos(predictor, mocker, capsys): - predictor.chromosomes = ['I', 'II'] - predictor.blocks = 'blocks{state}.txt' - predictor.prefix = 'prefix' - predictor.strains = ['s1', 's2'] - predictor.hmm_initial = 'hmm_initial.txt' - predictor.hmm_trained = 'hmm_trained.txt' - predictor.probabilities = 'probs.txt' - predictor.alignment = 'prefix_{strain}_chr{chrom}.maf' - predictor.known_states = 'S288c CBS432 N_45 DBVP UWOP'.split() - predictor.unknown_states = ['unknown'] - predictor.states = predictor.known_states + predictor.unknown_states - predictor.threshold = 'viterbi' - predictor.config = { +def test_predictor(predictor): + assert predictor.config.known_states ==\ + 'S288c CBS432 N_45 DBVPG6304 UWOPS91_917_1'.split() + assert predictor.config.unknown_states == ['unknown'] + + +def test_run_prediction_no_pos(predictor, config, mocker, capsys): + config.chromosomes = ['I', 'II'] + config.blocks = 'blocks{state}.txt' + config.prefix = 'prefix' + config.strains = ['s1', 's2'] + config.hmm_initial = 'hmm_initial.txt' + config.hmm_trained = 'hmm_trained.txt' + config.probabilities = 'probs.txt' + config.positions = None + config.alignment = 'prefix_{strain}_chr{chrom}.maf' + config.known_states = 'S288c CBS432 N_45 DBVP UWOP'.split() + config.unknown_states = ['unknown'] + config.states = config.known_states + config.unknown_states + config.threshold = 'viterbi' + config.config = { 'analysis_params': {'reference': {'name': 'S288c'}, 'known_states': [ @@ -563,21 +126,21 @@ def test_run_prediction_no_pos(predictor, mocker, capsys): ] -def test_run_prediction_full(predictor, mocker): - predictor.chromosomes = ['I', 'II'] - predictor.blocks = 'blocks{state}.txt' - predictor.prefix = 'prefix' - predictor.strains = ['s1', 's2'] - predictor.hmm_initial = 'hmm_initial.txt' - predictor.hmm_trained = 'hmm_trained.txt' - predictor.probabilities = 'probs.txt' - predictor.positions = 'pos.txt' - predictor.alignment = 'prefix_{strain}_chr{chrom}.maf' - predictor.known_states = 'S288c CBS432 N_45 DBVP UWOP'.split() - predictor.unknown_states = ['unknown'] - predictor.states = predictor.known_states + predictor.unknown_states - predictor.threshold = 'viterbi' - predictor.config = { +def test_run_prediction_full(predictor, config, mocker): + config.chromosomes = ['I', 'II'] + config.blocks = 'blocks{state}.txt' + config.prefix = 'prefix' + config.strains = ['s1', 's2'] + config.hmm_initial = 'hmm_initial.txt' + config.hmm_trained = 'hmm_trained.txt' + config.probabilities = 'probs.txt' + config.positions = 'pos.txt' + config.alignment = 'prefix_{strain}_chr{chrom}.maf' + config.known_states = 'S288c CBS432 N_45 DBVP UWOP'.split() + config.unknown_states = ['unknown'] + config.states = config.known_states + config.unknown_states + config.threshold = 'viterbi' + config.config = { 'analysis_params': {'reference': {'name': 'S288c'}, 'known_states': [ @@ -790,16 +353,14 @@ def test_run_prediction_full(predictor, mocker): ]) -def test_write_hmm_header(predictor): - predictor.known_states = [] - predictor.unknown_states = [] +def test_write_hmm_header(predictor, config): + config.states = [] predictor.emission_symbols = [] writer = StringIO() predictor.write_hmm_header(writer) assert writer.getvalue() == 'strain\tchromosome\t\n' - predictor.known_states = ['s1', 's2'] - predictor.unknown_states = ['u1'] + config.states = ['s1', 's2', 'u1'] predictor.emission_symbols = ['-', '+'] writer = StringIO() predictor.write_hmm_header(writer) @@ -893,13 +454,13 @@ def test_write_positions(predictor): def test_write_state_probs(predictor): output = StringIO() - predictor.states = [] + predictor.config.states = [] predictor.write_state_probs([{}], output, 'strain', 'I') assert output.getvalue() == 'strain\tI\t\n' output = StringIO() - predictor.states = list('abc') + predictor.config.states = list('abc') predictor.write_state_probs([ [0, 0, 1], [1, 0, 0], @@ -913,16 +474,16 @@ def test_write_state_probs(predictor): 'c:1.00000,0.00000,1.00000\n') -def test_process_path(predictor, hm): +def test_process_path(predictor, config, hm): probs = hm.posterior_decoding()[0] - predictor.set_threshold(0.8) - predictor.states = 'N E'.split() - predictor.known_states = 'N E'.split() + config.set_threshold(0.8) + config.states = 'N E'.split() + config.known_states = 'N E'.split() path, probability = predictor.process_path(hm) assert (probability == probs).all() assert path == 'E E N E E N E E N N'.split() - predictor.set_threshold('viterbi') + config.set_threshold('viterbi') path, probability = predictor.process_path(hm) assert (probability == probs).all() @@ -942,7 +503,7 @@ def test_convert_to_blocks(predictor): def help_test_convert_blocks(states, seq, predictor): - predictor.states = states + predictor.config.states = states blocks = predictor.convert_to_blocks(seq) nseq = np.array(seq, int) diff --git a/code/test/helper_scripts/test_id_main.slurm b/code/test/helper_scripts/test_id_main.slurm index ec4dea2..e42e9fa 100755 --- a/code/test/helper_scripts/test_id_main.slurm +++ b/code/test/helper_scripts/test_id_main.slurm @@ -6,18 +6,13 @@ #SBATCH -n 1 #SBATCH -o "/tigress/tcomi/aclark4_temp/results/id_%A" -# ARGS=$(head -n $SLURM_ARRAY_TASK_ID predict_args.txt | tail -n 1) - -export PYTHONPATH=/home/tcomi/projects/aclark4_introgression/code/ - -#ARGS="p4e2 .001 viterbi 10000 .025 10000 .025 10000 .025 10000 .025 unknown 1000 .01" -ARGS="_test .001 viterbi 10000 .025 10000 .025 10000 .025 10000 .025 unknown 1000 .01" - -#Make sure chrms is set to only I -#ARGS="_chr1_test .001 viterbi 10000 .025 10000 .025 10000 .025 10000 .025 unknown 1000 .01" -#ARGS="_chr1 .001 viterbi 10000 .025 10000 .025 10000 .025 10000 .025 unknown 1000 .01" +config=/home/tcomi/projects/aclark4_introgression/code/config.yaml module load anaconda3 conda activate introgression3 -python $PYTHONPATH/analyze/id_regions_main.py $ARGS +introgression \ + --config $config \ + -vvvv \ + --log-file test.log \ + id-regions diff --git a/code/test/helper_scripts/test_predict.slurm b/code/test/helper_scripts/test_predict.slurm index d407187..1a55ef9 100755 --- a/code/test/helper_scripts/test_predict.slurm +++ b/code/test/helper_scripts/test_predict.slurm @@ -12,7 +12,7 @@ conda activate introgression3 introgression \ --config $config \ - -vvvv \ + -vv \ --log-file test.log \ predict diff --git a/code/test/misc/test_config_utils.py b/code/test/misc/test_config_utils.py index f0bed03..39a59e3 100644 --- a/code/test/misc/test_config_utils.py +++ b/code/test/misc/test_config_utils.py @@ -2,7 +2,7 @@ from misc.config_utils import (clean_config, clean_list, merge_lists, merge_dicts, get_nested, check_wildcards, - get_states, validate) + validate) def test_simple(): @@ -156,49 +156,6 @@ def test_check_wildcards(mocker): assert '{test} not found in test.txt' in str(e) -def test_get_states(): - assert get_states({}) == ([], []) - assert get_states( - { - 'analysis_params': { - 'known_states': [ - {'name': 'k1'}, - {'name': 'k2'}, - {'name': 'k3'}, - ], - 'unknown_states': [ - {'name': 'u1'}, - {'name': 'u2'}, - ] - } - }) == ('k1 k2 k3'.split(), 'u1 u2'.split()) - assert get_states( - { - 'analysis_params': { - 'reference': {'name': 'ref'}, - 'unknown_states': [ - {'name': 'u1'}, - {'name': 'u2'}, - ] - } - }) == ('ref'.split(), 'u1 u2'.split()) - assert get_states( - { - 'analysis_params': { - 'reference': {'name': 'ref'}, - 'known_states': [ - {'name': 'k1'}, - {'name': 'k2'}, - {'name': 'k3'}, - ], - 'unknown_states': [ - {'name': 'u1'}, - {'name': 'u2'}, - ] - } - }) == ('ref k1 k2 k3'.split(), 'u1 u2'.split()) - - def test_validate(mocker): assert validate({}, '', '', 'test') == 'test' assert validate({'path': 'test'}, 'path', '') == 'test' From 1bd8cf21ee79617bff937ac59917389f99a332d4 Mon Sep 17 00:00:00 2001 From: Troy Comi Date: Tue, 30 Apr 2019 10:51:07 -0400 Subject: [PATCH 20/33] Moved validate code, required position Move the validation code to the corresponding main objects Seeing that positions is required for summarize, went back and made that required. --- code/analyze/id_regions.py | 24 +- code/analyze/introgression_configuration.py | 93 +------ code/analyze/predict.py | 76 +++++- code/analyze/summarize_region_quality_main.py | 2 +- code/test/analyze/test_id_regions.py | 15 ++ .../test_introgression_configuration.py | 232 +----------------- code/test/analyze/test_main_id_args.py | 97 ++++++++ code/test/analyze/test_main_predict_args.py | 6 +- code/test/analyze/test_main_predict_config.py | 6 +- code/test/analyze/test_predict_predictor.py | 225 ++++++++++++----- 10 files changed, 384 insertions(+), 392 deletions(-) create mode 100644 code/test/analyze/test_main_id_args.py diff --git a/code/analyze/id_regions.py b/code/analyze/id_regions.py index 71090d7..f964881 100644 --- a/code/analyze/id_regions.py +++ b/code/analyze/id_regions.py @@ -3,6 +3,7 @@ from analyze.introgression_configuration import Configuration from analyze.predict import read_blocks import click +import logging as log class ID_producer(): @@ -17,7 +18,7 @@ def add_ids(self): ''' Adds a unique region id to block files, producing labeled text files ''' - self.config.validate_id_regions_arguments() + self.validate_arguments() regions = dict(zip(self.config.chromosomes, [[] for _ in self.config.chromosomes])) with ExitStack() as stack: @@ -68,3 +69,24 @@ def add_ids(self): id_counter += 1 if progress_bar: progress_bar.update(1) + + def validate_arguments(self): + ''' + Check that all required instance variables are set to perform a + id producer run. Returns true if valid, raises value error otherwise + ''' + args = [ + 'chromosomes', + 'blocks', + 'labeled_blocks', + 'states', + ] + variables = self.config.__dict__ + for arg in args: + if arg not in variables or variables[arg] is None: + err = ('Failed to validate ID Producer, required argument ' + f"'{arg}' was unset") + log.exception(err) + raise ValueError(err) + + return True diff --git a/code/analyze/introgression_configuration.py b/code/analyze/introgression_configuration.py index fa3d364..41fc2a3 100644 --- a/code/analyze/introgression_configuration.py +++ b/code/analyze/introgression_configuration.py @@ -233,11 +233,10 @@ def set_predict_files(self, 'No trained hmm file provided', hmm_trained) - if positions == '': - self.positions = get_nested(self.config, - 'paths.analysis.positions') - else: - self.positions = positions + self.positions = validate(self.config, + 'paths.analysis.positions', + 'No positions file provided', + positions) self.probabilities = validate(self.config, 'paths.analysis.probabilities', @@ -302,90 +301,6 @@ def get(self, key: str): ''' return get_nested(self.config, key) - def validate_predict_arguments(self): - ''' - Check that all required instance variables are set to perform a - prediction run. Returns true if valid, raises value error otherwise - ''' - args = [ - 'chromosomes', - 'blocks', - 'prefix', - 'strains', - 'hmm_initial', - 'hmm_trained', - 'probabilities', - 'alignment', - 'known_states', - 'unknown_states', - 'threshold', - ] - variables = self.__dict__ - for arg in args: - if arg not in variables or variables[arg] is None: - err = ('Failed to validate Predictor, required argument ' - f"'{arg}' was unset") - log.exception(err) - raise ValueError(err) - - # check the parameters for each state are present - known_states = self.get('analysis_params.known_states') - if known_states is None: - err = 'Configuration did not provide any known_states' - log.exception(err) - raise ValueError(err) - - for s in known_states: - if 'expected_length' not in s: - err = f'{s["name"]} did not provide an expected_length' - log.exception(err) - raise ValueError(err) - if 'expected_fraction' not in s: - err = f'{s["name"]} did not provide an expected_fraction' - log.exception(err) - raise ValueError(err) - - unknown_states = self.get('analysis_params.unknown_states') - if unknown_states is not None: - for s in unknown_states: - if 'expected_length' not in s: - err = f'{s["name"]} did not provide an expected_length' - log.exception(err) - raise ValueError(err) - if 'expected_fraction' not in s: - err = f'{s["name"]} did not provide an expected_fraction' - log.exception(err) - raise ValueError(err) - - reference = self.get('analysis_params.reference') - if reference is None: - err = f'Configuration did not specify a reference strain' - log.exception(err) - raise ValueError(err) - - return True - - def validate_id_regions_arguments(self): - ''' - Check that all required instance variables are set to perform a - id producer run. Returns true if valid, raises value error otherwise - ''' - args = [ - 'chromosomes', - 'blocks', - 'labeled_blocks', - 'states', - ] - variables = self.__dict__ - for arg in args: - if arg not in variables or variables[arg] is None: - err = ('Failed to validate ID Producer, required argument ' - f"'{arg}' was unset") - log.exception(err) - raise ValueError(err) - - return True - def __repr__(self): return ('Config file:\n' + print_dict(self.config) + diff --git a/code/analyze/predict.py b/code/analyze/predict.py index 7a544e7..b19f8dd 100644 --- a/code/analyze/predict.py +++ b/code/analyze/predict.py @@ -88,7 +88,7 @@ def run_prediction(self, only_poly_sites=True): ''' Run prediction with this predictor object ''' - self.config.validate_predict_arguments() + self.validate_arguments() hmm_builder = HMM_Builder(self.config) hmm_builder.set_expected_values() @@ -98,17 +98,12 @@ def run_prediction(self, only_poly_sites=True): with open(self.config.hmm_initial, 'w') as initial, \ open(self.config.hmm_trained, 'w') as trained, \ gzip.open(self.config.probabilities, 'wt') as probabilities, \ + gzip.open(self.config.positions, 'wt') as positions, \ ExitStack() as stack: self.write_hmm_header(initial) self.write_hmm_header(trained) - if self.config.positions is not None: - positions = stack.enter_context( - gzip.open(self.config.positions, 'wt')) - else: - positions = None - block_writers = {state: stack.enter_context( open(self.config.blocks.format( @@ -151,8 +146,7 @@ def run_prediction(self, only_poly_sites=True): hmm_trained) state_blocks = self.convert_to_blocks(predicted_states) - if positions is not None: - self.write_positions(pos, positions, strain, chrom) + self.write_positions(pos, positions, strain, chrom) for state, block in state_blocks.items(): self.write_blocks(block, @@ -168,6 +162,70 @@ def run_prediction(self, only_poly_sites=True): if progress_bar: progress_bar.update(1) + def validate_arguments(self): + ''' + Check that all required instance variables are set to perform a + prediction run. Returns true if valid, raises value error otherwise + ''' + args = [ + 'chromosomes', + 'blocks', + 'prefix', + 'strains', + 'hmm_initial', + 'hmm_trained', + 'probabilities', + 'positions', + 'alignment', + 'known_states', + 'unknown_states', + 'threshold', + ] + variables = self.config.__dict__ + for arg in args: + if arg not in variables or variables[arg] is None: + err = ('Failed to validate Predictor, required argument ' + f"'{arg}' was unset") + log.exception(err) + raise ValueError(err) + + # check the parameters for each state are present + known_states = self.config.get('analysis_params.known_states') + if known_states is None: + err = 'Configuration did not provide any known_states' + log.exception(err) + raise ValueError(err) + + for s in known_states: + if 'expected_length' not in s: + err = f'{s["name"]} did not provide an expected_length' + log.exception(err) + raise ValueError(err) + if 'expected_fraction' not in s: + err = f'{s["name"]} did not provide an expected_fraction' + log.exception(err) + raise ValueError(err) + + unknown_states = self.config.get('analysis_params.unknown_states') + if unknown_states is not None: + for s in unknown_states: + if 'expected_length' not in s: + err = f'{s["name"]} did not provide an expected_length' + log.exception(err) + raise ValueError(err) + if 'expected_fraction' not in s: + err = f'{s["name"]} did not provide an expected_fraction' + log.exception(err) + raise ValueError(err) + + reference = self.config.get('analysis_params.reference') + if reference is None: + err = f'Configuration did not specify a reference strain' + log.exception(err) + raise ValueError(err) + + return True + def write_hmm_header(self, writer: TextIO) -> None: ''' Write the header line for an hmm file to the provided textIO object diff --git a/code/analyze/summarize_region_quality_main.py b/code/analyze/summarize_region_quality_main.py index e2d9348..eba2088 100644 --- a/code/analyze/summarize_region_quality_main.py +++ b/code/analyze/summarize_region_quality_main.py @@ -26,9 +26,9 @@ def main() -> None: -{species}_chr_intervals.txt -{species}_chr_mafft.fa -{species}_chr_mafft.fa + -positions_{tag}.txt.gz Output files: - -positions_{tag}.txt.gz -regions file as {species}.fa.gz -index file for the fz.gz -blocks_{species}_quality.txt diff --git a/code/test/analyze/test_id_regions.py b/code/test/analyze/test_id_regions.py index 62151b4..056231e 100644 --- a/code/test/analyze/test_id_regions.py +++ b/code/test/analyze/test_id_regions.py @@ -140,3 +140,18 @@ def test_add_ids(id_producer, mocker): mocker.call('r14\tstrain3\tX\tstate1\t10\t100\t1\n'), ] mocked_file().write.assert_has_calls(calls) + + +def test_validate_arguments(id_producer): + with pytest.raises(ValueError) as e: + id_producer.validate_arguments() + assert ('Failed to validate ID Producer, ' + "required argument 'chromosomes' was unset") in str(e) + + config = id_producer.config + config.chromosomes = 1 + config.blocks = 1 + config.labeled_blocks = 1 + config.states = 1 + + assert id_producer.validate_arguments() diff --git a/code/test/analyze/test_introgression_configuration.py b/code/test/analyze/test_introgression_configuration.py index 1e7a032..01a1004 100644 --- a/code/test/analyze/test_introgression_configuration.py +++ b/code/test/analyze/test_introgression_configuration.py @@ -325,31 +325,23 @@ def test_set_predict_files(config): with pytest.raises(ValueError) as e: config.set_predict_files('init', 'trained', 'pos', 'prob', - 'align{prefix}') + 'align{prefix}') assert '{strain} not found in align{prefix}' in str(e) with pytest.raises(ValueError) as e: config.set_predict_files('init', 'trained', 'pos', 'prob', - 'align{prefix}{strain}') + 'align{prefix}{strain}') assert '{chrom} not found in align{prefix}{strain}' in str(e) config.prefix = 'pre' config.set_predict_files('init', 'trained', 'pos', 'prob', - 'align{prefix}{strain}{chrom}') + 'align{prefix}{strain}{chrom}') assert config.hmm_initial == 'init' assert config.hmm_trained == 'trained' assert config.positions == 'pos' assert config.probabilities == 'prob' assert config.alignment == 'alignpre{strain}{chrom}' - config.set_predict_files('init', 'trained', '', 'prob', - 'align{prefix}{strain}{chrom}') - assert config.hmm_initial == 'init' - assert config.hmm_trained == 'trained' - assert config.positions is None - assert config.probabilities == 'prob' - assert config.alignment == 'alignpre{strain}{chrom}' - with pytest.raises(ValueError) as e: config.config = {'paths': {'analysis': {'hmm_initial': 'init'}}} config.set_predict_files('', '', '', '', '') @@ -357,18 +349,18 @@ def test_set_predict_files(config): with pytest.raises(ValueError) as e: config.config = {'paths': {'analysis': {'hmm_initial': 'init', - 'hmm_trained': 'trained', - 'positions': 'pos' - }}} + 'hmm_trained': 'trained', + 'positions': 'pos' + }}} config.set_predict_files('', '', '', '', '') assert 'No probabilities file provided' in str(e) with pytest.raises(ValueError) as e: config.config = {'paths': {'analysis': {'hmm_initial': 'init', - 'hmm_trained': 'trained', - 'positions': 'pos', - 'probabilities': 'prob' - }}} + 'hmm_trained': 'trained', + 'positions': 'pos', + 'probabilities': 'prob' + }}} config.set_predict_files('', '', '', '', '') assert 'No alignment file provided' in str(e) @@ -386,207 +378,3 @@ def test_set_predict_files(config): assert config.positions == 'pos' assert config.probabilities == 'prob' assert config.alignment == 'alignpre{strain}{chrom}' - - config.config = {'paths': {'analysis': { - 'hmm_initial': 'init', - 'hmm_trained': 'trained', - 'probabilities': 'prob', - 'alignment': 'align{prefix}{strain}{chrom}' - }}} - config.set_predict_files('', '', '', '', '') - - assert config.hmm_initial == 'init' - assert config.hmm_trained == 'trained' - assert config.positions is None - assert config.probabilities == 'prob' - assert config.alignment == 'alignpre{strain}{chrom}' - - -def test_validate_predict_arguments(config): - config.chromosomes = 1 - config.blocks = 1 - config.prefix = 1 - config.strains = 1 - config.hmm_initial = 1 - config.hmm_trained = 1 - config.probabilities = 1 - config.alignment = 1 - config.known_states = 1 - config.unknown_states = 1 - config.threshold = 1 - config.config = { - 'analysis_params': - {'reference': {'name': 'S288c'}, - 'known_states': [ - {'name': 'CBS432', - 'expected_length': 1, - 'expected_fraction': 0.01}, - {'name': 'N_45', - 'expected_length': 1, - 'expected_fraction': 0.01}, - {'name': 'DBVPG6304', - 'expected_length': 1, - 'expected_fraction': 0.01}, - {'name': 'UWOPS91_917_1', - 'expected_length': 1, - 'expected_fraction': 0.01}, - ], - 'unknown_states': [{'name': 'unknown', - 'expected_length': 1, - 'expected_fraction': 0.01}, - ] - } - } - - assert config.validate_predict_arguments() - - args = [ - 'chromosomes', - 'blocks', - 'prefix', - 'strains', - 'hmm_initial', - 'hmm_trained', - 'probabilities', - 'alignment', - 'known_states', - 'unknown_states', - 'threshold' - ] - - for arg in args: - config.__dict__[arg] = None - with pytest.raises(ValueError) as e: - config.validate_predict_arguments() - assert ('Failed to validate Predictor, ' - f"required argument '{arg}' was unset") in str(e) - config.__dict__[arg] = 1 - - config.config = { - 'analysis_params': - {'reference': {'name': 'S288c'}, - 'unknown_states': [{'name': 'unknown', - 'expected_length': 1, - 'expected_fraction': 0.01}, - ] - } - } - with pytest.raises(ValueError) as e: - config.validate_predict_arguments() - assert 'Configuration did not provide any known_states' in str(e) - - config.config = { - 'analysis_params': - {'known_states': [ - {'name': 'CBS432', - 'expected_length': 1, - 'expected_fraction': 0.01}, - {'name': 'N_45', - 'expected_length': 1, - 'expected_fraction': 0.01}, - ], - 'unknown_states': [{'name': 'unknown', - 'expected_length': 1, - 'expected_fraction': 0.01}, - ] - } - } - with pytest.raises(ValueError) as e: - config.validate_predict_arguments() - assert 'Configuration did not specify a reference strain' in str(e) - - config.config = { - 'analysis_params': - {'reference': {'name': 'S288c'}, - 'known_states': [ - {'name': 'CBS432', - 'expected_fraction': 0.01}, - {'name': 'N_45', - 'expected_length': 1, - 'expected_fraction': 0.01}, - ], - 'unknown_states': [{'name': 'unknown', - 'expected_length': 1, - 'expected_fraction': 0.01}, - ] - } - } - with pytest.raises(ValueError) as e: - config.validate_predict_arguments() - assert 'CBS432 did not provide an expected_length' in str(e) - - config.config = { - 'analysis_params': - {'reference': {'name': 'S288c'}, - 'known_states': [ - {'name': 'CBS432', - 'expected_length': 1, - 'expected_fraction': 0.01}, - {'name': 'N_45', - 'expected_length': 1, - }, - ], - 'unknown_states': [{'name': 'unknown', - 'expected_length': 1, - 'expected_fraction': 0.01}, - ] - } - } - with pytest.raises(ValueError) as e: - config.validate_predict_arguments() - assert 'N_45 did not provide an expected_fraction' in str(e) - - config.config = { - 'analysis_params': - {'reference': {'name': 'S288c'}, - 'known_states': [ - {'name': 'CBS432', - 'expected_length': 1, - 'expected_fraction': 0.01}, - {'name': 'N_45', - 'expected_length': 1, - 'expected_fraction': 0.01}, - ], - 'unknown_states': [{'name': 'unknown', - 'expected_fraction': 0.01}, - ] - } - } - with pytest.raises(ValueError) as e: - config.validate_predict_arguments() - assert 'unknown did not provide an expected_length' in str(e) - - config.config = { - 'analysis_params': - {'reference': {'name': 'S288c'}, - 'known_states': [ - {'name': 'CBS432', - 'expected_length': 1, - 'expected_fraction': 0.01}, - {'name': 'N_45', - 'expected_length': 1, - 'expected_fraction': 0.01}, - ], - 'unknown_states': [{'name': 'unknown', - 'expected_length': 1, - }, - ] - } - } - with pytest.raises(ValueError) as e: - config.validate_predict_arguments() - assert 'unknown did not provide an expected_fraction' in str(e) - - -def test_validate_id_regions_arguments(config): - with pytest.raises(ValueError) as e: - config.validate_id_regions_arguments() - assert ('Failed to validate ID Producer, ' - "required argument 'chromosomes' was unset") in str(e) - - config.chromosomes = 1 - config.blocks = 1 - config.labeled_blocks = 1 - config.states = 1 - - assert config.validate_id_regions_arguments() diff --git a/code/test/analyze/test_main_id_args.py b/code/test/analyze/test_main_id_args.py new file mode 100644 index 0000000..c35b4f0 --- /dev/null +++ b/code/test/analyze/test_main_id_args.py @@ -0,0 +1,97 @@ +import pytest +from click.testing import CliRunner +import analyze.main as main +import yaml +from analyze.id_regions import ID_producer + + +''' +Unit tests for the id_regions command of main.py when parameters are +provided by args +''' + + +@pytest.fixture +def runner(): + return CliRunner() + + +def test_states(runner, mocker): + mock_log = mocker.patch('analyze.main.log.info') + with runner.isolated_filesystem(): + with open('config.yaml', 'w') as f: + yaml.dump( + { + 'chromosomes': 'I II III'.split(), + }, f) + + result = runner.invoke( + main.cli, + '--config config.yaml id-regions --state s1 --state s2') + + assert result.exit_code != 0 + assert str(result.exception) == 'No block file provided' + assert mock_log.call_args_list == [ + mocker.call('Verbosity set to WARNING'), + mocker.call('Read in 1 config file'), + mocker.call('Found 3 chromosomes in config'), + mocker.call('Found 2 states to process'), + ] + + +def test_block_file(runner, mocker): + mock_log = mocker.patch('analyze.main.log.info') + with runner.isolated_filesystem(): + with open('config.yaml', 'w') as f: + yaml.dump( + { + 'chromosomes': 'I II III'.split(), + }, f) + + result = runner.invoke( + main.cli, + '--config config.yaml id-regions ' + '--state s1 --state s2 ' + '--blocks block_{state}.txt ') + + assert result.exit_code != 0 + assert str(result.exception) == 'No labeled block file provided' + assert mock_log.call_args_list == [ + mocker.call('Verbosity set to WARNING'), + mocker.call('Read in 1 config file'), + mocker.call('Found 3 chromosomes in config'), + mocker.call('Found 2 states to process'), + mocker.call("Input blocks file is 'block_{state}.txt'"), + ] + + +def test_labeled_block_file(runner, mocker): + mock_log = mocker.patch('analyze.main.log.info') + with runner.isolated_filesystem(): + with open('config.yaml', 'w') as f: + yaml.dump( + { + 'chromosomes': 'I II III'.split(), + }, f) + + mock_id = mocker.patch.object(ID_producer, 'add_ids') + + result = runner.invoke( + main.cli, + '--config config.yaml id-regions ' + '--state s1 --state s2 ' + '--blocks block_{state}.txt ' + '--labeled labeled_block_{state}.txt' + ) + + assert result.exit_code == 0 + assert mock_log.call_args_list == [ + mocker.call('Verbosity set to WARNING'), + mocker.call('Read in 1 config file'), + mocker.call('Found 3 chromosomes in config'), + mocker.call('Found 2 states to process'), + mocker.call("Input blocks file is 'block_{state}.txt'"), + mocker.call("Output blocks file is 'labeled_block_{state}.txt'"), + ] + + mock_id.called_once() diff --git a/code/test/analyze/test_main_predict_args.py b/code/test/analyze/test_main_predict_args.py index 42b27a1..c5668f8 100644 --- a/code/test/analyze/test_main_predict_args.py +++ b/code/test/analyze/test_main_predict_args.py @@ -215,7 +215,7 @@ def test_outputs(runner, mocker): assert result.exit_code != 0 assert str(result.exception) == \ - 'No probabilities file provided' + 'No positions file provided' assert mock_log.call_args_list == mock_calls with runner.isolated_filesystem(): @@ -238,6 +238,7 @@ def test_outputs(runner, mocker): '--blocks blocks_{state}.txt --prefix s1_s2 ' '--hmm-initial hmm_init.txt ' '--hmm-trained hmm_trained.txt ' + '--positions pos.txt.gz ' '--probabilities probs.txt.gz ' ) @@ -266,6 +267,7 @@ def test_outputs(runner, mocker): '--blocks blocks_{state}.txt --prefix s1_s2 ' '--hmm-initial hmm_init.txt ' '--hmm-trained hmm_trained.txt ' + '--positions pos.txt.gz ' '--probabilities probs.txt.gz ' '--alignment {prefix}_{strain}_chr{chrom}.maf ' ) @@ -276,7 +278,7 @@ def test_outputs(runner, mocker): assert mock_log.call_args_list == mock_calls + [ mocker.call("Hmm_initial file is 'hmm_init.txt'"), mocker.call("Hmm_trained file is 'hmm_trained.txt'"), - mocker.call("Positions file is 'None'"), + mocker.call("Positions file is 'pos.txt.gz'"), mocker.call("Probabilities file is 'probs.txt.gz'"), mocker.call("Alignment file is 's1_s2_{strain}_chr{chrom}.maf'"), mocker.call("Only considering polymorphic sites"), diff --git a/code/test/analyze/test_main_predict_config.py b/code/test/analyze/test_main_predict_config.py index 9d9146e..eb2a9ef 100644 --- a/code/test/analyze/test_main_predict_config.py +++ b/code/test/analyze/test_main_predict_config.py @@ -294,7 +294,7 @@ def test_outputs(runner, mocker): assert result.exit_code != 0 assert str(result.exception) == \ - 'No probabilities file provided' + 'No positions file provided' assert mock_log.call_args_list == mock_calls with runner.isolated_filesystem(): @@ -313,6 +313,7 @@ def test_outputs(runner, mocker): 'block_files': 'blocks_{state}.txt', 'hmm_initial': 'hmm_init.txt', 'hmm_trained': 'hmm_trained.txt', + 'positions': 'pos.txt.gz', 'probabilities': 'probs.txt.gz', }}, }, f) @@ -344,6 +345,7 @@ def test_outputs(runner, mocker): 'hmm_initial': 'hmm_init.txt', 'hmm_trained': 'hmm_trained.txt', 'probabilities': 'probs.txt.gz', + 'positions': 'pos.txt.gz', 'alignment': '{prefix}_{strain}_chr{chrom}.maf', }}, }, f) @@ -359,7 +361,7 @@ def test_outputs(runner, mocker): assert mock_log.call_args_list == mock_calls + [ mocker.call("Hmm_initial file is 'hmm_init.txt'"), mocker.call("Hmm_trained file is 'hmm_trained.txt'"), - mocker.call("Positions file is 'None'"), + mocker.call("Positions file is 'pos.txt.gz'"), mocker.call("Probabilities file is 'probs.txt.gz'"), mocker.call("Alignment file is 's1_s2_{strain}_chr{chrom}.maf'"), mocker.call("Only considering polymorphic sites") diff --git a/code/test/analyze/test_predict_predictor.py b/code/test/analyze/test_predict_predictor.py index ee583ee..d8f0928 100644 --- a/code/test/analyze/test_predict_predictor.py +++ b/code/test/analyze/test_predict_predictor.py @@ -40,90 +40,183 @@ def test_predictor(predictor): assert predictor.config.unknown_states == ['unknown'] -def test_run_prediction_no_pos(predictor, config, mocker, capsys): - config.chromosomes = ['I', 'II'] - config.blocks = 'blocks{state}.txt' - config.prefix = 'prefix' - config.strains = ['s1', 's2'] - config.hmm_initial = 'hmm_initial.txt' - config.hmm_trained = 'hmm_trained.txt' - config.probabilities = 'probs.txt' - config.positions = None - config.alignment = 'prefix_{strain}_chr{chrom}.maf' - config.known_states = 'S288c CBS432 N_45 DBVP UWOP'.split() - config.unknown_states = ['unknown'] - config.states = config.known_states + config.unknown_states - config.threshold = 'viterbi' +def test_validate_arguments(predictor): + config = predictor.config + config.chromosomes = 1 + config.blocks = 1 + config.prefix = 1 + config.strains = 1 + config.hmm_initial = 1 + config.hmm_trained = 1 + config.probabilities = 1 + config.positions = 1 + config.alignment = 1 + config.known_states = 1 + config.unknown_states = 1 + config.threshold = 1 config.config = { 'analysis_params': {'reference': {'name': 'S288c'}, 'known_states': [ {'name': 'CBS432', - 'expected_length': 10000, - 'expected_fraction': 0.025}, + 'expected_length': 1, + 'expected_fraction': 0.01}, {'name': 'N_45', - 'expected_length': 10000, - 'expected_fraction': 0.025}, - {'name': 'DBVP', - 'expected_length': 10000, - 'expected_fraction': 0.025}, - {'name': 'UWOP', - 'expected_length': 10000, - 'expected_fraction': 0.025}, + 'expected_length': 1, + 'expected_fraction': 0.01}, + {'name': 'DBVPG6304', + 'expected_length': 1, + 'expected_fraction': 0.01}, + {'name': 'UWOPS91_917_1', + 'expected_length': 1, + 'expected_fraction': 0.01}, ], 'unknown_states': [{'name': 'unknown', - 'expected_length': 1000, + 'expected_length': 1, 'expected_fraction': 0.01}, ] } } - mock_files = [mocker.MagicMock() for i in range(8)] - mocker.patch('analyze.predict.open', - side_effect=mock_files) - mock_gzip = mocker.patch('analyze.predict.gzip.open') - mocker.patch('analyze.predict.log') - mocker.patch('analyze.predict.os.path.exists', return_value=True) - mocker.patch('analyze.predict.read_fasta', - return_value=(None, - [list('NNENNENNEN'), # S288c - list('NNNENEENNN'), # CBS432 - list('NN-NNEENNN'), # N_45 - list('NEENN-ENEN'), # DBVPG6304 - list('ENENNEENEN'), # UWOPS.. - list('NNENNEENEN'), # predicted - ] - )) - mock_log_hmm = mocker.patch('hmm.hmm_bw.log.info') + assert predictor.validate_arguments() + + args = [ + 'chromosomes', + 'blocks', + 'prefix', + 'strains', + 'hmm_initial', + 'hmm_trained', + 'probabilities', + 'positions', + 'alignment', + 'known_states', + 'unknown_states', + 'threshold' + ] - predictor.run_prediction(only_poly_sites=True) + for arg in args: + config.__dict__[arg] = None + with pytest.raises(ValueError) as e: + predictor.validate_arguments() + assert ('Failed to validate Predictor, ' + f"required argument '{arg}' was unset") in str(e) + config.__dict__[arg] = 1 - # check hmm output - assert mock_log_hmm.call_args_list[-3:] == \ - [mocker.call('Iteration 8'), - mocker.call('Iteration 9'), - mocker.call('finished in 10 iterations')] + config.config = { + 'analysis_params': + {'reference': {'name': 'S288c'}, + 'unknown_states': [{'name': 'unknown', + 'expected_length': 1, + 'expected_fraction': 0.01}, + ] + } + } + with pytest.raises(ValueError) as e: + predictor.validate_arguments() + assert 'Configuration did not provide any known_states' in str(e) - assert mock_gzip.call_args_list == [mocker.call('probs.txt', 'wt')] + config.config = { + 'analysis_params': + {'known_states': [ + {'name': 'CBS432', + 'expected_length': 1, + 'expected_fraction': 0.01}, + {'name': 'N_45', + 'expected_length': 1, + 'expected_fraction': 0.01}, + ], + 'unknown_states': [{'name': 'unknown', + 'expected_length': 1, + 'expected_fraction': 0.01}, + ] + } + } + with pytest.raises(ValueError) as e: + predictor.validate_arguments() + assert 'Configuration did not specify a reference strain' in str(e) - # probs and pos interspersed - print(mock_gzip.return_value.__enter__().write.call_args_list) - assert mock_gzip.return_value.__enter__().write.call_args_list == \ - [ - mocker.call('s1\tI\t'), - mocker.ANY, - mocker.call('\n'), - mocker.call('s2\tI\t'), - mocker.ANY, - mocker.call('\n'), - mocker.call('s1\tII\t'), - mocker.ANY, - mocker.call('\n'), - mocker.call('s2\tII\t'), - mocker.ANY, - mocker.call('\n'), + config.config = { + 'analysis_params': + {'reference': {'name': 'S288c'}, + 'known_states': [ + {'name': 'CBS432', + 'expected_fraction': 0.01}, + {'name': 'N_45', + 'expected_length': 1, + 'expected_fraction': 0.01}, + ], + 'unknown_states': [{'name': 'unknown', + 'expected_length': 1, + 'expected_fraction': 0.01}, + ] + } + } + with pytest.raises(ValueError) as e: + predictor.validate_arguments() + assert 'CBS432 did not provide an expected_length' in str(e) - ] + config.config = { + 'analysis_params': + {'reference': {'name': 'S288c'}, + 'known_states': [ + {'name': 'CBS432', + 'expected_length': 1, + 'expected_fraction': 0.01}, + {'name': 'N_45', + 'expected_length': 1, + }, + ], + 'unknown_states': [{'name': 'unknown', + 'expected_length': 1, + 'expected_fraction': 0.01}, + ] + } + } + with pytest.raises(ValueError) as e: + predictor.validate_arguments() + assert 'N_45 did not provide an expected_fraction' in str(e) + + config.config = { + 'analysis_params': + {'reference': {'name': 'S288c'}, + 'known_states': [ + {'name': 'CBS432', + 'expected_length': 1, + 'expected_fraction': 0.01}, + {'name': 'N_45', + 'expected_length': 1, + 'expected_fraction': 0.01}, + ], + 'unknown_states': [{'name': 'unknown', + 'expected_fraction': 0.01}, + ] + } + } + with pytest.raises(ValueError) as e: + predictor.validate_arguments() + assert 'unknown did not provide an expected_length' in str(e) + + config.config = { + 'analysis_params': + {'reference': {'name': 'S288c'}, + 'known_states': [ + {'name': 'CBS432', + 'expected_length': 1, + 'expected_fraction': 0.01}, + {'name': 'N_45', + 'expected_length': 1, + 'expected_fraction': 0.01}, + ], + 'unknown_states': [{'name': 'unknown', + 'expected_length': 1, + }, + ] + } + } + with pytest.raises(ValueError) as e: + predictor.validate_arguments() + assert 'unknown did not provide an expected_fraction' in str(e) def test_run_prediction_full(predictor, config, mocker): From c42b6fd9bb97ef3760234cce3299d42f8be91106 Mon Sep 17 00:00:00 2001 From: Troy Comi Date: Thu, 9 May 2019 15:34:41 -0400 Subject: [PATCH 21/33] Refactor summarize region quality Refactored summarize region quality main and supporting module with cleaner implementation. Added to main click method and all supporting unit tests. --- code/analyze/introgression_configuration.py | 104 +- code/analyze/main.py | 64 +- code/analyze/predict.py | 2 +- code/analyze/summarize_region_quality.py | 1052 +++++++++--- code/analyze/summarize_region_quality_main.py | 296 ---- code/config.yaml | 13 +- code/hmm/hmm_bw.py | 1 + code/misc/region_reader.py | 3 +- code/test/analyze/test_id_regions.py | 8 +- .../test_introgression_configuration.py | 145 +- code/test/analyze/test_main_id_config.py | 6 +- code/test/analyze/test_main_predict_config.py | 20 +- .../test_main_summarize_regions_args.py | 328 ++++ .../test_main_summarize_regions_config.py | 335 ++++ .../analyze/test_summarize_region_quality.py | 1437 +++++++++++++++-- .../test_summarize_region_quality_main.py | 63 - .../run_summarize_region_quality.slurm.sh | 12 +- code/test/helper_scripts/test_predict.slurm | 3 +- 18 files changed, 3145 insertions(+), 747 deletions(-) delete mode 100644 code/analyze/summarize_region_quality_main.py create mode 100644 code/test/analyze/test_main_summarize_regions_args.py create mode 100644 code/test/analyze/test_main_summarize_regions_config.py delete mode 100644 code/test/analyze/test_summarize_region_quality_main.py diff --git a/code/analyze/introgression_configuration.py b/code/analyze/introgression_configuration.py index 41fc2a3..4133c75 100644 --- a/code/analyze/introgression_configuration.py +++ b/code/analyze/introgression_configuration.py @@ -44,6 +44,28 @@ def get_states(self) -> Tuple[List, List]: return known_states, unknown_states + def get_interval_states(self) -> List: + ''' + Build list of interval states, typically just known names + but if the state has an interval name, use that + ''' + ref = get_nested(self.config, 'analysis_params.reference') + + # set with name or empty list + if ref is None: + ref = [] + else: + ref = [ref] + + known = get_nested(self.config, 'analysis_params.known_states') + if known is None: + known = [] + + return [s['interval_name'] + if 'interval_name' in s + else s['name'] + for s in ref + known] + def set_states(self, states: List[str] = None): ''' Set the states for which to perform region naming @@ -54,6 +76,8 @@ def set_states(self, states: List[str] = None): else: self.states = states + self.interval_states = self.get_interval_states() + if self.states == []: err = 'No states specified' log.exception(err) @@ -103,7 +127,7 @@ def set_blocks_file(self, blocks: str = None): ''' self.blocks = validate( self.config, - 'paths.analysis.block_files', + 'paths.analysis.blocks', 'No block file provided', blocks) @@ -116,12 +140,38 @@ def set_labeled_blocks_file(self, blocks: str = None): ''' self.labeled_blocks = validate( self.config, - 'paths.analysis.labeled_block_files', + 'paths.analysis.labeled_blocks', 'No labeled block file provided', blocks) check_wildcards(self.labeled_blocks, 'state') + def set_quality_file(self, quality: str = None): + ''' + Set the quality block wildcard filename. + Checks for appropriate wildcards + ''' + self.quality_blocks = validate( + self.config, + 'paths.analysis.quality', + 'No quality block file provided', + quality) + + check_wildcards(self.quality_blocks, 'state') + + def set_masked_file(self, masks: str = None): + ''' + Set the masked interval block wildcard filename. + Checks for appropriate wildcards + ''' + self.masks = validate( + self.config, + 'paths.analysis.masked_intervals', + 'No masked interval file provided', + masks) + + check_wildcards(self.masks, 'strain,chrom') + def set_prefix(self, prefix: str = ''): ''' Set prefix string of the predictor to the supplied value or @@ -233,22 +283,60 @@ def set_predict_files(self, 'No trained hmm file provided', hmm_trained) - self.positions = validate(self.config, - 'paths.analysis.positions', - 'No positions file provided', - positions) + self.set_positions(positions) self.probabilities = validate(self.config, 'paths.analysis.probabilities', 'No probabilities file provided', probabilities) + self.set_alignment(alignment) + + def set_alignment(self, alignment: str): + ''' + Set the alignment file, checking wildcards prefix, strain and chrom. + If prefix is present, it is substituted, otherwise checks just + strain and chrom + ''' alignment = validate(self.config, 'paths.analysis.alignment', 'No alignment file provided', alignment) - check_wildcards(alignment, 'prefix,strain,chrom') - self.alignment = alignment.replace('{prefix}', self.prefix) + if '{prefix}' in alignment: + check_wildcards(alignment, 'prefix,strain,chrom') + self.alignment = alignment.replace('{prefix}', self.prefix) + else: + check_wildcards(alignment, 'strain,chrom') + self.alignment = alignment + + def set_positions(self, positions: str): + ''' + Sets the position file + ''' + self.positions = validate(self.config, + 'paths.analysis.positions', + 'No positions file provided', + positions) + + def set_regions_files(self, + regions: str = None, + region_index: str = None): + ''' + Set the region and pickle wildcard filename. Checks for state wildcards + ''' + self.regions = validate( + self.config, + 'paths.analysis.regions', + 'No region file provided', + regions) + check_wildcards(self.regions, 'state') + + self.region_index = validate( + self.config, + 'paths.analysis.region_index', + 'No region index file provided', + region_index) + check_wildcards(self.region_index, 'state') def set_HMM_symbols(self): ''' diff --git a/code/analyze/main.py b/code/analyze/main.py index 082f3fd..2a1342f 100644 --- a/code/analyze/main.py +++ b/code/analyze/main.py @@ -4,6 +4,7 @@ import analyze.predict from analyze.introgression_configuration import Configuration from analyze.id_regions import ID_producer +from analyze.summarize_region_quality import Summarizer # TODO also check for snakemake object? @@ -136,7 +137,6 @@ def predict(ctx, predictor.run_prediction(only_poly_sites) -# accept multiple states and pass as list @cli.command() @click.option('--blocks', default='', help='Block file location with {state}') @click.option('--labeled', default='', help='Block file location with {state}') @@ -161,5 +161,67 @@ def id_regions(ctx, blocks, labeled, state): id_producer.add_ids() +# TODO add in summarize region quality here! +@cli.command() +@click.option('--state', multiple=True, help='States to summarize') +@click.option('--labeled', default='', + help='Labeled block file with {state} ' + 'Created during id_regions') +@click.option('--masks', default='', + help='Mask file with {strain} and {chrom}') +@click.option('--alignment', default='', + help='Alignment file with {prefix} [optional], ' + '{strain} and {chrom}') +@click.option('--positions', default='', + help='Position file created during prediction') +@click.option('--quality', default='', + help='Output quality file with {state}') +@click.option('--region', default='', + help='Output region file with {state}, gzipped') +@click.option('--region-index', default='', + help='Output region index file with {state}, pickled') +@click.pass_context +def summarize_regions(ctx, + state, + labeled, + quality, + masks, + alignment, + positions, + region, + region_index): + config = ctx.obj + + config.set_states() + + config.set_chromosomes() + log.info(f'Found {len(config.chromosomes)} chromosomes in config') + + config.set_labeled_blocks_file(labeled) + log.info(f'Labeled blocks file is \'{config.labeled_blocks}\'') + + config.set_quality_file(quality) + log.info(f'Quality file is \'{config.quality_blocks}\'') + + config.set_masked_file(masks) + log.info(f'Mask file is \'{config.masks}\'') + + config.set_prefix() + config.set_alignment(alignment) + log.info(f'Alignment file is \'{config.alignment}\'') + + config.set_positions(positions) + log.info(f'Positions file is \'{config.positions}\'') + + config.set_regions_files(region, region_index) + log.info(f'Region file is \'{config.regions}\'') + log.info(f'Region index file is \'{config.region_index}\'') + + config.set_HMM_symbols() + + summarizer = Summarizer(config) + summarizer.run(list(state)) + + if __name__ == '__main__': cli() diff --git a/code/analyze/predict.py b/code/analyze/predict.py index b19f8dd..62e8df4 100644 --- a/code/analyze/predict.py +++ b/code/analyze/predict.py @@ -692,7 +692,7 @@ def ungap_and_code(self, isbase = sequences != self.symbols['gap'] # make boolean for valid characters - isvalid = np.logical_and(sequences != self.symbols['gap'], + isvalid = np.logical_and(isbase, sequences != self.symbols['unsequenced']) # positions are where everything is valid, index where the reference is diff --git a/code/analyze/summarize_region_quality.py b/code/analyze/summarize_region_quality.py index d8f2ddd..43724c4 100644 --- a/code/analyze/summarize_region_quality.py +++ b/code/analyze/summarize_region_quality.py @@ -1,9 +1,17 @@ +from __future__ import annotations import bisect import gzip -import global_params as gp -from misc import binary_search import numpy as np +import pickle +from contextlib import ExitStack +import click +import logging as log +from collections import Counter +from misc import read_fasta +from misc import read_table +from misc import seq_functions from typing import List, Tuple, Dict +from analyze.introgression_configuration import Configuration cen_starts = [151465, 238207, 114385, 449711, 151987, 148510, @@ -39,12 +47,16 @@ tel_right_starts = [tel_coords[i] for i in range(2, len(tel_coords), 4)] tel_right_ends = [tel_coords[i] for i in range(3, len(tel_coords), 4)] +chromosomes = ('I II III IV V ' + 'VI VII VIII IX X ' + 'XI XII XIII XIV XV XVI').split() + def distance_from_telomere(start, end, chrm): assert start <= end, str(start) + ' ' + str(end) - i = gp.chrms.index(chrm) + i = chromosomes.index(chrm) # region entirely on left arm if end <= cen_starts[i]: return start - tel_left_ends[i] @@ -59,7 +71,7 @@ def distance_from_centromere(start, end, chrm): assert start <= end, str(start) + ' ' + str(end) - i = gp.chrms.index(chrm) + i = chromosomes.index(chrm) # region entirely on left arm if end <= cen_starts[i]: return cen_starts[i] - end @@ -86,7 +98,7 @@ def gap_columns(seqs): g = 0 for i in range(len(seqs[0])): for seq in seqs: - if seq[i] == gp.gap_symbol: + if seq[i] == '-': # gp.gap_symbol: g += 1 break return g @@ -120,9 +132,9 @@ def masked_columns(seqs): mask = False gap = False for s in range(num_seqs): - if seqs[s][ps] == gp.gap_symbol: + if seqs[s][ps] == '-': # gp.gap_symbol: gap = True - elif seqs[s][ps] == gp.masked_symbol: + elif seqs[s][ps] == 'x': # gp.masked_symbol: mask = True if mask: mask_total += 1 @@ -139,23 +151,14 @@ def index_by_reference(ref_seq, seq): ri = 0 si = 0 for i in range(len(ref_seq)): - if ref_seq[i] != gp.gap_symbol: + if ref_seq[i] != '-': # gp.gap_symbol: d[ri] = si ri += 1 - if seq[i] != gp.gap_symbol: + if seq[i] != '-': # gp.gap_symbol: si += 1 return d -def index_alignment_by_reference(ref_seq: np.array) -> np.array: - ''' - Find locations of non-gapped sites in reference sequence - want a way to go from reference sequence coordinate to index in - alignment - ''' - return np.where(ref_seq != gp.gap_symbol)[0] - - def num_sites_between(sites, start, end): # sites are sorted i = bisect.bisect_left(sites, start) @@ -163,225 +166,850 @@ def num_sites_between(sites, start, end): return j - i, sites[i:j] -def read_masked_intervals(filename: str) -> List[Tuple[int, int]]: +class Summarizer(): ''' - Read the interval file provided and return start and end sequences - as a list of tuples of 2 ints + Summarize region quality of each region ''' - with open(filename, 'r') as reader: - reader.readline() # header - intervals = [] - for line in reader: - line = line.split() - intervals.append((int(line[0]), int(line[2]))) + def __init__(self, configuration: Configuration): + self.config = configuration + + def validate_arguments(self): + ''' + Check that all required instance variables are set to perform a + summarize run. Returns true if valid, raises value error otherwise + ''' + args = [ + 'chromosomes', + 'labeled_blocks', + 'quality_blocks', + 'masks', + 'alignment', + 'positions', + 'regions', + 'region_index', + 'known_states', + 'unknown_states', + 'states', + 'symbols' + ] + variables = self.config.__dict__ + for arg in args: + if arg not in variables or variables[arg] is None: + err = ('Failed to validate Summarizer, required argument ' + f"'{arg}' was unset") + log.exception(err) + raise ValueError(err) + + reference = self.config.get('analysis_params.reference') + if reference is None: + err = f'Configuration did not specify a reference strain' + log.exception(err) + raise ValueError(err) + + return True + + def run(self, states: List[str] = None): + ''' + Summarize region quality of each region for the states specified + ''' + ref_ind, states = self.states_to_process(states) + + log.debug(f'reference index: {ref_ind}') + log.debug(f'states to analyze: {states}') + + known_states = self.config.known_states + log.debug(f'known_states {known_states}') + + analyzer = Sequence_Analyzer( + self.config.masks, + self.config.alignment, + self.config.known_states, + self.config.interval_states, + self.config.chromosomes, + self.config.symbols) + + log.debug(f'Sequence_Analyzer init with:') + log.debug(f'masks: {self.config.masks}') + log.debug(f'alignment: {self.config.alignment}') + + analyzer.build_masked_sites() + + for ind, state in enumerate(states): + log.info(f'Working on state {state}') + state_ind = self.config.states.index(state) + + with Position_Reader( + self.config.positions + ) as positions,\ + Region_Writer( + self.config.regions.format(state=state), + self.config.region_index.format(state=state), + known_states + ) as region_writer,\ + Quality_Writer( + self.config.quality_blocks.format(state=state) + ) as quality_writer,\ + ExitStack() as stack: + + progress_bar = None + if self.config.log_file: + progress_bar = stack.enter_context( + click.progressbar( + length=len(self.config.chromosomes), + label=f'State {ind+1} of {len(states)}')) + + for chrm in self.config.chromosomes: + log.info(f'Working on chromosome {chrm}') + region = Region_Database( + self.config.labeled_blocks.format(state=state), + chrm, + known_states) + + for strain, ps in positions.get_positions(region, chrm): + log.debug(f'{strain} {chrm}') + + analyzer.process_alignment(ref_ind, + state_ind, + chrm, + strain, + ps, + region, + region_writer) + + quality_writer.write_quality(region) + + if progress_bar: + progress_bar.update(1) + + def states_to_process(self, + states: List[str] = None) -> Tuple[int, + List[str]]: + ''' + Set the states to summarize to the values passed in. + If no values are specified, run all states in config + Checks if states are in config, warning if a state is not + found and raising an error if none of the states are in config. + ''' + reference = self.config.get('analysis_params.reference.name') + ref_ind = self.config.states.index(reference) + + if states is None or states == []: + to_process = self.config.states - return intervals + else: + to_process = [] + for s in states: + if s in self.config.states: + to_process.append(s) + else: + log.warning(f"state '{s}' was not found as a state") + if to_process == []: + err = 'No valid states were found to process' + log.exception(err) + raise ValueError(err) -def convert_intervals_to_sites(intervals: List[Tuple[int, int]]) -> np.array: - ''' - Given a list of start, end positions, returns a 1D np.array of all sites - contined in the intervals List - convert_intervals_to_sites([(1, 2), (4, 6)]) -> [1, 2, 4, 5, 6] - ''' - sites = [] - for start, end in intervals: - sites += range(start, end + 1) - return np.array(sites) + return ref_ind, to_process -def seq_id_hmm(seq1: np.array, - seq2: np.array, - offset: int, - include_sites: List[int]) -> Tuple[ - int, int, Dict[str, List[bool]]]: +class Sequence_Analyzer(): + ''' + Performs handling of masking, reading, and analyzing sequence data for + summarizing the sequences ''' - Compare two sequences and provide statistics of their overlap considering - only the included sites. - Takes the two sequences to consider, an offset of the included sites, - and a list of the included sites. - Returns: - -the total number of matching sites, where seq1[i] == seq2[i] and - i is an element in included_sites - offset - -the total number of sites considered in the included sites, e.g. where - included_sites - offset >= 0 and < len(seq) - -a dict with the following keys: - -gap_flag: true where seq1 or seq1 == gap_symbol - -unseq_flag: true where seq1 or seq1 == unsequenced_symbol - -hmm_flag: true where hmm_flag[i] is in included_sites - offset - -match: true where seq1 == seq2, regardless of symbol + def __init__(self, + mask_file: str, + alignment_file: str, + known_states: List, + interval_states: List, + chromosomes: List, + symbols: Dict): + self.masks = mask_file + self.alignments = alignment_file + self.known_states = known_states + self.interval_states = interval_states + self.chromosomes = chromosomes + self.symbols = symbols + + def build_masked_sites(self): + ''' + Read in all intervals files and return dictionary of intervals, + keyed first by chromosome, then state + ''' + result = {} + for chrom in self.chromosomes: + result[chrom] = {} + for state, name in zip(self.known_states, self.interval_states): + result[chrom][state] = self.read_masked_sites(chrom, name) + + self.masked_sites = result + + def read_masked_sites(self, chrom: str, strain: str) -> np.array: + filename = self.masks.format(chrom=chrom, strain=strain) + intervals = self.read_masked_intervals(filename) + sites = self.convert_intervals_to_sites(intervals) + return sites + + def convert_intervals_to_sites(self, + intervals: List[Tuple]) -> np.array: + ''' + Given a list of start, end positions, returns a 1D np.array of sites + contained in the intervals List + convert_intervals_to_sites([(1, 2), (4, 6)]) -> [1, 2, 4, 5, 6] + ''' + sites = [] + for start, end in intervals: + sites += range(start, end + 1) + return np.array(sites, dtype=int) + + def read_masked_intervals(self, + filename: str) -> List[Tuple[int, int]]: + ''' + Read the interval file provided and return start and end sequences + as a list of tuples of 2 ints + ''' + with open(filename, 'r') as reader: + reader.readline() # header + intervals = [] + for line in reader: + line = line.split() + intervals.append((int(line[0]), int(line[2]))) + + return intervals + + def get_stats(self, + current_sequence, + other_sequence, + slice_start, + aligned_index_positions, + masked_site): + ''' + Helper function to perform analyses on the sequences returning + the results of seq_id_hmm, seq_id, and seq_id_unmasked + ''' + + # only alignment columns used by HMM (polymorphic, no + # gaps in any strain) + hmm_stats = self.seq_id_hmm(other_sequence, + current_sequence, + slice_start, + aligned_index_positions) + + # all alignment columns, excluding ones with gaps in + # these two sequences + nongap_stats = seq_functions.seq_id(other_sequence, + current_sequence) + + # all alignment columns, excluding ones with gaps or + # masked bases or unsequenced in *these two sequences* + nonmask_stats = self.seq_id_unmasked(other_sequence, + current_sequence, + slice_start, + masked_site[0], + masked_site[1]) + + return hmm_stats, nongap_stats, nonmask_stats + + def seq_id_hmm(self, + seq1: np.array, + seq2: np.array, + offset: int, + include_sites: List[int]) -> Tuple[ + int, int, Flag_Info]: + ''' + Compare two sequences and provide statistics of their overlap + considering only the included sites. + Takes the two sequences to consider, an offset of the included sites, + and a list of the included sites. + Returns: + -the total number of matching sites, where seq1[i] == seq2[i] and + i is an element in included_sites - offset + -the total number of sites considered in the included sites, e.g. where + included_sites - offset >= 0 and < len(seq) + -a Flag_Info object with: + -gap: true where seq1 or seq1 == gap_symbol + -unseq: true where seq1 or seq1 == unsequenced_symbol + -hmm: true where hmm[i] is in included_sites - offset + -match: true where seq1 == seq2, regardless of symbol + ''' + sites = np.array(include_sites) - offset + + info = Flag_Info() + info.gap = np.logical_or(seq1 == self.symbols['gap'], + seq2 == self.symbols['gap']) + info.unseq = np.logical_or(seq1 == self.symbols['unsequenced'], + seq2 == self.symbols['unsequenced']) + info.match = seq1 == seq2 + info.hmm = np.zeros(info.match.shape, bool) + sites = sites[np.logical_and(sites < len(info.match), sites >= 0)] + info.hmm[sites] = True + + total_sites = np.sum(info.hmm) + total_match = np.sum(np.logical_and(info.hmm, info.match)) + + # check all included are not gapped or skipped + include_in_skip = np.logical_and( + info.hmm, np.logical_or( + info.unseq, info.gap)) + if np.any(include_in_skip): + ind = np.where(include_in_skip)[0][0] + err = ('Need to skip site specified as included ' + f'seq1: {seq1[ind]}, seq2: {seq2[ind]}, index: {ind}') + log.exception(err) + raise ValueError(err) + + return total_match, total_sites, info + + def seq_id_unmasked(self, + seq1: np.array, + seq2: np.array, + offset: int, + exclude_sites1: List[int], + exclude_sites2: List[int]) -> Tuple[ + int, int, Flag_info]: + ''' + Compare two sequences and provide statistics of their overlap considering + only the included sites. + Takes two sequences, an offset applied to each excluded sites list + Returns: + -total number of matching sites in non-excluded sites. A position is + excluded if it is an element of either excluded site list - offset, + or it is a gap or unsequenced symbol in either sequence. + -total number of non-excluded sites + A Flag_Info object with: + -mask_flag: a boolean array that is true if the position is in + either excluded list - offset + ''' + info = Flag_Info() + info.gap = np.logical_or(seq1 == self.symbols['gap'], + seq2 == self.symbols['gap']) + info.unseq = np.logical_or(seq1 == self.symbols['unsequenced'], + seq2 == self.symbols['unsequenced']) + exclude_sites1 = np.array(exclude_sites1) + exclude_sites2 = np.array(exclude_sites2) + + # convert offset excluded sites to boolean array + info.mask = np.zeros(seq1.shape, bool) + if exclude_sites1.size != 0: + sites1 = exclude_sites1 - offset + sites1 = sites1[np.logical_and(sites1 < len(info.gap), + sites1 >= 0)] + info.mask[sites1] = True + + if exclude_sites2.size != 0: + sites2 = exclude_sites2 - offset + sites2 = sites2[np.logical_and(sites2 < len(info.gap), + sites2 >= 0)] + info.mask[sites2] = True + + # find sites that are not masked, gapped, or unsequenced + sites = np.logical_not( + np.logical_or( + info.mask, + np.logical_or( + info.gap, info.unseq))) + + # determine totals + total_sites = np.sum(sites) + total_match = np.sum( + np.logical_and( + seq1 == seq2, + sites)) + + return total_match, total_sites, info + + def process_alignment(self, + reference_index: int, + state_index: int, + chromosome: str, + strain: str, + positions: np.array, + region: Region_Database, + region_writer: Region_Writer): + ''' + Analyze the alignment of a given strain, chromosome, and position. + Result is stored in the provided region database + ''' + sequences, alignments, masked_sites = self.get_indices(chromosome, + strain) + + # convert position indices from indices in master reference to + # indices in alignment + ps_align = alignments[reference_index][positions] + + for i, (r_id, start, end) in enumerate(region.get_entries(strain)): + start, end = self.get_slice(start, end, + alignments[reference_index], + ps_align) + + info = Flag_Info() + info.initialize_flags( + end - start + 1, + len(self.known_states)) + + for ind, state in enumerate(self.known_states): + hmm, nongap, nonmask = self.get_stats( + sequences[-1][start:end + 1], + sequences[ind][start:end + 1], + start, + ps_align, + (masked_sites[ind], + masked_sites[-1])) + + region.set_region(strain, i, state, + hmm, + nongap, + nonmask) + + info.add_sequence_flags(hmm[2], ind) + info.add_mask_flags(nonmask[2], ind) + + info_string = info.encode_info(reference_index, state_index) + + region_writer.write_header(r_id) + region_writer.write_sequences( + strain, + alignments, + sequences, + (start, end)) + region_writer.write_info_string(info_string) + + # and keep track of each symbol count + region.update_counts(strain, i, info_string) + + def get_indices(self, chromosome: str, strain: str) -> Tuple: + ''' + Get the sequences and different indices for the provided + chromosome and strain + Returned tuple contains: + -sequences as np.array + -index alignment list of indices for each sequence + -masked_sites, index aligned for each sequence + ''' + _, sequences = read_fasta.read_fasta( + self.alignments.format(chrom=chromosome, strain=strain)) + + # to go from index in reference seq to index in alignment + alignments = [ + self.index_alignment_by_reference(seq) + for seq in sequences + ] + + masked = self.read_masked_sites(chromosome, strain) + + masked_sites = [ + alignments[ind][self.masked_sites[chromosome][state]] + for ind, state in enumerate(self.known_states) + ] + [alignments[-1][masked]] # for strain + + return sequences, alignments, masked_sites + + def index_alignment_by_reference(self, sequence: np.array) -> np.array: + ''' + Find locations of non-gapped sites in sequence + want a way to go from reference sequence coordinate to index in + alignment + ''' + return np.where(sequence != self.symbols['gap'])[0] + + def get_slice(self, + start: int, + end: int, + alignment: np.array, + ps_align: np.array) -> Tuple[int, int]: + ''' + Get start and end positions of index aligned sequence. + Checks that positions are valid (in ps_align), and raises + value errors otherwise + ''' + # index of start and end of region in aligned sequences + slice_start, slice_end = alignment[[start, end]] + + if not np.in1d([slice_start, slice_end], ps_align).all(): + err = 'Slice not found in position alignment' + log.exception(err) + raise ValueError(err) + + return slice_start, slice_end + + +class Flag_Info(): ''' - sites = np.array(include_sites) - offset - - info_gap = np.logical_or(seq1 == gp.gap_symbol, - seq2 == gp.gap_symbol) - info_unseq = np.logical_or(seq1 == gp.unsequenced_symbol, - seq2 == gp.unsequenced_symbol) - info_match = seq1 == seq2 - info_hmm = np.zeros(info_match.shape, bool) - sites = sites[np.logical_and(sites < len(info_match), sites >= 0)] - info_hmm[sites] = True - - total_sites = np.sum(info_hmm) - total_match = np.sum(np.logical_and(info_hmm, info_match)) - - # check all included are not gapped or skipped - include_in_skip = np.logical_and( - info_hmm, np.logical_or( - info_unseq, info_gap)) - if np.any(include_in_skip): - ind = np.where(include_in_skip)[0][0] - raise AssertionError(f'{seq1[ind]} {seq2[ind]} {ind}') - - return total_match, total_sites, \ - {'gap_flag': info_gap, 'unseq_flag': info_unseq, - 'hmm_flag': info_hmm, 'match': info_match} - - -def seq_id_unmasked(seq1: np.array, - seq2: np.array, - offset: int, - exclude_sites1: List[int], - exclude_sites2: List[int]) -> Tuple[ - int, int, Dict[str, List[bool]]]: + Collection of boolean flags for sequence summary ''' - Compare two sequences and provide statistics of their overlap considering - only the included sites. - Takes two sequences, an offset applied to each excluded sites list - Returns: - -total number of matching sites in non-excluded sites. A position is - excluded if it is an element of either excluded site list - offset, - or it is a gap or unsequenced symbol in either sequence. - -total number of non-excluded sites - A dict with the following keys: - -mask_flag: a boolean array that is true if the position is in - either excluded list - offset + def __init__(self): + self.gap_any = None + self.mask_any = None + self.unseq_any = None + self.hmm = None + self.gap = None + self.mask = None + self.unseq = None + self.match = None + + def initialize_flags(self, number_sequences: int, number_states: int): + ''' + Initialize internal flags to np arrays of false + ''' + self.gap_any = np.zeros((number_sequences), bool) + self.mask_any = np.zeros((number_sequences), bool) + self.unseq_any = np.zeros((number_sequences), bool) + self.gap = np.zeros((number_sequences, number_states), bool) + self.mask = np.zeros((number_sequences, number_states), bool) + self.unseq = np.zeros((number_sequences, number_states), bool) + self.match = np.zeros((number_sequences, number_states), bool) + + def add_sequence_flags(self, other: Flag_Info, state: int): + ''' + Join the other flag info with this info by replacing values + in the gap, unseq, and match arrays and performing OR with anys + ''' + # only write the first time + if state == 0: + self.hmm = other.hmm + + self.gap_any = np.logical_or(self.gap_any, other.gap) + self.unseq_any = np.logical_or(self.unseq_any, other.unseq) + + self.gap[:, state] = other.gap + self.unseq[:, state] = other.unseq + self.match[:, state] = other.match + + def add_mask_flags(self, other: Flag_Info, state: int): + ''' + Join the other flag info with this by replacing values in mask and + performing an OR with mask_any + ''' + self.mask_any = np.logical_or(self.mask_any, other.mask) + self.mask[:, state] = other.mask + + def encode_info(self, + master_ind: int, + predict_ind: int) -> str: + ''' + Summarize info flags into a string. master_ind is the index of + the master reference state. predict_ind is the index of the predicted + state. The return string is encoded for each position as: + '-': if either master or predict has a gap + '_': if either master or predict is masked + '.': if any state has a match + 'b': both predict and master match + 'c': master matches but not predict + 'p': predict matches but not master + 'x': no other condition applies + if the position is in the hmm_flag + it will be capitalized for x, p, c, or b + in order of precidence, e.g. if a position satisfies both '-' and '.', + it will be '-'. + ''' + + if predict_ind >= self.match.shape[1]: + return self.encode_unknown_info(master_ind) + + decoder = np.array(list('xXpPcCbB._-')) + indices = np.zeros(self.match.shape[0], int) + + indices[self.match[:, predict_ind]] += 2 # x to p if true + indices[self.match[:, master_ind]] += 4 # x to c, p to b + indices[self.hmm] += 1 # to upper + + matches = np.all(self.match, axis=1) + indices[matches] = 8 # . + indices[np.any( + self.mask[:, [master_ind, predict_ind]], + axis=1)] = 9 # _ + indices[np.any( + self.gap[:, [master_ind, predict_ind]], + axis=1)] = 10 # - + + return ''.join(decoder[indices]) + + def encode_unknown_info(self, + master_ind: int) -> str: + ''' + Summarize info dictionary into a string for unknown state. + master_ind is the index of the master reference state. + The return string is encoded as each position as: + '-': if any state has a gap + '_': if any state has a mask + '.': all states match + 'x': master matches + 'X': no other condition applies + in order of precidence, e.g. if a position satisfies both '-' and '.', + it will be '-'. + ''' + + # used with indices to decode result + decoder = np.array(list('Xx._-')) + indices = np.zeros(self.gap_any.shape, int) + + indices[self.match[:, master_ind]] = 1 # x + matches = np.all(self.match, axis=1) + indices[matches] = 2 # . + indices[self.mask_any] = 3 # _ + indices[self.gap_any] = 4 # - + + return ''.join(decoder[indices]) + + +class Region_Database(): ''' - info_gap = np.logical_or(seq1 == gp.gap_symbol, - seq2 == gp.gap_symbol) - info_unseq = np.logical_or(seq1 == gp.unsequenced_symbol, - seq2 == gp.unsequenced_symbol) - - # convert offset excluded sites to boolean array - info_mask = np.zeros(seq1.shape, bool) - if exclude_sites1 != []: - sites1 = np.array(exclude_sites1) - offset - sites1 = sites1[np.logical_and(sites1 < len(info_gap), - sites1 >= 0)] - info_mask[sites1] = True - if exclude_sites2 != []: - sites2 = np.array(exclude_sites2) - offset - sites2 = sites2[np.logical_and(sites2 < len(info_gap), - sites2 >= 0)] - info_mask[sites2] = True - - # find sites that are not masked, gapped, or unsequenced - sites = np.logical_not( - np.logical_or( - info_mask, - np.logical_or( - info_gap, info_unseq))) + Contains data and logic for regions data during summarizing + ''' + def __init__(self, + labeled_file: str, + chromosome: str, + known_states: List[str]): + ''' + Read in labeled file and store resulting table and labels + ''' + self.info_string_symbols = list('.-_npbcxNPBCX') + + self.label_prefixes = ['match_nongap', + 'num_sites_nongap', + 'match_hmm', + 'match_nonmask', + 'num_sites_nonmask'] + + self.data, self.labels = read_table.read_table_columns( + labeled_file, + sep='\t', + group_by='strain', + chromosome=chromosome) + + if self.labels[0] != 'region_id': + err = 'Unexpected labeled format' + log.exception(err) + raise ValueError(err) + + for strain, data in self.data.items(): + n = len(data['region_id']) + + for s in known_states: + for lbl in self.label_prefixes: + data[f'{lbl}_{s}'] = [0] * n + + for s in self.info_string_symbols: + data['count_' + s] = [0] * n + + self.labels += [f'{lbl}_{st}' for lbl in self.label_prefixes + for st in known_states] + self.labels += ['count_' + x for x in self.info_string_symbols] + + def has_strain(self, strain: str) -> bool: + ''' + Checks if the strain is in this database + ''' + return strain in self.data + + def get_entries(self, strain: str) -> Tuple[str, int, int]: + ''' + returns an iterator for the region entries of the strain + with region id (string), start (int) and end (int) positions + ''' + if not self.has_strain(strain): + err = f'Region Database does not contain strain {strain}' + log.exception(err) + raise ValueError(err) + + r_ids = self.data[strain]['region_id'] + starts = self.data[strain]['start'] + ends = self.data[strain]['end'] + for i in range(len(r_ids)): + yield (r_ids[i], int(starts[i]), int(ends[i])) + + def set_region(self, + strain: str, + index: int, + state: str, + hmm, nongap, nonmask): + ''' + Set the region state with the provided values. + hmm, nongap and nonmask are tuples of the (match, total) values + ''' + ds = self.data[strain] + MATCH, TOTAL = 0, 1 + if hmm[TOTAL] is not None: + ds['num_sites_hmm'][index] = hmm[TOTAL] + + ds[f'match_hmm_{state}'][index] = hmm[MATCH] + + ds[f'match_nongap_{state}'][index] = nongap[MATCH] + ds[f'num_sites_nongap_{state}'][index] = nongap[TOTAL] + + ds[f'match_nonmask_{state}'][index] = nonmask[MATCH] + ds[f'num_sites_nonmask_{state}'][index] = nonmask[TOTAL] + + def update_counts(self, + strain: str, + index: int, + info_string: str): + ''' + Update the counts variables based on the provided info string + ''' + counts = Counter(info_string) + for sym in self.info_string_symbols: + self.data[strain]['count_' + sym][index] = counts[sym] + + def generate_output(self): + ''' + Yield lines for writing to the quality output file. + To save memory, this effectively deletes the data structure! + Outputs are tab delimited, sorted by region_id + ''' + # reorganize output as list of tuples ordered by label + output = [] + # have to store this as dict changes during iterations + strains = list(self.data.keys()) + for strain in strains: + # pop to limit memory usage + d = self.data.pop(strain) + output += list(zip(*[d[l] for l in self.labels])) + + # sort by region id (index 0, remove r #[1:]) + for entry in sorted(output, key=lambda e: int(e[0][1:])): + yield '\t'.join([str(e) for e in entry]) + '\n' + + def generate_header(self): + ''' + Generate a header line for the region database + ''' + return '\t'.join(self.labels) + '\n' + + +class Region_Writer(): + ''' + Controls the writing of region files and indices + ''' + def __init__(self, + region_file: str, + index_file: str, + known_states: List[str]): + self.region_file = region_file + self.index_file = index_file + self.index = {} + self.known_states = known_states - # determine totals - total_sites = np.sum(sites) - total_match = np.sum( - np.logical_and( - seq1 == seq2, - sites)) + def __enter__(self): + self.region_writer = gzip.open(self.region_file, 'wt') - return total_match, total_sites, {'mask_flag': info_mask} + return self + def __exit__(self, type, value, traceback): + self.region_writer.close() -def make_info_string(info: Dict[str, List[bool]], - master_ind: int, - predict_ind: int) -> str: + if traceback is None: + # write index + with open(self.index_file, 'wb') as index_writer: + pickle.dump(self.index, index_writer) + return True + + else: + return False + + def write_header(self, region_id: str): + ''' + Add a header line with the region id + ''' + self.index[int(region_id[1:])] = self.region_writer.tell() + self.region_writer.write(f'#{region_id}\n') + + def write_sequences(self, + strain: str, + alignments: List, + sequences: np.array, + indices: Tuple): + ''' + Write sequences to region file + ''' + start, end = indices + names = self.known_states + [strain] + for sj, name in enumerate(names): + startj = bisect.bisect_left(alignments[sj], start) + endj = bisect.bisect_left(alignments[sj], end) + + self.region_writer.write(f'> {name} {startj} {endj}\n') + + self.region_writer.write(''.join( + sequences[sj][start:end+1]) + '\n') + + def write_info_string(self, info_string: str): + ''' + Write info string with header to region file + ''' + # write info string + self.region_writer.write('> info\n') + self.region_writer.write(info_string + '\n') + + +class Position_Reader(): ''' - Summarize info dictionary into a string. master_ind is the index of - the master reference state. predict_ind is the index of the predicted - state. The return string is encoded as each position as: - '-': if either master or predict has a gap - '_': if either master or predict is masked - '.': if any state has a match - 'b': both predict and master match - 'c': master matches but not predict - 'p': predict matches but not master - 'x': no other condition applies - if the position is in the hmm_flag it will be capitalized for x, p, c, or - b - in order of precidence, e.g. if a position satisfies both '-' and '.', - it will be '-'. + Read in position file, yielding positions until no longer on current + chromosome ''' - if predict_ind >= info['match_flag'].shape[1]: - return make_info_string_unknown(info, master_ind) + def __init__(self, position_file): + self.position_file = position_file + self.last_position = 0 - decoder = np.array(list('xXpPcCbB._-')) - indices = np.zeros(info['match_flag'].shape[0], int) + def __enter__(self): + self.reader = gzip.open(self.position_file, 'rt') + return self - indices[info['match_flag'][:, predict_ind]] += 2 # x to p if true - indices[info['match_flag'][:, master_ind]] += 4 # x to c, p to b - indices[info['hmm_flag']] += 1 # to upper + def __exit__(self, type, value, traceback): + self.reader.close() + return traceback is None - matches = np.all(info['match_flag'], axis=1) - indices[matches] = 8 # . - indices[np.any( - info['mask_flag'][:, [master_ind, predict_ind]], - axis=1)] = 9 # _ - indices[np.any( - info['gap_flag'][:, [master_ind, predict_ind]], - axis=1)] = 10 # - + def get_positions(self, + region: Region_Database, + chromosome: str) -> Tuple[str, np.array]: + self.reader.seek(self.last_position) + line = self.next_line() + while line != '': + line = line.split('\t') - return ''.join(decoder[indices]) + chrm = line[1] + if chrm != chromosome: + break + strain = line[0] + if not region.has_strain(strain): + line = self.next_line() + continue -def make_info_string_unknown(info: Dict[str, List[bool]], - master_ind: int) -> str: - ''' - Summarize info dictionary into a string for unknown state. - master_ind is the index of the master reference state. - The return string is encoded as each position as: - '-': if any state has a gap - '_': if any state has a mask - '.': all states match - 'x': master matches - 'X': no other condition applies - in order of precidence, e.g. if a position satisfies both '-' and '.', - it will be '-'. - ''' - - # used with indices to decode result - decoder = np.array(list('Xx._-')) - indices = np.zeros(info['gap_any_flag'].shape, int) + yield strain, np.array(line[2:], dtype=int) - indices[info['match_flag'][:, master_ind]] = 1 # x - matches = np.all(info['match_flag'], axis=1) - indices[matches] = 2 # . - indices[info['mask_any_flag']] = 3 # _ - indices[info['gap_any_flag']] = 4 # - + line = self.next_line() - return ''.join(decoder[indices]) + def next_line(self) -> str: + self.last_position = self.reader.tell() + line = self.reader.readline() + return line -def read_region_file(fn): - f = gzip.open(fn, 'rb') - d = {} - line = f.readline().decode() - while line != '': - region_id = line[1:-1] - line = f.readline().decode() - seqs = {} - while line[0] != '#': - line = line[:-1].split(' ') - strain = line[1] - seqs[strain] = {} - if len(line) > 2: - seqs[strain]['start'] = int(line[2]) - seqs[strain]['end'] = int(line[3]) - seqs[strain]['seq'] = f.readline().decode()[:-1] - line = f.readline().decode() - if line == '': - break - d[region_id] = seqs - - f.close() - return d +class Quality_Writer(): + ''' + Control writing of quality file from region database + ''' + def __init__(self, quality_filename): + self.filename = quality_filename + self.first_write = True + + def __enter__(self): + self.writer = open(self.filename, 'w') + return self + + def __exit__(self, type, value, traceback): + self.writer.close() + return traceback is None + + def write_quality(self, region: Region_Database): + ''' + Writes header if needed and region database values + ''' + if self.first_write is True: + self.writer.write(region.generate_header()) + self.first_write = False + + for line in region.generate_output(): + self.writer.write(line) diff --git a/code/analyze/summarize_region_quality_main.py b/code/analyze/summarize_region_quality_main.py deleted file mode 100644 index eba2088..0000000 --- a/code/analyze/summarize_region_quality_main.py +++ /dev/null @@ -1,296 +0,0 @@ -import sys -import os -import gzip -from analyze import predict -from analyze.summarize_region_quality import (convert_intervals_to_sites, - read_masked_intervals, - index_alignment_by_reference, - seq_id_hmm, - seq_id_unmasked, - make_info_string) -import global_params as gp -from misc import read_fasta -from misc import read_table -from misc import seq_functions -import numpy as np -import bisect -import pickle - - -def main() -> None: - ''' - Summarize region quality of each region - First parameter is the species to process - Input files: - -blocks_{species}_labeled.txt - -{species}_chr_intervals.txt - -{species}_chr_mafft.fa - -{species}_chr_mafft.fa - -positions_{tag}.txt.gz - - Output files: - -regions file as {species}.fa.gz - -index file for the fz.gz - -blocks_{species}_quality.txt - ''' - - args = predict.process_predict_args(sys.argv[2:]) - - task_ind = int(sys.argv[1]) - species_ind = task_ind - - species_from = args['states'][species_ind] - - base_dir = gp.analysis_out_dir_absolute + args['tag'] - - regions_dir = f'{base_dir}/regions/' - if not os.path.isdir(regions_dir): - os.mkdir(regions_dir) - - quality_writer = None - positions = gzip.open(f'{base_dir}/positions_{args["tag"]}.txt.gz', 'rt') - line_number = 0 - - region_writer = gzip.open( - f'{regions_dir}{species_from}{gp.fasta_suffix}.gz', 'wt') - region_index = {} - - for chrm in gp.chrms: - # region_id strain chromosome predicted_species start end num_non_gap - regions_chrm, labels = read_table.read_table_columns( - f'{base_dir}/blocks_{species_from}_{args["tag"]}_labeled.txt', - '\t', - group_by='strain', - chromosome=chrm - ) - - for strain in regions_chrm: - n = len(regions_chrm[strain]['region_id']) - - for s in args['known_states']: - regions_chrm[strain]['match_nongap_' + s] = [0] * n - regions_chrm[strain]['num_sites_nongap_' + s] = [0] * n - regions_chrm[strain]['match_hmm_' + s] = [0] * n - regions_chrm[strain]['match_nonmask_' + s] = [0] * n - regions_chrm[strain]['num_sites_nonmask_' + s] = [0] * n - - info_string_symbols = list('.-_npbcxNPBCX') - for s in info_string_symbols: - regions_chrm[strain]['count_' + s] = [0] * n - - # get masked sites for all references, not just the current - # species_from we're considering regions from - masked_sites_refs = {} - for s, state in enumerate(args['known_states']): - masked_sites_refs[s] = \ - convert_intervals_to_sites( - read_masked_intervals( - f'{gp.mask_dir}{state}' - f'_chr{chrm}_intervals.txt')) - - # loop through chromosomes and strains, followed by species of - # introgression so that we only have to read each alignment in once - # move to last read chromosome - positions.seek(line_number) - line = positions.readline() - while line != '': - line = line.split('\t') - - current_chrm = line[1] - if current_chrm != chrm: - break - - strain = line[0] - if strain not in regions_chrm: - # record current position in case need to re read line - line_number = positions.tell() - line = positions.readline() - continue - - print(strain, chrm) - - # indices of alignment columns used by HMM - ps = np.array([int(x) for x in line[2:]]) - - headers, seqs = read_fasta.read_fasta( - args['setup_args']['alignments_directory'] + - '_'.join(args['known_states']) - + f'_{strain}_chr{chrm}_mafft{gp.alignment_suffix}') - - # to go from index in reference seq to index in alignment - ind_align = [] - for seq in seqs: - ind_align.append(index_alignment_by_reference(seq)) - - masked_sites = convert_intervals_to_sites( - read_masked_intervals( - f'{gp.mask_dir}{strain}_chr{chrm}_intervals.txt')) - - masked_sites_ind_align = [] - for s in range(len(args['known_states'])): - masked_sites_ind_align.append( - ind_align[s][masked_sites_refs[s]]) - - # add in sequence of query strain - masked_sites_ind_align.append( - ind_align[-1][masked_sites]) - - # convert position indices from indices in master reference to - # indices in alignment - ps_ind_align = ind_align[0][ps] - - # loop through all regions for the specified chromosome and the - # current strain - for i in range(len(regions_chrm[strain]['region_id'])): - r_id = regions_chrm[strain]['region_id'][i] - start = regions_chrm[strain]['start'][i] - end = regions_chrm[strain]['end'][i] - - # calculate: - # - identity with each reference - # - fraction of region that is gapped/masked - - # index of start and end of region in aligned sequences - slice_start = ind_align[0][int(start)] - slice_end = ind_align[0][int(end)] - assert slice_start in ps_ind_align, \ - f'{slice_start} {start} {r_id}' - assert slice_end in ps_ind_align, \ - f'{slice_end} {end} {r_id}' - - seqx = seqs[-1][slice_start:slice_end + 1] - len_seqx = slice_end - slice_start + 1 - len_states = len(args['known_states']) - - # . = all match - # - = gap in one or more sequences - # p = matches predicted reference - - info = {'gap_any_flag': np.zeros((len_seqx), bool), - 'mask_any_flag': np.zeros((len_seqx), bool), - 'unseq_any_flag': np.zeros((len_seqx), bool), - 'hmm_flag': np.zeros((len_seqx), bool), - 'gap_flag': np.zeros((len_seqx, len_states), bool), - 'mask_flag': np.zeros((len_seqx, len_states), bool), - 'unseq_flag': np.zeros((len_seqx, len_states), bool), - 'match_flag': np.zeros((len_seqx, len_states), bool)} - - for sj, statej in enumerate(args['known_states']): - seqj = seqs[sj][slice_start:slice_end+1] - - # only alignment columns used by HMM (polymorphic, no - # gaps in any strain) - total_match_hmm, total_sites_hmm, infoj = \ - seq_id_hmm(seqj, seqx, slice_start, ps_ind_align) - - if statej == species_from \ - or species_ind >= len(args['known_states']): - regions_chrm[strain]['num_sites_hmm'][i] = \ - total_sites_hmm - - # only write once, the first index - if sj == 0: - info['hmm_flag'] = infoj['hmm_flag'] - - info['gap_any_flag'] = np.logical_or( - info['gap_any_flag'], infoj['gap_flag']) - info['unseq_any_flag'] = np.logical_or( - info['unseq_any_flag'], infoj['unseq_flag']) - info['gap_flag'][:, sj] = infoj['gap_flag'] - info['unseq_flag'][:, sj] = infoj['unseq_flag'] - info['match_flag'][:, sj] = infoj['match'] - - regions_chrm[strain][f'match_hmm_{statej}'][i] = \ - total_match_hmm - - # all alignment columns, excluding ones with gaps in - # these two sequences - total_match_nongap, total_sites_nongap = \ - seq_functions.seq_id(seqj, seqx) - - regions_chrm[strain][f'match_nongap_{statej}'][i] =\ - total_match_nongap - regions_chrm[strain][f'num_sites_nongap_{statej}'][i] =\ - total_sites_nongap - - # all alignment columns, excluding ones with gaps or - # masked bases or unsequenced in *these two sequences* - total_match_nonmask, total_sites_nonmask, infoj = \ - seq_id_unmasked(seqj, seqx, slice_start, - masked_sites_ind_align[sj], - masked_sites_ind_align[-1]) - - info['mask_any_flag'] = np.logical_or( - info['mask_any_flag'], infoj['mask_flag']) - info['mask_flag'][:, sj] = infoj['mask_flag'] - - regions_chrm[strain][f'match_nonmask_{statej}'][i] = \ - total_match_nonmask - regions_chrm[strain][f'num_sites_nonmask_{statej}'][i] = \ - total_sites_nonmask - - region_index[int(r_id[1:])] = region_writer.tell() - region_writer.write(f'#{r_id}\n') - names = args['known_states'] + [strain] - for sj in range(len(names)): - # write sequence to region alignment file, along with - # start and end coordinates - startj = bisect.bisect_left(ind_align[sj], slice_start) - endj = bisect.bisect_left(ind_align[sj], slice_end) - - region_writer.write(f'> {names[sj]} {startj} {endj}\n') - region_writer.write( - ''.join(seqs[sj][slice_start:slice_end+1]) + '\n') - - # also write string with info about each site - info_string = make_info_string(info, 0, species_ind) - region_writer.write('> info\n') - region_writer.write(info_string + '\n') - - # TODO this can be made faster with numpy - # and keep track of each symbol count - for sym in info_string_symbols: - regions_chrm[strain]['count_' + sym][i] = \ - info_string.count(sym) - - # record current position in case need to re read line - line_number = positions.tell() - line = positions.readline() - sys.stdout.flush() - - labels += ['match_nongap_' + x for x in args['known_states']] - labels += ['num_sites_nongap_' + x for x in args['known_states']] - labels += ['match_hmm_' + x for x in args['known_states']] - labels += ['match_nonmask_' + x for x in args['known_states']] - labels += ['num_sites_nonmask_' + x for x in args['known_states']] - labels += ['count_' + x for x in info_string_symbols] - - assert labels[0] == 'region_id', 'Unexpected labeled format' - - # write on first execution - if quality_writer is None: - quality_writer = open(f'{base_dir}/blocks_{species_from}' - f'_{args["tag"]}_quality.txt', 'w') - - quality_writer.write('\t'.join(labels) + '\n') - - # reorganize output as list of tuples ordered by label - output = [] - strains = list(regions_chrm.keys()) - for strain in strains: - # pop to limit memory usage - d = regions_chrm.pop(strain) - output += list(zip(*[d[l] for l in labels])) - - # sort by region id (index 0, remove r) - for entry in sorted(output, key=lambda e: int(e[0][1:])): - quality_writer.write('\t'.join([str(e) for e in entry]) + '\n') - - quality_writer.close() - region_writer.close() - with open(f'{regions_dir}{species_from}.pkl', 'wb') as index: - pickle.dump(region_index, index) - - -if __name__ == '__main__': - main() diff --git a/code/config.yaml b/code/config.yaml index c7d8db4..24d4698 100644 --- a/code/config.yaml +++ b/code/config.yaml @@ -41,17 +41,18 @@ paths: analysis: analysis_base: __OUTPUT_ROOT__/analysisp4e2 - regions: __ANALYSIS_BASE__/regions/ + regions: __ANALYSIS_BASE__/regions/{state}.fa.gz + region_index: __ANALYSIS_BASE__/regions/{state}.pkl genes: __ANALYSIS_BASE__/genes/ - block_files: __ANALYSIS_BASE__/blocks_{state}_p4e2.txt - labeled_block_files: "__ANALYSIS_BASE__/../analysis_test/\ - blocks_{state}__test_labeled.txt" + blocks: __ANALYSIS_BASE__/blocks_{state}_p4e2.txt + labeled_blocks: "__ANALYSIS_BASE__/blocks_{state}_p4e2_labeled.txt" + quality: __ANALYSIS_BASE__/block_{state}_quality.txt hmm_initial: __ANALYSIS_BASE__/hmm_initial.txt hmm_trained: __ANALYSIS_BASE__/hmm_trained.txt probabilities: __ANALYSIS_BASE__/probabilities.txt.gz alignment: __ALIGNMENTS__/{prefix}_{strain}_chr{chrom}_mafft.maf - # positions are optional positions: __ANALYSIS_BASE__/positions.txt.gz + masked_intervals: __MASKS__/{strain}_chr{chrom}_intervals.txt # software install locations software: @@ -87,10 +88,12 @@ analysis_params: input_root: /tigress/AKEY/akey_vol2/aclark4/nobackup # master known state, prepeded to list of known states + # TODO need to use the other name for S288c! reference: name: S288c base_dir: __INPUT_ROOT__/100_genomes/genomes/S288c_SGD-R64/ gene_bank_dir: __INPUT_ROOT__/S288c/ + interval_name: S288c_SGD-R64 # if different than name known_states: - name: CBS432 diff --git a/code/hmm/hmm_bw.py b/code/hmm/hmm_bw.py index ea667d6..2d5982d 100644 --- a/code/hmm/hmm_bw.py +++ b/code/hmm/hmm_bw.py @@ -344,6 +344,7 @@ def calculate_max_states(self) -> Tuple[np.array, np.array]: emissions[:, None, :] probabilities[0, :] = np.log(self.initial_p) + emissions[0] + states[0, :] = -1 for i in range(1, len(emissions)): diff --git a/code/misc/region_reader.py b/code/misc/region_reader.py index 53aac84..829bd13 100644 --- a/code/misc/region_reader.py +++ b/code/misc/region_reader.py @@ -3,7 +3,7 @@ import os import sys import numpy as np -from typing import Dict, List, Tuple +from typing import List, Tuple class Region_Reader(): @@ -42,6 +42,7 @@ def __enter__(self): def __exit__(self, type, value, traceback): self.region_reader.close() + return traceback is None def __repr__(self): return ( diff --git a/code/test/analyze/test_id_regions.py b/code/test/analyze/test_id_regions.py index 056231e..f811285 100644 --- a/code/test/analyze/test_id_regions.py +++ b/code/test/analyze/test_id_regions.py @@ -36,8 +36,8 @@ def test_add_ids_empty(id_producer, mocker): id_producer.config.add_config({ 'chromosomes': ['I', 'II', 'III', 'IV', 'V', 'VI', 'VII', 'VIII', 'IX', 'X', 'XI', 'XII', 'XIII', 'XIV', 'XV', 'XVI'], - 'paths': {'analysis': {'block_files': 'dir/blocks_{state}.txt', - 'labeled_block_files': + 'paths': {'analysis': {'blocks': 'dir/blocks_{state}.txt', + 'labeled_blocks': 'dir/blocks_{state}_labeled.txt', }}}) @@ -70,8 +70,8 @@ def test_add_ids(id_producer, mocker): id_producer.config.add_config({ 'chromosomes': ['I', 'II', 'III', 'IV', 'V', 'VI', 'VII', 'VIII', 'IX', 'X', 'XI', 'XII', 'XIII', 'XIV', 'XV', 'XVI'], - 'paths': {'analysis': {'block_files': 'dir/blocks_{state}.txt', - 'labeled_block_files': + 'paths': {'analysis': {'blocks': 'dir/blocks_{state}.txt', + 'labeled_blocks': 'dir/blocks_{state}_labeled.txt', }}}) diff --git a/code/test/analyze/test_introgression_configuration.py b/code/test/analyze/test_introgression_configuration.py index 01a1004..e8ce776 100644 --- a/code/test/analyze/test_introgression_configuration.py +++ b/code/test/analyze/test_introgression_configuration.py @@ -78,6 +78,55 @@ def test_get_states(config): assert config.get_states() == ('ref k1 k2 k3'.split(), 'u1 u2'.split()) +def test_get_interval_states(config): + assert config.get_interval_states() == [] + + config.config = { + 'analysis_params': { + 'reference': {'name': 'ref'}, + 'known_states': [ + {'name': 'k1'}, + {'name': 'k2'}, + {'name': 'k3'}, + ], + } + } + assert config.get_interval_states() == 'ref k1 k2 k3'.split() + + config.config = { + 'analysis_params': { + 'known_states': [ + {'name': 'k1'}, + {'name': 'k2'}, + {'name': 'k3'}, + ], + } + } + assert config.get_interval_states() == 'k1 k2 k3'.split() + + config.config = { + 'analysis_params': { + 'reference': {'name': 'ref'}, + } + } + assert config.get_interval_states() == 'ref'.split() + + config.config = { + 'analysis_params': { + 'reference': {'name': 'ref', + 'interval_name': 'int_ref'}, + 'known_states': [ + {'name': 'k1', + 'interval_name': 'i1'}, + {'name': 'k2'}, + {'name': 'k3', + 'interval_name': 'i3'}, + ], + } + } + assert config.get_interval_states() == 'int_ref i1 k2 i3'.split() + + def test_set_states(config): config.config = { 'analysis_params': @@ -150,12 +199,12 @@ def test_set_labeled_blocks_file(config): assert 'No labeled block file provided' in str(e) config.config = {'paths': {'analysis': - {'labeled_block_files': 'blocks_file'}}} + {'labeled_blocks': 'blocks_file'}}} with pytest.raises(ValueError) as e: config.set_labeled_blocks_file() assert '{state} not found in blocks_file' in str(e) - config.config = {'paths': {'analysis': {'labeled_block_files': + config.config = {'paths': {'analysis': {'labeled_blocks': 'blocks_file{state}'}}} config.set_labeled_blocks_file() assert config.labeled_blocks == 'blocks_file{state}' @@ -173,12 +222,12 @@ def test_set_blocks_file(config): config.set_blocks_file() assert 'No block file provided' in str(e) - config.config = {'paths': {'analysis': {'block_files': 'blocks_file'}}} + config.config = {'paths': {'analysis': {'blocks': 'blocks_file'}}} with pytest.raises(ValueError) as e: config.set_blocks_file() assert '{state} not found in blocks_file' in str(e) - config.config = {'paths': {'analysis': {'block_files': + config.config = {'paths': {'analysis': {'blocks': 'blocks_file{state}'}}} config.set_blocks_file() assert config.blocks == 'blocks_file{state}' @@ -321,7 +370,7 @@ def test_set_predict_files(config): with pytest.raises(ValueError) as e: config.set_predict_files('init', 'trained', 'pos', 'prob', 'align') - assert '{prefix} not found in align' in str(e) + assert '{strain} not found in align' in str(e) with pytest.raises(ValueError) as e: config.set_predict_files('init', 'trained', 'pos', 'prob', @@ -378,3 +427,89 @@ def test_set_predict_files(config): assert config.positions == 'pos' assert config.probabilities == 'prob' assert config.alignment == 'alignpre{strain}{chrom}' + + +def test_set_alignment(config): + config.set_alignment('align{strain}{chrom}') + assert config.alignment == 'align{strain}{chrom}' + + with pytest.raises(AttributeError) as e: + config.set_alignment('align{prefix}{strain}{chrom}') + assert "'Configuration' object has no attribute 'prefix'" in str(e) + + config.prefix = 'prefix' + config.set_alignment('align{prefix}{strain}{chrom}') + assert config.alignment == 'alignprefix{strain}{chrom}' + + +def test_set_regions_file(config): + with pytest.raises(ValueError) as e: + config.set_regions_files() + assert 'No region file provided' in str(e) + + with pytest.raises(ValueError) as e: + config.set_regions_files('region') + assert '{state} not found in region' in str(e) + + with pytest.raises(ValueError) as e: + config.set_regions_files('region{state}') + assert 'No region index file provided' in str(e) + + with pytest.raises(ValueError) as e: + config.set_regions_files('region{state}', 'index') + assert '{state} not found in index' in str(e) + + config.set_regions_files('region{state}', 'index{state}') + assert config.regions == 'region{state}' + assert config.region_index == 'index{state}' + + config.config = {'paths': {'analysis': {'regions': 'region{state}', + 'region_index': 'index{state}', + }}} + config.set_regions_files() + assert config.regions == 'region{state}' + assert config.region_index == 'index{state}' + + # args overwrite config + config.set_regions_files('reg{state}', 'ind{state}') + assert config.regions == 'reg{state}' + assert config.region_index == 'ind{state}' + + +def test_set_quality_file(config): + with pytest.raises(ValueError) as e: + config.set_quality_file() + assert 'No quality block file provided' in str(e) + + with pytest.raises(ValueError) as e: + config.set_quality_file('qual') + assert '{state} not found in qual' in str(e) + + config.set_quality_file('qual{state}') + assert config.quality_blocks == 'qual{state}' + + config.config = {'paths': {'analysis': {'quality': 'qua{state}'}}} + config.set_quality_file() + assert config.quality_blocks == 'qua{state}' + + +def test_set_masked_file(config): + with pytest.raises(ValueError) as e: + config.set_masked_file() + assert 'No masked interval file provided' in str(e) + + with pytest.raises(ValueError) as e: + config.set_masked_file('mask') + assert '{strain} not found in mask' in str(e) + + with pytest.raises(ValueError) as e: + config.set_masked_file('mask{strain}') + assert '{chrom} not found in mask{strain}' in str(e) + + config.set_masked_file('mask{strain}{chrom}') + assert config.masks == 'mask{strain}{chrom}' + + config.config = {'paths': {'analysis': + {'masked_intervals': 'msk{strain}{chrom}'}}} + config.set_masked_file() + assert config.masks == 'msk{strain}{chrom}' diff --git a/code/test/analyze/test_main_id_config.py b/code/test/analyze/test_main_id_config.py index c3631f0..eb41378 100644 --- a/code/test/analyze/test_main_id_config.py +++ b/code/test/analyze/test_main_id_config.py @@ -87,7 +87,7 @@ def test_block_file(runner, mocker): {'name': 's2'}], }, 'paths': {'analysis': { - 'block_files': 'block_{state}.txt', + 'blocks': 'block_{state}.txt', }} }, f) @@ -119,8 +119,8 @@ def test_labeled_block_file(runner, mocker): {'name': 's2'}], }, 'paths': {'analysis': { - 'block_files': 'block_{state}.txt', - 'labeled_block_files': 'labeled_block_{state}.txt', + 'blocks': 'block_{state}.txt', + 'labeled_blocks': 'labeled_block_{state}.txt', }} }, f) diff --git a/code/test/analyze/test_main_predict_config.py b/code/test/analyze/test_main_predict_config.py index eb2a9ef..f53bd40 100644 --- a/code/test/analyze/test_main_predict_config.py +++ b/code/test/analyze/test_main_predict_config.py @@ -85,7 +85,7 @@ def test_block(runner, mocker): {'name': 's2'}], }, 'paths': {'analysis': { - 'block_files': 'blocks_{state}.txt', + 'blocks': 'blocks_{state}.txt', }}, }, f) @@ -120,7 +120,7 @@ def test_prefix(runner, mocker): {'name': 's2'}], }, 'paths': {'analysis': { - 'block_files': 'blocks_{state}.txt', + 'blocks': 'blocks_{state}.txt', }}, }, f) @@ -156,7 +156,7 @@ def test_strains(runner, mocker): {'name': 's2'}], }, 'paths': {'analysis': { - 'block_files': 'blocks_{state}.txt', + 'blocks': 'blocks_{state}.txt', }}, }, f) @@ -194,7 +194,7 @@ def test_test_strains(runner, mocker): }, 'paths': { 'analysis': { - 'block_files': 'blocks_{state}.txt', + 'blocks': 'blocks_{state}.txt', }, 'test_strains': ['{strain}_chr{chrom}.fa']}, }, f) @@ -253,7 +253,7 @@ def test_outputs(runner, mocker): {'name': 's2'}], }, 'paths': {'analysis': { - 'block_files': 'blocks_{state}.txt', + 'blocks': 'blocks_{state}.txt', 'hmm_initial': 'hmm_init.txt', }}, }, f) @@ -281,7 +281,7 @@ def test_outputs(runner, mocker): {'name': 's2'}], }, 'paths': {'analysis': { - 'block_files': 'blocks_{state}.txt', + 'blocks': 'blocks_{state}.txt', 'hmm_initial': 'hmm_init.txt', 'hmm_trained': 'hmm_trained.txt', }}, @@ -310,7 +310,7 @@ def test_outputs(runner, mocker): {'name': 's2'}], }, 'paths': {'analysis': { - 'block_files': 'blocks_{state}.txt', + 'blocks': 'blocks_{state}.txt', 'hmm_initial': 'hmm_init.txt', 'hmm_trained': 'hmm_trained.txt', 'positions': 'pos.txt.gz', @@ -341,7 +341,7 @@ def test_outputs(runner, mocker): {'name': 's2'}], }, 'paths': {'analysis': { - 'block_files': 'blocks_{state}.txt', + 'blocks': 'blocks_{state}.txt', 'hmm_initial': 'hmm_init.txt', 'hmm_trained': 'hmm_trained.txt', 'probabilities': 'probs.txt.gz', @@ -381,7 +381,7 @@ def test_outputs(runner, mocker): {'name': 's2'}], }, 'paths': {'analysis': { - 'block_files': 'blocks_{state}.txt', + 'blocks': 'blocks_{state}.txt', 'hmm_initial': 'hmm_init.txt', 'hmm_trained': 'hmm_trained.txt', 'positions': 'pos.txt.gz', @@ -420,7 +420,7 @@ def test_outputs(runner, mocker): {'name': 's2'}], }, 'paths': {'analysis': { - 'block_files': 'blocks_{state}.txt', + 'blocks': 'blocks_{state}.txt', 'hmm_initial': 'hmm_init.txt', 'hmm_trained': 'hmm_trained.txt', 'positions': 'pos.txt.gz', diff --git a/code/test/analyze/test_main_summarize_regions_args.py b/code/test/analyze/test_main_summarize_regions_args.py new file mode 100644 index 0000000..cc03ea1 --- /dev/null +++ b/code/test/analyze/test_main_summarize_regions_args.py @@ -0,0 +1,328 @@ +import pytest +from click.testing import CliRunner +import analyze.main as main +import yaml +from analyze.summarize_region_quality import Summarizer + + +''' +Unit tests for the summarize_regiosn command of main.py when parameters are +provided by args +''' + + +@pytest.fixture +def runner(): + return CliRunner() + + +def test_empty(runner, mocker): + mock_log = mocker.patch('analyze.main.log.info') + with runner.isolated_filesystem(): + with open('config.yaml', 'w') as f: + yaml.dump( + { + 'analysis_params': { + 'reference': + {'name': 'r1'}, + 'known_states': [ + {'name': 's1'}, + {'name': 's2'}], + 'unknown_states': [ + {'name': 'u1'}] + }, + 'chromosomes': 'I II III'.split(), + }, f) + + result = runner.invoke( + main.cli, + '--config config.yaml summarize-regions') + + assert result.exit_code != 0 + assert str(result.exception) == 'No labeled block file provided' + assert mock_log.call_args_list == [ + mocker.call('Verbosity set to WARNING'), + mocker.call('Read in 1 config file'), + mocker.call('Found 3 chromosomes in config'), + ] + + +def test_labeled(runner, mocker): + mock_log = mocker.patch('analyze.main.log.info') + with runner.isolated_filesystem(): + with open('config.yaml', 'w') as f: + yaml.dump( + { + 'analysis_params': { + 'reference': + {'name': 'r1'}, + 'known_states': [ + {'name': 's1'}, + {'name': 's2'}], + 'unknown_states': [ + {'name': 'u1'}] + }, + 'chromosomes': 'I II III'.split(), + }, f) + + result = runner.invoke( + main.cli, + '--config config.yaml summarize-regions --labeled {state}lbl.txt') + + assert result.exit_code != 0 + assert str(result.exception) == 'No quality block file provided' + assert mock_log.call_args_list == [ + mocker.call('Verbosity set to WARNING'), + mocker.call('Read in 1 config file'), + mocker.call('Found 3 chromosomes in config'), + mocker.call("Labeled blocks file is '{state}lbl.txt'"), + ] + + +def test_quality(runner, mocker): + mock_log = mocker.patch('analyze.main.log.info') + with runner.isolated_filesystem(): + with open('config.yaml', 'w') as f: + yaml.dump( + { + 'analysis_params': { + 'reference': + {'name': 'r1'}, + 'known_states': [ + {'name': 's1'}, + {'name': 's2'}], + 'unknown_states': [ + {'name': 'u1'}] + }, + 'chromosomes': 'I II III'.split(), + }, f) + + result = runner.invoke( + main.cli, + '--config config.yaml summarize-regions --labeled {state}lbl.txt ' + '--quality {state}qual.txt') + + assert result.exit_code != 0 + assert str(result.exception) == 'No masked interval file provided' + assert mock_log.call_args_list == [ + mocker.call('Verbosity set to WARNING'), + mocker.call('Read in 1 config file'), + mocker.call('Found 3 chromosomes in config'), + mocker.call("Labeled blocks file is '{state}lbl.txt'"), + mocker.call("Quality file is '{state}qual.txt'"), + ] + + +def test_masked(runner, mocker): + mock_log = mocker.patch('analyze.main.log.info') + with runner.isolated_filesystem(): + with open('config.yaml', 'w') as f: + yaml.dump( + { + 'analysis_params': { + 'reference': + {'name': 'r1'}, + 'known_states': [ + {'name': 's1'}, + {'name': 's2'}], + 'unknown_states': [ + {'name': 'u1'}] + }, + 'chromosomes': 'I II III'.split(), + }, f) + + result = runner.invoke( + main.cli, + '--config config.yaml summarize-regions --labeled {state}lbl.txt ' + '--quality {state}qual.txt --masks {strain}_{chrom}mask.txt') + + assert result.exit_code != 0 + assert str(result.exception) == 'No alignment file provided' + assert mock_log.call_args_list == [ + mocker.call('Verbosity set to WARNING'), + mocker.call('Read in 1 config file'), + mocker.call('Found 3 chromosomes in config'), + mocker.call("Labeled blocks file is '{state}lbl.txt'"), + mocker.call("Quality file is '{state}qual.txt'"), + mocker.call("Mask file is '{strain}_{chrom}mask.txt'"), + ] + + +def test_alignment(runner, mocker): + mock_log = mocker.patch('analyze.main.log.info') + with runner.isolated_filesystem(): + with open('config.yaml', 'w') as f: + yaml.dump( + { + 'analysis_params': { + 'reference': + {'name': 'r1'}, + 'known_states': [ + {'name': 's1'}, + {'name': 's2'}], + 'unknown_states': [ + {'name': 'u1'}] + }, + 'chromosomes': 'I II III'.split(), + }, f) + + # no prefix + result = runner.invoke( + main.cli, + '--config config.yaml summarize-regions --labeled {state}lbl.txt ' + '--quality {state}qual.txt --masks {strain}_{chrom}mask.txt ' + '--alignment {strain}_{chrom}_align.txt') + + assert result.exit_code != 0 + assert str(result.exception) == 'No positions file provided' + assert mock_log.call_args_list == [ + mocker.call('Verbosity set to WARNING'), + mocker.call('Read in 1 config file'), + mocker.call('Found 3 chromosomes in config'), + mocker.call("Labeled blocks file is '{state}lbl.txt'"), + mocker.call("Quality file is '{state}qual.txt'"), + mocker.call("Mask file is '{strain}_{chrom}mask.txt'"), + mocker.call("Alignment file is '{strain}_{chrom}_align.txt'"), + ] + + # with prefix + mock_log.reset_mock() + result = runner.invoke( + main.cli, + '--config config.yaml summarize-regions --labeled {state}lbl.txt ' + '--quality {state}qual.txt --masks {strain}_{chrom}mask.txt ' + '--alignment {prefix}_{strain}_{chrom}_align.txt') + + assert result.exit_code != 0 + assert str(result.exception) == 'No positions file provided' + assert mock_log.call_args_list == [ + mocker.call('Verbosity set to WARNING'), + mocker.call('Read in 1 config file'), + mocker.call('Found 3 chromosomes in config'), + mocker.call("Labeled blocks file is '{state}lbl.txt'"), + mocker.call("Quality file is '{state}qual.txt'"), + mocker.call("Mask file is '{strain}_{chrom}mask.txt'"), + mocker.call("Alignment file is " + "'r1_s1_s2_{strain}_{chrom}_align.txt'"), + ] + + +def test_positions(runner, mocker): + mock_log = mocker.patch('analyze.main.log.info') + with runner.isolated_filesystem(): + with open('config.yaml', 'w') as f: + yaml.dump( + { + 'analysis_params': { + 'reference': + {'name': 'r1'}, + 'known_states': [ + {'name': 's1'}, + {'name': 's2'}], + 'unknown_states': [ + {'name': 'u1'}] + }, + 'chromosomes': 'I II III'.split(), + }, f) + + result = runner.invoke( + main.cli, + '--config config.yaml summarize-regions --labeled {state}lbl.txt ' + '--quality {state}qual.txt --masks {strain}_{chrom}mask.txt ' + '--alignment {strain}_{chrom}_align.txt --positions pos.txt') + + assert result.exit_code != 0 + assert str(result.exception) == 'No region file provided' + assert mock_log.call_args_list == [ + mocker.call('Verbosity set to WARNING'), + mocker.call('Read in 1 config file'), + mocker.call('Found 3 chromosomes in config'), + mocker.call("Labeled blocks file is '{state}lbl.txt'"), + mocker.call("Quality file is '{state}qual.txt'"), + mocker.call("Mask file is '{strain}_{chrom}mask.txt'"), + mocker.call("Alignment file is '{strain}_{chrom}_align.txt'"), + mocker.call("Positions file is 'pos.txt'"), + ] + + +def test_region(runner, mocker): + mock_log = mocker.patch('analyze.main.log.info') + with runner.isolated_filesystem(): + with open('config.yaml', 'w') as f: + yaml.dump( + { + 'analysis_params': { + 'reference': + {'name': 'r1'}, + 'known_states': [ + {'name': 's1'}, + {'name': 's2'}], + 'unknown_states': [ + {'name': 'u1'}] + }, + 'chromosomes': 'I II III'.split(), + }, f) + + result = runner.invoke( + main.cli, + '--config config.yaml summarize-regions --labeled {state}lbl.txt ' + '--quality {state}qual.txt --masks {strain}_{chrom}mask.txt ' + '--alignment {strain}_{chrom}_align.txt --positions pos.txt ' + '--region region{state}.gz' + ) + + assert result.exit_code != 0 + assert str(result.exception) == 'No region index file provided' + assert mock_log.call_args_list == [ + mocker.call('Verbosity set to WARNING'), + mocker.call('Read in 1 config file'), + mocker.call('Found 3 chromosomes in config'), + mocker.call("Labeled blocks file is '{state}lbl.txt'"), + mocker.call("Quality file is '{state}qual.txt'"), + mocker.call("Mask file is '{strain}_{chrom}mask.txt'"), + mocker.call("Alignment file is '{strain}_{chrom}_align.txt'"), + mocker.call("Positions file is 'pos.txt'"), + ] + + +def test_run(runner, mocker): + mock_log = mocker.patch('analyze.main.log.info') + mock_summarize = mocker.patch.object(Summarizer, 'run') + with runner.isolated_filesystem(): + with open('config.yaml', 'w') as f: + yaml.dump( + { + 'analysis_params': { + 'reference': + {'name': 'r1'}, + 'known_states': [ + {'name': 's1'}, + {'name': 's2'}], + 'unknown_states': [ + {'name': 'u1'}] + }, + 'chromosomes': 'I II III'.split(), + }, f) + + result = runner.invoke( + main.cli, + '--config config.yaml summarize-regions --labeled {state}lbl.txt ' + '--quality {state}qual.txt --masks {strain}_{chrom}mask.txt ' + '--alignment {strain}_{chrom}_align.txt --positions pos.txt ' + '--region region{state}.gz --region-index ind{state}.pkl' + ) + + assert result.exit_code == 0 + assert mock_log.call_args_list == [ + mocker.call('Verbosity set to WARNING'), + mocker.call('Read in 1 config file'), + mocker.call('Found 3 chromosomes in config'), + mocker.call("Labeled blocks file is '{state}lbl.txt'"), + mocker.call("Quality file is '{state}qual.txt'"), + mocker.call("Mask file is '{strain}_{chrom}mask.txt'"), + mocker.call("Alignment file is '{strain}_{chrom}_align.txt'"), + mocker.call("Positions file is 'pos.txt'"), + mocker.call("Region file is 'region{state}.gz'"), + mocker.call("Region index file is 'ind{state}.pkl'"), + ] + mock_summarize.assert_called_once_with([]) diff --git a/code/test/analyze/test_main_summarize_regions_config.py b/code/test/analyze/test_main_summarize_regions_config.py new file mode 100644 index 0000000..5419d35 --- /dev/null +++ b/code/test/analyze/test_main_summarize_regions_config.py @@ -0,0 +1,335 @@ +import pytest +from click.testing import CliRunner +import analyze.main as main +import yaml +from analyze.summarize_region_quality import Summarizer + + +''' +Unit tests for the summarize_regions command of main.py when parameters are +provided by config +''' + + +@pytest.fixture +def runner(): + return CliRunner() + + +def test_empty(runner, mocker): + mock_log = mocker.patch('analyze.main.log.info') + with runner.isolated_filesystem(): + with open('config.yaml', 'w') as f: + yaml.dump( + { + 'analysis_params': { + 'reference': + {'name': 'r1'}, + 'known_states': [ + {'name': 's1'}, + {'name': 's2'}], + 'unknown_states': [ + {'name': 'u1'}] + }, + 'chromosomes': 'I II III'.split(), + }, f) + + result = runner.invoke( + main.cli, + '--config config.yaml summarize-regions') + + assert result.exit_code != 0 + assert str(result.exception) == 'No labeled block file provided' + assert mock_log.call_args_list == [ + mocker.call('Verbosity set to WARNING'), + mocker.call('Read in 1 config file'), + mocker.call('Found 3 chromosomes in config'), + ] + + +def test_labeled(runner, mocker): + mock_log = mocker.patch('analyze.main.log.info') + with runner.isolated_filesystem(): + with open('config.yaml', 'w') as f: + yaml.dump( + { + 'analysis_params': { + 'reference': + {'name': 'r1'}, + 'known_states': [ + {'name': 's1'}, + {'name': 's2'}], + 'unknown_states': [ + {'name': 'u1'}] + }, + 'chromosomes': 'I II III'.split(), + 'paths': {'analysis': { + 'labeled_blocks': '{state}lbl.txt', + }} + }, f) + + result = runner.invoke( + main.cli, + '--config config.yaml summarize-regions') + + assert result.exit_code != 0 + assert str(result.exception) == 'No quality block file provided' + assert mock_log.call_args_list == [ + mocker.call('Verbosity set to WARNING'), + mocker.call('Read in 1 config file'), + mocker.call('Found 3 chromosomes in config'), + mocker.call("Labeled blocks file is '{state}lbl.txt'"), + ] + + +def test_quality(runner, mocker): + mock_log = mocker.patch('analyze.main.log.info') + with runner.isolated_filesystem(): + with open('config.yaml', 'w') as f: + yaml.dump( + { + 'analysis_params': { + 'reference': + {'name': 'r1'}, + 'known_states': [ + {'name': 's1'}, + {'name': 's2'}], + 'unknown_states': [ + {'name': 'u1'}] + }, + 'chromosomes': 'I II III'.split(), + 'paths': {'analysis': { + 'labeled_blocks': '{state}lbl.txt', + 'quality': '{state}qual.txt', + }} + }, f) + + result = runner.invoke( + main.cli, + '--config config.yaml summarize-regions') + + assert result.exit_code != 0 + assert str(result.exception) == 'No masked interval file provided' + assert mock_log.call_args_list == [ + mocker.call('Verbosity set to WARNING'), + mocker.call('Read in 1 config file'), + mocker.call('Found 3 chromosomes in config'), + mocker.call("Labeled blocks file is '{state}lbl.txt'"), + mocker.call("Quality file is '{state}qual.txt'"), + ] + + +def test_masked(runner, mocker): + mock_log = mocker.patch('analyze.main.log.info') + with runner.isolated_filesystem(): + with open('config.yaml', 'w') as f: + yaml.dump( + { + 'analysis_params': { + 'reference': + {'name': 'r1'}, + 'known_states': [ + {'name': 's1'}, + {'name': 's2'}], + 'unknown_states': [ + {'name': 'u1'}] + }, + 'chromosomes': 'I II III'.split(), + 'paths': {'analysis': { + 'labeled_blocks': '{state}lbl.txt', + 'quality': '{state}qual.txt', + 'masked_intervals': '{strain}_{chrom}mask.txt', + }} + }, f) + + result = runner.invoke( + main.cli, + '--config config.yaml summarize-regions') + + assert result.exit_code != 0 + assert str(result.exception) == 'No alignment file provided' + assert mock_log.call_args_list == [ + mocker.call('Verbosity set to WARNING'), + mocker.call('Read in 1 config file'), + mocker.call('Found 3 chromosomes in config'), + mocker.call("Labeled blocks file is '{state}lbl.txt'"), + mocker.call("Quality file is '{state}qual.txt'"), + mocker.call("Mask file is '{strain}_{chrom}mask.txt'"), + ] + + +def test_alignment(runner, mocker): + mock_log = mocker.patch('analyze.main.log.info') + with runner.isolated_filesystem(): + with open('config.yaml', 'w') as f: + yaml.dump( + { + 'analysis_params': { + 'reference': + {'name': 'r1'}, + 'known_states': [ + {'name': 's1'}, + {'name': 's2'}], + 'unknown_states': [ + {'name': 'u1'}] + }, + 'chromosomes': 'I II III'.split(), + 'paths': {'analysis': { + 'labeled_blocks': '{state}lbl.txt', + 'quality': '{state}qual.txt', + 'masked_intervals': '{strain}_{chrom}mask.txt', + 'alignment': '{strain}_{chrom}_align.txt', + }} + }, f) + + # no prefix + result = runner.invoke( + main.cli, + '--config config.yaml summarize-regions') + + assert result.exit_code != 0 + assert str(result.exception) == 'No positions file provided' + assert mock_log.call_args_list == [ + mocker.call('Verbosity set to WARNING'), + mocker.call('Read in 1 config file'), + mocker.call('Found 3 chromosomes in config'), + mocker.call("Labeled blocks file is '{state}lbl.txt'"), + mocker.call("Quality file is '{state}qual.txt'"), + mocker.call("Mask file is '{strain}_{chrom}mask.txt'"), + mocker.call("Alignment file is '{strain}_{chrom}_align.txt'"), + ] + + +def test_positions(runner, mocker): + mock_log = mocker.patch('analyze.main.log.info') + with runner.isolated_filesystem(): + with open('config.yaml', 'w') as f: + yaml.dump( + { + 'analysis_params': { + 'reference': + {'name': 'r1'}, + 'known_states': [ + {'name': 's1'}, + {'name': 's2'}], + 'unknown_states': [ + {'name': 'u1'}] + }, + 'chromosomes': 'I II III'.split(), + 'paths': {'analysis': { + 'labeled_blocks': '{state}lbl.txt', + 'quality': '{state}qual.txt', + 'masked_intervals': '{strain}_{chrom}mask.txt', + 'alignment': '{strain}_{chrom}_align.txt', + 'positions': 'pos.txt', + }} + }, f) + + result = runner.invoke( + main.cli, + '--config config.yaml summarize-regions') + + assert result.exit_code != 0 + assert str(result.exception) == 'No region file provided' + assert mock_log.call_args_list == [ + mocker.call('Verbosity set to WARNING'), + mocker.call('Read in 1 config file'), + mocker.call('Found 3 chromosomes in config'), + mocker.call("Labeled blocks file is '{state}lbl.txt'"), + mocker.call("Quality file is '{state}qual.txt'"), + mocker.call("Mask file is '{strain}_{chrom}mask.txt'"), + mocker.call("Alignment file is '{strain}_{chrom}_align.txt'"), + mocker.call("Positions file is 'pos.txt'"), + ] + + +def test_region(runner, mocker): + mock_log = mocker.patch('analyze.main.log.info') + with runner.isolated_filesystem(): + with open('config.yaml', 'w') as f: + yaml.dump( + { + 'analysis_params': { + 'reference': + {'name': 'r1'}, + 'known_states': [ + {'name': 's1'}, + {'name': 's2'}], + 'unknown_states': [ + {'name': 'u1'}] + }, + 'chromosomes': 'I II III'.split(), + 'paths': {'analysis': { + 'labeled_blocks': '{state}lbl.txt', + 'quality': '{state}qual.txt', + 'masked_intervals': '{strain}_{chrom}mask.txt', + 'alignment': '{strain}_{chrom}_align.txt', + 'positions': 'pos.txt', + 'regions': 'region{state}.gz', + }} + }, f) + + result = runner.invoke( + main.cli, + '--config config.yaml summarize-regions') + + assert result.exit_code != 0 + assert str(result.exception) == 'No region index file provided' + assert mock_log.call_args_list == [ + mocker.call('Verbosity set to WARNING'), + mocker.call('Read in 1 config file'), + mocker.call('Found 3 chromosomes in config'), + mocker.call("Labeled blocks file is '{state}lbl.txt'"), + mocker.call("Quality file is '{state}qual.txt'"), + mocker.call("Mask file is '{strain}_{chrom}mask.txt'"), + mocker.call("Alignment file is '{strain}_{chrom}_align.txt'"), + mocker.call("Positions file is 'pos.txt'"), + ] + + +def test_run(runner, mocker): + mock_log = mocker.patch('analyze.main.log.info') + mock_summarize = mocker.patch.object(Summarizer, 'run') + with runner.isolated_filesystem(): + with open('config.yaml', 'w') as f: + yaml.dump( + { + 'analysis_params': { + 'reference': + {'name': 'r1'}, + 'known_states': [ + {'name': 's1'}, + {'name': 's2'}], + 'unknown_states': [ + {'name': 'u1'}] + }, + 'chromosomes': 'I II III'.split(), + 'paths': {'analysis': { + 'labeled_blocks': '{state}lbl.txt', + 'quality': '{state}qual.txt', + 'masked_intervals': '{strain}_{chrom}mask.txt', + 'alignment': '{strain}_{chrom}_align.txt', + 'positions': 'pos.txt', + 'regions': 'region{state}.gz', + 'region_index': 'ind{state}.pkl', + }} + }, f) + + result = runner.invoke( + main.cli, + '--config config.yaml summarize-regions') + + assert result.exit_code == 0 + assert mock_log.call_args_list == [ + mocker.call('Verbosity set to WARNING'), + mocker.call('Read in 1 config file'), + mocker.call('Found 3 chromosomes in config'), + mocker.call("Labeled blocks file is '{state}lbl.txt'"), + mocker.call("Quality file is '{state}qual.txt'"), + mocker.call("Mask file is '{strain}_{chrom}mask.txt'"), + mocker.call("Alignment file is '{strain}_{chrom}_align.txt'"), + mocker.call("Positions file is 'pos.txt'"), + mocker.call("Region file is 'region{state}.gz'"), + mocker.call("Region index file is 'ind{state}.pkl'"), + ] + mock_summarize.assert_called_once_with([]) diff --git a/code/test/analyze/test_summarize_region_quality.py b/code/test/analyze/test_summarize_region_quality.py index 80721c6..c7db4f7 100644 --- a/code/test/analyze/test_summarize_region_quality.py +++ b/code/test/analyze/test_summarize_region_quality.py @@ -3,20 +3,1095 @@ import pytest from pytest import approx import numpy as np +from numpy.testing import assert_array_equal as aae +from analyze.introgression_configuration import Configuration -def test_read_masked_intervals(mocker): +@pytest.fixture +def summarizer(): + return summarize.Summarizer(Configuration()) + + +def test_states_to_process(summarizer, mocker): + summarizer.config.add_config({ + 'analysis_params': + {'reference': {'name': 'S288c'}, + 'known_states': [ + {'name': 'CBS432'}, + {'name': 'N_45'}, + ], + 'unknown_states': [{'name': 'unknown'}] + } + }) + summarizer.config.set_states() + + assert summarizer.states_to_process() == \ + (0, 'S288c CBS432 N_45 unknown'.split()) + + mock_warn = mocker.patch('analyze.summarize_region_quality.log.warning') + assert summarizer.states_to_process('N_45 asdf S288c'.split()) == \ + (0, 'N_45 S288c'.split()) + mock_warn.assert_called_with("state 'asdf' was not found as a state") + + with pytest.raises(ValueError) as e: + summarizer.states_to_process('asdf qwer'.split()) + assert 'No valid states were found to process' in str(e) + + summarizer.config.add_config({ + 'analysis_params': {'reference': {'name': 'N_45'}}}) + + assert summarizer.states_to_process() == \ + (2, 'S288c CBS432 N_45 unknown'.split()) + + +def test_run(summarizer, mocker): + summarizer.config.add_config({ + 'analysis_params': + {'reference': {'name': 'S288c'}, + 'known_states': [ + {'name': 'CBS432'}, + {'name': 'N_45'}, + {'name': 'DBVPG6304'}, + {'name': 'UWOPS91_917_1'} + ], + 'unknown_states': [{'name': 'unknown'}] + } + }) + summarizer.config.set_states() + summarizer.config.set_HMM_symbols() + summarizer.config.set_positions('positions.txt.gz') + summarizer.config.set_labeled_blocks_file( + 'dir/tag/blocks_{state}_labeled.txt') + summarizer.config.set_quality_file('dir/tag/blocks_{state}_quality.txt') + summarizer.config.set_alignment('dir/tag/blocks_{chrom}_{strain}.txt') + summarizer.config.set_regions_files('dir/tag/regions/{state}.fa.gz', + 'dir/tag/regions/{state}.pkl') + summarizer.config.set_masked_file('dir/masked/{strain}_chr{chrom}.txt') + summarizer.config.chromosomes = ['I', 'II'] + summarizer.validate_arguments() + # for region database + mock_table = mocker.patch( + 'misc.read_table.open', + mocker.mock_open( + read_data='region_id\tstrain\tchromosome\t' + 'predicted_species\tstart\tend\tnum_sites_hmm\n' + 'r4\tyjm1381\tI\tS288c\t2\t5\t60\n' + 'r5\tyjm689\tI\tS288c\t3\t6\t56\n' + 'r6\tyjm1381\tI\tS288c\t3\t7\t18\n' + 'r7\tyjm689\tI\tS288c\t3\t5\t13728\n' + 'r8\tyjm1208\tI\tS288c\t3\t4\t20\n' + 'r9\tyjm1304\tII\tS288c\t3\t7\t16\n' + )) + + # sequence analyzer masked sites + mock_masked = mocker.patch.object( + summarize.Sequence_Analyzer, + 'read_masked_intervals', + return_value=[ + (0, 2), + (4, 5), + ]) + position_in = StringIO( + 'yjm1381\tI\t2\t3\t5\t7\n' + 'yjm689\tI\t3\t4\t5\t6\n' + 'yjm1464\tI\t1\t2\t3\t3\n' + ) + region_out = StringIO() + + def new_close(): + pass + + mocker.patch.object(region_out, 'close', new_close) + + mocked_gzip = mocker.patch( + 'analyze.summarize_region_quality.gzip.open', + side_effect=[position_in, region_out]) + + mocked_file = mocker.patch('analyze.summarize_region_quality.open', + mocker.mock_open()) + mock_log = mocker.patch('analyze.summarize_region_quality.log') + + mocker.patch('analyze.summarize_region_quality.read_fasta.read_fasta', + return_value=('', + np.asarray([ + list('--gatcctag--'), + list('-agatgcaag-c'), + list('-agatgcaag-c'), + list('-agatgcaag-c'), + list('-a-attacagt-'), + list('-a-atttcagt-'), + ]))) + + summarizer.run(['unknown']) + + mock_masked.assert_any_call('dir/masked/UWOPS91_917_1_chrII.txt') + mock_masked.assert_any_call('dir/masked/yjm1381_chrI.txt') + + assert mocked_gzip.call_args_list == [ + mocker.call('positions.txt.gz', 'rt'), + mocker.call('dir/tag/regions/unknown.fa.gz', 'wt'), + ] + assert mock_log.debug.call_args_list == [ + mocker.call('reference index: 0'), + mocker.call("states to analyze: ['unknown']"), + mocker.call("known_states ['S288c', 'CBS432', 'N_45', " + "'DBVPG6304', 'UWOPS91_917_1']"), + mocker.call('Sequence_Analyzer init with:'), + mocker.call('masks: dir/masked/{strain}_chr{chrom}.txt'), + mocker.call('alignment: dir/tag/blocks_{chrom}_{strain}.txt'), + mocker.call('yjm1381 I'), + mocker.call('yjm689 I') + ] + assert mock_log.info.call_args_list == [ + mocker.call('Working on state unknown'), + mocker.call('Working on chromosome I'), + mocker.call('Working on chromosome II'), + ] + + assert mocked_file.call_count == 2 + mocked_file.assert_any_call( + 'dir/tag/blocks_unknown_quality.txt', 'w') + mocked_file.assert_any_call( + 'dir/tag/regions/unknown.pkl', 'wb') + + mock_table.assert_any_call( + 'dir/tag/blocks_unknown_labeled.txt', 'r') + + # just headers + states = ['S288c', 'CBS432', 'N_45', 'DBVPG6304', 'UWOPS91_917_1'] + symbols = list('.-_npbcxNPBCX') + assert mocked_file().write.call_args_list == [ + mocker.call( + '\t'.join( + ('region_id\tstrain\tchromosome\tpredicted_species\tstart' + '\tend\tnum_sites_hmm').split() + + ['match_nongap_' + x for x in states] + + ['num_sites_nongap_' + x for x in states] + + ['match_hmm_' + x for x in states] + + ['match_nonmask_' + x for x in states] + + ['num_sites_nonmask_' + x for x in states] + + ['count_' + x for x in symbols] + ) + '\n'), + + mocker.call('r4\tyjm1381\tI\tS288c\t2\t5\t3\t1\t1\t1\t1\t3\t4\t4\t4\t4' + '\t4\t1\t1\t1\t1\t3\t0\t0\t0\t0\t1\t1\t0\t0\t0\t1\t0\t0' + '\t4\t0\t0\t0\t0\t0\t0\t0\t0\t0\t0\n'), + mocker.call('r5\tyjm689\tI\tS288c\t3\t6\t4\t1\t1\t1\t1\t3\t4\t4\t4\t4' + '\t4\t1\t1\t1\t1\t3\t1\t1\t1\t1\t2\t2\t1\t1\t1\t2\t1\t0' + '\t3\t0\t0\t0\t0\t0\t0\t0\t0\t0\t0\n'), + mocker.call('r6\tyjm1381\tI\tS288c\t3\t7\t3\t2\t2\t2\t2\t4\t5\t5\t5' + '\t5\t5\t1\t1\t1\t1\t3\t2\t2\t2\t2\t3\t3\t2\t2\t2\t3\t2' + '\t0\t3\t0\t0\t0\t0\t0\t0\t0\t0\t0\t0\n'), + mocker.call('r7\tyjm689\tI\tS288c\t3\t5\t3\t0\t0\t0\t0\t2\t3\t3\t3\t3' + '\t3\t0\t0\t0\t0\t2\t0\t0\t0\t0\t1\t1\t0\t0\t0\t1\t0\t0' + '\t3\t0\t0\t0\t0\t0\t0\t0\t0\t0\t0\n'), + mocker.call('r8\tyjm1208\tI\tS288c\t3\t4\t20\t0\t0\t0\t0\t0\t0\t0\t0' + '\t0\t0\t0\t0\t0\t0\t0\t0\t0\t0\t0\t0\t0\t0\t0\t0\t0\t0' + '\t0\t0\t0\t0\t0\t0\t0\t0\t0\t0\t0\t0\n'), + mocker.call('r9\tyjm1304\tII\tS288c\t3\t7\t16\t0\t0\t0\t0\t0\t0\t0\t0' + '\t0\t0\t0\t0\t0\t0\t0\t0\t0\t0\t0\t0\t0\t0\t0\t0\t0\t0' + '\t0\t0\t0\t0\t0\t0\t0\t0\t0\t0\t0\t0\n'), + mocker.ANY + ] + + region_output = region_out.getvalue() + region_out.close() + + assert region_output == ( + '#r4\n' + '> S288c 2 5\ntcct\n' + '> CBS432 3 6\ntgca\n' + '> N_45 3 6\ntgca\n' + '> DBVPG6304 3 6\ntgca\n' + '> UWOPS91_917_1 2 5\nttac\n' + '> yjm1381 2 5\ntttc\n' + '> info\n____\n' + '#r6\n' + '> S288c 3 7\ncctag\n' + '> CBS432 4 8\ngcaag\n' + '> N_45 4 8\ngcaag\n' + '> DBVPG6304 4 8\ngcaag\n' + '> UWOPS91_917_1 3 7\ntacag\n' + '> yjm1381 3 7\nttcag\n' + '> info\n___..\n' + '#r5\n' + '> S288c 3 6\nccta\n' + '> CBS432 4 7\ngcaa\n' + '> N_45 4 7\ngcaa\n' + '> DBVPG6304 4 7\ngcaa\n' + '> UWOPS91_917_1 3 6\ntaca\n' + '> yjm689 3 6\nttca\n' + '> info\n___.\n' + '#r7\n' + '> S288c 3 5\ncct\n' + '> CBS432 4 6\ngca\n' + '> N_45 4 6\ngca\n' + '> DBVPG6304 4 6\ngca\n' + '> UWOPS91_917_1 3 5\ntac\n' + '> yjm689 3 5\nttc\n' + '> info\n___\n' + ) + + +def test_run_all_states(summarizer, mocker): + summarizer.config.add_config({ + 'analysis_params': + {'reference': {'name': 'S288c'}, + 'known_states': [ + {'name': 'CBS432'}, + {'name': 'N_45'}, + {'name': 'DBVPG6304'}, + {'name': 'UWOPS91_917_1'} + ], + 'unknown_states': [{'name': 'unknown'}] + } + }) + summarizer.config.set_states() + summarizer.config.set_HMM_symbols() + summarizer.config.set_positions('positions.txt.gz') + summarizer.config.set_labeled_blocks_file( + 'dir/tag/blocks_{state}_labeled.txt') + summarizer.config.set_quality_file('dir/tag/blocks_{state}_quality.txt') + summarizer.config.set_alignment('dir/tag/blocks_{chrom}_{strain}.txt') + summarizer.config.set_regions_files('dir/tag/regions/{state}.fa.gz', + 'dir/tag/regions/{state}.pkl') + summarizer.config.set_masked_file('dir/masked/{strain}_chr{chrom}.txt') + summarizer.config.chromosomes = ['I', 'II'] + assert summarizer.validate_arguments() + + mock_log = mocker.patch('analyze.summarize_region_quality.log') + + with pytest.raises(FileNotFoundError) as e: + summarizer.run() + assert "No such file or directory: 'dir/masked/S288c_chrI.txt'" in str(e) + + assert mock_log.debug.call_args_list == [ + mocker.call('reference index: 0'), + mocker.call("states to analyze: ['S288c', 'CBS432', 'N_45', " + "'DBVPG6304', 'UWOPS91_917_1', 'unknown']"), + mocker.call("known_states ['S288c', 'CBS432', 'N_45', 'DBVPG6304', " + "'UWOPS91_917_1']"), + mocker.call('Sequence_Analyzer init with:'), + mocker.call('masks: dir/masked/{strain}_chr{chrom}.txt'), + mocker.call('alignment: dir/tag/blocks_{chrom}_{strain}.txt')] + + +@pytest.fixture +def flag(): + return summarize.Flag_Info() + + +def test_flag_info_init(flag): + assert flag.__dict__ == { + 'gap_any': None, + 'mask_any': None, + 'unseq_any': None, + 'hmm': None, + 'gap': None, + 'mask': None, + 'unseq': None, + 'match': None + } + + +def test_intialize_flags(flag): + flag.initialize_flags(3, 2) + + aae(flag.gap_any, np.zeros((3), bool)) + aae(flag.mask_any, np.zeros((3), bool)) + aae(flag.unseq_any, np.zeros((3), bool)) + aae(flag.gap, np.zeros((3, 2), bool)) + aae(flag.mask, np.zeros((3, 2), bool)) + aae(flag.unseq, np.zeros((3, 2), bool)) + aae(flag.match, np.zeros((3, 2), bool)) + + flag.initialize_flags(5, 3) + + aae(flag.gap_any, np.zeros((5), bool)) + aae(flag.mask_any, np.zeros((5), bool)) + aae(flag.unseq_any, np.zeros((5), bool)) + aae(flag.gap, np.zeros((5, 3), bool)) + aae(flag.mask, np.zeros((5, 3), bool)) + aae(flag.unseq, np.zeros((5, 3), bool)) + aae(flag.match, np.zeros((5, 3), bool)) + + +def test_add_sequence_flags(flag): + flag.initialize_flags(3, 2) + + other = summarize.Flag_Info() + other.gap = np.array([0, 0, 1], bool) + other.unseq = np.array([1, 0, 0], bool) + other.hmm = np.array([1, 0, 1], bool) + other.match = np.array([0, 1, 1], bool) + + flag.add_sequence_flags(other, 0) + aae(flag.hmm, np.array([1, 0, 1], bool)) + aae(flag.gap_any, np.array([0, 0, 1], bool)) + aae(flag.unseq_any, np.array([1, 0, 0], bool)) + + aae(flag.gap, np.array( + [ + [0, 0], + [0, 0], + [1, 0], + ], bool)) + + aae(flag.unseq, np.array( + [ + [1, 0], + [0, 0], + [0, 0], + ], bool)) + + aae(flag.match, np.array( + [ + [0, 0], + [1, 0], + [1, 0], + ], bool)) + + other = summarize.Flag_Info() + other.gap = np.array([1, 0, 0], bool) + other.unseq = np.array([0, 0, 1], bool) + other.hmm = np.array([0, 1, 0], bool) + other.match = np.array([0, 1, 1], bool) + + flag.add_sequence_flags(other, 1) + aae(flag.hmm, np.array([1, 0, 1], bool)) + aae(flag.gap_any, np.array([1, 0, 1], bool)) + aae(flag.unseq_any, np.array([1, 0, 1], bool)) + + aae(flag.gap, np.array( + [ + [0, 1], + [0, 0], + [1, 0], + ], bool)) + + aae(flag.unseq, np.array( + [ + [1, 0], + [0, 0], + [0, 1], + ], bool)) + + aae(flag.match, np.array( + [ + [0, 0], + [1, 1], + [1, 1], + ], bool)) + + +def test_add_mask_flags(flag): + flag.initialize_flags(3, 2) + + other = summarize.Flag_Info() + other.mask = np.array([0, 0, 1], bool) + + flag.add_mask_flags(other, 0) + aae(flag.mask_any, np.array([0, 0, 1], bool)) + + aae(flag.mask, np.array( + [ + [0, 0], + [0, 0], + [1, 0], + ], bool)) + + other = summarize.Flag_Info() + other.mask = np.array([0, 1, 0], bool) + + flag.add_mask_flags(other, 1) + aae(flag.mask_any, np.array([0, 1, 1], bool)) + + aae(flag.mask, np.array( + [ + [0, 0], + [0, 1], + [1, 0], + ], bool)) + + +def test_encode_info(flag): + flag.initialize_flags(14, 3) + flag.hmm = np.zeros((14), bool) + + flag.gap[0, 0] = True # - + flag.gap[11, 1] = True # - + flag.mask[1, 0] = True # _ + flag.mask[12, 1] = True # _ + flag.match[2, :] = True # . + flag.match[13, :] = True # . + flag.match[(3, 4, 5, 6), 0] = True # b and c + flag.match[(3, 4, 7, 8), 1] = True # b and p + # x is default + flag.hmm[[4, 6, 8, 10]] = True # capitalize + + s = flag.encode_info(master_ind=0, predict_ind=1) + assert s == '-_.bBcCpPxX-_.' + # 01234567890123 + + flag.initialize_flags(0, 3) + flag.hmm = np.zeros((0), bool) + + s = flag.encode_info(master_ind=0, predict_ind=1) + assert s == '' + + +def test_encode_unknown_info(flag): + flag.initialize_flags(5, 2) + + flag.gap_any[0] = True # - + flag.mask_any[1] = True # _ + flag.match[2, :] = True # . + flag.match[3, 0] = True # x + flag.match[4, 1] = True # X + + s = flag.encode_unknown_info(master_ind=0) + assert s == '-_.xX' + s = flag.encode_info(master_ind=0, predict_ind=3) + assert s == '-_.xX' + + flag.initialize_flags(0, 2) + s = flag.encode_unknown_info(master_ind=0) + assert s == '' + + +@pytest.fixture +def region_db(mocker): + mocker.patch( + 'misc.read_table.open', + mocker.mock_open( + read_data='region_id\tstrain\tchromosome\t' + 'predicted_species\tstart\tend\tnum_sites_hmm\n' + 'r4\tyjm1381\tI\tS288c\t24327\t26193\t60\n' + 'r5\tyjm689\tI\tS288c\t24327\t24444\t56\n' + 'r6\tyjm1381\tI\tS288c\t24612\t25439\t18\n' + 'r7\tyjm689\tI\tS288c\t24612\t138647\t13728\n' + 'r8\tyjm1208\tI\tS288c\t25395\t25448\t20\n' + )) + + return summarize.Region_Database('labeled.txt', + 'I', + ['s1', 's2']) + + +def test_region_init(mocker): + mock_open = mocker.patch( + 'misc.read_table.open', + mocker.mock_open( + read_data='region_id\tstrain\tchromosome\t' + 'predicted_species\tstart\tend\tnum_sites_hmm\n' + 'r4\tyjm1381\tI\tS288c\t24327\t26193\t60\n' + 'r5\tyjm689\tI\tS288c\t24327\t24444\t56\n' + 'r6\tyjm1381\tI\tS288c\t24612\t25439\t18\n' + 'r7\tyjm689\tI\tS288c\t24612\t138647\t13728\n' + 'r8\tyjm1208\tI\tS288c\t25395\t25448\t20\n' + 'r9\tyjm1304\tII\tS288c\t25395\t25436\t16\n' + )) + + db = summarize.Region_Database('labeled.txt', + 'I', + ['s1', 's2']) + mock_open.assert_called_with('labeled.txt', 'r') + + assert db.labels == ['region_id', 'strain', 'chromosome', + 'predicted_species', 'start', 'end', 'num_sites_hmm', + 'match_nongap_s1', 'match_nongap_s2', + 'num_sites_nongap_s1', 'num_sites_nongap_s2', + 'match_hmm_s1', 'match_hmm_s2', 'match_nonmask_s1', + 'match_nonmask_s2', 'num_sites_nonmask_s1', + 'num_sites_nonmask_s2', 'count_.', 'count_-', + 'count__', 'count_n', 'count_p', 'count_b', + 'count_c', 'count_x', 'count_N', 'count_P', + 'count_B', 'count_C', 'count_X'] + + assert db.data == { + 'yjm1381': { + 'region_id': ['r4', 'r6'], + 'strain': ['yjm1381', 'yjm1381'], + 'chromosome': ['I', 'I'], + 'predicted_species': ['S288c', 'S288c'], + 'start': ['24327', '24612'], + 'end': ['26193', '25439'], + 'num_sites_hmm': ['60', '18'], + 'match_nongap_s1': [0, 0], + 'num_sites_nongap_s1': [0, 0], + 'match_hmm_s1': [0, 0], + 'match_nonmask_s1': [0, 0], + 'num_sites_nonmask_s1': [0, 0], + 'match_nongap_s2': [0, 0], + 'num_sites_nongap_s2': [0, 0], + 'match_hmm_s2': [0, 0], + 'match_nonmask_s2': [0, 0], + 'num_sites_nonmask_s2': [0, 0], + 'count_.': [0, 0], 'count_-': [0, 0], 'count__': [0, 0], + 'count_n': [0, 0], 'count_p': [0, 0], 'count_b': [0, 0], + 'count_c': [0, 0], 'count_x': [0, 0], 'count_N': [0, 0], + 'count_P': [0, 0], 'count_B': [0, 0], 'count_C': [0, 0], + 'count_X': [0, 0]}, + 'yjm689': { + 'region_id': ['r5', 'r7'], + 'strain': ['yjm689', 'yjm689'], + 'chromosome': ['I', 'I'], + 'predicted_species': ['S288c', 'S288c'], + 'start': ['24327', '24612'], + 'end': ['24444', '138647'], + 'num_sites_hmm': ['56', '13728'], + 'match_nongap_s1': [0, 0], + 'num_sites_nongap_s1': [0, 0], + 'match_hmm_s1': [0, 0], + 'match_nonmask_s1': [0, 0], + 'num_sites_nonmask_s1': [0, 0], + 'match_nongap_s2': [0, 0], + 'num_sites_nongap_s2': [0, 0], + 'match_hmm_s2': [0, 0], + 'match_nonmask_s2': [0, 0], + 'num_sites_nonmask_s2': [0, 0], + 'count_.': [0, 0], 'count_-': [0, 0], 'count__': [0, 0], + 'count_n': [0, 0], 'count_p': [0, 0], 'count_b': [0, 0], + 'count_c': [0, 0], 'count_x': [0, 0], 'count_N': [0, 0], + 'count_P': [0, 0], 'count_B': [0, 0], 'count_C': [0, 0], + 'count_X': [0, 0]}, + 'yjm1208': { + 'region_id': ['r8'], + 'strain': ['yjm1208'], + 'chromosome': ['I'], + 'predicted_species': ['S288c'], + 'start': ['25395'], + 'end': ['25448'], + 'num_sites_hmm': ['20'], + 'match_nongap_s1': [0], + 'num_sites_nongap_s1': [0], + 'match_hmm_s1': [0], + 'match_nonmask_s1': [0], + 'num_sites_nonmask_s1': [0], + 'match_nongap_s2': [0], + 'num_sites_nongap_s2': [0], + 'match_hmm_s2': [0], + 'match_nonmask_s2': [0], + 'num_sites_nonmask_s2': [0], + 'count_.': [0], 'count_-': [0], 'count__': [0], + 'count_n': [0], 'count_p': [0], 'count_b': [0], + 'count_c': [0], 'count_x': [0], 'count_N': [0], + 'count_P': [0], 'count_B': [0], 'count_C': [0], + 'count_X': [0]}} + + assert db.info_string_symbols == list('.-_npbcxNPBCX') + assert db.label_prefixes == [ + 'match_nongap', + 'num_sites_nongap', + 'match_hmm', + 'match_nonmask', + 'num_sites_nonmask'] + + +def test_has_strain(region_db): + assert region_db.has_strain('yjm689') + assert not region_db.has_strain('yjm688') + + +def test_get_entries(region_db): + result = [('r4', 24327, 26193), + ('r6', 24612, 25439)] + for i, (r_id, start, end) in enumerate(region_db.get_entries('yjm1381')): + entry = result[i] + assert r_id == entry[0] + assert start == entry[1] + assert end == entry[2] + + result = [('r5', 24327, 24444), + ('r7', 24612, 138647)] + for i, (r_id, start, end) in enumerate(region_db.get_entries('yjm689')): + entry = result[i] + assert r_id == entry[0] + assert start == entry[1] + assert end == entry[2] + + result = [('r8', 25395, 25448)] + for i, (r_id, start, end) in enumerate(region_db.get_entries('yjm1208')): + entry = result[i] + assert r_id == entry[0] + assert start == entry[1] + assert end == entry[2] + + with pytest.raises(ValueError) as e: + list(region_db.get_entries('asdf')) + assert 'Region Database does not contain strain asdf' in str(e) + + +def test_set_region(region_db): + region_db.set_region('yjm1381', + 0, + 's1', + (10, 20), + (30, 40), + (50, 60)) + + ds = region_db.data['yjm1381'] + + assert ds['num_sites_hmm'][0] == 20 + assert ds[f'match_hmm_s1'][0] == 10 + assert ds[f'match_nongap_s1'][0] == 30 + assert ds[f'num_sites_nongap_s1'][0] == 40 + assert ds[f'match_nonmask_s1'][0] == 50 + assert ds[f'num_sites_nonmask_s1'][0] == 60 + + region_db.set_region('yjm1381', + 1, + 's2', + (11, None), + (31, 41), + (51, 61)) + + ds = region_db.data['yjm1381'] + + # retained from initial value + assert ds['num_sites_hmm'][1] == '18' + assert ds[f'match_hmm_s2'][1] == 11 + assert ds[f'match_nongap_s2'][1] == 31 + assert ds[f'num_sites_nongap_s2'][1] == 41 + assert ds[f'match_nonmask_s2'][1] == 51 + assert ds[f'num_sites_nonmask_s2'][1] == 61 + + +def test_generate_output(region_db): + initial_lines = [ + 'r4\tyjm1381\tI\tS288c\t24327\t26193\t60', + 'r5\tyjm689\tI\tS288c\t24327\t24444\t56', + 'r6\tyjm1381\tI\tS288c\t24612\t25439\t18', + 'r7\tyjm689\tI\tS288c\t24612\t138647\t13728', + 'r8\tyjm1208\tI\tS288c\t25395\t25448\t20', + ] + for i, line in enumerate(region_db.generate_output()): + assert line == initial_lines[i] + '\t' + '\t'.join(['0']*23) + '\n' + + +def test_generate_header(region_db, mocker): + assert region_db.generate_header() == ( + 'region_id\tstrain\tchromosome\tpredicted_species\tstart\tend' + '\tnum_sites_hmm\tmatch_nongap_s1\tmatch_nongap_s2' + '\tnum_sites_nongap_s1\tnum_sites_nongap_s2\tmatch_hmm_s1' + '\tmatch_hmm_s2\tmatch_nonmask_s1\tmatch_nonmask_s2' + '\tnum_sites_nonmask_s1\tnum_sites_nonmask_s2\tcount_.\tcount_-' + '\tcount__\tcount_n\tcount_p\tcount_b\tcount_c\tcount_x\tcount_N' + '\tcount_P\tcount_B\tcount_C\tcount_X\n') + + mocker.patch( + 'misc.read_table.open', + mocker.mock_open( + read_data='region_id\tstrain\tchromosome\n' + 'r4\tyjm1381\tII\n' + )) + + # fewer headers, states, same counts + db = summarize.Region_Database('labeled.txt', 'II', ['st1']) + assert db.generate_header() == ( + 'region_id\tstrain\tchromosome\tmatch_nongap_st1' + '\tnum_sites_nongap_st1\tmatch_hmm_st1\tmatch_nonmask_st1' + '\tnum_sites_nonmask_st1\tcount_.\tcount_-\tcount__\tcount_n\tcount_p' + '\tcount_b\tcount_c\tcount_x\tcount_N\tcount_P\tcount_B\tcount_C' + '\tcount_X\n') + + +def test_update_counts(region_db): + syms = list('.-_npbcxNPBCX') + + region_db.update_counts('yjm1381', 0, '.-_npbcxNPBCX') + for s in syms: + assert region_db.data['yjm1381'][f'count_{s}'][0] == 1 + + region_db.update_counts('yjm1381', 1, '.-_npbcxNPBCX'*20) + for s in syms: + assert region_db.data['yjm1381'][f'count_{s}'][1] == 20 + + region_db.update_counts('yjm689', 1, '._pcNBX'*40) + for i, s in enumerate(syms): + if i % 2 == 0: + assert region_db.data['yjm689'][f'count_{s}'][1] == 40 + else: + assert region_db.data['yjm689'][f'count_{s}'][1] == 0 + + # same values, gets overwritten + region_db.update_counts('yjm689', 1, '-nbxPC'*10) + for i, s in enumerate(syms): + if i % 2 == 0: + assert region_db.data['yjm689'][f'count_{s}'][1] == 0 + else: + assert region_db.data['yjm689'][f'count_{s}'][1] == 10 + + +@pytest.fixture +def region_context(mocker): + writer = summarize.Region_Writer( + 'test_region.gz', 'test_index.pkl', 's1 s2'.split()) + mock_gzip = mocker.patch('analyze.summarize_region_quality.gzip') + mock_open = mocker.patch('analyze.summarize_region_quality.open') + mock_pickle = mocker.patch('analyze.summarize_region_quality.pickle.dump') + + return (writer, mock_gzip, mock_open, mock_pickle) + + +def test_region_context(region_context, mocker): + writer, mock_gzip, mock_open, mock_pickle = region_context + + assert writer.region_file == 'test_region.gz' + assert writer.index_file == 'test_index.pkl' + assert writer.index == {} + assert writer.known_states == 's1 s2'.split() + + mock_gzip.assert_not_called() + mock_open.assert_not_called() + + writer = writer.__enter__() + assert writer.region_writer is not None + mock_gzip.open.assert_called_once_with('test_region.gz', 'wt') + mock_open.assert_not_called() + + writer.__exit__(None, None, None) + mock_gzip.open.return_value.close.assert_called_once() + mock_open.assert_called_once_with('test_index.pkl', 'wb') + mock_pickle.assert_called_once_with({}, mocker.ANY) + + +def test_region_write_header(region_context): + writer, _, _, _ = region_context + output = StringIO() + writer.region_writer = output + writer.write_header('r123') + writer.write_header('f13') + writer.write_header('q3') + writer.write_header('213') + + # note, 213 is changed to 13 when storing in index, overwriting f13 + assert output.getvalue() == '#r123\n#f13\n#q3\n#213\n' + assert writer.index == {3: 11, 13: 15, 123: 0} + + +def test_region_write_sequences(region_context): + writer, _, _, _ = region_context + output = StringIO() + writer.region_writer = output + + writer.write_sequences('strain', + [ + list(range(10)), + list(range(2, 12)), + list(range(4, 14)) + ], + np.array([ + list('0123456789')*3, + list('123456789')*3, + list('23456789')*3, + ]), + (5, 9)) + assert output.getvalue() == ( + '> s1 5 9\n' + '56789\n' + '> s2 3 7\n' + '67891\n' + '> strain 1 5\n' + '78923\n') + + +def test_region_write_info_string(region_context): + writer, _, _, _ = region_context + output = StringIO() + writer.region_writer = output + + writer.write_info_string('this is my string') + assert output.getvalue() == ( + '> info\n' + 'this is my string\n') + + +def test_position_reader_context(mocker): + mock_gzip = mocker.patch('analyze.summarize_region_quality.gzip.open') + reader = summarize.Position_Reader('test_file.gz') + assert reader.position_file == 'test_file.gz' + assert reader.last_position == 0 + + reader = reader.__enter__() + assert reader.reader is not None + mock_gzip.assert_called_once_with('test_file.gz', 'rt') + + reader = reader.__exit__(None, None, None) + mock_gzip.return_value.close.assert_called_once() + + +def test_position_reader_next_line(): + reader = summarize.Position_Reader('mock') + positions = StringIO( + 'yjm1460\tI\t25957\t25958\t25961\t25963\n' + 'yjm1463\tI\t25665\t25668\t25670\t25676\n' + 'yjm1464\tI\t25665\t25668\t25670\t25676\n' + 'yjm1464\tII\t25665\t25668\t25670\t25676\n' + 'yjm1460\tIII\t25957\t25958\t25961\t25963\n' + ) + reader.reader = positions + + assert reader.next_line() == \ + 'yjm1460\tI\t25957\t25958\t25961\t25963\n' + assert reader.last_position == 0 + + assert reader.next_line() == \ + 'yjm1463\tI\t25665\t25668\t25670\t25676\n' + assert reader.last_position == 34 + + assert reader.next_line() == \ + 'yjm1464\tI\t25665\t25668\t25670\t25676\n' + assert reader.last_position == 68 + + assert reader.next_line() == \ + 'yjm1464\tII\t25665\t25668\t25670\t25676\n' + assert reader.last_position == 102 + + assert reader.next_line() == \ + 'yjm1460\tIII\t25957\t25958\t25961\t25963\n' + assert reader.last_position == 137 + + assert reader.next_line() == '' + assert reader.last_position == 173 + + +def test_position_reader_get_positions(region_db): + # region_db contains yjm1381, 689, 1208 + reader = summarize.Position_Reader('mock') + positions = StringIO( + 'yjm1381\tI\t25957\t25958\t25961\t25963\n' + 'yjm689\tI\t25665\t25668\t25670\t25676\n' + 'yjm1464\tI\t25665\t25668\t25670\t25676\n' + 'yjm1464\tII\t25665\t25668\t25670\t25676\n' + 'yjm1381\tIII\t25957\t25958\t25961\t25963\n' + ) + reader.reader = positions + + results = [ + ('yjm1381', np.array([25957, 25958, 25961, 25963]), 0), + ('yjm689', np.array([25665, 25668, 25670, 25676]), 34), + ] + + for i, (strain, ps) in enumerate(reader.get_positions(region_db, 'I')): + assert strain == results[i][0] + aae(ps, results[i][1]) + assert reader.last_position == results[i][2] + + assert i == 1 + assert reader.last_position == 101 + + i = None + # won't run because chromosome is not in order (on II) + for i, (strain, ps) in enumerate(reader.get_positions(region_db, 'III')): + pass + + assert i is None + assert reader.last_position == 101 + + # won't return since strain not in regions, will change last position + for i, (strain, ps) in enumerate(reader.get_positions(region_db, 'II')): + pass + + # if loop has not run + assert i is None + assert reader.last_position == 136 + + for i, (strain, ps) in enumerate(reader.get_positions(region_db, 'III')): + assert strain == 'yjm1381' + aae(ps, np.array([25957, 25958, 25961, 25963])) + assert reader.last_position == 136 + + assert i == 0 + assert reader.last_position == 172 + + +def test_quality_writer_context(mocker): + mock_open = mocker.patch('analyze.summarize_region_quality.open') + writer = summarize.Quality_Writer('test_file.txt') + assert writer.filename == 'test_file.txt' + assert writer.first_write is True + + writer = writer.__enter__() + assert writer.writer is not None + mock_open.assert_called_once_with('test_file.txt', 'w') + + writer = writer.__exit__(None, None, None) + mock_open.return_value.close.assert_called_once() + + +def test_quality_writer_write_quality(mocker): + writer = summarize.Quality_Writer('test') + output = StringIO() + writer.writer = output + + assert writer.first_write is True + mock_region = mocker.MagicMock() + mock_region.generate_header.return_value = 'header1\n' + mock_region.generate_output.return_value = [ + 'region1\n', + 'region2\n', + ] + + writer.write_quality(mock_region) + assert writer.first_write is False + + mock_region = mocker.MagicMock() + mock_region.generate_header.return_value = 'header2\n' + mock_region.generate_output.return_value = [ + 'a\tdifferent\tformat\n', + ] + + writer.write_quality(mock_region) + + assert writer.first_write is False + assert output.getvalue() == ( + 'header1\n' + 'region1\n' + 'region2\n' + 'a\tdifferent\tformat\n' + ) + + +def test_sequence_analyzer_init(): + sa = summarize.Sequence_Analyzer('mask', + 'alignment', + ['s1', 's2'], + ['s1', 's2'], + ['I', 'II'], + {}) + + assert sa.masks == 'mask' + assert sa.alignments == 'alignment' + assert sa.known_states == ['s1', 's2'] + assert sa.chromosomes == ['I', 'II'] + assert sa.symbols == {} + + +@pytest.fixture +def sa(): + symbols = { + 'match': '+', + 'mismatch': '-', + 'unknown': '?', + 'unsequenced': 'n', + 'gap': '-', + 'unaligned': '?', + 'masked': 'x' + } + return summarize.Sequence_Analyzer('', '', [], [], [], symbols) + + +def test_SA_build_masked_sites(sa, mocker): + mock_open = mocker.patch( + 'analyze.summarize_region_quality.open', + mocker.mock_open( + read_data='>header\n' + '0 - 2\n' + '22 - 25\n' + '32 - 33\n' + )) + + sa.chromosomes = ['I', 'II'] + sa.known_states = ['s1', 's2'] + sa.interval_states = ['i1', 'i2'] + sa.masks = '{strain}_chr{chrom}_intervals.txt' + sa.build_masked_sites() + sites = sa.masked_sites + + assert mock_open.call_args_list == [ + mocker.call('i1_chrI_intervals.txt', 'r'), + mocker.call('i2_chrI_intervals.txt', 'r'), + mocker.call('i1_chrII_intervals.txt', 'r'), + mocker.call('i2_chrII_intervals.txt', 'r'), + ] + + expected = { + 'I': { + 's1': np.array([0, 1, 2, 22, 23, 24, 25, 32, 33]), + 's2': np.array([0, 1, 2, 22, 23, 24, 25, 32, 33]), + }, + 'II': { + 's1': np.array([0, 1, 2, 22, 23, 24, 25, 32, 33]), + 's2': np.array([0, 1, 2, 22, 23, 24, 25, 32, 33]), + } + } + for chrom in sites: + for state in sites[chrom]: + aae(sites[chrom][state], expected[chrom][state]) + + mock_open = mocker.patch( + 'analyze.summarize_region_quality.open', + mocker.mock_open( + read_data='>header\n' + )) + + sa.build_masked_sites() + sites = sa.masked_sites + + assert mock_open.call_args_list == [ + mocker.call('i1_chrI_intervals.txt', 'r'), + mocker.call('i2_chrI_intervals.txt', 'r'), + mocker.call('i1_chrII_intervals.txt', 'r'), + mocker.call('i2_chrII_intervals.txt', 'r'), + ] + + expected = { + 'I': { + 's1': np.array([]), + 's2': np.array([]), + }, + 'II': { + 's1': np.array([]), + 's2': np.array([]), + } + } + for chrom in sites: + for state in sites[chrom]: + aae(sites[chrom][state], expected[chrom][state]) + + +def test_read_masked_sites(sa, mocker): + mock_open = mocker.patch( + 'analyze.summarize_region_quality.open', + mocker.mock_open( + read_data='>header\n' + '0 - 2\n' + '22 - 25\n' + '32 - 33\n' + )) + + sa.masks = '{chrom}_{strain}_mock' + result = sa.read_masked_sites('I', 'str') + assert mock_open.call_args_list == [ + mocker.call('I_str_mock', 'r') + ] + aae(result, np.array('0 1 2 22 23 24 25 32 33'.split(), dtype=int)) + + +def test_convert_intervals_to_sites(sa): + sites = sa.convert_intervals_to_sites([]) + assert sites == approx([]) + + sites = sa.convert_intervals_to_sites([(1, 2)]) + assert sites == approx([1, 2]) + + sites = sa.convert_intervals_to_sites([(1, 2), (4, 6)]) + assert sites == approx([1, 2, 4, 5, 6]) + + +def test_read_masked_intervals(sa, mocker): lines = StringIO('') mocked_file = mocker.patch('analyze.summarize_region_quality.open', return_value=lines) - intervals = summarize.read_masked_intervals('mocked') + intervals = sa.read_masked_intervals('mocked') mocked_file.assert_called_with('mocked', 'r') assert intervals == [] lines = StringIO('I am a header') mocked_file = mocker.patch('analyze.summarize_region_quality.open', return_value=lines) - intervals = summarize.read_masked_intervals('mocked') + intervals = sa.read_masked_intervals('mocked') mocked_file.assert_called_with('mocked', 'r') assert intervals == [] @@ -25,167 +1100,265 @@ def test_read_masked_intervals(mocker): mocked_file = mocker.patch('analyze.summarize_region_quality.open', return_value=lines) with pytest.raises(ValueError): - intervals = summarize.read_masked_intervals('mocked') + intervals = sa.read_masked_intervals('mocked') mocked_file.assert_called_with('mocked', 'r') lines = StringIO('I am a header\n' '1 and 2') mocked_file = mocker.patch('analyze.summarize_region_quality.open', return_value=lines) - intervals = summarize.read_masked_intervals('mocked') + intervals = sa.read_masked_intervals('mocked') assert intervals == [(1, 2)] -def test_convert_intervals_to_sites(): - sites = summarize.convert_intervals_to_sites([]) - assert sites == approx([]) +def test_get_stats(sa): + hmm, nongap, nonmask = sa.get_stats( + np.array(list('abc')), + np.array(list('abd')), + 0, + [1, 2, 5], + ([0], [])) - sites = summarize.convert_intervals_to_sites([(1, 2)]) - assert sites == approx([1, 2]) + assert hmm[0] == 1 + assert hmm[1] == 2 + assert hmm[2].gap == approx([False] * 3) + assert hmm[2].hmm == approx([False, True, True]) + assert hmm[2].match == approx([True, True, False]) + assert hmm[2].unseq == approx([False] * 3) - sites = summarize.convert_intervals_to_sites([(1, 2), (4, 6)]) - assert sites == approx([1, 2, 4, 5, 6]) + assert nongap[0] == 1 + assert nongap[1] == 1 + + assert nonmask[0] == 1 + assert nonmask[1] == 2 + assert nonmask[2].mask == approx([True, False, False]) + + +def test_process_alignment(sa, region_context, mocker): + sa.known_states = ['k1', 'k2'] + sa.masked_sites = { + 'I': { + 'k1': np.array([0, 1, 2]), + 'k2': np.array([2, 4, 6]), + }} + region_writer, _, _, _ = region_context + output = StringIO() + region_writer.region_writer = output + region_writer.known_states = ['k1', 'k2'] + mocker.patch('analyze.summarize_region_quality.read_fasta.read_fasta', + return_value=('', + np.asarray([ + list('--gatcctag--'), + list('-agatgcaag-c'), + list('-a-att-cagt-'), + ]))) + mocker.patch.object(summarize.Sequence_Analyzer, 'read_masked_sites', + return_value=np.array( + '2 3 4'.split(), int)) + + # for region database + mocker.patch( + 'misc.read_table.open', + mocker.mock_open( + read_data='region_id\tstrain\tchromosome\t' + 'predicted_species\tstart\tend\tnum_sites_hmm\n' + 'r4\ts1\tI\tS288c\t1\t3\t60\n' + 'r5\tyjm689\tI\tS288c\t24327\t24444\t56\n' + 'r6\ts1\tI\tS288c\t5\t7\t18\n' + 'r7\tyjm689\tI\tS288c\t24612\t138647\t13728\n' + 'r8\tyjm1208\tI\tS288c\t25395\t25448\t20\n' + )) + rd = summarize.Region_Database('labeled_file.txt', + 'I', + ['k1', 'k2']) + + sa.process_alignment(0, + 1, + 'I', + 's1', + np.array([1, 3, 5, 7]), + rd, + region_writer) + + # these are inputs with modified hmm (last column) + region_input = ['r4\ts1\tI\tS288c\t1\t3\t2\t', + 'r5\tyjm689\tI\tS288c\t24327\t24444\t56\t', + 'r6\ts1\tI\tS288c\t5\t7\t2\t', + 'r7\tyjm689\tI\tS288c\t24612\t138647\t13728\t', + 'r8\tyjm1208\tI\tS288c\t25395\t25448\t20\t'] + # these are the counts + region_output = [ + '\t'.join('2 2 3 3 1 1 0 0 0 0 0 0 3 ' + '0 0 0 0 0 0 0 0 0 0'.split()) + '\n', + '\t'.join(['0' for _ in range(23)]) + '\n', + '\t'.join('2 2 3 3 1 1 2 2 2 2 2 0 1 ' + '0 0 0 0 0 0 0 0 0 0'.split()) + '\n', + '\t'.join(['0' for _ in range(23)]) + '\n', + '\t'.join(['0' for _ in range(23)]) + '\n', + ] + + for i, line in enumerate(rd.generate_output()): + assert line == region_input[i] + region_output[i] + + region_writer_output = ( + '#r4\n' + '> k1 1 3\n' + 'atc\n' + '> k2 2 4\n' + 'atg\n' + '> s1 1 3\n' + 'att\n' + '> info\n' + '___\n' + '#r6\n' + '> k1 5 7\n' + 'tag\n' + '> k2 6 8\n' + 'aag\n' + '> s1 4 6\n' + 'cag\n' + '> info\n' + '_..\n' + ) + assert output.getvalue() == region_writer_output + + +def test_get_indices(sa, mocker): + sa.known_states = ['k1', 'k2'] + sa.masked_sites = { + 'I': { + 'k1': np.array([0, 1, 2]), + 'k2': np.array([2, 4, 6]), + }} + sa.alignments = 'align_{chrom}_{strain}' + sa.masks = 'mask_{chrom}_{strain}' + mock_fasta = mocker.patch( + 'analyze.summarize_region_quality.read_fasta.read_fasta', + return_value=('', + np.asarray([ + list('--gatcctag--'), + list('-agatgcaag-c'), + list('-a-att-cagt-'), + ]))) + mock_mask = mocker.patch.object(summarize.Sequence_Analyzer, + 'read_masked_sites', + return_value=np.array( + '2 3 4'.split(), int)) + seq, align, mask = sa.get_indices('I', 's1') -def test_index_alignment_by_reference(): - assert summarize.gp.gap_symbol == '-' + mock_fasta.assert_called_once_with('align_I_s1') + mock_mask.assert_called_once_with('I', 's1') - output = summarize.index_alignment_by_reference(np.array(list('abc'))) + aae(seq, np.asarray([ + list('--gatcctag--'), + list('-agatgcaag-c'), + list('-a-att-cagt-'), + ])) + + result = [ + np.array(range(2, 10)), + np.array(list(range(1, 10)) + [11]), + np.array([1, 3, 4, 5, 7, 8, 9, 10]) + ] + for i, a in enumerate(align): + aae(a, result[i]) + + result = [ + np.array([2, 3, 4]), + np.array([3, 5, 7]), + np.array([4, 5, 7]), + ] + for i, m in enumerate(mask): + aae(m, result[i]) + + +def test_get_slice(sa): + alignment = np.array([1, 2, 3, 4, 5]) + ps_align = np.array([2, 4]) + + start, end = sa.get_slice(1, 3, alignment, ps_align) + assert start == 2 + assert end == 4 + + with pytest.raises(ValueError) as e: + sa.get_slice(0, 3, alignment, ps_align) + assert 'Slice not found in position alignment' in str(e) + + with pytest.raises(ValueError) as e: + sa.get_slice(1, 4, alignment, ps_align) + assert 'Slice not found in position alignment' in str(e) + + +def test_index_alignment_by_reference(sa): + output = sa.index_alignment_by_reference(np.array(list('abc'))) assert output == approx([0, 1, 2]) - output = summarize.index_alignment_by_reference(np.array(list('a-b-c'))) + output = sa.index_alignment_by_reference(np.array(list('a-b-c'))) assert output == approx([0, 2, 4]) -def test_seq_id_hmm(): - assert summarize.gp.gap_symbol == '-' - assert summarize.gp.unsequenced_symbol == 'n' - - match, sites, d = summarize.seq_id_hmm(np.array(list('abd')), - np.array(list('abc')), - 0, [1, 2, 5]) - assert match == 1 # only count matches in included sites - assert sites == 2 # included, not matching - assert d['gap_flag'] == approx([False] * 3) - assert d['hmm_flag'] == approx([False, True, True]) - assert d['match'] == approx([True, True, False]) - assert d['unseq_flag'] == approx([False] * 3) - assert len(d) == 4 - - match, sites, d = summarize.seq_id_hmm(np.array(list('n-d')), - np.array(list('--c')), - 1, [3, 5]) - assert match == 0 - assert sites == 1 - assert d['gap_flag'] == approx([True, True, False]) - assert d['hmm_flag'] == approx([False, False, True]) - assert d['match'] == approx([False, True, False]) - assert d['unseq_flag'] == approx([True, False, False]) - assert len(d) == 4 - - with pytest.raises(AssertionError) as e: - match, sites, d = summarize.seq_id_hmm(np.array(list('n-d')), - np.array(list('--c')), - 1, [2, 5]) - assert '- - 1' in str(e) - - with pytest.raises(AssertionError) as e: - match, sites, d = summarize.seq_id_hmm(np.array(list('n-d')), - np.array(list('--c')), - 1, [1, 5]) - assert 'n - 0' in str(e) - - -def test_seq_id_unmasked(): - assert summarize.gp.gap_symbol == '-' - assert summarize.gp.unsequenced_symbol == 'n' - - match, sites, d = summarize.seq_id_unmasked(np.array(list('abd')), - np.array(list('abc')), - 0, [], []) +def test_seq_id_unmasked(sa): + match, sites, info = sa.seq_id_unmasked(np.array(list('abd')), + np.array(list('abc')), + 0, [], []) assert match == 2 assert sites == 3 - assert d['mask_flag'] == approx([False, False, False]) + assert info.mask == approx([False, False, False]) - match, sites, d = summarize.seq_id_unmasked(np.array(list('abd')), - np.array(list('abc')), - 0, [0], []) + match, sites, info = sa.seq_id_unmasked(np.array(list('abd')), + np.array(list('abc')), + 0, [0], []) assert match == 1 assert sites == 2 - assert d['mask_flag'] == approx([True, False, False]) + assert info.mask == approx([True, False, False]) - match, sites, d = summarize.seq_id_unmasked(np.array(list('abd')), - np.array(list('abc')), - 2, [0], [1]) + match, sites, info = sa.seq_id_unmasked(np.array(list('abd')), + np.array(list('abc')), + 2, [0], [1]) assert match == 2 assert sites == 3 - assert d['mask_flag'] == approx([False, False, False]) + assert info.mask == approx([False, False, False]) - match, sites, d = summarize.seq_id_unmasked(np.array(list('abd')), - np.array(list('abc')), - 0, [0], [1]) + match, sites, info = sa.seq_id_unmasked(np.array(list('abd')), + np.array(list('abc')), + 0, [0], [1]) assert match == 0 assert sites == 1 - assert d['mask_flag'] == approx([True, True, False]) - - -def test_make_info_string(): - len_seqx = 14 - len_states = 3 - info = {'hmm_flag': np.zeros((len_seqx), bool), - 'gap_flag': np.zeros((len_seqx, len_states), bool), - 'mask_flag': np.zeros((len_seqx, len_states), bool), - 'match_flag': np.zeros((len_seqx, len_states), bool)} - - info['gap_flag'][0, 0] = True # - - info['gap_flag'][11, 1] = True # - - info['mask_flag'][1, 0] = True # _ - info['mask_flag'][12, 1] = True # _ - info['match_flag'][2, :] = True # . - info['match_flag'][13, :] = True # . - info['match_flag'][(3, 4, 5, 6), 0] = True # b and c - info['match_flag'][(3, 4, 7, 8), 1] = True # b and p - # x is default - info['hmm_flag'][[4, 6, 8, 10]] = True # capitalize + assert info.mask == approx([True, True, False]) - s = summarize.make_info_string(info, master_ind=0, predict_ind=1) - assert s == '-_.bBcCpPxX-_.' - # 01234567890123 - - len_seqx = 0 - len_states = 3 - info = {'hmm_flag': np.zeros((len_seqx), bool), - 'gap_flag': np.zeros((len_seqx, len_states), bool), - 'mask_flag': np.zeros((len_seqx, len_states), bool), - 'match_flag': np.zeros((len_seqx, len_states), bool)} - - s = summarize.make_info_string(info, master_ind=0, predict_ind=1) - assert s == '' - - -def test_info_string_unknown(): - len_seqx = 5 - len_states = 2 - info = {'gap_any_flag': np.zeros((len_seqx), bool), - 'mask_any_flag': np.zeros((len_seqx), bool), - 'match_flag': np.zeros((len_seqx, len_states), bool)} - info['gap_any_flag'][0] = True # - - info['mask_any_flag'][1] = True # _ - info['match_flag'][2, :] = True # . - info['match_flag'][3, 0] = True # x - info['match_flag'][4, 1] = True # X +def test_seq_id_hmm(sa): + match, sites, info = sa.seq_id_hmm(np.array(list('abd')), + np.array(list('abc')), + 0, [1, 2, 5]) + assert match == 1 # only count matches in included sites + assert sites == 2 # included, not matching + assert info.gap == approx([False] * 3) + assert info.hmm == approx([False, True, True]) + assert info.match == approx([True, True, False]) + assert info.unseq == approx([False] * 3) - s = summarize.make_info_string_unknown(info, master_ind=0) - assert s == '-_.xX' - s = summarize.make_info_string(info, master_ind=0, predict_ind=3) - assert s == '-_.xX' + match, sites, info = sa.seq_id_hmm(np.array(list('n-d')), + np.array(list('--c')), + 1, [3, 5]) + assert match == 0 + assert sites == 1 + assert info.gap == approx([True, True, False]) + assert info.hmm == approx([False, False, True]) + assert info.match == approx([False, True, False]) + assert info.unseq == approx([True, False, False]) - len_seqx = 0 - len_states = 2 - info = {'gap_any_flag': np.zeros((len_seqx), bool), - 'mask_any_flag': np.zeros((len_seqx), bool), - 'match_flag': np.zeros((len_seqx, len_states), bool)} + with pytest.raises(ValueError) as e: + match, sites, d = sa.seq_id_hmm(np.array(list('n-d')), + np.array(list('--c')), + 1, [2, 5]) + assert ('Need to skip site specified as included ' + f'seq1: -, seq2: -, index: 1') in str(e) - s = summarize.make_info_string_unknown(info, master_ind=0) - assert s == '' + with pytest.raises(ValueError) as e: + match, sites, d = sa.seq_id_hmm(np.array(list('n-d')), + np.array(list('--c')), + 1, [1, 5]) + assert ('Need to skip site specified as included ' + f'seq1: n, seq2: -, index: 0') in str(e) diff --git a/code/test/analyze/test_summarize_region_quality_main.py b/code/test/analyze/test_summarize_region_quality_main.py deleted file mode 100644 index 065a293..0000000 --- a/code/test/analyze/test_summarize_region_quality_main.py +++ /dev/null @@ -1,63 +0,0 @@ -import analyze.summarize_region_quality_main as main -from io import StringIO - - -def test_main(mocker): - # setup global params to match expectations - mocker.patch( - 'analyze.summarize_region_quality_main.gp.analysis_out_dir_absolute', - 'dir/') - mocker.patch( - 'analyze.summarize_strain_states_main.predict.process_predict_args', - return_value={ - 'known_states': ['S288c', 'CBS432', 'N_45', - 'DBVPG6304', 'UWOPS91_917_1'], - 'states': ['S288c', 'CBS432', 'N_45', - 'DBVPG6304', 'UWOPS91_917_1', 'unknown'], - 'tag': 'tag' - }) - mocker.patch('analyze.summarize_region_quality_main.gp.chrms', - ['I', 'II', 'III', 'IV', 'V', 'VI', 'VII', 'VIII', 'IX', - 'X', 'XI', 'XII', 'XIII', 'XIV', 'XV', 'XVI']) - - mocker.patch('sys.argv', - "test.py 5 tag .001 viterbi 10000 .025 10000 .025 \ - 10000 .025 10000 .025 unknown 1000 .01".split()) - mocker.patch('analyze.summarize_region_quality_main.os.path.isdir', - return_value=True) - # TODO check call arguments - mocker.patch('misc.read_table.read_table_columns', - return_value=({'s': {'region_id': []}}, ['region_id'])) - mocker.patch('analyze.summarize_region_quality_main.read_masked_intervals', - return_value=[(1, 2)]) - lines = StringIO('') - mocked_file = mocker.patch( - 'analyze.summarize_region_quality_main.gzip.open', - return_value=lines) - - mocked_file = mocker.patch('analyze.summarize_region_quality_main.open', - mocker.mock_open()) - - main.main() - - assert mocked_file.call_count == 2 - mocked_file.assert_any_call( - 'dir/tag/blocks_unknown_tag_quality.txt', 'w') - mocked_file.assert_any_call( - 'dir/tag/regions/unknown.pkl', 'wb') - - # just headers - states = ['S288c', 'CBS432', 'N_45', 'DBVPG6304', 'UWOPS91_917_1'] - symbols = list('.-_npbcxNPBCX') - mocked_file().write.assert_has_calls([ - mocker.call('\t'.join(['region_id'] + - ['match_nongap_' + x for x in states] + - ['num_sites_nongap_' + x for x in states] + - ['match_hmm_' + x for x in states] + - ['match_nonmask_' + x for x in states] + - ['num_sites_nonmask_' + x for x in states] + - ['count_' + x for x in symbols] - ) - + '\n') - ]) - assert True diff --git a/code/test/helper_scripts/run_summarize_region_quality.slurm.sh b/code/test/helper_scripts/run_summarize_region_quality.slurm.sh index 4ca8ed9..b9a433b 100755 --- a/code/test/helper_scripts/run_summarize_region_quality.slurm.sh +++ b/code/test/helper_scripts/run_summarize_region_quality.slurm.sh @@ -5,11 +5,15 @@ #SBATCH -n 1 #SBATCH -o "/tigress/tcomi/aclark4_temp/results/summarize_%A_%a" -export PYTHONPATH=/home/tcomi/projects/aclark4_introgression/code/ - module load anaconda3 conda activate introgression3 -ARGS="_test .001 viterbi 10000 .025 10000 .025 10000 .025 10000 .025 unknown 1000 .01" +config=/home/tcomi/projects/aclark4_introgression/code/config.yaml -python ${PYTHONPATH}analyze/summarize_region_quality_main.py $SLURM_ARRAY_TASK_ID $ARGS +introgression \ + --config $config \ + --log-file test.log \ + -vvvv \ + summarize-regions \ + --state N_45 \ + --state CBS432 diff --git a/code/test/helper_scripts/test_predict.slurm b/code/test/helper_scripts/test_predict.slurm index 1a55ef9..8f6f926 100755 --- a/code/test/helper_scripts/test_predict.slurm +++ b/code/test/helper_scripts/test_predict.slurm @@ -12,7 +12,6 @@ conda activate introgression3 introgression \ --config $config \ - -vv \ --log-file test.log \ + -vv \ predict - From 1ac85ed6381b74e58e7403f037ed8f687eb87503 Mon Sep 17 00:00:00 2001 From: Troy Comi Date: Mon, 13 May 2019 13:17:08 -0400 Subject: [PATCH 22/33] Combined filter methods Combined two filter steps, helper functions, and threshold sweep into a single file for further refactoring. --- code/analyze/filter_1_main.py | 1 + code/analyze/filter_regions.py | 272 ++++++++++++++ code/test/analyze/test_filter_helpers.py | 1 - code/test/analyze/test_filter_regions.py | 450 +++++++++++++++++++++++ 4 files changed, 723 insertions(+), 1 deletion(-) create mode 100644 code/analyze/filter_regions.py create mode 100644 code/test/analyze/test_filter_regions.py diff --git a/code/analyze/filter_1_main.py b/code/analyze/filter_1_main.py index 50194f2..ea72c89 100644 --- a/code/analyze/filter_1_main.py +++ b/code/analyze/filter_1_main.py @@ -18,6 +18,7 @@ from misc.region_reader import Region_Reader +# TODO make this, filter 2 and thresholds a single method then refactor def main() -> None: ''' Perform first step of filtering diff --git a/code/analyze/filter_regions.py b/code/analyze/filter_regions.py new file mode 100644 index 0000000..50f6869 --- /dev/null +++ b/code/analyze/filter_regions.py @@ -0,0 +1,272 @@ +import global_params as gp +from misc import seq_functions +import numpy as np +from typing import List, Dict, TextIO, Tuple +import sys +from contextlib import ExitStack +from analyze import predict +from misc import read_table +from misc.region_reader import Region_Reader + + +def write_filtered_line(writer: TextIO, + region_id: str, + region: Dict, + fields: List) -> None: + ''' + Write the region id and values in "region" dict to open file writer + ''' + writer.write(f'{region_id}\t' + + '\t'.join([str(region[field]) + for field in fields[1:]]) + + '\n') + + +def filter_introgressed(region: Dict, + info: str, + reference_species: str) -> Tuple[bool, str]: + ''' + filtering out things that we can't call introgressed in general + with confidence (i.e. doesn't seem like a strong case against + being S288c) + Return true if the region passes the filter, or false with a string + specifying which filter failed + Tests: + -fraction of gaps masked in reference > 0.5 + -fraction of gaps masked in predicted species > 0.5 + -number of matches to predicted > 7 + -number of matches to predicted > number matches to reference + -divergence with predicted species + ''' + + predicted_species = region['predicted_species'] + + aligned_length = (int(region['end']) - int(region['start']) + 1) + + # FILTER: fraction gaps + masked + fraction_gaps_masked_threshold = .5 + # num_sites_nonmask_x is number of sites at which neither + # reference x nor the test sequence is masked or has a gap or + # unsequenced character + fraction_gaps_masked_r = \ + 1 - region['num_sites_nonmask_' + reference_species] / aligned_length + fraction_gaps_masked_s = \ + 1 - region['num_sites_nonmask_' + predicted_species] / aligned_length + + if fraction_gaps_masked_r > fraction_gaps_masked_threshold: + return False, f'fraction gaps/masked in master = '\ + f'{fraction_gaps_masked_r}' + if fraction_gaps_masked_s > fraction_gaps_masked_threshold: + return False, f'fraction gaps/masked in predicted = '\ + f'{fraction_gaps_masked_s}' + + # FILTER: number sites analyzed by HMM that match predicted (P) + # reference (C) + count_P = info.count('P') + count_C = info.count('C') + number_match_only_threshold = 7 + if count_P < number_match_only_threshold: + return False, f'count_P = {count_P}' + if count_P <= count_C: + return False, f'count_P = {count_P} and count_C = {count_C}' + + # FILTER: divergence with predicted reference and master reference + # (S288c) + id_predicted = float(region['match_nongap_' + predicted_species]) / \ + float(region['num_sites_nongap_' + predicted_species]) + id_master = float(region['match_nongap_' + reference_species]) / \ + float(region['num_sites_nongap_' + reference_species]) + + if id_master >= id_predicted: + return False, f'id with master = {id_master} '\ + f'and id with predicted = {id_predicted}' + if id_master < .7: + return False, f'id with master = {id_master}' + + return True, '' + + +def filter_ambiguous(region: Dict, + seqs: np.array, + threshold: float, + refs: List[str]) -> Tuple[bool, + List[str], + List[float], + List[int]]: + ''' + filter out things we can't assign to one species specifically; + return the other reasonable alternatives if we're filtering + it out + Returns a tuple of: + True if the region passes the filter + A list of likely species for the region + A list of fraction of matching sequence for each species + A list of total matching sites + Fails the filter if number of matches and fraction matching are >= more + than one state for the region + ''' + + s = region['predicted_species'] + + ids = {} + P_counts = {} + + seqs = np.asarray(seqs) + # skip any gap or unsequenced in ref or test + # also skip if ref and test equal (later test ri == test but not ref) + skip = np.any( + (seqs[0] == gp.gap_symbol, + seqs[0] == gp.unsequenced_symbol, + seqs[-1] == gp.gap_symbol, + seqs[-1] == gp.unsequenced_symbol, + seqs[0] == seqs[-1]), + axis=0) + + for ri, ref in enumerate(refs): + if ri == 0: + continue + r_match, r_total = seq_functions.seq_id(seqs[-1], seqs[ri]) + if r_total != 0: + ids[ref] = r_match / r_total + P_counts[ref] = np.sum( + np.logical_and( + np.logical_not(skip), + seqs[ri] == seqs[-1])) + + alts = {} + for r in ids.keys(): + # TODO should threshold be the same for both? + if ids[r] >= threshold * ids[s] and \ + P_counts[r] >= threshold * P_counts[s]: + alts[r] = (ids[r], P_counts[r]) + + alt_states = sorted(alts.keys(), key=lambda x: alts[x][0], reverse=True) + alt_ids = [alts[state][0] for state in alt_states] + alt_P_counts = [alts[state][1] for state in alt_states] + + if len(alts) > 1: + return False, alt_states, alt_ids, alt_P_counts + + return True, alt_states, alt_ids, alt_P_counts + + +def main(thresholds=[]): + ''' + Perform first step of filtering + Input files: + -blocks_{species}_quality.txt + + Output files: + -blocks_{species}_filtered1intermediate.txt + -blocks_{species}_filtered1.txt + -regions/{species}.fa.gz + -regions/{species}.pkl + ''' + # thresholds = [.999, .995, .985, .975, .965, .955, .945, + # .935, .925, .915, .905, .89, .87, .86] + args = predict.process_predict_args(sys.argv[2:]) + out_dir = gp.analysis_out_dir_absolute + args['tag'] + threshold = float(sys.argv[1]) + + with ExitStack() as stack: + if thresholds != []: + threshold_writer = stack.enter_context( + open(f'{out_dir}/filter_2_thresholds_{args["tag"]}.txt', 'w')) + threshold_writer.write( + 'threshold\tpredicted_state\talternative_states\tcount\n') + + data_table = {} + + for species_from in args['known_states'][1:]: + + print(species_from) + + region_summary, fields = read_table.read_table_rows( + f'{out_dir}/blocks_{species_from}_{args["tag"]}_quality.txt', + '\t') + + fields1i = fields + ['reason'] + fields1 = fields + fields2i = fields + ['alternative_states', 'alternative_ids', + 'alternative_P_counts'] + fields2 = fields + + with open(f'{out_dir}/blocks_{species_from}_{args["tag"]}' + '_filtered1intermediate.txt', 'w') as f_out1i, \ + open(f'{out_dir}/blocks_{species_from}_{args["tag"]}' + '_filtered1.txt', 'w') as f_out1, \ + open(f'{out_dir}/blocks_{species_from}_{args["tag"]}' + '_filtered2intermediate.txt', 'w') as f_out2i, \ + open(f'{out_dir}/blocks_{species_from}_{args["tag"]}' + '_filtered2.txt', 'w') as f_out2, \ + Region_Reader(f'{out_dir}/regions/{species_from}.fa.gz', + as_fa=True) as region_reader: + + f_out1i.write('\t'.join(fields1i) + '\n') + f_out1.write('\t'.join(fields1) + '\n') + f_out2i.write('\t'.join(fields2i) + '\n') + f_out2.write('\t'.join(fields2) + '\n') + + for region_id, header, seqs in region_reader.yield_fa(): + region = region_summary[region_id] + info_string = seqs[-1] + seqs = seqs[:-1] + + # filtering stage 1: things that we're confident in + # calling not S288c + p, reason = filter_introgressed(region, + info_string, + args['known_states'][0]) + region['reason'] = reason + write_filtered_line(f_out1i, region_id, region, fields1i) + + if p: + write_filtered_line(f_out1, region_id, region, fields1) + + for thresh in thresholds: + _, alt_states, _, _ = \ + filter_ambiguous(region, seqs, thresh, + args['known_states']) + + record_data_hit(data_table, + thresh, + species_from, + ','.join(sorted(alt_states))) + + (p, alt_states, + alt_ids, alt_P_counts) = filter_ambiguous( + region, seqs, threshold, args['known_states']) + region['alternative_states'] = ','.join(alt_states) + region['alternative_ids'] = ','.join( + [str(x) for x in alt_ids]) + region['alternative_P_counts'] = ','.join( + [str(x) for x in alt_P_counts]) + write_filtered_line(f_out2i, region_id, + region, fields2i) + + if p: + write_filtered_line(f_out2, region_id, + region, fields2) + + for thresh in thresholds: + for species in args['known_states'][1:]: + d = data_table[thresh][species] + for key in d.keys(): + threshold_writer.write( + f'{thresh}\t{species}\t{key}\t{d[key]}\n') + + +def record_data_hit(data_dict, threshold, species, key): + ''' + adds an entry to the data table or increments if exists + ''' + if threshold not in data_dict: + data_dict[threshold] = {} + + if species not in data_dict[threshold]: + data_dict[threshold][species] = {} + + if key not in data_dict[threshold][species]: + data_dict[threshold][species][key] = 0 + + data_dict[threshold][species][key] += 1 diff --git a/code/test/analyze/test_filter_helpers.py b/code/test/analyze/test_filter_helpers.py index 9d8ea48..4838156 100644 --- a/code/test/analyze/test_filter_helpers.py +++ b/code/test/analyze/test_filter_helpers.py @@ -1,6 +1,5 @@ from analyze import filter_helpers from io import StringIO -import numpy as np from misc import read_fasta import os import warnings diff --git a/code/test/analyze/test_filter_regions.py b/code/test/analyze/test_filter_regions.py new file mode 100644 index 0000000..b56681f --- /dev/null +++ b/code/test/analyze/test_filter_regions.py @@ -0,0 +1,450 @@ +from analyze import filter_regions +from io import StringIO +from misc import read_fasta +import os +import warnings +from pytest import approx + + +def test_main_no_thresh(mocker, capsys): + mocker.patch('sys.argv', ['', '0.1']) + mocker.patch('analyze.filter_regions.predict.process_predict_args', + return_value={ + 'known_states': ['state1', 'state2'], + 'tag': 'tag' + }) + mocker.patch('analyze.filter_regions.gp.analysis_out_dir_absolute', + '/dir') + mocker.patch('analyze.filter_regions.read_table.read_table_rows', + return_value=({'r1': {}, 'r2': {'a': 1}}, ['regions'])) + mocked_file = mocker.patch('analyze.filter_regions.open') + + mock_read = mocker.patch('analyze.filter_regions.Region_Reader') + mock_read().__enter__().yield_fa.return_value = iter([ + ('r1', ['> seq', '> info'], ['atcg', 'x..']), + ('r2', ['> seq', '> info'], ['atcg', 'x..'])]) + + mock_filter1 = mocker.patch('analyze.filter_regions.filter_introgressed', + side_effect=[(False, 'test'), # r1 + (True, '')]) # r2 + mock_filter2 = mocker.patch( + 'analyze.filter_regions.filter_ambiguous', + side_effect=[ + (True, ['1'], [0.8], [2]), + (False, ['1', '2'], [0.8, 0.5], [2, 1, 0]) + ]) + mock_write = mocker.patch('analyze.filter_regions.write_filtered_line') + + filter_regions.main() + + captured = capsys.readouterr().out + assert captured == 'state2\n' + + assert mock_read.call_count == 2 # called once during setup + mock_read.assert_called_with('/dirtag/regions/state2.fa.gz', as_fa=True) + + assert mocked_file.call_args_list == [ + mocker.call('/dirtag/blocks_state2_tag_filtered1intermediate.txt', + 'w'), + mocker.call('/dirtag/blocks_state2_tag_filtered1.txt', 'w'), + mocker.call('/dirtag/blocks_state2_tag_filtered2intermediate.txt', + 'w'), + mocker.call('/dirtag/blocks_state2_tag_filtered2.txt', 'w'), + ] + + # just headers, capture others + assert mocked_file().__enter__().write.call_args_list == [ + mocker.call('regions\treason\n'), + mocker.call('regions\n'), + mocker.call('regions\talternative_states\t' + 'alternative_ids\talternative_P_counts\n'), + mocker.call('regions\n'), + ] + + assert mock_filter1.call_count == 2 + # seems like this references the object, which changes after call + assert mock_filter1.call_args_list == [ + mocker.call({'reason': 'test'}, 'x..', 'state1'), + mocker.call({'a': 1, 'reason': '', + 'alternative_states': '1', + 'alternative_ids': '0.8', + 'alternative_P_counts': '2'}, 'x..', 'state1') + ] + + assert mock_filter2.call_args_list == [ + mocker.call({'a': 1, 'reason': '', + 'alternative_states': '1', + 'alternative_ids': '0.8', + 'alternative_P_counts': '2'}, + ['atcg'], 0.1, ['state1', 'state2']), + ] + assert mock_write.call_args_list == [ + mocker.call(mocker.ANY, 'r1', {'reason': 'test'}, + ['regions', 'reason']), + mocker.call(mocker.ANY, 'r2', {'a': 1, 'reason': '', + 'alternative_states': '1', + 'alternative_ids': '0.8', + 'alternative_P_counts': '2' + }, + ['regions', 'reason']), + mocker.call(mocker.ANY, 'r2', {'a': 1, 'reason': '', + 'alternative_states': '1', + 'alternative_ids': '0.8', + 'alternative_P_counts': '2' + }, + ['regions']), + mocker.call(mocker.ANY, 'r2', {'a': 1, 'reason': '', + 'alternative_states': '1', + 'alternative_ids': '0.8', + 'alternative_P_counts': '2' + }, + ['regions', 'alternative_states', 'alternative_ids', + 'alternative_P_counts']), + mocker.call(mocker.ANY, 'r2', {'a': 1, 'reason': '', + 'alternative_states': '1', + 'alternative_ids': '0.8', + 'alternative_P_counts': '2' + }, + ['regions']), + ] + + +def test_main(mocker, capsys): + mocker.patch('sys.argv', ['', '0.1']) + mocker.patch('analyze.filter_regions.predict.process_predict_args', + return_value={ + 'known_states': ['state1', 'state2'], + 'tag': 'tag' + }) + mocker.patch('analyze.filter_regions.gp.analysis_out_dir_absolute', + '/dir') + mocker.patch('analyze.filter_regions.read_table.read_table_rows', + return_value=({'r1': {}, 'r2': {'a': 1}}, ['regions'])) + mocked_file = mocker.patch('analyze.filter_regions.open') + + mock_read = mocker.patch('analyze.filter_regions.Region_Reader') + mock_read().__enter__().yield_fa.return_value = iter([ + ('r1', ['> seq', '> info'], ['atcg', 'x..']), + ('r2', ['> seq', '> info'], ['atcg', 'x..'])]) + + mock_filter1 = mocker.patch('analyze.filter_regions.filter_introgressed', + side_effect=[(False, 'test'), # r1 + (True, '')]) # r2 + mock_filter2 = mocker.patch( + 'analyze.filter_regions.filter_ambiguous', + side_effect=[ + (False, ['1', '2'], [0.8, 0.5], [2, 1, 0]), + (True, ['1'], [0.8], [2]), + (True, ['1'], [0.8], [2]), + (False, ['1', '2'], [0.8, 0.5], [2, 1, 0]) + ]) + mock_write = mocker.patch('analyze.filter_regions.write_filtered_line') + + filter_regions.main([0.99]) + + captured = capsys.readouterr().out + assert captured == 'state2\n' + + assert mock_read.call_count == 2 # called once during setup + mock_read.assert_called_with('/dirtag/regions/state2.fa.gz', as_fa=True) + + assert mocked_file.call_count == 5 + assert mocked_file.call_args_list == [ + mocker.call('/dirtag/filter_2_thresholds_tag.txt', 'w'), + mocker.call('/dirtag/blocks_state2_tag_filtered1intermediate.txt', + 'w'), + mocker.call('/dirtag/blocks_state2_tag_filtered1.txt', 'w'), + mocker.call('/dirtag/blocks_state2_tag_filtered2intermediate.txt', + 'w'), + mocker.call('/dirtag/blocks_state2_tag_filtered2.txt', 'w'), + ] + + # just headers, capture others + assert mocked_file().__enter__().write.call_args_list == [ + mocker.call('threshold\tpredicted_state\talternative_states\tcount\n'), + mocker.call('regions\treason\n'), + mocker.call('regions\n'), + mocker.call('regions\talternative_states\t' + 'alternative_ids\talternative_P_counts\n'), + mocker.call('regions\n'), + mocker.call('0.99\tstate2\t1,2\t1\n') + ] + + assert mock_filter1.call_count == 2 + # seems like this references the object, which changes after call + assert mock_filter1.call_args_list == [ + mocker.call({'reason': 'test'}, 'x..', 'state1'), + mocker.call({'a': 1, 'reason': '', + 'alternative_states': '1', + 'alternative_ids': '0.8', + 'alternative_P_counts': '2'}, 'x..', 'state1') + ] + + assert mock_filter2.call_args_list == [ + mocker.call({'a': 1, 'reason': '', + 'alternative_states': '1', + 'alternative_ids': '0.8', + 'alternative_P_counts': '2'}, + ['atcg'], 0.99, ['state1', 'state2']), + mocker.call({'a': 1, 'reason': '', + 'alternative_states': '1', + 'alternative_ids': '0.8', + 'alternative_P_counts': '2'}, + ['atcg'], 0.1, ['state1', 'state2']), + ] + assert mock_write.call_args_list == [ + mocker.call(mocker.ANY, 'r1', {'reason': 'test'}, + ['regions', 'reason']), + mocker.call(mocker.ANY, 'r2', {'a': 1, 'reason': '', + 'alternative_states': '1', + 'alternative_ids': '0.8', + 'alternative_P_counts': '2' + }, + ['regions', 'reason']), + mocker.call(mocker.ANY, 'r2', {'a': 1, 'reason': '', + 'alternative_states': '1', + 'alternative_ids': '0.8', + 'alternative_P_counts': '2' + }, + ['regions']), + mocker.call(mocker.ANY, 'r2', {'a': 1, 'reason': '', + 'alternative_states': '1', + 'alternative_ids': '0.8', + 'alternative_P_counts': '2' + }, + ['regions', 'alternative_states', 'alternative_ids', + 'alternative_P_counts']), + mocker.call(mocker.ANY, 'r2', {'a': 1, 'reason': '', + 'alternative_states': '1', + 'alternative_ids': '0.8', + 'alternative_P_counts': '2' + }, + ['regions']), + ] + + +def test_record_data_hit(): + dt = {} + filter_regions.record_data_hit(dt, 0.9, 's1', 'k1') + assert dt == {0.9: {'s1': {'k1': 1}}} + filter_regions.record_data_hit(dt, 0.9, 's1', 'k1') + filter_regions.record_data_hit(dt, 0.9, 's1', 'k1') + assert dt == {0.9: {'s1': {'k1': 3}}} + filter_regions.record_data_hit(dt, 0.9, 's1', 'k2') + assert dt == { + 0.9: { + 's1': {'k1': 3, 'k2': 1} + } + } + filter_regions.record_data_hit(dt, 0.9, 's2', 'k2') + assert dt == { + 0.9: { + 's1': {'k1': 3, 'k2': 1}, + 's2': {'k2': 1} + } + } + filter_regions.record_data_hit(dt, 0.8, 's2', 'k2') + assert dt == { + 0.9: { + 's1': {'k1': 3, 'k2': 1}, + 's2': {'k2': 1} + }, + 0.8: { + 's2': {'k2': 1} + } + } + filter_regions.record_data_hit(dt, 0.9, 's2', 'k2') + assert dt == { + 0.9: { + 's1': {'k1': 3, 'k2': 1}, + 's2': {'k2': 2} + }, + 0.8: { + 's2': {'k2': 1} + } + } + + +def test_write_filtered_line(): + # single value, first field is ignored + output = StringIO() + filter_regions.write_filtered_line(output, 'r1', {'chr': 'I'}, ['', 'chr']) + + assert output.getvalue() == 'r1\tI\n' + + # no value + output = StringIO() + filter_regions.write_filtered_line(output, 'r1', {}, []) + + assert output.getvalue() == 'r1\t\n' + + # two values + output = StringIO() + filter_regions.write_filtered_line(output, 'r1', + {'a': 'b', 'c': 'd'}, + ['', 'c', 'a']) + + assert output.getvalue() == 'r1\td\tb\n' + + +def test_filter_introgressed(mocker): + # fail fraction gapped on reference + region = {'predicted_species': 'pred', + 'start': 0, + 'end': 9, + 'num_sites_nonmask_ref': 4, + 'num_sites_nonmask_pred': 0, + 'match_nongap_pred': 0, + 'num_sites_nongap_pred': 0, + 'match_nongap_ref': 0, + 'num_sites_nongap_ref': 0, + } + + assert filter_regions.filter_introgressed(region, '', 'ref') == \ + (False, 'fraction gaps/masked in master = 0.6') + + # fail fraction gapped on predicted + region = {'predicted_species': 'pred', + 'start': 0, + 'end': 9, + 'num_sites_nonmask_ref': 5, + 'num_sites_nonmask_pred': 3, + 'match_nongap_pred': 0, + 'num_sites_nongap_pred': 0, + 'match_nongap_ref': 0, + 'num_sites_nongap_ref': 0, + } + + assert filter_regions.filter_introgressed(region, '', 'ref') == \ + (False, 'fraction gaps/masked in predicted = 0.7') + + # fail match counts + region = {'predicted_species': 'pred', + 'start': 0, + 'end': 9, + 'num_sites_nonmask_ref': 5, + 'num_sites_nonmask_pred': 5, + 'match_nongap_pred': 0, + 'num_sites_nongap_pred': 0, + 'match_nongap_ref': 0, + 'num_sites_nongap_ref': 0, + } + + assert filter_regions.filter_introgressed(region, 'CP', 'ref') == \ + (False, 'count_P = 1') + assert filter_regions.filter_introgressed(region, + 'CCCCCCCCPPPPPPP', 'ref') == \ + (False, 'count_P = 7 and count_C = 8') + + # fail divergence, master >= pred + region = {'predicted_species': 'pred', + 'start': 0, + 'end': 9, + 'num_sites_nonmask_ref': 5, + 'num_sites_nonmask_pred': 5, + 'match_nongap_pred': 5, + 'num_sites_nongap_pred': 10, + 'match_nongap_ref': 6, + 'num_sites_nongap_ref': 10, + } + + assert filter_regions.filter_introgressed(region, 'CPPPPPPP', 'ref') == \ + (False, 'id with master = 0.6 and id with predicted = 0.5') + + # fail divergence, master >= 0.7 + region = {'predicted_species': 'pred', + 'start': 0, + 'end': 9, + 'num_sites_nonmask_ref': 5, + 'num_sites_nonmask_pred': 5, + 'match_nongap_pred': 8, + 'num_sites_nongap_pred': 10, + 'match_nongap_ref': 6, + 'num_sites_nongap_ref': 10, + } + + assert filter_regions.filter_introgressed(region, 'CPPPPPPP', 'ref') == \ + (False, 'id with master = 0.6') + + # passes + region = {'predicted_species': 'pred', + 'start': 0, + 'end': 9, + 'num_sites_nonmask_ref': 5, + 'num_sites_nonmask_pred': 5, + 'match_nongap_pred': 8, + 'num_sites_nongap_pred': 10, + 'match_nongap_ref': 7, + 'num_sites_nongap_ref': 10, + } + + assert filter_regions.filter_introgressed(region, 'CPPPPPPP', 'ref') == \ + (True, '') + + +def test_filter_ambiguous(mocker): + mocker.patch('analyze.filter_regions.gp.gap_symbol', '-') + mocker.patch('analyze.filter_regions.gp.unsequenced_symbol', 'n') + + region = {'predicted_species': '1', + } + seqs = [list('attatt'), # reference + list('aggcat'), # 4 / 5, p = 2 + list('a--tta'), # 2 / 4, p = 1 + list('nng---'), # no matches, '3' not in outputs + list('attatt'), # 2 / 5, p = 0 + list('ag-tat')] # test sequence + + threshold = 0 + filt, states, ids, p_count = filter_regions.filter_ambiguous( + region, seqs, threshold, ['ref', '1', '2', '3', '4']) + assert filt is False + assert states == ['1', '2', '4'] + assert ids == [0.8, 0.5, 0.4] + assert p_count == [2, 1, 0] + + threshold = 0.1 + filt, states, ids, p_count = filter_regions.filter_ambiguous( + region, seqs, threshold, ['ref', '1', '2', '3', '4']) + assert filt is False + assert states == ['1', '2'] + assert ids == [0.8, 0.5] + assert p_count == [2, 1] + + threshold = 0.9 + filt, states, ids, p_count = filter_regions.filter_ambiguous( + region, seqs, threshold, ['ref', '1', '2', '3', '4']) + assert filt is True + assert states == ['1'] + assert ids == [0.8] + assert p_count == [2] + + +def test_filter_ambiguous_on_region(mocker): + mocker.patch('analyze.filter_regions.gp.gap_symbol', '-') + mocker.patch('analyze.filter_regions.gp.unsequenced_symbol', 'n') + + fa = os.path.join(os.path.split(__file__)[0], 'r10805.fa') + + if os.path.exists(fa): + headers, seqs = read_fasta.read_fasta(fa, gz=False) + seqs = seqs[:-1] + p, alt_states, alt_ids, alt_P_counts = filter_regions.filter_ambiguous( + {'predicted_species': 'N_45'}, seqs, 0.1, + ['S288c', 'CBS432', 'N_45', 'DBVPG6304', 'UWOPS91_917_1']) + assert p is False + assert alt_states == ['CBS432', 'N_45', 'UWOPS91_917_1', 'DBVPG6304'] + assert alt_ids == approx([0.9983805668016195, 0.994331983805668, + 0.9642857142857143, 0.9618506493506493]) + assert alt_P_counts == [145, 143, 128, 129] + + p, alt_states, alt_ids, alt_P_counts = filter_regions.filter_ambiguous( + {'predicted_species': 'N_45'}, seqs, 0.98, + ['S288c', 'CBS432', 'N_45', 'DBVPG6304', 'UWOPS91_917_1']) + assert p is False + assert alt_states == ['CBS432', 'N_45'] + assert alt_ids == approx([0.9983805668016195, 0.994331983805668]) + assert alt_P_counts == [145, 143] + + else: + warnings.warn('Unable to test with datafile r10805.fa') From 2667a6a6c8d21c842440d17a251d357e81dcc544 Mon Sep 17 00:00:00 2001 From: Troy Comi Date: Tue, 21 May 2019 15:21:35 -0400 Subject: [PATCH 23/33] Filter Regions Refactor Continued refactoring of main methods onto filtering. Part of the changes saw a modification to the configuration object to simplify setting code into a more uniform interface. Have started checking flake8 on entire project, fixing occasionally. --- code/align/aggregate_alignment_stats.py | 25 +- code/align/alignment_stats.py | 13 +- code/align/average_alignment_stats.py | 20 +- code/align/convert_coordinates.py | 16 +- code/align/convert_coordinates_main.py | 23 +- code/align/mask.py | 22 +- code/align/mask_helpers.py | 14 +- code/align/master_alignment.py | 72 +- code/align/master_alignment_main.py | 11 +- code/align/polymorphism.py | 2 + code/align/ref_ids_main.py | 42 +- code/align/run_alignment_stats.py | 24 +- code/align/run_mafft.py | 39 +- code/align/run_mafft_refs.py | 37 +- code/align/run_mugsy.py | 22 +- code/align/run_tcoffee.py | 21 +- code/analyze/filter_1_main.py | 76 -- code/analyze/filter_2_main.py | 80 -- code/analyze/filter_2_thresholds_main.py | 97 -- code/analyze/filter_helpers.py | 178 ---- code/analyze/filter_regions.py | 663 +++++++------ code/analyze/introgression_configuration.py | 373 ++++---- code/analyze/main.py | 130 ++- .../plotting/format_for_plot_gene_region.py | 58 +- code/analyze/plotting/format_for_plotting.py | 190 ++-- code/analyze/plotting/format_for_plotting2.py | 18 +- .../plotting/format_for_plotting_region.py | 24 +- .../plotting/format_polymorphism_for_r.py | 20 +- code/analyze/predict.py | 4 +- code/analyze/read_args.py | 20 +- code/analyze/structure/structure_1_main.py | 43 +- code/analyze/structure/structure_2_main.py | 54 +- code/analyze/structure/structure_3_main.py | 80 +- code/analyze/summarize_region_quality.py | 244 ++--- .../aggregate_genes_by_strains_main.py | 25 +- code/analyze/to_update/annotate_positions.py | 17 +- .../to_update/annotate_positions_main.py | 50 +- code/analyze/to_update/annotate_regions.py | 31 +- .../to_update/annotate_regions_main.py | 62 +- code/analyze/to_update/check_paralogs_main.py | 102 +- code/analyze/to_update/combine_all_strains.py | 163 ++-- .../combine_gene_all_strains_main.py | 110 ++- .../combine_region_all_strains_main.py | 86 +- code/analyze/to_update/compare.py | 56 +- .../to_update/compare_3strains_main.py | 25 +- .../to_update/compare_predictions_main.py | 43 +- code/analyze/to_update/compare_to_strope.py | 50 +- .../analyze/to_update/count_coding_changes.py | 31 +- .../to_update/count_coding_changes_main.py | 167 ++-- .../to_update/count_introgressed_main.py | 28 +- .../frequency_of_introgression_main.py | 21 +- code/analyze/to_update/gene_overlap_main.py | 133 ++- code/analyze/to_update/gene_predictions.py | 134 +-- ...otide_diversity_from_introgression_main.py | 223 ++--- code/annotate/fix.py | 8 +- code/annotate/makeblastdb.py | 11 +- code/annotate/orfs_main.py | 39 +- code/config.yaml | 22 +- code/global_params.py | 57 +- code/misc/to_bed.py | 13 +- code/test/analyze/test_filter_1_main.py | 59 -- code/test/analyze/test_filter_2_main.py | 89 -- .../analyze/test_filter_2_thresholds_main.py | 104 --- code/test/analyze/test_filter_helpers.py | 245 ----- code/test/analyze/test_filter_regions.py | 883 ++++++++++++------ code/test/analyze/test_id_regions.py | 14 +- .../test_introgression_configuration.py | 346 +++---- .../analyze/test_main_filter_regions_args.py | 283 ++++++ .../test_main_filter_regions_config.py | 282 ++++++ code/test/analyze/test_main_id_args.py | 4 +- code/test/analyze/test_main_id_config.py | 6 +- code/test/analyze/test_main_predict_args.py | 10 +- code/test/analyze/test_main_predict_config.py | 14 +- .../test_main_summarize_regions_args.py | 16 +- .../test_main_summarize_regions_config.py | 26 +- code/test/analyze/test_predict_hmm_builder.py | 6 +- code/test/analyze/test_predict_predictor.py | 6 +- .../analyze/test_summarize_region_quality.py | 39 +- .../helper_scripts/compare_filter_outputs.sh | 20 +- .../helper_scripts/run_filter_2_thresholds.sh | 11 +- code/test/hmm/test_hmm_bw.py | 36 +- 81 files changed, 3455 insertions(+), 3506 deletions(-) delete mode 100644 code/analyze/filter_1_main.py delete mode 100644 code/analyze/filter_2_main.py delete mode 100644 code/analyze/filter_2_thresholds_main.py delete mode 100644 code/analyze/filter_helpers.py delete mode 100644 code/test/analyze/test_filter_1_main.py delete mode 100644 code/test/analyze/test_filter_2_main.py delete mode 100644 code/test/analyze/test_filter_2_thresholds_main.py delete mode 100644 code/test/analyze/test_filter_helpers.py create mode 100644 code/test/analyze/test_main_filter_regions_args.py create mode 100644 code/test/analyze/test_main_filter_regions_config.py diff --git a/code/align/aggregate_alignment_stats.py b/code/align/aggregate_alignment_stats.py index 290179a..a3e72bd 100644 --- a/code/align/aggregate_alignment_stats.py +++ b/code/align/aggregate_alignment_stats.py @@ -1,14 +1,17 @@ import os -import sys -sys.path.insert(0, '..') import global_params as gp gp_dir = '../' -stats_files = [gp_dir + gp.alignments_dir + x for x in filter(\ - lambda x: 'stats' in x and 'summary' not in x, os.listdir(gp_dir + gp.alignments_dir))] +stats_files = [gp_dir + gp.alignments_dir + x for x in filter( + lambda x: 'stats' in x and 'summary' not in x, + os.listdir(gp_dir + gp.alignments_dir))] # goal is to generate file for R (e.g. for two references and test strain): -# chromosome strain frac_S288c_S288c frac_S288c_CBS432 frac_S288c_x frac_CBS432_S288c frac_CBS432_CBS432 frac_CBS432_x frac_x_S288c frac_x_CBS432 frac_x_x aligned_length_S288c aligned_length_CBS432 aligned_length_x num_align_columns_0 num_align_columns_1 num_align_columns_2 num_align_columns_3 +# chromosome strain frac_S288c_S288c frac_S288c_CBS432 frac_S288c_x +# frac_CBS432_S288c frac_CBS432_CBS432 frac_CBS432_x frac_x_S288c +# frac_x_CBS432 frac_x_x aligned_length_S288c aligned_length_CBS432 +# aligned_length_x num_align_columns_0 num_align_columns_1 +# num_align_columns_2 num_align_columns_3 f = open(gp_dir + gp.alignments_dir + 'mafft_stats_summary.txt', 'w') @@ -25,15 +28,15 @@ for i in range(0, len(gp.alignment_ref_order) + 2): f.write('\t' + 'num_align_columns_' + str(i)) - + f.write('\n') all_strains = gp.alignment_ref_order + ['x'] # one line for each of these files for fn in stats_files: - print fn - + print(fn) + lines = [line.strip() for line in open(fn, 'r').readlines()] # histogram of number of number of strains aligned @@ -43,10 +46,10 @@ c.append(float(lines[i + offset].split(',')[1])) # aligned lengths - l = [] + lengths = [] offset += len(all_strains) + 1 + 2 for i in range(len(all_strains)): - l.append(float(lines[i + offset].split(',')[1])) + lengths.append(float(lines[i + offset].split(',')[1])) sx = lines[offset + len(all_strains) - 1].split(',')[0] @@ -67,7 +70,7 @@ for j in range(len(all_strains)): f.write('\t' + str(fr[i][j])) for i in range(len(all_strains)): - f.write('\t' + str(l[i])) + f.write('\t' + str(lengths[i])) for i in range(len(all_strains) + 1): f.write('\t' + str(c[i])) f.write('\n') diff --git a/code/align/alignment_stats.py b/code/align/alignment_stats.py index 9a06939..58f5ac3 100644 --- a/code/align/alignment_stats.py +++ b/code/align/alignment_stats.py @@ -1,8 +1,6 @@ -import sys -sys.path.insert(0, '..') import global_params as gp -sys.path.insert(0, '../misc') -import read_fasta +from misc import read_fasta + # count sites where n, ..., 3, 2, 1 genomes aligned, etc. def num_strains_aligned_by_site(seqs): @@ -17,7 +15,8 @@ def num_strains_aligned_by_site(seqs): return num_strains_hist -# fraction of each strain's sequence contained in alignment + +# fraction of each strain's sequence contained in alignment # (should be 1) def fraction_strains_aligned(headers, seqs): nseqs = len(seqs) @@ -34,6 +33,7 @@ def fraction_strains_aligned(headers, seqs): return fracs_aligned, seq_lengths + # using each genome as reference, percentage of other genomes aligned def frac_aligned_to_reference(seqs, seq_lengths): nseqs = len(seqs) @@ -47,7 +47,8 @@ def frac_aligned_to_reference(seqs, seq_lengths): else: total = 0 for i in range(nsites): - if seqs[ref][i] != gp.gap_symbol and seqs[other][i] != gp.gap_symbol: + if seqs[ref][i] != gp.gap_symbol and \ + seqs[other][i] != gp.gap_symbol: total += 1 r.append(float(total) / seq_lengths[other]) fracs_aligned_to_ref.append(r) diff --git a/code/align/average_alignment_stats.py b/code/align/average_alignment_stats.py index 718cd85..b6c04af 100644 --- a/code/align/average_alignment_stats.py +++ b/code/align/average_alignment_stats.py @@ -1,14 +1,12 @@ import os -import sys -sys.path.insert(0, '..') import global_params as gp import numpy gp_dir = '../' -stats_files = [gp_dir + gp.alignments_dir + x for x in filter(\ +stats_files = [gp_dir + gp.alignments_dir + x for x in filter( lambda x: 'stats' in x, os.listdir(gp_dir + gp.alignments_dir))] -#avg_frac_aligned_by_chrm = dict(zip(gp.chrms, [0]*len(gp.chrms))) +# avg_frac_aligned_by_chrm = dict(zip(gp.chrms, [0]*len(gp.chrms))) avg_frac_aligned_p = 0 avg_frac_aligned_x = 0 total_p = 0 @@ -26,18 +24,18 @@ avg_frac_aligned_x += fx * lx total_p += lp total_x += lx - #print fn[fn.find('chr')-8:], fx, lx, lc + # print fn[fn.find('chr')-8:], fx, lx, lc a.append(fx) avg_frac_aligned_p /= total_p avg_frac_aligned_x /= total_x -print len(stats_files) -print avg_frac_aligned_p -print avg_frac_aligned_x +print(len(stats_files)) +print(avg_frac_aligned_p) +print(avg_frac_aligned_x) hist, edges = numpy.histogram(a, bins=30) -print hist -print edges -print sum(hist[:-1]) +print(hist) +print(edges) +print(sum(hist[:-1])) diff --git a/code/align/convert_coordinates.py b/code/align/convert_coordinates.py index e38f872..a14ded6 100644 --- a/code/align/convert_coordinates.py +++ b/code/align/convert_coordinates.py @@ -1,13 +1,11 @@ -import sys import gzip -sys.path.insert(0, '..') import global_params as gp def convert(s1, s2): i2 = -1 i2d = 0 - l = [] + result = [] for i in range(len(s1)): if s2[i] == gp.gap_symbol: i2d += 1 @@ -16,14 +14,14 @@ def convert(s1, s2): i2d = 0 if s1[i] != gp.gap_symbol: if i2d == 0: - l.append(str(i2)) + result.append(str(i2)) else: - l.append(str(i2) + '.' + str(i2d)) - return l + result.append(str(i2) + '.' + str(i2d)) + return result -def write_coordinates(l, fn): + +def write_coordinates(coords, fn): f = gzip.open(fn, 'wb') - f.write('\n'.join([str(x) for x in l])) + f.write('\n'.join([str(x) for x in coords])) f.write('\n') f.close() - diff --git a/code/align/convert_coordinates_main.py b/code/align/convert_coordinates_main.py index 22ff28e..08cc90a 100644 --- a/code/align/convert_coordinates_main.py +++ b/code/align/convert_coordinates_main.py @@ -1,33 +1,30 @@ -import sys import os -from convert_coordinates import * -sys.path.insert(0, '..') +from convert_coordinates import (write_coordinates, convert) import global_params as gp -sys.path.insert(0, '../misc/') -import read_fasta +from misc import read_fasta gp_dir = '../' fns = os.listdir(gp_dir + gp.alignments_dir) fns = filter(lambda fn: fn.endswith(gp.alignment_suffix), fns) for fn in fns: - print fn + print(fn) x = fn.split('_') chrm = x[-2] strain_names = x[0:-2] headers, seqs = read_fasta.read_fasta(gp_dir + gp.alignments_dir + fn) - + # for each index in cer reference, get index in other strain # (either par reference for 2-way alignment or cer strain for # 3-way) - coord_fn = gp.analysis_out_dir_absolute + 'coordinates/' + \ - strain_names[0] + '_to_' + strain_names[-1] + \ - '_' + chrm + '.txt.gz' + coord_fn = (gp.analysis_out_dir_absolute + 'coordinates/' + + strain_names[0] + '_to_' + strain_names[-1] + + '_' + chrm + '.txt.gz') write_coordinates(convert(seqs[0], seqs[-1]), coord_fn) # for each index in other strain, get index in cer reference - coord_fn = gp.analysis_out_dir_absolute + 'coordinates/' + \ - strain_names[-1] + '_to_' + strain_names[0] + \ - '_' + chrm + '.txt.gz' + coord_fn = (gp.analysis_out_dir_absolute + 'coordinates/' + + strain_names[-1] + '_to_' + strain_names[0] + + '_' + chrm + '.txt.gz') write_coordinates(convert(seqs[-1], seqs[0]), coord_fn) diff --git a/code/align/mask.py b/code/align/mask.py index df0ffa4..4d5ae74 100644 --- a/code/align/mask.py +++ b/code/align/mask.py @@ -1,7 +1,6 @@ import sys import os -from mask_helpers import * -import align_helpers +from mask_helpers import mask from analyze import read_args import global_params as gp @@ -33,25 +32,22 @@ current_strain_fn = d + strain_fn.replace('*', strain) current_strain_masked_fn = d + strain_masked_fn.replace('*', strain) - current_strain_intervals_fn = intervals_d + intervals_fn.replace('*', strain) + current_strain_intervals_fn = intervals_d + intervals_fn.replace('*', + strain) for chrm in gp.chrms: in_fn = current_strain_fn.replace('?', chrm) out_fn = current_strain_intervals_fn.replace('?', chrm) - + # get dustmasker intervals - cmd_string = gp.blast_install_path + 'dustmasker' + \ - ' -in ' + in_fn + \ - ' -out ' + out_fn + \ - ' -outfmt interval' - + cmd_string = (gp.blast_install_path + 'dustmasker' + + ' -in ' + in_fn + + ' -out ' + out_fn + + ' -outfmt interval') + os.system(cmd_string) # replace those intervals with Ns and write to masked fasta file masked_fn = current_strain_masked_fn.replace('?', chrm) mask(in_fn, masked_fn, out_fn) - - - - diff --git a/code/align/mask_helpers.py b/code/align/mask_helpers.py index 8bd7dc6..b71ede7 100644 --- a/code/align/mask_helpers.py +++ b/code/align/mask_helpers.py @@ -1,14 +1,11 @@ -import sys -sys.path.insert(0, '..') import global_params as gp -sys.path.insert(0, '../misc') -import read_fasta -import write_fasta +from misc import read_fasta +from misc import write_fasta -def read_intervals(fn): +def read_intervals(fn): f = open(fn, 'r') - f.readline() # header + f.readline() # header line = f.readline() intervals = [] while line != '': @@ -18,8 +15,8 @@ def read_intervals(fn): f.close() return intervals -def mask(fn, masked_fn, intervals_fn): +def mask(fn, masked_fn, intervals_fn): headers, seqs = read_fasta.read_fasta(fn) seq = list(seqs[0]) intervals = read_intervals(intervals_fn) @@ -28,4 +25,3 @@ def mask(fn, masked_fn, intervals_fn): seq[i] = gp.unsequenced_symbol seq = ''.join(seq) write_fasta.write_fasta(headers, [seq], masked_fn) - diff --git a/code/align/master_alignment.py b/code/align/master_alignment.py index 7309012..aff6cc9 100644 --- a/code/align/master_alignment.py +++ b/code/align/master_alignment.py @@ -1,61 +1,60 @@ -# combine all chromosomal alignments into one master, indexed relative to cerevisiae reference - -import sys -import os -import copy -sys.path.insert(0, '../misc') -import read_maf -sys.path.insert(0, '..') +# combine all chromosomal alignments into one master +# indexed relative to cerevisiae reference + +from misc import read_maf import global_params as gp -complement = {'A':'T', 'T':'A', 'G':'C', 'C':'G', \ - 'a':'t', 't':'a', 'g':'c', 'c':'g', \ - 'N':'N', 'n':'n', '-':'-'} +complement = {'A': 'T', 'T': 'A', 'G': 'C', 'C': 'G', + 'a': 't', 't': 'a', 'g': 'c', 'c': 'g', + 'N': 'N', 'n': 'n', '-': '-'} + +flip = {'-': '+', '+': '-'} -flip = {'-':'+', '+':'-'} def reverse_start(start, length, total_length): return total_length - start - length -def reverse_complement(s): +def reverse_complement(s): r = [] for b in s[::-1]: r.append(complement[b]) return r -def forward_index(blocks): +def forward_index(blocks): # go through all blocks and add a field for start relative to # forward strand, and sequence in forward direction for label in blocks.keys(): - block = blocks[label] for strain in blocks[label]['strains'].keys(): start = blocks[label]['strains'][strain]['start'] seq = blocks[label]['strains'][strain]['sequence'] - + blocks[label]['strains'][strain]['forward_start'] = start blocks[label]['strains'][strain]['forward_sequence'] = seq if blocks[label]['strains'][strain]['strand'] == '-': - blocks[label]['strains'][strain]['forward_sequence'] = seq[::-1] + blocks[label]['strains'][strain]['forward_sequence'] = \ + seq[::-1] blocks[label]['strains'][strain]['forward_start'] = \ - reverse_start(start, blocks[label]['strains'][strain]['length'], \ - blocks[label]['strains'][strain]['aligned_length']) + reverse_start( + start, + blocks[label]['strains'][strain]['length'], + blocks[label]['strains'][strain]['aligned_length']) return blocks -# make all master sequences go in forward direction (+) and flip -# others as necessary -def master_forward(blocks, master): +def master_forward(blocks, master): + # make all master sequences go in forward direction (+) and flip + # others as necessary for label in blocks.keys(): - block = blocks[label] - if blocks[label]['strains'].has_key(master): + if master in blocks[label]['strains']: if blocks[label]['strains'][master]['strand'] == '-': for strain in blocks[label]['strains'].keys(): - aligned_length = blocks[label]['strains'][strain]['aligned_length'] + aligned_length = \ + blocks[label]['strains'][strain]['aligned_length'] seq = blocks[label]['strains'][strain]['sequence'] start = blocks[label]['strains'][strain]['start'] length = blocks[label]['strains'][strain]['length'] @@ -70,11 +69,12 @@ def master_forward(blocks, master): return blocks + def make_master(fn, master): # keyed by block label; most of info in each keyed by ['strains'][strain] blocks = read_maf.read_mugsy(fn) - + # flip all blocks so that master sequence is on + strand blocks = master_forward(blocks, master) # add fields giving index and sequence relative to + strand @@ -83,13 +83,13 @@ def make_master(fn, master): # make sequences with alignment columns present in master n = blocks['1']['strains'][master]['aligned_length'] all_strains = blocks['1']['strains'].keys() - a = dict(zip(all_strains, [[gp.unaligned_symbol] * n for s in all_strains])) + a = dict(zip(all_strains, + [[gp.unaligned_symbol] * n for s in all_strains])) # loop through all blocks for label in blocks.keys(): - block = blocks[label] # only care about aligned blocks that include master sequence - if blocks[label]['strains'].has_key(master): + if master in blocks[label]['strains']: absolute_ind = blocks[label]['strains'][master]['start'] master_seq = blocks[label]['strains'][master]['sequence'] block_length = len(master_seq) @@ -104,25 +104,23 @@ def make_master(fn, master): # apparently mugsy sometimes aligns the same part # of one genome to multiple parts of another # genome. this is a problem. - assert a[master][absolute_ind] == gp.unaligned_symbol, absolute_ind + assert a[master][absolute_ind] == gp.unaligned_symbol,\ + absolute_ind # loop through all the strains in this block for strain in strains: a[strain][absolute_ind] = \ - blocks[label]['strains'][strain]\ - ['forward_sequence'][relative_ind] + blocks[label]['strains'][strain][ + 'forward_sequence'][relative_ind] absolute_ind += 1 for strain in all_strains: a[strain] = ''.join(a[strain]) - print strain, a[strain].count(gp.unaligned_symbol) - - #assert total_aligned_master == n - a[master].count(gp.unaligned_symbol), \ - # str(total_aligned_master) + ' ' + str(n - a[master].count(gp.unaligned_symbol)) + print(strain, a[strain].count(gp.unaligned_symbol)) return a -def write_master(fn, a): +def write_master(fn, a): f = open(fn, 'w') for strain in a.keys(): f.write('> ' + strain + '\n') diff --git a/code/align/master_alignment_main.py b/code/align/master_alignment_main.py index 52d2126..f0c5a10 100644 --- a/code/align/master_alignment_main.py +++ b/code/align/master_alignment_main.py @@ -2,9 +2,8 @@ # to one reference import sys -from master_alignment import * -sys.path.insert(0, '..') -from align_helpers import * +from align.master_alignment import (make_master, write_master) +from align.align_helpers import (get_strains, flatten) import global_params as gp strains = get_strains(flatten(gp.non_ref_dirs.values())) @@ -19,9 +18,9 @@ '_'.join(gp.alignment_ref_order) + \ '_' + strain for chrm in gp.chrms: - print chrm + print(chrm) alignment_fn = alignment_prefix + '_chr' + chrm + gp.alignment_suffix - master_alignment_fn = alignment_prefix + '_chr' + chrm + '_master' + gp.fasta_suffix + master_alignment_fn = (alignment_prefix + '_chr' + + chrm + '_master' + gp.fasta_suffix) a = make_master(alignment_fn, gp.master_ref) write_master(master_alignment_fn, a) - diff --git a/code/align/polymorphism.py b/code/align/polymorphism.py index 566f2fe..1f191c4 100644 --- a/code/align/polymorphism.py +++ b/code/align/polymorphism.py @@ -1,6 +1,8 @@ # calculate polymorphism rate between reference genomes in 100-bp # windows across each chromosome +import sys +from misc import read_fasta headers, seqs = read_fasta.read_fasta(sys.argv[1]) a = dict(zip(headers, seqs)) diff --git a/code/align/ref_ids_main.py b/code/align/ref_ids_main.py index f6cf7b8..58812fd 100644 --- a/code/align/ref_ids_main.py +++ b/code/align/ref_ids_main.py @@ -1,14 +1,8 @@ -import re -import sys -import os -import copy from collections import defaultdict -sys.path.insert(0, '..') import global_params as gp -sys.path.insert(0, '../misc/') -import mystats -import seq_functions -import read_fasta +from misc import mystats +from misc import seq_functions +from misc import read_fasta # get pairwise identities between all aligned references: # - overall average @@ -23,28 +17,28 @@ pair_chrm_ids = defaultdict(lambda: defaultdict(list)) for chrm in gp.chrms: - print chrm - fn = gp_dir + gp.alignments_dir + \ - '_'.join(gp.alignment_ref_order) + \ - '_chr' + chrm + '_mafft' + gp.alignment_suffix + print(chrm) + fn = (gp_dir + gp.alignments_dir + + '_'.join(gp.alignment_ref_order) + + '_chr' + chrm + '_mafft' + gp.alignment_suffix) headers, seqs = read_fasta.read_fasta(fn) for i in range(nrefs): ref1 = gp.alignment_ref_order[i] for j in range(i+1, nrefs): - print i, j + print(i, j) ref2 = gp.alignment_ref_order[j] ids = seq_functions.seq_id_windowed(seqs[i], seqs[j], window) - + pair_chrm_ids[(ref1, ref2)][chrm] = ids -fs = open(gp.analysis_out_dir_absolute + 'ref_ids_summary_' + \ - '_'.join(gp.alignment_ref_order) + '.txt', 'w') +fs = open(gp.analysis_out_dir_absolute + 'ref_ids_summary_' + + '_'.join(gp.alignment_ref_order) + '.txt', 'w') fs.write('pair\tchromosome\tmean\tmedian\n') -f = open(gp.analysis_out_dir_absolute + 'ref_ids_' + \ +f = open(gp.analysis_out_dir_absolute + 'ref_ids_' + '_'.join(gp.alignment_ref_order) + '.txt', 'w') f.write('pair\tid\n') @@ -53,14 +47,14 @@ pair_string = ','.join(pair) for chrm in gp.chrms: ids = pair_chrm_ids[pair][chrm] - fs.write(pair_string + '\t' + \ - chrm + '\t' + \ - str(mystats.mean(ids)) + '\t' + \ + fs.write(pair_string + '\t' + + chrm + '\t' + + str(mystats.mean(ids)) + '\t' + str(mystats.median(ids)) + '\n') all_ids += ids - fs.write(pair_string + '\t' + \ - 'all' + '\t' + \ - str(mystats.mean(all_ids)) + '\t' + \ + fs.write(pair_string + '\t' + + 'all' + '\t' + + str(mystats.mean(all_ids)) + '\t' + str(mystats.median(all_ids)) + '\n') for i in ids: diff --git a/code/align/run_alignment_stats.py b/code/align/run_alignment_stats.py index 0fd7052..7f99b73 100644 --- a/code/align/run_alignment_stats.py +++ b/code/align/run_alignment_stats.py @@ -1,8 +1,11 @@ +import sys import os -from alignment_stats import * -from align_helpers import * -sys.path.insert(0, '..') +from align.alignment_stats import (num_strains_aligned_by_site, + fraction_strains_aligned, + frac_aligned_to_reference) +from align.align_helpers import (flatten, get_strains) import global_params as gp +from misc import read_fasta # gives info related to how good an alignment is: # - number of sites where n, ..., 3, 2, 1, genomes aligned @@ -14,10 +17,11 @@ strain, d = s[int(sys.argv[1])] gp_dir = '../' -fn_start = gp_dir + gp.alignments_dir + '_'.join(gp.alignment_ref_order) + '_' + strain + '_chr' +fn_start = (gp_dir + gp.alignments_dir + '_'.join(gp.alignment_ref_order) + + '_' + strain + '_chr') for chrm in gp.chrms: - print chrm + print(chrm) sys.stdout.flush() if not os.path.isfile(fn_start + chrm + '_mafft.maf'): @@ -30,8 +34,9 @@ # number of sites where n,...,3,2,1 genomes aligned num_strains_by_site = num_strains_aligned_by_site(seqs) - f_out.write(\ - '# histogram of number of strains aligned across all alignment columns\n') + f_out.write( + '# histogram of number of strains ' + 'aligned across all alignment columns\n') for n in range(len(num_strains_by_site)): f_out.write(str(n) + ',' + str(num_strains_by_site[n]) + '\n') f_out.write('\n') @@ -44,7 +49,8 @@ # length of chromosomes f_out.write('chromosome aligned lengths\n') for n in range(len(seqs)): - f_out.write(headers[n][1:].strip().split(' ')[0] + ',' + str(seq_lengths[n]) + '\n') + f_out.write(headers[n][1:].strip().split(' ')[0] + + ',' + str(seq_lengths[n]) + '\n') f_out.write('\n') # using each genome as reference, fraction of other genomes aligned @@ -52,7 +58,7 @@ frac_aligned_to_ref = frac_aligned_to_reference(seqs, seq_lengths) for ref in range(len(seqs)): f_out.write(headers[ref][1:].strip().split(' ')[0]) - for other in range(len(seqs)): + for other in range(len(seqs)): f_out.write(',' + str(frac_aligned_to_ref[ref][other])) f_out.write('\n') f_out.write('\n') diff --git a/code/align/run_mafft.py b/code/align/run_mafft.py index 4681f77..e5f6681 100644 --- a/code/align/run_mafft.py +++ b/code/align/run_mafft.py @@ -1,6 +1,6 @@ import sys import os -from align_helpers import * +from align.align_helpers import (concatenate_fasta) from analyze import read_args import global_params as gp @@ -19,8 +19,8 @@ if os.stat(args['alignments_directory'] + fn).st_size != 0: a.append(fn) ref_prefix = '_'.join(args['references']) + '_' -ref_fns = [args['reference_directories'][r] + r + '_chr' + '?' + \ - mask_suffix + gp.fasta_suffix \ +ref_fns = [args['reference_directories'][r] + r + '_chr' + '?' + + mask_suffix + gp.fasta_suffix for r in args['references']] if ref_only: @@ -30,16 +30,16 @@ ref_fns_chrm = [x.replace('?', chrm) for x in ref_fns] combined_fn = 'run_mafft_' + chrm + '.temp' - concatenate_fasta(ref_fns_chrm, \ + concatenate_fasta(ref_fns_chrm, args['references'], combined_fn) - - align_fn = ref_prefix + 'chr' + chrm + \ - '_mafft' + gp.alignment_suffix + + align_fn = (ref_prefix + 'chr' + chrm + + '_mafft' + gp.alignment_suffix) align_fn_abs = args['alignments_directory'] + align_fn - cmd_string = gp.mafft_install_path + '/mafft ' + \ - combined_fn + ' > ' + align_fn_abs + '; ' - + cmd_string = (gp.mafft_install_path + '/mafft ' + + combined_fn + ' > ' + align_fn_abs + '; ') + cmd_string += 'rm ' + combined_fn + ';' print(cmd_string) @@ -60,7 +60,6 @@ # shell instance every time (I think there's a limit on the # command character count or something which is why we're not # making a single string for all strains) -#cmd_string = '' current_strain_fn = d + strain_fn.replace('*', strain) @@ -74,7 +73,7 @@ # if we don't already have an alignment for this strain/chromosome # (or that alignment file is empty), then make one - #if (align_fn not in a) or (os.stat(align_fn_abs).st_size == 0): + # if (align_fn not in a) or (os.stat(align_fn_abs).st_size == 0): if align_fn not in a: cmd_string = '' @@ -82,18 +81,18 @@ ref_fns_chrm = [x.replace('?', chrm) for x in ref_fns] current_strain_fn_chrm = current_strain_fn.replace('?', chrm) combined_fn = 'run_mafft_' + strain + chrm + '.temp' - - concatenate_fasta(ref_fns_chrm + [current_strain_fn_chrm], \ + + concatenate_fasta(ref_fns_chrm + [current_strain_fn_chrm], args['references'] + [strain], combined_fn) - + # add --ep 0.123 to maybe get shorter alignment - #cmd_string += gp.mafft_install_path + '/mafft --ep 0.123 ' + \ + # cmd_string += gp.mafft_install_path + '/mafft --ep 0.123 ' + \ # combined_fn + ' > ' + align_fn_abs + '; ' - #cmd_string += gp.mafft_install_path + '/mafft --retree 1 ' + \ + # cmd_string += gp.mafft_install_path + '/mafft --retree 1 ' + \ # combined_fn + ' > ' + align_fn_abs + '; ' cmd_string += gp.mafft_install_path + '/mafft ' + \ combined_fn + ' > ' + align_fn_abs + '; ' - + cmd_string += 'rm ' + combined_fn + ';' print(cmd_string) @@ -109,7 +108,3 @@ else: print("already did this alignment: " + strain + ' chromosome ' + chrm) sys.stdout.flush() - -#print cmd_string -#sys.stdout.flush() -#os.system(cmd_string) diff --git a/code/align/run_mafft_refs.py b/code/align/run_mafft_refs.py index 21105c1..1242531 100644 --- a/code/align/run_mafft_refs.py +++ b/code/align/run_mafft_refs.py @@ -2,8 +2,7 @@ import sys import os -from align_helpers import * -sys.path.insert(0, '..') +from align.align_helpers import concatenate_fasta import global_params as gp masked = False @@ -16,9 +15,9 @@ if gp.resume_alignment: a = os.listdir(gp_dir + gp.alignments_dir) -ref_prefix = '_'.join(gp.alignment_ref_order) -ref_fns = [gp.ref_dir[r] + gp.ref_fn_prefix[r] + '_chr' + '?' + \ - mask_suffix + gp.fasta_suffix \ +ref_prefix = '_'.join(gp.alignment_ref_order) +ref_fns = [gp.ref_dir[r] + gp.ref_fn_prefix[r] + '_chr' + '?' + + mask_suffix + gp.fasta_suffix for r in gp.alignment_ref_order] @@ -26,11 +25,11 @@ # shell instance every time (I think there's a limit on the # command character count or something which is why we're not # making a single string for all strains) -#cmd_string = '' +# cmd_string = '' chrm = gp.chrms[int(sys.argv[1])] -print chrm +print(chrm) sys.stdout.flush() align_fn = ref_prefix + '_chr' + chrm + \ @@ -44,28 +43,24 @@ # first put all sequences in same (temporary) file ref_fns_chrm = [x.replace('?', chrm) for x in ref_fns] combined_fn = 'run_mafft_' + chrm + '.temp' - - concatenate_fasta(ref_fns_chrm, \ + + concatenate_fasta(ref_fns_chrm, gp.alignment_ref_order, combined_fn) - - cmd_string += gp.mafft_install_path + '/mafft ' + \ - combined_fn + ' > ' + align_fn_abs + '; ' - + + cmd_string += (gp.mafft_install_path + '/mafft ' + + combined_fn + ' > ' + align_fn_abs + '; ') + cmd_string += 'rm ' + combined_fn + ';' - - print cmd_string + + print(cmd_string) sys.stdout.flush() os.system(cmd_string) # want some kind of indication if alignment fails (due to # running out of memory probably) if os.stat(align_fn_abs).st_size == 0: - print 'alignment failed:' + ' chromosome ' + chrm + print('alignment failed: chromosome ' + chrm) sys.stdout.flush() sys.exit() else: - print "already did this alignment:" + ' chromosome ' + chrm - -#print cmd_string -#sys.stdout.flush() -#os.system(cmd_string) + print('already did this alignment: chromosome ' + chrm) diff --git a/code/align/run_mugsy.py b/code/align/run_mugsy.py index e858769..7af436b 100644 --- a/code/align/run_mugsy.py +++ b/code/align/run_mugsy.py @@ -1,7 +1,5 @@ -import sys import os -sys.path.insert(0, '..') -from align_helpers import * +from align.align_helpers import get_strains, flatten import global_params as gp # get all non-reference strains of cerevisiae and paradoxus @@ -22,23 +20,25 @@ ref_dirs = [gp.ref_dir[ref] for ref in gp.alignment_ref_order] for strain, d in s: - print strain + print(strain) cmd_string = cmd_string_start - + for chrm in [gp.chrms[-1]]: align_fn = ref_prefix + strain + '_chr' + chrm + gp.alignment_suffix - # if we don't already have an alignment for this strain/chromosome, then make one + # if we don't already have an alignment for this strain/chromosome, + # then make one if align_fn not in a: cmd_string += gp.mugsy_install_path + '/mugsy ' + \ '--directory ' + gp_dir + gp.alignments_dir + ' ' + \ '--prefix ' + ref_prefix + strain + '_chr' + chrm for ref in gp.alignment_ref_order: - cmd_string += ' ' + gp.ref_dir[ref] + '/' + \ - gp.ref_fn_prefix[ref] + '_chr' + chrm + gp.fasta_suffix - cmd_string += ' ' + d + '/' + strain + '_chr' + chrm + gp.fasta_suffix + '; ' - + cmd_string += (' ' + gp.ref_dir[ref] + '/' + + gp.ref_fn_prefix[ref] + '_chr' + + chrm + gp.fasta_suffix) + cmd_string += (' ' + d + '/' + strain + + '_chr' + chrm + gp.fasta_suffix + '; ') # commands can only be up to a certain length so break it up this way - print cmd_string + print(cmd_string) os.system(cmd_string) diff --git a/code/align/run_tcoffee.py b/code/align/run_tcoffee.py index c94a3af..dcf48bb 100644 --- a/code/align/run_tcoffee.py +++ b/code/align/run_tcoffee.py @@ -1,7 +1,6 @@ import sys import os -sys.path.insert(0, '..') -from align_helpers import * +from align.align_helpers import (get_strains, flatten, concatenate_fasta) import global_params as gp # get all non-reference strains of cerevisiae and paradoxus @@ -14,13 +13,14 @@ ref_prefix = '_'.join(gp.alignment_ref_order) + '_' -ref_fns = [gp.ref_dir[r] + gp.ref_fn_prefix[r] + '_chr' + '?' + gp.fasta_suffix \ - for r in gp.alignment_ref_order] +ref_fns = [gp.ref_dir[r] + gp.ref_fn_prefix[r] + + '_chr' + '?' + gp.fasta_suffix + for r in gp.alignment_ref_order] strain_fn = '*_chr?' + gp.fasta_suffix for strain, d in s: - print strain + print(strain) # building up one command string so that we don't create a new # shell instance every time (I think there's a limit on the @@ -29,9 +29,9 @@ cmd_string = '' current_strain_fn = d + strain_fn.replace('*', strain) - + for chrm in gp.chrms[:2]: - print chrm + print(chrm) align_fn = ref_prefix + strain + '_chr' + chrm + \ '_tcoffee' + gp.alignment_suffix # if we don't already have an alignment for this @@ -42,14 +42,15 @@ current_strain_fn_chrm = current_strain_fn.replace('?', chrm) combined_fn = 'run_tcoffee_' + strain + chrm + '.temp' - concatenate_fasta(ref_fns_chrm + [current_strain_fn_chrm], combined_fn) + concatenate_fasta(ref_fns_chrm + [current_strain_fn_chrm], + combined_fn) cmd_string += gp.tcoffee_install_path + '/t_coffee ' + \ combined_fn + '; ' - #cmd_string += 'rm ' + combined_fn + ';' + # cmd_string += 'rm ' + combined_fn + ';' # commands can only be up to a certain length so break it up this way - print cmd_string + print(cmd_string) os.system(cmd_string) sys.exit() diff --git a/code/analyze/filter_1_main.py b/code/analyze/filter_1_main.py deleted file mode 100644 index ea72c89..0000000 --- a/code/analyze/filter_1_main.py +++ /dev/null @@ -1,76 +0,0 @@ -# two levels of filtering: -# 1. remove regions that don't look confidently introgressed at all, -# based on fraction gaps/masked, number of matches to S288c and not S288c -# --> _filtered1 -# 2. remove regions that we can't confidently pin on a specific reference, -# based on whether it matches similarly to other reference(s) -# --> _filtered2 - -# just do the first level here, then run filter_2_thresholds_main.py -# to choose filtering thresholds for next level - - -import sys -from analyze import predict -from analyze.filter_helpers import passes_filters1, write_filtered_line -import global_params as gp -from misc import read_table -from misc.region_reader import Region_Reader - - -# TODO make this, filter 2 and thresholds a single method then refactor -def main() -> None: - ''' - Perform first step of filtering - Input files: - -blocks_{species}_quality.txt - - Output files: - -blocks_{species}_filtered1intermediate.txt - -blocks_{species}_filtered1.txt - -regions/{species}.fa.gz - -regions/{species}.pkl - ''' - args = predict.process_predict_args(sys.argv[1:]) - out_dir = gp.analysis_out_dir_absolute + args['tag'] - - for species_from in args['known_states'][1:]: - - print(species_from) - - region_summary, fields = read_table.read_table_rows( - f'{out_dir}/blocks_{species_from}_{args["tag"]}_quality.txt', - '\t') - - fields1i = fields + ['reason'] - fields1 = fields - - with open(f'{out_dir}/blocks_{species_from}_{args["tag"]}' - '_filtered1intermediate.txt', 'w') as f_out1i, \ - open(f'{out_dir}/blocks_{species_from}_{args["tag"]}' - '_filtered1.txt', 'w') as f_out1, \ - Region_Reader(f'{out_dir}/regions/{species_from}.fa.gz', - as_fa=True) as region_reader: - - f_out1i.write('\t'.join(fields1i) + '\n') - f_out1.write('\t'.join(fields1) + '\n') - - for region_id, header, seqs in region_reader.yield_fa(): - region = region_summary[region_id] - info_string = seqs[-1] - seqs = seqs[:-1] - - # filtering stage 1: things that we're confident in calling not - # S288c - p, reason = passes_filters1(region, - info_string, - args['known_states'][0]) - region['reason'] = reason - write_filtered_line(f_out1i, region_id, region, fields1i) - - if p: - write_filtered_line(f_out1, region_id, region, fields1) - - -if __name__ == "__main__": - main() diff --git a/code/analyze/filter_2_main.py b/code/analyze/filter_2_main.py deleted file mode 100644 index bda92bd..0000000 --- a/code/analyze/filter_2_main.py +++ /dev/null @@ -1,80 +0,0 @@ -# two levels of filtering: -# 1. remove regions that don't look confidently introgressed at all, -# based on fraction gaps/masked, number of matches to S288c and not S288c -# --> _filtered1 -# 2. remove regions that we can't confidently pin on a specific reference, -# based on whether it matches similarly to other reference(s) -# --> _filtered2 - -# do second level of filtering here, based on previously selected -# thresholds - -import sys -from analyze import predict -from analyze.filter_helpers import (write_filtered_line, - passes_filters2) -import global_params as gp -from misc import read_table -from misc.region_reader import Region_Reader - - -def main() -> None: - ''' - Perform second stage of filtering - Input files: - -blocks_{species}_filtered1.txt - regions/{species}.fa.gz - regions/{species}.pkl - - Output files: - -blocks_{species}_filtered2.txt - -blocks_{species}_filtered2intermediate.txt - ''' - args = predict.process_predict_args(sys.argv[2:]) - threshold = float(sys.argv[1]) - out_dir = gp.analysis_out_dir_absolute + args['tag'] - - for species_from in args['known_states'][1:]: - - print(species_from) - - region_summary, fields = read_table.read_table_rows( - f'{out_dir}/blocks_{species_from}_{args["tag"]}_filtered1.txt', - '\t') - - fields2i = fields + ['alternative_states', 'alternative_ids', - 'alternative_P_counts'] - fields2 = fields - - with open(f'{out_dir}/blocks_{species_from}_{args["tag"]}' - '_filtered2intermediate.txt', 'w') as f_out2i, \ - open(f'{out_dir}/blocks_{species_from}_{args["tag"]}' - '_filtered2.txt', 'w') as f_out2, \ - Region_Reader(f'{out_dir}/regions/{species_from}.fa.gz', - as_fa=True) as region_reader: - - f_out2i.write('\t'.join(fields2i) + '\n') - f_out2.write('\t'.join(fields2) + '\n') - - for region_id, header, seqs in \ - region_reader.yield_fa(region_summary.keys()): - region = region_summary[region_id] - - seqs = seqs[:-1] - - # filtering stage 2: things that we're confident in calling - # introgressed from one species specifically - p, alt_states, alt_ids, alt_P_counts = passes_filters2( - region, seqs, threshold, args['known_states']) - region['alternative_states'] = ','.join(alt_states) - region['alternative_ids'] = ','.join([str(x) for x in alt_ids]) - region['alternative_P_counts'] = ','.join( - [str(x) for x in alt_P_counts]) - write_filtered_line(f_out2i, region_id, region, fields2i) - - if p: - write_filtered_line(f_out2, region_id, region, fields2) - - -if __name__ == '__main__': - main() diff --git a/code/analyze/filter_2_thresholds_main.py b/code/analyze/filter_2_thresholds_main.py deleted file mode 100644 index f6cb61a..0000000 --- a/code/analyze/filter_2_thresholds_main.py +++ /dev/null @@ -1,97 +0,0 @@ -# explore different thresholds for calling introgressions for specific -# strains - -# specifically, try a range of thresholds, and for each one, calculate -# fraction of introgressions we've classified as 1 strain or every -# possible combination of strains - -# then we'll make some plots in R to see if there's a sort of obvious -# place to draw the line - -import sys -from analyze import predict -from analyze.filter_helpers import passes_filters2 -import global_params as gp -from misc import read_table -from misc.region_reader import Region_Reader - - -thresholds = [.999, .995, .985, .975, .965, .955, .945, - .935, .925, .915, .905, .89, .87, .86] -# thresholds = [.99, .98, .97, .96, .95, .94, .93, .92, -# .91, .9, .88, .85, .82, .8, .75, .7, .6, .5] -# thresholds = [1] - - -def main() -> None: - ''' - Perform second stage of filtering with several threshold levels - Input files: - -blocks_{species}_filtered1.txt - -regions/{species}.fa.gz - -regions/{species}.pkl - - Output files: - -filter_2_thresholds.txt - ''' - args = predict.process_predict_args(sys.argv[1:]) - out_dir = gp.analysis_out_dir_absolute + args['tag'] - - open_mode = 'w' - with open(f'{out_dir}/filter_2_thresholds_{args["tag"]}.txt', open_mode)\ - as writer: - if open_mode == 'w': - writer.write( - 'threshold\tpredicted_state\talternative_states\tcount\n') - - data_table = {} - for species_from in args['known_states'][1:]: - print(f'* {species_from}') - - region_summary, fields = read_table.read_table_rows( - f'{out_dir}/blocks_{species_from}' - f'_{args["tag"]}_filtered1.txt', - '\t') - - with Region_Reader(f'{out_dir}/regions/{species_from}.fa.gz', - as_fa=True) as region_reader: - for region_id, header, seqs in \ - region_reader.yield_fa(region_summary.keys()): - - region = region_summary[region_id] - seqs = seqs[:-1] - - for threshold in thresholds: - _, alt_states, _, _ = \ - passes_filters2(region, seqs, threshold) - - record_data_hit(data_table, - threshold, - species_from, - ','.join(sorted(alt_states))) - - for threshold in thresholds: - for species in args['known_states'][1:]: - d = data_table[threshold][species] - for key in d.keys(): - writer.write(f'{threshold}\t{species}\t{key}\t{d[key]}\n') - - -def record_data_hit(data_dict, threshold, species, key): - ''' - adds an entry to the data table or increments if exists - ''' - if threshold not in data_dict: - data_dict[threshold] = {} - - if species not in data_dict[threshold]: - data_dict[threshold][species] = {} - - if key not in data_dict[threshold][species]: - data_dict[threshold][species][key] = 0 - - data_dict[threshold][species][key] += 1 - - -if __name__ == "__main__": - main() diff --git a/code/analyze/filter_helpers.py b/code/analyze/filter_helpers.py deleted file mode 100644 index c263436..0000000 --- a/code/analyze/filter_helpers.py +++ /dev/null @@ -1,178 +0,0 @@ -import global_params as gp -from misc import seq_functions -import numpy as np -from typing import List, Dict, TextIO, Tuple - - -def write_filtered_line(writer: TextIO, - region_id: str, - region: Dict, - fields: List) -> None: - ''' - Write the region id and values in "region" dict to open file writer - ''' - writer.write(f'{region_id}\t' - + '\t'.join([str(region[field]) - for field in fields[1:]]) - + '\n') - - -def passes_filters(region: Dict) -> bool: - ''' - test if the supplied region satisfies: - -Fraction of gaps and masked < 0.5 - -Number of matching > 7 - -Divergence < 0.7 - ''' - # fraction gaps + masked filter - fraction_gaps_masked_threshold = .5 - fraction_gaps_masked = \ - (float(region['number_gaps']) + - float(region['number_masked_non_gap'])) / \ - (int(region['end']) - int(region['start']) + 1) - if fraction_gaps_masked > fraction_gaps_masked_threshold: - return False - - # number sites match only par filter - number_match_only_threshold = 7 - number_match_only = int(region['number_match_ref2_not_ref1']) - if number_match_only < number_match_only_threshold: - return False - - # divergence from cer filter (idea is that poor alignments will - # result in much larger divergence than we'd expect) - id_ref1_threshold = .7 - id_ref1 = float(region['number_match_ref1']) / \ - (float(region['aligned_length']) - float(region['number_gaps'])) - if id_ref1 < id_ref1_threshold: - return False - - return True - - -def passes_filters1(region: Dict, - info: str, - reference_species: str) -> Tuple[bool, str]: - ''' - filtering out things that we can't call introgressed in general - with confidence (i.e. doesn't seem like a strong case against - being S288c) - Return true if the region passes the filter, or false with a string - specifying which filter failed - Tests: - -fraction of gaps masked in reference > 0.5 - -fraction of gaps masked in predicted species > 0.5 - -number of matches to predicted > 7 - -number of matches to predicted > number matches to reference - -divergence with predicted species - ''' - - predicted_species = region['predicted_species'] - - aligned_length = (int(region['end']) - int(region['start']) + 1) - - # FILTER: fraction gaps + masked - fraction_gaps_masked_threshold = .5 - # num_sites_nonmask_x is number of sites at which neither - # reference x nor the test sequence is masked or has a gap or - # unsequenced character - fraction_gaps_masked_r = \ - 1 - region['num_sites_nonmask_' + reference_species] / aligned_length - fraction_gaps_masked_s = \ - 1 - region['num_sites_nonmask_' + predicted_species] / aligned_length - - if fraction_gaps_masked_r > fraction_gaps_masked_threshold: - return False, f'fraction gaps/masked in master = '\ - f'{fraction_gaps_masked_r}' - if fraction_gaps_masked_s > fraction_gaps_masked_threshold: - return False, f'fraction gaps/masked in predicted = '\ - f'{fraction_gaps_masked_s}' - - # FILTER: number sites analyzed by HMM that match predicted (P) - # reference (C) - count_P = info.count('P') - count_C = info.count('C') - number_match_only_threshold = 7 - if count_P < number_match_only_threshold: - return False, f'count_P = {count_P}' - if count_P <= count_C: - return False, f'count_P = {count_P} and count_C = {count_C}' - - # FILTER: divergence with predicted reference and master reference - # (S288c) - id_predicted = float(region['match_nongap_' + predicted_species]) / \ - float(region['num_sites_nongap_' + predicted_species]) - id_master = float(region['match_nongap_' + reference_species]) / \ - float(region['num_sites_nongap_' + reference_species]) - - if id_master >= id_predicted: - return False, f'id with master = {id_master} '\ - f'and id with predicted = {id_predicted}' - if id_master < .7: - return False, f'id with master = {id_master}' - - return True, '' - - -def passes_filters2(region: Dict, - seqs: np.array, - threshold: float, - refs: List[str]) -> Tuple[bool, - List[str], - List[float], - List[int]]: - ''' - filter out things we can't assign to one species specifically; - return the other reasonable alternatives if we're filtering - it out - Returns a tuple of: - True if the region passes the filter - A list of likely species for the region - A list of fraction of matching sequence for each species - A list of total matching sites - Fails the filter if number of matches and fraction matching are >= more - than one state for the region - ''' - - s = region['predicted_species'] - - ids = {} - P_counts = {} - - seqs = np.asarray(seqs) - # skip any gap or unsequenced in ref or test - # also skip if ref and test equal (later test ri == test but not ref) - skip = np.any( - (seqs[0] == gp.gap_symbol, - seqs[0] == gp.unsequenced_symbol, - seqs[-1] == gp.gap_symbol, - seqs[-1] == gp.unsequenced_symbol, - seqs[0] == seqs[-1]), - axis=0) - - for ri, ref in enumerate(refs): - if ri == 0: - continue - r_match, r_total = seq_functions.seq_id(seqs[-1], seqs[ri]) - if r_total != 0: - ids[ref] = r_match / r_total - P_counts[ref] = np.sum( - np.logical_and( - np.logical_not(skip), - seqs[ri] == seqs[-1])) - - alts = {} - for r in ids.keys(): - # TODO should threshold be the same for both? - if ids[r] >= threshold * ids[s] and \ - P_counts[r] >= threshold * P_counts[s]: - alts[r] = (ids[r], P_counts[r]) - - alt_states = sorted(alts.keys(), key=lambda x: alts[x][0], reverse=True) - alt_ids = [alts[state][0] for state in alt_states] - alt_P_counts = [alts[state][1] for state in alt_states] - - if len(alts) > 1: - return False, alt_states, alt_ids, alt_P_counts - - return True, alt_states, alt_ids, alt_P_counts diff --git a/code/analyze/filter_regions.py b/code/analyze/filter_regions.py index 50f6869..4d1695f 100644 --- a/code/analyze/filter_regions.py +++ b/code/analyze/filter_regions.py @@ -1,272 +1,409 @@ -import global_params as gp from misc import seq_functions import numpy as np from typing import List, Dict, TextIO, Tuple -import sys -from contextlib import ExitStack -from analyze import predict +import click +import logging as log +from contextlib import ExitStack, contextmanager from misc import read_table from misc.region_reader import Region_Reader - - -def write_filtered_line(writer: TextIO, - region_id: str, - region: Dict, - fields: List) -> None: - ''' - Write the region id and values in "region" dict to open file writer - ''' - writer.write(f'{region_id}\t' - + '\t'.join([str(region[field]) - for field in fields[1:]]) - + '\n') - - -def filter_introgressed(region: Dict, - info: str, - reference_species: str) -> Tuple[bool, str]: - ''' - filtering out things that we can't call introgressed in general - with confidence (i.e. doesn't seem like a strong case against - being S288c) - Return true if the region passes the filter, or false with a string - specifying which filter failed - Tests: - -fraction of gaps masked in reference > 0.5 - -fraction of gaps masked in predicted species > 0.5 - -number of matches to predicted > 7 - -number of matches to predicted > number matches to reference - -divergence with predicted species - ''' - - predicted_species = region['predicted_species'] - - aligned_length = (int(region['end']) - int(region['start']) + 1) - - # FILTER: fraction gaps + masked - fraction_gaps_masked_threshold = .5 - # num_sites_nonmask_x is number of sites at which neither - # reference x nor the test sequence is masked or has a gap or - # unsequenced character - fraction_gaps_masked_r = \ - 1 - region['num_sites_nonmask_' + reference_species] / aligned_length - fraction_gaps_masked_s = \ - 1 - region['num_sites_nonmask_' + predicted_species] / aligned_length - - if fraction_gaps_masked_r > fraction_gaps_masked_threshold: - return False, f'fraction gaps/masked in master = '\ - f'{fraction_gaps_masked_r}' - if fraction_gaps_masked_s > fraction_gaps_masked_threshold: - return False, f'fraction gaps/masked in predicted = '\ - f'{fraction_gaps_masked_s}' - - # FILTER: number sites analyzed by HMM that match predicted (P) - # reference (C) - count_P = info.count('P') - count_C = info.count('C') - number_match_only_threshold = 7 - if count_P < number_match_only_threshold: - return False, f'count_P = {count_P}' - if count_P <= count_C: - return False, f'count_P = {count_P} and count_C = {count_C}' - - # FILTER: divergence with predicted reference and master reference - # (S288c) - id_predicted = float(region['match_nongap_' + predicted_species]) / \ - float(region['num_sites_nongap_' + predicted_species]) - id_master = float(region['match_nongap_' + reference_species]) / \ - float(region['num_sites_nongap_' + reference_species]) - - if id_master >= id_predicted: - return False, f'id with master = {id_master} '\ - f'and id with predicted = {id_predicted}' - if id_master < .7: - return False, f'id with master = {id_master}' - - return True, '' - - -def filter_ambiguous(region: Dict, - seqs: np.array, - threshold: float, - refs: List[str]) -> Tuple[bool, - List[str], - List[float], - List[int]]: - ''' - filter out things we can't assign to one species specifically; - return the other reasonable alternatives if we're filtering - it out - Returns a tuple of: - True if the region passes the filter - A list of likely species for the region - A list of fraction of matching sequence for each species - A list of total matching sites - Fails the filter if number of matches and fraction matching are >= more - than one state for the region - ''' - - s = region['predicted_species'] - - ids = {} - P_counts = {} - - seqs = np.asarray(seqs) - # skip any gap or unsequenced in ref or test - # also skip if ref and test equal (later test ri == test but not ref) - skip = np.any( - (seqs[0] == gp.gap_symbol, - seqs[0] == gp.unsequenced_symbol, - seqs[-1] == gp.gap_symbol, - seqs[-1] == gp.unsequenced_symbol, - seqs[0] == seqs[-1]), - axis=0) - - for ri, ref in enumerate(refs): - if ri == 0: - continue - r_match, r_total = seq_functions.seq_id(seqs[-1], seqs[ri]) - if r_total != 0: - ids[ref] = r_match / r_total - P_counts[ref] = np.sum( - np.logical_and( - np.logical_not(skip), - seqs[ri] == seqs[-1])) - - alts = {} - for r in ids.keys(): - # TODO should threshold be the same for both? - if ids[r] >= threshold * ids[s] and \ - P_counts[r] >= threshold * P_counts[s]: - alts[r] = (ids[r], P_counts[r]) - - alt_states = sorted(alts.keys(), key=lambda x: alts[x][0], reverse=True) - alt_ids = [alts[state][0] for state in alt_states] - alt_P_counts = [alts[state][1] for state in alt_states] - - if len(alts) > 1: - return False, alt_states, alt_ids, alt_P_counts - - return True, alt_states, alt_ids, alt_P_counts - - -def main(thresholds=[]): - ''' - Perform first step of filtering - Input files: - -blocks_{species}_quality.txt - - Output files: - -blocks_{species}_filtered1intermediate.txt - -blocks_{species}_filtered1.txt - -regions/{species}.fa.gz - -regions/{species}.pkl - ''' - # thresholds = [.999, .995, .985, .975, .965, .955, .945, - # .935, .925, .915, .905, .89, .87, .86] - args = predict.process_predict_args(sys.argv[2:]) - out_dir = gp.analysis_out_dir_absolute + args['tag'] - threshold = float(sys.argv[1]) - - with ExitStack() as stack: - if thresholds != []: - threshold_writer = stack.enter_context( - open(f'{out_dir}/filter_2_thresholds_{args["tag"]}.txt', 'w')) - threshold_writer.write( +from analyze.introgression_configuration import Configuration + + +class Filterer(): + def __init__(self, configuration: Configuration): + self.config = configuration + + def filter_introgressed(self, + region: Dict, + info: str, + reference_species: str) -> Tuple[bool, str]: + ''' + filtering out things that we can't call introgressed in general + with confidence (i.e. doesn't seem like a strong case against + being S288c) + Return true if the region passes the filter, or false with a string + specifying which filter failed + Tests: + -fraction of gaps masked in reference > 0.5 + -fraction of gaps masked in predicted species > 0.5 + -number of matches to predicted > 7 + -number of matches to predicted > number matches to reference + -divergence with predicted species + ''' + + predicted_species = region['predicted_species'] + + aligned_length = (int(region['end']) - int(region['start']) + 1) + + # FILTER: fraction gaps + masked + fraction_gaps_masked_threshold = .5 + # num_sites_nonmask_x is number of sites at which neither + # reference x nor the test sequence is masked or has a gap or + # unsequenced character + fraction_gaps_masked_r = \ + 1 - int(region['num_sites_nonmask_' + + reference_species])/aligned_length + fraction_gaps_masked_s = \ + 1 - int(region['num_sites_nonmask_' + + predicted_species])/aligned_length + + if fraction_gaps_masked_r > fraction_gaps_masked_threshold: + return False, f'fraction gaps/masked in master = '\ + f'{fraction_gaps_masked_r}' + if fraction_gaps_masked_s > fraction_gaps_masked_threshold: + return False, f'fraction gaps/masked in predicted = '\ + f'{fraction_gaps_masked_s}' + + # FILTER: number sites analyzed by HMM that match predicted (P) + # reference (C) + count_P = np.sum(info == 'P') + count_C = np.sum(info == 'C') + number_match_only_threshold = 7 + if count_P < number_match_only_threshold: + return False, f'count_P = {count_P}' + if count_P <= count_C: + return False, f'count_P = {count_P} and count_C = {count_C}' + + # FILTER: divergence with predicted reference and master reference + # (S288c) + id_predicted = float(region['match_nongap_' + predicted_species]) / \ + float(region['num_sites_nongap_' + predicted_species]) + id_master = float(region['match_nongap_' + reference_species]) / \ + float(region['num_sites_nongap_' + reference_species]) + + if id_master >= id_predicted: + return False, f'id with master = {id_master} '\ + f'and id with predicted = {id_predicted}' + if id_master < .7: + return False, f'id with master = {id_master}' + + return True, '' + + def filter_ambiguous(self, + region: Dict, + seqs: np.array, + threshold: float, + refs: List[str]) -> Tuple[bool, + List[str], + List[float], + List[int]]: + ''' + filter out things we can't assign to one species specifically; + return the other reasonable alternatives if we're filtering + it out + Returns: + True if the region passes the filter + Fails the filter if number of matches and fraction matching are >= more + than one state for the region + Region is updated with: + A list of likely species for the region + A list of fraction of matching sequence for each species + A list of total matching sites + ''' + + s = region['predicted_species'] + + ids = {} + P_counts = {} + + seqs = np.asarray(seqs) + # skip any gap or unsequenced in ref or test + # also skip if ref and test equal (later test ri == test but not ref) + symbols = self.config.symbols + skip = np.any( + (seqs[0] == symbols['gap'], + seqs[0] == symbols['unsequenced'], + seqs[-1] == symbols['gap'], + seqs[-1] == symbols['unsequenced'], + seqs[0] == seqs[-1]), + axis=0) + + for ri, ref in enumerate(refs): + if ri == 0: + continue + r_match, r_total = seq_functions.seq_id(seqs[-1], seqs[ri]) + if r_total != 0: + ids[ref] = r_match / r_total + P_counts[ref] = np.sum( + np.logical_and( + np.logical_not(skip), + seqs[ri] == seqs[-1])) + + alts = {} + for r in ids.keys(): + # TODO should threshold be the same for both? + if ids[r] >= threshold * ids[s] and \ + P_counts[r] >= threshold * P_counts[s]: + alts[r] = (ids[r], P_counts[r]) + + alt_states = sorted(alts.keys(), + key=lambda x: alts[x][0], + reverse=True) + region['alternative_states'] = ','.join(alt_states) + + alt_ids = [alts[state][0] for state in alt_states] + region['alternative_ids'] = ','.join( + [str(x) for x in alt_ids]) + + alt_P_counts = [alts[state][1] for state in alt_states] + region['alternative_P_counts'] = ','.join( + [str(x) for x in alt_P_counts]) + + return len(alts) <= 1, alt_states + + def validate_arguments(self): + args = [ + 'introgressed', + 'introgressed_intermediate', + 'ambiguous', + 'ambiguous_intermediate', + 'filter_threshold', + 'known_states', + 'regions', + 'region_index', + 'symbols', + 'quality_blocks' + ] + variables = self.config.__dict__ + for arg in args: + if arg not in variables or variables[arg] is None: + err = ('Failed to validate Filterer, required argument ' + f"'{arg}' was unset") + log.exception(err) + raise ValueError(err) + + if 'filter_sweep' not in variables or \ + variables['filter_sweep'] is None: + log.warning(f"'filter_sweep' was unset and will not be run") + + def run(self, thresholds=[]): + ''' + Filter region files based on thresold in config and sweep + with the supplied threshold list + ''' + self.validate_arguments() + known_states = self.config.known_states + log.debug(f'Known states: {known_states}') + + with Filter_Sweep(self.config.filter_sweep, thresholds) as sweeper,\ + ExitStack() as stack: + + progress_bar = None + if self.config.log_file: + progress_bar = stack.enter_context( + click.progressbar( + length=len(known_states[1:]), + label='Filtering')) + + sweeper.write_header() + writers = Filter_Writers(self.config) + + for species_from in known_states[1:]: + + log.info(species_from) + + region_summary, fields = read_table.read_table_rows( + self.config.quality_blocks.format(state=species_from), + '\t') + + with writers.open_state(species_from, fields) as writers,\ + Region_Reader(self.config.regions.format( + state=species_from), as_fa=True) as region_reader: + + writers.write_headers() + + for region_id, _, seqs in region_reader.yield_fa(): + region = region_summary[region_id] + seqs, info_string = seqs[:-1], seqs[-1] + + # filtering stage 1: things that we're confident in + # calling not S288c + passes, reason = self.filter_introgressed( + region, + info_string, + known_states[0]) + region['reason'] = reason + + writers.write_introgressed(region_id, region, passes) + + if passes: + sweeper.record( + species_from, + lambda thresh: self.filter_ambiguous( + region, seqs, thresh, known_states)) + + passes, _ = self.filter_ambiguous( + region, seqs, + self.config.filter_threshold, known_states) + writers.write_ambiguous(region_id, region, passes) + + if progress_bar: + progress_bar.update(1) + + sweeper.write_results(known_states[1:]) + + +class Filter_Sweep(): + def __init__(self, + sweep_file: str, + thresholds: List[float]): + self.sweep_file = sweep_file + self.sweep_writer = None + self.thresholds = thresholds + self.data_table = {} + + def __enter__(self): + if self.sweep_file is not None and self.thresholds != []: + self.sweep_writer = open(self.sweep_file, 'w') + + return self + + def __exit__(self, type, value, traceback): + if self.sweep_writer: + self.sweep_writer.close() + + return traceback is None + + def write_header(self): + ''' + Write the header for the sweep filter file + ''' + if self.sweep_writer: + self.sweep_writer.write( 'threshold\tpredicted_state\talternative_states\tcount\n') - data_table = {} - - for species_from in args['known_states'][1:]: - - print(species_from) - - region_summary, fields = read_table.read_table_rows( - f'{out_dir}/blocks_{species_from}_{args["tag"]}_quality.txt', - '\t') - - fields1i = fields + ['reason'] - fields1 = fields - fields2i = fields + ['alternative_states', 'alternative_ids', - 'alternative_P_counts'] - fields2 = fields - - with open(f'{out_dir}/blocks_{species_from}_{args["tag"]}' - '_filtered1intermediate.txt', 'w') as f_out1i, \ - open(f'{out_dir}/blocks_{species_from}_{args["tag"]}' - '_filtered1.txt', 'w') as f_out1, \ - open(f'{out_dir}/blocks_{species_from}_{args["tag"]}' - '_filtered2intermediate.txt', 'w') as f_out2i, \ - open(f'{out_dir}/blocks_{species_from}_{args["tag"]}' - '_filtered2.txt', 'w') as f_out2, \ - Region_Reader(f'{out_dir}/regions/{species_from}.fa.gz', - as_fa=True) as region_reader: - - f_out1i.write('\t'.join(fields1i) + '\n') - f_out1.write('\t'.join(fields1) + '\n') - f_out2i.write('\t'.join(fields2i) + '\n') - f_out2.write('\t'.join(fields2) + '\n') - - for region_id, header, seqs in region_reader.yield_fa(): - region = region_summary[region_id] - info_string = seqs[-1] - seqs = seqs[:-1] - - # filtering stage 1: things that we're confident in - # calling not S288c - p, reason = filter_introgressed(region, - info_string, - args['known_states'][0]) - region['reason'] = reason - write_filtered_line(f_out1i, region_id, region, fields1i) - - if p: - write_filtered_line(f_out1, region_id, region, fields1) - - for thresh in thresholds: - _, alt_states, _, _ = \ - filter_ambiguous(region, seqs, thresh, - args['known_states']) - - record_data_hit(data_table, - thresh, - species_from, - ','.join(sorted(alt_states))) - - (p, alt_states, - alt_ids, alt_P_counts) = filter_ambiguous( - region, seqs, threshold, args['known_states']) - region['alternative_states'] = ','.join(alt_states) - region['alternative_ids'] = ','.join( - [str(x) for x in alt_ids]) - region['alternative_P_counts'] = ','.join( - [str(x) for x in alt_P_counts]) - write_filtered_line(f_out2i, region_id, - region, fields2i) - - if p: - write_filtered_line(f_out2, region_id, - region, fields2) - - for thresh in thresholds: - for species in args['known_states'][1:]: - d = data_table[thresh][species] - for key in d.keys(): - threshold_writer.write( - f'{thresh}\t{species}\t{key}\t{d[key]}\n') - - -def record_data_hit(data_dict, threshold, species, key): + def record(self, species_from, thresh_lambda): + ''' + Record the thresholds for this filter sweep object. + The thresh lambda is an anonymous function that takes a threshold + and returns a tuple with the value at index 1 being the alternative + states. Filter_ambiguous is what this is meant for. + ''' + if self.sweep_writer is None: + return + + for thresh in self.thresholds: + _, states = thresh_lambda(thresh) + self.record_data_hit(thresh, species_from, states) + + def record_data_hit(self, threshold: float, species: str, states: List): + ''' + adds an entry to the data table or increments if exists + ''' + key = ','.join(sorted(states)) + if threshold not in self.data_table: + self.data_table[threshold] = {} + + if species not in self.data_table[threshold]: + self.data_table[threshold][species] = {} + + if key not in self.data_table[threshold][species]: + self.data_table[threshold][species][key] = 0 + + self.data_table[threshold][species][key] += 1 + + def write_results(self, states): + if self.sweep_writer is None: + return + + for thresh in self.thresholds: + for species in states: + if thresh in self.data_table and \ + species in self.data_table[thresh]: + d = self.data_table[thresh][species] + for key, value in d.items(): + self.sweep_writer.write( + f'{thresh}\t{species}\t{key}\t{value}\n') + + +class Filter_Writers(): ''' - adds an entry to the data table or increments if exists + Writes the filter and intermediate files ''' - if threshold not in data_dict: - data_dict[threshold] = {} - - if species not in data_dict[threshold]: - data_dict[threshold][species] = {} - - if key not in data_dict[threshold][species]: - data_dict[threshold][species][key] = 0 - - data_dict[threshold][species][key] += 1 + def __init__(self, config): + self.files = { + 'introgressed': config.introgressed, + 'introgressed_int': config.introgressed_intermediate, + 'ambiguous': config.ambiguous, + 'ambiguous_int': config.ambiguous_intermediate + } + self.headers = None + self.writers = None + + @contextmanager + def open_state(self, state: str, fields: List): + ''' + Open output files for the particular state + ''' + self.headers = { + 'introgressed': fields, + 'introgressed_int': fields + ['reason'], + 'ambiguous': fields, + 'ambiguous_int': fields + ['alternative_states', + 'alternative_ids', + 'alternative_P_counts'] + } + + self.writers = {k: open(v.format(state=state), 'w') + for k, v in self.files.items()} + + yield self + + for writer in self.writers.values(): + writer.close() + + self.headers = None + self.writers = None + + def write_headers(self): + if self.headers is None or self.writers is None: + return + + for key, writer in self.writers.items(): + writer.write('\t'.join(self.headers[key]) + '\n') + + def write_filtered_line(self, + writer: TextIO, + region_id: str, + region: Dict, + fields: List) -> None: + ''' + Write the region id and values in "region" dict to open file writer + ''' + writer.write(f'{region_id}\t') + writer.write('\t'.join([str(region[field]) for field in fields[1:]])) + writer.write('\n') + + def write_introgressed(self, + region_id: str, + region: Dict, + passes: bool): + self.write_filtered_line( + self.writers['introgressed_int'], + region_id, + region, + self.headers['introgressed_int']) + + if passes: + self.write_filtered_line( + self.writers['introgressed'], + region_id, + region, + self.headers['introgressed']) + + def write_ambiguous(self, + region_id: str, + region: Dict, + passes: bool): + self.write_filtered_line( + self.writers['ambiguous_int'], + region_id, + region, + self.headers['ambiguous_int']) + + if passes: + self.write_filtered_line( + self.writers['ambiguous'], + region_id, + region, + self.headers['ambiguous']) diff --git a/code/analyze/introgression_configuration.py b/code/analyze/introgression_configuration.py index 4133c75..c8ccda0 100644 --- a/code/analyze/introgression_configuration.py +++ b/code/analyze/introgression_configuration.py @@ -11,6 +11,53 @@ def __init__(self): self.config = {} self.log_file = None + # these are very regular variables with state as a wildcard + state_files = [ + 'blocks', + 'labeled_blocks', + 'quality_blocks', + 'introgressed', + 'introgressed_intermediate', + 'ambiguous', + 'ambiguous_intermediate', + 'regions', + 'region_index', + ] + # no wildcards, non nullable + nonwild_files = [ + 'hmm_initial', + 'hmm_trained', + 'positions', + 'probabilities', + ] + var_list = [ + Variable('chromosomes'), + Threshold_Variable(), + Convergence_Variable(), + Symbols_Variable(), + Filter_Threshold_Variable(), + Variable('log_file', 'paths.log_file', nullable=True), + Variable('filter_sweep', 'paths.analysis.filter_sweep', + nullable=True), + Variable('masks', 'paths.analysis.masked_intervals', + wildcards='strain,chrom'), + ] + [ + Variable(n, f'paths.analysis.{n}', wildcards='state') + for n in state_files + ] + [ + Variable(n, f'paths.analysis.{n}') + for n in nonwild_files + ] + + self.variables = {v.name: v for v in var_list} + # these require too much state from configuration to split out + self.other_parsers = { + 'states': self._set_states, + 'prefix': self._set_prefix, + 'strains': self._set_strains, + 'alignment': self._set_alignment + } + def add_config(self, configuration: Dict): ''' merge the provided configuration dictionary with this object. @@ -19,6 +66,25 @@ def add_config(self, configuration: Dict): self.config = clean_config( merge_dicts(self.config, configuration)) + def set(self, *args, **kwargs): + ''' + Set the supplied variable to the value provided. + If just a name is provided, set the value with a value of None + ''' + kwargs.update({a: None for a in args}) + for key, value in kwargs.items(): + if key in self.variables: + variable = self.variables[key] + self.__dict__[key] = variable.parse(value, self.config) + + elif key in self.other_parsers: + self.other_parsers[key](value) + + else: + err = f'Unknown variable to set: {key}' + log.exception(err) + raise ValueError(err) + def get_states(self) -> Tuple[List, List]: ''' Build lists of known and unknown states from the analysis params @@ -66,7 +132,7 @@ def get_interval_states(self) -> List: else s['name'] for s in ref + known] - def set_states(self, states: List[str] = None): + def _set_states(self, states: List[str] = None): ''' Set the states for which to perform region naming ''' @@ -83,101 +149,12 @@ def set_states(self, states: List[str] = None): log.exception(err) raise ValueError(err) - def set_log_file(self, log_file: str = ''): - ''' - sets log file based on provided value or config - ''' - if log_file == '': - self.log_file = get_nested(self.config, 'paths.log_file') - else: - self.log_file = log_file - - def set_chromosomes(self): - ''' - Gets the chromosome list from config, raising a ValueError - if undefined. - ''' - self.chromosomes = validate( - self.config, - 'chromosomes', - 'No chromosomes specified in config file!') - - def set_threshold(self, threshold: str = None): - ''' - Set the threshold. Checks if set and converts to float if possible. - Failing float casting, will store a string if it is 'viterbi', - otherwise throws a ValueError - ''' - self.threshold = validate( - self.config, - 'analysis_params.threshold', - 'No threshold provided', - threshold) - try: - self.threshold = float(self.threshold) - except ValueError: - if self.threshold != 'viterbi': - err = f'Unsupported threshold value: {self.threshold}' - log.exception(err) - raise ValueError(err) - - def set_blocks_file(self, blocks: str = None): - ''' - Set the block wildcard filename. Checks for appropriate wildcards - ''' - self.blocks = validate( - self.config, - 'paths.analysis.blocks', - 'No block file provided', - blocks) - - check_wildcards(self.blocks, 'state') - - def set_labeled_blocks_file(self, blocks: str = None): - ''' - Set the labeled block wildcard filename. - Checks for appropriate wildcards - ''' - self.labeled_blocks = validate( - self.config, - 'paths.analysis.labeled_blocks', - 'No labeled block file provided', - blocks) - - check_wildcards(self.labeled_blocks, 'state') - - def set_quality_file(self, quality: str = None): - ''' - Set the quality block wildcard filename. - Checks for appropriate wildcards - ''' - self.quality_blocks = validate( - self.config, - 'paths.analysis.quality', - 'No quality block file provided', - quality) - - check_wildcards(self.quality_blocks, 'state') - - def set_masked_file(self, masks: str = None): - ''' - Set the masked interval block wildcard filename. - Checks for appropriate wildcards - ''' - self.masks = validate( - self.config, - 'paths.analysis.masked_intervals', - 'No masked interval file provided', - masks) - - check_wildcards(self.masks, 'strain,chrom') - - def set_prefix(self, prefix: str = ''): + def _set_prefix(self, prefix: str = ''): ''' Set prefix string of the predictor to the supplied value or build it from the known states ''' - if prefix == '': + if not prefix: if self.known_states == []: err = 'Unable to build prefix, no known states provided' log.exception(err) @@ -187,11 +164,11 @@ def set_prefix(self, prefix: str = ''): else: self.prefix = prefix - def set_strains(self, test_strains: str = ''): + def _set_strains(self, test_strains: str = ''): ''' build the strains to perform prediction on ''' - if test_strains == '': + if not test_strains: test_strains = get_nested(self.config, 'paths.test_strains') else: # need to support list for test strains @@ -262,37 +239,7 @@ def find_strains(self, test_strains: List[str] = None): else: # strains set in config self.strains = list(sorted(set(strains))) - def set_predict_files(self, - hmm_initial: str, - hmm_trained: str, - positions: str, - probabilities: str, - alignment: str): - ''' - Set output files from provided values or config. - Raises value errors if a file is not provided. - Checks alignment for all wildcards and replaces prefix. - ''' - self.hmm_initial = validate(self.config, - 'paths.analysis.hmm_initial', - 'No initial hmm file provided', - hmm_initial) - - self.hmm_trained = validate(self.config, - 'paths.analysis.hmm_trained', - 'No trained hmm file provided', - hmm_trained) - - self.set_positions(positions) - - self.probabilities = validate(self.config, - 'paths.analysis.probabilities', - 'No probabilities file provided', - probabilities) - - self.set_alignment(alignment) - - def set_alignment(self, alignment: str): + def _set_alignment(self, alignment: str): ''' Set the alignment file, checking wildcards prefix, strain and chrom. If prefix is present, it is substituted, otherwise checks just @@ -300,49 +247,125 @@ def set_alignment(self, alignment: str): ''' alignment = validate(self.config, 'paths.analysis.alignment', - 'No alignment file provided', + 'No alignment provided', alignment) + + check_wildcards(alignment, 'strain,chrom') if '{prefix}' in alignment: - check_wildcards(alignment, 'prefix,strain,chrom') self.alignment = alignment.replace('{prefix}', self.prefix) else: - check_wildcards(alignment, 'strain,chrom') self.alignment = alignment - def set_positions(self, positions: str): - ''' - Sets the position file - ''' - self.positions = validate(self.config, - 'paths.analysis.positions', - 'No positions file provided', - positions) - - def set_regions_files(self, - regions: str = None, - region_index: str = None): + def get(self, key: str): ''' - Set the region and pickle wildcard filename. Checks for state wildcards + Get nested key from underlying dictionary. Returning none if any + key is not in dict ''' - self.regions = validate( - self.config, - 'paths.analysis.regions', - 'No region file provided', - regions) - check_wildcards(self.regions, 'state') - - self.region_index = validate( - self.config, - 'paths.analysis.region_index', - 'No region index file provided', - region_index) - check_wildcards(self.region_index, 'state') - - def set_HMM_symbols(self): + return get_nested(self.config, key) + + def __repr__(self): + return ('Config file:\n' + + print_dict(self.config) + + '\nSettings:\n' + + print_dict({k: v for k, v in self.__dict__.items() + if k != 'config' and k != 'variables' + and k != 'other_parsers'}) + ) + + +class Variable(): + def __init__(self, name, config_path=None, nullable=False, wildcards=None): + self.name = name + if config_path: + self.config_path = config_path + else: + self.config_path = name + + self.nullable = nullable + self.wildcards = wildcards + + def parse(self, value, config={}): + if self.nullable: + if not value: + value = get_nested(config, self.config_path) + + else: + value = validate(config, self.config_path, + f'No {self.name} provided', value) + + if self.wildcards: + check_wildcards(value, self.wildcards) + + return value + + +class Threshold_Variable(Variable): + def __init__(self): + super().__init__('threshold', 'analysis_params.threshold') + + def parse(self, value, config={}): + value = super().parse(value, config) + + try: + value = float(value) + + except ValueError: + if value != 'viterbi': + err = f'Unsupported threshold value: {value}' + log.exception(err) + raise ValueError(err) + + return value + + +class Filter_Threshold_Variable(Variable): + def __init__(self): + super().__init__('filter_threshold', + 'analysis_params.filter_threshold') + + def parse(self, value, config={}): + value = super().parse(value, config) + + try: + value = float(value) + + except (ValueError, TypeError): + err = 'Filter threshold is not a valid number' + log.exception(err) + raise ValueError(err) + + return value + + +class Convergence_Variable(Variable): + def __init__(self): + super().__init__('convergence', + 'analysis_params.convergence_threshold', + nullable=True) + + def parse(self, value, config={}): + value = super().parse(value, config) + + try: + value = float(value) + + except (ValueError, TypeError): + log.warning('No value set for convergence_threshold, using ' + 'default of 0.001') + value = 0.001 + + return value + + +class Symbols_Variable(Variable): + def __init__(self): + super().__init__('symbols', '') + + def parse(self, value, config): ''' Set symbols based on config values, using defaults if unset ''' - self.symbols = { + symbols = { 'match': '+', 'mismatch': '-', 'unknown': '?', @@ -351,48 +374,24 @@ def set_HMM_symbols(self): 'unaligned': '?', 'masked': 'x' } - config_symbols = get_nested(self.config, 'HMM_symbols') + config_symbols = get_nested(config, 'HMM_symbols') if config_symbols is not None: for k, v in config_symbols.items(): - if k not in self.symbols: + if k not in symbols: log.warning("Unused symbol in configuration: " f"{k} -> '{v}'") else: - self.symbols[k] = v + symbols[k] = v log.debug(f"Overwriting default symbol for {k} with '{v}'") - for k, v in self.symbols.items(): + for k, v in symbols.items(): if k not in config_symbols: log.warning(f'Symbol for {k} unset in config, ' f"using default '{v}'") else: - for k, v in self.symbols.items(): + for k, v in symbols.items(): log.warning(f'Symbol for {k} unset in config, ' f"using default '{v}'") - def set_convergence(self): - ''' - Set convergence for HMM training, using default if unset - ''' - self.convergence = get_nested(self.config, - 'analysis_params.convergence_threshold') - if self.convergence is None: - log.warning('No value set for convergence_threshold, using ' - 'default of 0.001') - self.convergence = 0.001 - - def get(self, key: str): - ''' - Get nested key from underlying dictionary. Returning none if any - key is not in dict - ''' - return get_nested(self.config, key) - - def __repr__(self): - return ('Config file:\n' + - print_dict(self.config) + - '\nSettings:\n' + - print_dict({k: v for k, v in self.__dict__.items() - if k != 'config'}) - ) + return symbols diff --git a/code/analyze/main.py b/code/analyze/main.py index 2a1342f..46a10bc 100644 --- a/code/analyze/main.py +++ b/code/analyze/main.py @@ -5,6 +5,7 @@ from analyze.introgression_configuration import Configuration from analyze.id_regions import ID_producer from analyze.summarize_region_quality import Summarizer +from analyze.filter_regions import Filterer # TODO also check for snakemake object? @@ -40,7 +41,7 @@ def cli(ctx, config, verbosity, log_file): conf = yaml.safe_load(path) ctx.obj.add_config(conf) - ctx.obj.set_log_file(log_file) + ctx.obj.set(log_file=log_file) if ctx.obj.log_file is not None: log.basicConfig(level=level, filename=ctx.obj.log_file, filemode='w') else: @@ -94,20 +95,20 @@ def predict(ctx, only_poly_sites): config = ctx.obj - config.set_chromosomes() + config.set('chromosomes') log.info(f'Found {len(config.chromosomes)} chromosomes in config') - config.set_threshold(threshold) + config.set(threshold=threshold) log.info(f'Threshold value is \'{config.threshold}\'') - config.set_blocks_file(blocks) + config.set(blocks=blocks) log.info(f'Output blocks file is \'{config.blocks}\'') - config.set_states() - config.set_prefix(prefix) + config.set('states') + config.set(prefix=prefix) log.info(f'Prefix is \'{config.prefix}\'') - config.set_strains(test_strains) + config.set(strains=test_strains) if config.test_strains is None: log.info(f'No test_strains provided') else: @@ -118,11 +119,11 @@ def predict(ctx, log.info(f'Found {str_len} unique strain' f'{"" if str_len == 1 else "s"}') - config.set_predict_files(hmm_initial, - hmm_trained, - positions, - probabilities, - alignment) + config.set(hmm_initial=hmm_initial, + hmm_trained=hmm_trained, + positions=positions, + probabilities=probabilities, + alignment=alignment) log.info(f'Hmm_initial file is \'{config.hmm_initial}\'') log.info(f'Hmm_trained file is \'{config.hmm_trained}\'') log.info(f'Positions file is \'{config.positions}\'') @@ -144,24 +145,23 @@ def predict(ctx, @click.pass_context def id_regions(ctx, blocks, labeled, state): config = ctx.obj - config.set_chromosomes() + config.set('chromosomes') log.info(f'Found {len(config.chromosomes)} chromosomes in config') state = list(state) - config.set_states(state) + config.set(states=state) log.info(f'Found {len(config.states)} states to process') - config.set_blocks_file(blocks) + config.set(blocks=blocks) log.info(f'Input blocks file is \'{config.blocks}\'') - config.set_labeled_blocks_file(labeled) + config.set(labeled_blocks=labeled) log.info(f'Output blocks file is \'{config.labeled_blocks}\'') id_producer = ID_producer(config) id_producer.add_ids() -# TODO add in summarize region quality here! @cli.command() @click.option('--state', multiple=True, help='States to summarize') @click.option('--labeled', default='', @@ -192,36 +192,110 @@ def summarize_regions(ctx, region_index): config = ctx.obj - config.set_states() - - config.set_chromosomes() + config.set('states', + 'chromosomes') log.info(f'Found {len(config.chromosomes)} chromosomes in config') - config.set_labeled_blocks_file(labeled) + config.set(labeled_blocks=labeled) log.info(f'Labeled blocks file is \'{config.labeled_blocks}\'') - config.set_quality_file(quality) + config.set(quality_blocks=quality) log.info(f'Quality file is \'{config.quality_blocks}\'') - config.set_masked_file(masks) + config.set(masks=masks) log.info(f'Mask file is \'{config.masks}\'') - config.set_prefix() - config.set_alignment(alignment) + config.set('prefix') + config.set(alignment=alignment) log.info(f'Alignment file is \'{config.alignment}\'') - config.set_positions(positions) + config.set(positions=positions) log.info(f'Positions file is \'{config.positions}\'') - config.set_regions_files(region, region_index) + config.set(regions=region, region_index=region_index) log.info(f'Region file is \'{config.regions}\'') log.info(f'Region index file is \'{config.region_index}\'') - config.set_HMM_symbols() + config.set('symbols') summarizer = Summarizer(config) summarizer.run(list(state)) +@cli.command() +@click.option('--thresh', help='Threshold to apply to ambiguous filter', + default=None, type=float) +@click.option('--introgress-filter', default='', + help='Filtered block file location with {state}.' + ' Contains only regions passing introgression filter') +@click.option('--introgress-inter', default='', + help='Filtered block file location with {state}.' + ' Contains all regions with reasons they failed filtering') +@click.option('--ambiguous-filter', default='', + help='Filtered block file location with {state}.' + ' Contains only regions passing ambiguous filter') +@click.option('--ambiguous-inter', default='', + help='Filtered block file location with {state}.' + ' Contains all regions passing introgressing filtering, ' + 'with reasons they failed ambiguous filtering') +@click.option('--filter-sweep', default='', + help='Contains summary results for applying ambiguous filter ' + 'with various threshold values supplied as arguments.') +@click.option('--region', default='', + help='Region file with {state}, gzipped') +@click.option('--region-index', default='', + help='Region index file with {state}, pickled') +@click.option('--quality', default='', + help='Quality file with {state}') +@click.argument('thresholds', nargs=-1, type=float) +@click.pass_context +def filter_regions(ctx, + thresh, + introgress_filter, + introgress_inter, + ambiguous_filter, + ambiguous_inter, + filter_sweep, + region, + region_index, + quality, + thresholds): + config = ctx.obj # type: Configuration + config.set('states') + + config.set(filter_threshold=thresh) + log.info(f"Filter threshold set to '{config.filter_threshold}'") + + config.set(introgressed=introgress_filter, + introgressed_intermediate=introgress_inter, + ambiguous=ambiguous_filter, + ambiguous_intermediate=ambiguous_inter, + filter_sweep=filter_sweep) + log.info(f"Introgressed filtered file is '{config.introgressed}'") + log.info('Introgressed intermediate file is ' + f"'{config.introgressed_intermediate}'") + log.info(f"Ambiguous filtered file is '{config.ambiguous}'") + log.info('Ambiguous intermediate file is ' + f"'{config.ambiguous_intermediate}'") + if config.filter_sweep is not None: + log.info(f"Filter sweep file is '{config.filter_sweep}'") + + config.set(regions=region, + region_index=region_index) + log.info(f'Region file is \'{config.regions}\'') + log.info(f'Region index file is \'{config.region_index}\'') + + config.set(quality_blocks=quality) + log.info(f'Quality file is \'{config.quality_blocks}\'') + + config.set('symbols') + + thresholds = list(thresholds) + log.info(f'Threshold sweep with: {thresholds}') + + filterer = Filterer(config) + filterer.run(thresholds) + + if __name__ == '__main__': cli() diff --git a/code/analyze/plotting/format_for_plot_gene_region.py b/code/analyze/plotting/format_for_plot_gene_region.py index 306576e..176912f 100644 --- a/code/analyze/plotting/format_for_plot_gene_region.py +++ b/code/analyze/plotting/format_for_plot_gene_region.py @@ -7,29 +7,23 @@ # - for gap -import re -import sys -import os -import copy import gzip -sys.path.insert(0, '..') import global_params as gp -sys.path.insert(0, '../align/') -import align_helpers -sys.path.insert(0, '../misc/') -import read_fasta -import read_table +from align import align_helpers +from misc import read_fasta # copy pasta + def try_int(s, default=-1): try: i = int(s) return i - except: + except ValueError: return default -def referize(strain_seq, ref_ind_to_strain_ind, skip_char = 'N'): + +def referize(strain_seq, ref_ind_to_strain_ind, skip_char='N'): s = [skip_char for r in ref_ind_to_strain_ind] for i in range(len(ref_ind_to_strain_ind)): si = ref_ind_to_strain_ind[i] @@ -39,31 +33,34 @@ def referize(strain_seq, ref_ind_to_strain_ind, skip_char = 'N'): s[i] = strain_seq[si] return s -#region_start = 787000 -#region_end = 794000 -#chrm = 'II' + +# region_start = 787000 +# region_end = 794000 +# chrm = 'II' region_start = 917571 - 100 region_end = 921647 + 100 chrm = 'IV' region_length = region_end - region_start + 1 -##====== +# ====== # get strains -##====== +# ====== -strain_dirs = align_helpers.get_strains(align_helpers.flatten(gp.non_ref_dirs.values())) +strain_dirs = align_helpers.get_strains( + align_helpers.flatten(gp.non_ref_dirs.values())) num_strains = len(strain_dirs) -##====== +# ====== # loop through all strains, getting appropriate sequence -##====== +# ====== # master reference and other reference seqs master_ref = gp.alignment_ref_order[0] master_fn = gp.ref_dir[master_ref] + gp.ref_fn_prefix[master_ref] + '_chr' + \ chrm + gp.fasta_suffix -master_seq = read_fasta.read_fasta(master_fn)[1][0][region_start:region_end+1].lower() +master_seq = read_fasta.read_fasta(master_fn)[1][0][ + region_start:region_end+1].lower() other_ref = gp.alignment_ref_order[1] @@ -74,30 +71,30 @@ def referize(strain_seq, ref_ind_to_strain_ind, skip_char = 'N'): ref_ind_to_strain_ind = [try_int(line[:-1]) for line in f_coord.readlines()] other_ref_fn = gp.ref_dir[other_ref] + gp.ref_fn_prefix[other_ref] + \ '_chr' + chrm + gp.fasta_suffix -other_ref_seq = referize(read_fasta.read_fasta(other_ref_fn)[1][0].lower(), \ +other_ref_seq = referize(read_fasta.read_fasta(other_ref_fn)[1][0].lower(), ref_ind_to_strain_ind)[region_start:region_end+1] # other strains seqs = {} for i in range(num_strains): strain, d = strain_dirs[i] - print strain - coord_fn = gp.analysis_out_dir_absolute + 'coordinates/' + \ - gp.master_ref + '_to_' + strain + \ - '_chr' + chrm + '.txt.gz' + print(strain) + coord_fn = (gp.analysis_out_dir_absolute + 'coordinates/' + + gp.master_ref + '_to_' + strain + + '_chr' + chrm + '.txt.gz') f_coord = gzip.open(coord_fn, 'rb') - ref_ind_to_strain_ind = [try_int(line[:-1]) for line in f_coord.readlines()] + ref_ind_to_strain_ind = [try_int(line[:-1]) + for line in f_coord.readlines()] strain_fn = d + strain + '_chr' + chrm + gp.fasta_suffix - seqs[strain] = referize(read_fasta.read_fasta(strain_fn)[1][0].lower(), \ + seqs[strain] = referize(read_fasta.read_fasta(strain_fn)[1][0].lower(), ref_ind_to_strain_ind)[region_start:region_end+1] # write file fn = 'gene_region_variants.txt' f = open(fn, 'w') -f.write('ps\t' + '\t'.join([x[0] for x in strain_dirs]) + '\n') +f.write('ps\t' + '\t'.join([x[0] for x in strain_dirs]) + '\n') for i in range(region_length): - f.write(str(region_start + i)) for strain, d in strain_dirs: x = seqs[strain][i] @@ -119,4 +116,3 @@ def referize(strain_seq, ref_ind_to_strain_ind, skip_char = 'N'): f.write('n') f.write('\n') f.close() - diff --git a/code/analyze/plotting/format_for_plotting.py b/code/analyze/plotting/format_for_plotting.py index fca2521..d0aca70 100644 --- a/code/analyze/plotting/format_for_plotting.py +++ b/code/analyze/plotting/format_for_plotting.py @@ -1,35 +1,29 @@ # format output files to be read easily and plotted in R -import re import sys -import os -import copy import gene_predictions -sys.path.insert(0, '..') import global_params as gp -sys.path.insert(0, '../sim/') -import sim_analyze_hmm_bw as sim -sys.path.insert(0, '../misc/') -import mystats +from misc import mystats -##====== +# ====== # read in analysis parameters -##====== +# ====== suffix = '' if len(sys.argv == 3): suffix = sys.argv[1] -all_predict_args = [x.strip().split() for x in open(sys.argv[2], 'r').readlines()] -all_predict_args = [{'tag':x[0], 'improvement_frac':x[1], 'threshold':x[2], \ - 'expected_length':x[-2], 'expected_frac':x[-1]} \ +all_predict_args = [x.strip().split() + for x in open(sys.argv[2], 'r').readlines()] +all_predict_args = [{'tag': x[0], 'improvement_frac': x[1], 'threshold': x[2], + 'expected_length':x[-2], 'expected_frac':x[-1]} for x in all_predict_args] -l = range(0,36) -l.remove(19) -l.remove(25) -l = [0] -all_predict_args = [all_predict_args[i] for i in l] +arg_inds = range(0, 36) +arg_inds.remove(19) +arg_inds.remove(25) +arg_inds = [0] +all_predict_args = [all_predict_args[i] for i in arg_inds] ''' finished = range(1,36) @@ -49,39 +43,42 @@ sep = '\t' -##====== +# ====== # for plot: lengths of all introgressed regions -##====== +# ====== # one table for each tag # strain chrm region_length # one table for all tags -# tag improvement_frac threshold expected_length expected_frac avg_length lower upper median min max total_num_regions +# tag improvement_frac threshold expected_length expected_frac +# avg_length lower upper median min max total_num_regions -print 'working on region lengths' +print('working on region lengths') f = open(gp.analysis_out_dir_absolute + 'plot_region_lengths.txt', 'w') for i in range(len(all_predict_args)): - print '-', i + print('-', i) args = all_predict_args[i] - f_tag = open(gp.analysis_out_dir_absolute + args['tag'] + '/' + \ - 'plot_region_lengths' + suffix + '_' + args['tag'] + '.txt', 'w') - fn = gp.analysis_out_dir_absolute + args['tag'] + '/' + \ - 'introgressed_blocks_par' + suffix + '_' + args['tag'] + '_summary_plus.txt' + f_tag = open(gp.analysis_out_dir_absolute + args['tag'] + '/' + + 'plot_region_lengths' + suffix + '_' + args['tag'] + + '.txt', 'w') + fn = (gp.analysis_out_dir_absolute + args['tag'] + '/' + + 'introgressed_blocks_par' + suffix + '_' + + args['tag'] + '_summary_plus.txt') region_summary = gene_predictions.read_region_summary(fn) lengths_all = [] for region in region_summary: length = int(region_summary[region]['end']) - \ int(region_summary[region]['start']) + 1 if int(region_summary[region]['number_match_ref2_not_ref1']) >= 5: - f_tag.write(region + sep + region_summary[region]['strain'] + sep + \ - region_summary[region]['chromosome'] + sep + \ + f_tag.write(region + sep + region_summary[region]['strain'] + sep + + region_summary[region]['chromosome'] + sep + str(length) + '\n') lengths_all.append(length) f_tag.close() - f.write(args['tag'] + sep + args['improvement_frac'] + sep + \ - args['threshold'] + sep + args['expected_length'] + sep + \ + f.write(args['tag'] + sep + args['improvement_frac'] + sep + + args['threshold'] + sep + args['expected_length'] + sep + args['expected_frac'] + sep) f.write(str(mystats.mean(lengths_all)) + sep) bs_lower, bs_upper = mystats.bootstrap(lengths_all) @@ -92,43 +89,44 @@ f.write(str(len(lengths_all)) + '\n') f.close() -print 'done' +print('done') sys.exit() -##====== +# ====== # for plot: number of genes per introgressed region -##====== +# ====== # one table for each tag # strain chrm region number_genes # one table for all tags -# tag improvement_frac threshold expected_length expected_frac avg_genes_per_region lower upper median min max +# tag improvement_frac threshold expected_length expected_frac +# avg_genes_per_region lower upper median min max -print 'working on number of genes for each region' +print('working on number of genes for each region') f = open(gp.analysis_out_dir_absolute + 'plot_number_genes_by_region.txt', 'w') for i in range(len(all_predict_args)): - print '-', i + print('-', i) args = all_predict_args[i] - f_tag = open(gp.analysis_out_dir_absolute + args['tag'] + '/' + \ + f_tag = open(gp.analysis_out_dir_absolute + args['tag'] + '/' + 'plot_number_genes_by_region_' + args['tag'] + '.txt', 'w') - fn = gp.analysis_out_dir_absolute + args['tag'] + '/' + \ - 'genes_for_each_region_' + args['tag'] + '.txt' + fn = (gp.analysis_out_dir_absolute + args['tag'] + '/' + + 'genes_for_each_region_' + args['tag'] + '.txt') genes = gene_predictions.read_genes_for_each_region_summary(fn) - fn = gp.analysis_out_dir_absolute + args['tag'] + '/' + \ - 'introgressed_blocks_par_' + args['tag'] + '_summary.txt' + fn = (gp.analysis_out_dir_absolute + args['tag'] + '/' + + 'introgressed_blocks_par_' + args['tag'] + '_summary.txt') region_summary = gene_predictions.read_region_summary(fn) num_genes_all = [] for region in genes: - f_tag.write(region + sep + region_summary[region]['strain'] + sep + \ - region_summary[region]['chromosome'] + sep + \ + f_tag.write(region + sep + region_summary[region]['strain'] + sep + + region_summary[region]['chromosome'] + sep + genes[region]['num_genes'] + '\n') num_genes_all.append(int(genes[region]['num_genes'])) f_tag.close() - f.write(args['tag'] + sep + args['improvement_frac'] + sep + \ - args['threshold'] + sep + args['expected_length'] + sep + \ + f.write(args['tag'] + sep + args['improvement_frac'] + sep + + args['threshold'] + sep + args['expected_length'] + sep + args['expected_frac'] + sep) f.write(str(mystats.mean(num_genes_all)) + sep) bs_lower, bs_upper = mystats.bootstrap(num_genes_all) @@ -138,91 +136,96 @@ f.write(str(max(num_genes_all)) + '\n') f.close() -print 'done' +print('done') -##====== +# ====== # for plot: number of introgressed bases for each strain -##====== +# ====== # one table for all tags -# tag improvement_frac threshold expected_length expected_frac strain number_bases +# tag improvement_frac threshold expected_length +# expected_frac strain number_bases -print 'working on number of bases for each strain' +print('working on number of bases for each strain') -f = open(gp.analysis_out_dir_absolute + \ +f = open(gp.analysis_out_dir_absolute + 'plot_number_introgressed_bases_by_strain.txt', 'w') for i in range(len(all_predict_args)): - print '-', i + print('-', i) args = all_predict_args[i] - fn = gp.analysis_out_dir_absolute + '/' + args['tag'] + '/' + \ - 'regions_for_each_strain_' + args['tag'] + '.txt' + fn = (gp.analysis_out_dir_absolute + '/' + args['tag'] + '/' + + 'regions_for_each_strain_' + args['tag'] + '.txt') regions = gene_predictions.read_regions_for_each_strain(fn) for strain in regions: total = 0 for r in regions[strain]['region_list']: total += int(r[1]) - f.write(args['tag'] + sep + args['improvement_frac'] + sep + \ - args['threshold'] + sep + args['expected_length'] + sep + \ + f.write(args['tag'] + sep + args['improvement_frac'] + sep + + args['threshold'] + sep + args['expected_length'] + sep + args['expected_frac'] + sep + strain + sep + str(total) + '\n') f.close() -print 'done' +print('done') -##====== +# ====== # for plot: number of introgressed genes for each strain -##====== +# ====== # one table for all tags -# tag improvement_frac threshold expected_length expected_frac strain number_genes +# tag improvement_frac threshold expected_length +# expected_frac strain number_genes -print 'working on number of genes for each strain' +print('working on number of genes for each strain') -f = open(gp.analysis_out_dir_absolute + \ +f = open(gp.analysis_out_dir_absolute + 'plot_number_introgressed_genes_by_strain.txt', 'w') for i in range(len(all_predict_args)): - print '-', i + print('-', i) args = all_predict_args[i] - fn = gp.analysis_out_dir_absolute + '/' + args['tag'] + '/' + \ - 'genes_for_each_strain_' + args['tag'] + '.txt' + fn = (gp.analysis_out_dir_absolute + '/' + args['tag'] + '/' + + 'genes_for_each_strain_' + args['tag'] + '.txt') genes = gene_predictions.read_genes_for_each_strain(fn) for strain in genes: - f.write(args['tag'] + sep + args['improvement_frac'] + sep + \ - args['threshold'] + sep + args['expected_length'] + sep + \ - args['expected_frac'] + sep + strain + sep + \ + f.write(args['tag'] + sep + args['improvement_frac'] + sep + + args['threshold'] + sep + args['expected_length'] + sep + + args['expected_frac'] + sep + strain + sep + genes[strain]['num_genes'] + sep + '\n') f.close() -print 'done' +print('done') -##====== -# for plot: number of strains each gene introgressed in -##====== +# ====== +# for plot: number of strains each gene introgressed in +# ====== # one table for each tag # gene num_strains # one table for all tags -# tag improvement_frac threshold expected_length expected_frac avg_strains_per_gene lower upper median min max total_num_genes total_num_genes_1 total_num_genes_>1 +# tag improvement_frac threshold expected_length expected_frac +# avg_strains_per_gene lower upper median min max total_num_genes +# total_num_genes_1 total_num_genes_>1 -print 'working on number of strains for each gene' +print('working on number of strains for each gene') -f = open(gp.analysis_out_dir_absolute + 'plot_number_strains_by_genes.txt', 'w') +f = open(gp.analysis_out_dir_absolute + 'plot_number_strains_by_genes.txt', + 'w') for i in range(len(all_predict_args)): - print '-', i + print('-', i) args = all_predict_args[i] - f_tag = open(gp.analysis_out_dir_absolute + args['tag'] + '/' + \ + f_tag = open(gp.analysis_out_dir_absolute + args['tag'] + '/' + 'plot_number_strains_by_genes_' + args['tag'] + '.txt', 'w') - fn = gp.analysis_out_dir_absolute + args['tag'] + '/' + \ - 'strains_for_each_gene_' + args['tag'] + '.txt' + fn = (gp.analysis_out_dir_absolute + args['tag'] + '/' + + 'strains_for_each_gene_' + args['tag'] + '.txt') strains = gene_predictions.read_strains_for_each_gene(fn) num_strains_all = [] for gene in strains: f_tag.write(gene + sep + strains[gene]['num_strains'] + '\n') num_strains_all.append(int(strains[gene]['num_strains'])) f_tag.close() - f.write(args['tag'] + sep + args['improvement_frac'] + sep + \ - args['threshold'] + sep + args['expected_length'] + sep + \ + f.write(args['tag'] + sep + args['improvement_frac'] + sep + + args['threshold'] + sep + args['expected_length'] + sep + args['expected_frac'] + sep) f.write(str(mystats.mean(num_strains_all)) + sep) bs_lower, bs_upper = mystats.bootstrap(num_strains_all) @@ -235,24 +238,25 @@ f.write(str(len(filter(lambda x: x > 1, num_strains_all))) + '\n') f.close() -print 'done' +print('done') -##====== -# for plot: average fraction of each (introgressed) gene that's introgressed -##====== +# ====== +# for plot: average fraction of each (introgressed) gene that's introgressed +# ====== # one table for each tag # gene avg_frac_introgressed lower upper median min max -print 'working on fraction of gene introgressed' +print('working on fraction of gene introgressed') for i in range(len(all_predict_args)): - print '-', i + print('-', i) args = all_predict_args[i] - f_tag = open(gp.analysis_out_dir_absolute + args['tag'] + '/' + \ - 'plot_frac_introgressed_by_genes_' + args['tag'] + '.txt', 'w') - fn = gp.analysis_out_dir_absolute + args['tag'] + '/' + \ - 'strains_for_each_gene_' + args['tag'] + '.txt' + f_tag = open(gp.analysis_out_dir_absolute + args['tag'] + '/' + + 'plot_frac_introgressed_by_genes_' + args['tag'] + '.txt', + 'w') + fn = (gp.analysis_out_dir_absolute + args['tag'] + '/' + + 'strains_for_each_gene_' + args['tag'] + '.txt') strains = gene_predictions.read_strains_for_each_gene(fn) for gene in strains: fracs = [float(x[1]) for x in strains[gene]['strain_list']] @@ -265,4 +269,4 @@ f_tag.write(str(max(fracs)) + '\n') f_tag.close() -print 'done' +print('done') diff --git a/code/analyze/plotting/format_for_plotting2.py b/code/analyze/plotting/format_for_plotting2.py index 3362e63..4d9eaca 100644 --- a/code/analyze/plotting/format_for_plotting2.py +++ b/code/analyze/plotting/format_for_plotting2.py @@ -1,14 +1,6 @@ -import re import sys -import os -import copy -import gene_predictions -sys.path.insert(0, '..') +from analyze.to_update import gene_predictions import global_params as gp -sys.path.insert(0, '../sim/') -import sim_analyze_hmm_bw as sim -sys.path.insert(0, '../misc/') -import mystats tag = sys.argv[1] @@ -16,15 +8,15 @@ if len(sys.argv == 3): suffix = sys.argv[2] -fn = gp.analysis_out_dir_absolute + args['tag'] + '/' + \ - 'introgressed_blocks_par' + suffix + '_' + args['tag'] + '_summary_plus.txt' +fn = gp.analysis_out_dir_absolute + tag + '/' + \ + 'introgressed_blocks_par' + suffix + '_' + tag + '_summary_plus.txt' region_summary = gene_predictions.read_region_summary(fn) sep = '\t' -##====== +# ====== # for plot: lengths of all introgressed regions -##====== +# ====== # one table for each tag # strain chrm region_length diff --git a/code/analyze/plotting/format_for_plotting_region.py b/code/analyze/plotting/format_for_plotting_region.py index b6454dc..15cebad 100644 --- a/code/analyze/plotting/format_for_plotting_region.py +++ b/code/analyze/plotting/format_for_plotting_region.py @@ -1,24 +1,20 @@ -import gene_predictions +from analyze.to_update import gene_predictions import sys import os import gzip -sys.path.insert(0, '../misc/') -import read_fasta import global_params as gp -sys.path.insert(0, '../sim/') + def read_annotated_alignment(fn, nstrains): f = gzip.open(fn, 'rb') lines = f.readlines() f.close() - strains = [l[:-1] for l in lines[:nstrains]] genes = lines[nstrains + 2][len('genes:'):-1].split() - + x = 11 match_cer = '' match_par = '' gene = '' - gene_ind = -1 intd = '' while x < len(lines): @@ -38,13 +34,12 @@ def read_annotated_alignment(fn, nstrains): return match_cer, match_par, gene, genes, intd -def write_ps_annotated(match_cer, match_par, gene, glist, intd, region, fn): +def write_ps_annotated(match_cer, match_par, gene, glist, intd, region, fn): f = open(fn, 'w') f.write('ps\tmatch\tintd\tgene\n') - block_start = int(region['start']) - intd.index('I') - block_end = len(intd) - intd.rindex('I') + int(region['end']) + block_start = int(region['start']) - intd.index('I') out_of_gene = True gene_ind = -1 @@ -64,11 +59,13 @@ def write_ps_annotated(match_cer, match_par, gene, glist, intd, region, fn): f.write('\n') f.close() + tag = sys.argv[1] region = sys.argv[2] -blocks_fn = gp.analysis_out_dir_absolute + tag + '/' + \ - 'introgressed_blocks_filtered_' + 'par' + '_' + tag + '_summary.txt' +blocks_fn = (gp.analysis_out_dir_absolute + tag + '/' + + 'introgressed_blocks_filtered_' + 'par' + + '_' + tag + '_summary.txt') r = gene_predictions.read_region_summary(blocks_fn) strain = r[region]['strain'] chrm = r[region]['chromosome'] @@ -86,6 +83,5 @@ def write_ps_annotated(match_cer, match_par, gene, glist, intd, region, fn): write_ps_annotated(match_cer, match_par, gene, glist, intd, r[region], fn_out) -#probs_f = gzip.open(gp.analysis_out_dir_absolute + tag + '/' + \ +# probs_f = gzip.open(gp.analysis_out_dir_absolute + tag + '/' + \ # 'probs_' + tag + '.txt.gz', 'rb') - diff --git a/code/analyze/plotting/format_polymorphism_for_r.py b/code/analyze/plotting/format_polymorphism_for_r.py index 58d10ce..3665d0d 100644 --- a/code/analyze/plotting/format_polymorphism_for_r.py +++ b/code/analyze/plotting/format_polymorphism_for_r.py @@ -1,7 +1,5 @@ # lol because i'm so bad at R -import sys -sys.path.insert(0, '..') import global_params as gp tag = 'u3_i.001_tv_l1000_f.01' @@ -16,7 +14,7 @@ d2_sums = {} for line in lines[1:]: chrm = line[0] - if not d_sums.has_key(chrm): + if chrm not in d_sums: d_sums[chrm] = 0 d2_sums[chrm] = 0 d[chrm] = {} @@ -53,24 +51,24 @@ fab += c try: fo = str(float(fo)/d_sums[chrm]) - except: + except ValueError: fo = 'NaN' try: fob = str(float(fob)/d2_sums[chrm]) - except: + except ValueError: fob = 'NaN' try: fa = str(float(fa)/d_sums[chrm]) - except: + except ValueError: fa = 'NaN' try: fab = str(float(fab)/d2_sums[chrm]) - except: + except ValueError: fab = 'NaN' - f.write(chrm + '\tone\tpolymorphic\t' + fo + '\t' + str(d_sums[chrm]) + '\n') - f.write(chrm + '\tone\tbiallelic\t' + fob + '\t' + str(d2_sums[chrm]) + '\n') - f.write(chrm + '\tall\tpolymorphic\t' + fa + '\t' + str(d_sums[chrm]) + '\n') - f.write(chrm + '\tall\tbiallelic\t' + fab + '\t' + str(d2_sums[chrm]) + '\n') + f.write(f'{chrm}\tone\tpolymorphic\t{fo}\t{d_sums[chrm]}\n') + f.write(f'{chrm}\tone\tbiallelic\t{fob}\t{d2_sums[chrm]}\n') + f.write(f'{chrm}\tall\tpolymorphic\t{fa}\t{d_sums[chrm]}\n') + f.write(f'{chrm}\tall\tbiallelic\t{fab}\t{d2_sums[chrm]}\n') f.close() diff --git a/code/analyze/predict.py b/code/analyze/predict.py index 62e8df4..bea0d33 100644 --- a/code/analyze/predict.py +++ b/code/analyze/predict.py @@ -401,9 +401,9 @@ def convert_to_blocks(self, class HMM_Builder(): def __init__(self, configuration: Configuration): self.config = configuration - self.config.set_HMM_symbols() + self.config.set('symbols') self.symbols = self.config.symbols - self.config.set_convergence() + self.config.set('convergence') def update_emission_symbols(self, repeats: int): ''' diff --git a/code/analyze/read_args.py b/code/analyze/read_args.py index 26c6698..b78c022 100644 --- a/code/analyze/read_args.py +++ b/code/analyze/read_args.py @@ -1,10 +1,11 @@ import sys from align import align_helpers + def process_predict_args(args): - + d = {} - + i = 0 d['tag'] = args[i] @@ -14,7 +15,7 @@ def process_predict_args(args): i += 1 d['improvement_frac'] = float(args[i]) - + i += 1 d['threshold'] = args[i] if d['threshold'] != 'viterbi': @@ -34,8 +35,9 @@ def process_predict_args(args): d['expected_length'][state] = float(args[i]) i += 1 d['expected_frac'][state] = float(args[i]) - d['expected_frac'][d['known_states'][0]] = 1 - sum(d['expected_frac'].values()) - d['expected_length'][d['known_states'][0]] = 0 # calculate later + d['expected_frac'][d['known_states'][0]] = \ + 1 - sum(d['expected_frac'].values()) + d['expected_length'][d['known_states'][0]] = 0 # calculate later i += 1 while i < len(args): @@ -57,7 +59,8 @@ def process_predict_args(args): d['setup_args'] = setup_args return d - + + def read_setup_args(fn): x = {} @@ -72,7 +75,8 @@ def read_setup_args(fn): d = {} d['references'] = x['references'] - d['reference_directories'] = dict(zip(x['references'], x['reference_directories'])) + d['reference_directories'] = \ + dict(zip(x['references'], x['reference_directories'])) d['alignments_directory'] = x['alignments_directory'][0] d['strain_dirs'] = \ @@ -80,6 +84,7 @@ def read_setup_args(fn): return d + def get_predict_args_by_tag(fn, tag): f = open(fn, 'r') line = f.readline() @@ -90,4 +95,3 @@ def get_predict_args_by_tag(fn, tag): line = f.readline() print(f'tag not found: {tag}') return None - diff --git a/code/analyze/structure/structure_1_main.py b/code/analyze/structure/structure_1_main.py index 27e060b..bb6063d 100644 --- a/code/analyze/structure/structure_1_main.py +++ b/code/analyze/structure/structure_1_main.py @@ -3,24 +3,18 @@ import sys import os -import gzip import predict from collections import defaultdict -import gene_predictions -sys.path.insert(0, '..') import global_params as gp -sys.path.insert(0, '../misc/') -import read_fasta -import read_table -import seq_functions +from misc import read_fasta args = predict.process_predict_args(sys.argv[2:]) chrm = gp.chrms[int(sys.argv[1])] # maybe getting strains should be simpler -strains = [line.split('\t')[0] for line in \ - open(gp.analysis_out_dir_absolute + args['tag'] + \ +strains = [line.split('\t')[0] for line in + open(gp.analysis_out_dir_absolute + args['tag'] + '/state_counts_by_strain.txt', 'r').readlines()[1:]] nucs = set(['a', 't', 'g', 'c']) @@ -32,10 +26,10 @@ gp_dir = '../' -##====== +# ====== # use program ldselect to find set of tag snps all in low LD for # specified chromosome -##====== +# ====== # input file for ldselect is formatted so that each row is a snp and # each column is the genotype for a strain, e.g. @@ -45,13 +39,13 @@ snps = defaultdict(list) # loop through all the strains for strain in strains: - print '-', strain + print('-', strain) # read multiple alignment file for this strain with the master # reference (and other references which we don't care about # here) - headers, seqs = read_fasta.read_fasta(gp_dir + gp.alignments_dir + \ - '_'.join(gp.alignment_ref_order) + \ - '_' + strain + '_chr' + chrm + \ + headers, seqs = read_fasta.read_fasta(gp_dir + gp.alignments_dir + + '_'.join(gp.alignment_ref_order) + + '_' + strain + '_chr' + chrm + '_mafft.maf') # look at all alignment columns, keeping track of the index in # the master reference @@ -66,11 +60,11 @@ # get reference sequence (unaligned, without gaps) # TODO correct alignment file location -ref_seq = read_fasta.read_fasta(gp_dir + gp.alignments_dir + \ - '_'.join(gp.alignment_ref_order) + \ - '_' + strains[0] + '_chr' + chrm + \ +ref_seq = read_fasta.read_fasta(gp_dir + gp.alignments_dir + + '_'.join(gp.alignment_ref_order) + + '_' + strains[0] + '_chr' + chrm + '_mafft.maf')[1][0].replace(gp.gap_symbol, '') -open(out_dir + 'chromosome_lengths.txt', 'a').write(chrm + '\t' + \ +open(out_dir + 'chromosome_lengths.txt', 'a').write(chrm + '\t' + str(len(ref_seq)) + '\n') # loop through all the sites we collected above @@ -83,13 +77,13 @@ # TODO do names have to be integers and/or equal in length? snp_id = str(snp) # write row for master reference - f.write(snp_id + '\t' + \ - gp.alignment_ref_order[0] + '\t' + \ + f.write(snp_id + '\t' + + gp.alignment_ref_order[0] + '\t' + ref_seq[snp] + '\n') # and one row for each of the other strains for si in range(len(strains)): - f.write(snp_id + '\t' + \ - strains[si] + '\t' + \ + f.write(snp_id + '\t' + + strains[si] + '\t' + snps[snp][si] + '\n') f.close() @@ -97,7 +91,8 @@ """ # run ldselect on this input file fn_out = fn.replace('input', 'output') -os.system('perl ' + gp.ldselect_install_path + 'ldSelect.pl -pb ' + fn + ' > ' + fn_out) +os.system('perl ' + gp.ldselect_install_path + + 'ldSelect.pl -pb ' + fn + ' > ' + fn_out) # extract one tag snp from each set of equivalent tag snps from # ldselect output file diff --git a/code/analyze/structure/structure_2_main.py b/code/analyze/structure/structure_2_main.py index 9b22ed9..3298a44 100644 --- a/code/analyze/structure/structure_2_main.py +++ b/code/analyze/structure/structure_2_main.py @@ -1,15 +1,8 @@ import sys import os -import gzip import predict from collections import defaultdict -import gene_predictions -sys.path.insert(0, '..') import global_params as gp -sys.path.insert(0, '../misc/') -import read_fasta -import read_table -import seq_functions args = predict.process_predict_args(sys.argv[2:]) @@ -23,18 +16,18 @@ os.makedirs(out_dir_run + '/population_ranges') # maybe getting strains should be simpler -strains = [line.split('\t')[0] for line in \ - open(gp.analysis_out_dir_absolute + args['tag'] + \ +strains = [line.split('\t')[0] for line in + open(gp.analysis_out_dir_absolute + args['tag'] + '/state_counts_by_strain.txt', 'r').readlines()[1:]] gp_dir = '../' -nuc_to_int = {'a':1, 't':2, 'g':3, 'c':4} +nuc_to_int = {'a': 1, 't': 2, 'g': 3, 'c': 4} -##====== +# ====== # use program structure to find population proportion using either # unlinked tagsnps from ldselect, or just all snps -##====== +# ====== use_all_snps = True @@ -93,12 +86,13 @@ f = open(out_dir_run + 'structure_input_run' + run_id + '.txt', 'w') for chrm in gp.chrms: - f.write('\t\t\t' + '\t'.join([chrm + '_' + str(x) \ - for x in sorted(all_snps[chrm].keys())])) + f.write('\t\t\t' + '\t'.join([chrm + '_' + str(x) + for x in sorted(all_snps[chrm].keys())])) f.write('\n') for chrm in gp.chrms: - f.write('\t\t\t' + '\t'.join([str(map_distances[chrm][x]) \ - for x in sorted(map_distances[chrm].keys())])) + f.write('\t\t\t' + '\t'.join( + [str(map_distances[chrm][x]) + for x in sorted(map_distances[chrm].keys())])) f.write('\n') for strain in strains: @@ -118,7 +112,8 @@ """ os.system(gp.structure_install_path + 'structure -L ' + str(num_snps) + \ ' -K 6 -i ' + out_dir_run + 'structure_input_run' + run_id + \ - '.txt -o ' + out_dir_run + 'structure_output_k6_run' + run_id + '.txt') + '.txt -o ' + out_dir_run + 'structure_output_k6_run' + + run_id + '.txt') os.system('mv ' + out_dir_run + 'structure_output_k6_run' + \ run_id + '.txt_ss ' + out_dir_run + \ @@ -136,7 +131,7 @@ line = f.readline() while line != "Inferred ancestry of individuals:\n": line = f.readline() -f.readline() # column headings +f.readline() # column headings line = f.readline() f_out.write('strain\tpopulation\tfraction\tindex\n') while line != "\n": @@ -149,7 +144,7 @@ ind = i break for i in range(len(fracs)): - f_out.write(strain + '\t' + str(i + 1) + '\t' + \ + f_out.write(strain + '\t' + str(i + 1) + '\t' + str(fracs[i]) + '\t' + str(ind + 1) + '\n') line = f.readline() f.close() @@ -161,7 +156,8 @@ f = open(out_dir_run + 'structure_output_ss_k6_run' + run_id + '.txt', 'r') k = 6 -# read in posterior probabilities for each strain locus being in each population +# read in posterior probabilities for +# each strain locus being in each population line = f.readline() while line.strip() == '\n': line = f.readline() @@ -187,8 +183,6 @@ line = f.readline() f.close() - - # TODO at some point associate numbered populations with logical names # (i.e. ones from strope et al) @@ -198,17 +192,17 @@ # population_ranges_strain_chrX.txt # start end popx # start end popx/popy -# start end +# start end -chrm_lengths = [line[:-1].split('\t') for line in \ - open(out_dir + 'chromosome_lengths.txt', 'r').readlines()] -chrm_lengths = dict(zip([x[0] for x in chrm_lengths], \ +chrm_lengths = [line[:-1].split('\t') for line in + open(out_dir + 'chromosome_lengths.txt', 'r').readlines()] +chrm_lengths = dict(zip([x[0] for x in chrm_lengths], [int(x[1]) for x in chrm_lengths])) for strain in strains: for chrm in gp.chrms: ranges = [] - snps = sorted(strain_snp_pop[strain][chrm].keys()) + snps = sorted(strain_snp_pop[strain][chrm].keys()) start = snps[0] end = start previous_pop = strain_snp_pop[strain][chrm][start] @@ -221,7 +215,8 @@ else: ranges.append((start, end, previous_pop)) - ranges.append((end + 1, snp - 1, previous_pop + '/' + current_pop)) + ranges.append((end + 1, snp - 1, + previous_pop + '/' + current_pop)) start = snp end = snp previous_pop = current_pop @@ -231,7 +226,8 @@ ranges.append((end + 1, chrm_lengths[chrm], 'end')) # TODO file location - f = open(out_dir_run + 'population_ranges/population_ranges_' + strain + '_chr' + chrm + '_run' + run_id + '.txt', 'w') + f = open(out_dir_run + 'population_ranges/population_ranges_' + + strain + '_chr' + chrm + '_run' + run_id + '.txt', 'w') for r in ranges: f.write('\t'.join([str(x) for x in r]) + '\n') f.close() diff --git a/code/analyze/structure/structure_3_main.py b/code/analyze/structure/structure_3_main.py index 4446cdd..fc1403c 100644 --- a/code/analyze/structure/structure_3_main.py +++ b/code/analyze/structure/structure_3_main.py @@ -1,26 +1,19 @@ -## generate three files: +# generate three files: -## 1. introgressed regions annotated by which population background(s) -## they overlap +# 1. introgressed regions annotated by which population background(s) +# they overlap -## 2. population backgrounds annotated by how much introgression they -## have from each reference strain (or ambiguous strains) +# 2. population backgrounds annotated by how much introgression they +# have from each reference strain (or ambiguous strains) -## 3. counts of bases in for each strain x population background x -## introgresssing reference [or lack of introgression] +# 3. counts of bases in for each strain x population background x +# introgresssing reference [or lack of introgression] import sys -import os -import gzip import predict from collections import defaultdict -import gene_predictions -sys.path.insert(0, '..') import global_params as gp -sys.path.insert(0, '../misc/') -import read_fasta -import read_table -import seq_functions +from misc import read_table args = predict.process_predict_args(sys.argv[3:]) @@ -33,8 +26,8 @@ # TODO maybe getting strains should be simpler...at least make this # not copy pasta -strains = [line.split('\t')[0] for line in \ - open(gp.analysis_out_dir_absolute + args['tag'] + \ +strains = [line.split('\t')[0] for line in + open(gp.analysis_out_dir_absolute + args['tag'] + '/state_counts_by_strain.txt', 'r').readlines()[1:]] @@ -58,8 +51,10 @@ def find_pops(start, end, pop_ranges): bases.append(r[1] - r[0] + 1) return pops, bases + population_int_counts = defaultdict(lambda: defaultdict(int)) -strain_population_int_counts = defaultdict(lambda: defaultdict(lambda: defaultdict(int))) +strain_population_int_counts = defaultdict( + lambda: defaultdict(lambda: defaultdict(int))) population_totals = defaultdict(int) strain_population_totals = defaultdict(lambda: defaultdict(int)) all_alternative_states = set([]) @@ -74,24 +69,24 @@ def find_pops(start, end, pop_ranges): strain = regions[region_id]['strain'] regions_strain_chrm[strain][chrm][region_id] = regions[region_id] new_regions_fn = gp.analysis_out_dir_absolute + args['tag'] + '/' + \ - 'blocks_' + ref + \ - '_' + args['tag'] + '_populations.txt' + 'blocks_' + ref + \ + '_' + args['tag'] + '_populations.txt' f = open(new_regions_fn, 'w') labels = labels[1:] + ['population'] f.write('region_id' + '\t' + '\t'.join(labels) + '\n') - #for chrm in regions_strain_chrm[strain]: + # for chrm in regions_strain_chrm[strain]: for strain in strains: for chrm in gp.chrms: # TODO get rid of run_id in filenames? pop_ranges_fn = out_dir_run + 'population_ranges/' + \ 'population_ranges_' + strain + \ '_chr' + chrm + '_run' + run_id + '.txt' - pop_ranges = [line[:-1].split('\t') for line in \ + pop_ranges = [line[:-1].split('\t') for line in open(pop_ranges_fn, 'r').readlines()] pop_ranges = [(int(x[0]), int(x[1]), x[2]) for x in pop_ranges] for pr in pop_ranges: - population_totals[pr[2]] += pr[1] - pr[0] + 1 + population_totals[pr[2]] += pr[1] - pr[0] + 1 strain_population_totals[strain][pr[2]] += pr[1] - pr[0] + 1 for region_id in regions_strain_chrm[strain][chrm]: @@ -101,40 +96,47 @@ def find_pops(start, end, pop_ranges): # find the population ranges that the region start and end # coordinates fall within - pops, overlaps = find_pops(int(r['start']), int(r['end']), pop_ranges) + pops, overlaps = find_pops(int(r['start']), int(r['end']), + pop_ranges) regions_strain_chrm[strain][chrm][region_id]['population'] = \ ','.join(pops) - f.write(region_id + '\t' + \ - '\t'.join([str(regions_strain_chrm[strain][chrm][region_id][x])\ - for x in labels]) + '\n') + f.write(region_id + '\t' + + '\t'.join( + [str(regions_strain_chrm[strain][ + chrm][region_id][x]) + for x in labels]) + '\n') for i in range(len(pops)): - population_int_counts[pops[i]][r['alternative_states']] += \ + population_int_counts[pops[i]][r['alternative_states']] +=\ overlaps[i] - strain_population_int_counts[strain][pops[i]]\ - [r['alternative_states']] += overlaps[i] + strain_population_int_counts[strain][pops[i]][ + r['alternative_states']] += overlaps[i] -f = open(out_dir_run + 'population_introgression_counts_run' + run_id + '.txt', 'w') -f.write('population\treference\tnum_bases_introgressed\tfrac_bases_introgressed\n') +f = open(out_dir_run + 'population_introgression_counts_run' + run_id + '.txt', + 'w') +f.write('population\treference\t' + 'num_bases_introgressed\tfrac_bases_introgressed\n') for i in population_int_counts.keys(): for ref in population_int_counts[i].keys(): - f.write(str(i) + '\t' + ref + '\t' + str(population_int_counts[i][ref]) + '\t' + \ - str(float(population_int_counts[i][ref])/population_totals[i]) + '\n') + f.write(str(i) + '\t' + ref + '\t' + + str(population_int_counts[i][ref]) + '\t' + + str(float(population_int_counts[i][ref])/population_totals[i]) + + '\n') f.close() -f = open(out_dir_run + 'strain_population_introgression_counts_run' + \ +f = open(out_dir_run + 'strain_population_introgression_counts_run' + run_id + '.txt', 'w') -f.write('strain\tpopulation\treference\tnum_bases_introgressed' + \ +f.write('strain\tpopulation\treference\tnum_bases_introgressed' + '\tfrac_bases_introgressed\n') for strain in strains: for i in strain_population_int_counts[strain].keys(): for ref in all_alternative_states: count = strain_population_int_counts[strain][i][ref] total = strain_population_totals[strain][i] - #frac = 0 - #if total > 0: + # frac = 0 + # if total > 0: frac = float(count)/total - f.write(strain + '\t' + str(i) + '\t' + ref + '\t' + + f.write(strain + '\t' + str(i) + '\t' + ref + '\t' + str(count) + '\t' + str(frac) + '\n') f.close() diff --git a/code/analyze/summarize_region_quality.py b/code/analyze/summarize_region_quality.py index 43724c4..c2a4160 100644 --- a/code/analyze/summarize_region_quality.py +++ b/code/analyze/summarize_region_quality.py @@ -312,6 +312,125 @@ def states_to_process(self, return ref_ind, to_process +class Flag_Info(): + ''' + Collection of boolean flags for sequence summary + ''' + def __init__(self): + self.gap_any = None + self.mask_any = None + self.unseq_any = None + self.hmm = None + self.gap = None + self.mask = None + self.unseq = None + self.match = None + + def initialize_flags(self, number_sequences: int, number_states: int): + ''' + Initialize internal flags to np arrays of false + ''' + self.gap_any = np.zeros((number_sequences), bool) + self.mask_any = np.zeros((number_sequences), bool) + self.unseq_any = np.zeros((number_sequences), bool) + self.gap = np.zeros((number_sequences, number_states), bool) + self.mask = np.zeros((number_sequences, number_states), bool) + self.unseq = np.zeros((number_sequences, number_states), bool) + self.match = np.zeros((number_sequences, number_states), bool) + + def add_sequence_flags(self, other: Flag_Info, state: int): + ''' + Join the other flag info with this info by replacing values + in the gap, unseq, and match arrays and performing OR with anys + ''' + # only write the first time + if state == 0: + self.hmm = other.hmm + + self.gap_any = np.logical_or(self.gap_any, other.gap) + self.unseq_any = np.logical_or(self.unseq_any, other.unseq) + + self.gap[:, state] = other.gap + self.unseq[:, state] = other.unseq + self.match[:, state] = other.match + + def add_mask_flags(self, other: Flag_Info, state: int): + ''' + Join the other flag info with this by replacing values in mask and + performing an OR with mask_any + ''' + self.mask_any = np.logical_or(self.mask_any, other.mask) + self.mask[:, state] = other.mask + + def encode_info(self, + master_ind: int, + predict_ind: int) -> str: + ''' + Summarize info flags into a string. master_ind is the index of + the master reference state. predict_ind is the index of the predicted + state. The return string is encoded for each position as: + '-': if either master or predict has a gap + '_': if either master or predict is masked + '.': if any state has a match + 'b': both predict and master match + 'c': master matches but not predict + 'p': predict matches but not master + 'x': no other condition applies + if the position is in the hmm_flag + it will be capitalized for x, p, c, or b + in order of precidence, e.g. if a position satisfies both '-' and '.', + it will be '-'. + ''' + + if predict_ind >= self.match.shape[1]: + return self.encode_unknown_info(master_ind) + + decoder = np.array(list('xXpPcCbB._-')) + indices = np.zeros(self.match.shape[0], int) + + indices[self.match[:, predict_ind]] += 2 # x to p if true + indices[self.match[:, master_ind]] += 4 # x to c, p to b + indices[self.hmm] += 1 # to upper + + matches = np.all(self.match, axis=1) + indices[matches] = 8 # . + indices[np.any( + self.mask[:, [master_ind, predict_ind]], + axis=1)] = 9 # _ + indices[np.any( + self.gap[:, [master_ind, predict_ind]], + axis=1)] = 10 # - + + return ''.join(decoder[indices]) + + def encode_unknown_info(self, + master_ind: int) -> str: + ''' + Summarize info dictionary into a string for unknown state. + master_ind is the index of the master reference state. + The return string is encoded as each position as: + '-': if any state has a gap + '_': if any state has a mask + '.': all states match + 'x': master matches + 'X': no other condition applies + in order of precidence, e.g. if a position satisfies both '-' and '.', + it will be '-'. + ''' + + # used with indices to decode result + decoder = np.array(list('Xx._-')) + indices = np.zeros(self.gap_any.shape, int) + + indices[self.match[:, master_ind]] = 1 # x + matches = np.all(self.match, axis=1) + indices[matches] = 2 # . + indices[self.mask_any] = 3 # _ + indices[self.gap_any] = 4 # - + + return ''.join(decoder[indices]) + + class Sequence_Analyzer(): ''' Performs handling of masking, reading, and analyzing sequence data for @@ -466,10 +585,10 @@ def seq_id_unmasked(self, offset: int, exclude_sites1: List[int], exclude_sites2: List[int]) -> Tuple[ - int, int, Flag_info]: + int, int, Flag_Info]: ''' - Compare two sequences and provide statistics of their overlap considering - only the included sites. + Compare two sequences and provide statistics of their overlap + considering only the included sites. Takes two sequences, an offset applied to each excluded sites list Returns: -total number of matching sites in non-excluded sites. A position is @@ -633,125 +752,6 @@ def get_slice(self, return slice_start, slice_end -class Flag_Info(): - ''' - Collection of boolean flags for sequence summary - ''' - def __init__(self): - self.gap_any = None - self.mask_any = None - self.unseq_any = None - self.hmm = None - self.gap = None - self.mask = None - self.unseq = None - self.match = None - - def initialize_flags(self, number_sequences: int, number_states: int): - ''' - Initialize internal flags to np arrays of false - ''' - self.gap_any = np.zeros((number_sequences), bool) - self.mask_any = np.zeros((number_sequences), bool) - self.unseq_any = np.zeros((number_sequences), bool) - self.gap = np.zeros((number_sequences, number_states), bool) - self.mask = np.zeros((number_sequences, number_states), bool) - self.unseq = np.zeros((number_sequences, number_states), bool) - self.match = np.zeros((number_sequences, number_states), bool) - - def add_sequence_flags(self, other: Flag_Info, state: int): - ''' - Join the other flag info with this info by replacing values - in the gap, unseq, and match arrays and performing OR with anys - ''' - # only write the first time - if state == 0: - self.hmm = other.hmm - - self.gap_any = np.logical_or(self.gap_any, other.gap) - self.unseq_any = np.logical_or(self.unseq_any, other.unseq) - - self.gap[:, state] = other.gap - self.unseq[:, state] = other.unseq - self.match[:, state] = other.match - - def add_mask_flags(self, other: Flag_Info, state: int): - ''' - Join the other flag info with this by replacing values in mask and - performing an OR with mask_any - ''' - self.mask_any = np.logical_or(self.mask_any, other.mask) - self.mask[:, state] = other.mask - - def encode_info(self, - master_ind: int, - predict_ind: int) -> str: - ''' - Summarize info flags into a string. master_ind is the index of - the master reference state. predict_ind is the index of the predicted - state. The return string is encoded for each position as: - '-': if either master or predict has a gap - '_': if either master or predict is masked - '.': if any state has a match - 'b': both predict and master match - 'c': master matches but not predict - 'p': predict matches but not master - 'x': no other condition applies - if the position is in the hmm_flag - it will be capitalized for x, p, c, or b - in order of precidence, e.g. if a position satisfies both '-' and '.', - it will be '-'. - ''' - - if predict_ind >= self.match.shape[1]: - return self.encode_unknown_info(master_ind) - - decoder = np.array(list('xXpPcCbB._-')) - indices = np.zeros(self.match.shape[0], int) - - indices[self.match[:, predict_ind]] += 2 # x to p if true - indices[self.match[:, master_ind]] += 4 # x to c, p to b - indices[self.hmm] += 1 # to upper - - matches = np.all(self.match, axis=1) - indices[matches] = 8 # . - indices[np.any( - self.mask[:, [master_ind, predict_ind]], - axis=1)] = 9 # _ - indices[np.any( - self.gap[:, [master_ind, predict_ind]], - axis=1)] = 10 # - - - return ''.join(decoder[indices]) - - def encode_unknown_info(self, - master_ind: int) -> str: - ''' - Summarize info dictionary into a string for unknown state. - master_ind is the index of the master reference state. - The return string is encoded as each position as: - '-': if any state has a gap - '_': if any state has a mask - '.': all states match - 'x': master matches - 'X': no other condition applies - in order of precidence, e.g. if a position satisfies both '-' and '.', - it will be '-'. - ''' - - # used with indices to decode result - decoder = np.array(list('Xx._-')) - indices = np.zeros(self.gap_any.shape, int) - - indices[self.match[:, master_ind]] = 1 # x - matches = np.all(self.match, axis=1) - indices[matches] = 2 # . - indices[self.mask_any] = 3 # _ - indices[self.gap_any] = 4 # - - - return ''.join(decoder[indices]) - - class Region_Database(): ''' Contains data and logic for regions data during summarizing diff --git a/code/analyze/to_update/aggregate_genes_by_strains_main.py b/code/analyze/to_update/aggregate_genes_by_strains_main.py index e7ac91e..f7e008f 100644 --- a/code/analyze/to_update/aggregate_genes_by_strains_main.py +++ b/code/analyze/to_update/aggregate_genes_by_strains_main.py @@ -1,28 +1,19 @@ import sys -import os -import gzip -import predict from collections import defaultdict -from summarize_region_quality import * -import gene_predictions -sys.path.insert(0, '..') import global_params as gp -sys.path.insert(0, '../misc/') -import read_fasta -import read_table -import seq_functions +from misc import read_table tag = sys.argv[1] fn = gp.analysis_out_dir_absolute + tag + \ '/introgressed_blocks_filtered_par_' + tag + '_summary_plus.txt' -regions_filtered, l = read_table.read_table_rows(fn, "\t") +regions_filtered, _ = read_table.read_table_rows(fn, "\t") gene_strains = defaultdict(set) strain_genes = defaultdict(lambda: defaultdict(set)) for chrm in gp.chrms: - + fn = gp.analysis_out_dir_absolute + tag + \ '/genes_for_each_region_chr' + chrm + '_' + \ tag + '.txt' @@ -43,18 +34,18 @@ for gene in gene_strains: gene_counts[gene] = len(gene_strains[gene]) -f_out = open(gp.analysis_out_dir_absolute + tag + \ - '/genes_for_each_strain_filtered_' + \ +f_out = open(gp.analysis_out_dir_absolute + tag + + '/genes_for_each_strain_filtered_' + tag + '.txt', 'w') f_out.write('strain\tchromosome\tnum_genes\n') for chrm in gp.chrms: for strain in strain_genes[chrm]: - f_out.write(strain + '\t' + chrm + '\t' + \ + f_out.write(strain + '\t' + chrm + '\t' + str(len(strain_genes[chrm][strain])) + '\n') f_out.close() -f_out = open(gp.analysis_out_dir_absolute + tag + \ - '/genes_strain_hist_' + \ +f_out = open(gp.analysis_out_dir_absolute + tag + + '/genes_strain_hist_' + tag + '.txt', 'w') f_out.write('gene\tnum_strains\n') for gene in sorted(gene_counts.keys()): diff --git a/code/analyze/to_update/annotate_positions.py b/code/analyze/to_update/annotate_positions.py index 16b6378..0393c58 100644 --- a/code/analyze/to_update/annotate_positions.py +++ b/code/analyze/to_update/annotate_positions.py @@ -1,9 +1,8 @@ -import sys import re import gzip -sys.path.insert(0, '../misc/') -import overlap -import read_fasta +from misc import overlap +from misc import read_fasta + def get_genes(fn): @@ -15,14 +14,17 @@ def get_genes(fn): f.close() return genes + def get_orfs(fn): headers, seqs = read_fasta.read_fasta(fn) orfs = {} for h in headers: - m = re.search(' (?P[a-zA-Z0-9]+)_(?P[a-zA-Z0-9\.]+):(?P[0-9]+):(?P[0-9]+)', h) + m = re.search(r' (?P[a-zA-Z0-9]+)_(?P[a-zA-Z0-9\.]+)' + ':(?P[0-9]+):(?P[0-9]+)', h) orfs[(int(m.group('start')), int(m.group('end')))] = m.group('name') return orfs + def write_annotated_file(coords, genes, orfs, fn): # could definitely do this all way more efficiently sep = '\t' @@ -33,14 +35,13 @@ def write_annotated_file(coords, genes, orfs, fn): if int(coords[i]) == coords[i]: f.write(str(int(coords[i])) + sep) gene = overlap.contained_any_named(coords[i], genes) - if gene != None: + if gene is not None: f.write(gene) f.write(sep) else: f.write(str(coords[i]) + sep + sep) orf = overlap.contained_any_named(i, orfs) - if orf != None: + if orf is not None: f.write(orf) f.write('\n') f.close() - diff --git a/code/analyze/to_update/annotate_positions_main.py b/code/analyze/to_update/annotate_positions_main.py index c5647b0..f721acb 100644 --- a/code/analyze/to_update/annotate_positions_main.py +++ b/code/analyze/to_update/annotate_positions_main.py @@ -4,28 +4,24 @@ # gene # in ORF? -import re import sys import os -import copy import gzip -from annotate_positions import * -sys.path.insert(0, '..') +from annotate_positions import (get_genes, get_orfs, write_annotated_file) import global_params as gp -sys.path.insert(0, '../align/') -import align_helpers +from align import align_helpers -##====== +# ====== # get strains -##====== +# ====== i = int(sys.argv[1]) s = align_helpers.get_strains(align_helpers.flatten(gp.non_ref_dirs.values())) strain, d = s[i] -##====== +# ====== # get genes on each chromosome -##====== +# ====== genes_by_chrm = {} for chrm in gp.chrms: @@ -33,10 +29,10 @@ '_genes.txt' genes_by_chrm[chrm] = get_genes(fn) -##====== +# ====== # loop through all strains and chromosomes, generating annotated # position file for each -##====== +# ====== coord_dir = gp.analysis_out_dir_absolute + 'coordinates/' if not os.path.exists(coord_dir + 'annotated'): @@ -44,37 +40,15 @@ for chrm in gp.chrms: - print strain, chrm + print(strain, chrm) fn = strain + '_to_' + gp.master_ref + '_chr' + chrm + '.txt.gz' fn_orfs = d + 'orfs/' + strain + '_chr' + chrm + \ - '_orfs' + gp.fasta_suffix + '_orfs' + gp.fasta_suffix orfs = get_orfs(fn_orfs) fn_out = coord_dir + 'annotated/' + fn - coords = [float(line) for line in gzip.open(coord_dir + fn, 'rb').readlines()] + coords = [float(line) + for line in gzip.open(coord_dir + fn, 'rb').readlines()] write_annotated_file(coords, genes_by_chrm[chrm], orfs, fn_out) - - - - - -#for strain, d in s: - - #m = re.search('(?P[a-zA-Z0-9]+)_to_(?P[a-zA-Z0-9]+)_chr(?P[IVXM]+)', fn) - #if m == None: - # continue - #strain1 = m.group('strain1') - #strain2 = m.group('strain2') - #chrm = m.group('chrm') - - #if strain1 == gp.master_ref: - # continue - - # don't deal with paradoxus just for now - #if strain1 in gp.alignment_ref_order or strain2 != gp.master_ref: - # continue - - #print fn - diff --git a/code/analyze/to_update/annotate_regions.py b/code/analyze/to_update/annotate_regions.py index fd032a0..8758484 100644 --- a/code/analyze/to_update/annotate_regions.py +++ b/code/analyze/to_update/annotate_regions.py @@ -1,9 +1,6 @@ import gzip import gene_predictions -import sys import global_params as gp -sys.path.insert(0, '../misc/') - def get_block_by_site(all_regions, seq): @@ -19,8 +16,8 @@ def get_block_by_site(all_regions, seq): return introgressed_by_site -def write_predictions_annotated(alignment_headers, alignment_seqs, master, \ - strain_labels, match_by_site, \ +def write_predictions_annotated(alignment_headers, alignment_seqs, master, + strain_labels, match_by_site, gene_by_site, block_by_site, masked, fn): f = gzip.open(fn, 'wb') @@ -36,10 +33,10 @@ def write_predictions_annotated(alignment_headers, alignment_seqs, master, \ individual_indices = [0] * num_seqs # header - f.write('ps_ref' + sep + 'ps_strain' + sep + \ - sep.join(strain_labels) + sep + \ - 'match' + sep + \ - 'gene' + sep + 'block' + sep + \ + f.write('ps_ref' + sep + 'ps_strain' + sep + + sep.join(strain_labels) + sep + + 'match' + sep + + 'gene' + sep + 'block' + sep + sep.join([lab + '_masked' for lab in strain_labels]) + '\n') lines = [] @@ -57,7 +54,7 @@ def write_predictions_annotated(alignment_headers, alignment_seqs, master, \ ind_ref += 1 ps_ref = str(ind_ref) line += ps_ref + sep - + # index in strain ps_strain = None if alignment_seqs[-1][i] == gp.gap_symbol: @@ -76,7 +73,7 @@ def write_predictions_annotated(alignment_headers, alignment_seqs, master, \ line += match_by_site[r][i] line += sep - if gene_by_site[i] != None: + if gene_by_site[i] is not None: line += gene_by_site[i] line += sep @@ -87,17 +84,18 @@ def write_predictions_annotated(alignment_headers, alignment_seqs, master, \ line += sep if alignment_seqs[si][i] != gp.gap_symbol: # TODO update n to x - if masked[si][individual_indices[si]] == 'n': #gp.masked_symbol: + if masked[si][individual_indices[si]] == 'n': # masked line += gp.masked_symbol individual_indices[si] += 1 - + line += '\n' - + lines.append(line) f.writelines(lines) f.close() + # TODO give this a more general name/place def read_predictions_annotated(fn): sep = '\t' @@ -109,10 +107,7 @@ def read_predictions_annotated(fn): line = line[:-1].split(sep) for i in range(len(labels)): d[labels[i]].append(line[i]) - #d[line[0]] = dict(zip(labels[1:], line[1:])) + # d[line[0]] = dict(zip(labels[1:], line[1:])) line = f.readline() f.close() return d - - - diff --git a/code/analyze/to_update/annotate_regions_main.py b/code/analyze/to_update/annotate_regions_main.py index a06bb2c..3a60544 100644 --- a/code/analyze/to_update/annotate_regions_main.py +++ b/code/analyze/to_update/annotate_regions_main.py @@ -1,28 +1,25 @@ -# ps_cer ps_strain cer_ref par_ref strain gene introgressed_region cer_masked par_masked strain_masked +# ps_cer ps_strain cer_ref par_ref strain gene introgressed_region +# cer_masked par_masked strain_masked -import re import sys import os -import copy -import gene_predictions -from annotate_regions import * +import gene_predictions +from annotate_regions import (write_predictions_annotated, + get_block_by_site) import predict -import pickle -sys.path.insert(0, '..') import global_params as gp -sys.path.insert(0, '../misc/') -import read_fasta +from misc import read_fasta -##====== +# ====== # read in analysis parameters -##====== +# ====== refs, strains, args = predict.process_args(sys.argv[1:]) chrm = sys.argv[1] -##====== +# ====== # read in introgressed/unknown regions and alignments -##====== +# ====== gp_dir = '../' @@ -40,9 +37,9 @@ fn_align_prefix = gp_dir + gp.alignments_dir fn_align_prefix += '_'.join([refs[s][0] for s in args['species']]) + '_' -##====== +# ====== # produce annotated files -##====== +# ====== # for keeping track of all genes introgressed in each strain, and the # fraction introgressed @@ -68,20 +65,20 @@ fn_genes = gp.analysis_out_dir_absolute + '/' + \ master_ref + '_chr' + chrm + '_genes.txt' -print 'reading genes on chromosome', chrm +print('reading genes on chromosome', chrm) # dictionary keyed by name: (start, end) genes = gene_predictions.read_genes(fn, fn_genes) -print 'done reading genes' +print('done reading genes') # loop through all strains that we've called introgression in, and # associate genes with the regions they overlap for strain in regions.keys(): - - print '***', strain, chrm + + print('***', strain, chrm) sys.stdout.flush() - fn_out = gp.analysis_out_dir_absolute + args['tag'] + '/site_summaries/' + \ - 'predictions_' + strain + '_chr' + chrm + '_site_summary.txt.gz' + fn_out = gp.analysis_out_dir_absolute + args['tag'] + '/site_summaries/' +\ + 'predictions_' + strain + '_chr' + chrm + '_site_summary.txt.gz' if not os.path.exists(os.path.dirname(fn_out)): os.makedirs(os.path.dirname(fn_out)) @@ -92,29 +89,30 @@ # read alignment blocks for this strain and chromosome fn_align = fn_align_prefix + \ - strain + '_chr' + chrm + '_mafft' + gp.alignment_suffix + strain + '_chr' + chrm + '_mafft' + gp.alignment_suffix alignment_headers, alignment_seqs = read_fasta.read_fasta(fn_align) # read masked (unaligned) sequences seq_masked_fns = [header.split()[-1] for header in alignment_headers] - seq_masked_fns = [mfn[:-len(gp.fasta_suffix)] + '_masked' + gp.fasta_suffix \ + seq_masked_fns = [mfn[:-len(gp.fasta_suffix)] + '_masked' + gp.fasta_suffix for mfn in seq_masked_fns] seqs_masked = [read_fasta.read_fasta(mfn)[1][0] for mfn in seq_masked_fns] labels = ref_labels + [strain] - + # mark each site as matching each reference or not - ref_match_by_site = gene_predictions.get_ref_match_by_site(alignment_seqs, labels) + ref_match_by_site = gene_predictions.get_ref_match_by_site(alignment_seqs, + labels) # mark each site as in a gene or not - genes_by_site = gene_predictions.get_genes_by_site(genes, alignment_seqs[0]) + genes_by_site = gene_predictions.get_genes_by_site(genes, + alignment_seqs[0]) # mark each site as introgressed or not all_regions = [regions[strain][chrm]] - if regions_unk.has_key(strain) and regions_unk[strain].has_key(chrm): + if strain in regions_unk.has_key and chrm in regions_unk[strain]: all_regions.append(regions_unk[strain][chrm]) block_by_site = get_block_by_site(all_regions, alignment_seqs[0]) - write_predictions_annotated(alignment_headers, alignment_seqs, 0, \ - ref_labels + [strain], ref_match_by_site, \ - genes_by_site, block_by_site, seqs_masked, fn_out) - - + write_predictions_annotated(alignment_headers, alignment_seqs, 0, + ref_labels + [strain], ref_match_by_site, + genes_by_site, block_by_site, + seqs_masked, fn_out) diff --git a/code/analyze/to_update/check_paralogs_main.py b/code/analyze/to_update/check_paralogs_main.py index 1bc39bd..bde7994 100644 --- a/code/analyze/to_update/check_paralogs_main.py +++ b/code/analyze/to_update/check_paralogs_main.py @@ -1,6 +1,6 @@ # Loop through all introgressed genes (might be just a small part) # that have paralogs -# Extract introgressed portion of gene +# Extract introgressed portion of gene # Blast that portion against: # - Cerevisiae gene # - Paradoxus gene (region aligned to cerevisiae gene) @@ -13,23 +13,14 @@ # - Paradoxus paralog -> interesting... -import re -import sys import os import math -import Bio.SeqIO -import copy import gzip -import gene_predictions -sys.path.insert(0, '..') import global_params as gp -sys.path.insert(0, '../align/') -import align_helpers -sys.path.insert(0, '../misc/') -import read_table -import read_fasta -import write_fasta -import mystats +from align import align_helpers +from misc import read_table +from misc import read_fasta +from misc import write_fasta postprocess = False @@ -48,7 +39,7 @@ # dict of dicts keyed by region id and column names; includes filtered # and unfiltered regions region_to_genes = {} -f = open(gp.analysis_out_dir_absolute + tag + \ +f = open(gp.analysis_out_dir_absolute + tag + '/genes_for_each_region_' + tag + '.txt', 'r') line = f.readline() while line != '': @@ -60,10 +51,10 @@ f.close() # dict of lists keyed by region id -t_regions_filtered, l = \ - read_table.read_table_rows(gp.analysis_out_dir_absolute + tag + \ - '/introgressed_blocks_filtered_par_' + tag + \ - '_summary_plus.txt', \ +t_regions_filtered, _ = \ + read_table.read_table_rows(gp.analysis_out_dir_absolute + tag + + '/introgressed_blocks_filtered_par_' + tag + + '_summary_plus.txt', '\t', header=True) @@ -73,12 +64,12 @@ for region_id in region_to_genes: genes = region_to_genes[region_id] for gene in genes: - if not gene_to_regions.has_key(gene): + if gene not in gene_to_regions: gene_to_regions[gene] = [] gene_to_regions[gene].append(region_id) if region_id in t_regions_filtered: for gene in genes: - if not gene_to_regions_filtered.has_key(gene): + if gene not in gene_to_regions_filtered: gene_to_regions_filtered[gene] = [] gene_to_regions_filtered[gene].append(region_id) @@ -95,7 +86,7 @@ # read in all gene coordinates gene_coords = {} for chrm in gp.chrms: - f = open(gp.analysis_out_dir_absolute + \ + f = open(gp.analysis_out_dir_absolute + 'S288c_chr' + chrm + '_genes.txt', 'r') lines = [line.strip().split('\t') for line in f.readlines()] f.close() @@ -108,8 +99,9 @@ genes_to_analyze = gene_to_regions_filtered.keys() if postprocess: - genes_to_analyze = [line.split('\t')[0] for line in \ - open('check_paralogs_out_cer_paralog.tsv', 'r').readlines()] + genes_to_analyze = [line.split('\t')[0] for line in + open('check_paralogs_out_cer_paralog.tsv', + 'r').readlines()] genes_to_analyze = list(set(genes_to_analyze)) ip = 0 @@ -117,13 +109,13 @@ if gene not in paralogs: continue - print ip + print(ip) ip += 1 chrm, ref_gene_start, ref_gene_end = gene_coords[gene] gene_headers, gene_seqs = \ - read_fasta.read_fasta(gp.analysis_out_dir_absolute + tag + '/genes/' + \ + read_fasta.read_fasta(gp.analysis_out_dir_absolute + tag + '/genes/' + gene + '/' + gene + '_from_alignment.fa') gene_headers = [x[1:].strip() for x in gene_headers] strain_seqs = dict(zip(gene_headers, gene_seqs)) @@ -133,7 +125,7 @@ paralog = paralogs[gene] gene_headers, gene_seqs = \ - read_fasta.read_fasta(gp.analysis_out_dir_absolute + tag + '/genes/' + \ + read_fasta.read_fasta(gp.analysis_out_dir_absolute + tag + '/genes/' + paralog + '/' + paralog + '_from_alignment.fa') gene_headers = [x[1:].strip() for x in gene_headers] strain_paralog_seqs = dict(zip(gene_headers, gene_seqs)) @@ -153,8 +145,8 @@ f.close() cmd_string = gp.blast_install_path + 'makeblastdb' + \ - ' -in ' + db_fn + \ - ' -dbtype nucl' + ' -in ' + db_fn + \ + ' -dbtype nucl' os.system(cmd_string) strain_intd_seqs = {} @@ -164,25 +156,29 @@ ref_region_start = int(t_regions_filtered[region]['start']) ref_region_end = int(t_regions_filtered[region]['end']) - ref_to_strain_coords = [float(x[:-1]) for x in \ - gzip.open(gp.analysis_out_dir_absolute + \ - 'coordinates/S288c_to_' + strain + \ - '_chr' + chrm + '.txt.gz').readlines()] + ref_to_strain_coords = [float(x[:-1]) for x in + gzip.open(gp.analysis_out_dir_absolute + + 'coordinates/S288c_to_' + strain + + '_chr' + chrm + + '.txt.gz').readlines()] - gene_start = int(max(0, math.ceil(ref_to_strain_coords[ref_gene_start]))) + gene_start = int(max(0, math.ceil( + ref_to_strain_coords[ref_gene_start]))) gene_end = int(math.floor(ref_to_strain_coords[ref_gene_end])) - - region_start = int(max(0, math.ceil(ref_to_strain_coords[ref_region_start]))) + + region_start = int(max(0, math.ceil( + ref_to_strain_coords[ref_region_start]))) region_end = int(math.floor(ref_to_strain_coords[ref_region_end])) start = max(gene_start, region_start) end = min(gene_end, region_end) - chrom_seq = read_fasta.read_fasta(strain_dirs[strain] + strain + '_chr' + \ + chrom_seq = read_fasta.read_fasta(strain_dirs[strain] + + strain + '_chr' + chrm + gp.fasta_suffix)[1][0] seq = chrom_seq[start:end+1] - if not strain_intd_seqs.has_key(strain): + if strain not in strain_intd_seqs: strain_intd_seqs[strain] = chrom_seq[gene_start:gene_end+1].lower() relative_start = start - gene_start relative_end = end - gene_start @@ -196,11 +192,11 @@ f.close() cmd_string = gp.blast_install_path + 'blastn' + \ - ' -db ' + db_fn + \ - ' -query ' + query_fn + \ - ' -out ' + out_fn + \ - ' -outfmt ' + outfmt - print cmd_string + ' -db ' + db_fn + \ + ' -query ' + query_fn + \ + ' -out ' + out_fn + \ + ' -outfmt ' + outfmt + print(cmd_string) os.system(cmd_string) if os.stat(out_fn).st_size == 0: @@ -210,9 +206,9 @@ ' -out ' + out_fn + \ ' -task "blastn-short"' + \ ' -outfmt ' + outfmt - print cmd_string + print(cmd_string) os.system(cmd_string) - + lines = open(out_fn, 'r').readlines() best_key = 'none' if len(lines) != 0: @@ -236,10 +232,10 @@ # write reference genes and paralogs and all introgressed # genes to file and then align fn = gp.analysis_out_dir_absolute + tag + '/paralogs/' + \ - gene + gp.fasta_suffix - headers = ['S288c ' + gene, 'CBS432 ' + gene, \ + gene + gp.fasta_suffix + headers = ['S288c ' + gene, 'CBS432 ' + gene, 'S288c ' + paralog, 'CBS432 ' + paralog] - seqs = [cer_seq.lower(), par_seq.lower(), \ + seqs = [cer_seq.lower(), par_seq.lower(), cer_paralog_seq.lower(), par_paralog_seq.lower()] for strain in strain_intd_seqs: headers.append(strain + ' ' + gene) @@ -248,10 +244,10 @@ aligned_fn = fn.replace(gp.fasta_suffix, gp.alignment_suffix) cmd_string = gp.mafft_install_path + '/mafft ' + \ - ' --quiet --reorder --preservecase ' + \ - fn + ' > ' + aligned_fn + ' --quiet --reorder --preservecase ' + \ + fn + ' > ' + aligned_fn os.system(cmd_string) - + f = open('check_paralogs_out.tsv', 'w') f.write('category\tnum_total_genes\tnum_unique_genes\n') for key in keys: @@ -265,7 +261,5 @@ for item in all_rankings[key]: fk.write('\t'.join(item) + '\n') fk.close() - -f.close() - +f.close() diff --git a/code/analyze/to_update/combine_all_strains.py b/code/analyze/to_update/combine_all_strains.py index 947764b..fa331ca 100644 --- a/code/analyze/to_update/combine_all_strains.py +++ b/code/analyze/to_update/combine_all_strains.py @@ -3,19 +3,12 @@ import os import math import Bio.SeqIO -import copy -import gene_predictions -sys.path.insert(0, '..') import global_params as gp -sys.path.insert(0, '../sim/') -import sim_analyze_hmm_bw as sim -sys.path.insert(0, '../misc/') -import seq_functions -import read_table -import read_fasta -import write_fasta -import mystats -import overlap +from misc import seq_functions +from misc import read_table +from misc import read_fasta +from misc import overlap + def get_range_seq(start, end, seq_fn): @@ -23,16 +16,17 @@ def get_range_seq(start, end, seq_fn): range_seq = chrm_seq[start:end+1] return range_seq + def get_ref_gene_seq(gene, gene_coords_fn, seq_fn): - d1, labels = read_table.read_table_rows(gene_coords_fn, '\t', \ + d1, labels = read_table.read_table_rows(gene_coords_fn, '\t', header=False, key_ind=0) d = {} for g in d1: if d1[g][0] == '""': d[g] = d1[g][1:] else: - d[d1[g][0]] = d1[g][1:] + d[d1[g][0]] = d1[g][1:] gene_start = int(d[gene][2]) - 1 gene_end = int(d[gene][3]) - 1 @@ -45,6 +39,7 @@ def get_ref_gene_seq(gene, gene_coords_fn, seq_fn): assert gene_start < gene_end return gene_seq, gene_start, gene_end, strand + def get_inds_from_alignment(fn, flip_ref, rind=0, sind=1): headers, seqs = read_fasta.read_fasta(fn) n = len(seqs[0]) @@ -63,29 +58,30 @@ def get_inds_from_alignment(fn, flip_ref, rind=0, sind=1): pr.append(str(ri)) ps.append(str(si)) if flip_ref: - return {'ps_ref':ps, 'ps_strain':pr} - return {'ps_ref':pr, 'ps_strain':ps} + return {'ps_ref': ps, 'ps_strain': pr} + return {'ps_ref': pr, 'ps_strain': ps} # by taking part of sequence aligned with reference coordinates -def get_range_seqs(strains, chrm, start, end, tag, gp_dir = '../'): +def get_range_seqs(strains, chrm, start, end, tag, gp_dir='../'): # TODO this shouldn't actually be dependent on tag strain_range_seqs = {} for strain, d in strains: - print strain + print(strain) fn = d + strain + '_chr' + chrm + gp.fasta_suffix chrm_seq = read_fasta.read_fasta(fn)[1][0] t = None try: - t, labels = read_table.read_table_columns(gp.analysis_out_dir_absolute + \ - tag + '/' + \ - 'site_summaries/predictions_' + \ - strain + \ - '_chr' + chrm + \ - '_site_summary.txt.gz', '\t') - except: + t, labels = read_table.read_table_columns( + gp.analysis_out_dir_absolute + + tag + '/' + + 'site_summaries/predictions_' + + strain + + '_chr' + chrm + + '_site_summary.txt.gz', '\t') + except FileNotFoundError: # for par reference which doesn't have site summary file align_fn = gp_dir + gp.alignments_dir + \ '_'.join(gp.alignment_ref_order) + '_chr' + chrm + \ @@ -97,13 +93,13 @@ def get_range_seqs(strains, chrm, start, end, tag, gp_dir = '../'): start_strain = int(math.ceil(float(ref_ind_to_strain_ind[str(start)]))) end_strain = int(math.floor(float(ref_ind_to_strain_ind[str(end)]))) - - strain_range_seqs[strain] = (chrm_seq[start_strain:end_strain+1], \ - start_strain, end_strain) + strain_range_seqs[strain] = (chrm_seq[start_strain:end_strain+1], + start_strain, end_strain) return strain_range_seqs -def choose_best_hit_rev(hits, query_fn, ref_chrm_fn, orf_headers, orf_seqs, start, end): +def choose_best_hit_rev(hits, query_fn, ref_chrm_fn, + orf_headers, orf_seqs, start, end): # choosing best hit by reciprocal blast -> not reliable tho if len(hits) == 1: return hits[0][0] @@ -124,10 +120,10 @@ def choose_best_hit_rev(hits, query_fn, ref_chrm_fn, orf_headers, orf_seqs, star f.write(seq + '\n') f.close() cmd_string = gp.blast_install_path + 'blastn' + \ - ' -db ' + ref_chrm_fn + \ - ' -query ' + orf_query_fn + \ - ' -out ' + out_fn + \ - ' -outfmt ' + outfmt + ' -db ' + ref_chrm_fn + \ + ' -query ' + orf_query_fn + \ + ' -out ' + out_fn + \ + ' -outfmt ' + outfmt os.system(cmd_string) f = open(out_fn, 'r') nhits = [line[:-1].split('\t') for line in f.readlines()] @@ -135,7 +131,7 @@ def choose_best_hit_rev(hits, query_fn, ref_chrm_fn, orf_headers, orf_seqs, star nstart = int(nhits[0][-2]) nend = int(nhits[0][-1]) # this division is hacky and unprincipled - o = overlap.overlap(start, end, nstart, nend) / float(hit[1]) + o = overlap.overlap(start, end, nstart, nend) / float(hit[1]) if o > greatest_overlap: greatest_overlap = o best_hit = hit[0] @@ -143,10 +139,11 @@ def choose_best_hit_rev(hits, query_fn, ref_chrm_fn, orf_headers, orf_seqs, star os.remove(out_fn) return best_hit -def choose_best_hit(hits, start, end, tag, strain, chrm, headers, seqs,\ + +def choose_best_hit(hits, start, end, tag, strain, chrm, headers, seqs, strain_ind_to_ref_ind, gp_dir='../'): - greatest_overlap = 0 # don't want to take overlaps of 0 + greatest_overlap = 0 # don't want to take overlaps of 0 best_hit = None x = None seq = None @@ -171,7 +168,7 @@ def choose_best_hit(hits, start, end, tag, strain, chrm, headers, seqs,\ c2 = chunk2.find(':', c1+1) seq = seqs[i] orf_start = int(chunk2[c1+1:c2]) - orf_end = int(chunk2[c2+1:]) + orf_end = int(chunk2[c2+1:]) strand = '1' if orf_start > orf_end: temp = orf_end @@ -179,8 +176,10 @@ def choose_best_hit(hits, start, end, tag, strain, chrm, headers, seqs,\ orf_start = temp strand = '-1' break - current_start = int(math.ceil(float(strain_ind_to_ref_ind[str(orf_start)]))) - current_end = int(math.floor(float(strain_ind_to_ref_ind[str(orf_end)]))) + current_start = int(math.ceil( + float(strain_ind_to_ref_ind[str(orf_start)]))) + current_end = int(math.floor( + float(strain_ind_to_ref_ind[str(orf_end)]))) o = overlap.overlap(start, end, current_start, current_end) if o > greatest_overlap: greatest_overlap = o @@ -190,16 +189,19 @@ def choose_best_hit(hits, start, end, tag, strain, chrm, headers, seqs,\ orf_start_max = orf_start orf_end_max = orf_end strand_max = strand - seq_max = seq # don't need to reverse complement (blast does this) + seq_max = seq # don't need to reverse complement (blast does this) - print greatest_overlap + print(greatest_overlap) return best_hit, x_max, seq_max, orf_start_max, orf_end_max, strand_max + # by blasting ORFs -def get_gene_seqs(query_fn, strains, chrm, ref_chrm_fn, start, end, strand, tag, +def get_gene_seqs(query_fn, strains, chrm, ref_chrm_fn, + start, end, strand, tag, strain_ind_to_ref_ind): - - #outfmt = '"6 qseqid sseqid slen qstart qend length mismatch gapopen gaps sseq"' + + # outfmt = '"6 qseqid sseqid slen qstart qend \ + # length mismatch gapopen gaps sseq"' outfmt = '"6 sseqid slen evalue bitscore"' strain_gene_seqs = {} @@ -208,38 +210,39 @@ def get_gene_seqs(query_fn, strains, chrm, ref_chrm_fn, start, end, strand, tag, if strain != 'yjm1332': continue - print '-', strain + print('-', strain) sys.stdout.flush() - fn = d + 'orfs/' + strain + '_chr' + chrm + '_orfs' + gp.fasta_suffix + fn = d + 'orfs/' + strain + '_chr' + chrm + '_orfs' + gp.fasta_suffix cmd_string = gp.blast_install_path + 'blastn' + \ - ' -db ' + fn + \ - ' -query ' + query_fn + \ - ' -out ' + out_fn + \ - ' -outfmt ' + outfmt - #print cmd_string + ' -db ' + fn + \ + ' -query ' + query_fn + \ + ' -out ' + out_fn + \ + ' -outfmt ' + outfmt + # print(cmd_string) os.system(cmd_string) - hits = [line[:-1].split('\t') for line in open(out_fn, 'r').readlines()] - num_hits = len(hits) + hits = [line[:-1].split('\t') + for line in open(out_fn, 'r').readlines()] if len(hits) == 0: strain_gene_seqs[strain] = ('nohit', '', -1, -1, '') continue - #best_orf_id = hits[0][0] + # best_orf_id = hits[0][0] headers, seqs = read_fasta.read_fasta(fn) best_orf_id, x, seq, orf_start, orf_end, orf_strand = \ - choose_best_hit(hits, start, end, tag, strain, chrm, headers, seqs, \ + choose_best_hit(hits, start, end, tag, strain, + chrm, headers, seqs, strain_ind_to_ref_ind[strain]) - print hits - print best_orf_id - print orf_strand, strand + print(hits) + print(best_orf_id) + print(orf_strand, strand) sys.exit() - if best_orf_id == None or orf_strand != strand: + if best_orf_id is None or orf_strand != strand: strain_gene_seqs[strain] = ('nohit', '', -1, -1, '') continue strain_gene_seqs[strain] = (x, seq, orf_start, orf_end, orf_strand) os.remove(out_fn) return strain_gene_seqs - + # can't actually count on annotations def get_gene_seqs_gb(fn, gene, chrm): @@ -249,38 +252,40 @@ def get_gene_seqs_gb(fn, gene, chrm): strains = set([]) for strain_chrm_record in gb_records: desc = strain_chrm_record.description - m = re.search(' (?P[a-zA-Z0-9]+) chromosome (?P[IVXM]+)', \ + m = re.search(' (?P[a-zA-Z0-9]+) chromosome (?P[IVXM]+)', desc) chrm_current = m.group('chrm') strain = m.group('strain').lower() strains.add(strain) - #if len(strain_gene_seqs) > 82: - # break - print strain, chrm_current + # if len(strain_gene_seqs) > 82: + # break + print(strain, chrm_current) if chrm_current != chrm: continue for feature in strain_chrm_record.features: - if feature.type == 'CDS' and feature.qualifiers.has_key('gene') and \ + if feature.type == 'CDS' and 'gene' in feature.qualifiers and \ feature.qualifiers['gene'][0] == gene: desc = strain_chrm_record.description - m = re.search(\ - ' (?P[a-zA-Z0-9]+) chromosome (?P[IVXM]+)', \ + m = re.search( + ' (?P[a-zA-Z0-9]+) ' + 'chromosome (?P[IVXM]+)', desc) seq = str(feature.extract(strain_chrm_record.seq).lower()) start = str(feature.location.start) end = str(feature.location.end) strand = str(feature.location.strand) locus_tag = feature.qualifiers['locus_tag'][0] - strain_gene_seqs[strain] = {'seq':seq, \ - 'chrm':chrm, \ - 'start':start, \ - 'end':end, \ - 'strand':strand,\ - 'locus_tag':locus_tag} - - print '- found gene in', strain + strain_gene_seqs[strain] = {'seq': seq, + 'chrm': chrm, + 'start': start, + 'end': end, + 'strand': strand, + 'locus_tag': locus_tag} + + print('- found gene in', strain) return strain_gene_seqs, list(strains) + # because don't have gb file for paradoxus... def get_gene_seqs_fsa(fn, gene, chrm): f = open(fn, 'r') @@ -295,12 +300,13 @@ def get_gene_seqs_fsa(fn, gene, chrm): line = f.readline() f.close() - seqfa = open(gp.ref_dir['CBS432'] + 'CBS432_chr' + chrm + '.fa', 'r').read() + seqfa = open(gp.ref_dir['CBS432'] + + 'CBS432_chr' + chrm + '.fa', 'r').read() seqfa = seqfa.replace('\n', '') if seq in seqfa: - print 'found paradoxus seq' + print('found paradoxus seq') else: - print 'did not find paradoxus seq' + print('did not find paradoxus seq') fg = open('a.txt', 'w') fg.write(seq + '\n') fg.write(seqfa + '\n') @@ -308,4 +314,3 @@ def get_gene_seqs_fsa(fn, gene, chrm): return seq.lower() line = f.readline() - diff --git a/code/analyze/to_update/combine_gene_all_strains_main.py b/code/analyze/to_update/combine_gene_all_strains_main.py index 2816954..b8c6912 100644 --- a/code/analyze/to_update/combine_gene_all_strains_main.py +++ b/code/analyze/to_update/combine_gene_all_strains_main.py @@ -1,4 +1,5 @@ -# TODO - when blasting, take best gene, except if there are multiple hits, prioritize the one that overlaps the region we'd expect based on alignment +# TODO - when blasting, take best gene, except if there are multiple hits, +# prioritize the one that overlaps the region we'd expect based on alignment # input a gene or start/end coordinates @@ -6,34 +7,30 @@ # - for gene, relies on annotations/orfs # - for coordinates, relies on alignments -import re import sys import os import math -import Bio.SeqIO import copy -from combine_all_strains import * +from analyze.to_update.combine_all_strains import (get_gene_seqs, + get_inds_from_alignment, + get_ref_gene_seq) import gene_predictions -sys.path.insert(0, '..') import global_params as gp -sys.path.insert(0, '../align/') -import align_helpers -sys.path.insert(0, '../misc/') -import read_table -import read_fasta -import write_fasta -import mystats +from align import align_helpers +from misc import read_table +from misc import write_fasta tag = sys.argv[1] gene = sys.argv[2] chrm = sys.argv[3] -#all_outfiles = [] +# all_outfiles = [] -print 'getting gene sequence from reference strain' +print('getting gene sequence from reference strain') ref = 'S288c' ref_gene_coords_fn = '../../data/S288c_verified_orfs.tsv' -ref_seq_fn = gp.ref_dir[ref] + gp.ref_fn_prefix[ref] + '_chr' + chrm + gp.fasta_suffix +ref_seq_fn = gp.ref_dir[ref] + gp.ref_fn_prefix[ref] \ + + '_chr' + chrm + gp.fasta_suffix ref_gene_seq, ref_start, ref_end, ref_strand = \ get_ref_gene_seq(gene, ref_gene_coords_fn, ref_seq_fn) query_fn = gene + '.txt' @@ -41,19 +38,19 @@ f.write(ref_gene_seq + '\n') f.close() -print 'getting gene sequences from all strains' +print('getting gene sequences from all strains') gp_dir = '../' s = align_helpers.get_strains(align_helpers.flatten(gp.non_ref_dirs.values())) ref_ind_to_strain_ind = {} strain_ind_to_ref_ind = {} for strain, d in s: - print '*', strain + print('*', strain) sys.stdout.flush() - t, labels = read_table.read_table_columns(gp.analysis_out_dir_absolute + \ - tag + '/' + \ - 'site_summaries/predictions_' + \ - strain + \ - '_chr' + chrm + \ + t, labels = read_table.read_table_columns(gp.analysis_out_dir_absolute + + tag + '/' + + 'site_summaries/predictions_' + + strain + + '_chr' + chrm + '_site_summary.txt.gz', '\t') ref_ind_to_strain_ind[strain] = dict(zip(t['ps_ref'], t['ps_strain'])) strain_ind_to_ref_ind[strain] = dict(zip(t['ps_strain'], t['ps_ref'])) @@ -63,15 +60,19 @@ '_mafft' + gp.alignment_suffix t = get_inds_from_alignment(align_fn, True) other_ref_strain = gp.ref_fn_prefix[gp.alignment_ref_order[1]] -ref_ind_to_strain_ind[other_ref_strain] = dict(zip(t['ps_ref'], t['ps_strain'])) -strain_ind_to_ref_ind[other_ref_strain] = dict(zip(t['ps_strain'], t['ps_ref'])) +ref_ind_to_strain_ind[other_ref_strain] = dict( + zip(t['ps_ref'], t['ps_strain'])) +strain_ind_to_ref_ind[other_ref_strain] = dict( + zip(t['ps_strain'], t['ps_ref'])) s.append((other_ref_strain, gp.ref_dir[gp.alignment_ref_order[1]])) -strain_gene_seqs = get_gene_seqs(query_fn, s, chrm, ref_seq_fn, ref_start, ref_end, ref_strand, tag, strain_ind_to_ref_ind) +strain_gene_seqs = get_gene_seqs(query_fn, s, chrm, ref_seq_fn, ref_start, + ref_end, ref_strand, + tag, strain_ind_to_ref_ind) os.remove(query_fn) -print 'writing all gene sequences to file' +print('writing all gene sequences to file') keys = sorted(strain_gene_seqs.keys()) -headers = [key + ' ' + strain_gene_seqs[key][0] + ' ' + \ +headers = [key + ' ' + strain_gene_seqs[key][0] + ' ' + strain_gene_seqs[key][-1] for key in keys] seqs = [strain_gene_seqs[key][1] for key in keys] strains = [ref] + keys @@ -86,18 +87,18 @@ suffixes = ['', '_filtered'] for suffix in suffixes: - print ' '.join(['finding', suffix, 'regions that overlap gene']) + print(' '.join(['finding', suffix, 'regions that overlap gene'])) # read in filtered regions fn_regions = gp.analysis_out_dir_absolute + tag + '/' + \ - 'introgressed_blocks' + suffix + '_par_' + tag + '_summary_plus.txt' - regions, l = read_table.read_table_rows(fn_regions, '\t') + 'introgressed_blocks' + suffix + '_par_' + tag + '_summary_plus.txt' + regions, _ = read_table.read_table_rows(fn_regions, '\t') # figure out which strains are introgressed/which regions overlap gene fn_genes_regions = gp.analysis_out_dir_absolute + tag + '/' + \ - 'genes_for_each_region_chr' + chrm + '_' + tag + '.txt' + 'genes_for_each_region_chr' + chrm + '_' + tag + '.txt' region_to_genes = \ gene_predictions.read_genes_for_each_region_summary(fn_genes_regions) - #strains = [x[0] for x in s] + # strains = [x[0] for x in s] regions_overlapping = {} # TODO does this actually ensure that regions are sorted appropriately # in fasta headers below? @@ -106,12 +107,12 @@ if regions[region]['chromosome'] == chrm and \ gene in [x[0] for x in region_to_genes[region]['gene_list']]: strain = regions[region]['strain'] - if not regions_overlapping.has_key(strain): + if strain not in regions_overlapping: regions_overlapping[strain] = [] regions_overlapping[strain].append(region) - print ' '.join(['writing all gene sequences to file, with', \ - suffix, 'introgressed bases capitalized']) + print(' '.join(['writing all gene sequences to file, with', + suffix, 'introgressed bases capitalized'])) headers_current = copy.deepcopy(headers) seqs_current = copy.deepcopy(seqs) for i in range(len(seqs)): @@ -123,38 +124,35 @@ if strain not in regions_overlapping: continue g = strain_gene_seqs[strain] - t, labels = read_table.read_table_columns(gp.analysis_out_dir_absolute + \ - tag + '/' + \ - 'site_summaries/predictions_' + \ - strain + \ - '_chr' + chrm + \ - '_site_summary.txt.gz', '\t') + t, labels = read_table.read_table_columns( + gp.analysis_out_dir_absolute + tag + '/' + + 'site_summaries/predictions_' + strain + '_chr' + chrm + + '_site_summary.txt.gz', '\t') for region in regions_overlapping[strain]: header += ' ' + region start_strain = \ - math.ceil(float(\ - ref_ind_to_strain_ind[strain][regions[region]['start']])) + math.ceil(float(ref_ind_to_strain_ind[ + strain][regions[region]['start']])) end_strain = \ - math.floor(float(\ - ref_ind_to_strain_ind[strain][regions[region]['end']])) + math.floor(float(ref_ind_to_strain_ind[ + strain][regions[region]['end']])) start_relative = int(max(start_strain - int(g[2]), 0)) end_relative = int(end_strain - int(g[2])) seq = seq[:start_relative] + \ - seq[start_relative:end_relative+1].upper() + \ - seq[end_relative+1:] - seqs_current[i] = seq + seq[start_relative:end_relative+1].upper() + \ + seq[end_relative+1:] + seqs_current[i] = seq headers_current[i] = header gene_seqs_fn = gp.analysis_out_dir_absolute + tag + \ - '/genes/' + gene + '/' + gene + \ - '_introgressed' + suffix + gp.fasta_suffix + '/genes/' + gene + '/' + gene + \ + '_introgressed' + suffix + gp.fasta_suffix write_fasta.write_fasta(headers_current, seqs_current, gene_seqs_fn) - - print 'aligning gene sequences' - gene_seqs_aligned_fn = gene_seqs_fn.replace(gp.fasta_suffix, gp.alignment_suffix) + print('aligning gene sequences') + gene_seqs_aligned_fn = gene_seqs_fn.replace(gp.fasta_suffix, + gp.alignment_suffix) cmd_string = gp.mafft_install_path + '/mafft ' + \ - ' --quiet --reorder --preservecase ' + \ - gene_seqs_fn + ' > ' + gene_seqs_aligned_fn + ' --quiet --reorder --preservecase ' + \ + gene_seqs_fn + ' > ' + gene_seqs_aligned_fn os.system(cmd_string) - diff --git a/code/analyze/to_update/combine_region_all_strains_main.py b/code/analyze/to_update/combine_region_all_strains_main.py index cc79a6b..234671a 100644 --- a/code/analyze/to_update/combine_region_all_strains_main.py +++ b/code/analyze/to_update/combine_region_all_strains_main.py @@ -3,44 +3,38 @@ # - for gene, relies on annotations/orfs # - for coordinates, relies on alignments -import re import sys import os import math -import Bio.SeqIO import copy -from combine_all_strains import * -import gene_predictions -sys.path.insert(0, '..') +from combine_all_strains import (get_range_seq, + get_range_seqs) import global_params as gp -sys.path.insert(0, '../align/') -import align_helpers -sys.path.insert(0, '../misc/') -import read_table -import read_fasta -import write_fasta -import mystats - +from align import align_helpers +from misc import read_table +from misc import write_fasta tag = sys.argv[1] start = int(sys.argv[2]) end = int(sys.argv[3]) chrm = sys.argv[4] -print 'getting range sequence from reference strain' +print('getting range sequence from reference strain') ref = 'S288c' -ref_seq_fn = gp.ref_dir[ref] + gp.ref_fn_prefix[ref] + '_chr' + chrm + gp.fasta_suffix +ref_seq_fn = gp.ref_dir[ref] + gp.ref_fn_prefix[ref] + \ + '_chr' + chrm + gp.fasta_suffix ref_range_seq = get_range_seq(start, end, ref_seq_fn) -print 'getting range sequences from all strains' +print('getting range sequences from all strains') s = align_helpers.get_strains(align_helpers.flatten(gp.non_ref_dirs.values())) -s.append((gp.ref_fn_prefix[gp.alignment_ref_order[1]], gp.ref_dir[gp.alignment_ref_order[1]])) +s.append((gp.ref_fn_prefix[gp.alignment_ref_order[1]], + gp.ref_dir[gp.alignment_ref_order[1]])) # keyed by strain: (seq, start, end) strain_range_seqs = get_range_seqs(s, chrm, start, end, tag) -print 'writing all range sequences to file' +print('writing all range sequences to file') keys = sorted(strain_range_seqs.keys()) -headers = [key + ' ' + str(strain_range_seqs[key][1]) + ':' + \ +headers = [key + ' ' + str(strain_range_seqs[key][1]) + ':' + str(strain_range_seqs[key][2]) for key in keys] seqs = [strain_range_seqs[key][0] for key in keys] strains = [ref] + keys @@ -55,11 +49,11 @@ suffixes = ['', '_filtered'] for suffix in suffixes: - print ' '.join(['finding', suffix, 'regions that overlap range']) + print(' '.join(['finding', suffix, 'regions that overlap range'])) # read in filtered regions fn_regions = gp.analysis_out_dir_absolute + tag + '/' + \ - 'introgressed_blocks' + suffix + '_par_' + tag + '_summary_plus.txt' - regions, l = read_table.read_table_rows(fn_regions, '\t') + 'introgressed_blocks' + suffix + '_par_' + tag + '_summary_plus.txt' + regions, _ = read_table.read_table_rows(fn_regions, '\t') regions_overlapping = {} # TODO does this actually ensure that regions are sorted appropriately @@ -67,17 +61,17 @@ region_keys_ordered = sorted(regions.keys(), key=lambda x: int(x[1:])) for region in region_keys_ordered: if regions[region]['chromosome'] == chrm and \ - ((int(regions[region]['start']) > start and \ - int(regions[region]['start']) < end) or \ - (int(regions[region]['end']) > start and \ + ((int(regions[region]['start']) > start and + int(regions[region]['start']) < end) or + (int(regions[region]['end']) > start and int(regions[region]['end']) < end)): strain = regions[region]['strain'] - if not regions_overlapping.has_key(strain): + if strain not in regions_overlapping: regions_overlapping[strain] = [] regions_overlapping[strain].append(region) - print ' '.join(['writing all range sequences to file, with', \ - suffix, 'introgressed bases capitalized']) + print(' '.join(['writing all range sequences to file, with', + suffix, 'introgressed bases capitalized'])) headers_current = copy.deepcopy(headers) seqs_current = copy.deepcopy(seqs) for i in range(len(seqs)): @@ -89,35 +83,33 @@ if strain not in regions_overlapping: continue r = strain_range_seqs[strain] - t, labels = read_table.read_table_columns(gp.analysis_out_dir_absolute + \ - tag + '/' + \ - 'site_summaries/predictions_' + \ - strain + \ - '_chr' + chrm + \ - '_site_summary.txt.gz', '\t') + t, labels = read_table.read_table_columns( + gp.analysis_out_dir_absolute + tag + '/' + + 'site_summaries/predictions_' + strain + '_chr' + chrm + + '_site_summary.txt.gz', '\t') ref_ind_to_strain_ind = dict(zip(t['ps_ref'], t['ps_strain'])) for region in regions_overlapping[strain]: header += ' ' + region - start_strain = math.ceil(float(\ - ref_ind_to_strain_ind[regions[region]['start']])) - end_strain = math.floor(float(\ - ref_ind_to_strain_ind[regions[region]['end']])) + start_strain = math.ceil(float( + ref_ind_to_strain_ind[regions[region]['start']])) + end_strain = math.floor(float( + ref_ind_to_strain_ind[regions[region]['end']])) start_relative = int(max(start_strain - int(r[1]), 0)) end_relative = int(end_strain - int(r[1])) - seq = seq[:start_relative] + \ - seq[start_relative:end_relative+1].upper() + \ - seq[end_relative+1:] + seq = (seq[:start_relative] + + seq[start_relative:end_relative+1].upper() + + seq[end_relative+1:]) seqs_current[i] = seq headers_current[i] = header range_seqs_fn = gp.analysis_out_dir_absolute + tag + '/ranges/' + label + \ - '/' + label + '_introgressed' + suffix + gp.fasta_suffix + '/' + label + '_introgressed' + suffix + gp.fasta_suffix write_fasta.write_fasta(headers_current, seqs_current, range_seqs_fn) - - print 'aligning range sequences' - range_seqs_aligned_fn = range_seqs_fn.replace(gp.fasta_suffix, gp.alignment_suffix) + print('aligning range sequences') + range_seqs_aligned_fn = range_seqs_fn.replace(gp.fasta_suffix, + gp.alignment_suffix) cmd_string = gp.mafft_install_path + '/mafft ' + \ - ' --reorder --preservecase ' + \ - range_seqs_fn + ' > ' + range_seqs_aligned_fn + ' --reorder --preservecase ' + \ + range_seqs_fn + ' > ' + range_seqs_aligned_fn os.system(cmd_string) diff --git a/code/analyze/to_update/compare.py b/code/analyze/to_update/compare.py index 460203e..57d9b95 100644 --- a/code/analyze/to_update/compare.py +++ b/code/analyze/to_update/compare.py @@ -1,10 +1,9 @@ -import sys -import os -sys.path.insert(0, '..') import global_params as gp -lines = [x.split(',') for x in open('/tigress/AKEY/akey_vol2/aclark4/nobackup/introgression/data/Table_S5_introgressed_genes.csv', 'r').readlines()] +lines = [x.split(',') for x in open( + '/tigress/AKEY/akey_vol2/aclark4/nobackup/introgression/data/\ + Table_S5_introgressed_genes.csv', 'r').readlines()] genes = [] genes_verified = [] for i in range(2, len(lines)): @@ -14,12 +13,13 @@ gp.analysis_out_dir_absolute -lines = [x.split(' ') for x in open('../../results/introgressed_id_genes.txt', 'r').readlines()] +lines = [x.split(' ') for x in open('../../results/introgressed_id_genes.txt', + 'r').readlines()] my_genes = [x[0] for x in lines] -print len(genes), 'genes from paper' -print len(genes_verified), 'verified genes from paper' -print len(my_genes), '(verified) genes I identify' +print(len(genes), 'genes from paper') +print(len(genes_verified), 'verified genes from paper') +print(len(my_genes), '(verified) genes I identify') pm = [] pnm = [] @@ -33,27 +33,28 @@ if g not in genes: npm.append(g) -print 'genes found in paper that I found (', len(pm), '):' +print('genes found in paper that I found (', len(pm), '):') for x in pm: - print x -print 'genes found in paper that I didn\'t find (', len(pnm), '):' + print(x) +print('genes found in paper that I didn\'t find (', len(pnm), '):') for x in pnm: - print x -print 'genes that I found not in paper(', len(npm), '):' + print(x) +print('genes that I found not in paper(', len(npm), '):') for x in npm: - print x + print(x) -lines = [x.strip().split(' ') for x in open('../../results/introgressed_id_genes_fns.txt', 'r').readlines()] +lines = [x.strip().split(' ') for x in open( + '../../results/introgressed_id_genes_fns.txt', 'r').readlines()] gene_to_fns = {} for line in lines: gene_to_fns[line[0]] = line[1:] while True: - gene = raw_input('=========================================\nwhich gene? ') + gene = input('=========================================\nwhich gene? ') try: gene_to_fns[gene] - except: - print 'that gene wasn\'t one i found' + except KeyError: + print('that gene wasn\'t one i found') continue for fn in gene_to_fns[gene]: f = open(fn) @@ -78,16 +79,17 @@ seq += 'p' else: seq += '-' - print '==========', fn + print('==========', fn) line_length = 10000 for i in range(0, len(seq), line_length): - print seqc[i:i+line_length] - print seqp[i:i+line_length] - print seqx[i:i+line_length] - print seq[i:i+line_length] - print - print - raw_input('') + print(seqc[i:i+line_length]) + print(seqp[i:i+line_length]) + print(seqx[i:i+line_length]) + print(seq[i:i+line_length]) + print() + print() + input('') -# TODO: get alignments for genes found in paper but not by me; print positions in each genome before alignments +# TODO: get alignments for genes found in paper but not by me; +# print positions in each genome before alignments diff --git a/code/analyze/to_update/compare_3strains_main.py b/code/analyze/to_update/compare_3strains_main.py index 7e7e560..ab190d5 100644 --- a/code/analyze/to_update/compare_3strains_main.py +++ b/code/analyze/to_update/compare_3strains_main.py @@ -1,20 +1,7 @@ -import re -import sys -import os -import copy import itertools from collections import defaultdict -import gene_predictions -import predict -from filter_helpers import * -sys.path.insert(0, '..') import global_params as gp -sys.path.insert(0, '../sim/') -import sim_analyze_hmm_bw as sim -sys.path.insert(0, '../misc/') -import mystats -import read_table -import read_fasta +from misc import read_table strains = ['yjm1252', 'yjm1078', 'yjm248'] @@ -36,12 +23,13 @@ for base in range(start, end + 1): bases_by_strains[chrm][base].append(strain) -#for base in sorted(bases_by_strains['I'].keys()): +# for base in sorted(bases_by_strains['I'].keys()): # print base, bases_by_strains['I'][base] - + categories = [] -for i in range(1,len(strains) + 1): - categories += [tuple(sorted(x)) for x in itertools.combinations(strains, i)] +for i in range(1, len(strains) + 1): + categories += [tuple(sorted(x)) + for x in itertools.combinations(strains, i)] cat_counts = defaultdict(int) for chrm in bases_by_strains.keys(): @@ -55,4 +43,3 @@ for cat in categories: f.write(','.join(cat) + '\t' + str(cat_counts[cat]) + '\n') f.close() - diff --git a/code/analyze/to_update/compare_predictions_main.py b/code/analyze/to_update/compare_predictions_main.py index c967c92..753c0a1 100644 --- a/code/analyze/to_update/compare_predictions_main.py +++ b/code/analyze/to_update/compare_predictions_main.py @@ -1,20 +1,9 @@ -import re import sys -import os -import copy -import itertools -import gene_predictions import predict from collections import defaultdict -from filter_helpers import * -sys.path.insert(0, '..') import global_params as gp -sys.path.insert(0, '../sim/') -import sim_analyze_hmm_bw as sim -sys.path.insert(0, '../misc/') -import mystats -import read_table -import read_fasta +from misc import read_table + # similar to find_pops function in structure_3_main.py def overlap_with_any(start, end, blocks): @@ -32,23 +21,24 @@ def overlap_with_any(start, end, blocks): break return count + args = predict.process_predict_args(sys.argv[1:]) -## comparing to other prediction run; e.g. comparing using just one -## introgressed reference state to using multiple; this is a little -## janky because some of the file names and formatting have changed +# comparing to other prediction run; e.g. comparing using just one +# introgressed reference state to using multiple; this is a little +# janky because some of the file names and formatting have changed other_region_fn = gp.analysis_out_dir_absolute + 'u3_i.001_tv_l1000_f.01/' + \ - 'introgressed_blocks_filtered_par_u3_i.001_tv_l1000_f.01_summary_plus.txt' + 'introgressed_blocks_filtered_par_u3_i.001_tv_l1000_f.01_summary_plus.txt' rt_other, fields_other = read_table.read_table_rows(other_region_fn, '\t') regions_other = defaultdict(lambda: defaultdict(list)) for region_id in rt_other: chrm = rt_other[region_id]['chromosome'] strain = rt_other[region_id]['strain'] - regions_other[chrm][strain].append((int(rt_other[region_id]['start']), \ + regions_other[chrm][strain].append((int(rt_other[region_id]['start']), int(rt_other[region_id]['end']))) for chrm in gp.chrms: for strain in regions_other[chrm].keys(): - regions_other[chrm][strain].sort(key = lambda x: x[0]) + regions_other[chrm][strain].sort(key=lambda x: x[0]) regions = defaultdict(lambda: defaultdict(list)) @@ -63,12 +53,12 @@ def overlap_with_any(start, end, blocks): for region_id in rt: chrm = rt[region_id]['chromosome'] strain = rt[region_id]['strain'] - regions[chrm][strain].append((int(rt[region_id]['start']), \ + regions[chrm][strain].append((int(rt[region_id]['start']), int(rt[region_id]['end']), rt[region_id]['alternative_states'])) for chrm in gp.chrms: for strain in regions[chrm].keys(): - regions[chrm][strain].sort(key = lambda x: x[0]) + regions[chrm][strain].sort(key=lambda x: x[0]) # count bases found in every possible combination of species_from + # presence/absence in regions_other @@ -79,14 +69,14 @@ def overlap_with_any(start, end, blocks): # current predictions for strain in regions[chrm].keys(): for region in regions[chrm][strain]: - x = overlap_with_any(region[0], region[1], regions_other[chrm][strain]) + x = overlap_with_any(region[0], + region[1], regions_other[chrm][strain]) length = region[1] - region[0] + 1 alt_states = region[2].split(',') d[strain][tuple(['other'] + alt_states)] += x d[strain][tuple(alt_states)] += length - x assert x <= length - # other predictions for strain in regions_other[chrm].keys(): for region in regions_other[chrm][strain]: @@ -95,13 +85,14 @@ def overlap_with_any(start, end, blocks): d[strain][('other', 'any')] += x d[strain][('other',)] += length - x assert x <= length - -fn = gp.analysis_out_dir_absolute + args['tag'] + '/' + 'state_counts_comparison.txt' +fn = gp.analysis_out_dir_absolute + args['tag'] +\ + '/' + 'state_counts_comparison.txt' f = open(fn, 'w') f.write('strain\tlabel\tcount\n') for strain in d.keys(): for label in d[strain].keys(): - f.write(strain + '\t' + ','.join(label) + '\t' + str(d[strain][label]) + '\n') + f.write(strain + '\t' + ','.join(label) + + '\t' + str(d[strain][label]) + '\n') f.close() diff --git a/code/analyze/to_update/compare_to_strope.py b/code/analyze/to_update/compare_to_strope.py index 297ebd4..7fed803 100644 --- a/code/analyze/to_update/compare_to_strope.py +++ b/code/analyze/to_update/compare_to_strope.py @@ -1,22 +1,10 @@ # compare set of genes I've called to set called in Strope et al (100 # genomes paper) -import re import sys -import os -import math -import Bio.SeqIO -import copy import gene_predictions -sys.path.insert(0, '..') import global_params as gp -sys.path.insert(0, '../align/') -import align_helpers -sys.path.insert(0, '../misc/') -import read_table -import read_fasta -import write_fasta -import mystats +from misc import read_table tag = sys.argv[1] @@ -35,14 +23,14 @@ if line[7+i] == 'P': strains_int_par.append(strains[i]) n_int_par = len(strains_int_par) - genes_strope[line[2]] = (n_int_par, n_int_other, n_del, strains_int_par, \ + genes_strope[line[2]] = (n_int_par, n_int_other, n_del, strains_int_par, line[1], line[4]) sys_standard_strope[line[1]] = line[2] - + fn_regions = gp.analysis_out_dir_absolute + tag + '/' + \ 'introgressed_blocks_filtered_par_' + tag + '_summary_plus.txt' # dict keyed by region: {strain:, start:, end:, etc} -regions, l = read_table.read_table_rows(fn_regions, '\t') +regions, _ = read_table.read_table_rows(fn_regions, '\t') region_to_genes = {} for chrm in gp.chrms: fn_genes_regions = gp.analysis_out_dir_absolute + tag + '/' + \ @@ -53,9 +41,9 @@ region_to_genes.update(region_to_genes_current) genes_by_strain = {} for region in regions: - if not genes_by_strain.has_key(regions[region]['strain']): + if regions[region]['strain'] not in genes_by_strain: genes_by_strain[regions[region]['strain']] = set([]) - [genes_by_strain[regions[region]['strain']].add(gene) \ + [genes_by_strain[regions[region]['strain']].add(gene) for gene in [x[0] for x in region_to_genes[region]['gene_list']]] genes = {} @@ -97,7 +85,8 @@ # TODO fix my gene list then get rid of this all_genes = {} for chrm in gp.chrms: - fn_all_genes = gp.analysis_out_dir_absolute + 'S288c_chr' + chrm + '_genes.txt' + fn_all_genes = gp.analysis_out_dir_absolute +\ + 'S288c_chr' + chrm + '_genes.txt' f_all_genes = open(fn_all_genes, 'r') lines = [line.strip().split('\t') for line in f_all_genes.readlines()] f_all_genes.close() @@ -107,7 +96,6 @@ strand = 'NA' all_genes[line[0]] = ('NA', chrm, start, end, strand) - fn_paralogs = '../../data/S288c_paralogs.tsv' f_paralogs = open(fn_paralogs, 'r') lines = [line.strip().split('\t') for line in f_paralogs.readlines()] @@ -117,7 +105,6 @@ if line[0] != "": paralogs[line[0]] = line[3] - f_s = open('compare_to_strope/genes_strope_only.txt', 'w') f_m = open('compare_to_strope/genes_me_only.txt', 'w') f_sm = open('compare_to_strope/genes_both.txt', 'w') @@ -147,7 +134,8 @@ f_sp.write(gene + '\n') if gene in genes or (gene in sys_standard and sys_standard[gene] in genes): continue - elif not (gene in all_genes or (gene in sys_standard and sys_standard[gene] in all_genes)): + elif not (gene in all_genes or + (gene in sys_standard and sys_standard[gene] in all_genes)): continue elif genes_strope[gene][0] == 0: continue @@ -157,19 +145,19 @@ f_s.write(gene + '\n') c_s += 1 if gene in paralogs: - c_s_p +=1 + c_s_p += 1 f_s.close() f_m.close() f_sm.close() f_mp.close() f_sp.close() -print 'number strope only:', c_s -print 'number me only:', c_m -print 'number strope and me:', c_sm -print 'number strope only paralogs', c_s_p -print 'number me only paralogs', c_m_p -print 'number strope and me paralogs', c_sm_p -print 'number paralogs', len(paralogs) +print('number strope only:', c_s) +print('number me only:', c_m) +print('number strope and me:', c_sm) +print('number strope only paralogs', c_s_p) +print('number me only paralogs', c_m_p) +print('number strope and me paralogs', c_sm_p) +print('number paralogs', len(paralogs)) -print paralogs.keys() +print(paralogs.keys()) diff --git a/code/analyze/to_update/count_coding_changes.py b/code/analyze/to_update/count_coding_changes.py index a4af578..8ce9ec6 100644 --- a/code/analyze/to_update/count_coding_changes.py +++ b/code/analyze/to_update/count_coding_changes.py @@ -1,10 +1,7 @@ -import sys -import os -sys.path.insert(0, '..') import global_params as gp -sys.path.insert(0, '../misc/') -import seq_functions -import read_fasta +from misc import seq_functions +from misc import read_fasta + def get_aligned_genes(fn, strains): headers, seqs = read_fasta.read_fasta(fn) @@ -44,7 +41,6 @@ def ambiguous(gene, ref_start, ref_end, coords, orfs): def count_coding(seq_master, seq_ref, seq_strain, start, end): - if not seq_master.startswith('ATG'): seq_master = seq_functions.reverse_complement(seq_master) assert seq_master.startswith('ATG'), seq_master @@ -87,16 +83,15 @@ def count_coding(seq_master, seq_ref, seq_strain, start, end): def count_coding_with_gaps(seq_master, seq_ref, seq_strain, start, end): - - print seq_master - print seq_ref - print seq_strain - print start, end + print(seq_master) + print(seq_ref) + print(seq_strain) + print(start, end) seq_master = seq_master.upper() seq_ref = seq_ref.upper() seq_strain = seq_strain.upper() - + ind_master = 0 ind_ref = 0 ind_strain = 0 @@ -177,10 +172,9 @@ def count_coding_with_gaps(seq_master, seq_ref, seq_strain, start, end): if codon_strain != codon_master: aa_master = seq_functions.codon_table.get(codon_master) - aa_ref = seq_functions.codon_table.get(codon_ref) aa_strain = seq_functions.codon_table.get(codon_strain) - if aa_master == None or aa_strain == None: + if aa_master is None or aa_strain is None: if gaps_master > gaps_strain: t_insert += gaps_master - gaps_strain else: @@ -212,10 +206,9 @@ def count_coding_with_gaps(seq_master, seq_ref, seq_strain, start, end): else: t_non_ref += 1 - print t_syn, t_non, t_syn_ref, t_non_ref - print t_insert, t_delete, t_insert_ref, t_delete_ref - print frameshift + print(t_syn, t_non, t_syn_ref, t_non_ref) + print(t_insert, t_delete, t_insert_ref, t_delete_ref) + print(frameshift) return t_syn, t_non, t_syn_ref, t_non_ref, \ t_insert/3.0, t_delete/3.0, t_insert_ref/3.0, t_delete_ref/3.0, \ gene_delete, gene_delete_ref, frameshift_count - diff --git a/code/analyze/to_update/count_coding_changes_main.py b/code/analyze/to_update/count_coding_changes_main.py index d0a4220..96c1705 100644 --- a/code/analyze/to_update/count_coding_changes_main.py +++ b/code/analyze/to_update/count_coding_changes_main.py @@ -1,36 +1,34 @@ import sys import os import gzip -from count_coding_changes import * +from count_coding_changes import get_aligned_genes, count_coding_with_gaps import annotate_positions -sys.path.insert(0, '..') import global_params as gp -sys.path.insert(0, '../misc/') -import overlap -import read_table -import read_fasta +from misc import overlap +from misc import read_table +from misc import read_fasta -##====== +# ====== # command line arguments -##====== +# ====== tag = sys.argv[1] -##====== +# ====== # read in introgressed regions -##====== +# ====== -# key region ids by chromosome and then strain +# key region ids by chromosome and then strain fn_regions = gp.analysis_out_dir_absolute + tag + '/' + \ 'introgressed_blocks_filtered_par_' + tag + '_summary_plus.txt' -regions, l = read_table.read_table_rows(fn_regions, '\t') +regions, _ = read_table.read_table_rows(fn_regions, '\t') region_ids_by_chrm_strain = {} for r in regions.keys(): strain = regions[r]['strain'] chrm = regions[r]['chromosome'] - if not region_ids_by_chrm_strain.has_key(chrm): + if chrm not in region_ids_by_chrm_strain: region_ids_by_chrm_strain[chrm] = {} - if not region_ids_by_chrm_strain[chrm].has_key(strain): + if strain not in region_ids_by_chrm_strain[chrm]: region_ids_by_chrm_strain[chrm][strain] = [] region_ids_by_chrm_strain[chrm][strain].append(r) @@ -48,82 +46,84 @@ f.close() -##====== +# ====== # count sites within all regions that are coding/noncoding, plus some # more details about coding changes -##====== +# ====== other_ref = gp.alignment_ref_order[1] region_totals = {} gene_totals = {} strain_totals = {} -totals = {'syn':0, 'non':0, 'syn_ref':0, 'non_ref':0, \ - 'insert':0, 'delete':0, 'insert_ref':0, 'delete_ref':0, \ - 'gene_delete':0, 'gene_delete_ref':0, \ - 'ref_gene_only':0, 'strain_orf_only':0, \ - 'coding':0, 'noncoding':0, 'frameshift':0} +totals = {'syn': 0, 'non': 0, 'syn_ref': 0, 'non_ref': 0, + 'insert': 0, 'delete': 0, 'insert_ref': 0, 'delete_ref': 0, + 'gene_delete': 0, 'gene_delete_ref': 0, + 'ref_gene_only': 0, 'strain_orf_only': 0, + 'coding': 0, 'noncoding': 0, 'frameshift': 0} for chrm in gp.chrms: - print chrm + print(chrm) # read in cer reference genes fn = gp.analysis_out_dir_absolute + gp.master_ref + '_chr' + chrm + \ - '_genes.txt' - genes, l = read_table.read_table_rows(fn, '\t', header=False, key_ind=0) + '_genes.txt' + genes, _ = read_table.read_table_rows(fn, '\t', header=False, key_ind=0) for gene in genes: genes[gene] = (int(genes[gene][0]), int(genes[gene][1])) # read in cer ref -> par ref position file fn = gp.analysis_out_dir_absolute + 'coordinates/' + gp.master_ref + \ - '_to_' + other_ref + '_chr' + chrm + '.txt.gz' - master_to_other_ref_pos = [float(line[:-1]) \ + '_to_' + other_ref + '_chr' + chrm + '.txt.gz' + master_to_other_ref_pos = [float(line[:-1]) for line in gzip.open(fn, 'rb').readlines()] # read in cer ref chromosome sequence fn = gp.ref_dir[gp.master_ref] + gp.ref_fn_prefix[gp.master_ref] + \ - '_chr' + chrm + gp.fasta_suffix + '_chr' + chrm + gp.fasta_suffix master_seq = read_fasta.read_fasta(fn)[1][0] # read in par ref chromosome sequence fn = gp.ref_dir[other_ref] + gp.ref_fn_prefix[other_ref] + \ - '_chr' + chrm + gp.fasta_suffix + '_chr' + chrm + gp.fasta_suffix other_ref_seq = read_fasta.read_fasta(fn)[1][0] # read in par ref ORFs fn = gp.ref_dir[other_ref] + 'orfs/' + other_ref + \ - '_chr' + chrm + '_orfs' + gp.fasta_suffix + '_chr' + chrm + '_orfs' + gp.fasta_suffix ref_orfs = annotate_positions.get_orfs(fn) for strain in region_ids_by_chrm_strain[chrm].keys(): - print '-', strain - - if not strain_totals.has_key(strain): - strain_totals[strain] = {'syn':0, 'non':0, 'syn_ref':0, 'non_ref':0, \ - 'ref_gene_only':0, 'strain_orf_only':0, \ - 'coding':0, 'noncoding':0} - - # read in cer ref -> strain position file + print('-', strain) + + if strain not in strain_totals: + strain_totals[strain] = { + 'syn': 0, 'non': 0, 'syn_ref': 0, 'non_ref': 0, + 'ref_gene_only': 0, 'strain_orf_only': 0, + 'coding': 0, 'noncoding': 0} + + # read in cer ref -> strain position file fn = gp.analysis_out_dir_absolute + 'coordinates/' + gp.master_ref + \ - '_to_' + strain + '_chr' + chrm + '.txt.gz' - master_to_strain_pos = [float(line[:-1]) \ + '_to_' + strain + '_chr' + chrm + '.txt.gz' + master_to_strain_pos = [float(line[:-1]) for line in gzip.open(fn, 'rb').readlines()] # read in strain chromosome sequence fn = gp.non_ref_dirs[gp.master_ref][0] + strain + \ - '_chr' + chrm + gp.fasta_suffix + '_chr' + chrm + gp.fasta_suffix strain_seq = read_fasta.read_fasta(fn)[1][0] # read in strain ORFs fn = gp.non_ref_dirs[gp.master_ref][0] + 'orfs/' + strain + \ - '_chr' + chrm + '_orfs' + gp.fasta_suffix + '_chr' + chrm + '_orfs' + gp.fasta_suffix orfs = annotate_positions.get_orfs(fn) for region in region_ids_by_chrm_strain[chrm][strain]: - region_totals[region] = {'syn':0, 'non':0, 'syn_ref':0, 'non_ref':0, \ - 'ref_gene_only':0, 'strain_orf_only':0, \ - 'coding':0, 'noncoding':0} + region_totals[region] = { + 'syn': 0, 'non': 0, 'syn_ref': 0, 'non_ref': 0, + 'ref_gene_only': 0, 'strain_orf_only': 0, + 'coding': 0, 'noncoding': 0} # is each site in region in a master ref gene and/or # strain ORF? @@ -132,10 +132,11 @@ t_gene_not_orf = 0 t_not_gene_orf = 0 t_not_gene_not_orf = 0 - for site in range(int(regions[region]['start']), \ + for site in range(int(regions[region]['start']), int(regions[region]['end'])): in_gene = overlap.contained_any(site, genes.values()) - in_orf = overlap.contained_any(master_to_strain_pos[site], orfs.keys()) + in_orf = overlap.contained_any( + master_to_strain_pos[site], orfs.keys()) if in_gene: if in_orf: t_gene_orf += 1 @@ -170,58 +171,63 @@ # read multiple alignment for the gene, in which we've # previously selected the best orfs to match the gene - fn = gp.analysis_out_dir_absolute + tag + '/genes/' + gene + '/' + \ - gene + '_introgressed_filtered.maf' + fn = gp.analysis_out_dir_absolute + tag + '/genes/' \ + + gene + '/' + gene + '_introgressed_filtered.maf' if not os.path.isfile(fn): - print 'do not have alignment for', gene + print('do not have alignment for', gene) continue - aligned_genes = get_aligned_genes(fn, \ - [gp.master_ref, other_ref, strain]) + aligned_genes = get_aligned_genes( + fn, [gp.master_ref, other_ref, strain]) - print gene, strain + print(gene, strain) # for now, ignore cerevisiae reference genes that # don't map perfectly to an ORF in the strain and # paradoxus reference - #if ambiguous(gene, gene_start, gene_end, master_to_strain_pos, orfs): + # if ambiguous(gene, gene_start, gene_end, + # master_to_strain_pos, orfs): # continue - #if ambiguous(gene, gene_start, gene_end, \ + # if ambiguous(gene, gene_start, gene_end, \ # master_to_other_ref_pos, ref_orfs): # continue - + # extract gene sequence from references and strain g_master = master_seq[gene_start:gene_end+1] - g_ref = other_ref_seq[int(master_to_other_ref_pos[gene_start]):\ + g_ref = other_ref_seq[int(master_to_other_ref_pos[gene_start]): int(master_to_other_ref_pos[gene_end])+1] - g_strain = strain_seq[int(master_to_strain_pos[gene_start]):\ + g_strain = strain_seq[int(master_to_strain_pos[gene_start]): int(master_to_strain_pos[gene_end])+1] # get overlap between gene and introgressed region - o_start, o_end = overlap.overlap_region(genes[gene][0], \ - genes[gene][1], \ - int(regions[region]['start']), \ - int(regions[region]['end'])) + o_start, o_end = overlap.overlap_region( + genes[gene][0], + genes[gene][1], + int(regions[region]['start']), + int(regions[region]['end'])) # count synonymous and non synonymous changes due to # paradoxus (deal with gene direction correctly) # t_syn, t_non = count_coding(g_master, g_ref, g_strain, \ - # o_start-gene_start, o_end-gene_start) + # o_start-gene_start, + # o_end-gene_start) # alternative method that deals with imperfect matches t_syn, t_non, t_syn_ref, t_non_ref, \ t_insert, t_delete, t_insert_ref, t_delete_ref, \ gene_delete, gene_delete_ref, frameshift = \ - count_coding_with_gaps(aligned_genes[gp.master_ref], \ - aligned_genes[other_ref], \ - aligned_genes[strain], \ - o_start-gene_start, o_end-gene_start) + count_coding_with_gaps(aligned_genes[gp.master_ref], + aligned_genes[other_ref], + aligned_genes[strain], + o_start-gene_start, + o_end-gene_start) # add to totals for region, gene, strain, and overall - if not gene_totals.has_key(gene): - gene_totals[gene] = {'syn':0, 'non':0, 'syn_ref':0, 'non_ref':0, \ - 'insert':0, 'delete':0, \ - 'insert_ref':0, 'delete_ref':0, \ - 'gene_delete':0, 'gene_delete_ref':0, \ - 'frameshift':0} + if gene not in gene_totals: + gene_totals[gene] = { + 'syn': 0, 'non': 0, 'syn_ref': 0, 'non_ref': 0, + 'insert': 0, 'delete': 0, + 'insert_ref': 0, 'delete_ref': 0, + 'gene_delete': 0, 'gene_delete_ref': 0, + 'frameshift': 0} gene_totals[gene]['syn'] += t_syn gene_totals[gene]['non'] += t_non gene_totals[gene]['syn_ref'] += t_syn_ref @@ -256,9 +262,9 @@ totals['gene_delete_ref'] += gene_delete_ref totals['frameshift'] += frameshift -##====== +# ====== # write output file -##====== +# ====== fn = gp.analysis_out_dir_absolute + tag + '/' + 'coding_changes_summary_' + \ tag + '.txt' @@ -271,30 +277,26 @@ for strain in strain_totals: for key in strain_totals[strain].keys(): - f.write(strain + sep + 'strain' + sep + \ + f.write(strain + sep + 'strain' + sep + str(strain_totals[strain][key]) + sep + key + '\n') for gene in gene_totals: for key in gene_totals[gene].keys(): - f.write(gene + sep + 'gene' + sep + \ + f.write(gene + sep + 'gene' + sep + str(gene_totals[gene][key]) + sep + key + '\n') for region in region_totals: for key in region_totals[region].keys(): - f.write(region + sep + 'region' + sep + \ + f.write(region + sep + 'region' + sep + str(region_totals[region][key]) + sep + key + '\n') -f.close() - - - +f.close() # new plan # for each region # for each site in region # is it in ref gene and/or strain orf? (keep track of four totals) -# # for each gene # get corresponding orfs in par and strain @@ -305,4 +307,3 @@ # - categories: # multiples of 3 # not multiples of 3 -> stop counting/ignore gene? - diff --git a/code/analyze/to_update/count_introgressed_main.py b/code/analyze/to_update/count_introgressed_main.py index 2663c12..5cbc0d3 100644 --- a/code/analyze/to_update/count_introgressed_main.py +++ b/code/analyze/to_update/count_introgressed_main.py @@ -1,24 +1,12 @@ # counts total amount of sites introgressed on each chromosome -import re -import sys -import os -import math -import gzip -import itertools -sys.path.insert(0, '..') import global_params as gp -sys.path.insert(0, '../align/') -import align_helpers -sys.path.insert(0, '../misc/') -import overlap -import read_table -import read_fasta -import write_fasta -import mystats +from misc import read_table -chrm_sizes = [230218, 813184, 316620, 1531933, 576874, 270161, 1090940, 562643, 439888, 745751, 666816, 1078177, 924431, 784333, 1091291, 948066] +chrm_sizes = [230218, 813184, 316620, 1531933, 576874, + 270161, 1090940, 562643, 439888, 745751, + 666816, 1078177, 924431, 784333, 1091291, 948066] tag = 'u3_i.001_tv_l1000_f.01' @@ -30,13 +18,13 @@ for region in d: chrm = d[region]['chromosome'] strain = d[region]['strain'] - regions_by_chrm[chrm].append((strain, \ - int(d[region]['start']), \ + regions_by_chrm[chrm].append((strain, + int(d[region]['start']), int(d[region]['end']))) hist = {} for chrm in gp.chrms: - print chrm + print(chrm) chrm_size = chrm_sizes[gp.chrms.index(chrm)] x = [0 for i in range(chrm_size)] for ri in range(len(regions_by_chrm[chrm])): @@ -56,7 +44,7 @@ total = 0 for chrm in gp.chrms: f.write(chrm + '\t') - chrm_size = chrm_sizes[gp.chrms.index(chrm)] + chrm_size = chrm_sizes[gp.chrms.index(chrm)] at_least_one = chrm_size - hist[chrm][0] total += at_least_one f.write(str(at_least_one) + '\t') diff --git a/code/analyze/to_update/frequency_of_introgression_main.py b/code/analyze/to_update/frequency_of_introgression_main.py index 8004fdb..89b2da0 100644 --- a/code/analyze/to_update/frequency_of_introgression_main.py +++ b/code/analyze/to_update/frequency_of_introgression_main.py @@ -1,25 +1,11 @@ -import re -import sys -import os -import copy -import itertools from collections import defaultdict -import gene_predictions -import predict -from filter_helpers import * -sys.path.insert(0, '..') import global_params as gp -sys.path.insert(0, '../sim/') -import sim_analyze_hmm_bw as sim -sys.path.insert(0, '../misc/') -import mystats -import read_table -import read_fasta +from misc import read_table tag = 'u3_i.001_tv_l1000_f.01' species_from = 'par' -#strains3 = ['yjm1252', 'yjm1078', 'yjm248'] +# strains3 = ['yjm1252', 'yjm1078', 'yjm248'] fn = gp.analysis_out_dir_absolute + tag + '/' + \ 'introgressed_blocks_filtered_' + species_from + \ @@ -30,7 +16,7 @@ strains = set([]) for region_id in regions: strain = regions[region_id]['strain'] - #if strain not in strains3: + # if strain not in strains3: strains.add(strain) chrm = regions[region_id]['chromosome'] start = int(regions[region_id]['start']) @@ -50,4 +36,3 @@ for i in range(len(strains)): f.write(str(i) + '\t' + str(counts[i]) + '\n') f.close() - diff --git a/code/analyze/to_update/gene_overlap_main.py b/code/analyze/to_update/gene_overlap_main.py index 9c274e4..3e431e0 100644 --- a/code/analyze/to_update/gene_overlap_main.py +++ b/code/analyze/to_update/gene_overlap_main.py @@ -10,10 +10,10 @@ # bases within coding sequence are upper case. In addition, there is a # corresponding file S288c_CBS432_strain_chrX_start-end.genes.txt # listing the genes that overlap this region, and the indices of -# the bases they overlap, in this format: -# gene_name\t0-149\t25236-25385 +# the bases they overlap, in this format: +# gene_name\t0-149\t25236-25385 # gene_name\t200-600\t.... -# +# # also generate a file in results/tag/gene_alignments/ for each # introgressed gene, which contains one threeway alignment for each # strain in which the gene was called introgressed...followed by all @@ -27,93 +27,86 @@ # versions (gene_introgressed.fasta), and also to all of the versions # (gene_all.fasta). -# TODO: -## _annotated file should be .txt not .maf -## also modify so that 80 characters per line -## and extra row showing summary of which references match +# TODO: +# _annotated file should be .txt not .maf +# also modify so that 80 characters per line +# and extra row showing summary of which references match -import re import sys -import os -import copy -from gene_predictions import * +from gene_predictions import read_gene_file import predict -import pickle from collections import defaultdict -sys.path.insert(0, '..') import global_params as gp -sys.path.insert(0, '../misc/') -import read_fasta -import overlap +from misc import overlap -##====== +# ====== # read in analysis parameters -##====== +# ====== args = predict.process_predict_args(sys.argv[1:]) gp_dir = '../' open_mode = 'w' -##====== +# ====== # read in reference gene coordinates -##====== +# ====== genes = {} for chrm in gp.chrms: fn_genes = gp.analysis_out_dir_absolute + '/' + \ gp.master_ref + '_chr' + chrm + '_genes.txt' - # + genes[chrm] = read_gene_file(fn_genes) -##====== +# ====== # do all the stuff -##====== +# ====== for species_from in args['states']: - ##====== + # ====== # read in introgressed regions for current state - ##====== + # ====== # strain chromosome predicted_species start end number_non_gap blocks_fn = gp.analysis_out_dir_absolute + args['tag'] + '/' + \ 'introgressed_blocks_' + species_from + '_' + args['tag'] + \ '_labeled.txt' - # introgressed regions keyed by strain and then chromosome: + # introgressed regions keyed by strain and then chromosome: # (region_id, start, end, number_non_gap) regions = predict.read_blocks(blocks_fn, labeled=True) - ##====== + # ====== # extract alignments and genes for introgressed regions - ##====== + # ====== - fn_genes_regions = gp.analysis_out_dir_absolute + '/' + args['tag'] + '/' + \ - 'genes_for_each_region_' + species_from + '_' + \ - args['tag'] + '.txt' + fn_genes_regions = gp.analysis_out_dir_absolute + '/' + args['tag'] + \ + '/' + 'genes_for_each_region_' + species_from + '_' + \ + args['tag'] + '.txt' f_genes_regions = open(fn_genes_regions, open_mode) f_genes_regions.write('region_id\tnumber_genes\tgenes\tfracs\n') d_regions_to_genes = defaultdict(lambda: defaultdict(float)) - #fn_regions_strains = gp.analysis_out_dir_absolute + '/' + args['tag'] + '/' + \ - # 'regions_for_each_strain_' species_from + '_' + \ - # args['tag'] + '.txt' - #f_regions_strains = open(fn_regions_strains, open_mode) - #f_regions_strains.write('strain\tregions\n') - #d_strains_to_regions = defaultdict(list) + # fn_regions_strains = gp.analysis_out_dir_absolute + '/' + args['tag']\ + # + '/' + 'regions_for_each_strain_' species_from\ + # + '_' + args['tag'] + '.txt' + # f_regions_strains = open(fn_regions_strains, open_mode) + # f_regions_strains.write('strain\tregions\n') + # d_strains_to_regions = defaultdict(list) - fn_genes_strains = gp.analysis_out_dir_absolute + '/' + args['tag'] + '/' + \ - 'genes_for_each_strain_' + species_from + '_' + \ - args['tag'] + '.txt' + fn_genes_strains = gp.analysis_out_dir_absolute + '/' + args['tag'] + \ + '/' + 'genes_for_each_strain_' + species_from + '_' + \ + args['tag'] + '.txt' f_genes_strains = open(fn_genes_strains, open_mode) f_genes_strains.write('strain\tnumber_genes\tgenes\tfracs\n') d_strains_to_genes = defaultdict(lambda: defaultdict(float)) - fn_strains_genes = gp.analysis_out_dir_absolute + '/' + args['tag'] + '/' + \ - 'strains_for_each_gene_' + species_from + '_' + \ - args['tag'] + '.txt' + fn_strains_genes = gp.analysis_out_dir_absolute + '/' + args['tag'] + \ + '/' + 'strains_for_each_gene_' + species_from + '_' + \ + args['tag'] + '.txt' f_strains_genes = open(fn_strains_genes, open_mode) f_strains_genes.write('gene\tnum_strains\tstrains\tfracs\n') d_genes_to_strains = defaultdict(lambda: defaultdict(float)) @@ -123,42 +116,42 @@ for entry in regions[strain][chrm]: region_id, start, end, number_non_gap = entry for gene in genes[chrm]: - o = overlap.overlap(start, end, \ - genes[chrm][gene][0], genes[chrm][gene][1]) + o = overlap.overlap(start, end, + genes[chrm][gene][0], + genes[chrm][gene][1]) if o > 0: - gene_length = float(genes[chrm][gene][1] - \ + gene_length = float(genes[chrm][gene][1] - genes[chrm][gene][0] + 1) frac_o = o / gene_length d_regions_to_genes[region_id][gene] += frac_o d_strains_to_genes[strain][gene] += frac_o d_genes_to_strains[gene][strain] += frac_o - for region in sorted(d_regions_to_genes.keys(), key=lambda x: int(x[1:])): g = sorted(d_regions_to_genes[region].keys()) f_genes_regions.write(region + '\t' + str(len(g)) + '\t') f_genes_regions.write(','.join(g) + '\t') - f_genes_regions.write(','.join([str(d_regions_to_genes[region][x]) \ + f_genes_regions.write(','.join([str(d_regions_to_genes[region][x]) for x in g]) + '\n') for strain in sorted(d_strains_to_genes.keys()): g = sorted(d_strains_to_genes[strain].keys()) f_genes_strains.write(strain + '\t' + str(len(g)) + '\t') f_genes_strains.write(','.join(g) + '\t') - f_genes_strains.write(','.join([str(d_strains_to_genes[strain][x]) \ + f_genes_strains.write(','.join([str(d_strains_to_genes[strain][x]) for x in g]) + '\n') for gene in sorted(d_genes_to_strains.keys()): s = sorted(d_genes_to_strains[gene].keys()) f_strains_genes.write(gene + '\t' + str(len(s)) + '\t') f_strains_genes.write(','.join(s) + '\t') - f_strains_genes.write(','.join([str(d_genes_to_strains[gene][x]) \ + f_strains_genes.write(','.join([str(d_genes_to_strains[gene][x]) for x in s]) + '\n') f_genes_regions.close() f_genes_strains.close() f_strains_genes.close() -""" +""" # produce region summmary file with all the same info, but also with # region ids (r1-rn), and with genes overlapping each region @@ -178,7 +171,8 @@ fn_align_prefix += '_'.join([refs[s][0] for s in args['species']]) + '_' # for annotated region files (output) -fn_region_prefix = gp.analysis_out_dir_absolute + '/' + args['tag'] + '/regions/' +fn_region_prefix = gp.analysis_out_dir_absolute + \ + '/' + args['tag'] + '/regions/' if not os.path.isdir(fn_region_prefix): os.makedirs(fn_region_prefix) @@ -193,19 +187,23 @@ write_region_summary_header(refs_ordered, f_region_summary) fn_genes_regions = gp.analysis_out_dir_absolute + '/' + args['tag'] + '/' + \ - 'genes_for_each_region_chr' + chrm + '_' + args['tag'] + '.txt' + 'genes_for_each_region_chr' + chrm + '_' \ + + args['tag'] + '.txt' f_genes_regions = open(fn_genes_regions, open_mode) fn_regions_strains = gp.analysis_out_dir_absolute + '/' + args['tag'] + '/' + \ - 'regions_for_each_strain_chr' + chrm + '_' + args['tag'] + '.txt' + 'regions_for_each_strain_chr' + chrm + '_' + \ + args['tag'] + '.txt' f_regions_strains = open(fn_regions_strains, open_mode) fn_genes_strains = gp.analysis_out_dir_absolute + '/' + args['tag'] + '/' + \ - 'genes_for_each_strain_chr' + chrm + '_' + args['tag'] + '.txt' + 'genes_for_each_strain_chr' + chrm + '_' + \ + args['tag'] + '.txt' f_genes_strains = open(fn_genes_strains, open_mode) fn_strains_genes = gp.analysis_out_dir_absolute + '/' + args['tag'] + '/' + \ - 'strains_for_each_gene_chr' + chrm + '_' + args['tag'] + '.txt' + 'strains_for_each_gene_chr' + chrm + '_' + \ + args['tag'] + '.txt' f_strains_genes = open(fn_strains_genes, open_mode) # for keeping track of all genes introgressed in each strain, and the @@ -240,7 +238,7 @@ # loop through all strains that we've called introgression in, and # associate genes with the regions they overlap for strain in regions.keys(): - + print '***', strain, chrm sys.stdout.flush() # skip this strain x chromosome if there are no introgressed @@ -254,7 +252,7 @@ alignment_headers, alignment_seqs = read_fasta.read_fasta(fn_align) labels = ref_labels + [strain] - + # mark each site as matching each reference or not ref_match_by_site = get_ref_match_by_site(alignment_seqs, labels) # mark each site as in a gene or not @@ -281,7 +279,6 @@ # regions are indexed by (unaligned) master ref sequence write_region_alignment(alignment_headers, alignment_seqs, fn_region, \ entry[0], entry[1], 0) - # write region to file in annotated/readable format fn_region_annotated = fn_region_current_prefix + '_annotated' + \ @@ -290,34 +287,35 @@ write_region_alignment_annotated(labels, alignment_seqs, \ fn_region_annotated, \ entry[0], entry[1], 0, \ - genes, ref_match_by_site, + genes, ref_match_by_site, genes_by_site, \ introgressed_by_site, 100) #==== # region summary file with extra info #==== - + # strain chromosome predicted_species start end number_non_gap # number_match_ref1 number_match_ref2 number_match_only_ref1 # number_match_ref2_not_ref1 number_mismatch_all_ref write_region_summary_line(entry, strain, chrm, species_from, \ alignment_seqs, labels, \ - relative_start, relative_end, f_region_summary) + relative_start, relative_end, + f_region_summary) #==== # genes for each region summary file #==== # region_id num_genes gene frac_intd gene frac_intd - + frac_intd = write_genes_for_each_region_summary_line(entry[3], \ genes_by_site, \ genes, \ relative_start, \ relative_end, \ - alignment_seqs[0], \ + alignment_seqs[0], f_genes_regions) for gene in frac_intd: # keep track of all genes for each strain... @@ -331,14 +329,13 @@ gene_strains_dic[gene][strain] = 0 gene_strains_dic[gene][strain] += frac_intd[gene] - #==== # strains for each gene summary file #==== -# (could do this for one chromsoome at a time if we wanted) +# (could do this for one chromsoome at a time if we wanted) # gene num_strains strain frac_intd strain frac_intd - + write_strains_for_each_gene_lines(gene_strains_dic, f_strains_genes) #==== @@ -363,7 +360,5 @@ f_regions_strains.close() f_genes_strains.close() f_strains_genes.close() - - """ diff --git a/code/analyze/to_update/gene_predictions.py b/code/analyze/to_update/gene_predictions.py index 0507a44..0d118e3 100644 --- a/code/analyze/to_update/gene_predictions.py +++ b/code/analyze/to_update/gene_predictions.py @@ -1,14 +1,9 @@ import re -import sys import os -import copy import gzip -sys.path.insert(0, '..') import global_params as gp -sys.path.insert(0, '../sim/') -import sim_analyze_hmm_bw as sim -sys.path.insert(0, '../misc/') -import write_fasta +from misc import write_fasta + def index_ignoring_gaps(s, i, s_start): '''returns the index of the ith (starting at 0) non-gap character in @@ -31,6 +26,7 @@ def index_ignoring_gaps(s, i, s_start): x += 1 return x + def get_ref_match_by_site(seqs, labels): # for master: matches _only_ that ref @@ -51,7 +47,7 @@ def get_ref_match_by_site(seqs, labels): if seqs[0][i] == seqs[-1][i]: ref_match_by_site[0][i] = labels[0][0] - + for r in range(1, nrefs): if seqs[r][i] == seqs[-1][i]: # matches this ref and master ref -> both blank @@ -69,10 +65,9 @@ def get_ref_match_by_site(seqs, labels): else: ref_match_by_site[r][i] = '.' ref_match_by_site[0][i] = '.' - return [''.join(s) for s in ref_match_by_site] - + def get_ref_match_by_site_2(seqs, labels): @@ -99,6 +94,7 @@ def get_ref_match_by_site_2(seqs, labels): return [''.join(s) for s in ref_match_by_site] + def get_genes_by_site(genes, seq): genes_by_site = [None for site in seq] @@ -110,6 +106,7 @@ def get_genes_by_site(genes, seq): genes_by_site[i] = gene_name return genes_by_site + def get_introgressed_by_site(regions, seq): introgressed_by_site = [' ' for site in seq] @@ -119,19 +116,20 @@ def get_introgressed_by_site(regions, seq): for i in range(start_ind, end_ind+1): introgressed_by_site[i] = 'i' return ''.join(introgressed_by_site) - + def write_region_alignment(headers, seqs, fn, start, end, master_ind): - + relative_start = max(0, index_ignoring_gaps(seqs[master_ind], start, 0)) relative_end = index_ignoring_gaps(seqs[master_ind], end, 0) - + region_seqs = [seq[relative_start:relative_end+1] for seq in seqs] write_fasta.write_fasta(headers, region_seqs, fn, gz=True) + def get_genes_in_region(start, end, genes): - + region_genes = [] for gene_name in genes: gene_start, gene_end = genes[gene_name] @@ -142,24 +140,27 @@ def get_genes_in_region(start, end, genes): region_genes.sort(key=lambda x: x[1]) return region_genes -def write_region_alignment_annotated(labels, seqs, fn, start, end, \ - master_ind, genes, ref_match_by_site, \ - genes_by_site, \ + +def write_region_alignment_annotated(labels, seqs, fn, start, end, + master_ind, genes, ref_match_by_site, + genes_by_site, introgressed_by_site, context): relative_start_with_context = \ max(0, index_ignoring_gaps(seqs[master_ind], start-context, 0)) relative_start = max(0, index_ignoring_gaps(seqs[master_ind], start, 0)) relative_end = index_ignoring_gaps(seqs[master_ind], end, 0) - relative_end_with_context = index_ignoring_gaps(seqs[master_ind], end+context, 0) - - region_seqs = [seq[relative_start_with_context:relative_end_with_context+1] \ - for seq in seqs] + relative_end_with_context = index_ignoring_gaps(seqs[master_ind], + end+context, 0) + + region_seqs = [ + seq[relative_start_with_context:relative_end_with_context+1] + for seq in seqs] # for reference matching lines ref_match_strings = [] for r in ref_match_by_site: - ref_match_strings.append(\ + ref_match_strings.append( r[relative_start_with_context:relative_end_with_context+1]) # for gene line @@ -168,10 +169,12 @@ def write_region_alignment_annotated(labels, seqs, fn, start, end, \ region_genes_set = list(set(region_genes)) try: region_genes_set.remove(None) - except: + except ValueError: pass region_genes_set.sort(key=lambda x: genes[x][1]) - gene_string = ''.join([' ' if entry == None else '=' for entry in region_genes]) + gene_string = ''.join([' ' + if entry is None else '=' + for entry in region_genes]) # for introgression line introgressed_string = \ @@ -188,7 +191,7 @@ def write_region_alignment_annotated(labels, seqs, fn, start, end, \ # assume master ref comes first f.write('matches only ' + labels[0] + '\n') # and assume ref seqs come before predict seq - for label in labels[1:-1]: + for label in labels[1:-1]: f.write('matches ' + label + ' and mismatches ' + labels[0] + '\n') f.write('genes: ' + ' '.join(region_genes_set) + '\n') f.write('introgressed\n\n') @@ -212,6 +215,7 @@ def write_region_alignment_annotated(labels, seqs, fn, start, end, \ return relative_start, relative_end + def read_gene_file(fn): f = open(fn, 'r') genes = {} @@ -223,6 +227,7 @@ def read_gene_file(fn): f.close() return genes + def write_gene_file(genes, fn): f = open(fn, 'w') for gene in genes: @@ -230,15 +235,17 @@ def write_gene_file(genes, fn): f.write(gene + '\t' + str(start) + '\t' + str(end) + '\n') f.close() + def write_region_summary_header(refs, f): - f.write('region_id\tstrain\tchromosome\tpredicted_species\tstart\tend\t' + \ + f.write('region_id\tstrain\tchromosome\tpredicted_species\tstart\tend\t' + 'number_non_gap\t') f.write('\t'.join(['number_match_' + ref for ref in refs]) + '\t') f.write('\t'.join(['number_match_only_' + ref for ref in refs]) + '\t') f.write('number_mismatch_all_refs\n') -def write_region_summary_line(region, strain, chrm, predicted_species, seqs, labels, - start, end, f): + +def write_region_summary_line(region, strain, chrm, predicted_species, + seqs, labels, start, end, f): # region_id [strain chromosome predicted_species start end number_non_gap] # number_match_ref1 number_match_ref2 number_match_only_ref1 @@ -246,8 +253,8 @@ def write_region_summary_line(region, strain, chrm, predicted_species, seqs, lab sep = '\t' - f.write(region[3] + sep + strain + sep + chrm + sep + predicted_species + \ - sep + str(region[0]) + sep + str(region[1]) + sep + \ + f.write(region[3] + sep + strain + sep + chrm + sep + predicted_species + + sep + str(region[0]) + sep + str(region[1]) + sep + str(region[2]) + sep) ids = [0] * (len(seqs) - 1) @@ -280,12 +287,13 @@ def write_region_summary_line(region, strain, chrm, predicted_species, seqs, lab continue for r in range(1, len(seqs) - 1): unique_ids[r] += match_refs[r] - + f.write(sep.join([str(x) for x in ids]) + sep) f.write(sep.join([str(x) for x in unique_ids]) + sep) f.write(str(mismatch_all) + '\n') f.flush() + def read_region_summary(fn): # region_id [strain chromosome predicted_species start end number_non_gap] # number_match_ref1 number_match_ref2 number_match_only_ref1 @@ -294,39 +302,41 @@ def read_region_summary(fn): f = open(fn, 'r') line = f.readline() d = {} - fields = ['strain', 'chromosome', 'predicted_species', 'start', 'end', \ - 'number_non_gap', 'number_match_ref1', 'number_match_ref2', \ - 'number_match_only_ref1', 'number_match_ref2_not_ref1', \ + fields = ['strain', 'chromosome', 'predicted_species', 'start', 'end', + 'number_non_gap', 'number_match_ref1', 'number_match_ref2', + 'number_match_only_ref1', 'number_match_ref2_not_ref1', 'number_mismatch_all_ref'] while line != '': line = line[:-1].split('\t') - #TODO actually fix the multiple header lines scattered throughout + # TODO actually fix the multiple header lines scattered throughout if line[0] != 'region_id': d[line[0]] = dict(zip(fields, line[1:])) line = f.readline() f.close() return d -def write_genes_for_each_region_summary_line(region_id, genes_by_site, gene_summary, \ + +def write_genes_for_each_region_summary_line(region_id, genes_by_site, + gene_summary, start, end, seq, f): - + # region_id num_genes gene frac_intd gene frac_intd genes = genes_by_site[start:end+1] genes_set = list(set(genes)) try: genes_set.remove(None) - except: + except ValueError: pass seq_region = seq[start:end+1] gene_site_counts = dict(zip(genes_set, [0]*len(genes_set))) for i in range(len(seq_region)): - if seq_region[i] != gp.gap_symbol and genes[i] != None: + if seq_region[i] != gp.gap_symbol and genes[i] is not None: gene_site_counts[genes[i]] += 1 frac_intd = {} for gene in genes_set: gene_length = gene_summary[gene][1] - gene_summary[gene][0] + 1 frac_intd[gene] = float(gene_site_counts[gene]) / gene_length - + sep = '\t' f.write(region_id + sep) f.write(str(len(genes_set))) @@ -337,6 +347,7 @@ def write_genes_for_each_region_summary_line(region_id, genes_by_site, gene_summ return frac_intd + def read_genes_for_each_region_summary(fn): # region_id num_genes gene frac_intd gene frac_intd @@ -348,18 +359,19 @@ def read_genes_for_each_region_summary(fn): gene_list = [] for i in range(2, len(line), 2): gene_list.append((line[i], line[i+1])) - d[line[0]] = {'num_genes':line[1], 'gene_list':gene_list} + d[line[0]] = {'num_genes': line[1], 'gene_list': gene_list} line = f.readline() f.close() return d + def write_regions_for_each_strain(regions, f): # strain num_regions region length region length sep = '\t' for strain in regions: f.write(strain + sep) - num_regions = sum([len(regions[strain][chrm]) \ + num_regions = sum([len(regions[strain][chrm]) for chrm in regions[strain].keys()]) f.write(str(num_regions)) for chrm in regions[strain].keys(): @@ -369,6 +381,7 @@ def write_regions_for_each_strain(regions, f): f.write('\n') f.flush() + def read_regions_for_each_strain(fn): # strain num_regions region length region length @@ -380,11 +393,12 @@ def read_regions_for_each_strain(fn): region_list = [] for i in range(2, len(line), 2): region_list.append((line[i], line[i+1])) - d[line[0]] = {'num_regions':line[1], 'region_list':region_list} + d[line[0]] = {'num_regions': line[1], 'region_list': region_list} line = f.readline() f.close() return d - + + def write_genes_for_each_strain(strain_genes_dic, f): # strain num_genes gene frac_intd gene frac_intd @@ -396,6 +410,7 @@ def write_genes_for_each_strain(strain_genes_dic, f): f.write('\n') f.flush() + def read_genes_for_each_strain(fn): # strain num_genes gene frac_intd gene frac_intd @@ -407,11 +422,12 @@ def read_genes_for_each_strain(fn): gene_list = [] for i in range(2, len(line), 2): gene_list.append((line[i], line[i+1])) - d[line[0]] = {'num_genes':line[1], 'gene_list':gene_list} + d[line[0]] = {'num_genes': line[1], 'gene_list': gene_list} line = f.readline() f.close() return d + def write_strains_for_each_gene_lines(gene_strains_dic, f): # (this is actually the same as above function, but it's confusing @@ -426,6 +442,7 @@ def write_strains_for_each_gene_lines(gene_strains_dic, f): f.write('\n') f.flush() + def read_strains_for_each_gene(fn): # gene num_strains strain frac_intd strain frac_intd @@ -437,11 +454,12 @@ def read_strains_for_each_gene(fn): strain_list = [] for i in range(2, len(line), 2): strain_list.append((line[i], line[i+1])) - d[line[0]] = {'num_strains':line[1], 'strain_list':strain_list} + d[line[0]] = {'num_strains': line[1], 'strain_list': strain_list} line = f.readline() f.close() return d + def read_genes(fn, fn_genes): if os.path.isfile(fn_genes): @@ -464,11 +482,12 @@ def read_genes(fn, fn_genes): break # starting with new gene - #assert line.strip().startswith('gene'), line + # assert line.strip().startswith('gene'), line skip_this_gene = False # regex for finding coordinates - m = re.search(r'[><]?(?P[0-9]+)[.><,0-9]*\.\.[><]?(?P[0-9]+)', line) + m = re.search(r'[><]?(?P[0-9]+)' + r'[.><,0-9]*\.\.[><]?(?P[0-9]+)', line) # subtract one to index from zero TODO is this correct? end is # inclusive @@ -495,20 +514,25 @@ def read_genes(fn, fn_genes): if gene_name != '': genes[gene_name] = (start, end) else: - print 'gene name not found: ' + line + print('gene name not found: ' + line) f.close() write_gene_file(genes, fn_genes) return genes + """ def summarize_gene_info(fn_all, fn_strains, fn_strains_g, \ introgressed_genes, gene_info, tag, threshold=0): - + f_all = open(fn_all, 'w') - f_all.write('gene\tchromosome\tstart\tend\tnumber_strains\taverage_introgressed_fraction\taverage_number_non_gap\taverage_ref_from_count\n') + f_all.write('gene\tchromosome\tstart\tend\tnumber_strains' + '\taverage_introgressed_fraction\taverage_number_non_gap' + '\taverage_ref_from_count\n') - f_gene_heading = 'region_id\tstrain\tstart\tend\tintrogressed_fraction\tnumber_non_gap\tref_from_count\n' + f_gene_heading = ('region_id\tstrain\tstart\tend\t' + 'introgressed_fraction\tnumber_non_gap' + '\tref_from_count\n') strain_genes = {} @@ -518,7 +542,8 @@ def summarize_gene_info(fn_all, fn_strains, fn_strains_g, \ sum_introgressed_fraction = {} sum_number_non_gap = {} sum_ref_from_count = {} - fn_gene = gp.analysis_out_dir_absolute + tag + '/genes/' + gene + '.txt' + fn_gene = (gp.analysis_out_dir_absolute + tag + + '/genes/' + gene + '.txt') if not os.path.exists(os.path.dirname(fn_gene)): os.makedirs(os.path.dirname(fn_gene)) f_gene = open(fn_gene, 'w') @@ -530,7 +555,8 @@ def summarize_gene_info(fn_all, fn_strains, fn_strains_g, \ sum_introgressed_fraction[strain] = 0 sum_number_non_gap[strain] = 0 sum_ref_from_count[strain] = 0 - sum_introgressed_fraction[strain] += entry['introgressed_fraction'] + sum_introgressed_fraction[strain] += entry[ + 'introgressed_fraction'] sum_number_non_gap[strain] += entry['number_non_gap'] sum_ref_from_count[strain] += entry['ref_from_count'] if strain not in strain_genes: diff --git a/code/analyze/to_update/nucleotide_diversity_from_introgression_main.py b/code/analyze/to_update/nucleotide_diversity_from_introgression_main.py index 6bda7a6..c53b29a 100644 --- a/code/analyze/to_update/nucleotide_diversity_from_introgression_main.py +++ b/code/analyze/to_update/nucleotide_diversity_from_introgression_main.py @@ -1,33 +1,25 @@ -## calculate nucleotide diversity for all sites and for all sites -## excluding introgression; also calculate the same but only in coding -## regions +# calculate nucleotide diversity for all sites and for all sites +# excluding introgression; also calculate the same but only in coding +# regions -import re import sys -import os import copy -import math import gzip -import itertools -sys.path.insert(0, '..') import global_params as gp -sys.path.insert(0, '../align/') -import align_helpers -sys.path.insert(0, '../misc/') -import overlap -import read_table -import read_fasta -import write_fasta -import mystats +from align import align_helpers +from misc import read_table +from misc import read_fasta + def try_int(s, default=-1): try: i = int(s) return i - except: + except ValueError: return default -def count_diffs(s, t, skip_char = 'N'): + +def count_diffs(s, t, skip_char='N'): assert len(s) == len(t) num = 0 den = 0 @@ -39,10 +31,11 @@ def count_diffs(s, t, skip_char = 'N'): den += 1 return num, den -## generate a sequence that has the current strain's base for each -## site in the reference sequence, and skip_char for any site where -## the base is a gap/unknown (this is all based on the alignment) -def referize(strain_seq, ref_ind_to_strain_ind, skip_char = 'N'): + +# generate a sequence that has the current strain's base for each +# site in the reference sequence, and skip_char for any site where +# the base is a gap/unknown (this is all based on the alignment) +def referize(strain_seq, ref_ind_to_strain_ind, skip_char='N'): s = [skip_char for r in ref_ind_to_strain_ind] for i in range(len(ref_ind_to_strain_ind)): si = ref_ind_to_strain_ind[i] @@ -52,6 +45,7 @@ def referize(strain_seq, ref_ind_to_strain_ind, skip_char = 'N'): s[i] = strain_seq[si] return s + def mark_excluded(seq, regions, fill='N'): seqi = copy.deepcopy(seq) for start, end in regions: @@ -59,6 +53,7 @@ def mark_excluded(seq, regions, fill='N'): seqi[i] = fill return seqi + def mark_included(seq, regions, fill='N'): s = [fill for r in seq] for start, end in regions: @@ -66,15 +61,17 @@ def mark_included(seq, regions, fill='N'): s[i] = seq[i] return s + tag = 'u3_i.001_tv_l1000_f.01' -######## -## read in introgressed regions, as well as strains and reference genes -######## +# ###### +# read in introgressed regions, as well as strains and reference genes +# ###### -## dictionary of introgressed regions keyed by chromosome and then -## strain -regions_by_chrm_and_strain = dict(zip(gp.chrms, [{} for i in range(len(gp.chrms))])) +# dictionary of introgressed regions keyed by chromosome and then +# strain +regions_by_chrm_and_strain = dict(zip(gp.chrms, + [{} for i in range(len(gp.chrms))])) fn_regions = gp.analysis_out_dir_absolute + tag + '/' + \ 'introgressed_blocks_filtered_par_' + tag + '_summary_plus.txt' d, labels = read_table.read_table_rows(fn_regions, '\t') @@ -82,20 +79,20 @@ def mark_included(seq, regions, fill='N'): for region in d: chrm = d[region]['chromosome'] strain = d[region]['strain'] - if not regions_by_chrm_and_strain[chrm].has_key(strain): + if strain not in regions_by_chrm_and_strain[chrm]: regions_by_chrm_and_strain[chrm][strain] = [] - regions_by_chrm_and_strain[chrm][strain].append((int(d[region]['start']), \ + regions_by_chrm_and_strain[chrm][strain].append((int(d[region]['start']), int(d[region]['end']))) -## read in all strains +# read in all strains strain_dirs = align_helpers.get_strains(gp.non_ref_dirs[gp.master_ref]) num_strains = len(strain_dirs) -## read in genes in reference sequence into dictionary keyed by -## chromosome +# read in genes in reference sequence into dictionary keyed by +# chromosome ref_genes = {} for chrm in gp.chrms: ref_genes[chrm] = [] - f = open(gp.analysis_out_dir_absolute + gp.master_ref + \ + f = open(gp.analysis_out_dir_absolute + gp.master_ref + '_chr' + chrm + '_genes.txt', 'r') line = f.readline() while line != '': @@ -104,9 +101,9 @@ def mark_included(seq, regions, fill='N'): line = f.readline() f.close() -######## -## calculate nucleotide diversity -######## +# ###### +# calculate nucleotide diversity +# ###### # all sites total_frac = 0 @@ -127,7 +124,7 @@ def mark_included(seq, regions, fill='N'): # total number of strain pairs num_comparisons = 0 -## loop through all strains +# loop through all strains for i in range(num_strains): strain_i, d_i = strain_dirs[i] strain_i_seqs = {} @@ -135,49 +132,53 @@ def mark_included(seq, regions, fill='N'): strain_i_seqs_coding = {} strain_i_seqs_coding_nonint = {} - ## for each + # for each for chrm in gp.chrms: - ## coordinate conversion between reference and current strain + # coordinate conversion between reference and current strain coord_fn = gp.analysis_out_dir_absolute + 'coordinates/' + \ gp.master_ref + '_to_' + strain_i + \ '_chr' + chrm + '.txt.gz' f_coord = gzip.open(coord_fn, 'rb') - ref_ind_to_strain_i_ind = [try_int(line[:-1]) for line in f_coord.readlines()] + ref_ind_to_strain_i_ind = [try_int(line[:-1]) + for line in f_coord.readlines()] - ## current strain fasta file for current chromosome + # current strain fasta file for current chromosome strain_fn = d_i + strain_i + '_chr' + chrm + gp.fasta_suffix - print strain_i, chrm - - ## get chromosome sequence for this strain relative to - ## reference strain (the base for this strain at each site in - ## the reference, based on original alignment); - ## gaps/unsequenced sites/etc marked as 'N' - strain_i_seqs[chrm] = referize(read_fasta.read_fasta(strain_fn)[1][0].lower(),\ - ref_ind_to_strain_i_ind) - - ## get version of sequence where everything that doesn't fall - ## within gene is replaced by 'N' - strain_i_seqs_coding[chrm] = mark_included(strain_i_seqs[chrm],\ + print(strain_i, chrm) + + # get chromosome sequence for this strain relative to + # reference strain (the base for this strain at each site in + # the reference, based on original alignment); + # gaps/unsequenced sites/etc marked as 'N' + strain_i_seqs[chrm] = referize( + read_fasta.read_fasta(strain_fn)[1][0].lower(), + ref_ind_to_strain_i_ind) + + # get version of sequence where everything that doesn't fall + # within gene is replaced by 'N' + strain_i_seqs_coding[chrm] = mark_included(strain_i_seqs[chrm], ref_genes[chrm]) - ## also get version of above sequences where introgressed sites are - ## replaced by 'N' + # also get version of above sequences where introgressed sites are + # replaced by 'N' strain_i_seqs_nonint[chrm] = copy.deepcopy(strain_i_seqs[chrm]) - strain_i_seqs_coding_nonint[chrm] = copy.deepcopy(strain_i_seqs_coding[chrm]) - if regions_by_chrm_and_strain[chrm].has_key(strain_i): - strain_i_seqs_nonint[chrm] = mark_excluded(strain_i_seqs[chrm],\ - regions_by_chrm_and_strain[chrm][strain_i]) + strain_i_seqs_coding_nonint[chrm] = copy.deepcopy( + strain_i_seqs_coding[chrm]) + if strain_i in regions_by_chrm_and_strain[chrm]: + strain_i_seqs_nonint[chrm] = mark_excluded( + strain_i_seqs[chrm], + regions_by_chrm_and_strain[chrm][strain_i]) strain_i_seqs_coding_nonint[chrm] = \ - mark_excluded(strain_i_seqs_coding[chrm],\ + mark_excluded(strain_i_seqs_coding[chrm], regions_by_chrm_and_strain[chrm][strain_i]) - ## loop through all strains to get second strain for current pair + # loop through all strains to get second strain for current pair for j in range(i+1, num_strains): strain_j, d_j = strain_dirs[j] - print strain_i, strain_j - ## keep track of total number of strain pairs we're looking - ## at, so we can divide total by that later + print(strain_i, strain_j) + # keep track of total number of strain pairs we're looking + # at, so we can divide total by that later num_comparisons += 1 num = 0 @@ -190,46 +191,49 @@ def mark_included(seq, regions, fill='N'): den_coding_nonint = 0 for chrm in gp.chrms: - ## do the same reading in of sequence for this strain, - ## relative to reference, and also excluding introgressed - ## sites + # do the same reading in of sequence for this strain, + # relative to reference, and also excluding introgressed + # sites coord_fn = gp.analysis_out_dir_absolute + 'coordinates/' + \ gp.master_ref + '_to_' + strain_j + \ '_chr' + chrm + '.txt.gz' f_coord = gzip.open(coord_fn, 'rb') - ref_ind_to_strain_ind = [try_int(line[:-1]) for line in f_coord.readlines()] - + ref_ind_to_strain_ind = [try_int(line[:-1]) + for line in f_coord.readlines()] + strain_fn = d_j + strain_j + '_chr' + chrm + gp.fasta_suffix - strain_j_seq = referize(read_fasta.read_fasta(strain_fn)[1][0].lower(),\ - ref_ind_to_strain_ind) + strain_j_seq = referize( + read_fasta.read_fasta(strain_fn)[1][0].lower(), + ref_ind_to_strain_ind) strain_j_seq_coding = mark_included(strain_j_seq, ref_genes[chrm]) strain_j_seq_nonint = copy.deepcopy(strain_j_seq) strain_j_seq_coding_nonint = copy.deepcopy(strain_j_seq_coding) - if regions_by_chrm_and_strain[chrm].has_key(strain_j): - strain_j_seq_nonint = mark_excluded(strain_j_seq,\ - regions_by_chrm_and_strain[chrm][strain_j]) - strain_j_seq_coding_nonint = mark_excluded(strain_j_seq_coding,\ - regions_by_chrm_and_strain[chrm][strain_j]) - - ## count sites that differ between the two strains - ## (ignoring any sites where one of the strains has 'N') - ## and add to appropriate running total - - ## all sites + if strain_j in regions_by_chrm_and_strain[chrm]: + strain_j_seq_nonint = mark_excluded( + strain_j_seq, regions_by_chrm_and_strain[chrm][strain_j]) + strain_j_seq_coding_nonint = mark_excluded( + strain_j_seq_coding, + regions_by_chrm_and_strain[chrm][strain_j]) + + # count sites that differ between the two strains + # (ignoring any sites where one of the strains has 'N') + # and add to appropriate running total + + # all sites num_chrm, den_chrm = count_diffs(strain_i_seqs[chrm], strain_j_seq) num += num_chrm den += den_chrm total_fracs[chrm] += float(num_chrm)/den_chrm # nonintrogressed - num_chrm_nonint, den_chrm_nonint = count_diffs(strain_i_seqs_nonint[chrm],\ - strain_j_seq_nonint) + num_chrm_nonint, den_chrm_nonint = count_diffs( + strain_i_seqs_nonint[chrm], strain_j_seq_nonint) num_nonint += num_chrm_nonint den_nonint += den_chrm_nonint total_fracs_nonint[chrm] += float(num_chrm_nonint)/den_chrm_nonint - ## all coding sites + # all coding sites num_chrm_coding, den_chrm_coding = \ count_diffs(strain_i_seqs_coding[chrm], strain_j_seq_coding) num_coding += num_chrm_coding @@ -238,20 +242,20 @@ def mark_included(seq, regions, fill='N'): # coding, nonintrogressed num_chrm_coding_nonint, den_chrm_coding_nonint = \ - count_diffs(strain_i_seqs_coding_nonint[chrm],\ + count_diffs(strain_i_seqs_coding_nonint[chrm], strain_j_seq_coding_nonint) num_coding_nonint += num_chrm_coding_nonint den_coding_nonint += den_chrm_coding_nonint total_fracs_coding_nonint[chrm] += \ float(num_chrm_coding_nonint)/den_chrm_coding_nonint - print num_comparisons, chrm, \ - total_fracs[chrm], \ - total_fracs_nonint[chrm], \ - 1 - total_fracs_nonint[chrm]/total_fracs[chrm], \ - total_fracs_coding[chrm], \ - total_fracs_coding_nonint[chrm], \ - 1 - total_fracs_coding_nonint[chrm]/total_fracs_coding[chrm] + print(num_comparisons, chrm, + total_fracs[chrm], + total_fracs_nonint[chrm], + 1 - total_fracs_nonint[chrm]/total_fracs[chrm], + total_fracs_coding[chrm], + total_fracs_coding_nonint[chrm], + 1 - total_fracs_coding_nonint[chrm]/total_fracs_coding[chrm]) # and keep track across all chromosomes total_frac += float(num)/den @@ -259,9 +263,10 @@ def mark_included(seq, regions, fill='N'): total_frac_coding += float(num_coding)/den_coding total_frac_coding_nonint += float(num_coding_nonint)/den_coding_nonint - print num_comparisons, total_frac, total_frac_nonint, \ - 1 - total_frac_nonint/total_frac, total_frac_coding, \ - total_frac_coding_nonint, 1 - total_frac_coding_nonint/total_frac_coding + print(num_comparisons, total_frac, total_frac_nonint, + 1 - total_frac_nonint/total_frac, total_frac_coding, + total_frac_coding_nonint, + 1 - total_frac_coding_nonint/total_frac_coding) sys.stdout.flush() # nucleotide diversity is the running total of fractions of sites that @@ -272,23 +277,23 @@ def mark_included(seq, regions, fill='N'): nuc_div_coding = total_frac_coding/num_comparisons nuc_div_coding_nonint = total_frac_coding_nonint/num_comparisons -print nuc_div -print nuc_div_nonint -print nuc_div_coding -print nuc_div_coding_nonint +print(nuc_div) +print(nuc_div_nonint) +print(nuc_div_coding) +print(nuc_div_coding_nonint) -######## -## write overall results and results for individual chromosome to file -######## +# ###### +# write overall results and results for individual chromosome to file +# ###### -f = open(gp.analysis_out_dir_absolute + tag + '/polymorphism/' + \ +f = open(gp.analysis_out_dir_absolute + tag + '/polymorphism/' + 'nucleotide_diversity_c.txt', 'w') f.write('chromosome\tpi\tpi_nonint\tpi_coding\tpi_coding_nonint\n') -f.write('all\t' + str(nuc_div) + '\t' + str(nuc_div_nonint) + \ +f.write('all\t' + str(nuc_div) + '\t' + str(nuc_div_nonint) + '\t' + str(nuc_div_coding) + '\t' + str(nuc_div_coding_nonint) + '\n') for chrm in gp.chrms: - f.write(chrm + '\t' + str(total_fracs[chrm]/num_comparisons) + '\t' + \ - str(total_fracs_nonint[chrm]/num_comparisons) + '\t' + \ - str(total_fracs_coding[chrm]/num_comparisons) + '\t' + \ + f.write(chrm + '\t' + str(total_fracs[chrm]/num_comparisons) + '\t' + + str(total_fracs_nonint[chrm]/num_comparisons) + '\t' + + str(total_fracs_coding[chrm]/num_comparisons) + '\t' + str(total_fracs_coding_nonint[chrm]/num_comparisons) + '\n') f.close() diff --git a/code/annotate/fix.py b/code/annotate/fix.py index 4c7ca15..76cf352 100644 --- a/code/annotate/fix.py +++ b/code/annotate/fix.py @@ -1,12 +1,6 @@ -import sys import os -#from orf import * -sys.path.insert(0, '../align') -import align_helpers -sys.path.insert(0, '..') -import global_params as gp -#d = '/tigress/AKEY/akey_vol2/aclark4/nobackup/100_genomes/genomes_gb/orfs/' +# d = '/tigress/AKEY/akey_vol2/aclark4/nobackup/100_genomes/genomes_gb/orfs/' d = '../../data/CBS432/orfs/' fns = os.listdir(d) for fn in fns: diff --git a/code/annotate/makeblastdb.py b/code/annotate/makeblastdb.py index 606969a..4e7e967 100644 --- a/code/annotate/makeblastdb.py +++ b/code/annotate/makeblastdb.py @@ -1,18 +1,13 @@ -import sys import os -#from orf import * -sys.path.insert(0, '../align') -import align_helpers -sys.path.insert(0, '..') import global_params as gp -#d = '/tigress/AKEY/akey_vol2/aclark4/nobackup/100_genomes/genomes_gb/orfs/' +# d = '/tigress/AKEY/akey_vol2/aclark4/nobackup/100_genomes/genomes_gb/orfs/' d = '../../data/CBS432/orfs/' -#d = '/tigress/AKEY/akey_vol2/aclark4/nobackup/100_genomes/genomes_gb/orfs/' +# d = '/tigress/AKEY/akey_vol2/aclark4/nobackup/100_genomes/genomes_gb/orfs/' fns = os.listdir(d) for fn in fns: cmd_string = gp.blast_install_path + 'makeblastdb' + \ ' -dbtype nucl' + \ ' -in ' + d + fn - print cmd_string + print(cmd_string) os.system(cmd_string) diff --git a/code/annotate/orfs_main.py b/code/annotate/orfs_main.py index 9c0a906..6ec3a05 100644 --- a/code/annotate/orfs_main.py +++ b/code/annotate/orfs_main.py @@ -3,49 +3,48 @@ import sys import os -#from orf import * -sys.path.insert(0, '../align') -import align_helpers -sys.path.insert(0, '..') +from align import align_helpers import global_params as gp -ref_fns = [gp.ref_dir[r] + gp.ref_fn_prefix[r] + '_chr' + '?' + \ - gp.fasta_suffix \ +ref_fns = [gp.ref_dir[r] + gp.ref_fn_prefix[r] + '_chr' + '?' + + gp.fasta_suffix for r in gp.alignment_ref_order] # get all non-reference strains of cerevisiae and paradoxus s = align_helpers.get_strains(align_helpers.flatten(gp.non_ref_dirs.values())) # and get paradoxus reference as well -s.append((gp.ref_fn_prefix[gp.alignment_ref_order[1]], gp.ref_dir[gp.alignment_ref_order[1]])) +s.append((gp.ref_fn_prefix[gp.alignment_ref_order[1]], + gp.ref_dir[gp.alignment_ref_order[1]])) strain_fn = '*_chr?' + gp.fasta_suffix f = open('orfs.sh', 'w') -for i in range(78,94): +for i in range(78, 94): strain, d = s[i] - print strain + print(strain) current_strain_fn = strain_fn.replace('*', strain) for chrm in gp.chrms: - print chrm + print(chrm) sys.stdout.flush() - + current_strain_chrm_fn = current_strain_fn.replace('?', chrm) - orf_fn = strain + '_chr' + chrm + \ - '_orfs' + gp.fasta_suffix + orf_fn = strain + '_chr' + chrm + '_orfs' + gp.fasta_suffix orf_d = d + '/orfs/' if not os.path.isdir(orf_d): os.makedirs(orf_d) - cmd_string = gp.orffinder_install_path + '/ORFfinder' + \ - ' -in ' + d + current_strain_chrm_fn + \ - ' -s 0' + \ - ' -out ' + orf_d + orf_fn + \ - ' -outfmt 1 -n true; \n' - #print cmd_string + cmd_string = (gp.orffinder_install_path + '/ORFfinder' + + ' -in ' + d + current_strain_chrm_fn + + ' -s 0' + + ' -out ' + orf_d + orf_fn + + ' -outfmt 1 -n true; \n') + # print(cmd_string) os.system(cmd_string) f.write(cmd_string) f.close() -# "../../../../software/ORFfinder -in /tigress/AKEY/akey_vol2/aclark4/nobackup/100_genomes/genomes_gb/yjm248_chrI.fa -out a.txt -outfmt 1 -n true" +# "../../../../software/ORFfinder \ +# -in /tigress/AKEY/akey_vol2/aclark4/nobackup/\ +# 100_genomes/genomes_gb/yjm248_chrI.fa -out a.txt -outfmt 1 -n true" diff --git a/code/config.yaml b/code/config.yaml index 24d4698..1da5546 100644 --- a/code/config.yaml +++ b/code/config.yaml @@ -41,18 +41,25 @@ paths: analysis: analysis_base: __OUTPUT_ROOT__/analysisp4e2 - regions: __ANALYSIS_BASE__/regions/{state}.fa.gz - region_index: __ANALYSIS_BASE__/regions/{state}.pkl + regions: __OUTPUT_ROOT__/analysis_test/regions/{state}.fa.gz + region_index: __OUTPUT_ROOT__/analysis_test/regions/{state}.pkl genes: __ANALYSIS_BASE__/genes/ - blocks: __ANALYSIS_BASE__/blocks_{state}_p4e2.txt - labeled_blocks: "__ANALYSIS_BASE__/blocks_{state}_p4e2_labeled.txt" - quality: __ANALYSIS_BASE__/block_{state}_quality.txt + blocks: __ANALYSIS_BASE__/blocks_{state}.txt + labeled_blocks: __ANALYSIS_BASE__/blocks_{state}_p4e2_labeled.txt + quality_blocks: __ANALYSIS_BASE__/blocks_{state}_p4e2_quality.txt hmm_initial: __ANALYSIS_BASE__/hmm_initial.txt hmm_trained: __ANALYSIS_BASE__/hmm_trained.txt probabilities: __ANALYSIS_BASE__/probabilities.txt.gz alignment: __ALIGNMENTS__/{prefix}_{strain}_chr{chrom}_mafft.maf positions: __ANALYSIS_BASE__/positions.txt.gz masked_intervals: __MASKS__/{strain}_chr{chrom}_intervals.txt + introgressed: __ANALYSIS_BASE__/blocks_{state}_filter1.txt + introgressed_intermediate: "__ANALYSIS_BASE__/\ + blocks_{state}_filter1inter.txt" + ambiguous: __ANALYSIS_BASE__/blocks_{state}_filter2.txt + ambiguous_intermediate: "__ANALYSIS_BASE__/\ + blocks_{state}_filter2inter.txt" + filter_sweep: __ANALYSIS_BASE__/filter2_thresholds.txt # software install locations software: @@ -69,7 +76,6 @@ paths: ldselect: __ROOT_INSTALL__/ldSelect/ structure: __ROOT_INSTALL__/structure/ -# chromosomes: ['I'] chromosomes: ['I', 'II', 'III', 'IV', 'V', 'VI', 'VII', 'VIII', 'IX', 'X', 'XI', 'XII', 'XIII', 'XIV', 'XV', 'XVI'] @@ -83,12 +89,12 @@ chromosomes: ['I', 'II', 'III', 'IV', 'V', 'VI', 'VII', 'VIII', analysis_params: tag: p2e4 convergence_threshold: 0.001 - # threshold can be 'viterbi' or a float to threshold probabilities + # threshold can be 'viterbi' or a float to threshold HMM probabilities threshold: viterbi input_root: /tigress/AKEY/akey_vol2/aclark4/nobackup + filter_threshold: 0.98 # master known state, prepeded to list of known states - # TODO need to use the other name for S288c! reference: name: S288c base_dir: __INPUT_ROOT__/100_genomes/genomes/S288c_SGD-R64/ diff --git a/code/global_params.py b/code/global_params.py index f01ec9d..a21ef1e 100644 --- a/code/global_params.py +++ b/code/global_params.py @@ -1,33 +1,34 @@ -#==== +# ==== # biological parameters -#==== +# ==== mu = 1.84 * 10 ** -10 -#==== +# ==== # file extensions -#==== +# ==== # suffix for _all_ fasta files fasta_suffix = '.fa' -# suffix for _all_ alignment files; this needs to match the suffix output by mugsy +# suffix for _all_ alignment files +# this needs to match the suffix output by mugsy alignment_suffix = '.maf' -#==== +# ==== # sequence locations/names -#==== +# ==== -## now specified in setup_args file +# now specified in setup_args file -#==== +# ==== # alignment files -#==== +# ==== -## alignments directory now specified in setup_args file +# alignments directory now specified in setup_args file mask_dir = '../alignments/masked/' -#mask_dir = '/tigress/tcomi/aclark4_temp/par4/masked/' +# mask_dir = '/tigress/tcomi/aclark4_temp/par4/masked/' # should we leave the alignments already completed in the alignments # directory alone? @@ -36,9 +37,9 @@ # master_ref now automatically assumed to be first # reference specified in setup_args file -#==== +# ==== # HMM -#==== +# ==== match_symbol = '+' mismatch_symbol = '-' @@ -49,13 +50,12 @@ unaligned_symbol = '?' masked_symbol = 'x' -#==== +# ==== # simulations -#==== +# ==== # output directory for simulpations sim_out_dir_absolute = '/tigress/tcomi/aclark4_temp/results/sim' -#sim_out_dir_absolute = '/tigress/AKEY/akey_vol2/aclark4/projects/introgression/results/sim/' # prefix for simulation output sim_out_prefix = 'sim_out_' @@ -63,9 +63,9 @@ # suffix for simulation output sim_out_suffix = '.txt' -#==== +# ==== # analysis -#==== +# ==== analysis_out_dir_absolute = \ '/tigress/AKEY/akey_vol2/aclark4/projects/introgression/results/analysis/' @@ -74,20 +74,22 @@ genes_out_dir_absolute = analysis_out_dir_absolute + '/genes/' -#==== +# ==== # software install locations -#==== +# ==== mugsy_install_path = '/tigress/anneec/software/mugsy/' -tcoffee_install_path = '/tigress/anneec/software/T-COFFEE_installer_Version_11.00.8cbe486_linux_x64/bin/' +tcoffee_install_path = '/tigress/anneec/software/\ + T-COFFEE_installer_Version_11.00.8cbe486_linux_x64/bin/' mafft_install_path = '/tigress/anneec/software/mafft/bin/' ms_install_path = '/tigress/anneec/software/msdir/' # including dustmasker -blast_install_path = '/tigress/anneec/software/ncbi-blast-2.7.1+-src/c++/ReleaseMT/bin/' +blast_install_path = '/tigress/anneec/software/\ + ncbi-blast-2.7.1+-src/c++/ReleaseMT/bin/' orffinder_install_path = '/tigress/anneec/software/' @@ -95,11 +97,12 @@ structure_install_path = '/tigress/anneec/software/structure/' -#==== +# ==== # other -#==== +# ==== -chrms = ['I', 'II', 'III', 'IV', 'V', 'VI', 'VII', 'VIII', 'IX', 'X', 'XI', 'XII', 'XIII', 'XIV', 'XV', 'XVI'] -#chrms = ['I'] +chrms = ['I', 'II', 'III', 'IV', 'V', 'VI', 'VII', 'VIII', + 'IX', 'X', 'XI', 'XII', 'XIII', 'XIV', 'XV', 'XVI'] +# chrms = ['I'] chrms_ara = dict(zip(chrms, range(1, len(chrms)+1))) diff --git a/code/misc/to_bed.py b/code/misc/to_bed.py index e6ac10a..d14fb07 100644 --- a/code/misc/to_bed.py +++ b/code/misc/to_bed.py @@ -1,6 +1,7 @@ import re -chrms_roman = ['0', 'I', 'II', 'III', 'IV', 'V', 'VI', 'VII', 'VIII', 'IX', 'X', 'XI', 'XII', 'XIII', 'XIV', 'XV', 'XIV'] +chrms_roman = ['0', 'I', 'II', 'III', 'IV', 'V', 'VI', 'VII', 'VIII', + 'IX', 'X', 'XI', 'XII', 'XIII', 'XIV', 'XV', 'XIV'] f = open('../../results/introgressed_id.txt', 'r') @@ -12,11 +13,11 @@ while line != '': if strain in line: line = line.strip().split(',') - - m = re.search('chr(?P[A-Z]+)\.', line[0]) + + m = re.search(r'chr(?P[A-Z]+)\.', line[0]) chrm = m.group('chrm') chrm = 'chr' + str(chrms_roman.index(chrm)) - + d = line[2].find('-') start = line[2][:d] end = line[2][d+1:] @@ -24,8 +25,8 @@ region_count += 1 score = '0' strand = line[1][line[1].find(' strand') - 1] - f_out.write(chrm + '\t' + start + '\t' + end + '\t' + name + '\t' + score + '\t' + strand + '\n') + f_out.write(chrm + '\t' + start + '\t' + end + '\t' + + name + '\t' + score + '\t' + strand + '\n') line = f.readline() f.close() f_out.close() - diff --git a/code/test/analyze/test_filter_1_main.py b/code/test/analyze/test_filter_1_main.py deleted file mode 100644 index 5d3f097..0000000 --- a/code/test/analyze/test_filter_1_main.py +++ /dev/null @@ -1,59 +0,0 @@ -from analyze import filter_1_main as main - - -def test_main(mocker, capsys): - mocker.patch('analyze.filter_1_main.predict.process_predict_args', - return_value={ - 'known_states': ['state1', 'state2'], - 'tag': 'tag' - }) - mocker.patch('analyze.filter_1_main.gp.analysis_out_dir_absolute', - '/dir') - mocker.patch('analyze.filter_1_main.read_table.read_table_rows', - return_value=({'r1': {}, 'r2': {'a': 1}}, ['regions'])) - mocked_file = mocker.patch('analyze.filter_1_main.open') - - mock_read = mocker.patch('analyze.filter_1_main.Region_Reader') - mock_read().__enter__().yield_fa.return_value = iter([ - ('r1', ['> seq', '> info'], ['atcg', 'x..']), - ('r2', ['> seq', '> info'], ['atcg', 'x..'])]) - - mock_filter = mocker.patch('analyze.filter_1_main.passes_filters1', - side_effect=[(False, 'test'), # r1 - (True, '')]) # r2 - mock_write = mocker.patch('analyze.filter_1_main.write_filtered_line') - - main.main() - - captured = capsys.readouterr().out - assert captured == 'state2\n' - - assert mock_read.call_count == 2 # called once during setup - mock_read.assert_called_with('/dirtag/regions/state2.fa.gz', as_fa=True) - - assert mocked_file.call_count == 2 - mocked_file.assert_any_call( - '/dirtag/blocks_state2_tag_filtered1intermediate.txt', 'w') - mocked_file.assert_any_call( - '/dirtag/blocks_state2_tag_filtered1.txt', 'w') - - # just headers, capture others - mocked_file().__enter__().write.assert_has_calls([ - mocker.call('regions\treason\n'), - mocker.call('regions\n')]) - - assert mock_filter.call_count == 2 - # seems like this references the object, which changes after call - mock_filter.assert_has_calls([ - mocker.call({'reason': 'test'}, 'x..', 'state1'), - mocker.call({'reason': '', 'a': 1}, 'x..', 'state1')]) - - assert mock_write.call_count == 3 - mock_write.assert_has_calls([ - mocker.call(mocker.ANY, 'r1', {'reason': 'test'}, - ['regions', 'reason']), - mocker.call(mocker.ANY, 'r2', {'a': 1, 'reason': ''}, - ['regions', 'reason']), - mocker.call(mocker.ANY, 'r2', {'a': 1, 'reason': ''}, - ['regions']), - ]) diff --git a/code/test/analyze/test_filter_2_main.py b/code/test/analyze/test_filter_2_main.py deleted file mode 100644 index da383b3..0000000 --- a/code/test/analyze/test_filter_2_main.py +++ /dev/null @@ -1,89 +0,0 @@ -from analyze import filter_2_main as main - - -def test_main(mocker, capsys): - mocker.patch('sys.argv', ['', '0.1']) - mocker.patch('analyze.filter_2_main.predict.process_predict_args', - return_value={ - 'known_states': ['state1', 'state2'], - 'tag': 'tag' - }) - mocker.patch('analyze.filter_2_main.gp.analysis_out_dir_absolute', - '/dir') - mocker.patch('analyze.filter_2_main.read_table.read_table_rows', - return_value=({'r1': {}, 'r2': {'a': 1}}, ['regions'])) - - mocked_file = mocker.patch('analyze.filter_2_main.open') - - mock_read = mocker.patch('analyze.filter_2_main.Region_Reader') - mock_read().__enter__().yield_fa.return_value = iter([ - ('r1', ['> seq', '> info'], ['atcg', 'x..']), - ('r2', ['> seq', '> info'], ['atcg', 'x..'])]) - - mock_filter = mocker.patch('analyze.filter_2_main.passes_filters2', - side_effect=[ - (False, ['1', '2'], [0.8, 0.5], [2, 1, 0]), - (True, ['1'], [0.8], [2]) - ]) - mock_write = mocker.patch('analyze.filter_2_main.write_filtered_line') - - main.main() - - captured = capsys.readouterr().out - assert captured == 'state2\n' - - assert mock_read.call_count == 2 - mock_read.assert_called_with('/dirtag/regions/state2.fa.gz', as_fa=True) - - assert mocked_file.call_count == 2 - mocked_file.assert_any_call( - '/dirtag/blocks_state2_tag_filtered2intermediate.txt', 'w') - mocked_file.assert_any_call( - '/dirtag/blocks_state2_tag_filtered2.txt', 'w') - - # just headers, capture others - mocked_file().__enter__().write.assert_has_calls([ - mocker.call('regions\talternative_states\t' - 'alternative_ids\talternative_P_counts\n'), - mocker.call('regions\n')]) - - assert mock_filter.call_count == 2 - # seems like this references the object, which changes after call - mock_filter.assert_has_calls([ - mocker.call( - {'alternative_states': '1,2', - 'alternative_ids': '0.8,0.5', - 'alternative_P_counts': '2,1,0'}, - ['atcg'], 0.1, ['state1', 'state2']), - mocker.call( - {'a': 1, - 'alternative_states': '1', - 'alternative_ids': '0.8', - 'alternative_P_counts': '2'}, - ['atcg'], 0.1, ['state1', 'state2'])]) - - assert mock_write.call_count == 3 - mock_write.assert_has_calls([ - mocker.call(mocker.ANY, 'r1', - {'alternative_states': '1,2', - 'alternative_ids': '0.8,0.5', - 'alternative_P_counts': '2,1,0'}, - ['regions', 'alternative_states', - 'alternative_ids', 'alternative_P_counts'] - ), - mocker.call(mocker.ANY, 'r2', - {'a': 1, - 'alternative_states': '1', - 'alternative_ids': '0.8', - 'alternative_P_counts': '2'}, - ['regions', 'alternative_states', - 'alternative_ids', 'alternative_P_counts'] - ), - mocker.call(mocker.ANY, 'r2', - {'a': 1, - 'alternative_states': '1', - 'alternative_ids': '0.8', - 'alternative_P_counts': '2'}, - ['regions'] - ) - ]) diff --git a/code/test/analyze/test_filter_2_thresholds_main.py b/code/test/analyze/test_filter_2_thresholds_main.py deleted file mode 100644 index 9758737..0000000 --- a/code/test/analyze/test_filter_2_thresholds_main.py +++ /dev/null @@ -1,104 +0,0 @@ -from analyze import filter_2_thresholds_main as main - - -def test_main(mocker, capsys): - mocker.patch('sys.argv', ['', '0.1']) - mocker.patch( - 'analyze.filter_2_thresholds_main.predict.process_predict_args', - return_value={ - 'known_states': ['state1', 'state2'], - 'tag': 'tag' - }) - mocker.patch( - 'analyze.filter_2_thresholds_main.thresholds', - [0.99, 0.95]) - mocker.patch( - 'analyze.filter_2_thresholds_main.gp.analysis_out_dir_absolute', - '/dir') - mocker.patch('analyze.filter_2_thresholds_main.read_table.read_table_rows', - return_value=({'r1': {}, 'r2': {'a': 1}}, ['regions'])) - - mocked_file = mocker.patch('analyze.filter_2_thresholds_main.open') - mock_read = mocker.patch('analyze.filter_2_thresholds_main.Region_Reader') - mock_read().__enter__().yield_fa.return_value = iter([ - ('r1', ['> seq', '> info'], ['atcg', 'x..']), - ('r2', ['> seq', '> info'], ['atcg', 'x..'])]) - mock_filter = mocker.patch( - 'analyze.filter_2_thresholds_main.passes_filters2', - side_effect=[ - (False, ['1', '2'], [0.8, 0.5], [2, 1, 0]), - (True, ['1'], [0.8], [2]), - (True, ['1'], [0.8], [2]), - (False, ['1', '2'], [0.8, 0.5], [2, 1, 0]) - ]) - - main.main() - - captured = capsys.readouterr().out - assert captured == '* state2\n' - - assert mock_read.call_count == 2 - mock_read.assert_called_with('/dirtag/regions/state2.fa.gz', as_fa=True) - - assert mocked_file.call_count == 1 - mocked_file.assert_any_call( - '/dirtag/filter_2_thresholds_tag.txt', 'w') - - mocked_file().__enter__().write.assert_has_calls([ - mocker.call('threshold\tpredicted_state\talternative_states\tcount\n'), - mocker.call('0.99\tstate2\t1,2\t1\n'), - mocker.call('0.99\tstate2\t1\t1\n'), - mocker.call('0.95\tstate2\t1\t1\n'), - mocker.call('0.95\tstate2\t1,2\t1\n'), - ]) - - assert mock_filter.call_count == 4 - print(mock_filter.call_args_list) - mock_filter.assert_has_calls([ - mocker.call({}, ['atcg'], 0.99), - mocker.call({}, ['atcg'], 0.95), - mocker.call({'a': 1}, ['atcg'], 0.99), - mocker.call({'a': 1}, ['atcg'], 0.95), - ]) - - -def test_record_data_hit(): - dt = {} - main.record_data_hit(dt, 0.9, 's1', 'k1') - assert dt == {0.9: {'s1': {'k1': 1}}} - main.record_data_hit(dt, 0.9, 's1', 'k1') - main.record_data_hit(dt, 0.9, 's1', 'k1') - assert dt == {0.9: {'s1': {'k1': 3}}} - main.record_data_hit(dt, 0.9, 's1', 'k2') - assert dt == { - 0.9: { - 's1': {'k1': 3, 'k2': 1} - } - } - main.record_data_hit(dt, 0.9, 's2', 'k2') - assert dt == { - 0.9: { - 's1': {'k1': 3, 'k2': 1}, - 's2': {'k2': 1} - } - } - main.record_data_hit(dt, 0.8, 's2', 'k2') - assert dt == { - 0.9: { - 's1': {'k1': 3, 'k2': 1}, - 's2': {'k2': 1} - }, - 0.8: { - 's2': {'k2': 1} - } - } - main.record_data_hit(dt, 0.9, 's2', 'k2') - assert dt == { - 0.9: { - 's1': {'k1': 3, 'k2': 1}, - 's2': {'k2': 2} - }, - 0.8: { - 's2': {'k2': 1} - } - } diff --git a/code/test/analyze/test_filter_helpers.py b/code/test/analyze/test_filter_helpers.py deleted file mode 100644 index 4838156..0000000 --- a/code/test/analyze/test_filter_helpers.py +++ /dev/null @@ -1,245 +0,0 @@ -from analyze import filter_helpers -from io import StringIO -from misc import read_fasta -import os -import warnings -from pytest import approx - - -def test_write_filtered_line(): - # single value, first field is ignored - output = StringIO() - filter_helpers.write_filtered_line(output, 'r1', {'chr': 'I'}, ['', 'chr']) - - assert output.getvalue() == 'r1\tI\n' - - # no value - output = StringIO() - filter_helpers.write_filtered_line(output, 'r1', {}, []) - - assert output.getvalue() == 'r1\t\n' - - # two values - output = StringIO() - filter_helpers.write_filtered_line(output, 'r1', - {'a': 'b', 'c': 'd'}, - ['', 'c', 'a']) - - assert output.getvalue() == 'r1\td\tb\n' - - -def test_passes_filters(): - # check gaps + number masked / end-start+1 > 0.5 - region = {'number_gaps': 1, - 'number_masked_non_gap': 0, - 'start': 0, - 'end': 1, - 'number_match_ref2_not_ref1': 0, - 'number_match_ref1': 0, - 'aligned_length': 0, - } - assert filter_helpers.passes_filters(region) is False - region = {'number_gaps': 1, - 'number_masked_non_gap': 1, - 'start': 0, - 'end': 1, - 'number_match_ref2_not_ref1': 0, - 'number_match_ref1': 0, - 'aligned_length': 0, - } - assert filter_helpers.passes_filters(region) is False - - # check match only > 7 - region = {'number_gaps': 0, - 'number_masked_non_gap': 0, - 'start': 0, - 'end': 1, - 'number_match_ref2_not_ref1': 6, - 'number_match_ref1': 0, - 'aligned_length': 0, - } - assert filter_helpers.passes_filters(region) is False - - # check divergences (match_ref1 / aligned - gapped) < 0.7 - region = {'number_masked_non_gap': 0, - 'start': 0, - 'end': 1, - 'number_match_ref2_not_ref1': 7, - 'number_match_ref1': 6, - 'aligned_length': 11, - 'number_gaps': 1} - assert filter_helpers.passes_filters(region) is False - - # passes - region = {'number_gaps': 0, - 'number_masked_non_gap': 0, - 'start': 0, - 'end': 1, # fraction gaps > 0.5 - 'number_match_ref2_not_ref1': 7, # >= 7 - 'number_match_ref1': 7, # div >= 0.7 - 'aligned_length': 10, - } - assert filter_helpers.passes_filters(region) is True - - -def test_passes_filters1(mocker): - # fail fraction gapped on reference - region = {'predicted_species': 'pred', - 'start': 0, - 'end': 9, - 'num_sites_nonmask_ref': 4, - 'num_sites_nonmask_pred': 0, - 'match_nongap_pred': 0, - 'num_sites_nongap_pred': 0, - 'match_nongap_ref': 0, - 'num_sites_nongap_ref': 0, - } - - assert filter_helpers.passes_filters1(region, '', 'ref') == \ - (False, 'fraction gaps/masked in master = 0.6') - - # fail fraction gapped on predicted - region = {'predicted_species': 'pred', - 'start': 0, - 'end': 9, - 'num_sites_nonmask_ref': 5, - 'num_sites_nonmask_pred': 3, - 'match_nongap_pred': 0, - 'num_sites_nongap_pred': 0, - 'match_nongap_ref': 0, - 'num_sites_nongap_ref': 0, - } - - assert filter_helpers.passes_filters1(region, '', 'ref') == \ - (False, 'fraction gaps/masked in predicted = 0.7') - - # fail match counts - region = {'predicted_species': 'pred', - 'start': 0, - 'end': 9, - 'num_sites_nonmask_ref': 5, - 'num_sites_nonmask_pred': 5, - 'match_nongap_pred': 0, - 'num_sites_nongap_pred': 0, - 'match_nongap_ref': 0, - 'num_sites_nongap_ref': 0, - } - - assert filter_helpers.passes_filters1(region, 'CP', 'ref') == \ - (False, 'count_P = 1') - assert filter_helpers.passes_filters1(region, - 'CCCCCCCCPPPPPPP', 'ref') == \ - (False, 'count_P = 7 and count_C = 8') - - # fail divergence, master >= pred - region = {'predicted_species': 'pred', - 'start': 0, - 'end': 9, - 'num_sites_nonmask_ref': 5, - 'num_sites_nonmask_pred': 5, - 'match_nongap_pred': 5, - 'num_sites_nongap_pred': 10, - 'match_nongap_ref': 6, - 'num_sites_nongap_ref': 10, - } - - assert filter_helpers.passes_filters1(region, 'CPPPPPPP', 'ref') == \ - (False, 'id with master = 0.6 and id with predicted = 0.5') - - # fail divergence, master >= 0.7 - region = {'predicted_species': 'pred', - 'start': 0, - 'end': 9, - 'num_sites_nonmask_ref': 5, - 'num_sites_nonmask_pred': 5, - 'match_nongap_pred': 8, - 'num_sites_nongap_pred': 10, - 'match_nongap_ref': 6, - 'num_sites_nongap_ref': 10, - } - - assert filter_helpers.passes_filters1(region, 'CPPPPPPP', 'ref') == \ - (False, 'id with master = 0.6') - - # passes - region = {'predicted_species': 'pred', - 'start': 0, - 'end': 9, - 'num_sites_nonmask_ref': 5, - 'num_sites_nonmask_pred': 5, - 'match_nongap_pred': 8, - 'num_sites_nongap_pred': 10, - 'match_nongap_ref': 7, - 'num_sites_nongap_ref': 10, - } - - assert filter_helpers.passes_filters1(region, 'CPPPPPPP', 'ref') == \ - (True, '') - - -def test_passes_filters2(mocker): - mocker.patch('analyze.filter_helpers.gp.gap_symbol', '-') - mocker.patch('analyze.filter_helpers.gp.unsequenced_symbol', 'n') - - region = {'predicted_species': '1', - } - seqs = [list('attatt'), # reference - list('aggcat'), # 4 / 5, p = 2 - list('a--tta'), # 2 / 4, p = 1 - list('nng---'), # no matches, '3' not in outputs - list('attatt'), # 2 / 5, p = 0 - list('ag-tat')] # test sequence - - threshold = 0 - filt, states, ids, p_count = filter_helpers.passes_filters2( - region, seqs, threshold, ['ref', '1', '2', '3', '4']) - assert filt is False - assert states == ['1', '2', '4'] - assert ids == [0.8, 0.5, 0.4] - assert p_count == [2, 1, 0] - - threshold = 0.1 - filt, states, ids, p_count = filter_helpers.passes_filters2( - region, seqs, threshold, ['ref', '1', '2', '3', '4']) - assert filt is False - assert states == ['1', '2'] - assert ids == [0.8, 0.5] - assert p_count == [2, 1] - - threshold = 0.9 - filt, states, ids, p_count = filter_helpers.passes_filters2( - region, seqs, threshold, ['ref', '1', '2', '3', '4']) - assert filt is True - assert states == ['1'] - assert ids == [0.8] - assert p_count == [2] - - -def test_passes_filters2_on_region(mocker): - mocker.patch('analyze.filter_helpers.gp.gap_symbol', '-') - mocker.patch('analyze.filter_helpers.gp.unsequenced_symbol', 'n') - - fa = os.path.join(os.path.split(__file__)[0], 'r10805.fa') - - if os.path.exists(fa): - headers, seqs = read_fasta.read_fasta(fa, gz=False) - seqs = seqs[:-1] - p, alt_states, alt_ids, alt_P_counts = filter_helpers.passes_filters2( - {'predicted_species': 'N_45'}, seqs, 0.1, - ['S288c', 'CBS432', 'N_45', 'DBVPG6304', 'UWOPS91_917_1']) - assert p is False - assert alt_states == ['CBS432', 'N_45', 'UWOPS91_917_1', 'DBVPG6304'] - assert alt_ids == approx([0.9983805668016195, 0.994331983805668, - 0.9642857142857143, 0.9618506493506493]) - assert alt_P_counts == [145, 143, 128, 129] - - p, alt_states, alt_ids, alt_P_counts = filter_helpers.passes_filters2( - {'predicted_species': 'N_45'}, seqs, 0.98, - ['S288c', 'CBS432', 'N_45', 'DBVPG6304', 'UWOPS91_917_1']) - assert p is False - assert alt_states == ['CBS432', 'N_45'] - assert alt_ids == approx([0.9983805668016195, 0.994331983805668]) - assert alt_P_counts == [145, 143] - - else: - warnings.warn('Unable to test with datafile r10805.fa') diff --git a/code/test/analyze/test_filter_regions.py b/code/test/analyze/test_filter_regions.py index b56681f..1da5227 100644 --- a/code/test/analyze/test_filter_regions.py +++ b/code/test/analyze/test_filter_regions.py @@ -1,293 +1,306 @@ from analyze import filter_regions +import pytest from io import StringIO from misc import read_fasta import os +import numpy as np import warnings -from pytest import approx - - -def test_main_no_thresh(mocker, capsys): - mocker.patch('sys.argv', ['', '0.1']) - mocker.patch('analyze.filter_regions.predict.process_predict_args', - return_value={ - 'known_states': ['state1', 'state2'], - 'tag': 'tag' - }) - mocker.patch('analyze.filter_regions.gp.analysis_out_dir_absolute', - '/dir') +from analyze.introgression_configuration import Configuration + + +@pytest.fixture +def filterer(): + config = Configuration() + config.set('symbols', + introgressed='int_{state}.txt', + introgressed_intermediate='int_int_{state}.txt', + ambiguous='amb_{state}.txt', + ambiguous_intermediate='amb_int_{state}.txt', + filter_sweep='sweep.txt', + filter_threshold=0.1, + regions='region_{state}.fa.gz', + region_index='region_{state}.pkl', + quality_blocks='block_{state}_quality.txt') + config.add_config({ + 'analysis_params': + {'reference': {'name': 'ref'}, + 'known_states': [ + {'name': 'pred'}, + {'name': 'pred2'}, + ], + } + }) + config.set('states') + return filter_regions.Filterer(config) + + +class NoCloseStringIO(StringIO): + def close(self): + pass + + def super_close(self): + super(StringIO).close(self) + + +def test_run_no_thresh_file(filterer, mocker): mocker.patch('analyze.filter_regions.read_table.read_table_rows', - return_value=({'r1': {}, 'r2': {'a': 1}}, ['regions'])) - mocked_file = mocker.patch('analyze.filter_regions.open') + return_value=({'r1': { + 'predicted_species': 'pred', + 'start': 0, + 'end': 9, + 'num_sites_nonmask_ref': 4, + 'num_sites_nonmask_pred': 0, + 'match_nongap_pred': 0, + 'num_sites_nongap_pred': 0, + 'match_nongap_ref': 0, + 'num_sites_nongap_ref': 0, + }, 'r2': { + 'predicted_species': 'pred', + 'start': 0, + 'end': 9, + 'num_sites_nonmask_ref': 5, + 'num_sites_nonmask_pred': 5, + 'match_nongap_pred': 8, + 'num_sites_nongap_pred': 10, + 'match_nongap_ref': 7, + 'num_sites_nongap_ref': 10, + }, 'r3': { + 'predicted_species': 'pred', + 'start': 0, + 'end': 9, + 'num_sites_nonmask_ref': 5, + 'num_sites_nonmask_pred': 5, + 'match_nongap_pred': 8, + 'num_sites_nongap_pred': 10, + 'match_nongap_ref': 7, + 'num_sites_nongap_ref': 10, + }}, ['regions'])) + + files = [NoCloseStringIO() for i in range(8)] + mocked_file = mocker.patch('analyze.filter_regions.open', + side_effect=files) mock_read = mocker.patch('analyze.filter_regions.Region_Reader') mock_read().__enter__().yield_fa.return_value = iter([ ('r1', ['> seq', '> info'], ['atcg', 'x..']), - ('r2', ['> seq', '> info'], ['atcg', 'x..'])]) + ('r2', ['> seq', '> info'], ['attatt', 'aggcat', 'attatt', + 'ag-tat', np.array(list('CPPPPPPP'))]), + ('r3', ['> seq', '> info'], ['actata', 'attatt', 'nng---', + 'ag-tat', np.array(list('CPPPPPPP'))])]) - mock_filter1 = mocker.patch('analyze.filter_regions.filter_introgressed', - side_effect=[(False, 'test'), # r1 - (True, '')]) # r2 - mock_filter2 = mocker.patch( - 'analyze.filter_regions.filter_ambiguous', - side_effect=[ - (True, ['1'], [0.8], [2]), - (False, ['1', '2'], [0.8, 0.5], [2, 1, 0]) - ]) - mock_write = mocker.patch('analyze.filter_regions.write_filtered_line') - - filter_regions.main() + filterer.config.filter_sweep = None + filterer.run([.9]) - captured = capsys.readouterr().out - assert captured == 'state2\n' - - assert mock_read.call_count == 2 # called once during setup - mock_read.assert_called_with('/dirtag/regions/state2.fa.gz', as_fa=True) + assert mock_read.call_count == 3 # called once during setup + mock_read.assert_called_with('region_pred2.fa.gz', as_fa=True) assert mocked_file.call_args_list == [ - mocker.call('/dirtag/blocks_state2_tag_filtered1intermediate.txt', - 'w'), - mocker.call('/dirtag/blocks_state2_tag_filtered1.txt', 'w'), - mocker.call('/dirtag/blocks_state2_tag_filtered2intermediate.txt', - 'w'), - mocker.call('/dirtag/blocks_state2_tag_filtered2.txt', 'w'), - ] - - # just headers, capture others - assert mocked_file().__enter__().write.call_args_list == [ - mocker.call('regions\treason\n'), - mocker.call('regions\n'), - mocker.call('regions\talternative_states\t' - 'alternative_ids\talternative_P_counts\n'), - mocker.call('regions\n'), - ] - - assert mock_filter1.call_count == 2 - # seems like this references the object, which changes after call - assert mock_filter1.call_args_list == [ - mocker.call({'reason': 'test'}, 'x..', 'state1'), - mocker.call({'a': 1, 'reason': '', - 'alternative_states': '1', - 'alternative_ids': '0.8', - 'alternative_P_counts': '2'}, 'x..', 'state1') - ] - - assert mock_filter2.call_args_list == [ - mocker.call({'a': 1, 'reason': '', - 'alternative_states': '1', - 'alternative_ids': '0.8', - 'alternative_P_counts': '2'}, - ['atcg'], 0.1, ['state1', 'state2']), - ] - assert mock_write.call_args_list == [ - mocker.call(mocker.ANY, 'r1', {'reason': 'test'}, - ['regions', 'reason']), - mocker.call(mocker.ANY, 'r2', {'a': 1, 'reason': '', - 'alternative_states': '1', - 'alternative_ids': '0.8', - 'alternative_P_counts': '2' - }, - ['regions', 'reason']), - mocker.call(mocker.ANY, 'r2', {'a': 1, 'reason': '', - 'alternative_states': '1', - 'alternative_ids': '0.8', - 'alternative_P_counts': '2' - }, - ['regions']), - mocker.call(mocker.ANY, 'r2', {'a': 1, 'reason': '', - 'alternative_states': '1', - 'alternative_ids': '0.8', - 'alternative_P_counts': '2' - }, - ['regions', 'alternative_states', 'alternative_ids', - 'alternative_P_counts']), - mocker.call(mocker.ANY, 'r2', {'a': 1, 'reason': '', - 'alternative_states': '1', - 'alternative_ids': '0.8', - 'alternative_P_counts': '2' - }, - ['regions']), + mocker.call('int_pred.txt', 'w'), + mocker.call('int_int_pred.txt', 'w'), + mocker.call('amb_pred.txt', 'w'), + mocker.call('amb_int_pred.txt', 'w'), + mocker.call('int_pred2.txt', 'w'), + mocker.call('int_int_pred2.txt', 'w'), + mocker.call('amb_pred2.txt', 'w'), + mocker.call('amb_int_pred2.txt', 'w'), ] - -def test_main(mocker, capsys): - mocker.patch('sys.argv', ['', '0.1']) - mocker.patch('analyze.filter_regions.predict.process_predict_args', - return_value={ - 'known_states': ['state1', 'state2'], - 'tag': 'tag' - }) - mocker.patch('analyze.filter_regions.gp.analysis_out_dir_absolute', - '/dir') + assert files[0].getvalue() == 'regions\nr2\t\nr3\t\n' # pass filter 1 + assert files[1].getvalue() == ( + 'regions\treason\n' + 'r1\tfraction gaps/masked in master = 0.6\n' + 'r2\t\n' + 'r3\t\n' + ) + assert files[2].getvalue() == 'regions\nr3\t\n' # pass filter 2 + assert files[3].getvalue() == ( + 'regions\talternative_states\talternative_ids\talternative_P_counts\n' + 'r2\tpred,pred2\t1.0,1.0\t0,0\n' + 'r3\tpred\t1.0\t0\n' + ) + # files 4:8 are just headers + + +def test_run_no_thresh(filterer, mocker): mocker.patch('analyze.filter_regions.read_table.read_table_rows', - return_value=({'r1': {}, 'r2': {'a': 1}}, ['regions'])) - mocked_file = mocker.patch('analyze.filter_regions.open') + return_value=({'r1': { + 'predicted_species': 'pred', + 'start': 0, + 'end': 9, + 'num_sites_nonmask_ref': 4, + 'num_sites_nonmask_pred': 0, + 'match_nongap_pred': 0, + 'num_sites_nongap_pred': 0, + 'match_nongap_ref': 0, + 'num_sites_nongap_ref': 0, + }, 'r2': { + 'predicted_species': 'pred', + 'start': 0, + 'end': 9, + 'num_sites_nonmask_ref': 5, + 'num_sites_nonmask_pred': 5, + 'match_nongap_pred': 8, + 'num_sites_nongap_pred': 10, + 'match_nongap_ref': 7, + 'num_sites_nongap_ref': 10, + }, 'r3': { + 'predicted_species': 'pred', + 'start': 0, + 'end': 9, + 'num_sites_nonmask_ref': 5, + 'num_sites_nonmask_pred': 5, + 'match_nongap_pred': 8, + 'num_sites_nongap_pred': 10, + 'match_nongap_ref': 7, + 'num_sites_nongap_ref': 10, + }}, ['regions'])) + + files = [NoCloseStringIO() for i in range(8)] + mocked_file = mocker.patch('analyze.filter_regions.open', + side_effect=files) mock_read = mocker.patch('analyze.filter_regions.Region_Reader') mock_read().__enter__().yield_fa.return_value = iter([ - ('r1', ['> seq', '> info'], ['atcg', 'x..']), - ('r2', ['> seq', '> info'], ['atcg', 'x..'])]) - - mock_filter1 = mocker.patch('analyze.filter_regions.filter_introgressed', - side_effect=[(False, 'test'), # r1 - (True, '')]) # r2 - mock_filter2 = mocker.patch( - 'analyze.filter_regions.filter_ambiguous', - side_effect=[ - (False, ['1', '2'], [0.8, 0.5], [2, 1, 0]), - (True, ['1'], [0.8], [2]), - (True, ['1'], [0.8], [2]), - (False, ['1', '2'], [0.8, 0.5], [2, 1, 0]) - ]) - mock_write = mocker.patch('analyze.filter_regions.write_filtered_line') + ('r1', ['> seq', '> info'], ['atcg', np.array(list('x..'))]), + ('r2', ['> seq', '> info'], ['attatt', 'aggcat', 'attatt', + 'ag-tat', np.array(list('CPPPPPPP'))]), + ('r3', ['> seq', '> info'], ['actata', 'attatt', 'nng---', + 'ag-tat', np.array(list('CPPPPPPP'))])]) - filter_regions.main([0.99]) + mock_log = mocker.patch('analyze.filter_regions.log') - captured = capsys.readouterr().out - assert captured == 'state2\n' + filterer.run() - assert mock_read.call_count == 2 # called once during setup - mock_read.assert_called_with('/dirtag/regions/state2.fa.gz', as_fa=True) - - assert mocked_file.call_count == 5 - assert mocked_file.call_args_list == [ - mocker.call('/dirtag/filter_2_thresholds_tag.txt', 'w'), - mocker.call('/dirtag/blocks_state2_tag_filtered1intermediate.txt', - 'w'), - mocker.call('/dirtag/blocks_state2_tag_filtered1.txt', 'w'), - mocker.call('/dirtag/blocks_state2_tag_filtered2intermediate.txt', - 'w'), - mocker.call('/dirtag/blocks_state2_tag_filtered2.txt', 'w'), + assert mock_log.info.call_args_list == [ + mocker.call('pred'), + mocker.call('pred2'), ] - # just headers, capture others - assert mocked_file().__enter__().write.call_args_list == [ - mocker.call('threshold\tpredicted_state\talternative_states\tcount\n'), - mocker.call('regions\treason\n'), - mocker.call('regions\n'), - mocker.call('regions\talternative_states\t' - 'alternative_ids\talternative_P_counts\n'), - mocker.call('regions\n'), - mocker.call('0.99\tstate2\t1,2\t1\n') - ] - - assert mock_filter1.call_count == 2 - # seems like this references the object, which changes after call - assert mock_filter1.call_args_list == [ - mocker.call({'reason': 'test'}, 'x..', 'state1'), - mocker.call({'a': 1, 'reason': '', - 'alternative_states': '1', - 'alternative_ids': '0.8', - 'alternative_P_counts': '2'}, 'x..', 'state1') - ] + assert mock_read.call_count == 3 # called once during setup + mock_read.assert_called_with('region_pred2.fa.gz', as_fa=True) - assert mock_filter2.call_args_list == [ - mocker.call({'a': 1, 'reason': '', - 'alternative_states': '1', - 'alternative_ids': '0.8', - 'alternative_P_counts': '2'}, - ['atcg'], 0.99, ['state1', 'state2']), - mocker.call({'a': 1, 'reason': '', - 'alternative_states': '1', - 'alternative_ids': '0.8', - 'alternative_P_counts': '2'}, - ['atcg'], 0.1, ['state1', 'state2']), - ] - assert mock_write.call_args_list == [ - mocker.call(mocker.ANY, 'r1', {'reason': 'test'}, - ['regions', 'reason']), - mocker.call(mocker.ANY, 'r2', {'a': 1, 'reason': '', - 'alternative_states': '1', - 'alternative_ids': '0.8', - 'alternative_P_counts': '2' - }, - ['regions', 'reason']), - mocker.call(mocker.ANY, 'r2', {'a': 1, 'reason': '', - 'alternative_states': '1', - 'alternative_ids': '0.8', - 'alternative_P_counts': '2' - }, - ['regions']), - mocker.call(mocker.ANY, 'r2', {'a': 1, 'reason': '', - 'alternative_states': '1', - 'alternative_ids': '0.8', - 'alternative_P_counts': '2' - }, - ['regions', 'alternative_states', 'alternative_ids', - 'alternative_P_counts']), - mocker.call(mocker.ANY, 'r2', {'a': 1, 'reason': '', - 'alternative_states': '1', - 'alternative_ids': '0.8', - 'alternative_P_counts': '2' - }, - ['regions']), + assert mocked_file.call_args_list == [ + mocker.call('int_pred.txt', 'w'), + mocker.call('int_int_pred.txt', 'w'), + mocker.call('amb_pred.txt', 'w'), + mocker.call('amb_int_pred.txt', 'w'), + mocker.call('int_pred2.txt', 'w'), + mocker.call('int_int_pred2.txt', 'w'), + mocker.call('amb_pred2.txt', 'w'), + mocker.call('amb_int_pred2.txt', 'w'), ] + assert files[0].getvalue() == 'regions\nr2\t\nr3\t\n' # pass filter 1 + assert files[1].getvalue() == ( + 'regions\treason\n' + 'r1\tfraction gaps/masked in master = 0.6\n' + 'r2\t\n' + 'r3\t\n' + ) + assert files[2].getvalue() == 'regions\nr3\t\n' # pass filter 2 + assert files[3].getvalue() == ( + 'regions\talternative_states\talternative_ids\talternative_P_counts\n' + 'r2\tpred,pred2\t1.0,1.0\t0,0\n' + 'r3\tpred\t1.0\t0\n' + ) + + +def test_run(filterer, mocker): + mocker.patch('analyze.filter_regions.read_table.read_table_rows', + return_value=({'r1': { + 'predicted_species': 'pred', + 'start': 0, + 'end': 9, + 'num_sites_nonmask_ref': 4, + 'num_sites_nonmask_pred': 0, + 'match_nongap_pred': 0, + 'num_sites_nongap_pred': 0, + 'match_nongap_ref': 0, + 'num_sites_nongap_ref': 0, + }, 'r2': { + 'predicted_species': 'pred', + 'start': 0, + 'end': 9, + 'num_sites_nonmask_ref': 5, + 'num_sites_nonmask_pred': 5, + 'match_nongap_pred': 8, + 'num_sites_nongap_pred': 10, + 'match_nongap_ref': 7, + 'num_sites_nongap_ref': 10, + }, 'r3': { + 'predicted_species': 'pred', + 'start': 0, + 'end': 9, + 'num_sites_nonmask_ref': 5, + 'num_sites_nonmask_pred': 5, + 'match_nongap_pred': 8, + 'num_sites_nongap_pred': 10, + 'match_nongap_ref': 7, + 'num_sites_nongap_ref': 10, + }}, ['regions'])) + + files = [NoCloseStringIO() for i in range(9)] + mocked_file = mocker.patch('analyze.filter_regions.open', + side_effect=files) -def test_record_data_hit(): - dt = {} - filter_regions.record_data_hit(dt, 0.9, 's1', 'k1') - assert dt == {0.9: {'s1': {'k1': 1}}} - filter_regions.record_data_hit(dt, 0.9, 's1', 'k1') - filter_regions.record_data_hit(dt, 0.9, 's1', 'k1') - assert dt == {0.9: {'s1': {'k1': 3}}} - filter_regions.record_data_hit(dt, 0.9, 's1', 'k2') - assert dt == { - 0.9: { - 's1': {'k1': 3, 'k2': 1} - } - } - filter_regions.record_data_hit(dt, 0.9, 's2', 'k2') - assert dt == { - 0.9: { - 's1': {'k1': 3, 'k2': 1}, - 's2': {'k2': 1} - } - } - filter_regions.record_data_hit(dt, 0.8, 's2', 'k2') - assert dt == { - 0.9: { - 's1': {'k1': 3, 'k2': 1}, - 's2': {'k2': 1} - }, - 0.8: { - 's2': {'k2': 1} - } - } - filter_regions.record_data_hit(dt, 0.9, 's2', 'k2') - assert dt == { - 0.9: { - 's1': {'k1': 3, 'k2': 1}, - 's2': {'k2': 2} - }, - 0.8: { - 's2': {'k2': 1} - } - } - - -def test_write_filtered_line(): - # single value, first field is ignored - output = StringIO() - filter_regions.write_filtered_line(output, 'r1', {'chr': 'I'}, ['', 'chr']) - - assert output.getvalue() == 'r1\tI\n' - - # no value - output = StringIO() - filter_regions.write_filtered_line(output, 'r1', {}, []) + mock_read = mocker.patch('analyze.filter_regions.Region_Reader') + mock_read().__enter__().yield_fa.return_value = iter([ + ('r1', ['> seq', '> info'], ['atcg', 'x..']), + ('r2', ['> seq', '> info'], ['attatt', 'aggcat', 'attatt', + 'ag-tat', np.array(list('CPPPPPPP'))]), + ('r3', ['> seq', '> info'], ['actata', 'attatt', 'nng---', + 'ag-tat', np.array(list('CPPPPPPP'))])]) + mock_log = mocker.patch('analyze.filter_regions.log') - assert output.getvalue() == 'r1\t\n' + filterer.run([0.99, 0.8, 0.1]) - # two values - output = StringIO() - filter_regions.write_filtered_line(output, 'r1', - {'a': 'b', 'c': 'd'}, - ['', 'c', 'a']) + assert mock_log.info.call_args_list == [ + mocker.call('pred'), + mocker.call('pred2'), + ] - assert output.getvalue() == 'r1\td\tb\n' + assert mock_read.call_count == 3 # called once during setup + mock_read.assert_called_with('region_pred2.fa.gz', as_fa=True) + assert mocked_file.call_args_list == [ + mocker.call('sweep.txt', 'w'), + mocker.call('int_pred.txt', 'w'), + mocker.call('int_int_pred.txt', 'w'), + mocker.call('amb_pred.txt', 'w'), + mocker.call('amb_int_pred.txt', 'w'), + mocker.call('int_pred2.txt', 'w'), + mocker.call('int_int_pred2.txt', 'w'), + mocker.call('amb_pred2.txt', 'w'), + mocker.call('amb_int_pred2.txt', 'w'), + ] -def test_filter_introgressed(mocker): + print(files[0].getvalue()) + assert files[0].getvalue() == ( + 'threshold\tpredicted_state\talternative_states\tcount\n' + '0.99\tpred\tpred,pred2\t1\n' + '0.99\tpred\tpred\t1\n' + '0.8\tpred\tpred,pred2\t1\n' + '0.8\tpred\tpred\t1\n' + '0.1\tpred\tpred,pred2\t1\n' + '0.1\tpred\tpred\t1\n' + ) + + assert files[1].getvalue() == 'regions\nr2\t\nr3\t\n' + assert files[2].getvalue() == ( + 'regions\treason\n' + 'r1\tfraction gaps/masked in master = 0.6\n' + 'r2\t\n' + 'r3\t\n' + ) + assert files[3].getvalue() == 'regions\nr3\t\n' # pass filter 2 + assert files[4].getvalue() == ( + 'regions\talternative_states\talternative_ids\talternative_P_counts\n' + 'r2\tpred,pred2\t1.0,1.0\t0,0\n' + 'r3\tpred\t1.0\t0\n' + ) + + +def test_filter_introgressed(filterer, mocker): # fail fraction gapped on reference region = {'predicted_species': 'pred', 'start': 0, @@ -300,7 +313,7 @@ def test_filter_introgressed(mocker): 'num_sites_nongap_ref': 0, } - assert filter_regions.filter_introgressed(region, '', 'ref') == \ + assert filterer.filter_introgressed(region, '', 'ref') == \ (False, 'fraction gaps/masked in master = 0.6') # fail fraction gapped on predicted @@ -315,7 +328,7 @@ def test_filter_introgressed(mocker): 'num_sites_nongap_ref': 0, } - assert filter_regions.filter_introgressed(region, '', 'ref') == \ + assert filterer.filter_introgressed(region, '', 'ref') == \ (False, 'fraction gaps/masked in predicted = 0.7') # fail match counts @@ -330,10 +343,12 @@ def test_filter_introgressed(mocker): 'num_sites_nongap_ref': 0, } - assert filter_regions.filter_introgressed(region, 'CP', 'ref') == \ + assert filterer.filter_introgressed(region, + np.array(list('CP')), 'ref') == \ (False, 'count_P = 1') - assert filter_regions.filter_introgressed(region, - 'CCCCCCCCPPPPPPP', 'ref') == \ + assert filterer.filter_introgressed(region, + np.array(list('CCCCCCCCPPPPPPP')), + 'ref') == \ (False, 'count_P = 7 and count_C = 8') # fail divergence, master >= pred @@ -348,7 +363,8 @@ def test_filter_introgressed(mocker): 'num_sites_nongap_ref': 10, } - assert filter_regions.filter_introgressed(region, 'CPPPPPPP', 'ref') == \ + assert filterer.filter_introgressed(region, + np.array(list('CPPPPPPP')), 'ref') == \ (False, 'id with master = 0.6 and id with predicted = 0.5') # fail divergence, master >= 0.7 @@ -363,7 +379,8 @@ def test_filter_introgressed(mocker): 'num_sites_nongap_ref': 10, } - assert filter_regions.filter_introgressed(region, 'CPPPPPPP', 'ref') == \ + assert filterer.filter_introgressed(region, + np.array(list('CPPPPPPP')), 'ref') == \ (False, 'id with master = 0.6') # passes @@ -378,16 +395,13 @@ def test_filter_introgressed(mocker): 'num_sites_nongap_ref': 10, } - assert filter_regions.filter_introgressed(region, 'CPPPPPPP', 'ref') == \ + assert filterer.filter_introgressed(region, + np.array(list('CPPPPPPP')), 'ref') == \ (True, '') -def test_filter_ambiguous(mocker): - mocker.patch('analyze.filter_regions.gp.gap_symbol', '-') - mocker.patch('analyze.filter_regions.gp.unsequenced_symbol', 'n') - - region = {'predicted_species': '1', - } +def test_filter_ambiguous(filterer, mocker): + region = {'predicted_species': '1'} seqs = [list('attatt'), # reference list('aggcat'), # 4 / 5, p = 2 list('a--tta'), # 2 / 4, p = 1 @@ -396,55 +410,332 @@ def test_filter_ambiguous(mocker): list('ag-tat')] # test sequence threshold = 0 - filt, states, ids, p_count = filter_regions.filter_ambiguous( + filt, states = filterer.filter_ambiguous( region, seqs, threshold, ['ref', '1', '2', '3', '4']) assert filt is False + assert region['alternative_states'] == '1,2,4' + assert region['alternative_ids'] == '0.8,0.5,0.4' + assert region['alternative_P_counts'] == '2,1,0' assert states == ['1', '2', '4'] - assert ids == [0.8, 0.5, 0.4] - assert p_count == [2, 1, 0] threshold = 0.1 - filt, states, ids, p_count = filter_regions.filter_ambiguous( + filt, _ = filterer.filter_ambiguous( region, seqs, threshold, ['ref', '1', '2', '3', '4']) assert filt is False - assert states == ['1', '2'] - assert ids == [0.8, 0.5] - assert p_count == [2, 1] + assert region['alternative_states'] == '1,2' + assert region['alternative_ids'] == '0.8,0.5' + assert region['alternative_P_counts'] == '2,1' threshold = 0.9 - filt, states, ids, p_count = filter_regions.filter_ambiguous( + filt, _ = filterer.filter_ambiguous( region, seqs, threshold, ['ref', '1', '2', '3', '4']) assert filt is True - assert states == ['1'] - assert ids == [0.8] - assert p_count == [2] + assert region['alternative_states'] == '1' + assert region['alternative_ids'] == '0.8' + assert region['alternative_P_counts'] == '2' -def test_filter_ambiguous_on_region(mocker): - mocker.patch('analyze.filter_regions.gp.gap_symbol', '-') - mocker.patch('analyze.filter_regions.gp.unsequenced_symbol', 'n') +def test_filter_ambiguous_on_region_10817(filterer, mocker): - fa = os.path.join(os.path.split(__file__)[0], 'r10805.fa') + fa = os.path.join(os.path.split(__file__)[0], 'r10817.fa') if os.path.exists(fa): headers, seqs = read_fasta.read_fasta(fa, gz=False) seqs = seqs[:-1] - p, alt_states, alt_ids, alt_P_counts = filter_regions.filter_ambiguous( - {'predicted_species': 'N_45'}, seqs, 0.1, + region = {'predicted_species': 'CBS432'} + p, _ = filterer.filter_ambiguous( + region, seqs, 0.98, ['S288c', 'CBS432', 'N_45', 'DBVPG6304', 'UWOPS91_917_1']) assert p is False - assert alt_states == ['CBS432', 'N_45', 'UWOPS91_917_1', 'DBVPG6304'] - assert alt_ids == approx([0.9983805668016195, 0.994331983805668, - 0.9642857142857143, 0.9618506493506493]) - assert alt_P_counts == [145, 143, 128, 129] + assert region['alternative_states'] == ( + 'CBS432,N_45') + assert region['alternative_P_counts'] == '111,110' + + else: + warnings.warn('Unable to test with datafile r10817.fa') + - p, alt_states, alt_ids, alt_P_counts = filter_regions.filter_ambiguous( - {'predicted_species': 'N_45'}, seqs, 0.98, +def test_filter_ambiguous_on_region_10805(filterer, mocker): + + fa = os.path.join(os.path.split(__file__)[0], 'r10805.fa') + + if os.path.exists(fa): + headers, seqs = read_fasta.read_fasta(fa, gz=False) + seqs = seqs[:-1] + region = {'predicted_species': 'N_45'} + p, _ = filterer.filter_ambiguous( + region, seqs, 0.1, + ['S288c', 'CBS432', 'N_45', 'DBVPG6304', 'UWOPS91_917_1']) + assert p is False + assert region['alternative_states'] == ( + 'CBS432,N_45,UWOPS91_917_1,DBVPG6304') + assert region['alternative_ids'] == ( + '0.9983805668016195,0.994331983805668,' + '0.9642857142857143,0.9618506493506493') + assert region['alternative_P_counts'] == '145,143,128,129' + + region = {'predicted_species': 'N_45'} + p, _ = filterer.filter_ambiguous( + region, seqs, 0.98, ['S288c', 'CBS432', 'N_45', 'DBVPG6304', 'UWOPS91_917_1']) assert p is False - assert alt_states == ['CBS432', 'N_45'] - assert alt_ids == approx([0.9983805668016195, 0.994331983805668]) - assert alt_P_counts == [145, 143] + assert region['alternative_states'] == 'CBS432,N_45' + assert region['alternative_ids'] == ( + '0.9983805668016195,0.994331983805668') + assert region['alternative_P_counts'] == '145,143' else: warnings.warn('Unable to test with datafile r10805.fa') + + +@pytest.fixture +def filter_sweep(): + return filter_regions.Filter_Sweep(None, []) + + +def test_filter_sweep_context(mocker): + # no file, no list + mock_open = mocker.patch('analyze.filter_regions.open') + fs = filter_regions.Filter_Sweep(None, []) + mock_open.assert_not_called() + + fs.__enter__() + assert fs.sweep_writer is None + mock_open.assert_not_called() + fs.__exit__(None, None, None) + mock_open.return_value.close.assert_not_called() + + # file, no list + mock_open = mocker.patch('analyze.filter_regions.open') + fs = filter_regions.Filter_Sweep('sweep.txt', []) + mock_open.assert_not_called() + + fs.__enter__() + assert fs.sweep_writer is None + mock_open.assert_not_called() + assert not fs.__exit__(None, None, 'trace') + mock_open.return_value.close.assert_not_called() + + # file, list + mock_open = mocker.patch('analyze.filter_regions.open') + fs = filter_regions.Filter_Sweep('sweep.txt', [.99]) + mock_open.assert_not_called() + + fs.__enter__() + mock_open.assert_called_once_with('sweep.txt', 'w') + assert fs.sweep_writer is not None + assert fs.__exit__(None, None, None) + mock_open.return_value.close.assert_called_once() + + +def test_sweep_write_header(filter_sweep): + output = StringIO() + filter_sweep.sweep_writer = output + + filter_sweep.write_header() + assert output.getvalue() == \ + 'threshold\tpredicted_state\talternative_states\tcount\n' + + +def test_sweep_record(filter_sweep, mocker): + mock_lambda = mocker.MagicMock( + side_effect=[ + (0, ['s1']), + (0, ['s1', 's2']), + (0, ['s2', 's3']), + (0, ['s4', 's3']), + ]) + filter_sweep.thresholds = [1, 0.9, 0.8, 0.7] + + filter_sweep.record('test', mock_lambda) + mock_lambda.assert_not_called() + + filter_sweep.sweep_writer = '' + filter_sweep.record('test', mock_lambda) + assert mock_lambda.call_args_list == [ + mocker.call(1), + mocker.call(0.9), + mocker.call(0.8), + mocker.call(0.7)] + + assert filter_sweep.data_table == { + 1: {'test': {'s1': 1}}, + 0.9: {'test': {'s1,s2': 1}}, + 0.8: {'test': {'s2,s3': 1}}, + 0.7: {'test': {'s3,s4': 1}}, + } + + +def test_sweep_write_results(filter_sweep): + filter_sweep.data_table == { + 1: {'test': {'s1': 1}}, + 0.9: {'test': {'s1,s2': 1}}, + 0.8: {'test': {'s2,s3': 1}}, + 0.7: {'test': {'s3,s4': 1}}, + } + filter_sweep.thresholds = [1, 0.9, 0.8, 0.7, 0] + + filter_sweep.write_results([]) + + output = StringIO() + filter_sweep.sweep_writer = output + + filter_sweep.write_results(['state']) + assert output.getvalue() == '' + + filter_sweep.write_results(['test']) + assert output.getvalue() == ( + '' + ) + + +def test_record_data_hit(filter_sweep): + filter_sweep.record_data_hit(0.9, 's1', ['k1']) + assert filter_sweep.data_table == {0.9: {'s1': {'k1': 1}}} + filter_sweep.record_data_hit(0.9, 's1', ['k1']) + filter_sweep.record_data_hit(0.9, 's1', ['k1']) + assert filter_sweep.data_table == {0.9: {'s1': {'k1': 3}}} + filter_sweep.record_data_hit(0.9, 's1', ['k2']) + assert filter_sweep.data_table == { + 0.9: { + 's1': {'k1': 3, 'k2': 1} + } + } + filter_sweep.record_data_hit(0.9, 's2', ['k2']) + assert filter_sweep.data_table == { + 0.9: { + 's1': {'k1': 3, 'k2': 1}, + 's2': {'k2': 1} + } + } + filter_sweep.record_data_hit(0.8, 's2', ['k2']) + assert filter_sweep.data_table == { + 0.9: { + 's1': {'k1': 3, 'k2': 1}, + 's2': {'k2': 1} + }, + 0.8: { + 's2': {'k2': 1} + } + } + filter_sweep.record_data_hit(0.9, 's2', ['k2', 'k3']) + assert filter_sweep.data_table == { + 0.9: { + 's1': {'k1': 3, 'k2': 1}, + 's2': {'k2': 1, 'k2,k3': 1} + }, + 0.8: { + 's2': {'k2': 1} + } + } + + +@pytest.fixture +def filter_writer(): + config = Configuration() + config.set(introgressed='int_{state}.txt', + introgressed_intermediate='int_int_{state}.txt', + ambiguous='amb_{state}.txt', + ambiguous_intermediate='amb_int_{state}.txt') + return filter_regions.Filter_Writers(config) + + +def test_filter_writer_init(filter_writer): + assert filter_writer.files == { + 'introgressed': 'int_{state}.txt', + 'introgressed_int': 'int_int_{state}.txt', + 'ambiguous': 'amb_{state}.txt', + 'ambiguous_int': 'amb_int_{state}.txt' + } + assert filter_writer.writers is None + assert filter_writer.headers is None + + +def test_filter_writer_context(filter_writer, mocker): + mock_open = mocker.patch('analyze.filter_regions.open') + with filter_writer.open_state('s1', ['h1']) as filter_writer: + assert mock_open.call_args_list == [ + mocker.call('int_s1.txt', 'w'), + mocker.call('int_int_s1.txt', 'w'), + mocker.call('amb_s1.txt', 'w'), + mocker.call('amb_int_s1.txt', 'w')] + mock_open.return_value.close.assert_not_called() + assert filter_writer.headers == { + 'introgressed': ['h1'], + 'introgressed_int': ['h1', 'reason'], + 'ambiguous': ['h1'], + 'ambiguous_int': ['h1', 'alternative_states', + 'alternative_ids', 'alternative_P_counts'] + } + + assert mock_open.return_value.close.call_count == 4 + assert filter_writer.writers is None + assert filter_writer.headers is None + mock_open.reset_mock() + + with filter_writer.open_state('s2', ['h2']) as filter_writer: + assert mock_open.call_args_list == [ + mocker.call('int_s2.txt', 'w'), + mocker.call('int_int_s2.txt', 'w'), + mocker.call('amb_s2.txt', 'w'), + mocker.call('amb_int_s2.txt', 'w')] + + mock_open.return_value.close.assert_not_called() + + assert filter_writer.headers == { + 'introgressed': ['h2'], + 'introgressed_int': ['h2', 'reason'], + 'ambiguous': ['h2'], + 'ambiguous_int': ['h2', 'alternative_states', + 'alternative_ids', 'alternative_P_counts'] + } + + assert mock_open.return_value.close.call_count == 4 + + +def test_filter_writers_write_headers(filter_writer): + filter_writer.write_headers() # nop + + filter_writer.writers = { + 'introgressed': StringIO(), + 'introgressed_int': StringIO(), + 'ambiguous': StringIO(), + 'ambiguous_int': StringIO() + } + + filter_writer.write_headers() # nop + + filter_writer.headers = { + 'introgressed': ['h1'], + 'introgressed_int': ['h2', 'h3'], + 'ambiguous': ['h4'], + 'ambiguous_int': ['h5'] + } + + filter_writer.write_headers() + assert filter_writer.writers['introgressed'].getvalue() == 'h1\n' + assert filter_writer.writers['introgressed_int'].getvalue() == 'h2\th3\n' + assert filter_writer.writers['ambiguous'].getvalue() == 'h4\n' + assert filter_writer.writers['ambiguous_int'].getvalue() == 'h5\n' + + +def test_write_filtered_line(filter_writer): + # single value, first field is ignored + output = StringIO() + filter_writer.write_filtered_line(output, 'r1', {'chr': 'I'}, ['', 'chr']) + + assert output.getvalue() == 'r1\tI\n' + + # no value + output = StringIO() + filter_writer.write_filtered_line(output, 'r1', {}, []) + + assert output.getvalue() == 'r1\t\n' + + # two values + output = StringIO() + filter_writer.write_filtered_line(output, 'r1', + {'a': 'b', 'c': 'd'}, + ['', 'c', 'a']) + + assert output.getvalue() == 'r1\td\tb\n' diff --git a/code/test/analyze/test_id_regions.py b/code/test/analyze/test_id_regions.py index f811285..8c1cf58 100644 --- a/code/test/analyze/test_id_regions.py +++ b/code/test/analyze/test_id_regions.py @@ -18,7 +18,7 @@ def id_producer(): 'unknown_states': [{'name': 'unknown'}] } }) - config.set_states() + config.set('states') result = id_regions.ID_producer(config) return result @@ -42,9 +42,9 @@ def test_add_ids_empty(id_producer, mocker): }}}) id_producer.config.states = 'ref state1 unknown'.split() - id_producer.config.set_blocks_file() - id_producer.config.set_labeled_blocks_file() - id_producer.config.set_chromosomes() + id_producer.config.set(blocks=None, + labeled_blocks=None, + chromosomes=None) mocker.patch('analyze.id_regions.read_blocks', return_value={}) @@ -76,9 +76,9 @@ def test_add_ids(id_producer, mocker): }}}) id_producer.config.states = 'ref state1 unknown'.split() - id_producer.config.set_blocks_file() - id_producer.config.set_labeled_blocks_file() - id_producer.config.set_chromosomes() + id_producer.config.set(blocks=None, + labeled_blocks=None, + chromosomes=None) regions = [ { diff --git a/code/test/analyze/test_introgression_configuration.py b/code/test/analyze/test_introgression_configuration.py index e8ce776..3e3379e 100644 --- a/code/test/analyze/test_introgression_configuration.py +++ b/code/test/analyze/test_introgression_configuration.py @@ -1,4 +1,5 @@ -from analyze.introgression_configuration import Configuration +from analyze.introgression_configuration import ( + Configuration, Variable) import pytest @@ -7,29 +8,80 @@ def config(): return Configuration() -def test_set_log_file(config): - config.set_log_file() +def test_set(config): + # unknown key + with pytest.raises(ValueError) as e: + config.set(asdf=None) + assert 'Unknown variable to set: asdf' in str(e) + + # chromosomes + with pytest.raises(ValueError) as e: + config.set('chromosomes') + assert 'No chromosomes provided' in str(e) + + config.config = {'chromosomes': ['I']} + config.set('chromosomes') + assert config.chromosomes == ['I'] + + # log file + config.set(log_file='') assert config.log_file is None - config.set_log_file('test') + config.set(log_file='test') assert config.log_file == 'test' config.config = {'paths': {'log_file': 'log'}} - config.set_log_file() + config.set(log_file='') assert config.log_file == 'log' - config.set_log_file('test') + config.set(log_file='test') assert config.log_file == 'test' -def test_set_chromosomes(config): - with pytest.raises(ValueError) as e: - config.set_chromosomes() - assert 'No chromosomes specified in config file!' in str(e) - - config.config = {'chromosomes': ['I']} - config.set_chromosomes() - assert config.chromosomes == ['I'] +def test_set_state_files(config): + state_files = [ + 'blocks', + 'labeled_blocks', + 'quality_blocks', + 'introgressed', + 'introgressed_intermediate', + 'ambiguous', + 'ambiguous_intermediate', + ] + for sf in state_files: + with pytest.raises(ValueError) as e: + config.set(**{sf: None}) + assert f'No {sf} provided' in str(e) + + with pytest.raises(ValueError) as e: + config.set(**{sf: 'test'}) + assert '{state} not found in test' in str(e) + + config.set(**{sf: 'test{state}'}) + assert config.__dict__[sf] == 'test{state}' + + config.config = {'paths': {'analysis': {sf: 'test2{state}'}}} + config.set(**{sf: None}) + assert config.__dict__[sf] == 'test2{state}' + + +def test_set_nonwild_files(config): + nonwild_files = [ + 'hmm_initial', + 'hmm_trained', + 'positions' + ] + for nwf in nonwild_files: + with pytest.raises(ValueError) as e: + config.set(**{nwf: None}) + assert f'No {nwf} provided' in str(e) + + config.set(**{nwf: 'test'}) + assert config.__dict__[nwf] == 'test' + + config.config = {'paths': {'analysis': {nwf: 'test2'}}} + config.set(**{nwf: None}) + assert config.__dict__[nwf] == 'test2' def test_get_states(config): @@ -141,7 +193,7 @@ def test_set_states(config): } } - config.set_states() + config.set('states') assert config.known_states ==\ 'S288c CBS432 N_45 DBVPG6304 UWOPS91_917_1'.split() assert config.unknown_states ==\ @@ -149,7 +201,7 @@ def test_set_states(config): assert config.states ==\ 'S288c CBS432 N_45 DBVPG6304 UWOPS91_917_1 unknown'.split() - config.set_states([]) + config.set(states=[]) assert config.known_states ==\ 'S288c CBS432 N_45 DBVPG6304 UWOPS91_917_1'.split() assert config.unknown_states ==\ @@ -157,122 +209,75 @@ def test_set_states(config): assert config.states ==\ 'S288c CBS432 N_45 DBVPG6304 UWOPS91_917_1 unknown'.split() - config.set_states('testing 123'.split()) + config.set(states='testing 123'.split()) assert config.states == ['testing', '123'] config.config = {} with pytest.raises(ValueError) as e: - config.set_states() + config.set('states') assert 'No states specified' in str(e) def test_set_threshold(config): with pytest.raises(ValueError) as e: - config.set_threshold() + config.set('threshold') assert 'No threshold provided' in str(e) config.config = {'analysis_params': {'threshold': 'asdf'}} with pytest.raises(ValueError) as e: - config.set_threshold() + config.set('threshold') assert 'Unsupported threshold value: asdf' in str(e) - config.set_threshold(0.05) + config.set(threshold=0.05) assert config.threshold == 0.05 config.config = {'analysis_params': {'threshold': 'viterbi'}} - config.set_threshold() + config.set('threshold') assert config.threshold == 'viterbi' -def test_set_labeled_blocks_file(config): - with pytest.raises(ValueError) as e: - config.set_labeled_blocks_file('blocks_file') - assert '{state} not found in blocks_file' in str(e) - - config.set_labeled_blocks_file('blocks_file{state}') - assert config.labeled_blocks == 'blocks_file{state}' - - with pytest.raises(ValueError) as e: - config.set_labeled_blocks_file() - assert 'No labeled block file provided' in str(e) - - config.config = {'paths': {'analysis': - {'labeled_blocks': 'blocks_file'}}} - with pytest.raises(ValueError) as e: - config.set_labeled_blocks_file() - assert '{state} not found in blocks_file' in str(e) - - config.config = {'paths': {'analysis': {'labeled_blocks': - 'blocks_file{state}'}}} - config.set_labeled_blocks_file() - assert config.labeled_blocks == 'blocks_file{state}' - - -def test_set_blocks_file(config): - with pytest.raises(ValueError) as e: - config.set_blocks_file('blocks_file') - assert '{state} not found in blocks_file' in str(e) - - config.set_blocks_file('blocks_file{state}') - assert config.blocks == 'blocks_file{state}' - - with pytest.raises(ValueError) as e: - config.set_blocks_file() - assert 'No block file provided' in str(e) - - config.config = {'paths': {'analysis': {'blocks': 'blocks_file'}}} - with pytest.raises(ValueError) as e: - config.set_blocks_file() - assert '{state} not found in blocks_file' in str(e) - - config.config = {'paths': {'analysis': {'blocks': - 'blocks_file{state}'}}} - config.set_blocks_file() - assert config.blocks == 'blocks_file{state}' - - def test_set_prefix(config): config.known_states = ['s1'] - config.set_prefix() + config.set('prefix') assert config.prefix == 's1' config.known_states = 's1 s2'.split() - config.set_prefix() + config.set('prefix') assert config.prefix == 's1_s2' - config.set_prefix('prefix') + config.set(prefix='prefix') assert config.prefix == 'prefix' config.known_states = [] with pytest.raises(ValueError) as e: - config.set_prefix() + config.set('prefix') assert 'Unable to build prefix, no known states provided' in str(e) def test_set_strains(config, mocker): mock_find = mocker.patch.object(Configuration, 'find_strains') - config.set_strains() + config.set('strains') mock_find.called_with(None) with pytest.raises(ValueError) as e: config.config = {'paths': {'test_strains': ['test']}} - config.set_strains() + config.set('strains') assert '{strain} not found in test' in str(e) with pytest.raises(ValueError) as e: config.config = {'paths': {'test_strains': ['test{strain}']}} - config.set_strains() + config.set('strains') assert '{chrom} not found in test{strain}' in str(e) config.config = {'paths': {'test_strains': ['test{strain}{chrom}']}} - config.set_strains() + config.set('strains') mock_find.called_with(['test{strain}{chrom}']) - config.set_strains('test{strain}{chrom}') + config.set(strains='test{strain}{chrom}') mock_find.called_with(['test{strain}{chrom}']) @@ -355,161 +360,94 @@ def test_find_strains(config, mocker): assert config.strains == ['s1', 's2', 's3'] -def test_set_predict_files(config): - with pytest.raises(ValueError) as e: - config.set_predict_files('', '', '', '', '') - assert 'No initial hmm file provided' in str(e) - - with pytest.raises(ValueError) as e: - config.set_predict_files('init', '', '', '', '') - assert 'No trained hmm file provided' in str(e) - - with pytest.raises(ValueError) as e: - config.set_predict_files('init', 'trained', 'pos', 'prob', '') - assert 'No alignment file provided' in str(e) - - with pytest.raises(ValueError) as e: - config.set_predict_files('init', 'trained', 'pos', 'prob', 'align') - assert '{strain} not found in align' in str(e) - - with pytest.raises(ValueError) as e: - config.set_predict_files('init', 'trained', 'pos', 'prob', - 'align{prefix}') - assert '{strain} not found in align{prefix}' in str(e) - - with pytest.raises(ValueError) as e: - config.set_predict_files('init', 'trained', 'pos', 'prob', - 'align{prefix}{strain}') - assert '{chrom} not found in align{prefix}{strain}' in str(e) - - config.prefix = 'pre' - config.set_predict_files('init', 'trained', 'pos', 'prob', - 'align{prefix}{strain}{chrom}') - assert config.hmm_initial == 'init' - assert config.hmm_trained == 'trained' - assert config.positions == 'pos' - assert config.probabilities == 'prob' - assert config.alignment == 'alignpre{strain}{chrom}' - - with pytest.raises(ValueError) as e: - config.config = {'paths': {'analysis': {'hmm_initial': 'init'}}} - config.set_predict_files('', '', '', '', '') - assert 'No trained hmm file provided' in str(e) - - with pytest.raises(ValueError) as e: - config.config = {'paths': {'analysis': {'hmm_initial': 'init', - 'hmm_trained': 'trained', - 'positions': 'pos' - }}} - config.set_predict_files('', '', '', '', '') - assert 'No probabilities file provided' in str(e) - - with pytest.raises(ValueError) as e: - config.config = {'paths': {'analysis': {'hmm_initial': 'init', - 'hmm_trained': 'trained', - 'positions': 'pos', - 'probabilities': 'prob' - }}} - config.set_predict_files('', '', '', '', '') - assert 'No alignment file provided' in str(e) - - config.config = {'paths': {'analysis': { - 'hmm_initial': 'init', - 'hmm_trained': 'trained', - 'positions': 'pos', - 'probabilities': 'prob', - 'alignment': 'align{prefix}{strain}{chrom}' - }}} - config.set_predict_files('', '', '', '', '') - - assert config.hmm_initial == 'init' - assert config.hmm_trained == 'trained' - assert config.positions == 'pos' - assert config.probabilities == 'prob' - assert config.alignment == 'alignpre{strain}{chrom}' - - def test_set_alignment(config): - config.set_alignment('align{strain}{chrom}') + config.set(alignment='align{strain}{chrom}') assert config.alignment == 'align{strain}{chrom}' with pytest.raises(AttributeError) as e: - config.set_alignment('align{prefix}{strain}{chrom}') + config.set(alignment='align{prefix}{strain}{chrom}') assert "'Configuration' object has no attribute 'prefix'" in str(e) config.prefix = 'prefix' - config.set_alignment('align{prefix}{strain}{chrom}') + config.set(alignment='align{prefix}{strain}{chrom}') assert config.alignment == 'alignprefix{strain}{chrom}' -def test_set_regions_file(config): +def test_set_masked_file(config): with pytest.raises(ValueError) as e: - config.set_regions_files() - assert 'No region file provided' in str(e) + config.set('masks') + assert 'No masks provided' in str(e) with pytest.raises(ValueError) as e: - config.set_regions_files('region') - assert '{state} not found in region' in str(e) + config.set(masks='mask') + assert '{strain} not found in mask' in str(e) with pytest.raises(ValueError) as e: - config.set_regions_files('region{state}') - assert 'No region index file provided' in str(e) + config.set(masks='mask{strain}') + assert '{chrom} not found in mask{strain}' in str(e) - with pytest.raises(ValueError) as e: - config.set_regions_files('region{state}', 'index') - assert '{state} not found in index' in str(e) + config.set(masks='mask{strain}{chrom}') + assert config.masks == 'mask{strain}{chrom}' + + config.config = {'paths': {'analysis': + {'masked_intervals': 'msk{strain}{chrom}'}}} + config.set('masks') + assert config.masks == 'msk{strain}{chrom}' - config.set_regions_files('region{state}', 'index{state}') - assert config.regions == 'region{state}' - assert config.region_index == 'index{state}' - config.config = {'paths': {'analysis': {'regions': 'region{state}', - 'region_index': 'index{state}', - }}} - config.set_regions_files() - assert config.regions == 'region{state}' - assert config.region_index == 'index{state}' +def test_set_filter_threshold(config): + with pytest.raises(ValueError) as e: + config.set('filter_threshold') + assert 'No filter_threshold provided' in str(e) - # args overwrite config - config.set_regions_files('reg{state}', 'ind{state}') - assert config.regions == 'reg{state}' - assert config.region_index == 'ind{state}' + config.set(filter_threshold=0.9) + assert config.filter_threshold == 0.9 + config.config = {'analysis_params': {'filter_threshold': 0.8}} + config.set('filter_threshold') + assert config.filter_threshold == 0.8 -def test_set_quality_file(config): with pytest.raises(ValueError) as e: - config.set_quality_file() - assert 'No quality block file provided' in str(e) + config.set(filter_threshold='test') + assert 'Filter threshold is not a valid number' in str(e) - with pytest.raises(ValueError) as e: - config.set_quality_file('qual') - assert '{state} not found in qual' in str(e) - config.set_quality_file('qual{state}') - assert config.quality_blocks == 'qual{state}' +@pytest.fixture +def variable(): + return Variable('test') - config.config = {'paths': {'analysis': {'quality': 'qua{state}'}}} - config.set_quality_file() - assert config.quality_blocks == 'qua{state}' +def test_variable_init(variable): + assert variable.name == 'test' + assert variable.config_path == 'test' + assert variable.nullable is False + assert variable.wildcards is None -def test_set_masked_file(config): - with pytest.raises(ValueError) as e: - config.set_masked_file() - assert 'No masked interval file provided' in str(e) + var2 = Variable('test2', 'test.path', True, 'wild') + assert var2.name == 'test2' + assert var2.config_path == 'test.path' + assert var2.nullable is True + assert var2.wildcards == 'wild' - with pytest.raises(ValueError) as e: - config.set_masked_file('mask') - assert '{strain} not found in mask' in str(e) +def test_variable_parse(variable): with pytest.raises(ValueError) as e: - config.set_masked_file('mask{strain}') - assert '{chrom} not found in mask{strain}' in str(e) + variable.parse(None) + assert 'No test provided' in str(e) - config.set_masked_file('mask{strain}{chrom}') - assert config.masks == 'mask{strain}{chrom}' + assert variable.parse('test', {}) == 'test' + assert variable.parse(None, {'test': 'test'}) == 'test' - config.config = {'paths': {'analysis': - {'masked_intervals': 'msk{strain}{chrom}'}}} - config.set_masked_file() - assert config.masks == 'msk{strain}{chrom}' + variable.config_path = 'test.path' + assert variable.parse(None, {'test': {'path': 'test'}}) == 'test' + + variable.nullable = True + assert variable.parse(None) is None + assert variable.parse('test') == 'test' + + variable.wildcards = 'state' + with pytest.raises(ValueError) as e: + variable.parse('test') + assert '{state} not found in test' in str(e) + + assert variable.parse('test{state}') == 'test{state}' diff --git a/code/test/analyze/test_main_filter_regions_args.py b/code/test/analyze/test_main_filter_regions_args.py new file mode 100644 index 0000000..f0c1a8d --- /dev/null +++ b/code/test/analyze/test_main_filter_regions_args.py @@ -0,0 +1,283 @@ +import pytest +from click.testing import CliRunner +import analyze.main as main +import yaml +from analyze.filter_regions import Filterer + + +''' +Unit tests for the filter_regions command of main.py when parameters are +provided by args +''' + + +@pytest.fixture +def runner(): + return CliRunner() + + +def test_threshold(runner, mocker): + mock_log = mocker.patch('analyze.main.log.info') + with runner.isolated_filesystem(): + with open('config.yaml', 'w') as f: + yaml.dump( + { + 'analysis_params': { + 'known_states': [ + {'name': 's1'}, + {'name': 's2'}], + }, + }, f) + + result = runner.invoke( + main.cli, + '--config config.yaml filter-regions ' + '--thresh 0.9' + ) + + assert result.exit_code != 0 + assert str(result.exception) == 'No introgressed provided' + assert mock_log.call_args_list == [ + mocker.call('Verbosity set to WARNING'), + mocker.call('Read in 1 config file'), + mocker.call('Filter threshold set to \'0.9\''), + ] + + +def test_filter_files(runner, mocker): + mock_log = mocker.patch('analyze.main.log.info') + with runner.isolated_filesystem(): + with open('config.yaml', 'w') as f: + yaml.dump( + { + 'analysis_params': { + 'known_states': [ + {'name': 's1'}, + {'name': 's2'}], + }, + }, f) + + files = ['--introgress-filter int_{state}.txt', + '--introgress-inter int_int_{state}.txt', + '--ambiguous-filter amb_{state}.txt', + '--ambiguous-inter amb_int_{state}.txt', + '--filter-sweep filter.txt', + ] + results = [ + 'No introgressed_intermediate provided', + 'No ambiguous provided', + 'No ambiguous_intermediate provided', + 'No regions provided', # sweep is not required + 'No regions provided', + ] + for i, expected in enumerate(results): + result = runner.invoke( + main.cli, + '--config config.yaml filter-regions ' + '--thresh 0.9 ' + + ' '.join(files[0:i+1]) + ) + + assert result.exit_code != 0 + assert str(result.exception) == expected + log = [ + mocker.call('Verbosity set to WARNING'), + mocker.call('Read in 1 config file'), + mocker.call('Filter threshold set to \'0.9\''), + ] + if i >= 3: + log += [ + mocker.call('Introgressed filtered file ' + 'is \'int_{state}.txt\''), + mocker.call('Introgressed intermediate file ' + 'is \'int_int_{state}.txt\''), + mocker.call('Ambiguous filtered file ' + 'is \'amb_{state}.txt\''), + mocker.call('Ambiguous intermediate file ' + 'is \'amb_int_{state}.txt\''), + ] + # filter sweep line is not printed if it is unset + if i == 4: + log += [mocker.call('Filter sweep file is \'filter.txt\'')] + assert mock_log.call_args_list == log + mock_log.reset_mock() + + +def test_region_files(runner, mocker): + mock_log = mocker.patch('analyze.main.log.info') + with runner.isolated_filesystem(): + with open('config.yaml', 'w') as f: + yaml.dump( + { + 'analysis_params': { + 'known_states': [ + {'name': 's1'}, + {'name': 's2'}], + }, + }, f) + + result = runner.invoke( + main.cli, + '--config config.yaml filter-regions ' + '--thresh 0.9 ' + '--introgress-filter int_{state}.txt ' + '--introgress-inter int_int_{state}.txt ' + '--ambiguous-filter amb_{state}.txt ' + '--ambiguous-inter amb_int_{state}.txt ' + '--filter-sweep filter.txt ' + '--region region_{state}.gz ' + ) + + assert result.exit_code != 0 + assert str(result.exception) == 'No region_index provided' + assert mock_log.call_args_list == [ + mocker.call('Verbosity set to WARNING'), + mocker.call('Read in 1 config file'), + mocker.call('Filter threshold set to \'0.9\''), + mocker.call('Introgressed filtered file ' + 'is \'int_{state}.txt\''), + mocker.call('Introgressed intermediate file ' + 'is \'int_int_{state}.txt\''), + mocker.call('Ambiguous filtered file ' + 'is \'amb_{state}.txt\''), + mocker.call('Ambiguous intermediate file ' + 'is \'amb_int_{state}.txt\''), + mocker.call('Filter sweep file is \'filter.txt\'') + ] + mock_log.reset_mock() + + result = runner.invoke( + main.cli, + '--config config.yaml filter-regions ' + '--thresh 0.9 ' + '--introgress-filter int_{state}.txt ' + '--introgress-inter int_int_{state}.txt ' + '--ambiguous-filter amb_{state}.txt ' + '--ambiguous-inter amb_int_{state}.txt ' + '--filter-sweep filter.txt ' + '--region region_{state}.gz ' + '--region-index region_{state}.pkl ' + ) + + assert result.exit_code != 0 + assert str(result.exception) == 'No quality_blocks provided' + assert mock_log.call_args_list == [ + mocker.call('Verbosity set to WARNING'), + mocker.call('Read in 1 config file'), + mocker.call('Filter threshold set to \'0.9\''), + mocker.call('Introgressed filtered file ' + 'is \'int_{state}.txt\''), + mocker.call('Introgressed intermediate file ' + 'is \'int_int_{state}.txt\''), + mocker.call('Ambiguous filtered file ' + 'is \'amb_{state}.txt\''), + mocker.call('Ambiguous intermediate file ' + 'is \'amb_int_{state}.txt\''), + mocker.call('Filter sweep file is \'filter.txt\''), + mocker.call('Region file is \'region_{state}.gz\''), + mocker.call('Region index file is \'region_{state}.pkl\'') + ] + + +def test_quality_files(runner, mocker): + mock_log = mocker.patch('analyze.main.log.info') + mock_run = mocker.patch.object(Filterer, 'run') + with runner.isolated_filesystem(): + with open('config.yaml', 'w') as f: + yaml.dump( + { + 'analysis_params': { + 'known_states': [ + {'name': 's1'}, + {'name': 's2'}], + }, + }, f) + + result = runner.invoke( + main.cli, + '--config config.yaml filter-regions ' + '--thresh 0.9 ' + '--introgress-filter int_{state}.txt ' + '--introgress-inter int_int_{state}.txt ' + '--ambiguous-filter amb_{state}.txt ' + '--ambiguous-inter amb_int_{state}.txt ' + '--filter-sweep filter.txt ' + '--region region_{state}.gz ' + '--region-index region_{state}.pkl ' + '--quality quality_{state}.txt ' + ) + + assert result.exit_code == 0 + assert mock_log.call_args_list == [ + mocker.call('Verbosity set to WARNING'), + mocker.call('Read in 1 config file'), + mocker.call('Filter threshold set to \'0.9\''), + mocker.call('Introgressed filtered file ' + 'is \'int_{state}.txt\''), + mocker.call('Introgressed intermediate file ' + 'is \'int_int_{state}.txt\''), + mocker.call('Ambiguous filtered file ' + 'is \'amb_{state}.txt\''), + mocker.call('Ambiguous intermediate file ' + 'is \'amb_int_{state}.txt\''), + mocker.call('Filter sweep file is \'filter.txt\''), + mocker.call('Region file is \'region_{state}.gz\''), + mocker.call('Region index file is \'region_{state}.pkl\''), + mocker.call('Quality file is \'quality_{state}.txt\''), + mocker.call('Threshold sweep with: []'), + ] + + mock_run.assert_called_once_with([]) + + +def test_thresholds_files(runner, mocker): + mock_log = mocker.patch('analyze.main.log.info') + mock_run = mocker.patch.object(Filterer, 'run') + with runner.isolated_filesystem(): + with open('config.yaml', 'w') as f: + yaml.dump( + { + 'analysis_params': { + 'known_states': [ + {'name': 's1'}, + {'name': 's2'}], + }, + }, f) + + result = runner.invoke( + main.cli, + '--config config.yaml filter-regions ' + '--thresh 0.9 ' + '--introgress-filter int_{state}.txt ' + '--introgress-inter int_int_{state}.txt ' + '--ambiguous-filter amb_{state}.txt ' + '--ambiguous-inter amb_int_{state}.txt ' + '--filter-sweep filter.txt ' + '--region region_{state}.gz ' + '--region-index region_{state}.pkl ' + '--quality quality_{state}.txt ' + '1.0 .99 .98 .1 .01' + ) + + assert result.exit_code == 0 + assert mock_log.call_args_list == [ + mocker.call('Verbosity set to WARNING'), + mocker.call('Read in 1 config file'), + mocker.call('Filter threshold set to \'0.9\''), + mocker.call('Introgressed filtered file ' + 'is \'int_{state}.txt\''), + mocker.call('Introgressed intermediate file ' + 'is \'int_int_{state}.txt\''), + mocker.call('Ambiguous filtered file ' + 'is \'amb_{state}.txt\''), + mocker.call('Ambiguous intermediate file ' + 'is \'amb_int_{state}.txt\''), + mocker.call('Filter sweep file is \'filter.txt\''), + mocker.call('Region file is \'region_{state}.gz\''), + mocker.call('Region index file is \'region_{state}.pkl\''), + mocker.call('Quality file is \'quality_{state}.txt\''), + mocker.call('Threshold sweep with: ' + '[1.0, 0.99, 0.98, 0.1, 0.01]'), + ] + + mock_run.assert_called_once_with([1.0, 0.99, 0.98, 0.1, 0.01]) diff --git a/code/test/analyze/test_main_filter_regions_config.py b/code/test/analyze/test_main_filter_regions_config.py new file mode 100644 index 0000000..edbe091 --- /dev/null +++ b/code/test/analyze/test_main_filter_regions_config.py @@ -0,0 +1,282 @@ +import pytest +from click.testing import CliRunner +import analyze.main as main +import yaml +from analyze.filter_regions import Filterer + + +''' +Unit tests for the filter_regions command of main.py when parameters are +provided by args +''' + + +@pytest.fixture +def runner(): + return CliRunner() + + +def test_empty(runner, mocker): + mock_log = mocker.patch('analyze.main.log.info') + with runner.isolated_filesystem(): + result = runner.invoke( + main.cli, + 'filter-regions') + + assert result.exit_code != 0 + assert str(result.exception) == 'No states specified' + assert mock_log.call_args_list == [ + mocker.call('Verbosity set to WARNING'), + mocker.call('Read in 0 config files'), + ] + + +def test_states(runner, mocker): + mock_log = mocker.patch('analyze.main.log.info') + with runner.isolated_filesystem(): + with open('config.yaml', 'w') as f: + yaml.dump( + { + 'analysis_params': { + 'known_states': [ + {'name': 's1'}, + {'name': 's2'}], + }, + }, f) + + result = runner.invoke( + main.cli, + '--config config.yaml filter-regions') + + assert result.exit_code != 0 + assert str(result.exception) == 'No filter_threshold provided' + assert mock_log.call_args_list == [ + mocker.call('Verbosity set to WARNING'), + mocker.call('Read in 1 config file'), + ] + + +def test_threshold(runner, mocker): + mock_log = mocker.patch('analyze.main.log.info') + with runner.isolated_filesystem(): + with open('config.yaml', 'w') as f: + yaml.dump( + { + 'analysis_params': { + 'known_states': [ + {'name': 's1'}, + {'name': 's2'}], + 'filter_threshold': 0.9, + }, + }, f) + + result = runner.invoke( + main.cli, + '--config config.yaml filter-regions ' + ) + + assert result.exit_code != 0 + assert str(result.exception) == 'No introgressed provided' + assert mock_log.call_args_list == [ + mocker.call('Verbosity set to WARNING'), + mocker.call('Read in 1 config file'), + mocker.call('Filter threshold set to \'0.9\''), + ] + + +def test_filter_files(runner, mocker): + mock_log = mocker.patch('analyze.main.log.info') + with runner.isolated_filesystem(): + with open('config.yaml', 'w') as f: + yaml.dump( + { + 'analysis_params': { + 'known_states': [ + {'name': 's1'}, + {'name': 's2'}], + 'filter_threshold': 0.9, + }, + 'paths': {'analysis': { + 'introgressed': 'int_{state}.txt', + 'introgressed_intermediate': 'int_int_{state}.txt', + 'ambiguous': 'amb_{state}.txt', + 'ambiguous_intermediate': 'amb_int_{state}.txt', + }} + }, f) + result = runner.invoke( + main.cli, + '--config config.yaml filter-regions ' + ) + + assert result.exit_code != 0 + assert str(result.exception) == 'No regions provided' + log = [ + mocker.call('Verbosity set to WARNING'), + mocker.call('Read in 1 config file'), + mocker.call('Filter threshold set to \'0.9\''), + mocker.call('Introgressed filtered file ' + 'is \'int_{state}.txt\''), + mocker.call('Introgressed intermediate file ' + 'is \'int_int_{state}.txt\''), + mocker.call('Ambiguous filtered file ' + 'is \'amb_{state}.txt\''), + mocker.call('Ambiguous intermediate file ' + 'is \'amb_int_{state}.txt\''), + ] + assert mock_log.call_args_list == log + + +def test_region_files(runner, mocker): + mock_log = mocker.patch('analyze.main.log.info') + with runner.isolated_filesystem(): + with open('config.yaml', 'w') as f: + yaml.dump( + { + 'analysis_params': { + 'known_states': [ + {'name': 's1'}, + {'name': 's2'}], + 'filter_threshold': 0.9, + }, + 'paths': {'analysis': { + 'introgressed': 'int_{state}.txt', + 'introgressed_intermediate': 'int_int_{state}.txt', + 'ambiguous': 'amb_{state}.txt', + 'ambiguous_intermediate': 'amb_int_{state}.txt', + 'filter_sweep': 'filter.txt', + 'regions': 'region_{state}.gz', + 'region_index': 'region_{state}.pkl', + }} + }, f) + + result = runner.invoke( + main.cli, + '--config config.yaml filter-regions ' + ) + + assert result.exit_code != 0 + assert str(result.exception) == 'No quality_blocks provided' + assert mock_log.call_args_list == [ + mocker.call('Verbosity set to WARNING'), + mocker.call('Read in 1 config file'), + mocker.call('Filter threshold set to \'0.9\''), + mocker.call('Introgressed filtered file ' + 'is \'int_{state}.txt\''), + mocker.call('Introgressed intermediate file ' + 'is \'int_int_{state}.txt\''), + mocker.call('Ambiguous filtered file ' + 'is \'amb_{state}.txt\''), + mocker.call('Ambiguous intermediate file ' + 'is \'amb_int_{state}.txt\''), + mocker.call('Filter sweep file is \'filter.txt\''), + mocker.call('Region file is \'region_{state}.gz\''), + mocker.call('Region index file is \'region_{state}.pkl\'') + ] + + +def test_quality_files(runner, mocker): + mock_log = mocker.patch('analyze.main.log.info') + mock_run = mocker.patch.object(Filterer, 'run') + with runner.isolated_filesystem(): + with open('config.yaml', 'w') as f: + yaml.dump( + { + 'analysis_params': { + 'known_states': [ + {'name': 's1'}, + {'name': 's2'}], + 'filter_threshold': 0.9, + }, + 'paths': {'analysis': { + 'introgressed': 'int_{state}.txt', + 'introgressed_intermediate': 'int_int_{state}.txt', + 'ambiguous': 'amb_{state}.txt', + 'ambiguous_intermediate': 'amb_int_{state}.txt', + 'filter_sweep': 'filter.txt', + 'regions': 'region_{state}.gz', + 'region_index': 'region_{state}.pkl', + 'quality_blocks': 'quality_{state}.txt', + }} + }, f) + + result = runner.invoke( + main.cli, + '--config config.yaml filter-regions ' + ) + + assert result.exit_code == 0 + assert mock_log.call_args_list == [ + mocker.call('Verbosity set to WARNING'), + mocker.call('Read in 1 config file'), + mocker.call('Filter threshold set to \'0.9\''), + mocker.call('Introgressed filtered file ' + 'is \'int_{state}.txt\''), + mocker.call('Introgressed intermediate file ' + 'is \'int_int_{state}.txt\''), + mocker.call('Ambiguous filtered file ' + 'is \'amb_{state}.txt\''), + mocker.call('Ambiguous intermediate file ' + 'is \'amb_int_{state}.txt\''), + mocker.call('Filter sweep file is \'filter.txt\''), + mocker.call('Region file is \'region_{state}.gz\''), + mocker.call('Region index file is \'region_{state}.pkl\''), + mocker.call('Quality file is \'quality_{state}.txt\''), + mocker.call('Threshold sweep with: []'), + ] + + mock_run.assert_called_once_with([]) + + +def test_thresholds_files(runner, mocker): + mock_log = mocker.patch('analyze.main.log.info') + mock_run = mocker.patch.object(Filterer, 'run') + with runner.isolated_filesystem(): + with open('config.yaml', 'w') as f: + yaml.dump( + { + 'analysis_params': { + 'known_states': [ + {'name': 's1'}, + {'name': 's2'}], + 'filter_threshold': 0.9, + }, + 'paths': {'analysis': { + 'introgressed': 'int_{state}.txt', + 'introgressed_intermediate': 'int_int_{state}.txt', + 'ambiguous': 'amb_{state}.txt', + 'ambiguous_intermediate': 'amb_int_{state}.txt', + 'filter_sweep': 'filter.txt', + 'regions': 'region_{state}.gz', + 'region_index': 'region_{state}.pkl', + 'quality_blocks': 'quality_{state}.txt', + }} + }, f) + + result = runner.invoke( + main.cli, + '--config config.yaml filter-regions ' + '1.0 .99 .98 .1 .01' + ) + + assert result.exit_code == 0 + assert mock_log.call_args_list == [ + mocker.call('Verbosity set to WARNING'), + mocker.call('Read in 1 config file'), + mocker.call('Filter threshold set to \'0.9\''), + mocker.call('Introgressed filtered file ' + 'is \'int_{state}.txt\''), + mocker.call('Introgressed intermediate file ' + 'is \'int_int_{state}.txt\''), + mocker.call('Ambiguous filtered file ' + 'is \'amb_{state}.txt\''), + mocker.call('Ambiguous intermediate file ' + 'is \'amb_int_{state}.txt\''), + mocker.call('Filter sweep file is \'filter.txt\''), + mocker.call('Region file is \'region_{state}.gz\''), + mocker.call('Region index file is \'region_{state}.pkl\''), + mocker.call('Quality file is \'quality_{state}.txt\''), + mocker.call('Threshold sweep with: ' + '[1.0, 0.99, 0.98, 0.1, 0.01]'), + ] + + mock_run.assert_called_once_with([1.0, 0.99, 0.98, 0.1, 0.01]) diff --git a/code/test/analyze/test_main_id_args.py b/code/test/analyze/test_main_id_args.py index c35b4f0..6c4db93 100644 --- a/code/test/analyze/test_main_id_args.py +++ b/code/test/analyze/test_main_id_args.py @@ -30,7 +30,7 @@ def test_states(runner, mocker): '--config config.yaml id-regions --state s1 --state s2') assert result.exit_code != 0 - assert str(result.exception) == 'No block file provided' + assert str(result.exception) == 'No blocks provided' assert mock_log.call_args_list == [ mocker.call('Verbosity set to WARNING'), mocker.call('Read in 1 config file'), @@ -55,7 +55,7 @@ def test_block_file(runner, mocker): '--blocks block_{state}.txt ') assert result.exit_code != 0 - assert str(result.exception) == 'No labeled block file provided' + assert str(result.exception) == 'No labeled_blocks provided' assert mock_log.call_args_list == [ mocker.call('Verbosity set to WARNING'), mocker.call('Read in 1 config file'), diff --git a/code/test/analyze/test_main_id_config.py b/code/test/analyze/test_main_id_config.py index eb41378..60024d2 100644 --- a/code/test/analyze/test_main_id_config.py +++ b/code/test/analyze/test_main_id_config.py @@ -21,7 +21,7 @@ def test_empty(runner): main.cli, 'id-regions') assert result.exit_code != 0 - assert str(result.exception) == 'No chromosomes specified in config file!' + assert str(result.exception) == 'No chromosomes provided' def test_chroms(runner, mocker): @@ -65,7 +65,7 @@ def test_states(runner, mocker): '--config config.yaml id-regions') assert result.exit_code != 0 - assert str(result.exception) == 'No block file provided' + assert str(result.exception) == 'No blocks provided' assert mock_log.call_args_list == [ mocker.call('Verbosity set to WARNING'), mocker.call('Read in 1 config file'), @@ -96,7 +96,7 @@ def test_block_file(runner, mocker): '--config config.yaml id-regions') assert result.exit_code != 0 - assert str(result.exception) == 'No labeled block file provided' + assert str(result.exception) == 'No labeled_blocks provided' assert mock_log.call_args_list == [ mocker.call('Verbosity set to WARNING'), mocker.call('Read in 1 config file'), diff --git a/code/test/analyze/test_main_predict_args.py b/code/test/analyze/test_main_predict_args.py index c5668f8..c108ef7 100644 --- a/code/test/analyze/test_main_predict_args.py +++ b/code/test/analyze/test_main_predict_args.py @@ -31,7 +31,7 @@ def test_threshold(runner, mocker): '--config config.yaml predict --threshold 0.05') assert result.exit_code != 0 - assert str(result.exception) == 'No block file provided' + assert str(result.exception) == 'No blocks provided' assert mock_log.call_args_list == [ mocker.call('Verbosity set to WARNING'), mocker.call('Read in 1 config file'), @@ -136,7 +136,7 @@ def test_test_strains(runner, mocker): assert result.exit_code != 0 assert str(result.exception) == \ - 'No initial hmm file provided' + 'No hmm_initial provided' print(mock_log.call_args_list) assert mock_log.call_args_list == [ @@ -188,7 +188,7 @@ def test_outputs(runner, mocker): assert result.exit_code != 0 assert str(result.exception) == \ - 'No trained hmm file provided' + 'No hmm_trained provided' assert mock_log.call_args_list == mock_calls with runner.isolated_filesystem(): @@ -215,7 +215,7 @@ def test_outputs(runner, mocker): assert result.exit_code != 0 assert str(result.exception) == \ - 'No positions file provided' + 'No positions provided' assert mock_log.call_args_list == mock_calls with runner.isolated_filesystem(): @@ -244,7 +244,7 @@ def test_outputs(runner, mocker): assert result.exit_code != 0 assert str(result.exception) == \ - 'No alignment file provided' + 'No alignment provided' assert mock_log.call_args_list == mock_calls with runner.isolated_filesystem(): diff --git a/code/test/analyze/test_main_predict_config.py b/code/test/analyze/test_main_predict_config.py index f53bd40..af5dcb8 100644 --- a/code/test/analyze/test_main_predict_config.py +++ b/code/test/analyze/test_main_predict_config.py @@ -22,7 +22,7 @@ def test_chroms(runner, mocker): main.cli, 'predict') assert result.exit_code != 0 - assert str(result.exception) == 'No chromosomes specified in config file!' + assert str(result.exception) == 'No chromosomes provided' mock_log = mocker.patch('analyze.main.log.info') with runner.isolated_filesystem(): @@ -62,7 +62,7 @@ def test_threshold(runner, mocker): '--config config.yaml predict') assert result.exit_code != 0 - assert str(result.exception) == 'No block file provided' + assert str(result.exception) == 'No blocks provided' assert mock_log.call_args_list == [ mocker.call('Verbosity set to WARNING'), mocker.call('Read in 1 config file'), @@ -166,7 +166,7 @@ def test_strains(runner, mocker): assert result.exit_code != 0 assert str(result.exception) == \ - 'No initial hmm file provided' + 'No hmm_initial provided' assert mock_log.call_args_list == [ mocker.call('Verbosity set to WARNING'), mocker.call('Read in 1 config file'), @@ -212,7 +212,7 @@ def test_test_strains(runner, mocker): assert result.exit_code != 0 assert str(result.exception) == \ - 'No initial hmm file provided' + 'No hmm_initial provided' assert mock_log.call_args_list == [ mocker.call('Verbosity set to WARNING'), @@ -265,7 +265,7 @@ def test_outputs(runner, mocker): assert result.exit_code != 0 assert str(result.exception) == \ - 'No trained hmm file provided' + 'No hmm_trained provided' assert mock_log.call_args_list == mock_calls with runner.isolated_filesystem(): @@ -294,7 +294,7 @@ def test_outputs(runner, mocker): assert result.exit_code != 0 assert str(result.exception) == \ - 'No positions file provided' + 'No positions provided' assert mock_log.call_args_list == mock_calls with runner.isolated_filesystem(): @@ -325,7 +325,7 @@ def test_outputs(runner, mocker): assert result.exit_code != 0 assert str(result.exception) == \ - 'No alignment file provided' + 'No alignment provided' assert mock_log.call_args_list == mock_calls with runner.isolated_filesystem(): diff --git a/code/test/analyze/test_main_summarize_regions_args.py b/code/test/analyze/test_main_summarize_regions_args.py index cc03ea1..c36322d 100644 --- a/code/test/analyze/test_main_summarize_regions_args.py +++ b/code/test/analyze/test_main_summarize_regions_args.py @@ -39,7 +39,7 @@ def test_empty(runner, mocker): '--config config.yaml summarize-regions') assert result.exit_code != 0 - assert str(result.exception) == 'No labeled block file provided' + assert str(result.exception) == 'No labeled_blocks provided' assert mock_log.call_args_list == [ mocker.call('Verbosity set to WARNING'), mocker.call('Read in 1 config file'), @@ -70,7 +70,7 @@ def test_labeled(runner, mocker): '--config config.yaml summarize-regions --labeled {state}lbl.txt') assert result.exit_code != 0 - assert str(result.exception) == 'No quality block file provided' + assert str(result.exception) == 'No quality_blocks provided' assert mock_log.call_args_list == [ mocker.call('Verbosity set to WARNING'), mocker.call('Read in 1 config file'), @@ -103,7 +103,7 @@ def test_quality(runner, mocker): '--quality {state}qual.txt') assert result.exit_code != 0 - assert str(result.exception) == 'No masked interval file provided' + assert str(result.exception) == 'No masks provided' assert mock_log.call_args_list == [ mocker.call('Verbosity set to WARNING'), mocker.call('Read in 1 config file'), @@ -137,7 +137,7 @@ def test_masked(runner, mocker): '--quality {state}qual.txt --masks {strain}_{chrom}mask.txt') assert result.exit_code != 0 - assert str(result.exception) == 'No alignment file provided' + assert str(result.exception) == 'No alignment provided' assert mock_log.call_args_list == [ mocker.call('Verbosity set to WARNING'), mocker.call('Read in 1 config file'), @@ -174,7 +174,7 @@ def test_alignment(runner, mocker): '--alignment {strain}_{chrom}_align.txt') assert result.exit_code != 0 - assert str(result.exception) == 'No positions file provided' + assert str(result.exception) == 'No positions provided' assert mock_log.call_args_list == [ mocker.call('Verbosity set to WARNING'), mocker.call('Read in 1 config file'), @@ -194,7 +194,7 @@ def test_alignment(runner, mocker): '--alignment {prefix}_{strain}_{chrom}_align.txt') assert result.exit_code != 0 - assert str(result.exception) == 'No positions file provided' + assert str(result.exception) == 'No positions provided' assert mock_log.call_args_list == [ mocker.call('Verbosity set to WARNING'), mocker.call('Read in 1 config file'), @@ -232,7 +232,7 @@ def test_positions(runner, mocker): '--alignment {strain}_{chrom}_align.txt --positions pos.txt') assert result.exit_code != 0 - assert str(result.exception) == 'No region file provided' + assert str(result.exception) == 'No regions provided' assert mock_log.call_args_list == [ mocker.call('Verbosity set to WARNING'), mocker.call('Read in 1 config file'), @@ -272,7 +272,7 @@ def test_region(runner, mocker): ) assert result.exit_code != 0 - assert str(result.exception) == 'No region index file provided' + assert str(result.exception) == 'No region_index provided' assert mock_log.call_args_list == [ mocker.call('Verbosity set to WARNING'), mocker.call('Read in 1 config file'), diff --git a/code/test/analyze/test_main_summarize_regions_config.py b/code/test/analyze/test_main_summarize_regions_config.py index 5419d35..cecbe0e 100644 --- a/code/test/analyze/test_main_summarize_regions_config.py +++ b/code/test/analyze/test_main_summarize_regions_config.py @@ -39,7 +39,7 @@ def test_empty(runner, mocker): '--config config.yaml summarize-regions') assert result.exit_code != 0 - assert str(result.exception) == 'No labeled block file provided' + assert str(result.exception) == 'No labeled_blocks provided' assert mock_log.call_args_list == [ mocker.call('Verbosity set to WARNING'), mocker.call('Read in 1 config file'), @@ -73,7 +73,7 @@ def test_labeled(runner, mocker): '--config config.yaml summarize-regions') assert result.exit_code != 0 - assert str(result.exception) == 'No quality block file provided' + assert str(result.exception) == 'No quality_blocks provided' assert mock_log.call_args_list == [ mocker.call('Verbosity set to WARNING'), mocker.call('Read in 1 config file'), @@ -100,7 +100,7 @@ def test_quality(runner, mocker): 'chromosomes': 'I II III'.split(), 'paths': {'analysis': { 'labeled_blocks': '{state}lbl.txt', - 'quality': '{state}qual.txt', + 'quality_blocks': '{state}qual.txt', }} }, f) @@ -109,7 +109,7 @@ def test_quality(runner, mocker): '--config config.yaml summarize-regions') assert result.exit_code != 0 - assert str(result.exception) == 'No masked interval file provided' + assert str(result.exception) == 'No masks provided' assert mock_log.call_args_list == [ mocker.call('Verbosity set to WARNING'), mocker.call('Read in 1 config file'), @@ -137,7 +137,7 @@ def test_masked(runner, mocker): 'chromosomes': 'I II III'.split(), 'paths': {'analysis': { 'labeled_blocks': '{state}lbl.txt', - 'quality': '{state}qual.txt', + 'quality_blocks': '{state}qual.txt', 'masked_intervals': '{strain}_{chrom}mask.txt', }} }, f) @@ -147,7 +147,7 @@ def test_masked(runner, mocker): '--config config.yaml summarize-regions') assert result.exit_code != 0 - assert str(result.exception) == 'No alignment file provided' + assert str(result.exception) == 'No alignment provided' assert mock_log.call_args_list == [ mocker.call('Verbosity set to WARNING'), mocker.call('Read in 1 config file'), @@ -176,7 +176,7 @@ def test_alignment(runner, mocker): 'chromosomes': 'I II III'.split(), 'paths': {'analysis': { 'labeled_blocks': '{state}lbl.txt', - 'quality': '{state}qual.txt', + 'quality_blocks': '{state}qual.txt', 'masked_intervals': '{strain}_{chrom}mask.txt', 'alignment': '{strain}_{chrom}_align.txt', }} @@ -188,7 +188,7 @@ def test_alignment(runner, mocker): '--config config.yaml summarize-regions') assert result.exit_code != 0 - assert str(result.exception) == 'No positions file provided' + assert str(result.exception) == 'No positions provided' assert mock_log.call_args_list == [ mocker.call('Verbosity set to WARNING'), mocker.call('Read in 1 config file'), @@ -218,7 +218,7 @@ def test_positions(runner, mocker): 'chromosomes': 'I II III'.split(), 'paths': {'analysis': { 'labeled_blocks': '{state}lbl.txt', - 'quality': '{state}qual.txt', + 'quality_blocks': '{state}qual.txt', 'masked_intervals': '{strain}_{chrom}mask.txt', 'alignment': '{strain}_{chrom}_align.txt', 'positions': 'pos.txt', @@ -230,7 +230,7 @@ def test_positions(runner, mocker): '--config config.yaml summarize-regions') assert result.exit_code != 0 - assert str(result.exception) == 'No region file provided' + assert str(result.exception) == 'No regions provided' assert mock_log.call_args_list == [ mocker.call('Verbosity set to WARNING'), mocker.call('Read in 1 config file'), @@ -261,7 +261,7 @@ def test_region(runner, mocker): 'chromosomes': 'I II III'.split(), 'paths': {'analysis': { 'labeled_blocks': '{state}lbl.txt', - 'quality': '{state}qual.txt', + 'quality_blocks': '{state}qual.txt', 'masked_intervals': '{strain}_{chrom}mask.txt', 'alignment': '{strain}_{chrom}_align.txt', 'positions': 'pos.txt', @@ -274,7 +274,7 @@ def test_region(runner, mocker): '--config config.yaml summarize-regions') assert result.exit_code != 0 - assert str(result.exception) == 'No region index file provided' + assert str(result.exception) == 'No region_index provided' assert mock_log.call_args_list == [ mocker.call('Verbosity set to WARNING'), mocker.call('Read in 1 config file'), @@ -306,7 +306,7 @@ def test_run(runner, mocker): 'chromosomes': 'I II III'.split(), 'paths': {'analysis': { 'labeled_blocks': '{state}lbl.txt', - 'quality': '{state}qual.txt', + 'quality_blocks': '{state}qual.txt', 'masked_intervals': '{strain}_{chrom}mask.txt', 'alignment': '{strain}_{chrom}_align.txt', 'positions': 'pos.txt', diff --git a/code/test/analyze/test_predict_hmm_builder.py b/code/test/analyze/test_predict_hmm_builder.py index bb70b68..ee92c7e 100644 --- a/code/test/analyze/test_predict_hmm_builder.py +++ b/code/test/analyze/test_predict_hmm_builder.py @@ -39,7 +39,7 @@ def default_builder(config): } } builder = predict.HMM_Builder(config) - config.set_states() + config.set('states') builder.set_expected_values() builder.update_expected_length(1e5) return builder @@ -185,7 +185,7 @@ def test_set_expected_values(builder, config): ] } } - config.set_states() + config.set('states') builder.set_expected_values() assert builder.expected_lengths == { 'CBS432': 10, @@ -230,7 +230,7 @@ def test_update_expected_length(builder, config): ] } } - config.set_states() + config.set('states') builder.set_expected_values() assert builder.expected_lengths == { diff --git a/code/test/analyze/test_predict_predictor.py b/code/test/analyze/test_predict_predictor.py index d8f0928..c1d19cb 100644 --- a/code/test/analyze/test_predict_predictor.py +++ b/code/test/analyze/test_predict_predictor.py @@ -30,7 +30,7 @@ def config(): @pytest.fixture def predictor(config): result = predict.Predictor(config) - config.set_states() + config.set('states') return result @@ -569,14 +569,14 @@ def test_write_state_probs(predictor): def test_process_path(predictor, config, hm): probs = hm.posterior_decoding()[0] - config.set_threshold(0.8) + config.set(threshold=0.8) config.states = 'N E'.split() config.known_states = 'N E'.split() path, probability = predictor.process_path(hm) assert (probability == probs).all() assert path == 'E E N E E N E E N N'.split() - config.set_threshold('viterbi') + config.set(threshold='viterbi') path, probability = predictor.process_path(hm) assert (probability == probs).all() diff --git a/code/test/analyze/test_summarize_region_quality.py b/code/test/analyze/test_summarize_region_quality.py index c7db4f7..068bfd0 100644 --- a/code/test/analyze/test_summarize_region_quality.py +++ b/code/test/analyze/test_summarize_region_quality.py @@ -23,7 +23,7 @@ def test_states_to_process(summarizer, mocker): 'unknown_states': [{'name': 'unknown'}] } }) - summarizer.config.set_states() + summarizer.config.set('states') assert summarizer.states_to_process() == \ (0, 'S288c CBS432 N_45 unknown'.split()) @@ -57,16 +57,14 @@ def test_run(summarizer, mocker): 'unknown_states': [{'name': 'unknown'}] } }) - summarizer.config.set_states() - summarizer.config.set_HMM_symbols() - summarizer.config.set_positions('positions.txt.gz') - summarizer.config.set_labeled_blocks_file( - 'dir/tag/blocks_{state}_labeled.txt') - summarizer.config.set_quality_file('dir/tag/blocks_{state}_quality.txt') - summarizer.config.set_alignment('dir/tag/blocks_{chrom}_{strain}.txt') - summarizer.config.set_regions_files('dir/tag/regions/{state}.fa.gz', - 'dir/tag/regions/{state}.pkl') - summarizer.config.set_masked_file('dir/masked/{strain}_chr{chrom}.txt') + summarizer.config.set('symbols', 'states', + positions='positions.txt.gz', + labeled_blocks='dir/tag/blocks_{state}_labeled.txt', + quality_blocks='dir/tag/blocks_{state}_quality.txt', + alignment='dir/tag/blocks_{chrom}_{strain}.txt', + regions='dir/tag/regions/{state}.fa.gz', + region_index='dir/tag/regions/{state}.pkl', + masks='dir/masked/{strain}_chr{chrom}.txt') summarizer.config.chromosomes = ['I', 'II'] summarizer.validate_arguments() # for region database @@ -246,16 +244,15 @@ def test_run_all_states(summarizer, mocker): 'unknown_states': [{'name': 'unknown'}] } }) - summarizer.config.set_states() - summarizer.config.set_HMM_symbols() - summarizer.config.set_positions('positions.txt.gz') - summarizer.config.set_labeled_blocks_file( - 'dir/tag/blocks_{state}_labeled.txt') - summarizer.config.set_quality_file('dir/tag/blocks_{state}_quality.txt') - summarizer.config.set_alignment('dir/tag/blocks_{chrom}_{strain}.txt') - summarizer.config.set_regions_files('dir/tag/regions/{state}.fa.gz', - 'dir/tag/regions/{state}.pkl') - summarizer.config.set_masked_file('dir/masked/{strain}_chr{chrom}.txt') + summarizer.config.set('states', + 'symbols', + positions='positions.txt.gz', + labeled_blocks='dir/tag/blocks_{state}_labeled.txt', + quality_blocks='dir/tag/blocks_{state}_quality.txt', + alignment='dir/tag/blocks_{chrom}_{strain}.txt', + regions='dir/tag/regions/{state}.fa.gz', + region_index='dir/tag/regions/{state}.pkl', + masks='dir/masked/{strain}_chr{chrom}.txt') summarizer.config.chromosomes = ['I', 'II'] assert summarizer.validate_arguments() diff --git a/code/test/helper_scripts/compare_filter_outputs.sh b/code/test/helper_scripts/compare_filter_outputs.sh index 533888b..f0894e8 100755 --- a/code/test/helper_scripts/compare_filter_outputs.sh +++ b/code/test/helper_scripts/compare_filter_outputs.sh @@ -5,27 +5,31 @@ expected=/tigress/tcomi/aclark4_temp/results/analysisp4e2/ echo starting comarison of $(basename $actual) to $(basename $expected) for file in $(ls ${expected}*_filtered1.txt); do - act=$(echo $file | sed 's/p4e2/_test/g') + act=$(echo $file | sed 's/p4e2_filtered1.txt/filter1.txt/g') cmp <(sort $act) <(sort $file) \ - && echo $file passed! || echo $file failed #&& exit + && echo YAY! $file passed! || echo $file failed #&& exit done for file in $(ls ${expected}*_filtered1intermediate.txt); do - act=$(echo $file | sed 's/p4e2/_test/g') + act=$(echo $file | sed 's/p4e2_filtered1intermediate.txt/filter1inter.txt/g') cmp <(sort $act | python intermediate_format_1.py) \ <(sort $file | python intermediate_format_1.py) \ - && echo $file passed! || echo $file failed #&& exit + && echo YAY! $file passed! || echo $file failed #&& exit done for file in $(ls ${expected}*_filtered2.txt); do - act=$(echo $file | sed 's/p4e2/_test/g') + act=$(echo $file | sed 's/p4e2_filtered2.txt/filter2.txt/g') cmp <(sort $act) <(sort $file) \ - && echo $file passed! || echo $file failed #&& exit + && echo YAY! $file passed! || echo $file failed #&& exit done for file in $(ls ${expected}*_filtered2intermediate.txt); do - act=$(echo $file | sed 's/p4e2/_test/g') + act=$(echo $file | sed 's/p4e2_filtered2intermediate.txt/filter2inter.txt/g') cmp <(sort $act | python intermediate_format_2.py) \ <(sort $file | python intermediate_format_2.py) \ - && echo $file passed! || echo $file failed && exit + && echo YAY! $file passed! || echo $file failed && exit done + +cmp <(sort ${expected}/filter_2_thresholds_p4e2.txt) \ + <(sort ${expected}/filter2_thresholds.txt) \ + && echo YAY! thresholds passed! || echo thresholds failed diff --git a/code/test/helper_scripts/run_filter_2_thresholds.sh b/code/test/helper_scripts/run_filter_2_thresholds.sh index e6144ab..9cb7708 100755 --- a/code/test/helper_scripts/run_filter_2_thresholds.sh +++ b/code/test/helper_scripts/run_filter_2_thresholds.sh @@ -4,11 +4,14 @@ #SBATCH -n 1 #SBATCH -o "/tigress/tcomi/aclark4_temp/results/thresh_%A" -export PYTHONPATH=/home/tcomi/projects/aclark4_introgression/code/ +config=/home/tcomi/projects/aclark4_introgression/code/config.yaml module load anaconda3 conda activate introgression3 -ARGS="_test .001 viterbi 10000 .025 10000 .025 10000 .025 10000 .025 unknown 1000 .01" - -python ${PYTHONPATH}analyze/filter_2_thresholds_main.py $ARGS +introgression \ + --config $config \ + --log-file test.log \ + -vv \ + filter-regions \ + .999 .995 .985 .975 .965 .955 .945 .935 .925 .915 .905 .89 .87 .86 diff --git a/code/test/hmm/test_hmm_bw.py b/code/test/hmm/test_hmm_bw.py index b271bc0..7e6bf21 100644 --- a/code/test/hmm/test_hmm_bw.py +++ b/code/test/hmm/test_hmm_bw.py @@ -221,15 +221,15 @@ def test_forward(hm): for o in range(1, len(hm.observations[seq])): row = [] for current in range(len(hm.hidden_states)): - total = -np.inf - for prev in range(len(hm.hidden_states)): - total = np.logaddexp( - total, - alpha_current[o-1][prev] + - np.log(hm.transitions[prev][current])) - total += np.log( - hm.emissions[current][hm.observations[seq][o]]) - row.append(total) + total = -np.inf + for prev in range(len(hm.hidden_states)): + total = np.logaddexp( + total, + alpha_current[o-1][prev] + + np.log(hm.transitions[prev][current])) + total += np.log( + hm.emissions[current][hm.observations[seq][o]]) + row.append(total) alpha_current.append(row) alpha_iter.append(alpha_current) @@ -312,8 +312,8 @@ def test_bw(hm): norm = np.logaddexp(norm, prob) for i in range(len(hm.hidden_states)): - for j in range(len(hm.hidden_states)): - matrix[i][j] = matrix[i][j] - norm + for j in range(len(hm.hidden_states)): + matrix[i][j] = matrix[i][j] - norm xi_current.append(matrix) @@ -346,13 +346,13 @@ def test_calc_probs(hm): max_prob = np.NINF max_state = -1 for prev_state in range(len(hm.hidden_states)): - trans_prob = hm.transitions[prev_state][end_state] - emis_prob = hm.emissions[end_state][hm.observations[pos]] - prob = iter_probs[pos - 1][prev_state] + \ - np.log(trans_prob) + np.log(emis_prob) - if prob > max_prob: - max_prob = prob - max_state = prev_state + trans_prob = hm.transitions[prev_state][end_state] + emis_prob = hm.emissions[end_state][hm.observations[pos]] + prob = iter_probs[pos - 1][prev_state] + \ + np.log(trans_prob) + np.log(emis_prob) + if prob > max_prob: + max_prob = prob + max_state = prev_state iter_probs[pos].append(max_prob) iter_states[pos].append(max_state) From b00140a7d5dc8b4205f0c984ee446e35e83d9ef6 Mon Sep 17 00:00:00 2001 From: Troy Comi Date: Sat, 1 Jun 2019 15:09:56 -0400 Subject: [PATCH 24/33] Flake8 passing, summarize_strains refactored Finished refactor and testing of summarize_strain_states Finished formatting code consistent with FLAKE8 --- code/analyze/introgression_configuration.py | 2 + code/analyze/main.py | 36 ++ code/analyze/summarize_strain_states.py | 208 +++++++ code/analyze/summarize_strain_states_main.py | 124 ---- code/beer_strains/fastq_to_fasta.py | 46 +- code/beer_strains/vcf_to_fasta.py | 8 +- code/config.yaml | 16 +- code/hmm/hmm_bw.py | 7 +- code/misc/mystats.py | 51 +- code/misc/overlap.py | 15 +- code/misc/read_maf.py | 32 +- code/misc/read_origin.py | 6 +- code/misc/seq_id.py | 85 +-- code/misc/write_fasta.py | 3 +- code/phylo-hmm/gen_phylo_hmm_input_file.py | 239 ++++---- code/phylo-hmm/gen_sim_seqs.py | 32 +- code/phylo-hmm/remove_gaps.py | 3 +- ...un_gen_sim_seqs.py => run_gen_sim_seqs.sh} | 16 +- code/phylo-hmm/sim_analyze_phylo.py | 391 ++++++------ code/phylo-hmm/sim_analyze_phylo_main.py | 97 +-- code/phylogeny/format_for_phylip_main.py | 6 +- ...ormat_shared_introgression_for_plotting.py | 28 +- code/phylogeny/get_gene_set.py | 8 +- code/phylogeny/get_gene_set_main.py | 115 ++-- code/phylogeny/get_sequences_main.py | 72 +-- code/phylogeny/make_dollop_input_main.py | 28 +- .../make_shared_introgression_matrix_main.py | 37 +- ...gression_nonsingleton_polymorphism_main.py | 11 - .../shared_introgression_polymorphism_main.py | 61 +- code/sim/aggregate.py | 37 +- code/sim/aggregate_main.py | 11 +- code/sim/compare_introgressed.py | 199 +++--- code/sim/compare_introgressed_main.py | 30 +- code/sim/concordance_functions.py | 72 ++- code/sim/concordance_functions_old.py | 98 +-- code/sim/fix_summary.py | 2 +- code/sim/ils_rosenberg.py | 134 +++-- code/sim/ils_rosenberg_main.py | 54 +- code/sim/plot_rep_1_setup.py | 18 +- code/sim/plot_rep_1_setup_main.py | 55 +- code/sim/plot_rep_setup.py | 32 +- code/sim/plot_rep_setup_main.py | 71 ++- code/sim/prediction_functions.py | 189 +++--- code/sim/process.py | 21 +- code/sim/process_args.py | 37 +- code/sim/roc.py | 20 +- code/sim/roc_main.py | 79 ++- code/sim/sim_actual.py | 119 ++-- code/sim/sim_actual_main.py | 66 +- code/sim/sim_analyze_hmm_bw.py | 564 ++++++++++-------- code/sim/sim_analyze_hmm_bw_main.py | 152 ++--- code/sim/sim_multi_model.py | 39 +- code/sim/sim_predict.py | 195 +++--- code/sim/sim_predict_main.py | 73 ++- code/sim/sim_predict_phylohmm.py | 331 +++++----- code/sim/sim_predict_phylohmm_main.py | 66 +- code/sim/sim_predict_viterbi_main.py | 66 +- code/sim/sim_process.py | 51 +- code/sim/sim_stats.py | 5 +- code/sim/summarize.py | 6 +- code/sim/summarize_ils.py | 25 +- code/sim/summarize_power_fpr.py | 72 ++- code/sim/summary_stats.py | 95 +-- .../test_main_summarize_strains_args.py | 136 +++++ .../test_main_summarize_strains_config.py | 95 +++ .../analyze/test_summarize_strain_states.py | 415 +++++++++++++ .../test_summarize_strain_states_main.py | 87 --- .../helper_scripts/compare_filter_outputs.sh | 10 +- .../helper_scripts/run_summarize_strain.sh | 10 +- 69 files changed, 3360 insertions(+), 2260 deletions(-) create mode 100644 code/analyze/summarize_strain_states.py delete mode 100644 code/analyze/summarize_strain_states_main.py rename code/phylo-hmm/{run_gen_sim_seqs.py => run_gen_sim_seqs.sh} (82%) delete mode 100644 code/phylogeny/shared_introgression_nonsingleton_polymorphism_main.py create mode 100644 code/test/analyze/test_main_summarize_strains_args.py create mode 100644 code/test/analyze/test_main_summarize_strains_config.py create mode 100644 code/test/analyze/test_summarize_strain_states.py delete mode 100644 code/test/analyze/test_summarize_strain_states_main.py diff --git a/code/analyze/introgression_configuration.py b/code/analyze/introgression_configuration.py index c8ccda0..61a50ff 100644 --- a/code/analyze/introgression_configuration.py +++ b/code/analyze/introgression_configuration.py @@ -29,6 +29,8 @@ def __init__(self): 'hmm_trained', 'positions', 'probabilities', + 'strain_info', + 'state_counts', ] var_list = [ Variable('chromosomes'), diff --git a/code/analyze/main.py b/code/analyze/main.py index 46a10bc..bff61bd 100644 --- a/code/analyze/main.py +++ b/code/analyze/main.py @@ -6,6 +6,7 @@ from analyze.id_regions import ID_producer from analyze.summarize_region_quality import Summarizer from analyze.filter_regions import Filterer +from analyze.summarize_strain_states import Strain_Summarizer # TODO also check for snakemake object? @@ -297,5 +298,40 @@ def filter_regions(ctx, filterer.run(thresholds) +@cli.command() +@click.option('--introgress-inter', default='', + help='Filtered block file location with {state}.' + ' Contains all regions with reasons they failed filtering') +@click.option('--ambiguous-inter', default='', + help='Filtered block file location with {state}.' + ' Contains all regions passing introgressing filtering, ' + 'with reasons they failed ambiguous filtering') +@click.option('--strain-info', default='', + help='Tab separated table with strain name, alternate name, ' + 'location, envionment, and population') +@click.option('--state-counts', default='', + help='Output state summary file') +@click.pass_context +def summarize_strains(ctx, + introgress_inter, + ambiguous_inter, + strain_info, + state_counts): + config = ctx.obj # type: Configuration + config.set('states') + config.set(introgressed_intermediate=introgress_inter, + ambiguous_intermediate=ambiguous_inter, + strain_info=strain_info, + state_counts=state_counts) + log.info('Introgressed intermediate file is ' + f"'{config.introgressed_intermediate}'") + log.info('Ambiguous intermediate file is ' + f"'{config.ambiguous_intermediate}'") + log.info(f"Strain information from '{config.strain_info}'") + log.info(f"State counts saved to '{config.state_counts}'") + strain_summarizer = Strain_Summarizer(config) + strain_summarizer.run() + + if __name__ == '__main__': cli() diff --git a/code/analyze/summarize_strain_states.py b/code/analyze/summarize_strain_states.py new file mode 100644 index 0000000..24a1b44 --- /dev/null +++ b/code/analyze/summarize_strain_states.py @@ -0,0 +1,208 @@ +from analyze.introgression_configuration import Configuration +import logging as log +import itertools +from misc import read_table +from typing import List +from contextlib import ExitStack +import click + + +class Strain_Summarizer(): + def __init__(self, configuration: Configuration): + self.config = configuration + + def validate_arguments(self): + ''' + Check that all required instance variables are set to perform a + strain summary run. Returns true if valid, raises value error otherwise + ''' + args = [ + 'known_states', + 'introgressed_intermediate', + 'ambiguous_intermediate', + 'strain_info', + 'state_counts', + ] + variables = self.config.__dict__ + for arg in args: + if arg not in variables or variables[arg] is None: + err = ('Failed to validate strain summarizer,' + f" required argument '{arg}' was unset") + log.exception(err) + raise ValueError(err) + + return True + + def run(self): + ''' + Generate summary information for the state of + each position in the sequence + ''' + self.validate_arguments() + + summary = Summary_Table() + + states = self.config.known_states[1:] + with ExitStack() as stack: + progress_bar = None + if self.config.log_file: + progress_bar = stack.enter_context( + click.progressbar( + length=len(states), + label='State')) + for species_from in states: + + log.info(species_from) + + regions1, _ = read_table.read_table_rows( + self.config.introgressed_intermediate.format( + state=species_from), '\t') + regions2, _ = read_table.read_table_rows( + self.config.ambiguous_intermediate.format( + state=species_from), '\t') + + for region_id in regions1: + region1 = regions1[region_id] + + strain = region1['strain'] + length = int(region1['end']) - int(region1['start']) + 1 + + summary.set_region(strain, species_from, length) + summary.region_found() + + if region1['reason'] != '': # failed filter + continue + + summary.region_passes_filter1() + + region2 = regions2[region_id] + summary.record_alt_species( + region2['alternative_states'].split(',')) + + if progress_bar: + progress_bar.update(1) + + with open(self.config.strain_info, 'r') as reader: + summary.add_strain_info(reader) + + with open(self.config.state_counts, 'w') as writer: + summary.write_summary(states, writer) + + +class Summary_Table(): + def __init__(self): + self.table = {} + + def set_region(self, strain, species, length): + self.strain = strain + self.species = species + self.length = length + + def record_element(self, + strain: str, + key: str, + count: int = 1): + ''' + Increment the count of table[strain][key], adding new values as needed + ''' + + if strain not in self.table: + self.table[strain] = {} + + t = self.table[strain] + if key not in t: + t[key] = 0 + + t[key] += count + + def record_region(self, + strain: str, + species: str, + length: int, + suffix: str = "", + update_total: bool = True): + ''' + Record a region of provided length. + ''' + if suffix and suffix[0] != '_': + suffix = '_' + suffix + + self.record_element(strain, f'num_regions_{species}{suffix}', 1) + self.record_element(strain, f'num_bases_{species}{suffix}', length) + if update_total: + self.record_element(strain, f'num_bases_total{suffix}', length) + self.record_element(strain, f'num_regions_total{suffix}', 1) + + def record_alt_species(self, alt_states: List): + for species in alt_states: + self.record_alt(species) + + if len(alt_states) == 1: + self.record_region(self.strain, self.species, + self.length, '_filtered2') + else: + self.record_element(self.strain, + ('num_bases_' + + '_or_'.join(sorted(alt_states)) + + '_filtered2i'), + self.length) + + self.record_element(self.strain, + f'num_bases_{len(alt_states)}_filtered2i', + self.length) + + def region_found(self): + self.record_region(self.strain, self.species, self.length) + + def region_passes_filter1(self): + self.record_region(self.strain, self.species, + self.length, '_filtered1') + + def record_alt(self, alt_species): + self.record_region(self.strain, alt_species, + self.length, '_filtered2_inclusive', + self.species == alt_species) + + def add_strain_info(self, reader): + for line in reader: + strain, _, _, geo, env, pop = line[:-1].split('\t') + strain = strain.lower() + if strain in self.table: + d = self.table[strain] + d['population'] = pop + d['geographic_origin'] = geo + d['environmental_origin'] = env + + def write_summary(self, states, writer): + fields = self.get_fields(states) + + # write header + writer.write('strain\t' + '\t'.join(fields) + '\n') + + for strain in sorted(self.table.keys()): + row = self.table[strain] + entries = [row[field] + if field in row + else 0 + for field in fields] + + entries = [str(s) for s in [strain] + entries] + + writer.write('\t'.join(entries) + '\n') + + def get_fields(self, states): + fields = ['population', 'geographic_origin', 'environmental_origin'] +\ + [f'num_{thing}_{state}{value}' + for thing in ('regions', 'bases') + for value in ('', '_filtered1', + '_filtered2', '_filtered2_inclusive') + for state in states + ['total'] + ] + + r = sorted(states) + for n in range(2, len(r)+1): + fields += [f'num_bases_{"_or_".join(combo)}_filtered2i' + for combo in itertools.combinations(r, n)] + fields += [f'num_bases_{n}_filtered2i'] + + return fields diff --git a/code/analyze/summarize_strain_states_main.py b/code/analyze/summarize_strain_states_main.py deleted file mode 100644 index 344dae7..0000000 --- a/code/analyze/summarize_strain_states_main.py +++ /dev/null @@ -1,124 +0,0 @@ -import sys -import itertools -from analyze import predict -from collections import defaultdict -import global_params as gp -from misc import read_table - - -def main() -> None: - ''' - Generate summary information for the state of each position in the sequence - Input files: - -blocks_{species}_filtered1intermediate.txt - -blocks_{species}_filtered2intermediate.txt - -100_genomes_info.txt - - Output files: - -state_counts_by_strain.txt - ''' - args = predict.process_predict_args(sys.argv[1:]) - - d = defaultdict(lambda: defaultdict(int)) - outdir = gp.analysis_out_dir_absolute + args['tag'] - states = args['known_states'][1:] - for species_from in states: - - print(species_from) - - regions1, _ = read_table.read_table_rows( - f'{outdir}/blocks_{species_from}_' - f'{args["tag"]}_filtered1intermediate.txt', '\t') - regions2, _ = read_table.read_table_rows( - f'{outdir}/blocks_{species_from}_' - f'{args["tag"]}_filtered2intermediate.txt', '\t') - - for region_id, region1 in regions1.items(): - - strain = region1['strain'] - length = int(region1['end']) - int(region1['start']) + 1 - - d[strain][f'num_regions_{species_from}'] += 1 - d[strain]['num_regions_total'] += 1 - d[strain][f'num_bases_{species_from}'] += length - d[strain]['num_bases_total'] += length - - if regions1[region_id]['reason'] != '': - continue - - d[strain][f'num_regions_{species_from}_filtered1'] += 1 - d[strain]['num_regions_total_filtered1'] += 1 - d[strain][f'num_bases_{species_from}_filtered1'] += length - d[strain]['num_bases_total_filtered1'] += length - - alt_states = regions2[region_id]['alternative_states'].split(',') - for species_from_alt in alt_states: - d[strain][f'num_regions_{species_from_alt}' - '_filtered2_inclusive'] += 1 - d[strain][f'num_bases_{species_from_alt}' - '_filtered2_inclusive'] += length - if species_from_alt == species_from: - d[strain]['num_regions_total_filtered2_inclusive'] += 1 - d[strain]['num_bases_total_filtered2_inclusive'] += length - - if len(alt_states) == 1: - d[strain][f'num_regions_{species_from}' - '_filtered2'] += 1 - d[strain]['num_regions_total_filtered2'] += 1 - d[strain][f'num_bases_{species_from}' - '_filtered2'] += length - d[strain]['num_bases_total_filtered2'] += length - - else: - d[strain]['num_bases_' + - '_or_'.join(sorted(alt_states)) + - '_filtered2i'] += length - - d[strain][f'num_bases_{len(alt_states)}_filtered2i'] += length - - with open( - '/home/tcomi/projects/aclark4_introgression/100_genomes_info.txt', - 'r') as reader: - strain_info = [line[:-1].split('\t') for line in reader] - strain_info = {x[0].lower(): (x[5], x[3], x[4]) for x in strain_info} - - for strain in d.keys(): - d[strain]['population'] = strain_info[strain][0] - d[strain]['geographic_origin'] = strain_info[strain][1] - d[strain]['environmental_origin'] = strain_info[strain][2] - - fields = ['population', 'geographic_origin', 'environmental_origin'] +\ - [f'num_regions_{x}' for x in states] +\ - ['num_regions_total'] +\ - [f'num_regions_{x}_filtered1' for x in states] +\ - ['num_regions_total_filtered1'] +\ - [f'num_regions_{x}_filtered2' for x in states] +\ - ['num_regions_total_filtered2'] +\ - [f'num_regions_{x}_filtered2_inclusive' for x in states] +\ - ['num_regions_total_filtered2_inclusive'] +\ - [f'num_bases_{x}' for x in states] +\ - ['num_bases_total'] +\ - [f'num_bases_{x}_filtered1' for x in states] +\ - ['num_bases_total_filtered1'] +\ - [f'num_bases_{x}_filtered2' for x in states] +\ - ['num_bases_total_filtered2'] +\ - [f'num_bases_{x}_filtered2_inclusive' for x in states] +\ - ['num_bases_total_filtered2_inclusive'] - - r = sorted(args['known_states'][1:]) - for n in range(2, len(r)+1): - for combo in itertools.combinations(r, n): - fields += ['num_bases_' + '_or_'.join(combo) + '_filtered2i'] - fields += ['num_bases_' + str(n) + '_filtered2i'] - - with open(f'{outdir}/state_counts_by_strain.txt', 'w') as writer: - writer.write('strain\t' + '\t'.join(fields) + '\n') - - for strain in sorted(d.keys()): - writer.write(f'{strain}\t' + - '\t'.join([str(d[strain][x]) for x in fields]) + - '\n') - - -if __name__ == '__main__': - main() diff --git a/code/beer_strains/fastq_to_fasta.py b/code/beer_strains/fastq_to_fasta.py index 838af93..d7959ac 100644 --- a/code/beer_strains/fastq_to_fasta.py +++ b/code/beer_strains/fastq_to_fasta.py @@ -1,12 +1,15 @@ # take fastq files containing reads and quality information, along -# with reference genome, and convert to fasta file ... or vcf file and then fasta?? +# with reference genome, and convert to fasta file ... +# or vcf file and then fasta?? import os import sys fastq_dir = '/net/dunham/vol2/Giang/DunhamBeer/DunhamBeer' -quality_chars = list('!"#$%&\'()*+,-./0123456789:;<=>?@ABCDEFGHIJKLMNOPQRSTUVWXYZ[\]^_`abcdefghijklmnopqrstuvwxyz{|}~') +quality_chars = list('!"#$%&\'()*+,-./0123456789:;<=>?@\ + ABCDEFGHIJKLMNOPQRSTUVWXYZ[\\]^_`\ + abcdefghijklmnopqrstuvwxyz{|}~') char_to_score = dict(zip(quality_chars, range(1, len(quality_chars)))) fastq_dir = '/net/dunham/vol2/Giang/DunhamBeer/DunhamBeer/' @@ -16,34 +19,41 @@ if '.1.fastq' in l and 'stats' not in l and l[0] != 'N': fns.append(l[:-8]) -ref_fasta = '/net/akey/vol2/aclark4/nobackup/100_genomes/genomes/S288c_SGD-R64.fa' +ref_fasta = '/net/akey/vol2/aclark4/nobackup/\ + 100_genomes/genomes/S288c_SGD-R64.fa' ##### # align reads with bwa ##### samdir = '/net/akey/vol2/aclark4/nobackup/introgression/data/beer/dunham/sam/' -os.system('module load bwa/latest') # this doesn't actually work because it makes a new shell instance every time - TODO fix this -cmd = 'bwa index ' + ref_fasta -#print cmd -#os.system(cmd) +# this doesn't actually work because it makes a new shell instance every time +# TODO fix this +os.system('module load bwa/latest') +cmd = 'bwa index ' + ref_fasta +# print cmd +# os.system(cmd) for fn in fns: - cmd = 'bwa mem ' + ref_fasta + ' ' + fastq_dir + fn + '.1.fastq ' + fastq_dir + fn + '.2.fastq' + ' > ' + samdir + fn + '.sam' - print cmd + cmd = 'bwa mem ' + ref_fasta + ' ' + fastq_dir + fn + \ + '.1.fastq ' + fastq_dir + fn + '.2.fastq' + ' > ' + \ + samdir + fn + '.sam' + print(cmd) os.system(cmd) sys.exit() -##### +# ### # run base recalibrator -##### +# ### -outdir = '/net/akey/vol2/aclark4/nobackup/introgression/data/beer/dunham/fasta/' +outdir = '/net/akey/vol2/aclark4/nobackup/introgression/\ + data/beer/dunham/fasta/' for fn in fns: # -knownSites database of previously known polymorphisms - os.system('java -jar ~/software/GenomeAnalysisTK.jar -T BaseRecalibrator -R ' + ref_fasta + ' -I ' + fastq_dir + fn + ' -o ' + outdir + fn[:-1] + 'a') - -##### -# run -##### - + os.system('java -jar ~/software/GenomeAnalysisTK.jar ' + '-T BaseRecalibrator -R ' + ref_fasta + ' -I ' + + fastq_dir + fn + ' -o ' + outdir + fn[:-1] + 'a') + +# ### +# run +# ### diff --git a/code/beer_strains/vcf_to_fasta.py b/code/beer_strains/vcf_to_fasta.py index 01296d1..689ffe4 100644 --- a/code/beer_strains/vcf_to_fasta.py +++ b/code/beer_strains/vcf_to_fasta.py @@ -1,5 +1,6 @@ import sys + def read_vcf(fn): f = open(fn, 'r') @@ -14,14 +15,15 @@ def read_vcf(fn): f.close() return v + def vcf_to_fasta(v, fn_ref, fn_out): f_ref = open(fn_ref, 'r') - f_out = open(fn_out, 'w') line = f_ref.readline() while line != '': - + line = f_ref.readline() - + + v = read_vcf(sys.argv[1]) vcf_to_fasta(v, sys.argv[2]) diff --git a/code/config.yaml b/code/config.yaml index 1da5546..5a0fc47 100644 --- a/code/config.yaml +++ b/code/config.yaml @@ -40,13 +40,13 @@ paths: suffix: .txt analysis: - analysis_base: __OUTPUT_ROOT__/analysisp4e2 - regions: __OUTPUT_ROOT__/analysis_test/regions/{state}.fa.gz - region_index: __OUTPUT_ROOT__/analysis_test/regions/{state}.pkl + analysis_base: __OUTPUT_ROOT__/analysis + regions: __ANALYSIS_BASE__/regions/{state}.fa.gz + region_index: __ANALYSIS_BASE__/regions/{state}.pkl genes: __ANALYSIS_BASE__/genes/ blocks: __ANALYSIS_BASE__/blocks_{state}.txt - labeled_blocks: __ANALYSIS_BASE__/blocks_{state}_p4e2_labeled.txt - quality_blocks: __ANALYSIS_BASE__/blocks_{state}_p4e2_quality.txt + labeled_blocks: __ANALYSIS_BASE__/blocks_{state}_labeled.txt + quality_blocks: __ANALYSIS_BASE__/blocks_{state}_quality.txt hmm_initial: __ANALYSIS_BASE__/hmm_initial.txt hmm_trained: __ANALYSIS_BASE__/hmm_trained.txt probabilities: __ANALYSIS_BASE__/probabilities.txt.gz @@ -60,6 +60,11 @@ paths: ambiguous_intermediate: "__ANALYSIS_BASE__/\ blocks_{state}_filter2inter.txt" filter_sweep: __ANALYSIS_BASE__/filter2_thresholds.txt + # strain_info is a tsv file with strain, _, _, geographic origin, + # environmental origin, population + strain_info: "/home/tcomi/projects/aclark4_introgression/\ + 100_genomes_info.txt" + state_counts: __ANALYSIS_BASE__/state_counts.txt # software install locations software: @@ -87,7 +92,6 @@ chromosomes: ['I', 'II', 'III', 'IV', 'V', 'VI', 'VII', 'VIII', # if blank will be the reference and known state names joined with '_' analysis_params: - tag: p2e4 convergence_threshold: 0.001 # threshold can be 'viterbi' or a float to threshold HMM probabilities threshold: viterbi diff --git a/code/hmm/hmm_bw.py b/code/hmm/hmm_bw.py index 2d5982d..e787916 100644 --- a/code/hmm/hmm_bw.py +++ b/code/hmm/hmm_bw.py @@ -23,7 +23,7 @@ def set_hidden_states(self, states: List[str]) -> None: def set_observed_states(self, states: List[str]) -> None: ''' Sets the observed states of the HMM to the supplied list of strings - If not supplied will set to list of keys provided by emissions + If not supplied will set to list of keys provided by emissions ''' self.observed_states = states @@ -87,7 +87,8 @@ def set_initial_p(self, initial_p: List[float]) -> None: ''' self.initial_p = np.array(initial_p) - assert np.isclose(np.sum(initial_p), 1), f"{initial_p} {sum(initial_p)}" + assert np.isclose(np.sum(initial_p), 1), \ + f"{initial_p} {sum(initial_p)}" def print_results(self, iterations: int, LL: float) -> None: ''' @@ -327,7 +328,7 @@ def backward(self) -> np.array: def calculate_max_states(self) -> Tuple[np.array, np.array]: ''' - Find the maximum likelihood hidden states and the corresponding + Find the maximum likelihood hidden states and the corresponding log probability for each state. Returned tuple is (probability, states) ''' diff --git a/code/misc/mystats.py b/code/misc/mystats.py index c18d7a9..f71e2fb 100644 --- a/code/misc/mystats.py +++ b/code/misc/mystats.py @@ -1,43 +1,48 @@ import math import numpy.random -def mean(l): - l = filter(lambda x: x != 'NA' and not math.isnan(x), l) - if len(l) == 0: - #TODO float('nan') ? + +def mean(values): + values = filter(lambda x: x != 'NA' and not math.isnan(x), values) + if len(values) == 0: + # TODO float('nan') ? return 'NA' - return float(sum(l)) / len(l) + return float(sum(values)) / len(values) + -def std_dev(l): - l = filter(lambda x: x != 'NA' and not math.isnan(x), l) - if len(l) == 0: +def std_dev(values): + values = filter(lambda x: x != 'NA' and not math.isnan(x), values) + if len(values) == 0: return 'NA' - if len(l) == 1: + if len(values) == 1: return 0 - m = mean(l) - return math.sqrt(sum([(x - m)**2 for x in l]) / (len(l) - 1)) + m = mean(values) + return math.sqrt(sum([(x - m)**2 for x in values]) / (len(values) - 1)) + -def std_err(l): - l = filter(lambda x: x != 'NA' and not math.isnan(x), l) - if len(l) == 0: +def std_err(values): + values = filter(lambda x: x != 'NA' and not math.isnan(x), values) + if len(values) == 0: return 'NA' - return std_dev(l) / math.sqrt(len(l)) + return std_dev(values) / math.sqrt(len(values)) -def bootstrap(l, n = 100, alpha = .05): - l = filter(lambda x: x != 'NA' and not math.isnan(x), l) - x = len(l) + +def bootstrap(values, n=100, alpha=.05): + values = filter(lambda x: x != 'NA' and not math.isnan(x), values) + x = len(values) if x == 0: return 'NA', 'NA' a = [] for i in range(n): - a.append(mean(numpy.random.choice(l, size = x, replace = True))) + a.append(mean(numpy.random.choice(values, size=x, replace=True))) a.sort() - #print len(a), a.count(0) - #print mean(a) + # print len(a), a.count(0) + # print mean(a) return a[int(alpha * n * .5)], a[int((1 - alpha * .5) * n)] -def median(l): - m = sorted(l) + +def median(values): + m = sorted(values) x = len(m) if x % 2 == 0: return mean([m[x/2], m[x/2-1]]) diff --git a/code/misc/overlap.py b/code/misc/overlap.py index d19ed2a..06001fb 100644 --- a/code/misc/overlap.py +++ b/code/misc/overlap.py @@ -17,10 +17,11 @@ def overlap(start1, end1, start2, end2): return end2 - start1 + 1 else: return 0 - - #if start1 < start2: - # return max(end1 - start2 + 1, 0) - max(end1 - end2, 0) - #return max(end2 - start1 + 1, 0) - max(end2 - end1, 0) + + # if start1 < start2: + # return max(end1 - start2 + 1, 0) - max(end1 - end2, 0) + # return max(end2 - start1 + 1, 0) - max(end2 - end1, 0) + def overlap_any(start1, end1, coords): for start2, end2 in coords: @@ -28,24 +29,28 @@ def overlap_any(start1, end1, coords): return True return False + def contained(i, start, end): return i >= start and i <= end + def contained_any(i, coords): for start2, end2 in coords: if contained(i, start2, end2): return True return False + def contained_any_named(i, coords): for start2, end2 in coords.keys(): if contained(i, start2, end2): return coords[(start2, end2)] return None + def overlap_region(start1, end1, start2, end2): o_start = max(start1, start2) o_end = min(end1, end2) if o_start > o_end: - return -1, -1 # disjoint ranges + return -1, -1 # disjoint ranges return o_start, o_end diff --git a/code/misc/read_maf.py b/code/misc/read_maf.py index 69bcac4..7950f08 100644 --- a/code/misc/read_maf.py +++ b/code/misc/read_maf.py @@ -1,6 +1,7 @@ import re -def read_mugsy(fn, required_mult = 1): + +def read_mugsy(fn, required_mult=1): f = open(fn, 'r') line = f.readline() while line[0] == '#': @@ -9,9 +10,9 @@ def read_mugsy(fn, required_mult = 1): while line != '': assert line[0] == 'a', line block = {} - m = re.search('a score=(?P[0-9]+) ' +\ - 'label=(?P