diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 57557b4d..1eefdb00 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -48,6 +48,10 @@ jobs: wget https://cardiac.nottingham.ac.uk/syncropatch_export/test_data.tar.xz -P tests/ tar xvf tests/test_data.tar.xz -C tests/ + - name: Install TeX dependencies for run_herg_qc test + timeout-minutes: 5 + run: sudo apt-get install dvipng texlive-latex-extra texlive-fonts-recommended cm-super -y + - name: Run unit tests (without coverage testing) if: ${{ success() && matrix.python-version != env.python-latest }} run: python -m unittest @@ -56,14 +60,6 @@ jobs: if: ${{ success() && matrix.python-version == env.python-latest }} run: coverage run -m unittest - - name: Install TeX dependencies for run_herg_qc test - timeout-minutes: 15 - run: sudo apt-get install dvipng texlive-latex-extra texlive-fonts-recommended cm-super -y - - - name: Run `run_herg_qc` script - timeout-minutes: 15 - run: pcpostprocess run_herg_qc tests/test_data/13112023_MW2_FF -w A01 A02 A03 - - name: Report coverage to codecov uses: codecov/codecov-action@v4 if: ${{ success() && matrix.python-version == env.python-latest }} diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 73af3425..7ea00af3 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -76,6 +76,7 @@ Isort is configured in [pyproject.toml](./pyproject.toml) under the section `too ## Documentation Every method and every class should have a [docstring](https://www.python.org/dev/peps/pep-0257/) that describes in plain terms what it does, and what the expected input and output is. +The only exception are unit test methods starting with `test_` - unit test classes and other methods in unit tests should all have docstrings. Each docstring should start with a one-line explanation. If more explanation is needed, this one-liner is followed by a blank line and more information in the following paragraphs. diff --git a/pcpostprocess/scripts/__main__.py b/pcpostprocess/scripts/__main__.py index 140e5c72..b6252818 100644 --- a/pcpostprocess/scripts/__main__.py +++ b/pcpostprocess/scripts/__main__.py @@ -21,7 +21,7 @@ def main(): run_herg_qc.run_from_command_line() elif args.subcommand == "summarise_herg_export": - summarise_herg_export.main() + summarise_herg_export.run_from_command_line() if __name__ == "__main__": diff --git a/pcpostprocess/scripts/run_herg_qc.py b/pcpostprocess/scripts/run_herg_qc.py index 8808224c..ff37ff9c 100644 --- a/pcpostprocess/scripts/run_herg_qc.py +++ b/pcpostprocess/scripts/run_herg_qc.py @@ -11,7 +11,6 @@ import string import sys -import cycler import matplotlib import matplotlib.pyplot as plt import numpy as np @@ -28,17 +27,11 @@ from pcpostprocess.leak_correct import fit_linear_leak, get_leak_corrected from pcpostprocess.subtraction_plots import do_subtraction_plot -# TODO: Remove this -color_cycle = ["#5790fc", "#f89c20", "#e42536", "#964a8b", "#9c9ca1", "#7a21dd"] -plt.rcParams['axes.prop_cycle'] = cycler.cycler('color', color_cycle) - -# TODO: Not sure we need to explicitly set this! -matplotlib.use('Agg') - def run_from_command_line(): """ - Reads arguments from the command line and runs herg QC. + Reads arguments from the command line and an ``export_config.py`` and then + runs herg QC. """ parser = argparse.ArgumentParser() @@ -109,11 +102,7 @@ def run(data_path, output_path, qc_map, wells=None, write_traces=False, write_failed_traces=False, write_map={}, reversal_potential=-90, reversal_spread_threshold=10, max_processes=1, figure_size=None, - debug=False, - - save_id=None, - - ): + debug=False, save_id=None): """ Imports traces and runs QC. @@ -148,6 +137,9 @@ def run(data_path, output_path, qc_map, wells=None, # TODO Remove protocol selection here: this is done via the export file! # Only protocols listed there are accepted + # TODO: Find some way around setting this? + matplotlib.use('Agg') + # Select wells to use all_wells = [row + str(i).zfill(2) for row in string.ascii_uppercase[:16] for i in range(1, 25)] @@ -503,8 +495,6 @@ def agg_func(x): qc_df['protocol'] = ['staircaseramp1_2' if p == 'staircaseramp2' else p for p in qc_df.protocol] - print(qc_df.protocol.unique()) - fails_dict = {} no_wells = 384 @@ -1255,7 +1245,7 @@ def fit_func(x, args=None): ] # TESTING ONLY - np.random.seed(1) + # np.random.seed(1) #  Repeat optimisation with different starting guesses x0s = [[np.random.uniform(lower_b, upper_b) for lower_b, upper_b in bounds] for i in range(100)] diff --git a/pcpostprocess/scripts/summarise_herg_export.py b/pcpostprocess/scripts/summarise_herg_export.py index 84e19c81..a9664a50 100644 --- a/pcpostprocess/scripts/summarise_herg_export.py +++ b/pcpostprocess/scripts/summarise_herg_export.py @@ -1,10 +1,8 @@ import argparse import json -import logging import os import string -import cycler import matplotlib import matplotlib.pyplot as plt import numpy as np @@ -17,16 +15,8 @@ from pcpostprocess.directory_builder import setup_output_directory from pcpostprocess.scripts.run_herg_qc import create_qc_table -matplotlib.use('Agg') -pool_kws = {'maxtasksperchild': 1} - -color_cycle = ["#5790fc", "#f89c20", "#e42536", "#964a8b", "#9c9ca1", "#7a21dd"] -plt.rcParams['axes.prop_cycle'] = cycler.cycler('color', color_cycle) -sns.set_palette(sns.color_palette(color_cycle)) - - -def get_wells_list(input_dir): +def get_wells_list(input_dir, experiment_name): regex = re.compile(f"{experiment_name}-([a-z|A-Z|0-9]*)-([A-Z][0-9][0-9])-after") wells = [] @@ -37,59 +27,65 @@ def get_wells_list(input_dir): return list(np.unique(wells)) -def get_protocol_list(input_dir): +def get_protocol_list(input_dir, experiment_name): regex = re.compile(f"{experiment_name}-([a-z|A-Z|0-9]*)-([A-Z][0-9][0-9])-after") protocols = [] for f in filter(regex.match, os.listdir(input_dir)): well = re.search(regex, f).groups(3)[0] - if protocols not in protocols: + if protocols not in protocols: # TODO This has GOT to be a bug protocols.append(well) return list(np.unique(protocols)) -def main(): +def run_from_command_line(): + """ + Parses arguments from the command line and then ??? + """ - description = "" + description = '' # TODO Describe what this does parser = argparse.ArgumentParser(description) - - parser.add_argument('data_dir', type=str, help="path to the directory containing the subtract_leak results") - parser.add_argument('--cpus', '-c', default=1, type=int) - parser.add_argument('--wells', '-w', nargs='+', default=None) - parser.add_argument('--output_dir', '-o', default='output') - parser.add_argument('--protocols', type=str, default=[], nargs='+') - parser.add_argument('-r', '--reversal', type=float, default=np.nan) - # parser.add_argument('--selection_file', default=None, type=str) - parser.add_argument('--experiment_name', default='newtonrun4') - parser.add_argument('--figsize', type=int, nargs=2, default=[5, 3]) - parser.add_argument('--output_all', action='store_true') - parser.add_argument('--log_level', default='INFO') - - global args + parser.add_argument( + 'data_directory', help='path to the run_herg_qc results') + parser.add_argument( + 'experiment_name', help='the name of the experiment') + parser.add_argument('-o', '--output_dir', default='output', + help='The path to write output to') + parser.add_argument( + '--Erev', default=None, type=float, + help='The calculated or estimated reversal potential.') + parser.add_argument( + '--figsize', type=int, nargs=2, default=(5, 3), + help='A figure size, to pass to matplotlib') args = parser.parse_args() - # Setup logging - logging.basicConfig(level=args.log_level) - global logger - logger = logging.getLogger(__name__) - logger.setLevel(args.log_level) + run(args.data_directory, args.output_dir, args.experiment_name, + args.Erev, args.figsize) - global experiment_name - experiment_name = args.experiment_name - global output_dir - output_dir = setup_output_directory(args.output_dir, "summarise_herg_export") +def run(data_path, output_path, experiment_name, reversal_potential=None, + figsize=None): + """ + Does whatever this does. - leak_parameters_df = pd.read_csv(os.path.join(args.data_dir, 'subtraction_qc.csv')) + @param data_path The path to read data from + @param output_path A root path, will be appended with "summarise_herg_export" + @param experiment_name + @param reversal_potential The calculated reversal potential, or ``None`` + @param figsize The matplotlib figure size, or ``None``. + """ + # TODO: Find some way around setting this + matplotlib.use('Agg') - qc_df = pd.read_csv(os.path.join(args.data_dir, f"QC-{experiment_name}.csv")) + output_path = setup_output_directory(output_path) + + leak_parameters_df = pd.read_csv(os.path.join(data_path, 'subtraction_qc.csv')) + + qc_df = pd.read_csv(os.path.join(data_path, f"QC-{experiment_name}.csv")) qc_styled_df = create_qc_table(qc_df) qc_styled_df = qc_styled_df.pivot(columns='protocol', index='crit') - qc_styled_df.to_excel(os.path.join(output_dir, 'qc_table.xlsx')) - qc_styled_df.to_latex(os.path.join(output_dir, 'qc_table.tex')) - - qc_estimates_file = os.path.join(args.save_dir, f"{args.experiment_name}_subtraction_qc.csv") - qc_vals_df = pd.read_csv(os.path.join(qc_estimates_file)) + # qc_styled_df.to_excel(os.path.join(output_path, 'qc_table.xlsx')) + qc_styled_df.to_latex(os.path.join(output_path, 'qc_table.tex')) qc_df.protocol = ['staircaseramp1' if protocol == 'staircaseramp' else protocol for protocol in qc_df.protocol] @@ -101,22 +97,14 @@ def main(): leak_parameters_df.protocol = ['staircaseramp1_2' if protocol == 'staircaseramp_2' else protocol for protocol in leak_parameters_df.protocol] - print(leak_parameters_df.protocol.unique()) - - with open(os.path.join(args.data_dir, 'passed_wells.txt')) as fin: - global passed_wells + with open(os.path.join(data_path, 'passed_wells.txt')) as fin: passed_wells = fin.read().splitlines() # Compute new variables leak_parameters_df = compute_leak_magnitude(leak_parameters_df) - global wells - wells = leak_parameters_df.well.unique() - global protocols - protocols = leak_parameters_df.protocol.unique() - try: - chrono_fname = os.path.join(args.data_dir, 'chrono.txt') + chrono_fname = os.path.join(data_path, 'chrono.txt') with open(chrono_fname, 'r') as fin: lines = fin.read().splitlines() protocol_order = [line.split(' ')[0] for line in lines] @@ -127,48 +115,41 @@ def main(): protocol_order = ['staircaseramp1_2' if p == 'staircaseramp_2' else p for p in protocol_order] - leak_parameters_df['protocol'] = pd.Categorical(leak_parameters_df['protocol'], - categories=protocol_order, - ordered=True) - - qc_vals_df['protocol'] = pd.Categorical(qc_vals_df['protocol'], - categories=protocol_order, - ordered=True) + leak_parameters_df['protocol'] = pd.Categorical( + leak_parameters_df['protocol'], categories=protocol_order, ordered=True) leak_parameters_df.sort_values(['protocol', 'sweep'], inplace=True) - except FileNotFoundError as exc: - logging.warning(str(exc)) - logger.warning('no chronological information provided. Sorting alphabetically') + except FileNotFoundError: leak_parameters_df.sort_values(['protocol', 'sweep']) - scatterplot_timescale_E_obs(leak_parameters_df) + scatterplot_timescale_E_obs(output_path, leak_parameters_df, passed_wells, figsize) - do_chronological_plots(leak_parameters_df) - do_chronological_plots(leak_parameters_df, normalise=True) + do_chronological_plots(leak_parameters_df, output_path, reversal_potential, + figsize=figsize, normalise=False) + do_chronological_plots(leak_parameters_df, output_path, reversal_potential, + figsize=figsize, normalise=True) attrition_df = create_attrition_table(qc_df, leak_parameters_df) - attrition_df.to_latex(os.path.join(output_dir, 'attrition.tex')) + attrition_df.to_latex(os.path.join(output_path, 'attrition.tex')) if 'passed QC' not in leak_parameters_df.columns and\ 'passed QC6a' in leak_parameters_df.columns: leak_parameters_df['passed QC'] = leak_parameters_df['passed QC6a'] - plot_leak_conductance_change_sweep_to_sweep(leak_parameters_df) - plot_reversal_change_sweep_to_sweep(leak_parameters_df) - plot_spatial_passed(leak_parameters_df) - plot_reversal_spread(leak_parameters_df) - if np.isfinite(args.reversal): - plot_spatial_Erev(leak_parameters_df) - - leak_parameters_df['passed QC'] = [well in passed_wells for well in leak_parameters_df.well] - qc_vals_df['passed QC'] = [well in passed_wells for well in qc_vals_df.well] + plot_leak_conductance_change_sweep_to_sweep( + leak_parameters_df, output_path, passed_wells, figsize) + plot_reversal_change_sweep_to_sweep( + leak_parameters_df, output_path, passed_wells, figsize) + plot_spatial_passed(leak_parameters_df, output_path, passed_wells) + plot_reversal_spread(leak_parameters_df, output_path, figsize) + if reversal_potential is not None: + plot_spatial_Erev(leak_parameters_df, output_path, figsize) - # do_scatter_matrices(leak_parameters_df, qc_vals_df) - plot_histograms(leak_parameters_df, qc_vals_df) + leak_parameters_df['passed QC'] = [ + well in passed_wells for well in leak_parameters_df.well] - # Very resource intensive - # overlay_reversal_plots(leak_parameters_df) - # do_combined_plots(leak_parameters_df) + # do_scatter_matrices(leak_parameters_df, qc_vals_df, output_path, reversal_potential) + plot_histograms(leak_parameters_df, output_path, reversal_potential, figsize) def compute_leak_magnitude(df, lims=[-120, 60]): @@ -197,8 +178,17 @@ def compute_magnitude(g, E, lims=lims): return df -def scatterplot_timescale_E_obs(df): - fig = plt.figure(figsize=args.figsize, constrained_layout=True) +def scatterplot_timescale_E_obs(output_path, df, passed_wells, figsize=None): + """ + ??? + + @param output_path + @param df + @param passed_wells + @param figsize + + """ + fig = plt.figure(figsize=figsize, constrained_layout=True) ax = fig.subplots() df = df[(df.well.isin(passed_wells))].sort_values('protocol') @@ -235,7 +225,7 @@ def scatterplot_timescale_E_obs(df): ax.set_ylabel(r'$\tau$ (ms)') ax.set_xlabel(r'$E_\mathrm{obs}$') - fig.savefig(os.path.join(output_dir, "decay_timescale_vs_E_rev_scatter.pdf")) + fig.savefig(os.path.join(output_path, 'decay_timescale_vs_E_rev_scatter.png')) ax.cla() sns.lineplot(data=plot_df, y='40mV decay time constant', @@ -245,7 +235,7 @@ def scatterplot_timescale_E_obs(df): ax.set_ylabel(r'$\tau$ (ms)') ax.set_xlabel(r'$E_\mathrm{obs}$') ax.spines[['top', 'right']].set_visible(False) - fig.savefig(os.path.join(output_dir, "decay_timescale_vs_E_rev_line.pdf")) + fig.savefig(os.path.join(output_path, 'decay_timescale_vs_E_rev_line.png')) ax.cla() plot_df['E_rev'] = (plot_df.set_index('well')['E_rev'] - plot_df.groupby('well') @@ -257,15 +247,25 @@ def scatterplot_timescale_E_obs(df): ax.set_ylabel(r'$E_\mathrm{leak} - \bar E_\mathrm{leak}$ (ms)') ax.set_xlabel(r'$E_\mathrm{obs} - \bar E_\mathrm{obs}$') - fig.savefig(os.path.join(output_dir, "E_leak_vs_E_rev_scatter.pdf")) + fig.savefig(os.path.join(output_path, 'E_leak_vs_E_rev_scatter.png')) ax.cla() -def do_chronological_plots(df, normalise=False): - fig = plt.figure(figsize=args.figsize, constrained_layout=True) +def do_chronological_plots(df, output_path, reversal_potential=None, + normalise=False, figsize=None): + """ + ??? + + @param df + @param output_path + @param reversal_potential + @param normalise + """ + + fig = plt.figure(figsize=figsize, constrained_layout=True) ax = fig.subplots() - sub_dir = os.path.join(output_dir, 'chrono_plots') + sub_dir = os.path.join(output_path, 'chrono_plots') if not os.path.exists(sub_dir): os.makedirs(sub_dir) @@ -325,8 +325,9 @@ def label_func(p, s): ax.legend(frameon=False, fontsize=8) - if var == 'E_rev' and np.isfinite(args.reversal): - ax.axhline(args.reversal, linestyle='--', color='grey', label='Calculated Nernst potential') + if var == 'E_rev' and reversal_potential is not None: + ax.axhline(reversal_potential, linestyle='--', color='grey', + label='Calculated Nernst potential') ax.set_xlabel('') if var in pretty_vars and var in units: @@ -336,22 +337,31 @@ def label_func(p, s): legend_handles, _ = ax.get_legend_handles_labels() ax.legend(legend_handles, ['failed QC', 'passed QC'], bbox_to_anchor=(1.26, 1)) - fig.savefig(os.path.join(sub_dir, f"{var.replace(' ', '_')}.pdf"), - format='pdf') + fig.savefig(os.path.join(sub_dir, f'{var.replace(" ", "_")}.png')) ax.cla() plt.close(fig) -def do_combined_plots(leak_parameters_df): - fig = plt.figure(figsize=args.figsize, constrained_layout=True) +def do_combined_plots(data_path, output_path, experiment_name, + leak_parameters_df, passed_wells, figsize=None): + """ + ??? + + @param data_path + @param output_path + @param experiment_name + @param leak_parameters_df + @param passed_wells + @param figsize + + """ + fig = plt.figure(figsize=figsize, constrained_layout=True) ax = fig.subplots() wells = [well for well in leak_parameters_df.well.unique() if well in passed_wells] - logger.info(f"passed wells are {passed_wells}") - - protocol_overlaid_dir = os.path.join(output_dir, 'overlaid_by_protocol') + protocol_overlaid_dir = os.path.join(output_path, 'overlaid_by_protocol') if not os.path.exists(protocol_overlaid_dir): os.makedirs(protocol_overlaid_dir) @@ -359,9 +369,16 @@ def do_combined_plots(leak_parameters_df): palette = sns.color_palette('husl', len(leak_parameters_df.groupby(['well', 'sweep']))) for protocol in leak_parameters_df.protocol.unique(): - times_fname = f"{experiment_name}-{protocol}-times.csv" + pname = protocol + if pname == 'staircaseramp1': + pname = 'staircaseramp' + elif pname == 'staircaseramp1_2': + pname = 'staircaseramp_2' + + times_fname = os.path.join(data_path, 'traces', + f'{experiment_name}-{pname}-times.csv') try: - times = np.loadtxt(os.path.join(args.data_dir, 'traces', times_fname)).astype(np.float64).flatten() + times = np.loadtxt(times_fname).astype(np.float64).flatten() except FileNotFoundError: continue @@ -372,9 +389,9 @@ def do_combined_plots(leak_parameters_df): i = 0 for sweep in leak_parameters_df.sweep.unique(): for well in wells: - fname = f"{experiment_name}-{protocol}-{well}-sweep{sweep}.csv" + fname = f"{experiment_name}-{pname}-{well}-sweep{sweep}.csv" try: - data = pd.read_csv(os.path.join(args.data_dir, 'traces', fname)) + data = pd.read_csv(os.path.join(data_path, 'traces', fname)) except FileNotFoundError: continue @@ -402,26 +419,30 @@ def do_combined_plots(leak_parameters_df): palette = sns.color_palette('husl', len(leak_parameters_df.groupby(['protocol', 'sweep']))) - fig2 = plt.figure(figsize=args.figsize, constrained_layout=True) + fig2 = plt.figure(figsize=figsize, constrained_layout=True) axs2 = fig2.subplots(1, 2, sharey=True) - wells_overlaid_dir = os.path.join(output_dir, 'overlaid_by_well') + wells_overlaid_dir = os.path.join(output_path, 'overlaid_by_well') if not os.path.exists(wells_overlaid_dir): os.makedirs(wells_overlaid_dir) - logger.info('overlaying traces by well') - for well in passed_wells: i = 0 for sweep in leak_parameters_df.sweep.unique(): for protocol in leak_parameters_df.protocol.unique(): - times_fname = f"{experiment_name}-{protocol}-times.csv" - times = np.loadtxt(os.path.join(args.data_dir, 'traces', times_fname)) + pname = protocol + if pname == 'staircaseramp1': + pname = 'staircaseramp' + elif pname == 'staircaseramp1_2': + pname = 'staircaseramp_2' + + times_fname = f'{experiment_name}-{pname}-times.csv' + times = np.loadtxt(os.path.join(data_path, 'traces', times_fname)) times = times.flatten().astype(np.float64) fname = f"{experiment_name}-{protocol}-{well}-sweep{sweep}.csv" try: - data = pd.read_csv(os.path.join(args.data_dir, 'traces', fname)) + data = pd.read_csv(os.path.join(data_path, 'traces', fname)) except FileNotFoundError: continue @@ -462,31 +483,36 @@ def do_combined_plots(leak_parameters_df): plt.close(fig2) -def do_scatter_matrices(df, qc_df): +def do_scatter_matrices(df, qc_df, output_path, reversal_potential=None): + """ + ??? + + @param df + @param qc_df + @param output_path + @reversal_potential + """ grid = sns.pairplot(data=df, hue='passed QC', diag_kind='hist', plot_kws={'alpha': 0.4, 'edgecolor': None}, hue_order=[True, False]) - grid.savefig(os.path.join(output_dir, 'scatter_matrix_by_QC')) + grid.savefig(os.path.join(output_path, 'scatter_matrix_by_QC')) - if args.reversal: - true_reversal = args.reversal + if reversal_potential is not None: + true_reversal = reversal_potential else: + # TODO Clarify in plot label! true_reversal = df['E_rev'].values.mean() df['hue'] = df.E_rev.to_numpy() > true_reversal grid = sns.pairplot(data=df, hue='hue', diag_kind='hist', plot_kws={'alpha': 0.4, 'edgecolor': None}, hue_order=[True, False]) - grid.savefig(os.path.join(output_dir, 'scatter_matrix_by_reversal.pdf'), - format='pdf') + grid.savefig(os.path.join(output_path, 'scatter_matrix_by_reversal.png')) # Now do artefact parameters only if 'drug' in qc_df: qc_df = qc_df[qc_df.drug == 'before'] - # if args.selection_file and not args.output_all: - # qc_df = qc_df[qc_df.well.isin(passed_wells)] - first_sweep = sorted(list(qc_df.sweep.unique()))[0] qc_df = qc_df[(qc_df.protocol == 'staircaseramp1') & (qc_df.sweep == first_sweep)] @@ -500,10 +526,17 @@ def do_scatter_matrices(df, qc_df): 'edgecolor': None}, hue='passed QC', hue_order=[True, False]) - grid.savefig(os.path.join(output_dir, 'scatter_matrix_QC_params_by_QC')) + grid.savefig(os.path.join(output_path, 'scatter_matrix_QC_params_by_QC')) + +def plot_reversal_spread(df, output_path, figsize=None): + """ + ??? -def plot_reversal_spread(df): + @param df + @param output_path + @param figsize + """ df.E_rev = df.E_rev.values.astype(np.float64) failed_to_infer = [well for well in df.well.unique() if not @@ -522,7 +555,7 @@ def spread_func(x): }) group_df['E_Kr range'] = group_df['E_rev'] - fig = plt.figure(figsize=args.figsize, constrained_layout=True) + fig = plt.figure(figsize=figsize, constrained_layout=True) ax = fig.subplots() sns.histplot(data=group_df, x='E_Kr range', hue='passed QC', @@ -530,12 +563,21 @@ def spread_func(x): ax.set_xlabel(r'spread in inferred E_Kr / mV') - fig.savefig(os.path.join(output_dir, 'spread_of_fitted_E_Kr')) - df.to_csv(os.path.join(output_dir, 'spread_of_fitted_E_Kr.csv')) + fig.savefig(os.path.join(output_path, 'spread_of_fitted_E_Kr')) + df.to_csv(os.path.join(output_path, 'spread_of_fitted_E_Kr.csv')) -def plot_reversal_change_sweep_to_sweep(df): - fig = plt.figure(figsize=args.figsize, constrained_layout=True) +def plot_reversal_change_sweep_to_sweep( + df, output_path, passed_wells, figsize=None): + """ + ??? + + @param df + @param output_path + @param passed_wells + @param figsize + """ + fig = plt.figure(figsize=figsize, constrained_layout=True) ax = fig.subplots() for protocol in df.protocol.unique(): @@ -563,14 +605,23 @@ def plot_reversal_change_sweep_to_sweep(df): sns.histplot(data=delta_df, x=var_name_ltx, hue='passed QC', stat='count', multiple='stack') - fig.savefig(os.path.join(output_dir, f"E_rev_sweep_to_sweep_{protocol}")) + fig.savefig(os.path.join(output_path, f"E_rev_sweep_to_sweep_{protocol}")) ax.cla() plt.close(fig) -def plot_leak_conductance_change_sweep_to_sweep(df): - fig = plt.figure(figsize=args.figsize, constrained_layout=True) +def plot_leak_conductance_change_sweep_to_sweep( + df, output_path, passed_wells, figsize=None): + """ + ??? + + @param df + @param output_path + @param passed_wells + @param figsize + """ + fig = plt.figure(figsize=figsize, constrained_layout=True) ax = fig.subplots() for protocol in df.protocol.unique(): @@ -598,12 +649,19 @@ def plot_leak_conductance_change_sweep_to_sweep(df): sns.histplot(data=delta_df, x=var_name_ltx, hue='passed QC', stat='count', multiple='stack', ax=ax) - fig.savefig(os.path.join(output_dir, f"g_leak_sweep_to_sweep_{protocol}")) + fig.savefig(os.path.join(output_path, f"g_leak_sweep_to_sweep_{protocol}")) plt.close(fig) -def plot_spatial_Erev(df): +def plot_spatial_Erev(df, output_path, figsize=None): + """ + ??? + + @param df + @param output_path + @param figsize + """ def func(protocol, sweep): zs = [] for row in range(16): @@ -634,10 +692,11 @@ def func(protocol, sweep): zs[~np.isfinite(zs)] = 2 zs = np.array(zs).reshape((16, 24)) - fig = plt.figure(figsize=args.figsize) + fig = plt.figure(figsize=figsize) ax = fig.subplots() # add black color for NaNs + color_cycle = ["#5790fc", "#f89c20"] cmap = matplotlib.colors.ListedColormap([color_cycle[0], color_cycle[1]], 'indexed') ax.pcolormesh(zs, edgecolors='white', cmap=cmap, linewidths=1, antialiased=True) @@ -656,8 +715,8 @@ def func(protocol, sweep): # Put 'A' row at the top ax.invert_yaxis() - fig.savefig(os.path.join(output_dir, f"{protocol}_sweep{sweep}_E_Kr_map.pdf"), - format='pdf') + fig.savefig(os.path.join( + output_path, f'{protocol}_sweep{sweep}_E_Kr_map.png')) plt.close(fig) protocol = 'staircaseramp1' @@ -666,7 +725,14 @@ def func(protocol, sweep): func(protocol, sweep) -def plot_spatial_passed(df): +def plot_spatial_passed(df, output_path, passed_wells): + """ + ??? + + @param df + @param output_path + @param passed_wells + """ fig = plt.figure(figsize=(5, 3)) ax = fig.subplots() zs = [] @@ -679,6 +745,7 @@ def plot_spatial_passed(df): zs = np.array(zs).reshape(16, 24) + color_cycle = ["#5790fc", "#f89c20"] cmap = matplotlib.colors.ListedColormap([color_cycle[0], color_cycle[1]], 'indexed') _ = ax.pcolormesh(zs, edgecolors='white', linewidths=1, antialiased=True, cmap=cmap @@ -696,13 +763,21 @@ def plot_spatial_passed(df): ax.set_yticklabels(string.ascii_uppercase[:16]) ax.invert_yaxis() - fig.savefig(os.path.join(output_dir, "QC_map.pdf"), format='pdf') + fig.savefig(os.path.join(output_path, 'QC_map.png')) plt.close(fig) -def plot_histograms(df, qc_df): - fig = plt.figure(figsize=args.figsize, constrained_layout=True) +def plot_histograms(df, output_path, reversal_potential=None, figsize=None): + """ + ??? + + @param df + @param output_path + @param reversal_potential + @param figsize + """ + fig = plt.figure(figsize=figsize, constrained_layout=True) ax = fig.subplots() ax.spines[['top', 'right']].set_visible(False) @@ -714,12 +789,14 @@ def plot_histograms(df, qc_df): multiple='stack', stat='count', legend=False) ax.set_xlabel(r'$\mathrm{mean}(E_{\mathrm{obs}})$') - fig.savefig(os.path.join(output_dir, 'averaged_reversal_potential_histogram')) + fig.savefig(os.path.join( + output_path, 'averaged_reversal_potential_histogram')) - if np.isfinite(args.reversal): - ax.axvline(args.reversal, linestyle='--', color='grey', label='Calculated Nernst potential') + if reversal_potential is not None: + ax.axvline(reversal_potential, linestyle='--', color='grey', + label='Calculated Nernst potential') - fig.savefig(os.path.join(output_dir, 'reversal_potential_histogram')) + fig.savefig(os.path.join(output_path, 'reversal_potential_histogram')) vars = ['pre-drug leak magnitude', 'post-drug leak magnitude', @@ -738,13 +815,13 @@ def plot_histograms(df, qc_df): x='pre-drug leak magnitude', hue='passed QC', multiple='stack', stat='count', common_norm=False) - fig.savefig(os.path.join(output_dir, 'pre_drug_leak_magnitude')) + fig.savefig(os.path.join(output_path, 'pre_drug_leak_magnitude')) ax.cla() sns.histplot(df, x='post-drug leak magnitude', hue='passed QC', stat='count', common_norm=False, multiple='stack') - fig.savefig(os.path.join(output_dir, 'post_drug_leak_magnitude')) + fig.savefig(os.path.join(output_path, 'post_drug_leak_magnitude')) ax.cla() ax.cla() @@ -757,55 +834,52 @@ def plot_histograms(df, qc_df): legend_handles, _ = ax.get_legend_handles_labels() ax.legend(legend_handles, ['failed QC', 'passed QC'], bbox_to_anchor=(1.26, 1)) - fig.savefig(os.path.join(output_dir, 'R_leftover')) + fig.savefig(os.path.join(output_path, 'R_leftover')) ax.cla() - sns.histplot(df, - x='gleak_before', hue='passed QC', - multiple='stack', - stat='count', common_norm=False) - fig.savefig(os.path.join(output_dir, 'g_leak_before')) + kwargs = dict( + hue='passed QC', multiple='stack', stat='count', common_norm=False) + sns.histplot(df, x='gleak_before', **kwargs) + fig.savefig(os.path.join(output_path, 'g_leak_before')) ax.cla() - sns.histplot(df, - x='gleak_after', hue='passed QC', - multiple='stack', - stat='count', common_norm=False) - fig.savefig(os.path.join(output_dir, 'g_leak_after')) + sns.histplot(df, x='gleak_after', **kwargs) + fig.savefig(os.path.join(output_path, 'g_leak_after')) ax.cla() - sns.histplot(df, - x='Rseries', hue='passed QC', - multiple='stack', - stat='count', common_norm=False) - fig.savefig(os.path.join(output_dir, 'Rseries_before')) + sns.histplot(df, x='Rseries', **kwargs) + fig.savefig(os.path.join(output_path, 'Rseries_before')) ax.cla() - sns.histplot(df, - x='Rseal', hue='passed QC', - multiple='stack', - stat='count', common_norm=False) - fig.savefig(os.path.join(output_dir, 'Rseal_before')) + sns.histplot(df, x='Rseal', **kwargs) + fig.savefig(os.path.join(output_path, 'Rseal_before')) ax.cla() - sns.histplot(df, - x='Cm', hue='passed QC', multiple='stack', - stat='count', common_norm=False) - fig.savefig(os.path.join(output_dir, 'Cm_before')) + sns.histplot(df, x='Cm', **kwargs) + fig.savefig(os.path.join(output_path, 'Cm_before')) plt.close(fig) -def overlay_reversal_plots(leak_parameters_df): - fig = plt.figure(figsize=args.figsize, constrained_layout=True) +def overlay_reversal_plots( + data_path, output_path, experiment_name, leak_parameters_df, wells, + reversal_potential=None, figsize=None): + """ + ??? + + @param data_path + @param output_path + @param experiment_name + @param leak_parameters_df + @param reversal_potential + @param figsize + """ + fig = plt.figure(figsize=figsize, constrained_layout=True) ax = fig.subplots() palette = sns.color_palette('husl', len(leak_parameters_df.groupby(['protocol', 'sweep']))) - sub_dir = os.path.join(output_dir, 'overlaid_reversal_plots') - - # if args.selection_file and not args.output_all: - # leak_parameters_df[leak_parameters_df.well.isin(passed_wells)] + sub_dir = os.path.join(output_path, 'overlaid_reversal_plots') if not os.path.exists(sub_dir): os.makedirs(sub_dir) @@ -813,8 +887,6 @@ def overlay_reversal_plots(leak_parameters_df): protocols_to_plot = ['staircaseramp1'] sweeps_to_plot = [1] - # leak_parameters_df = leak_parameters_df[leak_parameters_df.well.isin(passed_wells)] - for well in wells: # Setup figure if False in leak_parameters_df[leak_parameters_df.well == well]['passed QC'].values: @@ -823,36 +895,44 @@ def overlay_reversal_plots(leak_parameters_df): for protocol in protocols_to_plot: if protocol == np.nan: continue - for sweep in sweeps_to_plot: - voltage_fname = os.path.join(args.data_dir, 'traces', - f"{experiment_name}-{protocol}-voltages.csv") - voltages = pd.read_csv(voltage_fname)['voltage'].values.flatten() - fname = f"{experiment_name}-{protocol}-{well}-sweep{sweep}.csv" + pname = protocol + if pname == 'staircaseramp1': + pname = 'staircaseramp' + elif pname == 'staircaseramp1_2': + pname = 'staircaseramp_2' + voltage_fname = os.path.join(data_path, 'traces', + f'{experiment_name}-{pname}-voltages.csv') + voltages = pd.read_csv(voltage_fname)['voltage'].values.flatten() + + times_fname = f"{experiment_name}-{pname}-times.csv" + times = np.loadtxt(os.path.join(data_path, 'traces', times_fname)) + times = times.flatten().astype(np.float64) + + # First, find the reversal ramp + json_name = os.path.join(data_path, 'traces', 'protocols', + f'{experiment_name}-{pname}.json') + with open(json_name, 'r') as f: + json_protocol = json.load(f) + v_protocol = VoltageProtocol.from_json(json_protocol) + ramps = v_protocol.get_ramps() + reversal_ramp = ramps[-1] + ramp_start, ramp_end = reversal_ramp[:2] + + # Next extract steps + istart = np.argmax(times >= ramp_start) + iend = np.argmax(times > ramp_end) + + if istart == 0 or iend == 0 or istart == iend: + raise Exception('Could not identify reversal ramp') + + for sweep in sweeps_to_plot: + fname = f"{experiment_name}-{pname}-{well}-sweep{sweep}.csv" try: - data = pd.read_csv(os.path.join(args.data_dir, 'traces', fname)) + data = pd.read_csv(os.path.join(data_path, 'traces', fname)) except FileNotFoundError: continue - times_fname = f"{experiment_name}-{protocol}-times.csv" - times = np.loadtxt(os.path.join(args.data_dir, 'traces', times_fname)) - times = times.flatten().astype(np.float64) - - # First, find the reversal ramp - json_protocol = json.load(os.path.join(args.data_dir, 'traces', 'protocols', - f"{experiment_name}-{protocol}.json")) - v_protocol = VoltageProtocol.from_json(json_protocol) - ramps = v_protocol.get_ramps() - reversal_ramp = ramps[-1] - ramp_start, ramp_end = reversal_ramp[:2] - - # Next extract steps - istart = np.argmax(times >= ramp_start) - iend = np.argmax(times > ramp_end) - - if istart == 0 or iend == 0 or istart == iend: - raise Exception("Couldn't identify reversal ramp") - # Plot voltage vs current current = data['current'].values.astype(np.float64) @@ -865,8 +945,9 @@ def overlay_reversal_plots(leak_parameters_df): ax.plot(voltages[istart:iend], fitted_poly(voltages[istart:iend]), color=col) i += 1 - if np.isfinite(args.reversal): - ax.axvline(args.reversal, linestyle='--', color='grey', label='Calculated Nernst potential') + if reversal_potential is not None: + ax.axvline(reversal_potential, linestyle='--', color='grey', + label='Calculated Nernst potential') ax.legend() # Save figure @@ -888,6 +969,12 @@ def error2(p): def create_attrition_table(qc_df, subtraction_df): + """ + ??? + + @param qc_df + @param subtraction_df + """ original_qc_criteria = ['qc1.rseal', 'qc1.cm', 'qc1.rseries', 'qc2.raw', 'qc2.subtracted', 'qc3.raw', 'qc3.E4031', @@ -896,20 +983,19 @@ def create_attrition_table(qc_df, subtraction_df): 'qc6.subtracted', 'qc6.1.subtracted', 'qc6.2.subtracted'] - subtraction_df_sc = subtraction_df[subtraction_df.protocol.isin(['staircaseramp1', - 'staircaseramp1_2'])] - R_leftover_qc = subtraction_df_sc.groupby('well')['R_leftover'].max() < 0.4 - - qc_df['QC.R_leftover'] = [R_leftover_qc.loc[well] for well in subtraction_df.well.unique()] + # subtraction_df_sc = subtraction_df[ + # subtraction_df.protocol.isin(['staircaseramp1', 'staircaseramp1_2'])] + # R_leftover_qc = subtraction_df_sc.groupby('well')['R_leftover'].max() < 0.4 + # qc_df['QC.R_leftover'] = [R_leftover_qc.loc[well] for well in subtraction_df.well.unique()] stage_3_criteria = original_qc_criteria + ['QC1.all_protocols', 'QC4.all_protocols', 'QC6.all_protocols'] stage_4_criteria = stage_3_criteria + ['qc3.bookend'] stage_5_criteria = stage_4_criteria + ['QC.Erev.all_protocols', 'QC.Erev.spread'] - stage_6_criteria = stage_5_criteria + ['QC.R_leftover'] + # stage_6_criteria = stage_5_criteria + ['QC.R_leftover'] - agg_dict = {crit: 'min' for crit in stage_6_criteria} + agg_dict = {crit: 'min' for crit in stage_5_criteria} qc_df_sc1 = qc_df[qc_df.protocol == 'staircaseramp1'] print(qc_df_sc1.values.shape) @@ -935,11 +1021,11 @@ def create_attrition_table(qc_df, subtraction_df): .agg(agg_dict)[stage_5_criteria].values, axis=1)) - n_stage_6_wells = np.sum(np.all(qc_df.groupby('well') - .agg(agg_dict)[stage_6_criteria].values, - axis=1)) + # n_stage_6_wells = np.sum( + # np.all(qc_df.groupby('well').agg(agg_dict)[stage_6_criteria].values, + # axis=1)) - passed_qc_df = qc_df.groupby('well').agg(agg_dict)[stage_6_criteria] + passed_qc_df = qc_df.groupby('well').agg(agg_dict)[stage_5_criteria] print(passed_qc_df) passed_wells = [well for well, row in passed_qc_df.iterrows() if np.all(row.values)] @@ -951,7 +1037,7 @@ def create_attrition_table(qc_df, subtraction_df): 'stage3': [n_stage_3_wells], 'stage4': [n_stage_4_wells], 'stage5': [n_stage_5_wells], - 'stage6': [n_stage_6_wells], + # 'stage6': [n_stage_6_wells], } res_df = pd.DataFrame.from_records(res_dict) @@ -959,4 +1045,4 @@ def create_attrition_table(qc_df, subtraction_df): if __name__ == "__main__": - main() + run_from_command_line() diff --git a/tests/__init__.py b/tests/__init__.py index e69de29b..3c8bb15c 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -0,0 +1,5 @@ +# +# +# Unit tests for pcpostproces +# +# diff --git a/tests/test_scripts.py b/tests/test_scripts.py new file mode 100755 index 00000000..9fe79dfb --- /dev/null +++ b/tests/test_scripts.py @@ -0,0 +1,46 @@ +#!/usr/bin/env python3 +import os +import tempfile +import unittest + +from pcpostprocess.scripts.run_herg_qc import run as run_herg_qc +from pcpostprocess.scripts.summarise_herg_export import run as run_summarise + +store_output = False + + +class TestScripts(unittest.TestCase): + """ + Tests the scripts bundled with pcpostprocess. + """ + + def test_run_herg_qc_and_summarise_herg_export(self): + # Test run_herg_qc_, then summarise_herg_export + + data = os.path.join('tests', 'test_data', '13112023_MW2_FF') + with tempfile.TemporaryDirectory() as d: + if store_output: + d = 'test_output' + d1 = os.path.join(d, 'run_herg_qc') + d2 = os.path.join(d, 'summarise_herg_export') + + # Test run herg qc + erev = -90.71 + qc_map = {'staircaseramp (2)_2kHz': 'staircaseramp'} + write_map = {'staircaseramp2': 'staircaseramp2'} + run_herg_qc( + data, d1, qc_map, ('A03', 'A20', 'D16'), + write_traces=True, write_map=write_map, + save_id='13112023_MW2', reversal_potential=erev) + + with open(os.path.join(d1, 'passed_wells.txt'), 'r') as f: + self.assertEqual(f.read().strip(), 'A03') + + # Test summarise herg export + run_summarise(d1, d2, '13112023_MW2', reversal_potential=erev) + + +if __name__ == '__main__': + store_output = True + unittest.main() +