Coverage for src/thunderfish/pulses.py: 0%
602 statements
« prev ^ index » next coverage.py v7.7.1, created at 2025-03-23 22:57 +0000
« prev ^ index » next coverage.py v7.7.1, created at 2025-03-23 22:57 +0000
1"""
2Extract and cluster EOD waverforms of pulse-type electric fish.
4## Main function
6- `extract_pulsefish()`: checks for pulse-type fish based on the EOD amplitude and shape.
8"""
10import os
11import numpy as np
12from scipy import stats
13from scipy.interpolate import interp1d
14from sklearn.preprocessing import StandardScaler
15from sklearn.decomposition import PCA
16from sklearn.cluster import DBSCAN
17from sklearn.mixture import BayesianGaussianMixture
18from sklearn.metrics import pairwise_distances
19from thunderlab.eventdetection import detect_peaks, median_std_threshold
20from .pulseplots import *
22import warnings
23def warn(*args, **kwargs):
24 """
25 Ignore all warnings.
26 """
27 pass
28warnings.warn = warn
30try:
31 from numba import jit
32except ImportError:
33 def jit(*args, **kwargs):
34 def decorator_jit(func):
35 return func
36 return decorator_jit
39# upgrade numpy functions for backwards compatibility:
40if not hasattr(np, 'isin'):
41 np.isin = np.in1d
43def unique_counts(ar):
44 """ Find the unique elements of an array and their counts, ignoring shape.
46 The code is condensed from numpy version 1.17.0.
48 Parameters
49 ----------
50 ar : numpy array
51 Input array
53 Returns
54 -------
55 unique_vaulues : numpy array
56 Unique values in array ar.
57 unique_counts : numpy array
58 Number of instances for each unique value in ar.
59 """
60 try:
61 return np.unique(ar, return_counts=True)
62 except TypeError:
63 ar = np.asanyarray(ar).flatten()
64 ar.sort()
65 mask = np.empty(ar.shape, dtype=bool_)
66 mask[:1] = True
67 mask[1:] = ar[1:] != ar[:-1]
68 idx = np.concatenate(np.nonzero(mask) + ([mask.size],))
69 return ar[mask], np.diff(idx)
72###########################################################################
75def extract_pulsefish(data, rate, amax, width_factor_shape=3,
76 width_factor_wave=8, width_factor_display=4,
77 verbose=0, plot_level=0, save_plots=False,
78 save_path='', ftype='png', return_data=[]):
79 """ Extract and cluster pulse-type fish EODs from single channel data.
81 Takes recording data containing an unknown number of pulsefish and extracts the mean
82 EOD and EOD timepoints for each fish present in the recording.
84 Parameters
85 ----------
86 data: 1-D array of float
87 The data to be analysed.
88 rate: float
89 Sampling rate of the data in Hertz.
90 amax: float
91 Maximum amplitude of data range.
92 width_factor_shape : float (optional)
93 Width multiplier used for EOD shape analysis.
94 EOD snippets are extracted based on width between the
95 peak and trough multiplied by the width factor.
96 width_factor_wave : float (optional)
97 Width multiplier used for wavefish detection.
98 width_factor_display : float (optional)
99 Width multiplier used for EOD mean extraction and display.
100 verbose : int (optional)
101 Verbosity level.
102 plot_level : int (optional)
103 Similar to verbosity levels, but with plots.
104 Only set to > 0 for debugging purposes.
105 save_plots : bool (optional)
106 Set to True to save the plots created by plot_level.
107 save_path: string (optional)
108 Path for saving plots.
109 ftype : string (optional)
110 Define the filetype to save the plots in if save_plots is set to True.
111 Options are: 'png', 'jpg', 'svg' ...
112 return_data : list of strings (optional)
113 Specify data that should be logged and returned in a dictionary. Each clustering
114 step has a specific keyword that results in adding different variables to the log dictionary.
115 Optional keys for return_data and the resulting additional key-value pairs to the log dictionary are:
117 - 'all_eod_times':
118 - 'all_times': list of two lists of floats.
119 All peak (`all_times[0]`) and trough times (`all_times[1]`) extracted
120 by the peak detection algorithm. Times are given in seconds.
121 - 'eod_troughtimes': list of 1D arrays.
122 The timepoints in seconds of each unique extracted EOD cluster,
123 where each 1D array encodes one cluster.
125 - 'peak_detection':
126 - "data": 1D numpy array of floats.
127 Quadratically interpolated data which was used for peak detection.
128 - "interp_fac": float.
129 Interpolation factor of raw data.
130 - "peaks_1": 1D numpy array of ints.
131 Peak indices on interpolated data after first peak detection step.
132 - "troughs_1": 1D numpy array of ints.
133 Peak indices on interpolated data after first peak detection step.
134 - "peaks_2": 1D numpy array of ints.
135 Peak indices on interpolated data after second peak detection step.
136 - "troughs_2": 1D numpy array of ints.
137 Peak indices on interpolated data after second peak detection step.
138 - "peaks_3": 1D numpy array of ints.
139 Peak indices on interpolated data after third peak detection step.
140 - "troughs_3": 1D numpy array of ints.
141 Peak indices on interpolated data after third peak detection step.
142 - "peaks_4": 1D numpy array of ints.
143 Peak indices on interpolated data after fourth peak detection step.
144 - "troughs_4": 1D numpy array of ints.
145 Peak indices on interpolated data after fourth peak detection step.
147 - 'all_cluster_steps':
148 - 'rate': float.
149 Sampling rate of interpolated data.
150 - 'EOD_widths': list of three 1D numpy arrays.
151 The first list entry gives the unique labels of all width clusters
152 as a list of ints.
153 The second list entry gives the width values for each EOD in samples
154 as a 1D numpy array of ints.
155 The third list entry gives the width labels for each EOD
156 as a 1D numpy array of ints.
157 - 'EOD_heights': nested lists (2 layers) of three 1D numpy arrays.
158 The first list entry gives the unique labels of all height clusters
159 as a list of ints for each width cluster.
160 The second list entry gives the height values for each EOD
161 as a 1D numpy array of floats for each width cluster.
162 The third list entry gives the height labels for each EOD
163 as a 1D numpy array of ints for each width cluster.
164 - 'EOD_shapes': nested lists (3 layers) of three 1D numpy arrays
165 The first list entry gives the raw EOD snippets as a 2D numpy array
166 for each height cluster in a width cluster.
167 The second list entry gives the snippet PCA values for each EOD
168 as a 2D numpy array of floats for each height cluster in a width cluster.
169 The third list entry gives the shape labels for each EOD as a 1D numpy array
170 of ints for each height cluster in a width cluster.
171 - 'discarding_masks': Nested lists (two layers) of 1D numpy arrays.
172 The masks of EODs that are discarded by the discarding step of the algorithm.
173 The masks are 1D boolean arrays where instances that are set to True are
174 discarded by the algorithm. Discarding masks are saved in nested lists
175 that represent the width and height clusters.
176 - 'merge_masks': Nested lists (two layers) of 2D numpy arrays.
177 The masks of EODs that are discarded by the merging step of the algorithm.
178 The masks are 2D boolean arrays where for each sample point `i` either
179 `merge_mask[i,0]` or `merge_mask[i,1]` is set to True. Here, merge_mask[:,0]
180 represents the peak-centered clusters and `merge_mask[:,1]` represents the
181 trough-centered clusters. Merge masks are saved in nested lists that
182 represent the width and height clusters.
184 - 'BGM_width':
185 - 'BGM_width': dictionary
186 - 'x': 1D numpy array of floats.
187 BGM input values (in this case the EOD widths),
188 - 'use_log': boolean.
189 True if the z-scored logarithm of the data was used as BGM input.
190 - 'BGM': list of three 1D numpy arrays.
191 The first instance are the weights of the Gaussian fits.
192 The second instance are the means of the Gaussian fits.
193 The third instance are the variances of the Gaussian fits.
194 - 'labels': 1D numpy array of ints.
195 Labels defined by BGM model (before merging based on merge factor).
196 - xlab': string.
197 Label for plot (defines the units of the BGM data).
199 - 'BGM_height':
200 This key adds a new dictionary for each width cluster.
201 - 'BGM_height_*n*' : dictionary, where *n* defines the width cluster as an int.
202 - 'x': 1D numpy array of floats.
203 BGM input values (in this case the EOD heights),
204 - 'use_log': boolean.
205 True if the z-scored logarithm of the data was used as BGM input.
206 - 'BGM': list of three 1D numpy arrays.
207 The first instance are the weights of the Gaussian fits.
208 The second instance are the means of the Gaussian fits.
209 The third instance are the variances of the Gaussian fits.
210 - 'labels': 1D numpy array of ints.
211 Labels defined by BGM model (before merging based on merge factor).
212 - 'xlab': string.
213 Label for plot (defines the units of the BGM data).
215 - 'snippet_clusters':
216 This key adds a new dictionary for each height cluster.
217 - 'snippet_clusters*_n_m_p*' : dictionary, where *n* defines the width cluster
218 (int), *m* defines the height cluster (int) and *p* defines shape clustering
219 on peak or trough centered EOD snippets (string: 'peak' or 'trough').
220 - 'raw_snippets': 2D numpy array (nsamples, nfeatures).
221 Raw EOD snippets.
222 - 'snippets': 2D numpy array.
223 Normalized EOD snippets.
224 - 'features': 2D numpy array.(nsamples, nfeatures)
225 PCA values for each normalized EOD snippet.
226 - 'clusters': 1D numpy array of ints.
227 Cluster labels.
228 - 'rate': float.
229 Sampling rate of snippets.
231 - 'eod_deletion':
232 This key adds two dictionaries for each (peak centered) shape cluster,
233 where *cluster* (int) is the unique shape cluster label.
234 - 'mask_*cluster*' : list of four booleans.
235 The mask for each cluster discarding step.
236 The first instance represents the artefact masks, where artefacts
237 are set to True.
238 The second instance represents the unreliable cluster masks,
239 where unreliable clusters are set to True.
240 The third instance represents the wavefish masks, where wavefish
241 are set to True.
242 The fourth instance represents the sidepeak masks, where sidepeaks
243 are set to True.
244 - 'vals_*cluster*' : list of lists.
245 All variables that are used for each cluster deletion step.
246 The first instance is a list of two 1D numpy arrays: the mean EOD and
247 the FFT of that mean EOD.
248 The second instance is a 1D numpy array with all EOD width to ISI ratios.
249 The third instance is a list with three entries:
250 The first entry is a 1D numpy array zoomed out version of the mean EOD.
251 The second entry is a list of two 1D numpy arrays that define the peak
252 and trough indices of the zoomed out mean EOD.
253 The third entry contains a list of two values that represent the
254 peak-trough pair in the zoomed out mean EOD with the largest height
255 difference.
256 - 'rate' : float.
257 EOD snippet sampling rate.
259 - 'masks':
260 - 'masks' : 2D numpy array (4,N).
261 Each row contains masks for each EOD detected by the EOD peakdetection step.
262 The first row defines the artefact masks, the second row defines the
263 unreliable EOD masks,
264 the third row defines the wavefish masks and the fourth row defines
265 the sidepeak masks.
267 - 'moving_fish':
268 - 'moving_fish': dictionary.
269 - 'w' : list of floats.
270 Median width for each width cluster that the moving fish algorithm is
271 computed on (in seconds).
272 - 'T' : list of floats.
273 Lenght of analyzed recording for each width cluster (in seconds).
274 - 'dt' : list of floats.
275 Sliding window size (in seconds) for each width cluster.
276 - 'clusters' : list of 1D numpy int arrays.
277 Cluster labels for each EOD cluster in a width cluster.
278 - 't' : list of 1D numpy float arrays.
279 EOD emission times for each EOD in a width cluster.
280 - 'fishcount' : list of lists.
281 Sliding window timepoints and fishcounts for each width cluster.
282 - 'ignore_steps' : list of 1D int arrays.
283 Mask for fishcounts that were ignored (ignored if True) in the
284 moving_fish analysis.
286 Returns
287 -------
288 mean_eods: list of 2D arrays (3, eod_length)
289 The average EOD for each detected fish. First column is time in seconds,
290 second column the mean eod, third column the standard error.
291 eod_times: list of 1D arrays
292 For each detected fish the times of EOD peaks or troughs in seconds.
293 Use these timepoints for EOD averaging.
294 eod_peaktimes: list of 1D arrays
295 For each detected fish the times of EOD peaks in seconds.
296 zoom_window: tuple of floats
297 Start and endtime of suggested window for plotting EOD timepoints.
298 log_dict: dictionary
299 Dictionary with logged variables, where variables to log are specified
300 by `return_data`.
301 """
302 if verbose > 0:
303 print('')
304 if verbose > 1:
305 print(70*'#')
306 print('##### extract_pulsefish', 46*'#')
308 if (save_plots and plot_level>0 and save_path):
309 # create folder to save things in.
310 if not os.path.exists(save_path):
311 os.makedirs(save_path)
312 else:
313 save_path = ''
315 mean_eods, eod_times, eod_peaktimes, zoom_window = [], [], [], []
316 log_dict = {}
318 # interpolate:
319 i_rate = 500000.0
320 #i_rate = rate
321 try:
322 f = interp1d(np.arange(len(data))/rate, data, kind='quadratic')
323 i_data = f(np.arange(0.0, (len(data)-1)/rate, 1.0/i_rate))
324 except MemoryError:
325 i_rate = rate
326 i_data = data
327 log_dict['data'] = i_data # TODO: could be removed
328 log_dict['rate'] = i_rate # TODO: could be removed
329 log_dict['i_data'] = i_data
330 log_dict['i_rate'] = i_rate
331 # log_dict["interp_fac"] = interp_fac # TODO: is not set anymore
333 # standard deviation of data in small snippets:
334 win_size = int(0.002*rate) # 2ms windows
335 threshold = median_std_threshold(data, win_size) # TODO make this a parameter
337 # extract peaks:
338 if 'peak_detection' in return_data:
339 x_peak, x_trough, eod_heights, eod_widths, pd_log_dict = \
340 detect_pulses(i_data, i_rate, threshold,
341 width_fac=np.max([width_factor_shape,
342 width_factor_display,
343 width_factor_wave]),
344 verbose=verbose-1, return_data=True)
345 log_dict.update(pd_log_dict)
346 else:
347 x_peak, x_trough, eod_heights, eod_widths = \
348 detect_pulses(i_data, i_rate, threshold,
349 width_fac=np.max([width_factor_shape,
350 width_factor_display,
351 width_factor_wave]),
352 verbose=verbose-1, return_data=False)
354 if len(x_peak) > 0:
355 # cluster
356 clusters, x_merge, c_log_dict = cluster(x_peak, x_trough,
357 eod_heights,
358 eod_widths, i_data,
359 i_rate,
360 width_factor_shape,
361 width_factor_wave,
362 merge_threshold_height=0.1*amax,
363 verbose=verbose-1,
364 plot_level=plot_level-1,
365 save_plots=save_plots,
366 save_path=save_path,
367 ftype=ftype,
368 return_data=return_data)
370 # extract mean eods and times
371 mean_eods, eod_times, eod_peaktimes, eod_troughtimes, cluster_labels = \
372 extract_means(i_data, x_merge, x_peak, x_trough, eod_widths, clusters,
373 i_rate, width_factor_display, verbose=verbose-1)
375 # determine clipped clusters (save them, but ignore in other steps)
376 clusters, clipped_eods, clipped_times, clipped_peaktimes, clipped_troughtimes = \
377 find_clipped_clusters(clusters, mean_eods, eod_times, eod_peaktimes,
378 eod_troughtimes, cluster_labels, width_factor_display,
379 verbose=verbose-1)
381 # delete the moving fish
382 clusters, zoom_window, mf_log_dict = \
383 delete_moving_fish(clusters, x_merge/i_rate, len(data)/rate,
384 eod_heights, eod_widths/i_rate, i_rate,
385 verbose=verbose-1, plot_level=plot_level-1, save_plot=save_plots,
386 save_path=save_path, ftype=ftype, return_data=return_data)
388 if 'moving_fish' in return_data:
389 log_dict['moving_fish'] = mf_log_dict
391 clusters = remove_sparse_detections(clusters, eod_widths, i_rate,
392 len(data)/rate, verbose=verbose-1)
394 # extract mean eods
395 mean_eods, eod_times, eod_peaktimes, eod_troughtimes, cluster_labels = \
396 extract_means(i_data, x_merge, x_peak, x_trough, eod_widths,
397 clusters, i_rate, width_factor_display, verbose=verbose-1)
399 mean_eods.extend(clipped_eods)
400 eod_times.extend(clipped_times)
401 eod_peaktimes.extend(clipped_peaktimes)
402 eod_troughtimes.extend(clipped_troughtimes)
404 if plot_level > 0:
405 plot_all(data, eod_peaktimes, eod_troughtimes, rate, mean_eods)
406 if save_plots:
407 plt.savefig('%sextract_pulsefish_results.%s' % (save_path, ftype))
408 if save_plots:
409 plt.close('all')
411 if 'all_eod_times' in return_data:
412 log_dict['all_times'] = [x_peak/i_rate, x_trough/i_rate]
413 log_dict['eod_troughtimes'] = eod_troughtimes
415 log_dict.update(c_log_dict)
417 return mean_eods, eod_times, eod_peaktimes, zoom_window, log_dict
420def detect_pulses(data, rate, thresh, min_rel_slope_diff=0.25,
421 min_width=0.00005, max_width=0.01, width_fac=5.0,
422 verbose=0, return_data=False):
423 """Detect pulses in data.
425 Was `def extract_eod_times(data, rate, width_factor,
426 interp_freq=500000, max_peakwidth=0.01,
427 min_peakwidth=None, verbose=0, return_data=[],
428 save_path='')` before.
430 Parameters
431 ----------
432 data: 1-D array of float
433 The data to be analysed.
434 rate: float
435 Sampling rate of the data.
436 thresh: float
437 Threshold for peak and trough detection via `detect_peaks()`.
438 Must be a positive number that sets the minimum difference
439 between a peak and a trough.
440 min_rel_slope_diff: float
441 Minimum required difference between left and right slope (between
442 peak and troughs) relative to mean slope for deciding which trough
443 to take besed on slope difference.
444 min_width: float
445 Minimum width (peak-trough distance) of pulses in seconds.
446 max_width: float
447 Maximum width (peak-trough distance) of pulses in seconds.
448 width_fac: float
449 Pulses extend plus or minus `width_fac` times their width
450 (distance between peak and assigned trough).
451 Only pulses are returned that can fully be analysed with this width.
452 verbose : int (optional)
453 Verbosity level.
454 return_data : bool
455 If `True` data of this function is logged and returned (see
456 extract_pulsefish()).
458 Returns
459 -------
460 peak_indices: array of ints
461 Indices of EOD peaks in data.
462 trough_indices: array of ints
463 Indices of EOD troughs in data. There is one x_trough for each x_peak.
464 heights: array of floats
465 EOD heights for each x_peak.
466 widths: array of ints
467 EOD widths for each x_peak (in samples).
468 peak_detection_result : dictionary
469 Key value pairs of logged data.
470 This is only returned if `return_data` is `True`.
472 """
473 peak_detection_result = {}
475 # detect peaks and troughs in the data:
476 peak_indices, trough_indices = detect_peaks(data, thresh)
477 if verbose > 0:
478 print('Peaks/troughs detected in data: %5d %5d'
479 % (len(peak_indices), len(trough_indices)))
480 if return_data:
481 peak_detection_result.update(peaks_1=np.array(peak_indices),
482 troughs_1=np.array(trough_indices))
483 if len(peak_indices) < 2 or \
484 len(trough_indices) < 2 or \
485 len(peak_indices) > len(data)/20:
486 # TODO: if too many peaks increase threshold!
487 if verbose > 0:
488 print('No or too many peaks/troughs detected in data.')
489 if return_data:
490 return np.array([], dtype=int), np.array([], dtype=int), \
491 np.array([]), np.array([], dtype=int), peak_detection_result
492 else:
493 return np.array([], dtype=int), np.array([], dtype=int), \
494 np.array([]), np.array([], dtype=int)
496 # assign troughs to peaks:
497 peak_indices, trough_indices, heights, widths, slopes = \
498 assign_side_peaks(data, peak_indices, trough_indices, min_rel_slope_diff)
499 if verbose > 1:
500 print('Number of peaks after assigning side-peaks: %5d'
501 % (len(peak_indices)))
502 if return_data:
503 peak_detection_result.update(peaks_2=np.array(peak_indices),
504 troughs_2=np.array(trough_indices))
506 # check widths:
507 keep = ((widths>min_width*rate) & (widths<max_width*rate))
508 peak_indices = peak_indices[keep]
509 trough_indices = trough_indices[keep]
510 heights = heights[keep]
511 widths = widths[keep]
512 slopes = slopes[keep]
513 if verbose > 1:
514 print('Number of peaks after checking pulse width: %5d'
515 % (len(peak_indices)))
516 if return_data:
517 peak_detection_result.update(peaks_3=np.array(peak_indices),
518 troughs_3=np.array(trough_indices))
520 # discard connected peaks:
521 same = np.nonzero(trough_indices[:-1] == trough_indices[1:])[0]
522 keep = np.ones(len(trough_indices), dtype=bool)
523 for i in same:
524 # same troughs at trough_indices[i] and trough_indices[i+1]:
525 s = slopes[i:i+2]
526 rel_slopes = np.abs(np.diff(s))[0]/np.mean(s)
527 if rel_slopes > min_rel_slope_diff:
528 keep[i+(s[1]<s[0])] = False
529 else:
530 keep[i+(heights[i+1]<heights[i])] = False
531 peak_indices = peak_indices[keep]
532 trough_indices = trough_indices[keep]
533 heights = heights[keep]
534 widths = widths[keep]
535 if verbose > 1:
536 print('Number of peaks after merging pulses: %5d'
537 % (len(peak_indices)))
538 if return_data:
539 peak_detection_result.update(peaks_4=np.array(peak_indices),
540 troughs_4=np.array(trough_indices))
541 if len(peak_indices) == 0:
542 if verbose > 0:
543 print('No peaks remain as pulse candidates.')
544 if return_data:
545 return np.array([], dtype=int), np.array([], dtype=int), \
546 np.array([]), np.array([], dtype=int), peak_detection_result
547 else:
548 return np.array([], dtype=int), np.array([], dtype=int), \
549 np.array([]), np.array([], dtype=int)
551 # only take those where the maximum cutwidth does not cause issues -
552 # if the width_fac times the width + x is more than length.
553 keep = ((peak_indices - widths > 0) &
554 (peak_indices + widths < len(data)) &
555 (trough_indices - widths > 0) &
556 (trough_indices + widths < len(data)))
558 if verbose > 0:
559 print('Remaining peaks after EOD extraction: %5d'
560 % (p.sum(keep)))
561 print('')
563 if return_data:
564 return peak_indices[keep], trough_indices[keep], \
565 heights[keep], widths[keep], peak_detection_result
566 else:
567 return peak_indices[keep], trough_indices[keep], \
568 heights[keep], widths[keep]
571@jit(nopython=True)
572def assign_side_peaks(data, peak_indices, trough_indices,
573 min_rel_slope_diff=0.25):
574 """Assign to each peak the trough resulting in a pulse with the steepest slope or largest height.
576 The slope between a peak and a trough is computed as the height
577 difference divided by the distance between peak and trough. If the
578 slopes between the left and the right trough differ by less than
579 `min_rel_slope_diff`, then just the heigths between and the two
580 troughs relative to the peak are compared.
582 Was `def detect_eod_peaks(data, main_indices, side_indices,
583 max_width=20, min_width=2, verbose=0)` before.
585 Parameters
586 ----------
587 data: array of floats
588 Data in which the events were detected.
589 peak_indices: array of ints
590 Indices of the detected peaks in the data time series.
591 trough_indices: array of ints
592 Indices of the detected troughs in the data time series.
593 min_rel_slope_diff: float
594 Minimum required difference of left and right slope relative
595 to mean slope.
597 Returns
598 -------
599 peak_indices: array of ints
600 Peak indices. Same as input `peak_indices` but potentially shorter
601 by one or two elements.
602 trough_indices: array of ints
603 Corresponding trough indices of trough to the left or right
604 of the peaks.
605 heights: array of floats
606 Peak heights (distance between peak and corresponding trough amplitude)
607 widths: array of ints
608 Peak widths (distance between peak and corresponding trough indices)
609 slopes: array of floats
610 Peak slope (height divided by width)
611 """
612 # is a main or side peak first?
613 peak_first = int(peak_indices[0] < trough_indices[0])
614 # is a main or side peak last?
615 peak_last = int(peak_indices[-1] > trough_indices[-1])
616 # ensure all peaks to have side peaks (troughs) at both sides,
617 # i.e. troughs at same index and next index are before and after peak:
618 peak_indices = peak_indices[peak_first:len(peak_indices)-peak_last]
619 y = data[peak_indices]
621 # indices of troughs on the left and right side of main peaks:
622 l_indices = np.arange(len(peak_indices))
623 r_indices = l_indices + 1
625 # indices, distance to peak, height, and slope of left troughs:
626 l_side_indices = trough_indices[l_indices]
627 l_distance = np.abs(peak_indices - l_side_indices)
628 l_height = np.abs(y - data[l_side_indices])
629 l_slope = np.abs(l_height/l_distance)
631 # indices, distance to peak, height, and slope of right troughs:
632 r_side_indices = trough_indices[r_indices]
633 r_distance = np.abs(r_side_indices - peak_indices)
634 r_height = np.abs(y - data[r_side_indices])
635 r_slope = np.abs(r_height/r_distance)
637 # which trough to assign to the peak?
638 # - either the one with the steepest slope, or
639 # - when slopes are similar on both sides
640 # (within `min_rel_slope_diff` difference),
641 # the trough with the maximum height difference to the peak:
642 rel_slopes = np.abs(l_slope-r_slope)/(0.5*(l_slope+r_slope))
643 take_slopes = rel_slopes > min_rel_slope_diff
644 take_left = l_height > r_height
645 take_left[take_slopes] = l_slope[take_slopes] > r_slope[take_slopes]
647 # assign troughs, heights, widths, and slopes:
648 trough_indices = np.where(take_left,
649 trough_indices[:-1], trough_indices[1:])
650 heights = np.where(take_left, l_height, r_height)
651 widths = np.where(take_left, l_distance, r_distance)
652 slopes = np.where(take_left, l_slope, r_slope)
654 return peak_indices, trough_indices, heights, widths, slopes
657def cluster(eod_xp, eod_xt, eod_heights, eod_widths, data, rate,
658 width_factor_shape, width_factor_wave, n_gaus_height=10,
659 merge_threshold_height=0.1, n_gaus_width=3,
660 merge_threshold_width=0.5, minp=10, verbose=0,
661 plot_level=0, save_plots=False, save_path='', ftype='pdf',
662 return_data=[]):
663 """Cluster EODs.
665 First cluster on EOD widths using a Bayesian Gaussian
666 Mixture (BGM) model, then cluster on EOD heights using a
667 BGM model. Lastly, cluster on EOD waveform with DBSCAN.
668 Clustering on EOD waveform is performed twice, once on
669 peak-centered EODs and once on trough-centered EODs.
670 Non-pulsetype EOD clusters are deleted, and clusters are
671 merged afterwards.
673 Parameters
674 ----------
675 eod_xp : list of ints
676 Location of EOD peaks in indices.
677 eod_xt: list of ints
678 Locations of EOD troughs in indices.
679 eod_heights: list of floats
680 EOD heights.
681 eod_widths: list of ints
682 EOD widths in samples.
683 data: array of floats
684 Data in which to detect pulse EODs.
685 rate : float
686 Sampling rate of `data`.
687 width_factor_shape : float
688 Multiplier for snippet extraction width. This factor is
689 multiplied with the width between the peak and through of a
690 single EOD.
691 width_factor_wave : float
692 Multiplier for wavefish extraction width.
693 n_gaus_height : int (optional)
694 Number of gaussians to use for the clustering based on EOD height.
695 merge_threshold_height : float (optional)
696 Threshold for merging clusters that are similar in height.
697 n_gaus_width : int (optional)
698 Number of gaussians to use for the clustering based on EOD width.
699 merge_threshold_width : float (optional)
700 Threshold for merging clusters that are similar in width.
701 minp : int (optional)
702 Minimum number of points for core clusters (DBSCAN).
703 verbose : int (optional)
704 Verbosity level.
705 plot_level : int (optional)
706 Similar to verbosity levels, but with plots.
707 Only set to > 0 for debugging purposes.
708 save_plots : bool (optional)
709 Set to True to save created plots.
710 save_path : string (optional)
711 Path to save plots to. Only used if save_plots==True.
712 ftype : string (optional)
713 Filetype to save plot images in.
714 return_data : list of strings (optional)
715 Keys that specify data to be logged. Keys that can be used to log data
716 in this function are: 'all_cluster_steps', 'BGM_width', 'BGM_height',
717 'snippet_clusters', 'eod_deletion' (see extract_pulsefish()).
719 Returns
720 -------
721 labels : list of ints
722 EOD cluster labels based on height and EOD waveform.
723 x_merge : list of ints
724 Locations of EODs in clusters.
725 saved_data : dictionary
726 Key value pairs of logged data. Data to be logged is specified
727 by return_data.
729 """
730 saved_data = {}
732 if plot_level>0 or 'all_cluster_steps' in return_data:
733 all_heightlabels = []
734 all_shapelabels = []
735 all_snippets = []
736 all_features = []
737 all_heights = []
738 all_unique_heightlabels = []
740 all_p_clusters = -1 * np.ones(len(eod_xp))
741 all_t_clusters = -1 * np.ones(len(eod_xp))
742 artefact_masks_p = np.ones(len(eod_xp), dtype=bool)
743 artefact_masks_t = np.ones(len(eod_xp), dtype=bool)
745 x_merge = -1 * np.ones(len(eod_xp))
747 max_label_p = 0 # keep track of the labels so that no labels are overwritten
748 max_label_t = 0
750 # loop only over height clusters that are bigger than minp
751 # first cluster on width
752 width_labels, bgm_log_dict = BGM(1000*eod_widths/rate,
753 merge_threshold_width,
754 n_gaus_width, use_log=False,
755 verbose=verbose-1,
756 plot_level=plot_level-1,
757 xlabel='width [ms]',
758 save_plot=save_plots,
759 save_path=save_path,
760 save_name='width', ftype=ftype,
761 return_data=return_data)
762 saved_data.update(bgm_log_dict)
764 if verbose > 0:
765 print('Clusters generated based on EOD width:')
766 for l in np.unique(width_labels):
767 print(f'N_{l} = {len(width_labels[width_labels==l]):4d} h_{l} = {np.mean(eod_widths[width_labels==l]):.4f}')
769 w_labels, w_counts = unique_counts(width_labels)
770 unique_width_labels = w_labels[w_counts>minp]
772 for wi, width_label in enumerate(unique_width_labels):
774 # select only features in one width cluster at a time
775 w_eod_widths = eod_widths[width_labels==width_label]
776 w_eod_heights = eod_heights[width_labels==width_label]
777 w_eod_xp = eod_xp[width_labels==width_label]
778 w_eod_xt = eod_xt[width_labels==width_label]
779 width = int(width_factor_shape*np.median(w_eod_widths))
780 if width > w_eod_xp[0]:
781 width = w_eod_xp[0]
782 if width > w_eod_xt[0]:
783 width = w_eod_xt[0]
784 if width > len(data) - w_eod_xp[-1]:
785 width = len(data) - w_eod_xp[-1]
786 if width > len(data) - w_eod_xt[-1]:
787 width = len(data) - w_eod_xt[-1]
789 wp_clusters = -1 * np.ones(len(w_eod_xp))
790 wt_clusters = -1 * np.ones(len(w_eod_xp))
791 wartefact_mask = np.ones(len(w_eod_xp))
793 # determine height labels
794 raw_p_snippets, p_snippets, p_features, p_bg_ratio = \
795 extract_snippet_features(data, w_eod_xp, w_eod_heights, width)
796 raw_t_snippets, t_snippets, t_features, t_bg_ratio = \
797 extract_snippet_features(data, w_eod_xt, w_eod_heights, width)
799 height_labels, bgm_log_dict = \
800 BGM(w_eod_heights, min(merge_threshold_height,
801 np.median(np.min(np.vstack([p_bg_ratio, t_bg_ratio]),
802 axis=0))), n_gaus_height, use_log=True,
803 verbose=verbose-1, plot_level=plot_level-1, xlabel =
804 'height [a.u.]', save_plot=save_plots,
805 save_path=save_path, save_name = 'height_%d' % wi,
806 ftype=ftype, return_data=return_data)
807 saved_data.update(bgm_log_dict)
809 if verbose > 0:
810 print('Clusters generated based on EOD height:')
811 for l in np.unique(height_labels):
812 print(f'N_{l} = {len(height_labels[height_labels==l]):4d} h_{l} = {np.mean(w_eod_heights[height_labels==l]):.4f}')
814 h_labels, h_counts = unique_counts(height_labels)
815 unique_height_labels = h_labels[h_counts>minp]
817 if plot_level > 0 or 'all_cluster_steps' in return_data:
818 all_heightlabels.append(height_labels)
819 all_heights.append(w_eod_heights)
820 all_unique_heightlabels.append(unique_height_labels)
821 shape_labels = []
822 cfeatures = []
823 csnippets = []
825 for hi, height_label in enumerate(unique_height_labels):
827 h_eod_widths = w_eod_widths[height_labels==height_label]
828 h_eod_heights = w_eod_heights[height_labels==height_label]
829 h_eod_xp = w_eod_xp[height_labels==height_label]
830 h_eod_xt = w_eod_xt[height_labels==height_label]
832 p_clusters = cluster_on_shape(p_features[height_labels==height_label],
833 p_bg_ratio, minp, verbose=0)
834 t_clusters = cluster_on_shape(t_features[height_labels==height_label],
835 t_bg_ratio, minp, verbose=0)
837 if plot_level > 1:
838 plot_feature_extraction(raw_p_snippets[height_labels==height_label],
839 p_snippets[height_labels==height_label],
840 p_features[height_labels==height_label],
841 p_clusters, 1/rate, 0)
842 plt.savefig('%sDBSCAN_peak_w%i_h%i.%s' % (save_path, wi, hi, ftype))
843 plot_feature_extraction(raw_t_snippets[height_labels==height_label],
844 t_snippets[height_labels==height_label],
845 t_features[height_labels==height_label],
846 t_clusters, 1/rate, 1)
847 plt.savefig('%sDBSCAN_trough_w%i_h%i.%s' % (save_path, wi, hi, ftype))
849 if 'snippet_clusters' in return_data:
850 saved_data[f'snippet_clusters_{width_label}_{height_label}_peak'] = {
851 'raw_snippets': raw_p_snippets[height_labels==height_label],
852 'snippets': p_snippets[height_labels==height_label],
853 'features': p_features[height_labels==height_label],
854 'clusters': p_clusters,
855 'rate': rate}
856 saved_data['snippet_clusters_{width_label}_{height_label}_trough'] = {
857 'raw_snippets': raw_t_snippets[height_labels==height_label],
858 'snippets': t_snippets[height_labels==height_label],
859 'features': t_features[height_labels==height_label],
860 'clusters': t_clusters,
861 'rate': rate}
863 if plot_level > 0 or 'all_cluster_steps' in return_data:
864 shape_labels.append([p_clusters, t_clusters])
865 cfeatures.append([p_features[height_labels==height_label],
866 t_features[height_labels==height_label]])
867 csnippets.append([p_snippets[height_labels==height_label],
868 t_snippets[height_labels==height_label]])
870 p_clusters[p_clusters==-1] = -max_label_p - 1
871 wp_clusters[height_labels==height_label] = p_clusters + max_label_p
872 max_label_p = max(np.max(wp_clusters), np.max(all_p_clusters)) + 1
874 t_clusters[t_clusters==-1] = -max_label_t - 1
875 wt_clusters[height_labels==height_label] = t_clusters + max_label_t
876 max_label_t = max(np.max(wt_clusters), np.max(all_t_clusters)) + 1
878 if verbose > 0:
879 if np.max(wp_clusters) == -1:
880 print(f'No EOD peaks in width cluster {width_label}')
881 else:
882 unique_clusters = np.unique(wp_clusters[wp_clusters!=-1])
883 if len(unique_clusters) > 1:
884 print('{len(unique_clusters)} different EOD peaks in width cluster {width_label}')
886 if plot_level > 0 or 'all_cluster_steps' in return_data:
887 all_shapelabels.append(shape_labels)
888 all_snippets.append(csnippets)
889 all_features.append(cfeatures)
891 # for each cluster, save fft + label
892 # so I end up with features for each label, and the masks.
893 # then I can extract e.g. first artefact or wave etc.
895 # remove artefacts here, based on the mean snippets ffts.
896 artefact_masks_p[width_labels==width_label], sdict = \
897 remove_artefacts(p_snippets, wp_clusters, rate,
898 verbose=verbose-1, return_data=return_data)
899 saved_data.update(sdict)
900 artefact_masks_t[width_labels==width_label], _ = \
901 remove_artefacts(t_snippets, wt_clusters, rate,
902 verbose=verbose-1, return_data=return_data)
904 # update maxlab so that no clusters are overwritten
905 all_p_clusters[width_labels==width_label] = wp_clusters
906 all_t_clusters[width_labels==width_label] = wt_clusters
908 # remove all non-reliable clusters
909 unreliable_fish_mask_p, saved_data = \
910 delete_unreliable_fish(all_p_clusters, eod_widths, eod_xp,
911 verbose=verbose-1, sdict=saved_data)
912 unreliable_fish_mask_t, _ = \
913 delete_unreliable_fish(all_t_clusters, eod_widths, eod_xt, verbose=verbose-1)
915 wave_mask_p, sidepeak_mask_p, saved_data = \
916 delete_wavefish_and_sidepeaks(data, all_p_clusters, eod_xp, eod_widths,
917 width_factor_wave, verbose=verbose-1, sdict=saved_data)
918 wave_mask_t, sidepeak_mask_t, _ = \
919 delete_wavefish_and_sidepeaks(data, all_t_clusters, eod_xt, eod_widths,
920 width_factor_wave, verbose=verbose-1)
922 og_clusters = [np.copy(all_p_clusters), np.copy(all_t_clusters)]
923 og_labels = np.copy(all_p_clusters + all_t_clusters)
925 # go through all clusters and masks??
926 all_p_clusters[(artefact_masks_p | unreliable_fish_mask_p | wave_mask_p | sidepeak_mask_p)] = -1
927 all_t_clusters[(artefact_masks_t | unreliable_fish_mask_t | wave_mask_t | sidepeak_mask_t)] = -1
929 # merge here.
930 all_clusters, x_merge, mask = merge_clusters(np.copy(all_p_clusters),
931 np.copy(all_t_clusters),
932 eod_xp, eod_xt,
933 verbose=verbose - 1)
935 if 'all_cluster_steps' in return_data or plot_level > 0:
936 all_dmasks = []
937 all_mmasks = []
939 discarding_masks = \
940 np.vstack(((artefact_masks_p | unreliable_fish_mask_p | wave_mask_p | sidepeak_mask_p),
941 (artefact_masks_t | unreliable_fish_mask_t | wave_mask_t | sidepeak_mask_t)))
942 merge_mask = mask
944 # save the masks in the same formats as the snippets
945 for wi, (width_label, w_shape_label, heightlabels, unique_height_labels) in enumerate(zip(unique_width_labels, all_shapelabels, all_heightlabels, all_unique_heightlabels)):
946 w_dmasks = discarding_masks[:,width_labels==width_label]
947 w_mmasks = merge_mask[:,width_labels==width_label]
949 wd_2 = []
950 wm_2 = []
952 for hi, (height_label, h_shape_label) in enumerate(zip(unique_height_labels, w_shape_label)):
954 h_dmasks = w_dmasks[:,heightlabels==height_label]
955 h_mmasks = w_mmasks[:,heightlabels==height_label]
957 wd_2.append(h_dmasks)
958 wm_2.append(h_mmasks)
960 all_dmasks.append(wd_2)
961 all_mmasks.append(wm_2)
963 if plot_level > 0:
964 plot_clustering(rate, [unique_width_labels, eod_widths, width_labels],
965 [all_unique_heightlabels, all_heights, all_heightlabels],
966 [all_snippets, all_features, all_shapelabels],
967 all_dmasks, all_mmasks)
968 if save_plots:
969 plt.savefig('%sclustering.%s' % (save_path, ftype))
971 if 'all_cluster_steps' in return_data:
972 saved_data = {'rate': rate,
973 'EOD_widths': [unique_width_labels, eod_widths, width_labels],
974 'EOD_heights': [all_unique_heightlabels, all_heights, all_heightlabels],
975 'EOD_shapes': [all_snippets, all_features, all_shapelabels],
976 'discarding_masks': all_dmasks,
977 'merge_masks': all_mmasks
978 }
980 if 'masks' in return_data:
981 saved_data = {'masks' : np.vstack(((artefact_masks_p & artefact_masks_t),
982 (unreliable_fish_mask_p & unreliable_fish_mask_t),
983 (wave_mask_p & wave_mask_t),
984 (sidepeak_mask_p & sidepeak_mask_t),
985 (all_p_clusters+all_t_clusters)))}
987 if verbose > 0:
988 print('Clusters generated based on height, width and shape: ')
989 for l in np.unique(all_clusters[all_clusters != -1]):
990 print('N_{int(l)} = {len(all_clusters[all_clusters == l]):4d}')
992 return all_clusters, x_merge, saved_data
995def BGM(x, merge_threshold=0.1, n_gaus=5, max_iter=200, n_init=5,
996 use_log=False, verbose=0, plot_level=0, xlabel='x [a.u.]',
997 save_plot=False, save_path='', save_name='', ftype='pdf',
998 return_data=[]):
999 """ Use a Bayesian Gaussian Mixture Model to cluster one-dimensional data.
1001 Additional steps are used to merge clusters that are closer than
1002 `merge_threshold`. Broad gaussian fits that cover one or more other
1003 gaussian fits are split by their intersections with the other
1004 gaussians.
1006 Parameters
1007 ----------
1008 x : 1D numpy array
1009 Features to compute clustering on.
1011 merge_threshold : float (optional)
1012 Ratio for merging nearby gaussians.
1013 n_gaus: int (optional)
1014 Maximum number of gaussians to fit on data.
1015 max_iter : int (optional)
1016 Maximum number of iterations for gaussian fit.
1017 n_init : int (optional)
1018 Number of initializations for the gaussian fit.
1019 use_log: boolean (optional)
1020 Set to True to compute the gaussian fit on the logarithm of x.
1021 Can improve clustering on features with nonlinear relationships such as peak height.
1022 verbose : int (optional)
1023 Verbosity level.
1024 plot_level : int (optional)
1025 Similar to verbosity levels, but with plots.
1026 Only set to > 0 for debugging purposes.
1027 xlabel : string (optional)
1028 Xlabel for displaying BGM plot.
1029 save_plot : bool (optional)
1030 Set to True to save created plot.
1031 save_path : string (optional)
1032 Path to location where data should be saved. Only used if save_plot==True.
1033 save_name : string (optional)
1034 Filename of the saved plot. Usefull as usually multiple BGM models are generated.
1035 ftype : string (optional)
1036 Filetype of plot image if save_plots==True.
1037 return_data : list of strings (optional)
1038 Keys that specify data to be logged. Keys that can be used to log data
1039 in this function are: 'BGM_width' and/or 'BGM_height' (see extract_pulsefish()).
1041 Returns
1042 -------
1043 labels : 1D numpy array
1044 Cluster labels for each sample in x.
1045 bgm_dict : dictionary
1046 Key value pairs of logged data. Data to be logged is specified by return_data.
1047 """
1049 bgm_dict = {}
1051 if len(np.unique(x)) > n_gaus:
1052 BGM_model = BayesianGaussianMixture(n_components=n_gaus, max_iter=max_iter, n_init=n_init)
1053 if use_log:
1054 labels = BGM_model.fit_predict(stats.zscore(np.log(x)).reshape(-1, 1))
1055 else:
1056 labels = BGM_model.fit_predict(stats.zscore(x).reshape(-1, 1))
1057 else:
1058 return np.zeros(len(x)), bgm_dict
1060 if verbose>0:
1061 if not BGM_model.converged_:
1062 print('!!! Gaussian mixture did not converge !!!')
1064 cur_labels = np.unique(labels)
1066 # map labels to be increasing for increasing values for x
1067 maxlab = len(cur_labels)
1068 aso = np.argsort([np.median(x[labels == l]) for l in cur_labels]) + 100
1069 for i, a in zip(cur_labels, aso):
1070 labels[labels==i] = a
1071 labels = labels - 100
1073 # separate gaussian clusters that can be split by other clusters
1074 splits = np.sort(np.copy(x))[1:][np.diff(labels[np.argsort(x)])!=0]
1076 labels[:] = 0
1077 for i, split in enumerate(splits):
1078 labels[x>=split] = i+1
1080 labels_before_merge = np.copy(labels)
1082 # merge gaussian clusters that are closer than merge_threshold
1083 labels = merge_gaussians(x, labels, merge_threshold)
1085 if 'BGM_'+save_name.split('_')[0] in return_data or plot_level>0:
1087 #sort model attributes by model_means_
1088 means = [m[0] for m in BGM_model.means_]
1089 weights = [w for w in BGM_model.weights_]
1090 variances = [v[0][0] for v in BGM_model.covariances_]
1091 weights = [w for _, w in sorted(zip(means, weights))]
1092 variances = [v for _, v in sorted(zip(means, variances))]
1093 means = sorted(means)
1095 if plot_level>0:
1096 plot_bgm(x, means, variances, weights, use_log, labels_before_merge,
1097 labels, xlabel)
1098 if save_plot:
1099 plt.savefig('%sBGM_%s.%s' % (save_path, save_name, ftype))
1101 if 'BGM_'+save_name.split('_')[0] in return_data:
1102 bgm_dict['BGM_'+save_name] = {'x':x,
1103 'use_log':use_log,
1104 'BGM':[weights, means, variances],
1105 'labels':labels_before_merge,
1106 'xlab':xlabel}
1108 return labels, bgm_dict
1111def merge_gaussians(x, labels, merge_threshold=0.1):
1112 """ Merge all clusters which have medians which are near. Only works in 1D.
1114 Parameters
1115 ----------
1116 x : 1D array of ints or floats
1117 Features used for clustering.
1118 labels : 1D array of ints
1119 Labels for each sample in x.
1120 merge_threshold : float (optional)
1121 Similarity threshold to merge clusters.
1123 Returns
1124 -------
1125 labels : 1D array of ints
1126 Merged labels for each sample in x.
1127 """
1129 # compare all the means of the gaussians. If they are too close, merge them.
1130 unique_labels = np.unique(labels[labels!=-1])
1131 x_medians = [np.median(x[labels==l]) for l in unique_labels]
1133 # fill a dict with the label mappings
1134 mapping = {}
1135 for label_1, x_m1 in zip(unique_labels, x_medians):
1136 for label_2, x_m2 in zip(unique_labels, x_medians):
1137 if label_1!=label_2:
1138 if np.abs(np.diff([x_m1, x_m2]))/np.max([x_m1, x_m2]) < merge_threshold:
1139 mapping[label_1] = label_2
1140 # apply mapping
1141 for map_key, map_value in mapping.items():
1142 labels[labels==map_key] = map_value
1144 return labels
1147def extract_snippet_features(data, eod_x, eod_heights, width, n_pc=5):
1148 """ Extract snippets from recording data, normalize them, and perform PCA.
1150 Parameters
1151 ----------
1152 data : 1D numpy array of floats
1153 Recording data.
1154 eod_x : 1D array of ints
1155 Locations of EODs as indices.
1156 eod_heights: 1D array of floats
1157 EOD heights.
1158 width : int
1159 Width to cut out to each side in samples.
1161 n_pc : int (optional)
1162 Number of PCs to use for PCA.
1164 Returns
1165 -------
1166 raw_snippets : 2D numpy array (N, EOD_width)
1167 Raw extracted EOD snippets.
1168 snippets : 2D numpy array (N, EOD_width)
1169 Normalized EOD snippets
1170 features : 2D numpy array (N,n_pc)
1171 PC values of EOD snippets
1172 bg_ratio : 1D numpy array (N)
1173 Ratio of the background activity slopes compared to EOD height.
1174 """
1175 # extract snippets with corresponding width
1176 raw_snippets = np.vstack([data[x-width:x+width] for x in eod_x])
1178 # subtract the slope and normalize the snippets
1179 snippets, bg_ratio = subtract_slope(np.copy(raw_snippets), eod_heights)
1180 snippets = StandardScaler().fit_transform(snippets.T).T
1182 # scale so that the absolute integral = 1.
1183 snippets = (snippets.T/np.sum(np.abs(snippets), axis=1)).T
1185 # compute features for clustering on waveform
1186 features = PCA(n_pc).fit_transform(snippets)
1188 return raw_snippets, snippets, features, bg_ratio
1191def cluster_on_shape(features, bg_ratio, minp, percentile=80,
1192 max_epsilon=0.01, slope_ratio_factor=4,
1193 min_cluster_fraction=0.01, verbose=0):
1194 """Separate EODs by their shape using DBSCAN.
1196 Parameters
1197 ----------
1198 features : 2D numpy array of floats (N, n_pc)
1199 PCA features of each EOD in a recording.
1200 bg_ratio : 1D array of floats
1201 Ratio of background activity slope the EOD is superimposed on.
1202 minp : int
1203 Minimum number of points for core cluster (DBSCAN).
1205 percentile : int (optional)
1206 Percentile of KNN distribution, where K=minp, to use as epsilon for DBSCAN.
1207 max_epsilon : float (optional)
1208 Maximum epsilon to use for DBSCAN clustering. This is used to avoid adding
1209 noisy clusters.
1210 slope_ratio_factor : float (optional)
1211 Influence of the slope-to-EOD ratio on the epsilon parameter.
1212 A slope_ratio_factor of 4 means that slope-to-EOD ratios >1/4
1213 start influencing epsilon.
1214 min_cluster_fraction : float (optional)
1215 Minimum fraction of all eveluated datapoint that can form a single cluster.
1216 verbose : int (optional)
1217 Verbosity level.
1219 Returns
1220 -------
1221 labels : 1D array of ints
1222 Merged labels for each sample in x.
1223 """
1225 # determine clustering threshold from data
1226 minpc = max(minp, int(len(features)*min_cluster_fraction))
1227 knn = np.sort(pairwise_distances(features, features), axis=0)[minpc]
1228 eps = min(max(1, slope_ratio_factor*np.median(bg_ratio))*max_epsilon,
1229 np.percentile(knn, percentile))
1231 if verbose>1:
1232 print('epsilon = %f'%eps)
1233 print('Slope to EOD ratio = %f'%np.median(bg_ratio))
1235 # cluster on EOD shape
1236 return DBSCAN(eps=eps, min_samples=minpc).fit(features).labels_
1239def subtract_slope(snippets, heights):
1240 """ Subtract underlying slope from all EOD snippets.
1242 Parameters
1243 ----------
1244 snippets: 2-D numpy array
1245 All EODs in a recorded stacked as snippets.
1246 Shape = (number of EODs, EOD width)
1247 heights: 1D numpy array
1248 EOD heights.
1250 Returns
1251 -------
1252 snippets: 2-D numpy array
1253 EOD snippets with underlying slope subtracted.
1254 bg_ratio : 1-D numpy array
1255 EOD height/background activity height.
1256 """
1258 left_y = snippets[:,0]
1259 right_y = snippets[:,-1]
1261 try:
1262 slopes = np.linspace(left_y, right_y, snippets.shape[1])
1263 except ValueError:
1264 delta = (right_y - left_y)/snippets.shape[1]
1265 slopes = np.arange(0, snippets.shape[1], dtype=snippets.dtype).reshape((-1,) + (1,) * np.ndim(delta))*delta + left_y
1267 return snippets - slopes.T, np.abs(left_y-right_y)/heights
1270def remove_artefacts(all_snippets, clusters, rate,
1271 freq_low=20000, threshold=0.75,
1272 verbose=0, return_data=[]):
1273 """ Create a mask for EOD clusters that result from artefacts, based on power in low frequency spectrum.
1275 Parameters
1276 ----------
1277 all_snippets: 2D array
1278 EOD snippets. Shape=(nEODs, EOD length)
1279 clusters: list of ints
1280 EOD cluster labels
1281 rate : float
1282 Sampling rate of original recording data.
1283 freq_low: float
1284 Frequency up to which low frequency components are summed up.
1285 threshold : float (optional)
1286 Minimum value for sum of low frequency components relative to
1287 sum overa ll spectrl amplitudes that separates artefact from
1288 clean pulsefish clusters.
1289 verbose : int (optional)
1290 Verbosity level.
1291 return_data : list of strings (optional)
1292 Keys that specify data to be logged. The key that can be used to log data in this function is
1293 'eod_deletion' (see extract_pulsefish()).
1295 Returns
1296 -------
1297 mask: numpy array of booleans
1298 Set to True for every EOD which is an artefact.
1299 adict : dictionary
1300 Key value pairs of logged data. Data to be logged is specified by return_data.
1301 """
1302 adict = {}
1304 mask = np.zeros(clusters.shape, dtype=bool)
1306 for cluster in np.unique(clusters[clusters >= 0]):
1307 snippets = all_snippets[clusters == cluster]
1308 mean_eod = np.mean(snippets, axis=0)
1309 mean_eod = mean_eod - np.mean(mean_eod)
1310 mean_eod_fft = np.abs(np.fft.rfft(mean_eod))
1311 freqs = np.fft.rfftfreq(len(mean_eod), 1/rate)
1312 low_frequency_ratio = np.sum(mean_eod_fft[freqs<freq_low])/np.sum(mean_eod_fft)
1313 if low_frequency_ratio < threshold: # TODO: check threshold!
1314 mask[clusters==cluster] = True
1316 if verbose > 0:
1317 print('Deleting cluster %i with low frequency ratio of %.3f (min %.3f)' % (cluster, low_frequency_ratio, threshold))
1319 if 'eod_deletion' in return_data:
1320 adict['vals_%d' % cluster] = [mean_eod, mean_eod_fft]
1321 adict['mask_%d' % cluster] = [np.any(mask[clusters==cluster])]
1323 return mask, adict
1326def delete_unreliable_fish(clusters, eod_widths, eod_x, verbose=0, sdict={}):
1327 """ Create a mask for EOD clusters that are either mixed with noise or other fish, or wavefish.
1329 This is the case when the ration between the EOD width and the ISI is too large.
1331 Parameters
1332 ----------
1333 clusters : list of ints
1334 Cluster labels.
1335 eod_widths : list of floats or ints
1336 EOD widths in samples or seconds.
1337 eod_x : list of ints or floats
1338 EOD times in samples or seconds.
1340 verbose : int (optional)
1341 Verbosity level.
1342 sdict : dictionary
1343 Dictionary that is used to log data. This is only used if a dictionary
1344 was created by remove_artefacts().
1345 For logging data in noise and wavefish discarding steps,
1346 see remove_artefacts().
1348 Returns
1349 -------
1350 mask : numpy array of booleans
1351 Set to True for every unreliable EOD.
1352 sdict : dictionary
1353 Key value pairs of logged data. Data is only logged if a dictionary
1354 was instantiated by remove_artefacts().
1355 """
1356 mask = np.zeros(clusters.shape, dtype=bool)
1357 for cluster in np.unique(clusters[clusters >= 0]):
1358 if len(eod_x[cluster == clusters]) < 2:
1359 mask[clusters == cluster] = True
1360 if verbose>0:
1361 print('deleting unreliable cluster %i, number of EOD times %d < 2' % (cluster, len(eod_x[cluster==clusters])))
1362 elif np.max(np.median(eod_widths[clusters==cluster])/np.diff(eod_x[cluster==clusters])) > 0.5:
1363 if verbose>0:
1364 print('deleting unreliable cluster %i, score=%f' % (cluster, np.max(np.median(eod_widths[clusters==cluster])/np.diff(eod_x[cluster==clusters]))))
1365 mask[clusters==cluster] = True
1366 if 'vals_%d' % cluster in sdict:
1367 sdict['vals_%d' % cluster].append(np.median(eod_widths[clusters==cluster])/np.diff(eod_x[cluster==clusters]))
1368 sdict['mask_%d' % cluster].append(any(mask[clusters==cluster]))
1369 return mask, sdict
1372def delete_wavefish_and_sidepeaks(data, clusters, eod_x, eod_widths,
1373 width_fac, max_slope_deviation=0.5,
1374 max_phases=4, verbose=0, sdict={}):
1375 """ Create a mask for EODs that are likely from wavefish, or sidepeaks of bigger EODs.
1377 Parameters
1378 ----------
1379 data : list of floats
1380 Raw recording data.
1381 clusters : list of ints
1382 Cluster labels.
1383 eod_x : list of ints
1384 Indices of EOD times.
1385 eod_widths : list of ints
1386 EOD widths in samples.
1387 width_fac : float
1388 Multiplier for EOD analysis width.
1390 max_slope_deviation: float (optional)
1391 Maximum deviation of position of maximum slope in snippets from
1392 center position in multiples of mean width of EOD.
1393 max_phases : int (optional)
1394 Maximum number of phases for any EOD.
1395 If the mean EOD has more phases than this, it is not a pulse EOD.
1396 verbose : int (optional)
1397 Verbosity level.
1398 sdict : dictionary
1399 Dictionary that is used to log data. This is only used if a dictionary
1400 was created by remove_artefacts().
1401 For logging data in noise and wavefish discarding steps, see remove_artefacts().
1403 Returns
1404 -------
1405 mask_wave: numpy array of booleans
1406 Set to True for every EOD which is a wavefish EOD.
1407 mask_sidepeak: numpy array of booleans
1408 Set to True for every snippet which is centered around a sidepeak of an EOD.
1409 sdict : dictionary
1410 Key value pairs of logged data. Data is only logged if a dictionary
1411 was instantiated by remove_artefacts().
1412 """
1413 mask_wave = np.zeros(clusters.shape, dtype=bool)
1414 mask_sidepeak = np.zeros(clusters.shape, dtype=bool)
1416 for i, cluster in enumerate(np.unique(clusters[clusters >= 0])):
1417 mean_width = np.mean(eod_widths[clusters == cluster])
1418 cutwidth = mean_width*width_fac
1419 current_x = eod_x[(eod_x>cutwidth) & (eod_x<(len(data)-cutwidth))]
1420 current_clusters = clusters[(eod_x>cutwidth) & (eod_x<(len(data)-cutwidth))]
1421 snippets = np.vstack([data[int(x-cutwidth):int(x+cutwidth)]
1422 for x in current_x[current_clusters==cluster]])
1424 # extract information on main peaks and troughs:
1425 mean_eod = np.mean(snippets, axis=0)
1426 mean_eod = mean_eod - np.mean(mean_eod)
1428 # detect peaks and troughs on data + some maxima/minima at the
1429 # end, so that the sides are also considered for peak detection:
1430 pk, tr = detect_peaks(np.concatenate(([-10*mean_eod[0]], mean_eod, [10*mean_eod[-1]])),
1431 np.std(mean_eod))
1432 pk = pk[(pk>0)&(pk<len(mean_eod))]
1433 tr = tr[(tr>0)&(tr<len(mean_eod))]
1435 if len(pk)>0 and len(tr)>0:
1436 idxs = np.sort(np.concatenate((pk, tr)))
1437 slopes = np.abs(np.diff(mean_eod[idxs]))
1438 m_slope = np.argmax(slopes)
1439 centered = np.min(np.abs(idxs[m_slope:m_slope+2] - len(mean_eod)//2))
1441 # compute all height differences of peaks and troughs within snippets.
1442 # if they are all similar, it is probably noise or a wavefish.
1443 idxs = np.sort(np.concatenate((pk, tr)))
1444 hdiffs = np.diff(mean_eod[idxs])
1446 if centered > max_slope_deviation*mean_width: # TODO: check, factor was probably 0.16
1447 if verbose > 0:
1448 print('Deleting cluster %i, which is a sidepeak' % cluster)
1449 mask_sidepeak[clusters==cluster] = True
1451 w_diff = np.abs(np.diff(np.sort(np.concatenate((pk, tr)))))
1453 if np.abs(np.diff(idxs[m_slope:m_slope+2])) < np.mean(eod_widths[clusters==cluster])*0.5 or len(pk) + len(tr)>max_phases or np.min(w_diff)>2*cutwidth/width_fac: #or len(hdiffs[np.abs(hdiffs)>0.5*(np.max(mean_eod)-np.min(mean_eod))])>max_phases:
1454 if verbose>0:
1455 print('Deleting cluster %i, which is a wavefish' % cluster)
1456 mask_wave[clusters==cluster] = True
1457 if 'vals_%d' % cluster in sdict:
1458 sdict['vals_%d' % cluster].append([mean_eod, [pk, tr],
1459 idxs[m_slope:m_slope+2]])
1460 sdict['mask_%d' % cluster].append(any(mask_wave[clusters==cluster]))
1461 sdict['mask_%d' % cluster].append(any(mask_sidepeak[clusters==cluster]))
1463 return mask_wave, mask_sidepeak, sdict
1466def merge_clusters(clusters_1, clusters_2, x_1, x_2, verbose=0):
1467 """ Merge clusters resulting from two clustering methods.
1469 This method only works if clustering is performed on the same EODs
1470 with the same ordering, where there is a one to one mapping from
1471 clusters_1 to clusters_2.
1473 Parameters
1474 ----------
1475 clusters_1: list of ints
1476 EOD cluster labels for cluster method 1.
1477 clusters_2: list of ints
1478 EOD cluster labels for cluster method 2.
1479 x_1: list of ints
1480 Indices of EODs for cluster method 1 (clusters_1).
1481 x_2: list of ints
1482 Indices of EODs for cluster method 2 (clusters_2).
1483 verbose : int (optional)
1484 Verbosity level.
1486 Returns
1487 -------
1488 clusters : list of ints
1489 Merged clusters.
1490 x_merged : list of ints
1491 Merged cluster indices.
1492 mask : 2d numpy array of ints (N, 2)
1493 Mask for clusters that are selected from clusters_1 (mask[:,0]) and
1494 from clusters_2 (mask[:,1]).
1495 """
1496 if verbose > 0:
1497 print('\nMerge cluster:')
1499 # these arrays become 1 for each EOD that is chosen from that array
1500 c1_keep = np.zeros(len(clusters_1))
1501 c2_keep = np.zeros(len(clusters_2))
1503 # add n to one of the cluster lists to avoid overlap
1504 ovl = np.max(clusters_1) + 1
1505 clusters_2[clusters_2!=-1] = clusters_2[clusters_2!=-1] + ovl
1507 remove_clusters = [[]]
1508 keep_clusters = []
1509 og_clusters = [np.copy(clusters_1), np.copy(clusters_2)]
1511 # loop untill done
1512 while True:
1514 # compute unique clusters and cluster sizes
1515 # of cluster that have not been iterated over:
1516 c1_labels, c1_size = unique_counts(clusters_1[(clusters_1 != -1) & (c1_keep == 0)])
1517 c2_labels, c2_size = unique_counts(clusters_2[(clusters_2 != -1) & (c2_keep == 0)])
1519 # if all clusters are done, break from loop:
1520 if len(c1_size) == 0 and len(c2_size) == 0:
1521 break
1523 # if the biggest cluster is in c_p, keep this one and discard all clusters
1524 # on the same indices in c_t:
1525 elif np.argmax([np.max(np.append(c1_size, 0)), np.max(np.append(c2_size, 0))]) == 0:
1527 # remove all the mappings from the other indices
1528 cluster_mappings, _ = unique_counts(clusters_2[clusters_1 == c1_labels[np.argmax(c1_size)]])
1530 clusters_2[np.isin(clusters_2, cluster_mappings)] = -1
1532 c1_keep[clusters_1==c1_labels[np.argmax(c1_size)]] = 1
1534 remove_clusters.append(cluster_mappings)
1535 keep_clusters.append(c1_labels[np.argmax(c1_size)])
1537 if verbose > 0:
1538 print('Keep cluster %i of group 1, delete clusters %s of group 2' % (c1_labels[np.argmax(c1_size)], str(cluster_mappings[cluster_mappings!=-1] - ovl)))
1540 # if the biggest cluster is in c_t, keep this one and discard all mappings in c_p
1541 elif np.argmax([np.max(np.append(c1_size, 0)), np.max(np.append(c2_size, 0))]) == 1:
1543 # remove all the mappings from the other indices
1544 cluster_mappings, _ = unique_counts(clusters_1[clusters_2 == c2_labels[np.argmax(c2_size)]])
1546 clusters_1[np.isin(clusters_1, cluster_mappings)] = -1
1548 c2_keep[clusters_2==c2_labels[np.argmax(c2_size)]] = 1
1550 remove_clusters.append(cluster_mappings)
1551 keep_clusters.append(c2_labels[np.argmax(c2_size)])
1553 if verbose > 0:
1554 print('Keep cluster %i of group 2, delete clusters %s of group 1' % (c2_labels[np.argmax(c2_size)] - ovl, str(cluster_mappings[cluster_mappings!=-1])))
1556 # combine results
1557 clusters = (clusters_1+1)*c1_keep + (clusters_2+1)*c2_keep - 1
1558 x_merged = (x_1)*c1_keep + (x_2)*c2_keep
1560 return clusters, x_merged, np.vstack([c1_keep, c2_keep])
1563def extract_means(data, eod_x, eod_peak_x, eod_tr_x, eod_widths,
1564 clusters, rate, width_fac, verbose=0):
1565 """ Extract mean EODs and EOD timepoints for each EOD cluster.
1567 Parameters
1568 ----------
1569 data: list of floats
1570 Raw recording data.
1571 eod_x: list of ints
1572 Locations of EODs in samples.
1573 eod_peak_x : list of ints
1574 Locations of EOD peaks in samples.
1575 eod_tr_x : list of ints
1576 Locations of EOD troughs in samples.
1577 eod_widths: list of ints
1578 EOD widths in samples.
1579 clusters: list of ints
1580 EOD cluster labels
1581 rate: float
1582 Sampling rate of recording
1583 width_fac : float
1584 Multiplication factor for window used to extract EOD.
1586 verbose : int (optional)
1587 Verbosity level.
1589 Returns
1590 -------
1591 mean_eods: list of 2D arrays (3, eod_length)
1592 The average EOD for each detected fish. First column is time in seconds,
1593 second column the mean eod, third column the standard error.
1594 eod_times: list of 1D arrays
1595 For each detected fish the times of EOD in seconds.
1596 eod_peak_times: list of 1D arrays
1597 For each detected fish the times of EOD peaks in seconds.
1598 eod_trough_times: list of 1D arrays
1599 For each detected fish the times of EOD troughs in seconds.
1600 eod_labels: list of ints
1601 Cluster label for each detected fish.
1602 """
1603 mean_eods, eod_times, eod_peak_times, eod_tr_times, eod_heights, cluster_labels = [], [], [], [], [], []
1605 for cluster in np.unique(clusters):
1606 if cluster!=-1:
1607 cutwidth = np.mean(eod_widths[clusters==cluster])*width_fac
1608 current_x = eod_x[(eod_x>cutwidth) & (eod_x<(len(data)-cutwidth))]
1609 current_clusters = clusters[(eod_x>cutwidth) & (eod_x<(len(data)-cutwidth))]
1611 snippets = np.vstack([data[int(x-cutwidth):int(x+cutwidth)] for x in current_x[current_clusters==cluster]])
1612 mean_eod = np.mean(snippets, axis=0)
1613 eod_time = np.arange(len(mean_eod))/rate - cutwidth/rate
1615 mean_eod = np.vstack([eod_time, mean_eod, np.std(snippets, axis=0)])
1617 mean_eods.append(mean_eod)
1618 eod_times.append(eod_x[clusters==cluster]/rate)
1619 eod_heights.append(np.min(mean_eod)-np.max(mean_eod))
1620 eod_peak_times.append(eod_peak_x[clusters==cluster]/rate)
1621 eod_tr_times.append(eod_tr_x[clusters==cluster]/rate)
1622 cluster_labels.append(cluster)
1624 return [m for _, m in sorted(zip(eod_heights, mean_eods))], [t for _, t in sorted(zip(eod_heights, eod_times))], [pt for _, pt in sorted(zip(eod_heights, eod_peak_times))], [tt for _, tt in sorted(zip(eod_heights, eod_tr_times))], [c for _, c in sorted(zip(eod_heights, cluster_labels))]
1627def find_clipped_clusters(clusters, mean_eods, eod_times,
1628 eod_peaktimes, eod_troughtimes,
1629 cluster_labels, width_factor,
1630 clip_threshold=0.9, verbose=0):
1631 """ Detect EODs that are clipped and set all clusterlabels of these clipped EODs to -1.
1633 Also return the mean EODs and timepoints of these clipped EODs.
1635 Parameters
1636 ----------
1637 clusters: array of ints
1638 Cluster labels for each EOD in a recording.
1639 mean_eods: list of numpy arrays
1640 Mean EOD waveform for each cluster.
1641 eod_times: list of numpy arrays
1642 EOD timepoints for each EOD cluster.
1643 eod_peaktimes
1644 EOD peaktimes for each EOD cluster.
1645 eod_troughtimes
1646 EOD troughtimes for each EOD cluster.
1647 cluster_labels: numpy array
1648 Unique EOD clusterlabels.
1649 clip_threshold: float
1650 Threshold for detecting clipped EODs.
1652 verbose: int
1653 Verbosity level.
1655 Returns
1656 -------
1657 clusters : array of ints
1658 Cluster labels for each EOD in the recording, where clipped EODs have been set to -1.
1659 clipped_eods : list of numpy arrays
1660 Mean EOD waveforms for each clipped EOD cluster.
1661 clipped_times : list of numpy arrays
1662 EOD timepoints for each clipped EOD cluster.
1663 clipped_peaktimes : list of numpy arrays
1664 EOD peaktimes for each clipped EOD cluster.
1665 clipped_troughtimes : list of numpy arrays
1666 EOD troughtimes for each clipped EOD cluster.
1667 """
1668 clipped_eods, clipped_times, clipped_peaktimes, clipped_troughtimes, clipped_labels = [], [], [], [], []
1670 for mean_eod, eod_time, eod_peaktime, eod_troughtime,label in zip(mean_eods, eod_times, eod_peaktimes, eod_troughtimes, cluster_labels):
1672 if (np.count_nonzero(mean_eod[1]>clip_threshold) > len(mean_eod[1])/(width_factor*2)) or (np.count_nonzero(mean_eod[1] < -clip_threshold) > len(mean_eod[1])/(width_factor*2)):
1673 clipped_eods.append(mean_eod)
1674 clipped_times.append(eod_time)
1675 clipped_peaktimes.append(eod_peaktime)
1676 clipped_troughtimes.append(eod_troughtime)
1677 clipped_labels.append(label)
1678 if verbose>0:
1679 print('clipped pulsefish')
1681 clusters[np.isin(clusters, clipped_labels)] = -1
1683 return clusters, clipped_eods, clipped_times, clipped_peaktimes, clipped_troughtimes
1686def delete_moving_fish(clusters, eod_t, T, eod_heights, eod_widths,
1687 rate, min_dt=0.25, stepsize=0.05,
1688 sliding_window_factor=2000, verbose=0,
1689 plot_level=0, save_plot=False, save_path='',
1690 ftype='pdf', return_data=[]):
1691 """
1692 Use a sliding window to detect the minimum number of fish detected simultaneously,
1693 then delete all other EOD clusters.
1695 Do this only for EODs within the same width clusters, as a
1696 moving fish will preserve its EOD width.
1698 Parameters
1699 ----------
1700 clusters: list of ints
1701 EOD cluster labels.
1702 eod_t: list of floats
1703 Timepoints of the EODs (in seconds).
1704 T: float
1705 Length of recording (in seconds).
1706 eod_heights: list of floats
1707 EOD amplitudes.
1708 eod_widths: list of floats
1709 EOD widths (in seconds).
1710 rate: float
1711 Recording data sampling rate.
1713 min_dt : float (optional)
1714 Minimum sliding window size (in seconds).
1715 stepsize : float (optional)
1716 Sliding window stepsize (in seconds).
1717 sliding_window_factor : float
1718 Multiplier for sliding window width,
1719 where the sliding window width = median(EOD_width)*sliding_window_factor.
1720 verbose : int (optional)
1721 Verbosity level.
1722 plot_level : int (optional)
1723 Similar to verbosity levels, but with plots.
1724 Only set to > 0 for debugging purposes.
1725 save_plot : bool (optional)
1726 Set to True to save the plots created by plot_level.
1727 save_path : string (optional)
1728 Path to save data to. Only important if you wish to save data (save_data==True).
1729 ftype : string (optional)
1730 Define the filetype to save the plots in if save_plots is set to True.
1731 Options are: 'png', 'jpg', 'svg' ...
1732 return_data : list of strings (optional)
1733 Keys that specify data to be logged. The key that can be used to log data
1734 in this function is 'moving_fish' (see extract_pulsefish()).
1736 Returns
1737 -------
1738 clusters : list of ints
1739 Cluster labels, where deleted clusters have been set to -1.
1740 window : list of 2 floats
1741 Start and end of window selected for deleting moving fish in seconds.
1742 mf_dict : dictionary
1743 Key value pairs of logged data. Data to be logged is specified by return_data.
1744 """
1745 mf_dict = {}
1747 if len(np.unique(clusters[clusters != -1])) == 0:
1748 return clusters, [0, 1], {}
1750 all_keep_clusters = []
1751 width_classes = merge_gaussians(eod_widths, np.copy(clusters), 0.75)
1753 all_windows = []
1754 all_dts = []
1755 ev_num = 0
1756 for iw, w in enumerate(np.unique(width_classes[clusters >= 0])):
1757 # initialize variables
1758 min_clusters = 100
1759 average_height = 0
1760 sparse_clusters = 100
1761 keep_clusters = []
1763 dt = max(min_dt, np.median(eod_widths[width_classes==w])*sliding_window_factor)
1764 window_start = 0
1765 window_end = dt
1767 wclusters = clusters[width_classes==w]
1768 weod_t = eod_t[width_classes==w]
1769 weod_heights = eod_heights[width_classes==w]
1770 weod_widths = eod_widths[width_classes==w]
1772 all_dts.append(dt)
1774 if verbose>0:
1775 print('sliding window dt = %f'%dt)
1777 # make W dependent on width??
1778 ignore_steps = np.zeros(len(np.arange(0, T-dt+stepsize, stepsize)))
1780 for i, t in enumerate(np.arange(0, T-dt+stepsize, stepsize)):
1781 current_clusters = wclusters[(weod_t>=t)&(weod_t<t+dt)&(wclusters!=-1)]
1782 if len(np.unique(current_clusters))==0:
1783 ignore_steps[i-int(dt/stepsize):i+int(dt/stepsize)] = 1
1784 if verbose>0:
1785 print('No pulsefish in recording at T=%.2f:%.2f' % (t, t+dt))
1788 x = np.arange(0, T-dt+stepsize, stepsize)
1789 y = np.ones(len(x))
1791 running_sum = np.ones(len(np.arange(0, T+stepsize, stepsize)))
1792 ulabs = np.unique(wclusters[wclusters>=0])
1794 # sliding window
1795 for j, (t, ignore_step) in enumerate(zip(x, ignore_steps)):
1796 current_clusters = wclusters[(weod_t>=t)&(weod_t<t+dt)&(wclusters!=-1)]
1797 current_widths = weod_widths[(weod_t>=t)&(weod_t<t+dt)&(wclusters!=-1)]
1799 unique_clusters = np.unique(current_clusters)
1800 y[j] = len(unique_clusters)
1802 if (len(unique_clusters) <= min_clusters) and \
1803 (ignore_step==0) and \
1804 (len(unique_clusters !=1)):
1806 current_labels = np.isin(wclusters, unique_clusters)
1807 current_height = np.mean(weod_heights[current_labels])
1809 # compute nr of clusters that are too sparse
1810 clusters_after_deletion = np.unique(remove_sparse_detections(np.copy(clusters[np.isin(clusters, unique_clusters)]), rate*eod_widths[np.isin(clusters, unique_clusters)], rate, T))
1811 current_sparse_clusters = len(unique_clusters) - len(clusters_after_deletion[clusters_after_deletion!=-1])
1813 if current_sparse_clusters <= sparse_clusters and \
1814 ((current_sparse_clusters<sparse_clusters) or
1815 (current_height > average_height) or
1816 (len(unique_clusters) < min_clusters)):
1818 keep_clusters = unique_clusters
1819 min_clusters = len(unique_clusters)
1820 average_height = current_height
1821 window_end = t+dt
1822 sparse_clusters = current_sparse_clusters
1824 all_keep_clusters.append(keep_clusters)
1825 all_windows.append(window_end)
1827 if 'moving_fish' in return_data or plot_level>0:
1828 if 'w' in mf_dict:
1829 mf_dict['w'].append(np.median(eod_widths[width_classes==w]))
1830 mf_dict['T'] = T
1831 mf_dict['dt'].append(dt)
1832 mf_dict['clusters'].append(wclusters)
1833 mf_dict['t'].append(weod_t)
1834 mf_dict['fishcount'].append([x+0.5*(x[1]-x[0]), y])
1835 mf_dict['ignore_steps'].append(ignore_steps)
1836 else:
1837 mf_dict['w'] = [np.median(eod_widths[width_classes==w])]
1838 mf_dict['T'] = [T]
1839 mf_dict['dt'] = [dt]
1840 mf_dict['clusters'] = [wclusters]
1841 mf_dict['t'] = [weod_t]
1842 mf_dict['fishcount'] = [[x+0.5*(x[1]-x[0]), y]]
1843 mf_dict['ignore_steps'] = [ignore_steps]
1845 if verbose>0:
1846 print('Estimated nr of pulsefish in recording: %i'%len(all_keep_clusters))
1848 if plot_level>0:
1849 plot_moving_fish(mf_dict['w'], mf_dict['dt'], mf_dict['clusters'],mf_dict['t'],
1850 mf_dict['fishcount'], T, mf_dict['ignore_steps'])
1851 if save_plot:
1852 plt.savefig('%sdelete_moving_fish.%s' % (save_path, ftype))
1853 # empty dict
1854 if 'moving_fish' not in return_data:
1855 mf_dict = {}
1857 # delete all clusters that are not selected
1858 clusters[np.invert(np.isin(clusters, np.concatenate(all_keep_clusters)))] = -1
1860 return clusters, [np.max(all_windows)-np.max(all_dts), np.max(all_windows)], mf_dict
1863def remove_sparse_detections(clusters, eod_widths, rate, T,
1864 min_density=0.0005, verbose=0):
1865 """ Remove all EOD clusters that are too sparse
1867 Parameters
1868 ----------
1869 clusters : list of ints
1870 Cluster labels.
1871 eod_widths : list of ints
1872 Cluster widths in samples.
1873 rate : float
1874 Sampling rate.
1875 T : float
1876 Lenght of recording in seconds.
1877 min_density : float (optional)
1878 Minimum density for realistic EOD detections.
1879 verbose : int (optional)
1880 Verbosity level.
1882 Returns
1883 -------
1884 clusters : list of ints
1885 Cluster labels, where sparse clusters have been set to -1.
1886 """
1887 for c in np.unique(clusters):
1888 if c!=-1:
1890 n = len(clusters[clusters==c])
1891 w = np.median(eod_widths[clusters==c])/rate
1893 if n*w < T*min_density:
1894 if verbose>0:
1895 print('cluster %i is too sparse'%c)
1896 clusters[clusters==c] = -1
1897 return clusters