Coverage for src / thunderfish / pulses.py: 0%
602 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-15 17:50 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-15 17:50 +0000
1"""
2Extract and cluster EOD waverforms of pulse-type electric fish.
4## Main function
6- `extract_pulsefish()`: checks for pulse-type fish based on the EOD amplitude and shape.
8"""
10import os
11import numpy as np
13from scipy import stats
14from scipy.interpolate import interp1d
15from sklearn.preprocessing import StandardScaler
16from sklearn.decomposition import PCA
17from sklearn.cluster import DBSCAN
18from sklearn.mixture import BayesianGaussianMixture
19from sklearn.metrics import pairwise_distances
20from thunderlab.eventdetection import detect_peaks, median_std_threshold
22from .pulseplots import *
24import warnings
25def warn(*args, **kwargs):
26 """
27 Ignore all warnings.
28 """
29 pass
30warnings.warn = warn
32try:
33 from numba import jit
34except ImportError:
35 def jit(*args, **kwargs):
36 def decorator_jit(func):
37 return func
38 return decorator_jit
41# upgrade numpy functions for backwards compatibility:
42if not hasattr(np, 'isin'):
43 np.isin = np.in1d
45def unique_counts(ar):
46 """ Find the unique elements of an array and their counts, ignoring shape.
48 The code is condensed from numpy version 1.17.0.
50 Parameters
51 ----------
52 ar : numpy array
53 Input array
55 Returns
56 -------
57 unique_vaulues : numpy array
58 Unique values in array ar.
59 unique_counts : numpy array
60 Number of instances for each unique value in ar.
61 """
62 try:
63 return np.unique(ar, return_counts=True)
64 except TypeError:
65 ar = np.asanyarray(ar).flatten()
66 ar.sort()
67 mask = np.empty(ar.shape, dtype=bool_)
68 mask[:1] = True
69 mask[1:] = ar[1:] != ar[:-1]
70 idx = np.concatenate(np.nonzero(mask) + ([mask.size],))
71 return ar[mask], np.diff(idx)
74###########################################################################
77def extract_pulsefish(data, rate, amax, width_factor_shape=3,
78 width_factor_wave=8, width_factor_display=4,
79 verbose=0, plot_level=0, save_plots=False,
80 save_path='', ftype='png', return_data=[]):
81 """ Extract and cluster pulse-type fish EODs from single channel data.
83 Takes recording data containing an unknown number of pulsefish and extracts the mean
84 EOD and EOD timepoints for each fish present in the recording.
86 Parameters
87 ----------
88 data: 1-D array of float
89 The data to be analysed.
90 rate: float
91 Sampling rate of the data in Hertz.
92 amax: float
93 Maximum amplitude of data range.
94 width_factor_shape : float (optional)
95 Width multiplier used for EOD shape analysis.
96 EOD snippets are extracted based on width between the
97 peak and trough multiplied by the width factor.
98 width_factor_wave : float (optional)
99 Width multiplier used for wavefish detection.
100 width_factor_display : float (optional)
101 Width multiplier used for EOD mean extraction and display.
102 verbose : int (optional)
103 Verbosity level.
104 plot_level : int (optional)
105 Similar to verbosity levels, but with plots.
106 Only set to > 0 for debugging purposes.
107 save_plots : bool (optional)
108 Set to True to save the plots created by plot_level.
109 save_path: string (optional)
110 Path for saving plots.
111 ftype : string (optional)
112 Define the filetype to save the plots in if save_plots is set to True.
113 Options are: 'png', 'jpg', 'svg' ...
114 return_data : list of strings (optional)
115 Specify data that should be logged and returned in a dictionary. Each clustering
116 step has a specific keyword that results in adding different variables to the log dictionary.
117 Optional keys for return_data and the resulting additional key-value pairs to the log dictionary are:
119 - 'all_eod_times':
120 - 'all_times': list of two lists of floats.
121 All peak (`all_times[0]`) and trough times (`all_times[1]`) extracted
122 by the peak detection algorithm. Times are given in seconds.
123 - 'eod_troughtimes': list of 1D arrays.
124 The timepoints in seconds of each unique extracted EOD cluster,
125 where each 1D array encodes one cluster.
127 - 'peak_detection':
128 - "data": 1D numpy array of floats.
129 Quadratically interpolated data which was used for peak detection.
130 - "interp_fac": float.
131 Interpolation factor of raw data.
132 - "peaks_1": 1D numpy array of ints.
133 Peak indices on interpolated data after first peak detection step.
134 - "troughs_1": 1D numpy array of ints.
135 Peak indices on interpolated data after first peak detection step.
136 - "peaks_2": 1D numpy array of ints.
137 Peak indices on interpolated data after second peak detection step.
138 - "troughs_2": 1D numpy array of ints.
139 Peak indices on interpolated data after second peak detection step.
140 - "peaks_3": 1D numpy array of ints.
141 Peak indices on interpolated data after third peak detection step.
142 - "troughs_3": 1D numpy array of ints.
143 Peak indices on interpolated data after third peak detection step.
144 - "peaks_4": 1D numpy array of ints.
145 Peak indices on interpolated data after fourth peak detection step.
146 - "troughs_4": 1D numpy array of ints.
147 Peak indices on interpolated data after fourth peak detection step.
149 - 'all_cluster_steps':
150 - 'rate': float.
151 Sampling rate of interpolated data.
152 - 'EOD_widths': list of three 1D numpy arrays.
153 The first list entry gives the unique labels of all width clusters
154 as a list of ints.
155 The second list entry gives the width values for each EOD in samples
156 as a 1D numpy array of ints.
157 The third list entry gives the width labels for each EOD
158 as a 1D numpy array of ints.
159 - 'EOD_heights': nested lists (2 layers) of three 1D numpy arrays.
160 The first list entry gives the unique labels of all height clusters
161 as a list of ints for each width cluster.
162 The second list entry gives the height values for each EOD
163 as a 1D numpy array of floats for each width cluster.
164 The third list entry gives the height labels for each EOD
165 as a 1D numpy array of ints for each width cluster.
166 - 'EOD_shapes': nested lists (3 layers) of three 1D numpy arrays
167 The first list entry gives the raw EOD snippets as a 2D numpy array
168 for each height cluster in a width cluster.
169 The second list entry gives the snippet PCA values for each EOD
170 as a 2D numpy array of floats for each height cluster in a width cluster.
171 The third list entry gives the shape labels for each EOD as a 1D numpy array
172 of ints for each height cluster in a width cluster.
173 - 'discarding_masks': Nested lists (two layers) of 1D numpy arrays.
174 The masks of EODs that are discarded by the discarding step of the algorithm.
175 The masks are 1D boolean arrays where instances that are set to True are
176 discarded by the algorithm. Discarding masks are saved in nested lists
177 that represent the width and height clusters.
178 - 'merge_masks': Nested lists (two layers) of 2D numpy arrays.
179 The masks of EODs that are discarded by the merging step of the algorithm.
180 The masks are 2D boolean arrays where for each sample point `i` either
181 `merge_mask[i,0]` or `merge_mask[i,1]` is set to True. Here, merge_mask[:,0]
182 represents the peak-centered clusters and `merge_mask[:,1]` represents the
183 trough-centered clusters. Merge masks are saved in nested lists that
184 represent the width and height clusters.
186 - 'BGM_width':
187 - 'BGM_width': dictionary
188 - 'x': 1D numpy array of floats.
189 BGM input values (in this case the EOD widths),
190 - 'use_log': boolean.
191 True if the z-scored logarithm of the data was used as BGM input.
192 - 'BGM': list of three 1D numpy arrays.
193 The first instance are the weights of the Gaussian fits.
194 The second instance are the means of the Gaussian fits.
195 The third instance are the variances of the Gaussian fits.
196 - 'labels': 1D numpy array of ints.
197 Labels defined by BGM model (before merging based on merge factor).
198 - xlab': string.
199 Label for plot (defines the units of the BGM data).
201 - 'BGM_height':
202 This key adds a new dictionary for each width cluster.
203 - 'BGM_height_*n*' : dictionary, where *n* defines the width cluster as an int.
204 - 'x': 1D numpy array of floats.
205 BGM input values (in this case the EOD heights),
206 - 'use_log': boolean.
207 True if the z-scored logarithm of the data was used as BGM input.
208 - 'BGM': list of three 1D numpy arrays.
209 The first instance are the weights of the Gaussian fits.
210 The second instance are the means of the Gaussian fits.
211 The third instance are the variances of the Gaussian fits.
212 - 'labels': 1D numpy array of ints.
213 Labels defined by BGM model (before merging based on merge factor).
214 - 'xlab': string.
215 Label for plot (defines the units of the BGM data).
217 - 'snippet_clusters':
218 This key adds a new dictionary for each height cluster.
219 - 'snippet_clusters*_n_m_p*' : dictionary, where *n* defines the width cluster
220 (int), *m* defines the height cluster (int) and *p* defines shape clustering
221 on peak or trough centered EOD snippets (string: 'peak' or 'trough').
222 - 'raw_snippets': 2D numpy array (nsamples, nfeatures).
223 Raw EOD snippets.
224 - 'snippets': 2D numpy array.
225 Normalized EOD snippets.
226 - 'features': 2D numpy array.(nsamples, nfeatures)
227 PCA values for each normalized EOD snippet.
228 - 'clusters': 1D numpy array of ints.
229 Cluster labels.
230 - 'rate': float.
231 Sampling rate of snippets.
233 - 'eod_deletion':
234 This key adds two dictionaries for each (peak centered) shape cluster,
235 where *cluster* (int) is the unique shape cluster label.
236 - 'mask_*cluster*' : list of four booleans.
237 The mask for each cluster discarding step.
238 The first instance represents the artefact masks, where artefacts
239 are set to True.
240 The second instance represents the unreliable cluster masks,
241 where unreliable clusters are set to True.
242 The third instance represents the wavefish masks, where wavefish
243 are set to True.
244 The fourth instance represents the sidepeak masks, where sidepeaks
245 are set to True.
246 - 'vals_*cluster*' : list of lists.
247 All variables that are used for each cluster deletion step.
248 The first instance is a list of two 1D numpy arrays: the mean EOD and
249 the FFT of that mean EOD.
250 The second instance is a 1D numpy array with all EOD width to ISI ratios.
251 The third instance is a list with three entries:
252 The first entry is a 1D numpy array zoomed out version of the mean EOD.
253 The second entry is a list of two 1D numpy arrays that define the peak
254 and trough indices of the zoomed out mean EOD.
255 The third entry contains a list of two values that represent the
256 peak-trough pair in the zoomed out mean EOD with the largest height
257 difference.
258 - 'rate' : float.
259 EOD snippet sampling rate.
261 - 'masks':
262 - 'masks' : 2D numpy array (4,N).
263 Each row contains masks for each EOD detected by the EOD peakdetection step.
264 The first row defines the artefact masks, the second row defines the
265 unreliable EOD masks,
266 the third row defines the wavefish masks and the fourth row defines
267 the sidepeak masks.
269 - 'moving_fish':
270 - 'moving_fish': dictionary.
271 - 'w' : list of floats.
272 Median width for each width cluster that the moving fish algorithm is
273 computed on (in seconds).
274 - 'T' : list of floats.
275 Lenght of analyzed recording for each width cluster (in seconds).
276 - 'dt' : list of floats.
277 Sliding window size (in seconds) for each width cluster.
278 - 'clusters' : list of 1D numpy int arrays.
279 Cluster labels for each EOD cluster in a width cluster.
280 - 't' : list of 1D numpy float arrays.
281 EOD emission times for each EOD in a width cluster.
282 - 'fishcount' : list of lists.
283 Sliding window timepoints and fishcounts for each width cluster.
284 - 'ignore_steps' : list of 1D int arrays.
285 Mask for fishcounts that were ignored (ignored if True) in the
286 moving_fish analysis.
288 Returns
289 -------
290 mean_eods: list of 2D arrays (3, eod_length)
291 The average EOD for each detected fish. First column is time in seconds,
292 second column the mean eod, third column the standard error.
293 eod_times: list of 1D arrays
294 For each detected fish the times of EOD peaks or troughs in seconds.
295 Use these timepoints for EOD averaging.
296 eod_peaktimes: list of 1D arrays
297 For each detected fish the times of EOD peaks in seconds.
298 zoom_window: tuple of floats
299 Start and endtime of suggested window for plotting EOD timepoints.
300 log_dict: dictionary
301 Dictionary with logged variables, where variables to log are specified
302 by `return_data`.
303 """
304 if verbose > 0:
305 print('')
306 if verbose > 1:
307 print(70*'#')
308 print('##### extract_pulsefish', 46*'#')
310 if (save_plots and plot_level>0 and save_path):
311 # create folder to save things in.
312 if not os.path.exists(save_path):
313 os.makedirs(save_path)
314 else:
315 save_path = ''
317 mean_eods, eod_times, eod_peaktimes, zoom_window = [], [], [], []
318 log_dict = {}
320 # interpolate:
321 i_rate = 500000.0
322 #i_rate = rate
323 try:
324 f = interp1d(np.arange(len(data))/rate, data, kind='quadratic')
325 i_data = f(np.arange(0.0, (len(data)-1)/rate, 1.0/i_rate))
326 except MemoryError:
327 i_rate = rate
328 i_data = data
329 log_dict['data'] = i_data # TODO: could be removed
330 log_dict['rate'] = i_rate # TODO: could be removed
331 log_dict['i_data'] = i_data
332 log_dict['i_rate'] = i_rate
333 # log_dict["interp_fac"] = interp_fac # TODO: is not set anymore
335 # standard deviation of data in small snippets:
336 win_size = int(0.002*rate) # 2ms windows
337 threshold = median_std_threshold(data, win_size) # TODO make this a parameter
339 # extract peaks:
340 if 'peak_detection' in return_data:
341 x_peak, x_trough, eod_heights, eod_widths, pd_log_dict = \
342 detect_pulses(i_data, i_rate, threshold,
343 width_fac=np.max([width_factor_shape,
344 width_factor_display,
345 width_factor_wave]),
346 verbose=verbose-1, return_data=True)
347 log_dict.update(pd_log_dict)
348 else:
349 x_peak, x_trough, eod_heights, eod_widths = \
350 detect_pulses(i_data, i_rate, threshold,
351 width_fac=np.max([width_factor_shape,
352 width_factor_display,
353 width_factor_wave]),
354 verbose=verbose-1, return_data=False)
356 if len(x_peak) > 0:
357 # cluster
358 clusters, x_merge, c_log_dict = cluster(x_peak, x_trough,
359 eod_heights,
360 eod_widths, i_data,
361 i_rate,
362 width_factor_shape,
363 width_factor_wave,
364 merge_threshold_height=0.1*amax,
365 verbose=verbose-1,
366 plot_level=plot_level-1,
367 save_plots=save_plots,
368 save_path=save_path,
369 ftype=ftype,
370 return_data=return_data)
372 # extract mean eods and times
373 mean_eods, eod_times, eod_peaktimes, eod_troughtimes, cluster_labels = \
374 extract_means(i_data, x_merge, x_peak, x_trough, eod_widths, clusters,
375 i_rate, width_factor_display, verbose=verbose-1)
377 # determine clipped clusters (save them, but ignore in other steps)
378 clusters, clipped_eods, clipped_times, clipped_peaktimes, clipped_troughtimes = \
379 find_clipped_clusters(clusters, mean_eods, eod_times, eod_peaktimes,
380 eod_troughtimes, cluster_labels, width_factor_display,
381 verbose=verbose-1)
383 # delete the moving fish
384 clusters, zoom_window, mf_log_dict = \
385 delete_moving_fish(clusters, x_merge/i_rate, len(data)/rate,
386 eod_heights, eod_widths/i_rate, i_rate,
387 verbose=verbose-1, plot_level=plot_level-1, save_plot=save_plots,
388 save_path=save_path, ftype=ftype, return_data=return_data)
390 if 'moving_fish' in return_data:
391 log_dict['moving_fish'] = mf_log_dict
393 clusters = remove_sparse_detections(clusters, eod_widths, i_rate,
394 len(data)/rate, verbose=verbose-1)
396 # extract mean eods
397 mean_eods, eod_times, eod_peaktimes, eod_troughtimes, cluster_labels = \
398 extract_means(i_data, x_merge, x_peak, x_trough, eod_widths,
399 clusters, i_rate, width_factor_display, verbose=verbose-1)
401 mean_eods.extend(clipped_eods)
402 eod_times.extend(clipped_times)
403 eod_peaktimes.extend(clipped_peaktimes)
404 eod_troughtimes.extend(clipped_troughtimes)
406 if plot_level > 0:
407 plot_all(data, eod_peaktimes, eod_troughtimes, rate, mean_eods)
408 if save_plots:
409 plt.savefig('%sextract_pulsefish_results.%s' % (save_path, ftype))
410 if save_plots:
411 plt.close('all')
413 if 'all_eod_times' in return_data:
414 log_dict['all_times'] = [x_peak/i_rate, x_trough/i_rate]
415 log_dict['eod_troughtimes'] = eod_troughtimes
417 log_dict.update(c_log_dict)
419 return mean_eods, eod_times, eod_peaktimes, zoom_window, log_dict
422def detect_pulses(data, rate, thresh, min_rel_slope_diff=0.25,
423 min_width=0.00005, max_width=0.01, width_fac=5.0,
424 verbose=0, return_data=False):
425 """Detect pulses in data.
427 Was `def extract_eod_times(data, rate, width_factor,
428 interp_freq=500000, max_peakwidth=0.01,
429 min_peakwidth=None, verbose=0, return_data=[],
430 save_path='')` before.
432 Parameters
433 ----------
434 data: 1-D array of float
435 The data to be analysed.
436 rate: float
437 Sampling rate of the data.
438 thresh: float
439 Threshold for peak and trough detection via `detect_peaks()`.
440 Must be a positive number that sets the minimum difference
441 between a peak and a trough.
442 min_rel_slope_diff: float
443 Minimum required difference between left and right slope (between
444 peak and troughs) relative to mean slope for deciding which trough
445 to take besed on slope difference.
446 min_width: float
447 Minimum width (peak-trough distance) of pulses in seconds.
448 max_width: float
449 Maximum width (peak-trough distance) of pulses in seconds.
450 width_fac: float
451 Pulses extend plus or minus `width_fac` times their width
452 (distance between peak and assigned trough).
453 Only pulses are returned that can fully be analysed with this width.
454 verbose : int (optional)
455 Verbosity level.
456 return_data : bool
457 If `True` data of this function is logged and returned (see
458 extract_pulsefish()).
460 Returns
461 -------
462 peak_indices: array of ints
463 Indices of EOD peaks in data.
464 trough_indices: array of ints
465 Indices of EOD troughs in data. There is one x_trough for each x_peak.
466 heights: array of floats
467 EOD heights for each x_peak.
468 widths: array of ints
469 EOD widths for each x_peak (in samples).
470 peak_detection_result : dictionary
471 Key value pairs of logged data.
472 This is only returned if `return_data` is `True`.
474 """
475 peak_detection_result = {}
477 # detect peaks and troughs in the data:
478 peak_indices, trough_indices = detect_peaks(data, thresh)
479 if verbose > 0:
480 print('Peaks/troughs detected in data: %5d %5d'
481 % (len(peak_indices), len(trough_indices)))
482 if return_data:
483 peak_detection_result.update(peaks_1=np.array(peak_indices),
484 troughs_1=np.array(trough_indices))
485 if len(peak_indices) < 2 or \
486 len(trough_indices) < 2 or \
487 len(peak_indices) > len(data)/20:
488 # TODO: if too many peaks increase threshold!
489 if verbose > 0:
490 print('No or too many peaks/troughs detected in data.')
491 if return_data:
492 return np.array([], dtype=int), np.array([], dtype=int), \
493 np.array([]), np.array([], dtype=int), peak_detection_result
494 else:
495 return np.array([], dtype=int), np.array([], dtype=int), \
496 np.array([]), np.array([], dtype=int)
498 # assign troughs to peaks:
499 peak_indices, trough_indices, heights, widths, slopes = \
500 assign_side_peaks(data, peak_indices, trough_indices, min_rel_slope_diff)
501 if verbose > 1:
502 print('Number of peaks after assigning side-peaks: %5d'
503 % (len(peak_indices)))
504 if return_data:
505 peak_detection_result.update(peaks_2=np.array(peak_indices),
506 troughs_2=np.array(trough_indices))
508 # check widths:
509 keep = ((widths>min_width*rate) & (widths<max_width*rate))
510 peak_indices = peak_indices[keep]
511 trough_indices = trough_indices[keep]
512 heights = heights[keep]
513 widths = widths[keep]
514 slopes = slopes[keep]
515 if verbose > 1:
516 print('Number of peaks after checking pulse width: %5d'
517 % (len(peak_indices)))
518 if return_data:
519 peak_detection_result.update(peaks_3=np.array(peak_indices),
520 troughs_3=np.array(trough_indices))
522 # discard connected peaks:
523 same = np.nonzero(trough_indices[:-1] == trough_indices[1:])[0]
524 keep = np.ones(len(trough_indices), dtype=bool)
525 for i in same:
526 # same troughs at trough_indices[i] and trough_indices[i+1]:
527 s = slopes[i:i+2]
528 rel_slopes = np.abs(np.diff(s))[0]/np.mean(s)
529 if rel_slopes > min_rel_slope_diff:
530 keep[i+(s[1]<s[0])] = False
531 else:
532 keep[i+(heights[i+1]<heights[i])] = False
533 peak_indices = peak_indices[keep]
534 trough_indices = trough_indices[keep]
535 heights = heights[keep]
536 widths = widths[keep]
537 if verbose > 1:
538 print('Number of peaks after merging pulses: %5d'
539 % (len(peak_indices)))
540 if return_data:
541 peak_detection_result.update(peaks_4=np.array(peak_indices),
542 troughs_4=np.array(trough_indices))
543 if len(peak_indices) == 0:
544 if verbose > 0:
545 print('No peaks remain as pulse candidates.')
546 if return_data:
547 return np.array([], dtype=int), np.array([], dtype=int), \
548 np.array([]), np.array([], dtype=int), peak_detection_result
549 else:
550 return np.array([], dtype=int), np.array([], dtype=int), \
551 np.array([]), np.array([], dtype=int)
553 # only take those where the maximum cutwidth does not cause issues -
554 # if the width_fac times the width + x is more than length.
555 keep = ((peak_indices - widths > 0) &
556 (peak_indices + widths < len(data)) &
557 (trough_indices - widths > 0) &
558 (trough_indices + widths < len(data)))
560 if verbose > 0:
561 print('Remaining peaks after EOD extraction: %5d'
562 % (p.sum(keep)))
563 print('')
565 if return_data:
566 return peak_indices[keep], trough_indices[keep], \
567 heights[keep], widths[keep], peak_detection_result
568 else:
569 return peak_indices[keep], trough_indices[keep], \
570 heights[keep], widths[keep]
573@jit(nopython=True)
574def assign_side_peaks(data, peak_indices, trough_indices,
575 min_rel_slope_diff=0.25):
576 """Assign to each peak the trough resulting in a pulse with the steepest slope or largest height.
578 The slope between a peak and a trough is computed as the height
579 difference divided by the distance between peak and trough. If the
580 slopes between the left and the right trough differ by less than
581 `min_rel_slope_diff`, then just the heigths between and the two
582 troughs relative to the peak are compared.
584 Was `def detect_eod_peaks(data, main_indices, side_indices,
585 max_width=20, min_width=2, verbose=0)` before.
587 Parameters
588 ----------
589 data: array of floats
590 Data in which the events were detected.
591 peak_indices: array of ints
592 Indices of the detected peaks in the data time series.
593 trough_indices: array of ints
594 Indices of the detected troughs in the data time series.
595 min_rel_slope_diff: float
596 Minimum required difference of left and right slope relative
597 to mean slope.
599 Returns
600 -------
601 peak_indices: array of ints
602 Peak indices. Same as input `peak_indices` but potentially shorter
603 by one or two elements.
604 trough_indices: array of ints
605 Corresponding trough indices of trough to the left or right
606 of the peaks.
607 heights: array of floats
608 Peak heights (distance between peak and corresponding trough amplitude)
609 widths: array of ints
610 Peak widths (distance between peak and corresponding trough indices)
611 slopes: array of floats
612 Peak slope (height divided by width)
613 """
614 # is a main or side peak first?
615 peak_first = int(peak_indices[0] < trough_indices[0])
616 # is a main or side peak last?
617 peak_last = int(peak_indices[-1] > trough_indices[-1])
618 # ensure all peaks to have side peaks (troughs) at both sides,
619 # i.e. troughs at same index and next index are before and after peak:
620 peak_indices = peak_indices[peak_first:len(peak_indices)-peak_last]
621 y = data[peak_indices]
623 # indices of troughs on the left and right side of main peaks:
624 l_indices = np.arange(len(peak_indices))
625 r_indices = l_indices + 1
627 # indices, distance to peak, height, and slope of left troughs:
628 l_side_indices = trough_indices[l_indices]
629 l_distance = np.abs(peak_indices - l_side_indices)
630 l_height = np.abs(y - data[l_side_indices])
631 l_slope = np.abs(l_height/l_distance)
633 # indices, distance to peak, height, and slope of right troughs:
634 r_side_indices = trough_indices[r_indices]
635 r_distance = np.abs(r_side_indices - peak_indices)
636 r_height = np.abs(y - data[r_side_indices])
637 r_slope = np.abs(r_height/r_distance)
639 # which trough to assign to the peak?
640 # - either the one with the steepest slope, or
641 # - when slopes are similar on both sides
642 # (within `min_rel_slope_diff` difference),
643 # the trough with the maximum height difference to the peak:
644 rel_slopes = np.abs(l_slope-r_slope)/(0.5*(l_slope+r_slope))
645 take_slopes = rel_slopes > min_rel_slope_diff
646 take_left = l_height > r_height
647 take_left[take_slopes] = l_slope[take_slopes] > r_slope[take_slopes]
649 # assign troughs, heights, widths, and slopes:
650 trough_indices = np.where(take_left,
651 trough_indices[:-1], trough_indices[1:])
652 heights = np.where(take_left, l_height, r_height)
653 widths = np.where(take_left, l_distance, r_distance)
654 slopes = np.where(take_left, l_slope, r_slope)
656 return peak_indices, trough_indices, heights, widths, slopes
659def cluster(eod_xp, eod_xt, eod_heights, eod_widths, data, rate,
660 width_factor_shape, width_factor_wave, n_gaus_height=10,
661 merge_threshold_height=0.1, n_gaus_width=3,
662 merge_threshold_width=0.5, minp=10, verbose=0,
663 plot_level=0, save_plots=False, save_path='', ftype='pdf',
664 return_data=[]):
665 """Cluster EODs.
667 First cluster on EOD widths using a Bayesian Gaussian
668 Mixture (BGM) model, then cluster on EOD heights using a
669 BGM model. Lastly, cluster on EOD waveform with DBSCAN.
670 Clustering on EOD waveform is performed twice, once on
671 peak-centered EODs and once on trough-centered EODs.
672 Non-pulsetype EOD clusters are deleted, and clusters are
673 merged afterwards.
675 Parameters
676 ----------
677 eod_xp : list of ints
678 Location of EOD peaks in indices.
679 eod_xt: list of ints
680 Locations of EOD troughs in indices.
681 eod_heights: list of floats
682 EOD heights.
683 eod_widths: list of ints
684 EOD widths in samples.
685 data: array of floats
686 Data in which to detect pulse EODs.
687 rate : float
688 Sampling rate of `data`.
689 width_factor_shape : float
690 Multiplier for snippet extraction width. This factor is
691 multiplied with the width between the peak and through of a
692 single EOD.
693 width_factor_wave : float
694 Multiplier for wavefish extraction width.
695 n_gaus_height : int (optional)
696 Number of gaussians to use for the clustering based on EOD height.
697 merge_threshold_height : float (optional)
698 Threshold for merging clusters that are similar in height.
699 n_gaus_width : int (optional)
700 Number of gaussians to use for the clustering based on EOD width.
701 merge_threshold_width : float (optional)
702 Threshold for merging clusters that are similar in width.
703 minp : int (optional)
704 Minimum number of points for core clusters (DBSCAN).
705 verbose : int (optional)
706 Verbosity level.
707 plot_level : int (optional)
708 Similar to verbosity levels, but with plots.
709 Only set to > 0 for debugging purposes.
710 save_plots : bool (optional)
711 Set to True to save created plots.
712 save_path : string (optional)
713 Path to save plots to. Only used if save_plots==True.
714 ftype : string (optional)
715 Filetype to save plot images in.
716 return_data : list of strings (optional)
717 Keys that specify data to be logged. Keys that can be used to log data
718 in this function are: 'all_cluster_steps', 'BGM_width', 'BGM_height',
719 'snippet_clusters', 'eod_deletion' (see extract_pulsefish()).
721 Returns
722 -------
723 labels : list of ints
724 EOD cluster labels based on height and EOD waveform.
725 x_merge : list of ints
726 Locations of EODs in clusters.
727 saved_data : dictionary
728 Key value pairs of logged data. Data to be logged is specified
729 by return_data.
731 """
732 saved_data = {}
734 if plot_level>0 or 'all_cluster_steps' in return_data:
735 all_heightlabels = []
736 all_shapelabels = []
737 all_snippets = []
738 all_features = []
739 all_heights = []
740 all_unique_heightlabels = []
742 all_p_clusters = -1 * np.ones(len(eod_xp))
743 all_t_clusters = -1 * np.ones(len(eod_xp))
744 artefact_masks_p = np.ones(len(eod_xp), dtype=bool)
745 artefact_masks_t = np.ones(len(eod_xp), dtype=bool)
747 x_merge = -1 * np.ones(len(eod_xp))
749 max_label_p = 0 # keep track of the labels so that no labels are overwritten
750 max_label_t = 0
752 # loop only over height clusters that are bigger than minp
753 # first cluster on width
754 width_labels, bgm_log_dict = BGM(1000*eod_widths/rate,
755 merge_threshold_width,
756 n_gaus_width, use_log=False,
757 verbose=verbose-1,
758 plot_level=plot_level-1,
759 xlabel='width [ms]',
760 save_plot=save_plots,
761 save_path=save_path,
762 save_name='width', ftype=ftype,
763 return_data=return_data)
764 saved_data.update(bgm_log_dict)
766 if verbose > 0:
767 print('Clusters generated based on EOD width:')
768 for l in np.unique(width_labels):
769 print(f'N_{l} = {len(width_labels[width_labels==l]):4d} h_{l} = {np.mean(eod_widths[width_labels==l]):.4f}')
771 w_labels, w_counts = unique_counts(width_labels)
772 unique_width_labels = w_labels[w_counts>minp]
774 for wi, width_label in enumerate(unique_width_labels):
776 # select only features in one width cluster at a time
777 w_eod_widths = eod_widths[width_labels==width_label]
778 w_eod_heights = eod_heights[width_labels==width_label]
779 w_eod_xp = eod_xp[width_labels==width_label]
780 w_eod_xt = eod_xt[width_labels==width_label]
781 width = int(width_factor_shape*np.median(w_eod_widths))
782 if width > w_eod_xp[0]:
783 width = w_eod_xp[0]
784 if width > w_eod_xt[0]:
785 width = w_eod_xt[0]
786 if width > len(data) - w_eod_xp[-1]:
787 width = len(data) - w_eod_xp[-1]
788 if width > len(data) - w_eod_xt[-1]:
789 width = len(data) - w_eod_xt[-1]
791 wp_clusters = -1 * np.ones(len(w_eod_xp))
792 wt_clusters = -1 * np.ones(len(w_eod_xp))
793 wartefact_mask = np.ones(len(w_eod_xp))
795 # determine height labels
796 raw_p_snippets, p_snippets, p_features, p_bg_ratio = \
797 extract_snippet_features(data, w_eod_xp, w_eod_heights, width)
798 raw_t_snippets, t_snippets, t_features, t_bg_ratio = \
799 extract_snippet_features(data, w_eod_xt, w_eod_heights, width)
801 height_labels, bgm_log_dict = \
802 BGM(w_eod_heights, min(merge_threshold_height,
803 np.median(np.min(np.vstack([p_bg_ratio, t_bg_ratio]),
804 axis=0))), n_gaus_height, use_log=True,
805 verbose=verbose-1, plot_level=plot_level-1, xlabel =
806 'height [a.u.]', save_plot=save_plots,
807 save_path=save_path, save_name = 'height_%d' % wi,
808 ftype=ftype, return_data=return_data)
809 saved_data.update(bgm_log_dict)
811 if verbose > 0:
812 print('Clusters generated based on EOD height:')
813 for l in np.unique(height_labels):
814 print(f'N_{l} = {len(height_labels[height_labels==l]):4d} h_{l} = {np.mean(w_eod_heights[height_labels==l]):.4f}')
816 h_labels, h_counts = unique_counts(height_labels)
817 unique_height_labels = h_labels[h_counts>minp]
819 if plot_level > 0 or 'all_cluster_steps' in return_data:
820 all_heightlabels.append(height_labels)
821 all_heights.append(w_eod_heights)
822 all_unique_heightlabels.append(unique_height_labels)
823 shape_labels = []
824 cfeatures = []
825 csnippets = []
827 for hi, height_label in enumerate(unique_height_labels):
829 h_eod_widths = w_eod_widths[height_labels==height_label]
830 h_eod_heights = w_eod_heights[height_labels==height_label]
831 h_eod_xp = w_eod_xp[height_labels==height_label]
832 h_eod_xt = w_eod_xt[height_labels==height_label]
834 p_clusters = cluster_on_shape(p_features[height_labels==height_label],
835 p_bg_ratio, minp, verbose=0)
836 t_clusters = cluster_on_shape(t_features[height_labels==height_label],
837 t_bg_ratio, minp, verbose=0)
839 if plot_level > 1:
840 plot_feature_extraction(raw_p_snippets[height_labels==height_label],
841 p_snippets[height_labels==height_label],
842 p_features[height_labels==height_label],
843 p_clusters, 1/rate, 0)
844 plt.savefig('%sDBSCAN_peak_w%i_h%i.%s' % (save_path, wi, hi, ftype))
845 plot_feature_extraction(raw_t_snippets[height_labels==height_label],
846 t_snippets[height_labels==height_label],
847 t_features[height_labels==height_label],
848 t_clusters, 1/rate, 1)
849 plt.savefig('%sDBSCAN_trough_w%i_h%i.%s' % (save_path, wi, hi, ftype))
851 if 'snippet_clusters' in return_data:
852 saved_data[f'snippet_clusters_{width_label}_{height_label}_peak'] = {
853 'raw_snippets': raw_p_snippets[height_labels==height_label],
854 'snippets': p_snippets[height_labels==height_label],
855 'features': p_features[height_labels==height_label],
856 'clusters': p_clusters,
857 'rate': rate}
858 saved_data['snippet_clusters_{width_label}_{height_label}_trough'] = {
859 'raw_snippets': raw_t_snippets[height_labels==height_label],
860 'snippets': t_snippets[height_labels==height_label],
861 'features': t_features[height_labels==height_label],
862 'clusters': t_clusters,
863 'rate': rate}
865 if plot_level > 0 or 'all_cluster_steps' in return_data:
866 shape_labels.append([p_clusters, t_clusters])
867 cfeatures.append([p_features[height_labels==height_label],
868 t_features[height_labels==height_label]])
869 csnippets.append([p_snippets[height_labels==height_label],
870 t_snippets[height_labels==height_label]])
872 p_clusters[p_clusters==-1] = -max_label_p - 1
873 wp_clusters[height_labels==height_label] = p_clusters + max_label_p
874 max_label_p = max(np.max(wp_clusters), np.max(all_p_clusters)) + 1
876 t_clusters[t_clusters==-1] = -max_label_t - 1
877 wt_clusters[height_labels==height_label] = t_clusters + max_label_t
878 max_label_t = max(np.max(wt_clusters), np.max(all_t_clusters)) + 1
880 if verbose > 0:
881 if np.max(wp_clusters) == -1:
882 print(f'No EOD peaks in width cluster {width_label}')
883 else:
884 unique_clusters = np.unique(wp_clusters[wp_clusters!=-1])
885 if len(unique_clusters) > 1:
886 print('{len(unique_clusters)} different EOD peaks in width cluster {width_label}')
888 if plot_level > 0 or 'all_cluster_steps' in return_data:
889 all_shapelabels.append(shape_labels)
890 all_snippets.append(csnippets)
891 all_features.append(cfeatures)
893 # for each cluster, save fft + label
894 # so I end up with features for each label, and the masks.
895 # then I can extract e.g. first artefact or wave etc.
897 # remove artefacts here, based on the mean snippets ffts.
898 artefact_masks_p[width_labels==width_label], sdict = \
899 remove_artefacts(p_snippets, wp_clusters, rate,
900 verbose=verbose-1, return_data=return_data)
901 saved_data.update(sdict)
902 artefact_masks_t[width_labels==width_label], _ = \
903 remove_artefacts(t_snippets, wt_clusters, rate,
904 verbose=verbose-1, return_data=return_data)
906 # update maxlab so that no clusters are overwritten
907 all_p_clusters[width_labels==width_label] = wp_clusters
908 all_t_clusters[width_labels==width_label] = wt_clusters
910 # remove all non-reliable clusters
911 unreliable_fish_mask_p, saved_data = \
912 delete_unreliable_fish(all_p_clusters, eod_widths, eod_xp,
913 verbose=verbose-1, sdict=saved_data)
914 unreliable_fish_mask_t, _ = \
915 delete_unreliable_fish(all_t_clusters, eod_widths, eod_xt, verbose=verbose-1)
917 wave_mask_p, sidepeak_mask_p, saved_data = \
918 delete_wavefish_and_sidepeaks(data, all_p_clusters, eod_xp, eod_widths,
919 width_factor_wave, verbose=verbose-1, sdict=saved_data)
920 wave_mask_t, sidepeak_mask_t, _ = \
921 delete_wavefish_and_sidepeaks(data, all_t_clusters, eod_xt, eod_widths,
922 width_factor_wave, verbose=verbose-1)
924 og_clusters = [np.copy(all_p_clusters), np.copy(all_t_clusters)]
925 og_labels = np.copy(all_p_clusters + all_t_clusters)
927 # go through all clusters and masks??
928 all_p_clusters[(artefact_masks_p | unreliable_fish_mask_p | wave_mask_p | sidepeak_mask_p)] = -1
929 all_t_clusters[(artefact_masks_t | unreliable_fish_mask_t | wave_mask_t | sidepeak_mask_t)] = -1
931 # merge here.
932 all_clusters, x_merge, mask = merge_clusters(np.copy(all_p_clusters),
933 np.copy(all_t_clusters),
934 eod_xp, eod_xt,
935 verbose=verbose - 1)
937 if 'all_cluster_steps' in return_data or plot_level > 0:
938 all_dmasks = []
939 all_mmasks = []
941 discarding_masks = \
942 np.vstack(((artefact_masks_p | unreliable_fish_mask_p | wave_mask_p | sidepeak_mask_p),
943 (artefact_masks_t | unreliable_fish_mask_t | wave_mask_t | sidepeak_mask_t)))
944 merge_mask = mask
946 # save the masks in the same formats as the snippets
947 for wi, (width_label, w_shape_label, heightlabels, unique_height_labels) in enumerate(zip(unique_width_labels, all_shapelabels, all_heightlabels, all_unique_heightlabels)):
948 w_dmasks = discarding_masks[:,width_labels==width_label]
949 w_mmasks = merge_mask[:,width_labels==width_label]
951 wd_2 = []
952 wm_2 = []
954 for hi, (height_label, h_shape_label) in enumerate(zip(unique_height_labels, w_shape_label)):
956 h_dmasks = w_dmasks[:,heightlabels==height_label]
957 h_mmasks = w_mmasks[:,heightlabels==height_label]
959 wd_2.append(h_dmasks)
960 wm_2.append(h_mmasks)
962 all_dmasks.append(wd_2)
963 all_mmasks.append(wm_2)
965 if plot_level > 0:
966 plot_clustering(rate, [unique_width_labels, eod_widths, width_labels],
967 [all_unique_heightlabels, all_heights, all_heightlabels],
968 [all_snippets, all_features, all_shapelabels],
969 all_dmasks, all_mmasks)
970 if save_plots:
971 plt.savefig('%sclustering.%s' % (save_path, ftype))
973 if 'all_cluster_steps' in return_data:
974 saved_data = {'rate': rate,
975 'EOD_widths': [unique_width_labels, eod_widths, width_labels],
976 'EOD_heights': [all_unique_heightlabels, all_heights, all_heightlabels],
977 'EOD_shapes': [all_snippets, all_features, all_shapelabels],
978 'discarding_masks': all_dmasks,
979 'merge_masks': all_mmasks
980 }
982 if 'masks' in return_data:
983 saved_data = {'masks' : np.vstack(((artefact_masks_p & artefact_masks_t),
984 (unreliable_fish_mask_p & unreliable_fish_mask_t),
985 (wave_mask_p & wave_mask_t),
986 (sidepeak_mask_p & sidepeak_mask_t),
987 (all_p_clusters+all_t_clusters)))}
989 if verbose > 0:
990 print('Clusters generated based on height, width and shape: ')
991 for l in np.unique(all_clusters[all_clusters != -1]):
992 print('N_{int(l)} = {len(all_clusters[all_clusters == l]):4d}')
994 return all_clusters, x_merge, saved_data
997def BGM(x, merge_threshold=0.1, n_gaus=5, max_iter=200, n_init=5,
998 use_log=False, verbose=0, plot_level=0, xlabel='x [a.u.]',
999 save_plot=False, save_path='', save_name='', ftype='pdf',
1000 return_data=[]):
1001 """ Use a Bayesian Gaussian Mixture Model to cluster one-dimensional data.
1003 Additional steps are used to merge clusters that are closer than
1004 `merge_threshold`. Broad gaussian fits that cover one or more other
1005 gaussian fits are split by their intersections with the other
1006 gaussians.
1008 Parameters
1009 ----------
1010 x : 1D numpy array
1011 Features to compute clustering on.
1013 merge_threshold : float (optional)
1014 Ratio for merging nearby gaussians.
1015 n_gaus: int (optional)
1016 Maximum number of gaussians to fit on data.
1017 max_iter : int (optional)
1018 Maximum number of iterations for gaussian fit.
1019 n_init : int (optional)
1020 Number of initializations for the gaussian fit.
1021 use_log: boolean (optional)
1022 Set to True to compute the gaussian fit on the logarithm of x.
1023 Can improve clustering on features with nonlinear relationships such as peak height.
1024 verbose : int (optional)
1025 Verbosity level.
1026 plot_level : int (optional)
1027 Similar to verbosity levels, but with plots.
1028 Only set to > 0 for debugging purposes.
1029 xlabel : string (optional)
1030 Xlabel for displaying BGM plot.
1031 save_plot : bool (optional)
1032 Set to True to save created plot.
1033 save_path : string (optional)
1034 Path to location where data should be saved. Only used if save_plot==True.
1035 save_name : string (optional)
1036 Filename of the saved plot. Usefull as usually multiple BGM models are generated.
1037 ftype : string (optional)
1038 Filetype of plot image if save_plots==True.
1039 return_data : list of strings (optional)
1040 Keys that specify data to be logged. Keys that can be used to log data
1041 in this function are: 'BGM_width' and/or 'BGM_height' (see extract_pulsefish()).
1043 Returns
1044 -------
1045 labels : 1D numpy array
1046 Cluster labels for each sample in x.
1047 bgm_dict : dictionary
1048 Key value pairs of logged data. Data to be logged is specified by return_data.
1049 """
1051 bgm_dict = {}
1053 if len(np.unique(x)) > n_gaus:
1054 BGM_model = BayesianGaussianMixture(n_components=n_gaus, max_iter=max_iter, n_init=n_init)
1055 if use_log:
1056 labels = BGM_model.fit_predict(stats.zscore(np.log(x)).reshape(-1, 1))
1057 else:
1058 labels = BGM_model.fit_predict(stats.zscore(x).reshape(-1, 1))
1059 else:
1060 return np.zeros(len(x)), bgm_dict
1062 if verbose>0:
1063 if not BGM_model.converged_:
1064 print('!!! Gaussian mixture did not converge !!!')
1066 cur_labels = np.unique(labels)
1068 # map labels to be increasing for increasing values for x
1069 maxlab = len(cur_labels)
1070 aso = np.argsort([np.median(x[labels == l]) for l in cur_labels]) + 100
1071 for i, a in zip(cur_labels, aso):
1072 labels[labels==i] = a
1073 labels = labels - 100
1075 # separate gaussian clusters that can be split by other clusters
1076 splits = np.sort(np.copy(x))[1:][np.diff(labels[np.argsort(x)])!=0]
1078 labels[:] = 0
1079 for i, split in enumerate(splits):
1080 labels[x>=split] = i+1
1082 labels_before_merge = np.copy(labels)
1084 # merge gaussian clusters that are closer than merge_threshold
1085 labels = merge_gaussians(x, labels, merge_threshold)
1087 if 'BGM_'+save_name.split('_')[0] in return_data or plot_level>0:
1089 #sort model attributes by model_means_
1090 means = [m[0] for m in BGM_model.means_]
1091 weights = [w for w in BGM_model.weights_]
1092 variances = [v[0][0] for v in BGM_model.covariances_]
1093 weights = [w for _, w in sorted(zip(means, weights))]
1094 variances = [v for _, v in sorted(zip(means, variances))]
1095 means = sorted(means)
1097 if plot_level>0:
1098 plot_bgm(x, means, variances, weights, use_log, labels_before_merge,
1099 labels, xlabel)
1100 if save_plot:
1101 plt.savefig('%sBGM_%s.%s' % (save_path, save_name, ftype))
1103 if 'BGM_'+save_name.split('_')[0] in return_data:
1104 bgm_dict['BGM_'+save_name] = {'x':x,
1105 'use_log':use_log,
1106 'BGM':[weights, means, variances],
1107 'labels':labels_before_merge,
1108 'xlab':xlabel}
1110 return labels, bgm_dict
1113def merge_gaussians(x, labels, merge_threshold=0.1):
1114 """ Merge all clusters which have medians which are near. Only works in 1D.
1116 Parameters
1117 ----------
1118 x : 1D array of ints or floats
1119 Features used for clustering.
1120 labels : 1D array of ints
1121 Labels for each sample in x.
1122 merge_threshold : float (optional)
1123 Similarity threshold to merge clusters.
1125 Returns
1126 -------
1127 labels : 1D array of ints
1128 Merged labels for each sample in x.
1129 """
1131 # compare all the means of the gaussians. If they are too close, merge them.
1132 unique_labels = np.unique(labels[labels!=-1])
1133 x_medians = [np.median(x[labels==l]) for l in unique_labels]
1135 # fill a dict with the label mappings
1136 mapping = {}
1137 for label_1, x_m1 in zip(unique_labels, x_medians):
1138 for label_2, x_m2 in zip(unique_labels, x_medians):
1139 if label_1!=label_2:
1140 if np.abs(np.diff([x_m1, x_m2]))/np.max([x_m1, x_m2]) < merge_threshold:
1141 mapping[label_1] = label_2
1142 # apply mapping
1143 for map_key, map_value in mapping.items():
1144 labels[labels==map_key] = map_value
1146 return labels
1149def extract_snippet_features(data, eod_x, eod_heights, width, n_pc=5):
1150 """ Extract snippets from recording data, normalize them, and perform PCA.
1152 Parameters
1153 ----------
1154 data : 1D numpy array of floats
1155 Recording data.
1156 eod_x : 1D array of ints
1157 Locations of EODs as indices.
1158 eod_heights: 1D array of floats
1159 EOD heights.
1160 width : int
1161 Width to cut out to each side in samples.
1163 n_pc : int (optional)
1164 Number of PCs to use for PCA.
1166 Returns
1167 -------
1168 raw_snippets : 2D numpy array (N, EOD_width)
1169 Raw extracted EOD snippets.
1170 snippets : 2D numpy array (N, EOD_width)
1171 Normalized EOD snippets
1172 features : 2D numpy array (N,n_pc)
1173 PC values of EOD snippets
1174 bg_ratio : 1D numpy array (N)
1175 Ratio of the background activity slopes compared to EOD height.
1176 """
1177 # extract snippets with corresponding width
1178 raw_snippets = np.vstack([data[x-width:x+width] for x in eod_x])
1180 # subtract the slope and normalize the snippets
1181 snippets, bg_ratio = subtract_slope(np.copy(raw_snippets), eod_heights)
1182 snippets = StandardScaler().fit_transform(snippets.T).T
1184 # scale so that the absolute integral = 1.
1185 snippets = (snippets.T/np.sum(np.abs(snippets), axis=1)).T
1187 # compute features for clustering on waveform
1188 features = PCA(n_pc).fit_transform(snippets)
1190 return raw_snippets, snippets, features, bg_ratio
1193def cluster_on_shape(features, bg_ratio, minp, percentile=80,
1194 max_epsilon=0.01, slope_ratio_factor=4,
1195 min_cluster_fraction=0.01, verbose=0):
1196 """Separate EODs by their shape using DBSCAN.
1198 Parameters
1199 ----------
1200 features : 2D numpy array of floats (N, n_pc)
1201 PCA features of each EOD in a recording.
1202 bg_ratio : 1D array of floats
1203 Ratio of background activity slope the EOD is superimposed on.
1204 minp : int
1205 Minimum number of points for core cluster (DBSCAN).
1207 percentile : int (optional)
1208 Percentile of KNN distribution, where K=minp, to use as epsilon for DBSCAN.
1209 max_epsilon : float (optional)
1210 Maximum epsilon to use for DBSCAN clustering. This is used to avoid adding
1211 noisy clusters.
1212 slope_ratio_factor : float (optional)
1213 Influence of the slope-to-EOD ratio on the epsilon parameter.
1214 A slope_ratio_factor of 4 means that slope-to-EOD ratios >1/4
1215 start influencing epsilon.
1216 min_cluster_fraction : float (optional)
1217 Minimum fraction of all eveluated datapoint that can form a single cluster.
1218 verbose : int (optional)
1219 Verbosity level.
1221 Returns
1222 -------
1223 labels : 1D array of ints
1224 Merged labels for each sample in x.
1225 """
1227 # determine clustering threshold from data
1228 minpc = max(minp, int(len(features)*min_cluster_fraction))
1229 knn = np.sort(pairwise_distances(features, features), axis=0)[minpc]
1230 eps = min(max(1, slope_ratio_factor*np.median(bg_ratio))*max_epsilon,
1231 np.percentile(knn, percentile))
1233 if verbose>1:
1234 print('epsilon = %f'%eps)
1235 print('Slope to EOD ratio = %f'%np.median(bg_ratio))
1237 # cluster on EOD shape
1238 return DBSCAN(eps=eps, min_samples=minpc).fit(features).labels_
1241def subtract_slope(snippets, heights):
1242 """ Subtract underlying slope from all EOD snippets.
1244 Parameters
1245 ----------
1246 snippets: 2-D numpy array
1247 All EODs in a recorded stacked as snippets.
1248 Shape = (number of EODs, EOD width)
1249 heights: 1D numpy array
1250 EOD heights.
1252 Returns
1253 -------
1254 snippets: 2-D numpy array
1255 EOD snippets with underlying slope subtracted.
1256 bg_ratio : 1-D numpy array
1257 EOD height/background activity height.
1258 """
1260 left_y = snippets[:,0]
1261 right_y = snippets[:,-1]
1263 try:
1264 slopes = np.linspace(left_y, right_y, snippets.shape[1])
1265 except ValueError:
1266 delta = (right_y - left_y)/snippets.shape[1]
1267 slopes = np.arange(0, snippets.shape[1], dtype=snippets.dtype).reshape((-1,) + (1,) * np.ndim(delta))*delta + left_y
1269 return snippets - slopes.T, np.abs(left_y-right_y)/heights
1272def remove_artefacts(all_snippets, clusters, rate,
1273 freq_low=20000, threshold=0.75,
1274 verbose=0, return_data=[]):
1275 """ Create a mask for EOD clusters that result from artefacts, based on power in low frequency spectrum.
1277 Parameters
1278 ----------
1279 all_snippets: 2D array
1280 EOD snippets. Shape=(nEODs, EOD length)
1281 clusters: list of ints
1282 EOD cluster labels
1283 rate : float
1284 Sampling rate of original recording data.
1285 freq_low: float
1286 Frequency up to which low frequency components are summed up.
1287 threshold : float (optional)
1288 Minimum value for sum of low frequency components relative to
1289 sum overa ll spectrl amplitudes that separates artefact from
1290 clean pulsefish clusters.
1291 verbose : int (optional)
1292 Verbosity level.
1293 return_data : list of strings (optional)
1294 Keys that specify data to be logged. The key that can be used to log data in this function is
1295 'eod_deletion' (see extract_pulsefish()).
1297 Returns
1298 -------
1299 mask: numpy array of booleans
1300 Set to True for every EOD which is an artefact.
1301 adict : dictionary
1302 Key value pairs of logged data. Data to be logged is specified by return_data.
1303 """
1304 adict = {}
1306 mask = np.zeros(clusters.shape, dtype=bool)
1308 for cluster in np.unique(clusters[clusters >= 0]):
1309 snippets = all_snippets[clusters == cluster]
1310 mean_eod = np.mean(snippets, axis=0)
1311 mean_eod = mean_eod - np.mean(mean_eod)
1312 mean_eod_fft = np.abs(np.fft.rfft(mean_eod))
1313 freqs = np.fft.rfftfreq(len(mean_eod), 1/rate)
1314 low_frequency_ratio = np.sum(mean_eod_fft[freqs<freq_low])/np.sum(mean_eod_fft)
1315 if low_frequency_ratio < threshold: # TODO: check threshold!
1316 mask[clusters==cluster] = True
1318 if verbose > 0:
1319 print('Deleting cluster %i with low frequency ratio of %.3f (min %.3f)' % (cluster, low_frequency_ratio, threshold))
1321 if 'eod_deletion' in return_data:
1322 adict['vals_%d' % cluster] = [mean_eod, mean_eod_fft]
1323 adict['mask_%d' % cluster] = [np.any(mask[clusters==cluster])]
1325 return mask, adict
1328def delete_unreliable_fish(clusters, eod_widths, eod_x, verbose=0, sdict={}):
1329 """ Create a mask for EOD clusters that are either mixed with noise or other fish, or wavefish.
1331 This is the case when the ration between the EOD width and the ISI is too large.
1333 Parameters
1334 ----------
1335 clusters : list of ints
1336 Cluster labels.
1337 eod_widths : list of floats or ints
1338 EOD widths in samples or seconds.
1339 eod_x : list of ints or floats
1340 EOD times in samples or seconds.
1342 verbose : int (optional)
1343 Verbosity level.
1344 sdict : dictionary
1345 Dictionary that is used to log data. This is only used if a dictionary
1346 was created by remove_artefacts().
1347 For logging data in noise and wavefish discarding steps,
1348 see remove_artefacts().
1350 Returns
1351 -------
1352 mask : numpy array of booleans
1353 Set to True for every unreliable EOD.
1354 sdict : dictionary
1355 Key value pairs of logged data. Data is only logged if a dictionary
1356 was instantiated by remove_artefacts().
1357 """
1358 mask = np.zeros(clusters.shape, dtype=bool)
1359 for cluster in np.unique(clusters[clusters >= 0]):
1360 if len(eod_x[cluster == clusters]) < 2:
1361 mask[clusters == cluster] = True
1362 if verbose>0:
1363 print('deleting unreliable cluster %i, number of EOD times %d < 2' % (cluster, len(eod_x[cluster==clusters])))
1364 elif np.max(np.median(eod_widths[clusters==cluster])/np.diff(eod_x[cluster==clusters])) > 0.5:
1365 if verbose>0:
1366 print('deleting unreliable cluster %i, score=%f' % (cluster, np.max(np.median(eod_widths[clusters==cluster])/np.diff(eod_x[cluster==clusters]))))
1367 mask[clusters==cluster] = True
1368 if 'vals_%d' % cluster in sdict:
1369 sdict['vals_%d' % cluster].append(np.median(eod_widths[clusters==cluster])/np.diff(eod_x[cluster==clusters]))
1370 sdict['mask_%d' % cluster].append(any(mask[clusters==cluster]))
1371 return mask, sdict
1374def delete_wavefish_and_sidepeaks(data, clusters, eod_x, eod_widths,
1375 width_fac, max_slope_deviation=0.5,
1376 max_phases=4, verbose=0, sdict={}):
1377 """ Create a mask for EODs that are likely from wavefish, or sidepeaks of bigger EODs.
1379 Parameters
1380 ----------
1381 data : list of floats
1382 Raw recording data.
1383 clusters : list of ints
1384 Cluster labels.
1385 eod_x : list of ints
1386 Indices of EOD times.
1387 eod_widths : list of ints
1388 EOD widths in samples.
1389 width_fac : float
1390 Multiplier for EOD analysis width.
1392 max_slope_deviation: float (optional)
1393 Maximum deviation of position of maximum slope in snippets from
1394 center position in multiples of mean width of EOD.
1395 max_phases : int (optional)
1396 Maximum number of phases for any EOD.
1397 If the mean EOD has more phases than this, it is not a pulse EOD.
1398 verbose : int (optional)
1399 Verbosity level.
1400 sdict : dictionary
1401 Dictionary that is used to log data. This is only used if a dictionary
1402 was created by remove_artefacts().
1403 For logging data in noise and wavefish discarding steps, see remove_artefacts().
1405 Returns
1406 -------
1407 mask_wave: numpy array of booleans
1408 Set to True for every EOD which is a wavefish EOD.
1409 mask_sidepeak: numpy array of booleans
1410 Set to True for every snippet which is centered around a sidepeak of an EOD.
1411 sdict : dictionary
1412 Key value pairs of logged data. Data is only logged if a dictionary
1413 was instantiated by remove_artefacts().
1414 """
1415 mask_wave = np.zeros(clusters.shape, dtype=bool)
1416 mask_sidepeak = np.zeros(clusters.shape, dtype=bool)
1418 for i, cluster in enumerate(np.unique(clusters[clusters >= 0])):
1419 mean_width = np.mean(eod_widths[clusters == cluster])
1420 cutwidth = mean_width*width_fac
1421 current_x = eod_x[(eod_x>cutwidth) & (eod_x<(len(data)-cutwidth))]
1422 current_clusters = clusters[(eod_x>cutwidth) & (eod_x<(len(data)-cutwidth))]
1423 snippets = np.vstack([data[int(x-cutwidth):int(x+cutwidth)]
1424 for x in current_x[current_clusters==cluster]])
1426 # extract information on main peaks and troughs:
1427 mean_eod = np.mean(snippets, axis=0)
1428 mean_eod = mean_eod - np.mean(mean_eod)
1430 # detect peaks and troughs on data + some maxima/minima at the
1431 # end, so that the sides are also considered for peak detection:
1432 pk, tr = detect_peaks(np.concatenate(([-10*mean_eod[0]], mean_eod, [10*mean_eod[-1]])),
1433 np.std(mean_eod))
1434 pk = pk[(pk>0)&(pk<len(mean_eod))]
1435 tr = tr[(tr>0)&(tr<len(mean_eod))]
1437 if len(pk)>0 and len(tr)>0:
1438 idxs = np.sort(np.concatenate((pk, tr)))
1439 slopes = np.abs(np.diff(mean_eod[idxs]))
1440 m_slope = np.argmax(slopes)
1441 centered = np.min(np.abs(idxs[m_slope:m_slope+2] - len(mean_eod)//2))
1443 # compute all height differences of peaks and troughs within snippets.
1444 # if they are all similar, it is probably noise or a wavefish.
1445 idxs = np.sort(np.concatenate((pk, tr)))
1446 hdiffs = np.diff(mean_eod[idxs])
1448 if centered > max_slope_deviation*mean_width: # TODO: check, factor was probably 0.16
1449 if verbose > 0:
1450 print('Deleting cluster %i, which is a sidepeak' % cluster)
1451 mask_sidepeak[clusters==cluster] = True
1453 w_diff = np.abs(np.diff(np.sort(np.concatenate((pk, tr)))))
1455 if np.abs(np.diff(idxs[m_slope:m_slope+2])) < np.mean(eod_widths[clusters==cluster])*0.5 or len(pk) + len(tr)>max_phases or np.min(w_diff)>2*cutwidth/width_fac: #or len(hdiffs[np.abs(hdiffs)>0.5*(np.max(mean_eod)-np.min(mean_eod))])>max_phases:
1456 if verbose>0:
1457 print('Deleting cluster %i, which is a wavefish' % cluster)
1458 mask_wave[clusters==cluster] = True
1459 if 'vals_%d' % cluster in sdict:
1460 sdict['vals_%d' % cluster].append([mean_eod, [pk, tr],
1461 idxs[m_slope:m_slope+2]])
1462 sdict['mask_%d' % cluster].append(any(mask_wave[clusters==cluster]))
1463 sdict['mask_%d' % cluster].append(any(mask_sidepeak[clusters==cluster]))
1465 return mask_wave, mask_sidepeak, sdict
1468def merge_clusters(clusters_1, clusters_2, x_1, x_2, verbose=0):
1469 """ Merge clusters resulting from two clustering methods.
1471 This method only works if clustering is performed on the same EODs
1472 with the same ordering, where there is a one to one mapping from
1473 clusters_1 to clusters_2.
1475 Parameters
1476 ----------
1477 clusters_1: list of ints
1478 EOD cluster labels for cluster method 1.
1479 clusters_2: list of ints
1480 EOD cluster labels for cluster method 2.
1481 x_1: list of ints
1482 Indices of EODs for cluster method 1 (clusters_1).
1483 x_2: list of ints
1484 Indices of EODs for cluster method 2 (clusters_2).
1485 verbose : int (optional)
1486 Verbosity level.
1488 Returns
1489 -------
1490 clusters : list of ints
1491 Merged clusters.
1492 x_merged : list of ints
1493 Merged cluster indices.
1494 mask : 2d numpy array of ints (N, 2)
1495 Mask for clusters that are selected from clusters_1 (mask[:,0]) and
1496 from clusters_2 (mask[:,1]).
1497 """
1498 if verbose > 0:
1499 print('\nMerge cluster:')
1501 # these arrays become 1 for each EOD that is chosen from that array
1502 c1_keep = np.zeros(len(clusters_1))
1503 c2_keep = np.zeros(len(clusters_2))
1505 # add n to one of the cluster lists to avoid overlap
1506 ovl = np.max(clusters_1) + 1
1507 clusters_2[clusters_2!=-1] = clusters_2[clusters_2!=-1] + ovl
1509 remove_clusters = [[]]
1510 keep_clusters = []
1511 og_clusters = [np.copy(clusters_1), np.copy(clusters_2)]
1513 # loop untill done
1514 while True:
1516 # compute unique clusters and cluster sizes
1517 # of cluster that have not been iterated over:
1518 c1_labels, c1_size = unique_counts(clusters_1[(clusters_1 != -1) & (c1_keep == 0)])
1519 c2_labels, c2_size = unique_counts(clusters_2[(clusters_2 != -1) & (c2_keep == 0)])
1521 # if all clusters are done, break from loop:
1522 if len(c1_size) == 0 and len(c2_size) == 0:
1523 break
1525 # if the biggest cluster is in c_p, keep this one and discard all clusters
1526 # on the same indices in c_t:
1527 elif np.argmax([np.max(np.append(c1_size, 0)), np.max(np.append(c2_size, 0))]) == 0:
1529 # remove all the mappings from the other indices
1530 cluster_mappings, _ = unique_counts(clusters_2[clusters_1 == c1_labels[np.argmax(c1_size)]])
1532 clusters_2[np.isin(clusters_2, cluster_mappings)] = -1
1534 c1_keep[clusters_1==c1_labels[np.argmax(c1_size)]] = 1
1536 remove_clusters.append(cluster_mappings)
1537 keep_clusters.append(c1_labels[np.argmax(c1_size)])
1539 if verbose > 0:
1540 print('Keep cluster %i of group 1, delete clusters %s of group 2' % (c1_labels[np.argmax(c1_size)], str(cluster_mappings[cluster_mappings!=-1] - ovl)))
1542 # if the biggest cluster is in c_t, keep this one and discard all mappings in c_p
1543 elif np.argmax([np.max(np.append(c1_size, 0)), np.max(np.append(c2_size, 0))]) == 1:
1545 # remove all the mappings from the other indices
1546 cluster_mappings, _ = unique_counts(clusters_1[clusters_2 == c2_labels[np.argmax(c2_size)]])
1548 clusters_1[np.isin(clusters_1, cluster_mappings)] = -1
1550 c2_keep[clusters_2==c2_labels[np.argmax(c2_size)]] = 1
1552 remove_clusters.append(cluster_mappings)
1553 keep_clusters.append(c2_labels[np.argmax(c2_size)])
1555 if verbose > 0:
1556 print('Keep cluster %i of group 2, delete clusters %s of group 1' % (c2_labels[np.argmax(c2_size)] - ovl, str(cluster_mappings[cluster_mappings!=-1])))
1558 # combine results
1559 clusters = (clusters_1+1)*c1_keep + (clusters_2+1)*c2_keep - 1
1560 x_merged = (x_1)*c1_keep + (x_2)*c2_keep
1562 return clusters, x_merged, np.vstack([c1_keep, c2_keep])
1565def extract_means(data, eod_x, eod_peak_x, eod_tr_x, eod_widths,
1566 clusters, rate, width_fac, verbose=0):
1567 """ Extract mean EODs and EOD timepoints for each EOD cluster.
1569 Parameters
1570 ----------
1571 data: list of floats
1572 Raw recording data.
1573 eod_x: list of ints
1574 Locations of EODs in samples.
1575 eod_peak_x : list of ints
1576 Locations of EOD peaks in samples.
1577 eod_tr_x : list of ints
1578 Locations of EOD troughs in samples.
1579 eod_widths: list of ints
1580 EOD widths in samples.
1581 clusters: list of ints
1582 EOD cluster labels
1583 rate: float
1584 Sampling rate of recording
1585 width_fac : float
1586 Multiplication factor for window used to extract EOD.
1588 verbose : int (optional)
1589 Verbosity level.
1591 Returns
1592 -------
1593 mean_eods: list of 2D arrays (3, eod_length)
1594 The average EOD for each detected fish. First column is time in seconds,
1595 second column the mean eod, third column the standard error.
1596 eod_times: list of 1D arrays
1597 For each detected fish the times of EOD in seconds.
1598 eod_peak_times: list of 1D arrays
1599 For each detected fish the times of EOD peaks in seconds.
1600 eod_trough_times: list of 1D arrays
1601 For each detected fish the times of EOD troughs in seconds.
1602 eod_labels: list of ints
1603 Cluster label for each detected fish.
1604 """
1605 mean_eods, eod_times, eod_peak_times, eod_tr_times, eod_heights, cluster_labels = [], [], [], [], [], []
1607 for cluster in np.unique(clusters):
1608 if cluster!=-1:
1609 cutwidth = np.mean(eod_widths[clusters==cluster])*width_fac
1610 current_x = eod_x[(eod_x>cutwidth) & (eod_x<(len(data)-cutwidth))]
1611 current_clusters = clusters[(eod_x>cutwidth) & (eod_x<(len(data)-cutwidth))]
1613 snippets = np.vstack([data[int(x-cutwidth):int(x+cutwidth)] for x in current_x[current_clusters==cluster]])
1614 mean_eod = np.mean(snippets, axis=0)
1615 eod_time = np.arange(len(mean_eod))/rate - cutwidth/rate
1617 mean_eod = np.vstack([eod_time, mean_eod, np.std(snippets, axis=0)])
1619 mean_eods.append(mean_eod)
1620 eod_times.append(eod_x[clusters==cluster]/rate)
1621 eod_heights.append(np.min(mean_eod)-np.max(mean_eod))
1622 eod_peak_times.append(eod_peak_x[clusters==cluster]/rate)
1623 eod_tr_times.append(eod_tr_x[clusters==cluster]/rate)
1624 cluster_labels.append(cluster)
1626 return [m for _, m in sorted(zip(eod_heights, mean_eods))], [t for _, t in sorted(zip(eod_heights, eod_times))], [pt for _, pt in sorted(zip(eod_heights, eod_peak_times))], [tt for _, tt in sorted(zip(eod_heights, eod_tr_times))], [c for _, c in sorted(zip(eod_heights, cluster_labels))]
1629def find_clipped_clusters(clusters, mean_eods, eod_times,
1630 eod_peaktimes, eod_troughtimes,
1631 cluster_labels, width_factor,
1632 clip_threshold=0.9, verbose=0):
1633 """ Detect EODs that are clipped and set all clusterlabels of these clipped EODs to -1.
1635 Also return the mean EODs and timepoints of these clipped EODs.
1637 Parameters
1638 ----------
1639 clusters: array of ints
1640 Cluster labels for each EOD in a recording.
1641 mean_eods: list of numpy arrays
1642 Mean EOD waveform for each cluster.
1643 eod_times: list of numpy arrays
1644 EOD timepoints for each EOD cluster.
1645 eod_peaktimes
1646 EOD peaktimes for each EOD cluster.
1647 eod_troughtimes
1648 EOD troughtimes for each EOD cluster.
1649 cluster_labels: numpy array
1650 Unique EOD clusterlabels.
1651 clip_threshold: float
1652 Threshold for detecting clipped EODs.
1654 verbose: int
1655 Verbosity level.
1657 Returns
1658 -------
1659 clusters : array of ints
1660 Cluster labels for each EOD in the recording, where clipped EODs have been set to -1.
1661 clipped_eods : list of numpy arrays
1662 Mean EOD waveforms for each clipped EOD cluster.
1663 clipped_times : list of numpy arrays
1664 EOD timepoints for each clipped EOD cluster.
1665 clipped_peaktimes : list of numpy arrays
1666 EOD peaktimes for each clipped EOD cluster.
1667 clipped_troughtimes : list of numpy arrays
1668 EOD troughtimes for each clipped EOD cluster.
1669 """
1670 clipped_eods, clipped_times, clipped_peaktimes, clipped_troughtimes, clipped_labels = [], [], [], [], []
1672 for mean_eod, eod_time, eod_peaktime, eod_troughtime,label in zip(mean_eods, eod_times, eod_peaktimes, eod_troughtimes, cluster_labels):
1674 if (np.count_nonzero(mean_eod[1]>clip_threshold) > len(mean_eod[1])/(width_factor*2)) or (np.count_nonzero(mean_eod[1] < -clip_threshold) > len(mean_eod[1])/(width_factor*2)):
1675 clipped_eods.append(mean_eod)
1676 clipped_times.append(eod_time)
1677 clipped_peaktimes.append(eod_peaktime)
1678 clipped_troughtimes.append(eod_troughtime)
1679 clipped_labels.append(label)
1680 if verbose>0:
1681 print('clipped pulsefish')
1683 clusters[np.isin(clusters, clipped_labels)] = -1
1685 return clusters, clipped_eods, clipped_times, clipped_peaktimes, clipped_troughtimes
1688def delete_moving_fish(clusters, eod_t, T, eod_heights, eod_widths,
1689 rate, min_dt=0.25, stepsize=0.05,
1690 sliding_window_factor=2000, verbose=0,
1691 plot_level=0, save_plot=False, save_path='',
1692 ftype='pdf', return_data=[]):
1693 """
1694 Use a sliding window to detect the minimum number of fish detected simultaneously,
1695 then delete all other EOD clusters.
1697 Do this only for EODs within the same width clusters, as a
1698 moving fish will preserve its EOD width.
1700 Parameters
1701 ----------
1702 clusters: list of ints
1703 EOD cluster labels.
1704 eod_t: list of floats
1705 Timepoints of the EODs (in seconds).
1706 T: float
1707 Length of recording (in seconds).
1708 eod_heights: list of floats
1709 EOD amplitudes.
1710 eod_widths: list of floats
1711 EOD widths (in seconds).
1712 rate: float
1713 Recording data sampling rate.
1715 min_dt : float (optional)
1716 Minimum sliding window size (in seconds).
1717 stepsize : float (optional)
1718 Sliding window stepsize (in seconds).
1719 sliding_window_factor : float
1720 Multiplier for sliding window width,
1721 where the sliding window width = median(EOD_width)*sliding_window_factor.
1722 verbose : int (optional)
1723 Verbosity level.
1724 plot_level : int (optional)
1725 Similar to verbosity levels, but with plots.
1726 Only set to > 0 for debugging purposes.
1727 save_plot : bool (optional)
1728 Set to True to save the plots created by plot_level.
1729 save_path : string (optional)
1730 Path to save data to. Only important if you wish to save data (save_data==True).
1731 ftype : string (optional)
1732 Define the filetype to save the plots in if save_plots is set to True.
1733 Options are: 'png', 'jpg', 'svg' ...
1734 return_data : list of strings (optional)
1735 Keys that specify data to be logged. The key that can be used to log data
1736 in this function is 'moving_fish' (see extract_pulsefish()).
1738 Returns
1739 -------
1740 clusters : list of ints
1741 Cluster labels, where deleted clusters have been set to -1.
1742 window : list of 2 floats
1743 Start and end of window selected for deleting moving fish in seconds.
1744 mf_dict : dictionary
1745 Key value pairs of logged data. Data to be logged is specified by return_data.
1746 """
1747 mf_dict = {}
1749 if len(np.unique(clusters[clusters != -1])) == 0:
1750 return clusters, [0, 1], {}
1752 all_keep_clusters = []
1753 width_classes = merge_gaussians(eod_widths, np.copy(clusters), 0.75)
1755 all_windows = []
1756 all_dts = []
1757 ev_num = 0
1758 for iw, w in enumerate(np.unique(width_classes[clusters >= 0])):
1759 # initialize variables
1760 min_clusters = 100
1761 average_height = 0
1762 sparse_clusters = 100
1763 keep_clusters = []
1765 dt = max(min_dt, np.median(eod_widths[width_classes==w])*sliding_window_factor)
1766 window_start = 0
1767 window_end = dt
1769 wclusters = clusters[width_classes==w]
1770 weod_t = eod_t[width_classes==w]
1771 weod_heights = eod_heights[width_classes==w]
1772 weod_widths = eod_widths[width_classes==w]
1774 all_dts.append(dt)
1776 if verbose>0:
1777 print('sliding window dt = %f'%dt)
1779 # make W dependent on width??
1780 ignore_steps = np.zeros(len(np.arange(0, T-dt+stepsize, stepsize)))
1782 for i, t in enumerate(np.arange(0, T-dt+stepsize, stepsize)):
1783 current_clusters = wclusters[(weod_t>=t)&(weod_t<t+dt)&(wclusters!=-1)]
1784 if len(np.unique(current_clusters))==0:
1785 ignore_steps[i-int(dt/stepsize):i+int(dt/stepsize)] = 1
1786 if verbose>0:
1787 print('No pulsefish in recording at T=%.2f:%.2f' % (t, t+dt))
1790 x = np.arange(0, T-dt+stepsize, stepsize)
1791 y = np.ones(len(x))
1793 running_sum = np.ones(len(np.arange(0, T+stepsize, stepsize)))
1794 ulabs = np.unique(wclusters[wclusters>=0])
1796 # sliding window
1797 for j, (t, ignore_step) in enumerate(zip(x, ignore_steps)):
1798 current_clusters = wclusters[(weod_t>=t)&(weod_t<t+dt)&(wclusters!=-1)]
1799 current_widths = weod_widths[(weod_t>=t)&(weod_t<t+dt)&(wclusters!=-1)]
1801 unique_clusters = np.unique(current_clusters)
1802 y[j] = len(unique_clusters)
1804 if (len(unique_clusters) <= min_clusters) and \
1805 (ignore_step==0) and \
1806 (len(unique_clusters !=1)):
1808 current_labels = np.isin(wclusters, unique_clusters)
1809 current_height = np.mean(weod_heights[current_labels])
1811 # compute nr of clusters that are too sparse
1812 clusters_after_deletion = np.unique(remove_sparse_detections(np.copy(clusters[np.isin(clusters, unique_clusters)]), rate*eod_widths[np.isin(clusters, unique_clusters)], rate, T))
1813 current_sparse_clusters = len(unique_clusters) - len(clusters_after_deletion[clusters_after_deletion!=-1])
1815 if current_sparse_clusters <= sparse_clusters and \
1816 ((current_sparse_clusters<sparse_clusters) or
1817 (current_height > average_height) or
1818 (len(unique_clusters) < min_clusters)):
1820 keep_clusters = unique_clusters
1821 min_clusters = len(unique_clusters)
1822 average_height = current_height
1823 window_end = t+dt
1824 sparse_clusters = current_sparse_clusters
1826 all_keep_clusters.append(keep_clusters)
1827 all_windows.append(window_end)
1829 if 'moving_fish' in return_data or plot_level>0:
1830 if 'w' in mf_dict:
1831 mf_dict['w'].append(np.median(eod_widths[width_classes==w]))
1832 mf_dict['T'] = T
1833 mf_dict['dt'].append(dt)
1834 mf_dict['clusters'].append(wclusters)
1835 mf_dict['t'].append(weod_t)
1836 mf_dict['fishcount'].append([x+0.5*(x[1]-x[0]), y])
1837 mf_dict['ignore_steps'].append(ignore_steps)
1838 else:
1839 mf_dict['w'] = [np.median(eod_widths[width_classes==w])]
1840 mf_dict['T'] = [T]
1841 mf_dict['dt'] = [dt]
1842 mf_dict['clusters'] = [wclusters]
1843 mf_dict['t'] = [weod_t]
1844 mf_dict['fishcount'] = [[x+0.5*(x[1]-x[0]), y]]
1845 mf_dict['ignore_steps'] = [ignore_steps]
1847 if verbose>0:
1848 print('Estimated nr of pulsefish in recording: %i'%len(all_keep_clusters))
1850 if plot_level>0:
1851 plot_moving_fish(mf_dict['w'], mf_dict['dt'], mf_dict['clusters'],mf_dict['t'],
1852 mf_dict['fishcount'], T, mf_dict['ignore_steps'])
1853 if save_plot:
1854 plt.savefig('%sdelete_moving_fish.%s' % (save_path, ftype))
1855 # empty dict
1856 if 'moving_fish' not in return_data:
1857 mf_dict = {}
1859 # delete all clusters that are not selected
1860 clusters[np.invert(np.isin(clusters, np.concatenate(all_keep_clusters)))] = -1
1862 return clusters, [np.max(all_windows)-np.max(all_dts), np.max(all_windows)], mf_dict
1865def remove_sparse_detections(clusters, eod_widths, rate, T,
1866 min_density=0.0005, verbose=0):
1867 """ Remove all EOD clusters that are too sparse
1869 Parameters
1870 ----------
1871 clusters : list of ints
1872 Cluster labels.
1873 eod_widths : list of ints
1874 Cluster widths in samples.
1875 rate : float
1876 Sampling rate.
1877 T : float
1878 Lenght of recording in seconds.
1879 min_density : float (optional)
1880 Minimum density for realistic EOD detections.
1881 verbose : int (optional)
1882 Verbosity level.
1884 Returns
1885 -------
1886 clusters : list of ints
1887 Cluster labels, where sparse clusters have been set to -1.
1888 """
1889 for c in np.unique(clusters):
1890 if c!=-1:
1892 n = len(clusters[clusters==c])
1893 w = np.median(eod_widths[clusters==c])/rate
1895 if n*w < T*min_density:
1896 if verbose>0:
1897 print('cluster %i is too sparse'%c)
1898 clusters[clusters==c] = -1
1899 return clusters