Coverage for src / thunderfish / pulses.py: 0%
604 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-02-09 14:25 +0000
« prev ^ index » next coverage.py v7.13.4, created at 2026-02-09 14:25 +0000
1"""
2Extract and cluster EOD waverforms of pulse-type electric fish.
4## Main function
6- `extract_pulsefish()`: checks for pulse-type fish based on the EOD amplitude and shape.
8"""
10import os
11import numpy as np
13from scipy import stats
14from scipy.interpolate import interp1d
15from sklearn.preprocessing import StandardScaler
16from sklearn.decomposition import PCA
17from sklearn.cluster import DBSCAN
18from sklearn.mixture import BayesianGaussianMixture
19from sklearn.metrics import pairwise_distances
20from thunderlab.eventdetection import detect_peaks, median_std_threshold
22from .pulseplots import *
24import warnings
25def warn(*args, **kwargs):
26 """
27 Ignore all warnings.
28 """
29 pass
30warnings.warn = warn
32try:
33 from numba import jit
34except ImportError:
35 def jit(*args, **kwargs):
36 def decorator_jit(func):
37 return func
38 return decorator_jit
41# upgrade numpy functions for backwards compatibility:
42if not hasattr(np, 'isin'):
43 np.isin = np.in1d
45def unique_counts(ar):
46 """ Find the unique elements of an array and their counts, ignoring shape.
48 The code is condensed from numpy version 1.17.0.
50 Parameters
51 ----------
52 ar : numpy array
53 Input array
55 Returns
56 -------
57 unique_vaulues : numpy array
58 Unique values in array ar.
59 unique_counts : numpy array
60 Number of instances for each unique value in ar.
61 """
62 try:
63 return np.unique(ar, return_counts=True)
64 except TypeError:
65 ar = np.asanyarray(ar).flatten()
66 ar.sort()
67 mask = np.empty(ar.shape, dtype=bool_)
68 mask[:1] = True
69 mask[1:] = ar[1:] != ar[:-1]
70 idx = np.concatenate(np.nonzero(mask) + ([mask.size],))
71 return ar[mask], np.diff(idx)
74###########################################################################
77def extract_pulsefish(data, rate, amax, width_factor_shape=3,
78 width_factor_wave=8, width_factor_display=4,
79 verbose=0, plot_level=0, save_plots=False,
80 save_path='', ftype='png', return_data=[]):
81 """ Extract and cluster pulse-type fish EODs from single channel data.
83 Takes recording data containing an unknown number of pulsefish and extracts the mean
84 EOD and EOD timepoints for each fish present in the recording.
86 Parameters
87 ----------
88 data: 1-D array of float
89 The data to be analysed.
90 rate: float
91 Sampling rate of the data in Hertz.
92 amax: float
93 Maximum amplitude of data range.
94 width_factor_shape : float (optional)
95 Width multiplier used for EOD shape analysis.
96 EOD snippets are extracted based on width between the
97 peak and trough multiplied by the width factor.
98 width_factor_wave : float (optional)
99 Width multiplier used for wavefish detection.
100 width_factor_display : float (optional)
101 Width multiplier used for EOD mean extraction and display.
102 verbose : int (optional)
103 Verbosity level.
104 plot_level : int (optional)
105 Similar to verbosity levels, but with plots.
106 Only set to > 0 for debugging purposes.
107 save_plots : bool (optional)
108 Set to True to save the plots created by plot_level.
109 save_path: string (optional)
110 Path for saving plots.
111 ftype : string (optional)
112 Define the filetype to save the plots in if save_plots is set to True.
113 Options are: 'png', 'jpg', 'svg' ...
114 return_data : list of strings (optional)
115 Specify data that should be logged and returned in a dictionary. Each clustering
116 step has a specific keyword that results in adding different variables to the log dictionary.
117 Optional keys for return_data and the resulting additional key-value pairs to the log dictionary are:
119 - 'all_eod_times':
120 - 'all_times': list of two lists of floats.
121 All peak (`all_times[0]`) and trough times (`all_times[1]`) extracted
122 by the peak detection algorithm. Times are given in seconds.
123 - 'eod_troughtimes': list of 1D arrays.
124 The timepoints in seconds of each unique extracted EOD cluster,
125 where each 1D array encodes one cluster.
127 - 'peak_detection':
128 - "data": 1D numpy array of floats.
129 Quadratically interpolated data which was used for peak detection.
130 - "interp_fac": float.
131 Interpolation factor of raw data.
132 - "peaks_1": 1D numpy array of ints.
133 Peak indices on interpolated data after first peak detection step.
134 - "troughs_1": 1D numpy array of ints.
135 Peak indices on interpolated data after first peak detection step.
136 - "peaks_2": 1D numpy array of ints.
137 Peak indices on interpolated data after second peak detection step.
138 - "troughs_2": 1D numpy array of ints.
139 Peak indices on interpolated data after second peak detection step.
140 - "peaks_3": 1D numpy array of ints.
141 Peak indices on interpolated data after third peak detection step.
142 - "troughs_3": 1D numpy array of ints.
143 Peak indices on interpolated data after third peak detection step.
144 - "peaks_4": 1D numpy array of ints.
145 Peak indices on interpolated data after fourth peak detection step.
146 - "troughs_4": 1D numpy array of ints.
147 Peak indices on interpolated data after fourth peak detection step.
149 - 'all_cluster_steps':
150 - 'rate': float.
151 Sampling rate of interpolated data.
152 - 'EOD_widths': list of three 1D numpy arrays.
153 The first list entry gives the unique labels of all width clusters
154 as a list of ints.
155 The second list entry gives the width values for each EOD in samples
156 as a 1D numpy array of ints.
157 The third list entry gives the width labels for each EOD
158 as a 1D numpy array of ints.
159 - 'EOD_heights': nested lists (2 layers) of three 1D numpy arrays.
160 The first list entry gives the unique labels of all height clusters
161 as a list of ints for each width cluster.
162 The second list entry gives the height values for each EOD
163 as a 1D numpy array of floats for each width cluster.
164 The third list entry gives the height labels for each EOD
165 as a 1D numpy array of ints for each width cluster.
166 - 'EOD_shapes': nested lists (3 layers) of three 1D numpy arrays
167 The first list entry gives the raw EOD snippets as a 2D numpy array
168 for each height cluster in a width cluster.
169 The second list entry gives the snippet PCA values for each EOD
170 as a 2D numpy array of floats for each height cluster in a width cluster.
171 The third list entry gives the shape labels for each EOD as a 1D numpy array
172 of ints for each height cluster in a width cluster.
173 - 'discarding_masks': Nested lists (two layers) of 1D numpy arrays.
174 The masks of EODs that are discarded by the discarding step of the algorithm.
175 The masks are 1D boolean arrays where instances that are set to True are
176 discarded by the algorithm. Discarding masks are saved in nested lists
177 that represent the width and height clusters.
178 - 'merge_masks': Nested lists (two layers) of 2D numpy arrays.
179 The masks of EODs that are discarded by the merging step of the algorithm.
180 The masks are 2D boolean arrays where for each sample point `i` either
181 `merge_mask[i,0]` or `merge_mask[i,1]` is set to True. Here, merge_mask[:,0]
182 represents the peak-centered clusters and `merge_mask[:,1]` represents the
183 trough-centered clusters. Merge masks are saved in nested lists that
184 represent the width and height clusters.
186 - 'BGM_width':
187 - 'BGM_width': dictionary
188 - 'x': 1D numpy array of floats.
189 BGM input values (in this case the EOD widths),
190 - 'use_log': boolean.
191 True if the z-scored logarithm of the data was used as BGM input.
192 - 'BGM': list of three 1D numpy arrays.
193 The first instance are the weights of the Gaussian fits.
194 The second instance are the means of the Gaussian fits.
195 The third instance are the variances of the Gaussian fits.
196 - 'labels': 1D numpy array of ints.
197 Labels defined by BGM model (before merging based on merge factor).
198 - xlab': string.
199 Label for plot (defines the units of the BGM data).
201 - 'BGM_height':
202 This key adds a new dictionary for each width cluster.
203 - 'BGM_height_*n*' : dictionary, where *n* defines the width cluster as an int.
204 - 'x': 1D numpy array of floats.
205 BGM input values (in this case the EOD heights),
206 - 'use_log': boolean.
207 True if the z-scored logarithm of the data was used as BGM input.
208 - 'BGM': list of three 1D numpy arrays.
209 The first instance are the weights of the Gaussian fits.
210 The second instance are the means of the Gaussian fits.
211 The third instance are the variances of the Gaussian fits.
212 - 'labels': 1D numpy array of ints.
213 Labels defined by BGM model (before merging based on merge factor).
214 - 'xlab': string.
215 Label for plot (defines the units of the BGM data).
217 - 'snippet_clusters':
218 This key adds a new dictionary for each height cluster.
219 - 'snippet_clusters*_n_m_p*' : dictionary, where *n* defines the width cluster
220 (int), *m* defines the height cluster (int) and *p* defines shape clustering
221 on peak or trough centered EOD snippets (string: 'peak' or 'trough').
222 - 'raw_snippets': 2D numpy array (nsamples, nfeatures).
223 Raw EOD snippets.
224 - 'snippets': 2D numpy array.
225 Normalized EOD snippets.
226 - 'features': 2D numpy array.(nsamples, nfeatures)
227 PCA values for each normalized EOD snippet.
228 - 'clusters': 1D numpy array of ints.
229 Cluster labels.
230 - 'rate': float.
231 Sampling rate of snippets.
233 - 'eod_deletion':
234 This key adds two dictionaries for each (peak centered) shape cluster,
235 where *cluster* (int) is the unique shape cluster label.
236 - 'mask_*cluster*' : list of four booleans.
237 The mask for each cluster discarding step.
238 The first instance represents the artefact masks, where artefacts
239 are set to True.
240 The second instance represents the unreliable cluster masks,
241 where unreliable clusters are set to True.
242 The third instance represents the wavefish masks, where wavefish
243 are set to True.
244 The fourth instance represents the sidepeak masks, where sidepeaks
245 are set to True.
246 - 'vals_*cluster*' : list of lists.
247 All variables that are used for each cluster deletion step.
248 The first instance is a list of two 1D numpy arrays: the mean EOD and
249 the FFT of that mean EOD.
250 The second instance is a 1D numpy array with all EOD width to ISI ratios.
251 The third instance is a list with three entries:
252 The first entry is a 1D numpy array zoomed out version of the mean EOD.
253 The second entry is a list of two 1D numpy arrays that define the peak
254 and trough indices of the zoomed out mean EOD.
255 The third entry contains a list of two values that represent the
256 peak-trough pair in the zoomed out mean EOD with the largest height
257 difference.
258 - 'rate' : float.
259 EOD snippet sampling rate.
261 - 'masks':
262 - 'masks' : 2D numpy array (4,N).
263 Each row contains masks for each EOD detected by the EOD peakdetection step.
264 The first row defines the artefact masks, the second row defines the
265 unreliable EOD masks,
266 the third row defines the wavefish masks and the fourth row defines
267 the sidepeak masks.
269 - 'moving_fish':
270 - 'moving_fish': dictionary.
271 - 'w' : list of floats.
272 Median width for each width cluster that the moving fish algorithm is
273 computed on (in seconds).
274 - 'T' : list of floats.
275 Lenght of analyzed recording for each width cluster (in seconds).
276 - 'dt' : list of floats.
277 Sliding window size (in seconds) for each width cluster.
278 - 'clusters' : list of 1D numpy int arrays.
279 Cluster labels for each EOD cluster in a width cluster.
280 - 't' : list of 1D numpy float arrays.
281 EOD emission times for each EOD in a width cluster.
282 - 'fishcount' : list of lists.
283 Sliding window timepoints and fishcounts for each width cluster.
284 - 'ignore_steps' : list of 1D int arrays.
285 Mask for fishcounts that were ignored (ignored if True) in the
286 moving_fish analysis.
288 Returns
289 -------
290 mean_eods: list of 2D arrays (3, eod_length)
291 The average EOD for each detected fish. First column is time in seconds,
292 second column the mean eod, third column the standard error.
293 eod_times: list of 1D arrays
294 For each detected fish the times of EOD peaks or troughs in seconds.
295 Use these timepoints for EOD averaging.
296 eod_peaktimes: list of 1D arrays
297 For each detected fish the times of EOD peaks in seconds.
298 zoom_window: tuple of floats
299 Start and endtime of suggested window for plotting EOD timepoints.
300 log_dict: dictionary
301 Dictionary with logged variables, where variables to log are specified
302 by `return_data`.
303 """
304 if verbose > 0:
305 print('')
306 if verbose > 1:
307 print(70*'#')
308 print('##### extract_pulsefish', 46*'#')
310 if (save_plots and plot_level>0 and save_path):
311 # create folder to save things in.
312 if not os.path.exists(save_path):
313 os.makedirs(save_path)
314 else:
315 save_path = ''
317 mean_eods, eod_times, eod_peaktimes, zoom_window = [], [], [], []
318 log_dict = {}
320 # interpolate:
321 i_rate = 500000.0
322 #i_rate = rate
323 try:
324 f = interp1d(np.arange(len(data))/rate, data, kind='quadratic')
325 i_data = f(np.arange(0.0, (len(data)-1)/rate, 1.0/i_rate))
326 except MemoryError:
327 i_rate = rate
328 i_data = data
329 log_dict['data'] = i_data # TODO: could be removed
330 log_dict['rate'] = i_rate # TODO: could be removed
331 log_dict['i_data'] = i_data
332 log_dict['i_rate'] = i_rate
333 # log_dict["interp_fac"] = interp_fac # TODO: is not set anymore
335 # standard deviation of data in small snippets:
336 win_size = int(0.002*rate) # 2ms windows
337 threshold = median_std_threshold(data, win_size) # TODO make this a parameter
339 # extract peaks:
340 if 'peak_detection' in return_data:
341 x_peak, x_trough, eod_heights, eod_widths, pd_log_dict = \
342 detect_pulses(i_data, i_rate, threshold,
343 width_fac=np.max([width_factor_shape,
344 width_factor_display,
345 width_factor_wave]),
346 verbose=verbose, return_data=True)
347 log_dict.update(pd_log_dict)
348 else:
349 x_peak, x_trough, eod_heights, eod_widths = \
350 detect_pulses(i_data, i_rate, threshold,
351 width_fac=np.max([width_factor_shape,
352 width_factor_display,
353 width_factor_wave]),
354 verbose=verbose, return_data=False)
356 if len(x_peak) > 0:
357 # cluster
358 clusters, x_merge, c_log_dict = cluster(x_peak, x_trough,
359 eod_heights,
360 eod_widths, i_data,
361 i_rate,
362 width_factor_shape,
363 width_factor_wave,
364 merge_threshold_height=0.1*amax,
365 verbose=verbose,
366 plot_level=plot_level-1,
367 save_plots=save_plots,
368 save_path=save_path,
369 ftype=ftype,
370 return_data=return_data)
372 # extract mean eods and times
373 mean_eods, eod_times, eod_peaktimes, eod_troughtimes, cluster_labels = \
374 extract_means(i_data, x_merge, x_peak, x_trough, eod_widths, clusters,
375 i_rate, width_factor_display, verbose=verbose)
377 # determine clipped clusters (save them, but ignore in other steps)
378 clusters, clipped_eods, clipped_times, clipped_peaktimes, clipped_troughtimes = \
379 find_clipped_clusters(clusters, mean_eods, eod_times,
380 eod_peaktimes, eod_troughtimes,
381 cluster_labels, width_factor_display,
382 verbose=verbose)
384 # delete the moving fish
385 clusters, zoom_window, mf_log_dict = \
386 delete_moving_fish(clusters, x_merge/i_rate, len(data)/rate,
387 eod_heights, eod_widths/i_rate, i_rate,
388 verbose=verbose, plot_level=plot_level-1,
389 save_plot=save_plots,
390 save_path=save_path, ftype=ftype,
391 return_data=return_data)
393 if 'moving_fish' in return_data:
394 log_dict['moving_fish'] = mf_log_dict
396 clusters = remove_sparse_detections(clusters, eod_widths, i_rate,
397 len(data)/rate, verbose=verbose)
399 # extract mean eods
400 mean_eods, eod_times, eod_peaktimes, eod_troughtimes, cluster_labels = \
401 extract_means(i_data, x_merge, x_peak, x_trough, eod_widths,
402 clusters, i_rate, width_factor_display,
403 verbose=verbose)
405 mean_eods.extend(clipped_eods)
406 eod_times.extend(clipped_times)
407 eod_peaktimes.extend(clipped_peaktimes)
408 eod_troughtimes.extend(clipped_troughtimes)
410 if plot_level > 0:
411 plot_all(data, eod_peaktimes, eod_troughtimes, rate, mean_eods)
412 if save_plots:
413 plt.savefig('%sextract_pulsefish_results.%s' % (save_path, ftype))
414 if save_plots:
415 plt.close('all')
417 if 'all_eod_times' in return_data:
418 log_dict['all_times'] = [x_peak/i_rate, x_trough/i_rate]
419 log_dict['eod_troughtimes'] = eod_troughtimes
421 log_dict.update(c_log_dict)
423 if verbose > 0:
424 print('')
426 return mean_eods, eod_times, eod_peaktimes, zoom_window, log_dict
429def detect_pulses(data, rate, thresh, min_rel_slope_diff=0.25,
430 min_width=0.00005, max_width=0.01, width_fac=5.0,
431 verbose=0, return_data=False):
432 """Detect pulses in data.
434 Was `def extract_eod_times(data, rate, width_factor,
435 interp_freq=500000, max_peakwidth=0.01,
436 min_peakwidth=None, verbose=0, return_data=[],
437 save_path='')` before.
439 Parameters
440 ----------
441 data: 1-D array of float
442 The data to be analysed.
443 rate: float
444 Sampling rate of the data.
445 thresh: float
446 Threshold for peak and trough detection via `detect_peaks()`.
447 Must be a positive number that sets the minimum difference
448 between a peak and a trough.
449 min_rel_slope_diff: float
450 Minimum required difference between left and right slope (between
451 peak and troughs) relative to mean slope for deciding which trough
452 to take besed on slope difference.
453 min_width: float
454 Minimum width (peak-trough distance) of pulses in seconds.
455 max_width: float
456 Maximum width (peak-trough distance) of pulses in seconds.
457 width_fac: float
458 Pulses extend plus or minus `width_fac` times their width
459 (distance between peak and assigned trough).
460 Only pulses are returned that can fully be analysed with this width.
461 verbose : int (optional)
462 Verbosity level.
463 return_data : bool
464 If `True` data of this function is logged and returned (see
465 extract_pulsefish()).
467 Returns
468 -------
469 peak_indices: array of ints
470 Indices of EOD peaks in data.
471 trough_indices: array of ints
472 Indices of EOD troughs in data. There is one x_trough for each x_peak.
473 heights: array of floats
474 EOD heights for each x_peak.
475 widths: array of ints
476 EOD widths for each x_peak (in samples).
477 peak_detection_result : dictionary
478 Key value pairs of logged data.
479 This is only returned if `return_data` is `True`.
481 """
482 peak_detection_result = {}
484 # detect peaks and troughs in the data:
485 peak_indices, trough_indices = detect_peaks(data, thresh)
486 if verbose > 0:
487 print('Peaks/troughs detected in data: %5d %5d'
488 % (len(peak_indices), len(trough_indices)))
489 if return_data:
490 peak_detection_result.update(peaks_1=np.array(peak_indices),
491 troughs_1=np.array(trough_indices))
492 if len(peak_indices) < 2 or \
493 len(trough_indices) < 2 or \
494 len(peak_indices) > len(data)/20:
495 # TODO: if too many peaks increase threshold!
496 if verbose > 0:
497 print('No or too many peaks/troughs detected in data.')
498 if return_data:
499 return np.array([], dtype=int), np.array([], dtype=int), \
500 np.array([]), np.array([], dtype=int), peak_detection_result
501 else:
502 return np.array([], dtype=int), np.array([], dtype=int), \
503 np.array([]), np.array([], dtype=int)
505 # assign troughs to peaks:
506 peak_indices, trough_indices, heights, widths, slopes = \
507 assign_side_peaks(data, peak_indices, trough_indices, min_rel_slope_diff)
508 if verbose > 1:
509 print('Number of peaks after assigning side-peaks: %5d'
510 % (len(peak_indices)))
511 if return_data:
512 peak_detection_result.update(peaks_2=np.array(peak_indices),
513 troughs_2=np.array(trough_indices))
515 # check widths:
516 keep = ((widths>min_width*rate) & (widths<max_width*rate))
517 peak_indices = peak_indices[keep]
518 trough_indices = trough_indices[keep]
519 heights = heights[keep]
520 widths = widths[keep]
521 slopes = slopes[keep]
522 if verbose > 1:
523 print('Number of peaks after checking pulse width: %5d'
524 % (len(peak_indices)))
525 if return_data:
526 peak_detection_result.update(peaks_3=np.array(peak_indices),
527 troughs_3=np.array(trough_indices))
529 # discard connected peaks:
530 same = np.nonzero(trough_indices[:-1] == trough_indices[1:])[0]
531 keep = np.ones(len(trough_indices), dtype=bool)
532 for i in same:
533 # same troughs at trough_indices[i] and trough_indices[i+1]:
534 s = slopes[i:i+2]
535 rel_slopes = np.abs(np.diff(s))[0]/np.mean(s)
536 if rel_slopes > min_rel_slope_diff:
537 keep[i+(s[1]<s[0])] = False
538 else:
539 keep[i+(heights[i+1]<heights[i])] = False
540 peak_indices = peak_indices[keep]
541 trough_indices = trough_indices[keep]
542 heights = heights[keep]
543 widths = widths[keep]
544 if verbose > 1:
545 print('Number of peaks after merging pulses: %5d'
546 % (len(peak_indices)))
547 if return_data:
548 peak_detection_result.update(peaks_4=np.array(peak_indices),
549 troughs_4=np.array(trough_indices))
550 if len(peak_indices) == 0:
551 if verbose > 0:
552 print('No peaks remain as pulse candidates.')
553 if return_data:
554 return np.array([], dtype=int), np.array([], dtype=int), \
555 np.array([]), np.array([], dtype=int), peak_detection_result
556 else:
557 return np.array([], dtype=int), np.array([], dtype=int), \
558 np.array([]), np.array([], dtype=int)
560 # only take those where the maximum cutwidth does not cause issues -
561 # if the width_fac times the width + x is more than length.
562 keep = ((peak_indices - widths > 0) &
563 (peak_indices + widths < len(data)) &
564 (trough_indices - widths > 0) &
565 (trough_indices + widths < len(data)))
567 if verbose > 0:
568 print('Remaining peaks after EOD extraction: %5d'
569 % (np.sum(keep)))
570 print('')
572 if return_data:
573 return peak_indices[keep], trough_indices[keep], \
574 heights[keep], widths[keep], peak_detection_result
575 else:
576 return peak_indices[keep], trough_indices[keep], \
577 heights[keep], widths[keep]
580@jit(nopython=True)
581def assign_side_peaks(data, peak_indices, trough_indices,
582 min_rel_slope_diff=0.25):
583 """Assign to each peak the trough resulting in a pulse with the steepest slope or largest height.
585 The slope between a peak and a trough is computed as the height
586 difference divided by the distance between peak and trough. If the
587 slopes between the left and the right trough differ by less than
588 `min_rel_slope_diff`, then just the heigths between and the two
589 troughs relative to the peak are compared.
591 Was `def detect_eod_peaks(data, main_indices, side_indices,
592 max_width=20, min_width=2, verbose=0)` before.
594 Parameters
595 ----------
596 data: array of floats
597 Data in which the events were detected.
598 peak_indices: array of ints
599 Indices of the detected peaks in the data time series.
600 trough_indices: array of ints
601 Indices of the detected troughs in the data time series.
602 min_rel_slope_diff: float
603 Minimum required difference of left and right slope relative
604 to mean slope.
606 Returns
607 -------
608 peak_indices: array of ints
609 Peak indices. Same as input `peak_indices` but potentially shorter
610 by one or two elements.
611 trough_indices: array of ints
612 Corresponding trough indices of trough to the left or right
613 of the peaks.
614 heights: array of floats
615 Peak heights (distance between peak and corresponding trough amplitude)
616 widths: array of ints
617 Peak widths (distance between peak and corresponding trough indices)
618 slopes: array of floats
619 Peak slope (height divided by width)
620 """
621 # is a main or side peak first?
622 peak_first = int(peak_indices[0] < trough_indices[0])
623 # is a main or side peak last?
624 peak_last = int(peak_indices[-1] > trough_indices[-1])
625 # ensure all peaks to have side peaks (troughs) at both sides,
626 # i.e. troughs at same index and next index are before and after peak:
627 peak_indices = peak_indices[peak_first:len(peak_indices)-peak_last]
628 y = data[peak_indices]
630 # indices of troughs on the left and right side of main peaks:
631 l_indices = np.arange(len(peak_indices))
632 r_indices = l_indices + 1
634 # indices, distance to peak, height, and slope of left troughs:
635 l_side_indices = trough_indices[l_indices]
636 l_distance = np.abs(peak_indices - l_side_indices)
637 l_height = np.abs(y - data[l_side_indices])
638 l_slope = np.abs(l_height/l_distance)
640 # indices, distance to peak, height, and slope of right troughs:
641 r_side_indices = trough_indices[r_indices]
642 r_distance = np.abs(r_side_indices - peak_indices)
643 r_height = np.abs(y - data[r_side_indices])
644 r_slope = np.abs(r_height/r_distance)
646 # which trough to assign to the peak?
647 # - either the one with the steepest slope, or
648 # - when slopes are similar on both sides
649 # (within `min_rel_slope_diff` difference),
650 # the trough with the maximum height difference to the peak:
651 rel_slopes = np.abs(l_slope-r_slope)/(0.5*(l_slope+r_slope))
652 take_slopes = rel_slopes > min_rel_slope_diff
653 take_left = l_height > r_height
654 take_left[take_slopes] = l_slope[take_slopes] > r_slope[take_slopes]
656 # assign troughs, heights, widths, and slopes:
657 trough_indices = np.where(take_left,
658 trough_indices[:-1], trough_indices[1:])
659 heights = np.where(take_left, l_height, r_height)
660 widths = np.where(take_left, l_distance, r_distance)
661 slopes = np.where(take_left, l_slope, r_slope)
663 return peak_indices, trough_indices, heights, widths, slopes
666def cluster(eod_xp, eod_xt, eod_heights, eod_widths, data, rate,
667 width_factor_shape, width_factor_wave, n_gaus_height=10,
668 merge_threshold_height=0.1, n_gaus_width=3,
669 merge_threshold_width=0.5, minp=10, verbose=0,
670 plot_level=0, save_plots=False, save_path='', ftype='pdf',
671 return_data=[]):
672 """Cluster EODs.
674 First cluster on EOD widths using a Bayesian Gaussian
675 Mixture (BGM) model, then cluster on EOD heights using a
676 BGM model. Lastly, cluster on EOD waveform with DBSCAN.
677 Clustering on EOD waveform is performed twice, once on
678 peak-centered EODs and once on trough-centered EODs.
679 Non-pulsetype EOD clusters are deleted, and clusters are
680 merged afterwards.
682 Parameters
683 ----------
684 eod_xp : list of ints
685 Location of EOD peaks in indices.
686 eod_xt: list of ints
687 Locations of EOD troughs in indices.
688 eod_heights: list of floats
689 EOD heights.
690 eod_widths: list of ints
691 EOD widths in samples.
692 data: array of floats
693 Data in which to detect pulse EODs.
694 rate : float
695 Sampling rate of `data`.
696 width_factor_shape : float
697 Multiplier for snippet extraction width. This factor is
698 multiplied with the width between the peak and through of a
699 single EOD.
700 width_factor_wave : float
701 Multiplier for wavefish extraction width.
702 n_gaus_height : int (optional)
703 Number of gaussians to use for the clustering based on EOD height.
704 merge_threshold_height : float (optional)
705 Threshold for merging clusters that are similar in height.
706 n_gaus_width : int (optional)
707 Number of gaussians to use for the clustering based on EOD width.
708 merge_threshold_width : float (optional)
709 Threshold for merging clusters that are similar in width.
710 minp : int (optional)
711 Minimum number of points for core clusters (DBSCAN).
712 verbose : int (optional)
713 Verbosity level.
714 plot_level : int (optional)
715 Similar to verbosity levels, but with plots.
716 Only set to > 0 for debugging purposes.
717 save_plots : bool (optional)
718 Set to True to save created plots.
719 save_path : string (optional)
720 Path to save plots to. Only used if save_plots==True.
721 ftype : string (optional)
722 Filetype to save plot images in.
723 return_data : list of strings (optional)
724 Keys that specify data to be logged. Keys that can be used to log data
725 in this function are: 'all_cluster_steps', 'BGM_width', 'BGM_height',
726 'snippet_clusters', 'eod_deletion' (see extract_pulsefish()).
728 Returns
729 -------
730 labels : list of ints
731 EOD cluster labels based on height and EOD waveform.
732 x_merge : list of ints
733 Locations of EODs in clusters.
734 saved_data : dictionary
735 Key value pairs of logged data. Data to be logged is specified
736 by return_data.
738 """
739 saved_data = {}
741 if plot_level>0 or 'all_cluster_steps' in return_data:
742 all_heightlabels = []
743 all_shapelabels = []
744 all_snippets = []
745 all_features = []
746 all_heights = []
747 all_unique_heightlabels = []
749 all_p_clusters = -1 * np.ones(len(eod_xp))
750 all_t_clusters = -1 * np.ones(len(eod_xp))
751 artefact_masks_p = np.ones(len(eod_xp), dtype=bool)
752 artefact_masks_t = np.ones(len(eod_xp), dtype=bool)
754 x_merge = -1 * np.ones(len(eod_xp))
756 max_label_p = 0 # keep track of the labels so that no labels are overwritten
757 max_label_t = 0
759 # loop only over height clusters that are bigger than minp
760 # first cluster on width
761 width_labels, bgm_log_dict = BGM(1000*eod_widths/rate,
762 merge_threshold_width,
763 n_gaus_width, use_log=False,
764 verbose=verbose-1,
765 plot_level=plot_level-1,
766 xlabel='width [ms]',
767 save_plot=save_plots,
768 save_path=save_path,
769 save_name='width', ftype=ftype,
770 return_data=return_data)
771 saved_data.update(bgm_log_dict)
773 if verbose > 0:
774 print('Clusters generated based on EOD width:')
775 for l in np.unique(width_labels):
776 print(f'N_{l} = {len(width_labels[width_labels==l]):4d} h_{l} = {np.mean(eod_widths[width_labels==l]):.4f}')
778 w_labels, w_counts = unique_counts(width_labels)
779 unique_width_labels = w_labels[w_counts>minp]
781 for wi, width_label in enumerate(unique_width_labels):
783 # select only features in one width cluster at a time
784 w_eod_widths = eod_widths[width_labels==width_label]
785 w_eod_heights = eod_heights[width_labels==width_label]
786 w_eod_xp = eod_xp[width_labels==width_label]
787 w_eod_xt = eod_xt[width_labels==width_label]
788 width = int(width_factor_shape*np.median(w_eod_widths))
789 if width > w_eod_xp[0]:
790 width = w_eod_xp[0]
791 if width > w_eod_xt[0]:
792 width = w_eod_xt[0]
793 if width > len(data) - w_eod_xp[-1]:
794 width = len(data) - w_eod_xp[-1]
795 if width > len(data) - w_eod_xt[-1]:
796 width = len(data) - w_eod_xt[-1]
798 wp_clusters = -1 * np.ones(len(w_eod_xp))
799 wt_clusters = -1 * np.ones(len(w_eod_xp))
800 wartefact_mask = np.ones(len(w_eod_xp))
802 # determine height labels
803 raw_p_snippets, p_snippets, p_features, p_bg_ratio = \
804 extract_snippet_features(data, w_eod_xp, w_eod_heights, width)
805 raw_t_snippets, t_snippets, t_features, t_bg_ratio = \
806 extract_snippet_features(data, w_eod_xt, w_eod_heights, width)
808 height_labels, bgm_log_dict = \
809 BGM(w_eod_heights, min(merge_threshold_height,
810 np.median(np.min(np.vstack([p_bg_ratio, t_bg_ratio]),
811 axis=0))), n_gaus_height, use_log=True,
812 verbose=verbose-1, plot_level=plot_level-1, xlabel =
813 'height [a.u.]', save_plot=save_plots,
814 save_path=save_path, save_name = 'height_%d' % wi,
815 ftype=ftype, return_data=return_data)
816 saved_data.update(bgm_log_dict)
818 if verbose > 0:
819 print('Clusters generated based on EOD height:')
820 for l in np.unique(height_labels):
821 print(f'N_{l} = {len(height_labels[height_labels==l]):4d} h_{l} = {np.mean(w_eod_heights[height_labels==l]):.4f}')
823 h_labels, h_counts = unique_counts(height_labels)
824 unique_height_labels = h_labels[h_counts>minp]
826 if plot_level > 0 or 'all_cluster_steps' in return_data:
827 all_heightlabels.append(height_labels)
828 all_heights.append(w_eod_heights)
829 all_unique_heightlabels.append(unique_height_labels)
830 shape_labels = []
831 cfeatures = []
832 csnippets = []
834 for hi, height_label in enumerate(unique_height_labels):
836 h_eod_widths = w_eod_widths[height_labels==height_label]
837 h_eod_heights = w_eod_heights[height_labels==height_label]
838 h_eod_xp = w_eod_xp[height_labels==height_label]
839 h_eod_xt = w_eod_xt[height_labels==height_label]
841 p_clusters = cluster_on_shape(p_features[height_labels==height_label],
842 p_bg_ratio, minp, verbose=0)
843 t_clusters = cluster_on_shape(t_features[height_labels==height_label],
844 t_bg_ratio, minp, verbose=0)
846 if plot_level > 1:
847 plot_feature_extraction(raw_p_snippets[height_labels==height_label],
848 p_snippets[height_labels==height_label],
849 p_features[height_labels==height_label],
850 p_clusters, 1/rate, 0)
851 plt.savefig('%sDBSCAN_peak_w%i_h%i.%s' % (save_path, wi, hi, ftype))
852 plot_feature_extraction(raw_t_snippets[height_labels==height_label],
853 t_snippets[height_labels==height_label],
854 t_features[height_labels==height_label],
855 t_clusters, 1/rate, 1)
856 plt.savefig('%sDBSCAN_trough_w%i_h%i.%s' % (save_path, wi, hi, ftype))
858 if 'snippet_clusters' in return_data:
859 saved_data[f'snippet_clusters_{width_label}_{height_label}_peak'] = {
860 'raw_snippets': raw_p_snippets[height_labels==height_label],
861 'snippets': p_snippets[height_labels==height_label],
862 'features': p_features[height_labels==height_label],
863 'clusters': p_clusters,
864 'rate': rate}
865 saved_data['snippet_clusters_{width_label}_{height_label}_trough'] = {
866 'raw_snippets': raw_t_snippets[height_labels==height_label],
867 'snippets': t_snippets[height_labels==height_label],
868 'features': t_features[height_labels==height_label],
869 'clusters': t_clusters,
870 'rate': rate}
872 if plot_level > 0 or 'all_cluster_steps' in return_data:
873 shape_labels.append([p_clusters, t_clusters])
874 cfeatures.append([p_features[height_labels==height_label],
875 t_features[height_labels==height_label]])
876 csnippets.append([p_snippets[height_labels==height_label],
877 t_snippets[height_labels==height_label]])
879 p_clusters[p_clusters==-1] = -max_label_p - 1
880 wp_clusters[height_labels==height_label] = p_clusters + max_label_p
881 max_label_p = max(np.max(wp_clusters), np.max(all_p_clusters)) + 1
883 t_clusters[t_clusters==-1] = -max_label_t - 1
884 wt_clusters[height_labels==height_label] = t_clusters + max_label_t
885 max_label_t = max(np.max(wt_clusters), np.max(all_t_clusters)) + 1
887 if verbose > 0:
888 if np.max(wp_clusters) == -1:
889 print(f'No EOD peaks in width cluster {width_label}')
890 else:
891 unique_clusters = np.unique(wp_clusters[wp_clusters!=-1])
892 if len(unique_clusters) > 1:
893 print(f'{len(unique_clusters)} different EOD peaks in width cluster {width_label}')
895 if plot_level > 0 or 'all_cluster_steps' in return_data:
896 all_shapelabels.append(shape_labels)
897 all_snippets.append(csnippets)
898 all_features.append(cfeatures)
900 # for each cluster, save fft + label
901 # so I end up with features for each label, and the masks.
902 # then I can extract e.g. first artefact or wave etc.
904 # remove artefacts here, based on the mean snippets ffts.
905 artefact_masks_p[width_labels==width_label], sdict = \
906 remove_artefacts(p_snippets, wp_clusters, rate,
907 verbose=verbose-1, return_data=return_data)
908 saved_data.update(sdict)
909 artefact_masks_t[width_labels==width_label], _ = \
910 remove_artefacts(t_snippets, wt_clusters, rate,
911 verbose=verbose-1, return_data=return_data)
913 # update maxlab so that no clusters are overwritten
914 all_p_clusters[width_labels==width_label] = wp_clusters
915 all_t_clusters[width_labels==width_label] = wt_clusters
917 # remove all non-reliable clusters
918 unreliable_fish_mask_p, saved_data = \
919 delete_unreliable_fish(all_p_clusters, eod_widths, eod_xp,
920 verbose=verbose-1, sdict=saved_data)
921 unreliable_fish_mask_t, _ = \
922 delete_unreliable_fish(all_t_clusters, eod_widths, eod_xt, verbose=verbose-1)
924 wave_mask_p, sidepeak_mask_p, saved_data = \
925 delete_wavefish_and_sidepeaks(data, all_p_clusters, eod_xp, eod_widths,
926 width_factor_wave, verbose=verbose-1, sdict=saved_data)
927 wave_mask_t, sidepeak_mask_t, _ = \
928 delete_wavefish_and_sidepeaks(data, all_t_clusters, eod_xt, eod_widths,
929 width_factor_wave, verbose=verbose-1)
931 og_clusters = [np.copy(all_p_clusters), np.copy(all_t_clusters)]
932 og_labels = np.copy(all_p_clusters + all_t_clusters)
934 # go through all clusters and masks??
935 all_p_clusters[(artefact_masks_p | unreliable_fish_mask_p | wave_mask_p | sidepeak_mask_p)] = -1
936 all_t_clusters[(artefact_masks_t | unreliable_fish_mask_t | wave_mask_t | sidepeak_mask_t)] = -1
938 # merge here.
939 all_clusters, x_merge, mask = merge_clusters(np.copy(all_p_clusters),
940 np.copy(all_t_clusters),
941 eod_xp, eod_xt,
942 verbose=verbose - 1)
944 if 'all_cluster_steps' in return_data or plot_level > 0:
945 all_dmasks = []
946 all_mmasks = []
948 discarding_masks = \
949 np.vstack(((artefact_masks_p | unreliable_fish_mask_p | wave_mask_p | sidepeak_mask_p),
950 (artefact_masks_t | unreliable_fish_mask_t | wave_mask_t | sidepeak_mask_t)))
951 merge_mask = mask
953 # save the masks in the same formats as the snippets
954 for wi, (width_label, w_shape_label, heightlabels, unique_height_labels) in enumerate(zip(unique_width_labels, all_shapelabels, all_heightlabels, all_unique_heightlabels)):
955 w_dmasks = discarding_masks[:,width_labels==width_label]
956 w_mmasks = merge_mask[:,width_labels==width_label]
958 wd_2 = []
959 wm_2 = []
961 for hi, (height_label, h_shape_label) in enumerate(zip(unique_height_labels, w_shape_label)):
963 h_dmasks = w_dmasks[:,heightlabels==height_label]
964 h_mmasks = w_mmasks[:,heightlabels==height_label]
966 wd_2.append(h_dmasks)
967 wm_2.append(h_mmasks)
969 all_dmasks.append(wd_2)
970 all_mmasks.append(wm_2)
972 if plot_level > 0:
973 plot_clustering(rate, [unique_width_labels, eod_widths, width_labels],
974 [all_unique_heightlabels, all_heights, all_heightlabels],
975 [all_snippets, all_features, all_shapelabels],
976 all_dmasks, all_mmasks)
977 if save_plots:
978 plt.savefig('%sclustering.%s' % (save_path, ftype))
980 if 'all_cluster_steps' in return_data:
981 saved_data = {'rate': rate,
982 'EOD_widths': [unique_width_labels, eod_widths, width_labels],
983 'EOD_heights': [all_unique_heightlabels, all_heights, all_heightlabels],
984 'EOD_shapes': [all_snippets, all_features, all_shapelabels],
985 'discarding_masks': all_dmasks,
986 'merge_masks': all_mmasks
987 }
989 if 'masks' in return_data:
990 saved_data = {'masks' : np.vstack(((artefact_masks_p & artefact_masks_t),
991 (unreliable_fish_mask_p & unreliable_fish_mask_t),
992 (wave_mask_p & wave_mask_t),
993 (sidepeak_mask_p & sidepeak_mask_t),
994 (all_p_clusters+all_t_clusters)))}
996 if verbose > 0:
997 print('Clusters generated based on height, width and shape: ')
998 for l in np.unique(all_clusters[all_clusters != -1]):
999 print(f'N_{int(l)} = {len(all_clusters[all_clusters == l]):4d}')
1001 return all_clusters, x_merge, saved_data
1004def BGM(x, merge_threshold=0.1, n_gaus=5, max_iter=200, n_init=5,
1005 use_log=False, verbose=0, plot_level=0, xlabel='x [a.u.]',
1006 save_plot=False, save_path='', save_name='', ftype='pdf',
1007 return_data=[]):
1008 """ Use a Bayesian Gaussian Mixture Model to cluster one-dimensional data.
1010 Additional steps are used to merge clusters that are closer than
1011 `merge_threshold`. Broad gaussian fits that cover one or more other
1012 gaussian fits are split by their intersections with the other
1013 gaussians.
1015 Parameters
1016 ----------
1017 x : 1D numpy array
1018 Features to compute clustering on.
1020 merge_threshold : float (optional)
1021 Ratio for merging nearby gaussians.
1022 n_gaus: int (optional)
1023 Maximum number of gaussians to fit on data.
1024 max_iter : int (optional)
1025 Maximum number of iterations for gaussian fit.
1026 n_init : int (optional)
1027 Number of initializations for the gaussian fit.
1028 use_log: boolean (optional)
1029 Set to True to compute the gaussian fit on the logarithm of x.
1030 Can improve clustering on features with nonlinear relationships such as peak height.
1031 verbose : int (optional)
1032 Verbosity level.
1033 plot_level : int (optional)
1034 Similar to verbosity levels, but with plots.
1035 Only set to > 0 for debugging purposes.
1036 xlabel : string (optional)
1037 Xlabel for displaying BGM plot.
1038 save_plot : bool (optional)
1039 Set to True to save created plot.
1040 save_path : string (optional)
1041 Path to location where data should be saved. Only used if save_plot==True.
1042 save_name : string (optional)
1043 Filename of the saved plot. Usefull as usually multiple BGM models are generated.
1044 ftype : string (optional)
1045 Filetype of plot image if save_plots==True.
1046 return_data : list of strings (optional)
1047 Keys that specify data to be logged. Keys that can be used to log data
1048 in this function are: 'BGM_width' and/or 'BGM_height' (see extract_pulsefish()).
1050 Returns
1051 -------
1052 labels : 1D numpy array
1053 Cluster labels for each sample in x.
1054 bgm_dict : dictionary
1055 Key value pairs of logged data. Data to be logged is specified by return_data.
1056 """
1058 bgm_dict = {}
1060 if len(np.unique(x)) > n_gaus:
1061 BGM_model = BayesianGaussianMixture(n_components=n_gaus, max_iter=max_iter, n_init=n_init)
1062 if use_log:
1063 labels = BGM_model.fit_predict(stats.zscore(np.log(x)).reshape(-1, 1))
1064 else:
1065 labels = BGM_model.fit_predict(stats.zscore(x).reshape(-1, 1))
1066 else:
1067 return np.zeros(len(x)), bgm_dict
1069 if verbose>0:
1070 if not BGM_model.converged_:
1071 print('!!! Gaussian mixture did not converge !!!')
1073 cur_labels = np.unique(labels)
1075 # map labels to be increasing for increasing values for x
1076 maxlab = len(cur_labels)
1077 aso = np.argsort([np.median(x[labels == l]) for l in cur_labels]) + 100
1078 for i, a in zip(cur_labels, aso):
1079 labels[labels==i] = a
1080 labels = labels - 100
1082 # separate gaussian clusters that can be split by other clusters
1083 splits = np.sort(np.copy(x))[1:][np.diff(labels[np.argsort(x)])!=0]
1085 labels[:] = 0
1086 for i, split in enumerate(splits):
1087 labels[x>=split] = i+1
1089 labels_before_merge = np.copy(labels)
1091 # merge gaussian clusters that are closer than merge_threshold
1092 labels = merge_gaussians(x, labels, merge_threshold)
1094 if 'BGM_'+save_name.split('_')[0] in return_data or plot_level>0:
1096 #sort model attributes by model_means_
1097 means = [m[0] for m in BGM_model.means_]
1098 weights = [w for w in BGM_model.weights_]
1099 variances = [v[0][0] for v in BGM_model.covariances_]
1100 weights = [w for _, w in sorted(zip(means, weights))]
1101 variances = [v for _, v in sorted(zip(means, variances))]
1102 means = sorted(means)
1104 if plot_level>0:
1105 plot_bgm(x, means, variances, weights, use_log, labels_before_merge,
1106 labels, xlabel)
1107 if save_plot:
1108 plt.savefig('%sBGM_%s.%s' % (save_path, save_name, ftype))
1110 if 'BGM_'+save_name.split('_')[0] in return_data:
1111 bgm_dict['BGM_'+save_name] = {'x':x,
1112 'use_log':use_log,
1113 'BGM':[weights, means, variances],
1114 'labels':labels_before_merge,
1115 'xlab':xlabel}
1117 return labels, bgm_dict
1120def merge_gaussians(x, labels, merge_threshold=0.1):
1121 """ Merge all clusters which have medians which are near. Only works in 1D.
1123 Parameters
1124 ----------
1125 x : 1D array of ints or floats
1126 Features used for clustering.
1127 labels : 1D array of ints
1128 Labels for each sample in x.
1129 merge_threshold : float (optional)
1130 Similarity threshold to merge clusters.
1132 Returns
1133 -------
1134 labels : 1D array of ints
1135 Merged labels for each sample in x.
1136 """
1138 # compare all the means of the gaussians. If they are too close, merge them.
1139 unique_labels = np.unique(labels[labels!=-1])
1140 x_medians = [np.median(x[labels==l]) for l in unique_labels]
1142 # fill a dict with the label mappings
1143 mapping = {}
1144 for label_1, x_m1 in zip(unique_labels, x_medians):
1145 for label_2, x_m2 in zip(unique_labels, x_medians):
1146 if label_1!=label_2:
1147 if np.abs(np.diff([x_m1, x_m2]))/np.max([x_m1, x_m2]) < merge_threshold:
1148 mapping[label_1] = label_2
1149 # apply mapping
1150 for map_key, map_value in mapping.items():
1151 labels[labels==map_key] = map_value
1153 return labels
1156def extract_snippet_features(data, eod_x, eod_heights, width, n_pc=5):
1157 """ Extract snippets from recording data, normalize them, and perform PCA.
1159 Parameters
1160 ----------
1161 data : 1D numpy array of floats
1162 Recording data.
1163 eod_x : 1D array of ints
1164 Locations of EODs as indices.
1165 eod_heights: 1D array of floats
1166 EOD heights.
1167 width : int
1168 Width to cut out to each side in samples.
1170 n_pc : int (optional)
1171 Number of PCs to use for PCA.
1173 Returns
1174 -------
1175 raw_snippets : 2D numpy array (N, EOD_width)
1176 Raw extracted EOD snippets.
1177 snippets : 2D numpy array (N, EOD_width)
1178 Normalized EOD snippets
1179 features : 2D numpy array (N,n_pc)
1180 PC values of EOD snippets
1181 bg_ratio : 1D numpy array (N)
1182 Ratio of the background activity slopes compared to EOD height.
1183 """
1184 # extract snippets with corresponding width
1185 raw_snippets = np.vstack([data[x-width:x+width] for x in eod_x])
1187 # subtract the slope and normalize the snippets
1188 snippets, bg_ratio = subtract_slope(np.copy(raw_snippets), eod_heights)
1189 snippets = StandardScaler().fit_transform(snippets.T).T
1191 # scale so that the absolute integral = 1.
1192 snippets = (snippets.T/np.sum(np.abs(snippets), axis=1)).T
1194 # compute features for clustering on waveform
1195 features = PCA(n_pc).fit_transform(snippets)
1197 return raw_snippets, snippets, features, bg_ratio
1200def cluster_on_shape(features, bg_ratio, minp, percentile=80,
1201 max_epsilon=0.01, slope_ratio_factor=4,
1202 min_cluster_fraction=0.01, verbose=0):
1203 """Separate EODs by their shape using DBSCAN.
1205 Parameters
1206 ----------
1207 features : 2D numpy array of floats (N, n_pc)
1208 PCA features of each EOD in a recording.
1209 bg_ratio : 1D array of floats
1210 Ratio of background activity slope the EOD is superimposed on.
1211 minp : int
1212 Minimum number of points for core cluster (DBSCAN).
1214 percentile : int (optional)
1215 Percentile of KNN distribution, where K=minp, to use as epsilon for DBSCAN.
1216 max_epsilon : float (optional)
1217 Maximum epsilon to use for DBSCAN clustering. This is used to avoid adding
1218 noisy clusters.
1219 slope_ratio_factor : float (optional)
1220 Influence of the slope-to-EOD ratio on the epsilon parameter.
1221 A slope_ratio_factor of 4 means that slope-to-EOD ratios >1/4
1222 start influencing epsilon.
1223 min_cluster_fraction : float (optional)
1224 Minimum fraction of all eveluated datapoint that can form a single cluster.
1225 verbose : int (optional)
1226 Verbosity level.
1228 Returns
1229 -------
1230 labels : 1D array of ints
1231 Merged labels for each sample in x.
1232 """
1234 # determine clustering threshold from data
1235 minpc = max(minp, int(len(features)*min_cluster_fraction))
1236 knn = np.sort(pairwise_distances(features, features), axis=0)[minpc]
1237 eps = min(max(1, slope_ratio_factor*np.median(bg_ratio))*max_epsilon,
1238 np.percentile(knn, percentile))
1240 if verbose>1:
1241 print('epsilon = %f'%eps)
1242 print('Slope to EOD ratio = %f'%np.median(bg_ratio))
1244 # cluster on EOD shape
1245 return DBSCAN(eps=eps, min_samples=minpc).fit(features).labels_
1248def subtract_slope(snippets, heights):
1249 """ Subtract underlying slope from all EOD snippets.
1251 Parameters
1252 ----------
1253 snippets: 2-D numpy array
1254 All EODs in a recorded stacked as snippets.
1255 Shape = (number of EODs, EOD width)
1256 heights: 1D numpy array
1257 EOD heights.
1259 Returns
1260 -------
1261 snippets: 2-D numpy array
1262 EOD snippets with underlying slope subtracted.
1263 bg_ratio : 1-D numpy array
1264 EOD height/background activity height.
1265 """
1267 left_y = snippets[:,0]
1268 right_y = snippets[:,-1]
1270 try:
1271 slopes = np.linspace(left_y, right_y, snippets.shape[1])
1272 except ValueError:
1273 delta = (right_y - left_y)/snippets.shape[1]
1274 slopes = np.arange(0, snippets.shape[1], dtype=snippets.dtype).reshape((-1,) + (1,) * np.ndim(delta))*delta + left_y
1276 return snippets - slopes.T, np.abs(left_y-right_y)/heights
1279def remove_artefacts(all_snippets, clusters, rate,
1280 freq_low=20000, threshold=0.75,
1281 verbose=0, return_data=[]):
1282 """ Create a mask for EOD clusters that result from artefacts, based on power in low frequency spectrum.
1284 Parameters
1285 ----------
1286 all_snippets: 2D array
1287 EOD snippets. Shape=(nEODs, EOD length)
1288 clusters: list of ints
1289 EOD cluster labels
1290 rate : float
1291 Sampling rate of original recording data.
1292 freq_low: float
1293 Frequency up to which low frequency components are summed up.
1294 threshold : float (optional)
1295 Minimum value for sum of low frequency components relative to
1296 sum overa ll spectrl amplitudes that separates artefact from
1297 clean pulsefish clusters.
1298 verbose : int (optional)
1299 Verbosity level.
1300 return_data : list of strings (optional)
1301 Keys that specify data to be logged. The key that can be used to log data in this function is
1302 'eod_deletion' (see extract_pulsefish()).
1304 Returns
1305 -------
1306 mask: numpy array of booleans
1307 Set to True for every EOD which is an artefact.
1308 adict : dictionary
1309 Key value pairs of logged data. Data to be logged is specified by return_data.
1310 """
1311 adict = {}
1313 mask = np.zeros(clusters.shape, dtype=bool)
1315 for cluster in np.unique(clusters[clusters >= 0]):
1316 snippets = all_snippets[clusters == cluster]
1317 mean_eod = np.mean(snippets, axis=0)
1318 mean_eod = mean_eod - np.mean(mean_eod)
1319 mean_eod_fft = np.abs(np.fft.rfft(mean_eod))
1320 freqs = np.fft.rfftfreq(len(mean_eod), 1/rate)
1321 low_frequency_ratio = np.sum(mean_eod_fft[freqs<freq_low])/np.sum(mean_eod_fft)
1322 if low_frequency_ratio < threshold: # TODO: check threshold!
1323 mask[clusters==cluster] = True
1325 if verbose > 0:
1326 print('Deleting cluster %i with low frequency ratio of %.3f (min %.3f)' % (cluster, low_frequency_ratio, threshold))
1328 if 'eod_deletion' in return_data:
1329 adict['vals_%d' % cluster] = [mean_eod, mean_eod_fft]
1330 adict['mask_%d' % cluster] = [np.any(mask[clusters==cluster])]
1332 return mask, adict
1335def delete_unreliable_fish(clusters, eod_widths, eod_x, verbose=0, sdict={}):
1336 """ Create a mask for EOD clusters that are either mixed with noise or other fish, or wavefish.
1338 This is the case when the ration between the EOD width and the ISI is too large.
1340 Parameters
1341 ----------
1342 clusters : list of ints
1343 Cluster labels.
1344 eod_widths : list of floats or ints
1345 EOD widths in samples or seconds.
1346 eod_x : list of ints or floats
1347 EOD times in samples or seconds.
1349 verbose : int (optional)
1350 Verbosity level.
1351 sdict : dictionary
1352 Dictionary that is used to log data. This is only used if a dictionary
1353 was created by remove_artefacts().
1354 For logging data in noise and wavefish discarding steps,
1355 see remove_artefacts().
1357 Returns
1358 -------
1359 mask : numpy array of booleans
1360 Set to True for every unreliable EOD.
1361 sdict : dictionary
1362 Key value pairs of logged data. Data is only logged if a dictionary
1363 was instantiated by remove_artefacts().
1364 """
1365 mask = np.zeros(clusters.shape, dtype=bool)
1366 for cluster in np.unique(clusters[clusters >= 0]):
1367 if len(eod_x[cluster == clusters]) < 2:
1368 mask[clusters == cluster] = True
1369 if verbose>0:
1370 print('deleting unreliable cluster %i, number of EOD times %d < 2' % (cluster, len(eod_x[cluster==clusters])))
1371 elif np.max(np.median(eod_widths[clusters==cluster])/np.diff(eod_x[cluster==clusters])) > 0.5:
1372 if verbose>0:
1373 print('deleting unreliable cluster %i, score=%f' % (cluster, np.max(np.median(eod_widths[clusters==cluster])/np.diff(eod_x[cluster==clusters]))))
1374 mask[clusters==cluster] = True
1375 if 'vals_%d' % cluster in sdict:
1376 sdict['vals_%d' % cluster].append(np.median(eod_widths[clusters==cluster])/np.diff(eod_x[cluster==clusters]))
1377 sdict['mask_%d' % cluster].append(any(mask[clusters==cluster]))
1378 return mask, sdict
1381def delete_wavefish_and_sidepeaks(data, clusters, eod_x, eod_widths,
1382 width_fac, max_slope_deviation=0.5,
1383 max_phases=4, verbose=0, sdict={}):
1384 """ Create a mask for EODs that are likely from wavefish, or sidepeaks of bigger EODs.
1386 Parameters
1387 ----------
1388 data : list of floats
1389 Raw recording data.
1390 clusters : list of ints
1391 Cluster labels.
1392 eod_x : list of ints
1393 Indices of EOD times.
1394 eod_widths : list of ints
1395 EOD widths in samples.
1396 width_fac : float
1397 Multiplier for EOD analysis width.
1399 max_slope_deviation: float (optional)
1400 Maximum deviation of position of maximum slope in snippets from
1401 center position in multiples of mean width of EOD.
1402 max_phases : int (optional)
1403 Maximum number of phases for any EOD.
1404 If the mean EOD has more phases than this, it is not a pulse EOD.
1405 verbose : int (optional)
1406 Verbosity level.
1407 sdict : dictionary
1408 Dictionary that is used to log data. This is only used if a dictionary
1409 was created by remove_artefacts().
1410 For logging data in noise and wavefish discarding steps, see remove_artefacts().
1412 Returns
1413 -------
1414 mask_wave: numpy array of booleans
1415 Set to True for every EOD which is a wavefish EOD.
1416 mask_sidepeak: numpy array of booleans
1417 Set to True for every snippet which is centered around a sidepeak of an EOD.
1418 sdict : dictionary
1419 Key value pairs of logged data. Data is only logged if a dictionary
1420 was instantiated by remove_artefacts().
1421 """
1422 mask_wave = np.zeros(clusters.shape, dtype=bool)
1423 mask_sidepeak = np.zeros(clusters.shape, dtype=bool)
1425 for i, cluster in enumerate(np.unique(clusters[clusters >= 0])):
1426 mean_width = np.mean(eod_widths[clusters == cluster])
1427 cutwidth = mean_width*width_fac
1428 current_x = eod_x[(eod_x>cutwidth) & (eod_x<(len(data)-cutwidth))]
1429 current_clusters = clusters[(eod_x>cutwidth) & (eod_x<(len(data)-cutwidth))]
1430 snippets = np.vstack([data[int(x-cutwidth):int(x+cutwidth)]
1431 for x in current_x[current_clusters==cluster]])
1433 # extract information on main peaks and troughs:
1434 mean_eod = np.mean(snippets, axis=0)
1435 mean_eod = mean_eod - np.mean(mean_eod)
1437 # detect peaks and troughs on data + some maxima/minima at the
1438 # end, so that the sides are also considered for peak detection:
1439 pk, tr = detect_peaks(np.concatenate(([-10*mean_eod[0]], mean_eod, [10*mean_eod[-1]])),
1440 np.std(mean_eod))
1441 pk = pk[(pk>0)&(pk<len(mean_eod))]
1442 tr = tr[(tr>0)&(tr<len(mean_eod))]
1444 if len(pk)>0 and len(tr)>0:
1445 idxs = np.sort(np.concatenate((pk, tr)))
1446 slopes = np.abs(np.diff(mean_eod[idxs]))
1447 m_slope = np.argmax(slopes)
1448 centered = np.min(np.abs(idxs[m_slope:m_slope+2] - len(mean_eod)//2))
1450 # compute all height differences of peaks and troughs within snippets.
1451 # if they are all similar, it is probably noise or a wavefish.
1452 idxs = np.sort(np.concatenate((pk, tr)))
1453 hdiffs = np.diff(mean_eod[idxs])
1455 if centered > max_slope_deviation*mean_width: # TODO: check, factor was probably 0.16
1456 if verbose > 0:
1457 print('Deleting cluster %i, which is a sidepeak' % cluster)
1458 mask_sidepeak[clusters==cluster] = True
1460 w_diff = np.abs(np.diff(np.sort(np.concatenate((pk, tr)))))
1462 if np.abs(np.diff(idxs[m_slope:m_slope+2])) < np.mean(eod_widths[clusters==cluster])*0.5 or len(pk) + len(tr)>max_phases or np.min(w_diff)>2*cutwidth/width_fac: #or len(hdiffs[np.abs(hdiffs)>0.5*(np.max(mean_eod)-np.min(mean_eod))])>max_phases:
1463 if verbose>0:
1464 print('Deleting cluster %i, which is a wavefish' % cluster)
1465 mask_wave[clusters==cluster] = True
1466 if 'vals_%d' % cluster in sdict:
1467 sdict['vals_%d' % cluster].append([mean_eod, [pk, tr],
1468 idxs[m_slope:m_slope+2]])
1469 sdict['mask_%d' % cluster].append(any(mask_wave[clusters==cluster]))
1470 sdict['mask_%d' % cluster].append(any(mask_sidepeak[clusters==cluster]))
1472 return mask_wave, mask_sidepeak, sdict
1475def merge_clusters(clusters_1, clusters_2, x_1, x_2, verbose=0):
1476 """ Merge clusters resulting from two clustering methods.
1478 This method only works if clustering is performed on the same EODs
1479 with the same ordering, where there is a one to one mapping from
1480 clusters_1 to clusters_2.
1482 Parameters
1483 ----------
1484 clusters_1: list of ints
1485 EOD cluster labels for cluster method 1.
1486 clusters_2: list of ints
1487 EOD cluster labels for cluster method 2.
1488 x_1: list of ints
1489 Indices of EODs for cluster method 1 (clusters_1).
1490 x_2: list of ints
1491 Indices of EODs for cluster method 2 (clusters_2).
1492 verbose : int (optional)
1493 Verbosity level.
1495 Returns
1496 -------
1497 clusters : list of ints
1498 Merged clusters.
1499 x_merged : list of ints
1500 Merged cluster indices.
1501 mask : 2d numpy array of ints (N, 2)
1502 Mask for clusters that are selected from clusters_1 (mask[:,0]) and
1503 from clusters_2 (mask[:,1]).
1504 """
1505 if verbose > 0:
1506 print('\nMerge cluster:')
1508 # these arrays become 1 for each EOD that is chosen from that array
1509 c1_keep = np.zeros(len(clusters_1))
1510 c2_keep = np.zeros(len(clusters_2))
1512 # add n to one of the cluster lists to avoid overlap
1513 ovl = np.max(clusters_1) + 1
1514 clusters_2[clusters_2!=-1] = clusters_2[clusters_2!=-1] + ovl
1516 remove_clusters = [[]]
1517 keep_clusters = []
1518 og_clusters = [np.copy(clusters_1), np.copy(clusters_2)]
1520 # loop untill done
1521 while True:
1523 # compute unique clusters and cluster sizes
1524 # of cluster that have not been iterated over:
1525 c1_labels, c1_size = unique_counts(clusters_1[(clusters_1 != -1) & (c1_keep == 0)])
1526 c2_labels, c2_size = unique_counts(clusters_2[(clusters_2 != -1) & (c2_keep == 0)])
1528 # if all clusters are done, break from loop:
1529 if len(c1_size) == 0 and len(c2_size) == 0:
1530 break
1532 # if the biggest cluster is in c_p, keep this one and discard all clusters
1533 # on the same indices in c_t:
1534 elif np.argmax([np.max(np.append(c1_size, 0)), np.max(np.append(c2_size, 0))]) == 0:
1536 # remove all the mappings from the other indices
1537 cluster_mappings, _ = unique_counts(clusters_2[clusters_1 == c1_labels[np.argmax(c1_size)]])
1539 clusters_2[np.isin(clusters_2, cluster_mappings)] = -1
1541 c1_keep[clusters_1==c1_labels[np.argmax(c1_size)]] = 1
1543 remove_clusters.append(cluster_mappings)
1544 keep_clusters.append(c1_labels[np.argmax(c1_size)])
1546 if verbose > 0:
1547 print('Keep cluster %i of group 1, delete clusters %s of group 2' % (c1_labels[np.argmax(c1_size)], str(cluster_mappings[cluster_mappings!=-1] - ovl)))
1549 # if the biggest cluster is in c_t, keep this one and discard all mappings in c_p
1550 elif np.argmax([np.max(np.append(c1_size, 0)), np.max(np.append(c2_size, 0))]) == 1:
1552 # remove all the mappings from the other indices
1553 cluster_mappings, _ = unique_counts(clusters_1[clusters_2 == c2_labels[np.argmax(c2_size)]])
1555 clusters_1[np.isin(clusters_1, cluster_mappings)] = -1
1557 c2_keep[clusters_2==c2_labels[np.argmax(c2_size)]] = 1
1559 remove_clusters.append(cluster_mappings)
1560 keep_clusters.append(c2_labels[np.argmax(c2_size)])
1562 if verbose > 0:
1563 print('Keep cluster %i of group 2, delete clusters %s of group 1' % (c2_labels[np.argmax(c2_size)] - ovl, str(cluster_mappings[cluster_mappings!=-1])))
1565 # combine results
1566 clusters = (clusters_1+1)*c1_keep + (clusters_2+1)*c2_keep - 1
1567 x_merged = (x_1)*c1_keep + (x_2)*c2_keep
1569 return clusters, x_merged, np.vstack([c1_keep, c2_keep])
1572def extract_means(data, eod_x, eod_peak_x, eod_tr_x, eod_widths,
1573 clusters, rate, width_fac, verbose=0):
1574 """ Extract mean EODs and EOD timepoints for each EOD cluster.
1576 Parameters
1577 ----------
1578 data: list of floats
1579 Raw recording data.
1580 eod_x: list of ints
1581 Locations of EODs in samples.
1582 eod_peak_x : list of ints
1583 Locations of EOD peaks in samples.
1584 eod_tr_x : list of ints
1585 Locations of EOD troughs in samples.
1586 eod_widths: list of ints
1587 EOD widths in samples.
1588 clusters: list of ints
1589 EOD cluster labels
1590 rate: float
1591 Sampling rate of recording
1592 width_fac : float
1593 Multiplication factor for window used to extract EOD.
1595 verbose : int (optional)
1596 Verbosity level.
1598 Returns
1599 -------
1600 mean_eods: list of 2D arrays (3, eod_length)
1601 The average EOD for each detected fish. First column is time in seconds,
1602 second column the mean eod, third column the standard error.
1603 eod_times: list of 1D arrays
1604 For each detected fish the times of EOD in seconds.
1605 eod_peak_times: list of 1D arrays
1606 For each detected fish the times of EOD peaks in seconds.
1607 eod_trough_times: list of 1D arrays
1608 For each detected fish the times of EOD troughs in seconds.
1609 eod_labels: list of ints
1610 Cluster label for each detected fish.
1611 """
1612 mean_eods, eod_times, eod_peak_times, eod_tr_times, eod_heights, cluster_labels = [], [], [], [], [], []
1614 for cluster in np.unique(clusters):
1615 if cluster!=-1:
1616 cutwidth = np.mean(eod_widths[clusters==cluster])*width_fac
1617 current_x = eod_x[(eod_x>cutwidth) & (eod_x<(len(data)-cutwidth))]
1618 current_clusters = clusters[(eod_x>cutwidth) & (eod_x<(len(data)-cutwidth))]
1620 snippets = np.vstack([data[int(x-cutwidth):int(x+cutwidth)] for x in current_x[current_clusters==cluster]])
1621 mean_eod = np.mean(snippets, axis=0)
1622 eod_time = np.arange(len(mean_eod))/rate - cutwidth/rate
1624 mean_eod = np.vstack([eod_time, mean_eod, np.std(snippets, axis=0)])
1626 mean_eods.append(mean_eod)
1627 eod_times.append(eod_x[clusters==cluster]/rate)
1628 eod_heights.append(np.min(mean_eod)-np.max(mean_eod))
1629 eod_peak_times.append(eod_peak_x[clusters==cluster]/rate)
1630 eod_tr_times.append(eod_tr_x[clusters==cluster]/rate)
1631 cluster_labels.append(cluster)
1633 return [m for _, m in sorted(zip(eod_heights, mean_eods))], [t for _, t in sorted(zip(eod_heights, eod_times))], [pt for _, pt in sorted(zip(eod_heights, eod_peak_times))], [tt for _, tt in sorted(zip(eod_heights, eod_tr_times))], [c for _, c in sorted(zip(eod_heights, cluster_labels))]
1636def find_clipped_clusters(clusters, mean_eods, eod_times,
1637 eod_peaktimes, eod_troughtimes,
1638 cluster_labels, width_factor,
1639 clip_threshold=0.9, verbose=0):
1640 """ Detect EODs that are clipped and set all clusterlabels of these clipped EODs to -1.
1642 Also return the mean EODs and timepoints of these clipped EODs.
1644 Parameters
1645 ----------
1646 clusters: array of ints
1647 Cluster labels for each EOD in a recording.
1648 mean_eods: list of numpy arrays
1649 Mean EOD waveform for each cluster.
1650 eod_times: list of numpy arrays
1651 EOD timepoints for each EOD cluster.
1652 eod_peaktimes
1653 EOD peaktimes for each EOD cluster.
1654 eod_troughtimes
1655 EOD troughtimes for each EOD cluster.
1656 cluster_labels: numpy array
1657 Unique EOD clusterlabels.
1658 clip_threshold: float
1659 Threshold for detecting clipped EODs.
1661 verbose: int
1662 Verbosity level.
1664 Returns
1665 -------
1666 clusters : array of ints
1667 Cluster labels for each EOD in the recording, where clipped EODs have been set to -1.
1668 clipped_eods : list of numpy arrays
1669 Mean EOD waveforms for each clipped EOD cluster.
1670 clipped_times : list of numpy arrays
1671 EOD timepoints for each clipped EOD cluster.
1672 clipped_peaktimes : list of numpy arrays
1673 EOD peaktimes for each clipped EOD cluster.
1674 clipped_troughtimes : list of numpy arrays
1675 EOD troughtimes for each clipped EOD cluster.
1676 """
1677 clipped_eods, clipped_times, clipped_peaktimes, clipped_troughtimes, clipped_labels = [], [], [], [], []
1679 for mean_eod, eod_time, eod_peaktime, eod_troughtime,label in zip(mean_eods, eod_times, eod_peaktimes, eod_troughtimes, cluster_labels):
1681 if (np.count_nonzero(mean_eod[1]>clip_threshold) > len(mean_eod[1])/(width_factor*2)) or (np.count_nonzero(mean_eod[1] < -clip_threshold) > len(mean_eod[1])/(width_factor*2)):
1682 clipped_eods.append(mean_eod)
1683 clipped_times.append(eod_time)
1684 clipped_peaktimes.append(eod_peaktime)
1685 clipped_troughtimes.append(eod_troughtime)
1686 clipped_labels.append(label)
1687 if verbose>0:
1688 print('clipped pulsefish')
1690 clusters[np.isin(clusters, clipped_labels)] = -1
1692 return clusters, clipped_eods, clipped_times, clipped_peaktimes, clipped_troughtimes
1695def delete_moving_fish(clusters, eod_t, T, eod_heights, eod_widths,
1696 rate, min_dt=0.25, stepsize=0.05,
1697 sliding_window_factor=2000, verbose=0,
1698 plot_level=0, save_plot=False, save_path='',
1699 ftype='pdf', return_data=[]):
1700 """
1701 Use a sliding window to detect the minimum number of fish detected simultaneously,
1702 then delete all other EOD clusters.
1704 Do this only for EODs within the same width clusters, as a
1705 moving fish will preserve its EOD width.
1707 Parameters
1708 ----------
1709 clusters: list of ints
1710 EOD cluster labels.
1711 eod_t: list of floats
1712 Timepoints of the EODs (in seconds).
1713 T: float
1714 Length of recording (in seconds).
1715 eod_heights: list of floats
1716 EOD amplitudes.
1717 eod_widths: list of floats
1718 EOD widths (in seconds).
1719 rate: float
1720 Recording data sampling rate.
1722 min_dt : float (optional)
1723 Minimum sliding window size (in seconds).
1724 stepsize : float (optional)
1725 Sliding window stepsize (in seconds).
1726 sliding_window_factor : float
1727 Multiplier for sliding window width,
1728 where the sliding window width = median(EOD_width)*sliding_window_factor.
1729 verbose : int (optional)
1730 Verbosity level.
1731 plot_level : int (optional)
1732 Similar to verbosity levels, but with plots.
1733 Only set to > 0 for debugging purposes.
1734 save_plot : bool (optional)
1735 Set to True to save the plots created by plot_level.
1736 save_path : string (optional)
1737 Path to save data to. Only important if you wish to save data (save_data==True).
1738 ftype : string (optional)
1739 Define the filetype to save the plots in if save_plots is set to True.
1740 Options are: 'png', 'jpg', 'svg' ...
1741 return_data : list of strings (optional)
1742 Keys that specify data to be logged. The key that can be used to log data
1743 in this function is 'moving_fish' (see extract_pulsefish()).
1745 Returns
1746 -------
1747 clusters : list of ints
1748 Cluster labels, where deleted clusters have been set to -1.
1749 window : list of 2 floats
1750 Start and end of window selected for deleting moving fish in seconds.
1751 mf_dict : dictionary
1752 Key value pairs of logged data. Data to be logged is specified by return_data.
1753 """
1754 mf_dict = {}
1756 if len(np.unique(clusters[clusters != -1])) == 0:
1757 return clusters, [0, 1], {}
1759 all_keep_clusters = []
1760 width_classes = merge_gaussians(eod_widths, np.copy(clusters), 0.75)
1762 all_windows = []
1763 all_dts = []
1764 ev_num = 0
1765 for iw, w in enumerate(np.unique(width_classes[clusters >= 0])):
1766 # initialize variables
1767 min_clusters = 100
1768 average_height = 0
1769 sparse_clusters = 100
1770 keep_clusters = []
1772 dt = max(min_dt, np.median(eod_widths[width_classes==w])*sliding_window_factor)
1773 window_start = 0
1774 window_end = dt
1776 wclusters = clusters[width_classes==w]
1777 weod_t = eod_t[width_classes==w]
1778 weod_heights = eod_heights[width_classes==w]
1779 weod_widths = eod_widths[width_classes==w]
1781 all_dts.append(dt)
1783 if verbose>0:
1784 print('sliding window dt = %f'%dt)
1786 # make W dependent on width??
1787 ignore_steps = np.zeros(len(np.arange(0, T-dt+stepsize, stepsize)))
1789 for i, t in enumerate(np.arange(0, T-dt+stepsize, stepsize)):
1790 current_clusters = wclusters[(weod_t>=t)&(weod_t<t+dt)&(wclusters!=-1)]
1791 if len(np.unique(current_clusters))==0:
1792 ignore_steps[i-int(dt/stepsize):i+int(dt/stepsize)] = 1
1793 if verbose>0:
1794 print('No pulsefish in recording at T=%.2f:%.2f' % (t, t+dt))
1797 x = np.arange(0, T-dt+stepsize, stepsize)
1798 y = np.ones(len(x))
1800 running_sum = np.ones(len(np.arange(0, T+stepsize, stepsize)))
1801 ulabs = np.unique(wclusters[wclusters>=0])
1803 # sliding window
1804 for j, (t, ignore_step) in enumerate(zip(x, ignore_steps)):
1805 current_clusters = wclusters[(weod_t>=t)&(weod_t<t+dt)&(wclusters!=-1)]
1806 current_widths = weod_widths[(weod_t>=t)&(weod_t<t+dt)&(wclusters!=-1)]
1808 unique_clusters = np.unique(current_clusters)
1809 y[j] = len(unique_clusters)
1811 if (len(unique_clusters) <= min_clusters) and \
1812 (ignore_step==0) and \
1813 (len(unique_clusters !=1)):
1815 current_labels = np.isin(wclusters, unique_clusters)
1816 current_height = np.mean(weod_heights[current_labels])
1818 # compute nr of clusters that are too sparse
1819 clusters_after_deletion = np.unique(remove_sparse_detections(np.copy(clusters[np.isin(clusters, unique_clusters)]), rate*eod_widths[np.isin(clusters, unique_clusters)], rate, T))
1820 current_sparse_clusters = len(unique_clusters) - len(clusters_after_deletion[clusters_after_deletion!=-1])
1822 if current_sparse_clusters <= sparse_clusters and \
1823 ((current_sparse_clusters<sparse_clusters) or
1824 (current_height > average_height) or
1825 (len(unique_clusters) < min_clusters)):
1827 keep_clusters = unique_clusters
1828 min_clusters = len(unique_clusters)
1829 average_height = current_height
1830 window_end = t+dt
1831 sparse_clusters = current_sparse_clusters
1833 all_keep_clusters.append(keep_clusters)
1834 all_windows.append(window_end)
1836 if 'moving_fish' in return_data or plot_level>0:
1837 if 'w' in mf_dict:
1838 mf_dict['w'].append(np.median(eod_widths[width_classes==w]))
1839 mf_dict['T'] = T
1840 mf_dict['dt'].append(dt)
1841 mf_dict['clusters'].append(wclusters)
1842 mf_dict['t'].append(weod_t)
1843 mf_dict['fishcount'].append([x+0.5*(x[1]-x[0]), y])
1844 mf_dict['ignore_steps'].append(ignore_steps)
1845 else:
1846 mf_dict['w'] = [np.median(eod_widths[width_classes==w])]
1847 mf_dict['T'] = [T]
1848 mf_dict['dt'] = [dt]
1849 mf_dict['clusters'] = [wclusters]
1850 mf_dict['t'] = [weod_t]
1851 mf_dict['fishcount'] = [[x+0.5*(x[1]-x[0]), y]]
1852 mf_dict['ignore_steps'] = [ignore_steps]
1854 if verbose>0:
1855 print('Estimated nr of pulsefish in recording: %i'%len(all_keep_clusters))
1857 if plot_level>0:
1858 plot_moving_fish(mf_dict['w'], mf_dict['dt'], mf_dict['clusters'],mf_dict['t'],
1859 mf_dict['fishcount'], T, mf_dict['ignore_steps'])
1860 if save_plot:
1861 plt.savefig('%sdelete_moving_fish.%s' % (save_path, ftype))
1862 # empty dict
1863 if 'moving_fish' not in return_data:
1864 mf_dict = {}
1866 # delete all clusters that are not selected
1867 clusters[np.invert(np.isin(clusters, np.concatenate(all_keep_clusters)))] = -1
1869 return clusters, [np.max(all_windows)-np.max(all_dts), np.max(all_windows)], mf_dict
1872def remove_sparse_detections(clusters, eod_widths, rate, T,
1873 min_density=0.0005, verbose=0):
1874 """ Remove all EOD clusters that are too sparse
1876 Parameters
1877 ----------
1878 clusters : list of ints
1879 Cluster labels.
1880 eod_widths : list of ints
1881 Cluster widths in samples.
1882 rate : float
1883 Sampling rate.
1884 T : float
1885 Lenght of recording in seconds.
1886 min_density : float (optional)
1887 Minimum density for realistic EOD detections.
1888 verbose : int (optional)
1889 Verbosity level.
1891 Returns
1892 -------
1893 clusters : list of ints
1894 Cluster labels, where sparse clusters have been set to -1.
1895 """
1896 for c in np.unique(clusters):
1897 if c!=-1:
1899 n = len(clusters[clusters==c])
1900 w = np.median(eod_widths[clusters==c])/rate
1902 if n*w < T*min_density:
1903 if verbose>0:
1904 print('cluster %i is too sparse'%c)
1905 clusters[clusters==c] = -1
1906 return clusters