diff --git a/scripts/extract_tracts.py b/scripts/extract_tracts.py index d7327e6..40e479e 100755 --- a/scripts/extract_tracts.py +++ b/scripts/extract_tracts.py @@ -23,6 +23,9 @@ import logging import re import os +import shutil + +from io import StringIO logging.basicConfig(format="%(levelname)s (%(name)s %(lineno)s): %(message)s") logger = logging.getLogger(__name__) @@ -155,10 +158,14 @@ def extract_tracts(vcf=str, msp=str, num_ancs=int, output_dir=None, output_vcf=N pop_genos = {} for i in range(num_ancs): # Write entries into each output files' list - output_lines[f"dos{i}"] = dos_anc_out - output_lines[f"ancdos{i}"] = dos_anc_out + output_lines[f"dos{i}"] = StringIO() + output_lines[f"dos{i}"].write(dos_anc_out) + output_lines[f"ancdos{i}"] = StringIO() + output_lines[f"ancdos{i}"].write(dos_anc_out) + if output_vcf: - output_lines[f"vcf{i}"] = vcf_out + output_lines[f"vcf{i}"] = StringIO() + output_lines[f"vcf{i}"].write(vcf_out) # optimized for quicker runtime - only move to next line when out of the current msp window # saves the current line until out of the window, then checks next line. VCF and window switches file should be in incremental order. @@ -174,7 +181,7 @@ def extract_tracts(vcf=str, msp=str, num_ancs=int, output_dir=None, output_vcf=N break # when get to the end of the msp file, stop # chm, spos, epos, sgpos, egpos, nsnps, calls ancs_entry = ancs.strip().split("\t", 6) - calls = ancs_entry[6].split("\t") + calls = [int(x) for x in ancs_entry[6].split("\t")] window = (ancs_entry[0], int(ancs_entry[1]), int(ancs_entry[2])) if row[0] == window[0] and window[1]> pos: skip_line=True #Skip VCF line @@ -187,15 +194,15 @@ def extract_tracts(vcf=str, msp=str, num_ancs=int, output_dir=None, output_vcf=N for i, geno in enumerate(genos): geno_parts = geno.split(":")[0].split("|") # assert incase eagle leaves some genos unphased geno_a,geno_b = map(str, geno_parts[:2]) - call_a = str(calls[2*i]) - call_b = str(calls[2*i + 1]) + call_a = calls[2*i] + call_b = calls[2*i + 1] counts = {anc: 0 for anc in range(num_ancs)} anc_counts = {anc: 0 for anc in range(num_ancs)} for j in range(num_ancs): if output_vcf: pop_genos[j] = "" - if call_a == str(j): + if call_a == j: if output_vcf: pop_genos[j] += geno_a anc_counts[j] += 1 @@ -205,7 +212,7 @@ def extract_tracts(vcf=str, msp=str, num_ancs=int, output_dir=None, output_vcf=N if output_vcf: pop_genos[j] += "." - if call_b == str(j): + if call_b == j: if output_vcf: pop_genos[j] += "|" + geno_b anc_counts[j] += 1 @@ -215,19 +222,25 @@ def extract_tracts(vcf=str, msp=str, num_ancs=int, output_dir=None, output_vcf=N if output_vcf: pop_genos[j] += "|." - output_lines[f"dos{j}"] += "\t" + str(counts[j]) - output_lines[f"ancdos{j}"] += "\t" + str(anc_counts[j]) + output_lines[f"dos{j}"].write("\t" + str(counts[j])) + output_lines[f"ancdos{j}"].write("\t" + str(anc_counts[j])) if output_vcf: - output_lines[f"vcf{j}"] += "\t" + pop_genos[j] + output_lines[f"vcf{j}"].write("\t" + pop_genos[j]) for j in range(num_ancs): - output_lines[f"dos{j}"] += "\n" - output_lines[f"ancdos{j}"] += "\n" - files[f"dos{j}"].write(output_lines[f"dos{j}"]) - files[f"ancdos{j}"].write(output_lines[f"ancdos{j}"]) + output_lines[f"dos{j}"].write("\n") + output_lines[f"ancdos{j}"].write("\n") + + output_lines[f"dos{j}"].seek(0) + output_lines[f"ancdos{j}"].seek(0) + shutil.copyfileobj(output_lines[f"dos{j}"], files[f"dos{j}"]) + shutil.copyfileobj(output_lines[f"ancdos{j}"], files[f"ancdos{j}"]) + + if output_vcf: - output_lines[f"vcf{j}"] += "\n" - files[f"vcf{j}"].write(output_lines[f"vcf{j}"]) + output_lines[f"vcf{j}"].write("\n") + output_lines[f"vcf{j}"].seek(0) + shutil.copyfileobj(output_lines[f"vcf{j}"], files[f"vcf{j}"]) logger.info("Finished extracting tracts per %d ancestries", num_ancs)