From 440a78dbf9982d6af3964890c169c44e9de97ad0 Mon Sep 17 00:00:00 2001 From: priscila Date: Wed, 6 Nov 2024 16:41:35 +1100 Subject: [PATCH] Adding simulated data notebook Co-authored-by: Gabriella Chan --- ZQ003/scripts/make_simulation.ipynb | 313 ++++++++++++++++++++++++++++ 1 file changed, 313 insertions(+) create mode 100644 ZQ003/scripts/make_simulation.ipynb diff --git a/ZQ003/scripts/make_simulation.ipynb b/ZQ003/scripts/make_simulation.ipynb new file mode 100644 index 0000000..9344a50 --- /dev/null +++ b/ZQ003/scripts/make_simulation.ipynb @@ -0,0 +1,313 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "from scipy.signal import lfilter, butter \n", + "from scipy.stats import poisson\n", + "import matplotlib.pyplot as plt\n", + "from scipy.interpolate import make_interp_spline\n", + "import random\n", + "import pandas as pd\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Parameters\n", + "sample_rate = 1017.25 # (Hz) - based on normal data recording rate\n", + "t = 1800 # (s) - based on std experiment\n", + "cutoff = 0.1 # based on OG simulation paper\n", + "n_dtpts = int(sample_rate*t) # Number of data points\n", + "movement_attenuation = 50 # Example attenuation percentage as per OG sim paper\n", + "noise_factor = 2 # as per OG sim paper\n", + "time_pts = np.linspace(0,t,n_dtpts)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def calculate_movement_component(cutoff = 0.1, sample_rate = sample_rate, movement_attenuation = 50):\n", + " '''\n", + " Calculate the movement component of the signal, \n", + " based on a lowpass filtered random data and movement attenuation parameter\n", + " '''\n", + "\n", + " b, a = butter(N=4, Wn=cutoff / (sample_rate / 2), btype='low') # check cutofffffffff\n", + "\n", + " # Apply the filter\n", + " lowpass_values = lfilter(b, a, np.random.rand(n_dtpts))\n", + "\n", + " movement_component = 1 - (lowpass_values * (movement_attenuation / 100))\n", + " return movement_component\n", + "\n", + "def calculate_decay_component(time_pts, decay_rate1 = 0.02, decay_rate2 = 0.002, decay_base = 40):\n", + " '''\n", + " Make a double exponatial decaying curve, sampled at every time_pts\n", + " '''\n", + " decay_rate = ((1 - decay_rate1) ** time_pts + (1 - decay_rate2) ** time_pts) / 2\n", + " print(np.shape(decay_rate))\n", + " decay = decay_rate*(decay_base/100)+(1-decay_base/100)\n", + " \n", + " return decay\n", + " \n", + "\n", + "def calculate_ERT(lambda_val = 2, peak = 1, scale = 5, vis = False):\n", + " '''\n", + " Makes a Poisson distribution, \n", + " with mean = lambda_val, range = t, max value = peak\n", + " '''\n", + " # evaluate lambda over a duration 5 times longer to capture the whole distribution\n", + " t = lambda_val*5\n", + "\n", + " # Generate discrete values of the theoretical Poisson probability mass\n", + " # function (pmf) from 0 to t\n", + " x = np.arange(0, t)\n", + " pmf = poisson.pmf(x, lambda_val)\n", + " # Rescale x axis. The lowest reasonable value of lambda is 2,\n", + " # corresponding to t = 10, our response timescale is >50ms\n", + " x = x * scale\n", + " # Rescale y axis\n", + " pmf = pmf/max(pmf)*peak\n", + " # print(max(pmf))\n", + "\n", + " # Interpolate pmf\n", + " b = make_interp_spline(x, pmf, k=2) # b spline interpolation\n", + " x = np.arange(0, t * scale)\n", + " pmf = b(x)\n", + " # print(max(pmf))\n", + "\n", + " # Reindex where pmf values are >= 0.01\n", + " indices = np.where(pmf >= 0.01)[0]\n", + " pmf = pmf[indices]\n", + " # x = x[indices]\n", + " # x = np.arange(len(x))\n", + " # print(max(pmf))\n", + " \n", + " if vis:\n", + " plt.plot(x, pmf)\n", + " plt.xlabel('time (ms; 1017.25Hz)')\n", + " plt.show()\n", + " \n", + " return pmf\n", + "\n", + "\n", + "def calculate_noise_component(n_dtpts, sample_rate, noise_factor=8):\n", + " '''\n", + " Make a vector of length n_dtpts with random noised scaled by noise_factor\n", + " '''\n", + " noise_component = np.random.randn(n_dtpts) * noise_factor\n", + "\n", + " # b = sig.firwin(noise_component, cutoff=[1], fs=data.attrs['fs'],\n", + " # pass_zero=False)\n", + " # noise_component = detrend.filter_b = b\n", + " # b, a = butter(N=6, Wn=0.99, btype='low')\n", + "\n", + " # # Apply the filter\n", + " # noise_component = lfilter(b, a, np.random.rand(n_dtpts))\n", + " \n", + " return noise_component\n", + "\n", + "\n", + "\n", + "\n", + " \n", + "\n", + " " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# %%\n", + "def make_event(n_dtpts = 1000, n_events = False, lambda_val = False, peak_m = False, vis = False, delay_og = 0):\n", + "\n", + " true_signal = np.zeros(n_dtpts)\n", + " if not n_events: n_events = random.randint(2,3)\n", + " events = np.zeros(n_dtpts)\n", + " if not lambda_val: lambda_val = random.randint(2, 5)\n", + " if not peak_m: peak_m = random.uniform(5,15)\n", + "\n", + " for i in range(n_events):\n", + " delay = delay_og + random.randint(0,5)\n", + " peak = peak_m + random.uniform(-2, 2)\n", + " print(peak)\n", + " ert = calculate_ERT(lambda_val, peak, scale=10)\n", + " event_duration = len(ert)\n", + "\n", + " initial_response = random.randint(delay, len(true_signal)-event_duration)\n", + " events[initial_response] = 1\n", + "\n", + " true_signal[initial_response:initial_response+event_duration] += ert\n", + "\n", + " if vis: plt.plot(true_signal)\n", + " \n", + " return events, true_signal" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "ert = calculate_ERT(20, 7, scale=10)\n", + "max(ert)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Put it all together to make the simmulated signal made of noise, underlying true signal, photobleaching decay and movement \n", + "\n", + "events1, true_signal1 = make_event(n_dtpts=n_dtpts, n_events=20, delay_og=1, peak_m = 8, vis = False, lambda_val=2)\n", + "events2, true_signal2 = make_event(n_dtpts=n_dtpts, n_events=15, delay_og=2, peak_m = 10, vis = False, lambda_val=20)\n", + "events3, true_signal3 = make_event(n_dtpts=n_dtpts, n_events=21, delay_og=0, peak_m = 12, vis = False, lambda_val=50)\n", + "\n", + "true_signal = true_signal1 + true_signal2 + true_signal3\n", + "\n", + "movement_component = calculate_movement_component(cutoff, sample_rate, movement_attenuation)\n", + "\n", + "noise_component = calculate_noise_component(n_dtpts, sample_rate)\n", + "noise_component_iso = calculate_noise_component(n_dtpts, sample_rate)\n", + "\n", + "decay_component = calculate_decay_component(time_pts)\n", + "\n", + "data = (true_signal + 200) * movement_component * decay_component + noise_component\n", + "\n", + "isob = 100 * movement_component * decay_component + noise_component_iso\n", + "\n", + "# np.save('C:\\Users\\levip\\Desktop\\NSB\\BrainHack\\behapy\\SIM\\rawdata\\sub-test1\\ses-TEST1\\sub-test1_ses-TEST.2_task-TEST_run-1_label-LNAc_channel-iso.npy', isob)\n", + "\n", + "p, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2,2)\n", + "ax1.plot(true_signal)\n", + "ax2 = plt.subplot(2,2, 2)\n", + "ax2.plot(decay_component)\n", + "ax3 = plt.subplot(2,2, 3)\n", + "ax3.plot(movement_component)\n", + "ax4 = plt.subplot(2,2, 4)\n", + "ax4.plot(noise_component)\n", + "plt.show()\n", + "\n", + "# plt.plot(decay_component)\n", + "# plt.show()\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "plt.plot(true_signal1)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Rescale onset of each event from index to seconds\n", + "aa = np.where(events1 == 1)[0] / sample_rate\n", + "bb = np.where(events2 == 1)[0] / sample_rate\n", + "cc = np.where(events3 == 1)[0] / sample_rate\n", + "\n", + "# Combine all event times and labels\n", + "onsets = np.concatenate([aa, bb, cc])\n", + "duration = [0.1] * len(onsets)\n", + "event_ids = ['event1'] * len(aa) + ['event2'] * len(bb) + ['event3'] * len(cc)\n", + "\n", + "# Create the DataFrame and sort by time\n", + "df = pd.DataFrame({'onset': onsets, 'duration': duration, 'event_id': event_ids}).sort_values(by='onset').reset_index(drop=True)\n", + "df = df.set_index('onset')\n", + "\n", + "\n", + "df.to_csv(r'\\Users\\levip\\Desktop\\NSB\\BrainHack\\behapy\\SIM\\rawdata\\sub-test1\\ses-TEST1\\sub-test1_ses-TEST1_task-TEST_run-1_events.csv')\n", + "\n", + "# print(df)\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "np.save('/Users/levip/Desktop/NSB/BrainHack/behapy/SIM/rawdata/sub-test1/ses-TEST1/fp/sub-test1_ses-TEST1_task-TEST_run-1_label-LNAc_channel-ACh.npy', data)\n", + "np.save('/Users/levip/Desktop/NSB/BrainHack/behapy/SIM/rawdata/sub-test1/ses-TEST1/fp/sub-test1_ses-TEST1_task-TEST_run-1_label-LNAc_channel-iso.npy', isob)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "plt.plot(data)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "#%%\n", + "import numpy as np\n", + "import holoviews as hv\n", + "import datashader as ds\n", + "from holoviews.operation.datashader import datashade\n", + "from bokeh.plotting import output_notebook\n", + "\n", + "# Enable Bokeh and Holoviews support in the notebook\n", + "hv.extension('bokeh')\n", + "# output_notebook()\n", + "\n", + "# Convert data to a Holoviews Curve\n", + "curve = hv.Curve((np.arange(len(true_signal3)), true_signal3))\n", + "shaded_curve = datashade(curve).opts(width=800)\n", + "\n", + "shaded_curve" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "behapy", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.7" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +}