Coverage for src/thunderfish/pulses.py: 0%
598 statements
« prev ^ index » next coverage.py v7.5.0, created at 2024-04-29 16:21 +0000
« prev ^ index » next coverage.py v7.5.0, created at 2024-04-29 16:21 +0000
1"""
2Extract and cluster EOD waverforms of pulse-type electric fish.
4## Main function
6- `extract_pulsefish()`: checks for pulse-type fish based on the EOD amplitude and shape.
8"""
10import os
11import numpy as np
12from scipy import stats
13from scipy.interpolate import interp1d
14from sklearn.preprocessing import StandardScaler
15from sklearn.decomposition import PCA
16from sklearn.cluster import DBSCAN
17from sklearn.mixture import BayesianGaussianMixture
18from sklearn.metrics import pairwise_distances
19from thunderlab.eventdetection import detect_peaks, median_std_threshold
20from .pulseplots import *
22import warnings
23def warn(*args, **kwargs):
24 """
25 Ignore all warnings.
26 """
27 pass
28warnings.warn = warn
30try:
31 from numba import jit
32except ImportError:
33 def jit(*args, **kwargs):
34 def decorator_jit(func):
35 return func
36 return decorator_jit
39# upgrade numpy functions for backwards compatibility:
40if not hasattr(np, 'isin'):
41 np.isin = np.in1d
43def unique_counts(ar):
44 """ Find the unique elements of an array and their counts, ignoring shape.
46 The code is condensed from numpy version 1.17.0.
48 Parameters
49 ----------
50 ar : numpy array
51 Input array
53 Returns
54 -------
55 unique_vaulues : numpy array
56 Unique values in array ar.
57 unique_counts : numpy array
58 Number of instances for each unique value in ar.
59 """
60 try:
61 return np.unique(ar, return_counts=True)
62 except TypeError:
63 ar = np.asanyarray(ar).flatten()
64 ar.sort()
65 mask = np.empty(ar.shape, dtype=bool_)
66 mask[:1] = True
67 mask[1:] = ar[1:] != ar[:-1]
68 idx = np.concatenate(np.nonzero(mask) + ([mask.size],))
69 return ar[mask], np.diff(idx)
72###########################################################################
75def extract_pulsefish(data, samplerate, width_factor_shape=3,
76 width_factor_wave=8, width_factor_display=4,
77 verbose=0, plot_level=0, save_plots=False,
78 save_path='', ftype='png', return_data=[]):
79 """ Extract and cluster pulse-type fish EODs from single channel data.
81 Takes recording data containing an unknown number of pulsefish and extracts the mean
82 EOD and EOD timepoints for each fish present in the recording.
84 Parameters
85 ----------
86 data: 1-D array of float
87 The data to be analysed.
88 samplerate: float
89 Sampling rate of the data in Hertz.
91 width_factor_shape : float (optional)
92 Width multiplier used for EOD shape analysis.
93 EOD snippets are extracted based on width between the
94 peak and trough multiplied by the width factor.
95 width_factor_wave : float (optional)
96 Width multiplier used for wavefish detection.
97 width_factor_display : float (optional)
98 Width multiplier used for EOD mean extraction and display.
99 verbose : int (optional)
100 Verbosity level.
101 plot_level : int (optional)
102 Similar to verbosity levels, but with plots.
103 Only set to > 0 for debugging purposes.
104 save_plots : bool (optional)
105 Set to True to save the plots created by plot_level.
106 save_path: string (optional)
107 Path for saving plots.
108 ftype : string (optional)
109 Define the filetype to save the plots in if save_plots is set to True.
110 Options are: 'png', 'jpg', 'svg' ...
111 return_data : list of strings (optional)
112 Specify data that should be logged and returned in a dictionary. Each clustering
113 step has a specific keyword that results in adding different variables to the log dictionary.
114 Optional keys for return_data and the resulting additional key-value pairs to the log dictionary are:
116 - 'all_eod_times':
117 - 'all_times': list of two lists of floats.
118 All peak (`all_times[0]`) and trough times (`all_times[1]`) extracted
119 by the peak detection algorithm. Times are given in seconds.
120 - 'eod_troughtimes': list of 1D arrays.
121 The timepoints in seconds of each unique extracted EOD cluster,
122 where each 1D array encodes one cluster.
124 - 'peak_detection':
125 - "data": 1D numpy array of floats.
126 Quadratically interpolated data which was used for peak detection.
127 - "interp_fac": float.
128 Interpolation factor of raw data.
129 - "peaks_1": 1D numpy array of ints.
130 Peak indices on interpolated data after first peak detection step.
131 - "troughs_1": 1D numpy array of ints.
132 Peak indices on interpolated data after first peak detection step.
133 - "peaks_2": 1D numpy array of ints.
134 Peak indices on interpolated data after second peak detection step.
135 - "troughs_2": 1D numpy array of ints.
136 Peak indices on interpolated data after second peak detection step.
137 - "peaks_3": 1D numpy array of ints.
138 Peak indices on interpolated data after third peak detection step.
139 - "troughs_3": 1D numpy array of ints.
140 Peak indices on interpolated data after third peak detection step.
141 - "peaks_4": 1D numpy array of ints.
142 Peak indices on interpolated data after fourth peak detection step.
143 - "troughs_4": 1D numpy array of ints.
144 Peak indices on interpolated data after fourth peak detection step.
146 - 'all_cluster_steps':
147 - 'samplerate': float.
148 Samplerate of interpolated data.
149 - 'EOD_widths': list of three 1D numpy arrays.
150 The first list entry gives the unique labels of all width clusters
151 as a list of ints.
152 The second list entry gives the width values for each EOD in samples
153 as a 1D numpy array of ints.
154 The third list entry gives the width labels for each EOD
155 as a 1D numpy array of ints.
156 - 'EOD_heights': nested lists (2 layers) of three 1D numpy arrays.
157 The first list entry gives the unique labels of all height clusters
158 as a list of ints for each width cluster.
159 The second list entry gives the height values for each EOD
160 as a 1D numpy array of floats for each width cluster.
161 The third list entry gives the height labels for each EOD
162 as a 1D numpy array of ints for each width cluster.
163 - 'EOD_shapes': nested lists (3 layers) of three 1D numpy arrays
164 The first list entry gives the raw EOD snippets as a 2D numpy array
165 for each height cluster in a width cluster.
166 The second list entry gives the snippet PCA values for each EOD
167 as a 2D numpy array of floats for each height cluster in a width cluster.
168 The third list entry gives the shape labels for each EOD as a 1D numpy array
169 of ints for each height cluster in a width cluster.
170 - 'discarding_masks': Nested lists (two layers) of 1D numpy arrays.
171 The masks of EODs that are discarded by the discarding step of the algorithm.
172 The masks are 1D boolean arrays where instances that are set to True are
173 discarded by the algorithm. Discarding masks are saved in nested lists
174 that represent the width and height clusters.
175 - 'merge_masks': Nested lists (two layers) of 2D numpy arrays.
176 The masks of EODs that are discarded by the merging step of the algorithm.
177 The masks are 2D boolean arrays where for each sample point `i` either
178 `merge_mask[i,0]` or `merge_mask[i,1]` is set to True. Here, merge_mask[:,0]
179 represents the peak-centered clusters and `merge_mask[:,1]` represents the
180 trough-centered clusters. Merge masks are saved in nested lists that
181 represent the width and height clusters.
183 - 'BGM_width':
184 - 'BGM_width': dictionary
185 - 'x': 1D numpy array of floats.
186 BGM input values (in this case the EOD widths),
187 - 'use_log': boolean.
188 True if the z-scored logarithm of the data was used as BGM input.
189 - 'BGM': list of three 1D numpy arrays.
190 The first instance are the weights of the Gaussian fits.
191 The second instance are the means of the Gaussian fits.
192 The third instance are the variances of the Gaussian fits.
193 - 'labels': 1D numpy array of ints.
194 Labels defined by BGM model (before merging based on merge factor).
195 - xlab': string.
196 Label for plot (defines the units of the BGM data).
198 - 'BGM_height':
199 This key adds a new dictionary for each width cluster.
200 - 'BGM_height_*n*' : dictionary, where *n* defines the width cluster as an int.
201 - 'x': 1D numpy array of floats.
202 BGM input values (in this case the EOD heights),
203 - 'use_log': boolean.
204 True if the z-scored logarithm of the data was used as BGM input.
205 - 'BGM': list of three 1D numpy arrays.
206 The first instance are the weights of the Gaussian fits.
207 The second instance are the means of the Gaussian fits.
208 The third instance are the variances of the Gaussian fits.
209 - 'labels': 1D numpy array of ints.
210 Labels defined by BGM model (before merging based on merge factor).
211 - 'xlab': string.
212 Label for plot (defines the units of the BGM data).
214 - 'snippet_clusters':
215 This key adds a new dictionary for each height cluster.
216 - 'snippet_clusters*_n_m_p*' : dictionary, where *n* defines the width cluster
217 (int), *m* defines the height cluster (int) and *p* defines shape clustering
218 on peak or trough centered EOD snippets (string: 'peak' or 'trough').
219 - 'raw_snippets': 2D numpy array (nsamples, nfeatures).
220 Raw EOD snippets.
221 - 'snippets': 2D numpy array.
222 Normalized EOD snippets.
223 - 'features': 2D numpy array.(nsamples, nfeatures)
224 PCA values for each normalized EOD snippet.
225 - 'clusters': 1D numpy array of ints.
226 Cluster labels.
227 - 'samplerate': float.
228 Samplerate of snippets.
230 - 'eod_deletion':
231 This key adds two dictionaries for each (peak centered) shape cluster,
232 where *cluster* (int) is the unique shape cluster label.
233 - 'mask_*cluster*' : list of four booleans.
234 The mask for each cluster discarding step.
235 The first instance represents the artefact masks, where artefacts
236 are set to True.
237 The second instance represents the unreliable cluster masks,
238 where unreliable clusters are set to True.
239 The third instance represents the wavefish masks, where wavefish
240 are set to True.
241 The fourth instance represents the sidepeak masks, where sidepeaks
242 are set to True.
243 - 'vals_*cluster*' : list of lists.
244 All variables that are used for each cluster deletion step.
245 The first instance is a list of two 1D numpy arrays: the mean EOD and
246 the FFT of that mean EOD.
247 The second instance is a 1D numpy array with all EOD width to ISI ratios.
248 The third instance is a list with three entries:
249 The first entry is a 1D numpy array zoomed out version of the mean EOD.
250 The second entry is a list of two 1D numpy arrays that define the peak
251 and trough indices of the zoomed out mean EOD.
252 The third entry contains a list of two values that represent the
253 peak-trough pair in the zoomed out mean EOD with the largest height
254 difference.
255 - 'samplerate' : float.
256 EOD snippet samplerate.
258 - 'masks':
259 - 'masks' : 2D numpy array (4,N).
260 Each row contains masks for each EOD detected by the EOD peakdetection step.
261 The first row defines the artefact masks, the second row defines the
262 unreliable EOD masks,
263 the third row defines the wavefish masks and the fourth row defines
264 the sidepeak masks.
266 - 'moving_fish':
267 - 'moving_fish': dictionary.
268 - 'w' : list of floats.
269 Median width for each width cluster that the moving fish algorithm is
270 computed on (in seconds).
271 - 'T' : list of floats.
272 Lenght of analyzed recording for each width cluster (in seconds).
273 - 'dt' : list of floats.
274 Sliding window size (in seconds) for each width cluster.
275 - 'clusters' : list of 1D numpy int arrays.
276 Cluster labels for each EOD cluster in a width cluster.
277 - 't' : list of 1D numpy float arrays.
278 EOD emission times for each EOD in a width cluster.
279 - 'fishcount' : list of lists.
280 Sliding window timepoints and fishcounts for each width cluster.
281 - 'ignore_steps' : list of 1D int arrays.
282 Mask for fishcounts that were ignored (ignored if True) in the
283 moving_fish analysis.
285 Returns
286 -------
287 mean_eods: list of 2D arrays (3, eod_length)
288 The average EOD for each detected fish. First column is time in seconds,
289 second column the mean eod, third column the standard error.
290 eod_times: list of 1D arrays
291 For each detected fish the times of EOD peaks or troughs in seconds.
292 Use these timepoints for EOD averaging.
293 eod_peaktimes: list of 1D arrays
294 For each detected fish the times of EOD peaks in seconds.
295 zoom_window: tuple of floats
296 Start and endtime of suggested window for plotting EOD timepoints.
297 log_dict: dictionary
298 Dictionary with logged variables, where variables to log are specified
299 by `return_data`.
300 """
301 if verbose > 0:
302 print('')
303 if verbose > 1:
304 print(70*'#')
305 print('##### extract_pulsefish', 46*'#')
307 if (save_plots and plot_level>0 and save_path):
308 # create folder to save things in.
309 if not os.path.exists(save_path):
310 os.makedirs(save_path)
311 else:
312 save_path = ''
314 mean_eods, eod_times, eod_peaktimes, zoom_window = [], [], [], []
315 log_dict = {}
317 # interpolate:
318 i_samplerate = 500000.0
319 #i_samplerate = samplerate
320 try:
321 f = interp1d(np.arange(len(data))/samplerate, data, kind='quadratic')
322 i_data = f(np.arange(0.0, (len(data)-1)/samplerate, 1.0/i_samplerate))
323 except MemoryError:
324 i_samplerate = samplerate
325 i_data = data
326 log_dict['data'] = i_data # TODO: could be removed
327 log_dict['samplerate'] = i_samplerate # TODO: could be removed
328 log_dict['i_data'] = i_data
329 log_dict['i_samplerate'] = i_samplerate
330 # log_dict["interp_fac"] = interp_fac # TODO: is not set anymore
332 # standard deviation of data in small snippets:
333 threshold = median_std_threshold(data, samplerate) # TODO make this a parameter
335 # extract peaks:
336 if 'peak_detection' in return_data:
337 x_peak, x_trough, eod_heights, eod_widths, pd_log_dict = \
338 detect_pulses(i_data, i_samplerate, threshold,
339 width_fac=np.max([width_factor_shape,
340 width_factor_display,
341 width_factor_wave]),
342 verbose=verbose-1, return_data=True)
343 log_dict.update(pd_log_dict)
344 else:
345 x_peak, x_trough, eod_heights, eod_widths = \
346 detect_pulses(i_data, i_samplerate, threshold,
347 width_fac=np.max([width_factor_shape,
348 width_factor_display,
349 width_factor_wave]),
350 verbose=verbose-1, return_data=False)
352 if len(x_peak) > 0:
353 # cluster
354 clusters, x_merge, c_log_dict = cluster(x_peak, x_trough,
355 eod_heights,
356 eod_widths, i_data,
357 i_samplerate,
358 width_factor_shape,
359 width_factor_wave,
360 verbose=verbose-1,
361 plot_level=plot_level-1,
362 save_plots=save_plots,
363 save_path=save_path,
364 ftype=ftype,
365 return_data=return_data)
367 # extract mean eods and times
368 mean_eods, eod_times, eod_peaktimes, eod_troughtimes, cluster_labels = \
369 extract_means(i_data, x_merge, x_peak, x_trough, eod_widths, clusters,
370 i_samplerate, width_factor_display, verbose=verbose-1)
372 # determine clipped clusters (save them, but ignore in other steps)
373 clusters, clipped_eods, clipped_times, clipped_peaktimes, clipped_troughtimes = \
374 find_clipped_clusters(clusters, mean_eods, eod_times, eod_peaktimes,
375 eod_troughtimes, cluster_labels, width_factor_display,
376 verbose=verbose-1)
378 # delete the moving fish
379 clusters, zoom_window, mf_log_dict = \
380 delete_moving_fish(clusters, x_merge/i_samplerate, len(data)/samplerate,
381 eod_heights, eod_widths/i_samplerate, i_samplerate,
382 verbose=verbose-1, plot_level=plot_level-1, save_plot=save_plots,
383 save_path=save_path, ftype=ftype, return_data=return_data)
385 if 'moving_fish' in return_data:
386 log_dict['moving_fish'] = mf_log_dict
388 clusters = remove_sparse_detections(clusters, eod_widths, i_samplerate,
389 len(data)/samplerate, verbose=verbose-1)
391 # extract mean eods
392 mean_eods, eod_times, eod_peaktimes, eod_troughtimes, cluster_labels = \
393 extract_means(i_data, x_merge, x_peak, x_trough, eod_widths,
394 clusters, i_samplerate, width_factor_display, verbose=verbose-1)
396 mean_eods.extend(clipped_eods)
397 eod_times.extend(clipped_times)
398 eod_peaktimes.extend(clipped_peaktimes)
399 eod_troughtimes.extend(clipped_troughtimes)
401 if plot_level > 0:
402 plot_all(data, eod_peaktimes, eod_troughtimes, samplerate, mean_eods)
403 if save_plots:
404 plt.savefig('%sextract_pulsefish_results.%s' % (save_path, ftype))
405 if save_plots:
406 plt.close('all')
408 if 'all_eod_times' in return_data:
409 log_dict['all_times'] = [x_peak/i_samplerate, x_trough/i_samplerate]
410 log_dict['eod_troughtimes'] = eod_troughtimes
412 log_dict.update(c_log_dict)
414 return mean_eods, eod_times, eod_peaktimes, zoom_window, log_dict
417def detect_pulses(data, samplerate, thresh, min_rel_slope_diff=0.25,
418 min_width=0.00005, max_width=0.01, width_fac=5.0,
419 verbose=0, return_data=False):
420 """Detect pulses in data.
422 Was `def extract_eod_times(data, samplerate, width_factor,
423 interp_freq=500000, max_peakwidth=0.01,
424 min_peakwidth=None, verbose=0, return_data=[],
425 save_path='')` before.
427 Parameters
428 ----------
429 data: 1-D array of float
430 The data to be analysed.
431 samplerate: float
432 Sampling rate of the data.
433 thresh: float
434 Threshold for peak and trough detection via `detect_peaks()`.
435 Must be a positive number that sets the minimum difference
436 between a peak and a trough.
437 min_rel_slope_diff: float
438 Minimum required difference between left and right slope (between
439 peak and troughs) relative to mean slope for deciding which trough
440 to take besed on slope difference.
441 min_width: float
442 Minimum width (peak-trough distance) of pulses in seconds.
443 max_width: float
444 Maximum width (peak-trough distance) of pulses in seconds.
445 width_fac: float
446 Pulses extend plus or minus `width_fac` times their width
447 (distance between peak and assigned trough).
448 Only pulses are returned that can fully be analysed with this width.
449 verbose : int (optional)
450 Verbosity level.
451 return_data : bool
452 If `True` data of this function is logged and returned (see
453 extract_pulsefish()).
455 Returns
456 -------
457 peak_indices: array of ints
458 Indices of EOD peaks in data.
459 trough_indices: array of ints
460 Indices of EOD troughs in data. There is one x_trough for each x_peak.
461 heights: array of floats
462 EOD heights for each x_peak.
463 widths: array of ints
464 EOD widths for each x_peak (in samples).
465 peak_detection_result : dictionary
466 Key value pairs of logged data.
467 This is only returned if `return_data` is `True`.
469 """
470 peak_detection_result = {}
472 # detect peaks and troughs in the data:
473 peak_indices, trough_indices = detect_peaks(data, thresh)
474 if verbose > 0:
475 print('Peaks/troughs detected in data: %5d %5d'
476 % (len(peak_indices), len(trough_indices)))
477 if return_data:
478 peak_detection_result.update(peaks_1=np.array(peak_indices),
479 troughs_1=np.array(trough_indices))
480 if len(peak_indices) < 2 or \
481 len(trough_indices) < 2 or \
482 len(peak_indices) > len(data)/20:
483 # TODO: if too many peaks increase threshold!
484 if verbose > 0:
485 print('No or too many peaks/troughs detected in data.')
486 if return_data:
487 return np.array([], dtype=int), np.array([], dtype=int), \
488 np.array([]), np.array([], dtype=int), peak_detection_result
489 else:
490 return np.array([], dtype=int), np.array([], dtype=int), \
491 np.array([]), np.array([], dtype=int)
493 # assign troughs to peaks:
494 peak_indices, trough_indices, heights, widths, slopes = \
495 assign_side_peaks(data, peak_indices, trough_indices, min_rel_slope_diff)
496 if verbose > 1:
497 print('Number of peaks after assigning side-peaks: %5d'
498 % (len(peak_indices)))
499 if return_data:
500 peak_detection_result.update(peaks_2=np.array(peak_indices),
501 troughs_2=np.array(trough_indices))
503 # check widths:
504 keep = ((widths>min_width*samplerate) & (widths<max_width*samplerate))
505 peak_indices = peak_indices[keep]
506 trough_indices = trough_indices[keep]
507 heights = heights[keep]
508 widths = widths[keep]
509 slopes = slopes[keep]
510 if verbose > 1:
511 print('Number of peaks after checking pulse width: %5d'
512 % (len(peak_indices)))
513 if return_data:
514 peak_detection_result.update(peaks_3=np.array(peak_indices),
515 troughs_3=np.array(trough_indices))
517 # discard connected peaks:
518 same = np.nonzero(trough_indices[:-1] == trough_indices[1:])[0]
519 keep = np.ones(len(trough_indices), dtype=bool)
520 for i in same:
521 # same troughs at trough_indices[i] and trough_indices[i+1]:
522 s = slopes[i:i+2]
523 rel_slopes = np.abs(np.diff(s))[0]/np.mean(s)
524 if rel_slopes > min_rel_slope_diff:
525 keep[i+(s[1]<s[0])] = False
526 else:
527 keep[i+(heights[i+1]<heights[i])] = False
528 peak_indices = peak_indices[keep]
529 trough_indices = trough_indices[keep]
530 heights = heights[keep]
531 widths = widths[keep]
532 if verbose > 1:
533 print('Number of peaks after merging pulses: %5d'
534 % (len(peak_indices)))
535 if return_data:
536 peak_detection_result.update(peaks_4=np.array(peak_indices),
537 troughs_4=np.array(trough_indices))
538 if len(peak_indices) == 0:
539 if verbose > 0:
540 print('No peaks remain as pulse candidates.')
541 if return_data:
542 return np.array([], dtype=int), np.array([], dtype=int), \
543 np.array([]), np.array([], dtype=int), peak_detection_result
544 else:
545 return np.array([], dtype=int), np.array([], dtype=int), \
546 np.array([]), np.array([], dtype=int)
548 # only take those where the maximum cutwidth does not cause issues -
549 # if the width_fac times the width + x is more than length.
550 keep = ((peak_indices - widths > 0) &
551 (peak_indices + widths < len(data)) &
552 (trough_indices - widths > 0) &
553 (trough_indices + widths < len(data)))
555 if verbose > 0:
556 print('Remaining peaks after EOD extraction: %5d'
557 % (p.sum(keep)))
558 print('')
560 if return_data:
561 return peak_indices[keep], trough_indices[keep], \
562 heights[keep], widths[keep], peak_detection_result
563 else:
564 return peak_indices[keep], trough_indices[keep], \
565 heights[keep], widths[keep]
568@jit(nopython=True)
569def assign_side_peaks(data, peak_indices, trough_indices,
570 min_rel_slope_diff=0.25):
571 """Assign to each peak the trough resulting in a pulse with the steepest slope or largest height.
573 The slope between a peak and a trough is computed as the height
574 difference divided by the distance between peak and trough. If the
575 slopes between the left and the right trough differ by less than
576 `min_rel_slope_diff`, then just the heigths between and the two
577 troughs relative to the peak are compared.
579 Was `def detect_eod_peaks(data, main_indices, side_indices,
580 max_width=20, min_width=2, verbose=0)` before.
582 Parameters
583 ----------
584 data: array of floats
585 Data in which the events were detected.
586 peak_indices: array of ints
587 Indices of the detected peaks in the data time series.
588 trough_indices: array of ints
589 Indices of the detected troughs in the data time series.
590 min_rel_slope_diff: float
591 Minimum required difference of left and right slope relative
592 to mean slope.
594 Returns
595 -------
596 peak_indices: array of ints
597 Peak indices. Same as input `peak_indices` but potentially shorter
598 by one or two elements.
599 trough_indices: array of ints
600 Corresponding trough indices of trough to the left or right
601 of the peaks.
602 heights: array of floats
603 Peak heights (distance between peak and corresponding trough amplitude)
604 widths: array of ints
605 Peak widths (distance between peak and corresponding trough indices)
606 slopes: array of floats
607 Peak slope (height divided by width)
608 """
609 # is a main or side peak first?
610 peak_first = int(peak_indices[0] < trough_indices[0])
611 # is a main or side peak last?
612 peak_last = int(peak_indices[-1] > trough_indices[-1])
613 # ensure all peaks to have side peaks (troughs) at both sides,
614 # i.e. troughs at same index and next index are before and after peak:
615 peak_indices = peak_indices[peak_first:len(peak_indices)-peak_last]
616 y = data[peak_indices]
618 # indices of troughs on the left and right side of main peaks:
619 l_indices = np.arange(len(peak_indices))
620 r_indices = l_indices + 1
622 # indices, distance to peak, height, and slope of left troughs:
623 l_side_indices = trough_indices[l_indices]
624 l_distance = np.abs(peak_indices - l_side_indices)
625 l_height = np.abs(y - data[l_side_indices])
626 l_slope = np.abs(l_height/l_distance)
628 # indices, distance to peak, height, and slope of right troughs:
629 r_side_indices = trough_indices[r_indices]
630 r_distance = np.abs(r_side_indices - peak_indices)
631 r_height = np.abs(y - data[r_side_indices])
632 r_slope = np.abs(r_height/r_distance)
634 # which trough to assign to the peak?
635 # - either the one with the steepest slope, or
636 # - when slopes are similar on both sides
637 # (within `min_rel_slope_diff` difference),
638 # the trough with the maximum height difference to the peak:
639 rel_slopes = np.abs(l_slope-r_slope)/(0.5*(l_slope+r_slope))
640 take_slopes = rel_slopes > min_rel_slope_diff
641 take_left = l_height > r_height
642 take_left[take_slopes] = l_slope[take_slopes] > r_slope[take_slopes]
644 # assign troughs, heights, widths, and slopes:
645 trough_indices = np.where(take_left,
646 trough_indices[:-1], trough_indices[1:])
647 heights = np.where(take_left, l_height, r_height)
648 widths = np.where(take_left, l_distance, r_distance)
649 slopes = np.where(take_left, l_slope, r_slope)
651 return peak_indices, trough_indices, heights, widths, slopes
654def cluster(eod_xp, eod_xt, eod_heights, eod_widths, data, samplerate,
655 width_factor_shape, width_factor_wave,
656 n_gaus_height=10, merge_threshold_height=0.1, n_gaus_width=3,
657 merge_threshold_width=0.5, minp=10,
658 verbose=0, plot_level=0, save_plots=False, save_path='', ftype='pdf',
659 return_data=[]):
660 """Cluster EODs.
662 First cluster on EOD widths using a Bayesian Gaussian
663 Mixture (BGM) model, then cluster on EOD heights using a
664 BGM model. Lastly, cluster on EOD waveform with DBSCAN.
665 Clustering on EOD waveform is performed twice, once on
666 peak-centered EODs and once on trough-centered EODs.
667 Non-pulsetype EOD clusters are deleted, and clusters are
668 merged afterwards.
670 Parameters
671 ----------
672 eod_xp : list of ints
673 Location of EOD peaks in indices.
674 eod_xt: list of ints
675 Locations of EOD troughs in indices.
676 eod_heights: list of floats
677 EOD heights.
678 eod_widths: list of ints
679 EOD widths in samples.
680 data: array of floats
681 Data in which to detect pulse EODs.
682 samplerate : float
683 Sample rate of `data`.
684 width_factor_shape : float
685 Multiplier for snippet extraction width. This factor is
686 multiplied with the width between the peak and through of a
687 single EOD.
688 width_factor_wave : float
689 Multiplier for wavefish extraction width.
690 n_gaus_height : int (optional)
691 Number of gaussians to use for the clustering based on EOD height.
692 merge_threshold_height : float (optional)
693 Threshold for merging clusters that are similar in height.
694 n_gaus_width : int (optional)
695 Number of gaussians to use for the clustering based on EOD width.
696 merge_threshold_width : float (optional)
697 Threshold for merging clusters that are similar in width.
698 minp : int (optional)
699 Minimum number of points for core clusters (DBSCAN).
700 verbose : int (optional)
701 Verbosity level.
702 plot_level : int (optional)
703 Similar to verbosity levels, but with plots.
704 Only set to > 0 for debugging purposes.
705 save_plots : bool (optional)
706 Set to True to save created plots.
707 save_path : string (optional)
708 Path to save plots to. Only used if save_plots==True.
709 ftype : string (optional)
710 Filetype to save plot images in.
711 return_data : list of strings (optional)
712 Keys that specify data to be logged. Keys that can be used to log data
713 in this function are: 'all_cluster_steps', 'BGM_width', 'BGM_height',
714 'snippet_clusters', 'eod_deletion' (see extract_pulsefish()).
716 Returns
717 -------
718 labels : list of ints
719 EOD cluster labels based on height and EOD waveform.
720 x_merge : list of ints
721 Locations of EODs in clusters.
722 saved_data : dictionary
723 Key value pairs of logged data. Data to be logged is specified
724 by return_data.
726 """
727 saved_data = {}
729 if plot_level>0 or 'all_cluster_steps' in return_data:
730 all_heightlabels = []
731 all_shapelabels = []
732 all_snippets = []
733 all_features = []
734 all_heights = []
735 all_unique_heightlabels = []
737 all_p_clusters = -1 * np.ones(len(eod_xp))
738 all_t_clusters = -1 * np.ones(len(eod_xp))
739 artefact_masks_p = np.ones(len(eod_xp), dtype=bool)
740 artefact_masks_t = np.ones(len(eod_xp), dtype=bool)
742 x_merge = -1 * np.ones(len(eod_xp))
744 max_label_p = 0 # keep track of the labels so that no labels are overwritten
745 max_label_t = 0
747 # loop only over height clusters that are bigger than minp
748 # first cluster on width
749 width_labels, bgm_log_dict = BGM(1000*eod_widths/samplerate, merge_threshold_width,
750 n_gaus_width, use_log=False, verbose=verbose-1,
751 plot_level=plot_level-1, xlabel='width [ms]',
752 save_plot=save_plots, save_path=save_path,
753 save_name='width', ftype=ftype, return_data=return_data)
754 saved_data.update(bgm_log_dict)
756 if verbose>0:
757 print('Clusters generated based on EOD width:')
758 [print('N_{} = {:>4} h_{} = {:.4f}'.format(l, len(width_labels[width_labels==l]), l, np.mean(eod_widths[width_labels==l]))) for l in np.unique(width_labels)]
761 w_labels, w_counts = unique_counts(width_labels)
762 unique_width_labels = w_labels[w_counts>minp]
764 for wi, width_label in enumerate(unique_width_labels):
766 # select only features in one width cluster at a time
767 w_eod_widths = eod_widths[width_labels==width_label]
768 w_eod_heights = eod_heights[width_labels==width_label]
769 w_eod_xp = eod_xp[width_labels==width_label]
770 w_eod_xt = eod_xt[width_labels==width_label]
771 width = int(width_factor_shape*np.median(w_eod_widths))
772 if width > w_eod_xp[0]:
773 width = w_eod_xp[0]
774 if width > w_eod_xt[0]:
775 width = w_eod_xt[0]
776 if width > len(data) - w_eod_xp[-1]:
777 width = len(data) - w_eod_xp[-1]
778 if width > len(data) - w_eod_xt[-1]:
779 width = len(data) - w_eod_xt[-1]
781 wp_clusters = -1 * np.ones(len(w_eod_xp))
782 wt_clusters = -1 * np.ones(len(w_eod_xp))
783 wartefact_mask = np.ones(len(w_eod_xp))
785 # determine height labels
786 raw_p_snippets, p_snippets, p_features, p_bg_ratio = \
787 extract_snippet_features(data, w_eod_xp, w_eod_heights, width)
788 raw_t_snippets, t_snippets, t_features, t_bg_ratio = \
789 extract_snippet_features(data, w_eod_xt, w_eod_heights, width)
791 height_labels, bgm_log_dict = \
792 BGM(w_eod_heights,
793 min(merge_threshold_height, np.median(np.min(np.vstack([p_bg_ratio, t_bg_ratio]), axis=0))),
794 n_gaus_height, use_log=True, verbose=verbose-1, plot_level=plot_level-1,
795 xlabel = 'height [a.u.]', save_plot=save_plots, save_path=save_path,
796 save_name = 'height_%d' % wi, ftype=ftype, return_data=return_data)
797 saved_data.update(bgm_log_dict)
799 if verbose>0:
800 print('Clusters generated based on EOD height:')
801 [print('N_{} = {:>4} h_{} = {:.4f}'.format(l, len(height_labels[height_labels==l]), l, np.mean(w_eod_heights[height_labels==l]))) for l in np.unique(height_labels)]
803 h_labels, h_counts = unique_counts(height_labels)
804 unique_height_labels = h_labels[h_counts>minp]
806 if plot_level>0 or 'all_cluster_steps' in return_data:
807 all_heightlabels.append(height_labels)
808 all_heights.append(w_eod_heights)
809 all_unique_heightlabels.append(unique_height_labels)
810 shape_labels = []
811 cfeatures = []
812 csnippets = []
814 for hi, height_label in enumerate(unique_height_labels):
816 h_eod_widths = w_eod_widths[height_labels==height_label]
817 h_eod_heights = w_eod_heights[height_labels==height_label]
818 h_eod_xp = w_eod_xp[height_labels==height_label]
819 h_eod_xt = w_eod_xt[height_labels==height_label]
821 p_clusters = cluster_on_shape(p_features[height_labels==height_label],
822 p_bg_ratio, minp, verbose=0)
823 t_clusters = cluster_on_shape(t_features[height_labels==height_label],
824 t_bg_ratio, minp, verbose=0)
826 if plot_level>1:
827 plot_feature_extraction(raw_p_snippets[height_labels==height_label],
828 p_snippets[height_labels==height_label],
829 p_features[height_labels==height_label],
830 p_clusters, 1/samplerate, 0)
831 plt.savefig('%sDBSCAN_peak_w%i_h%i.%s' % (save_path, wi, hi, ftype))
832 plot_feature_extraction(raw_t_snippets[height_labels==height_label],
833 t_snippets[height_labels==height_label],
834 t_features[height_labels==height_label],
835 t_clusters, 1/samplerate, 1)
836 plt.savefig('%sDBSCAN_trough_w%i_h%i.%s' % (save_path, wi, hi, ftype))
838 if 'snippet_clusters' in return_data:
839 saved_data['snippet_clusters_%i_%i_peak' % (width_label, height_label)] = {
840 'raw_snippets':raw_p_snippets[height_labels==height_label],
841 'snippets':p_snippets[height_labels==height_label],
842 'features':p_features[height_labels==height_label],
843 'clusters':p_clusters,
844 'samplerate':samplerate}
845 saved_data['snippet_clusters_%i_%i_trough' % (width_label, height_label)] = {
846 'raw_snippets':raw_t_snippets[height_labels==height_label],
847 'snippets':t_snippets[height_labels==height_label],
848 'features':t_features[height_labels==height_label],
849 'clusters':t_clusters,
850 'samplerate':samplerate}
852 if plot_level>0 or 'all_cluster_steps' in return_data:
853 shape_labels.append([p_clusters, t_clusters])
854 cfeatures.append([p_features[height_labels==height_label],
855 t_features[height_labels==height_label]])
856 csnippets.append([p_snippets[height_labels==height_label],
857 t_snippets[height_labels==height_label]])
859 p_clusters[p_clusters==-1] = -max_label_p - 1
860 wp_clusters[height_labels==height_label] = p_clusters + max_label_p
861 max_label_p = max(np.max(wp_clusters), np.max(all_p_clusters)) + 1
863 t_clusters[t_clusters==-1] = -max_label_t - 1
864 wt_clusters[height_labels==height_label] = t_clusters + max_label_t
865 max_label_t = max(np.max(wt_clusters), np.max(all_t_clusters)) + 1
867 if verbose > 0:
868 if np.max(wp_clusters) == -1:
869 print('No EOD peaks in width cluster %i' % width_label)
870 else:
871 unique_clusters = np.unique(wp_clusters[wp_clusters!=-1])
872 if len(unique_clusters) > 1:
873 print('%i different EOD peaks in width cluster %i' % (len(unique_clusters), width_label))
875 if plot_level>0 or 'all_cluster_steps' in return_data:
876 all_shapelabels.append(shape_labels)
877 all_snippets.append(csnippets)
878 all_features.append(cfeatures)
880 # for each cluster, save fft + label
881 # so I end up with features for each label, and the masks.
882 # then I can extract e.g. first artefact or wave etc.
884 # remove artefacts here, based on the mean snippets ffts.
885 artefact_masks_p[width_labels==width_label], sdict = \
886 remove_artefacts(p_snippets, wp_clusters, samplerate,
887 verbose=verbose-1, return_data=return_data)
888 saved_data.update(sdict)
889 artefact_masks_t[width_labels==width_label], _ = \
890 remove_artefacts(t_snippets, wt_clusters, samplerate,
891 verbose=verbose-1, return_data=return_data)
893 # update maxlab so that no clusters are overwritten
894 all_p_clusters[width_labels==width_label] = wp_clusters
895 all_t_clusters[width_labels==width_label] = wt_clusters
897 # remove all non-reliable clusters
898 unreliable_fish_mask_p, saved_data = \
899 delete_unreliable_fish(all_p_clusters, eod_widths, eod_xp,
900 verbose=verbose-1, sdict=saved_data)
901 unreliable_fish_mask_t, _ = \
902 delete_unreliable_fish(all_t_clusters, eod_widths, eod_xt, verbose=verbose-1)
904 wave_mask_p, sidepeak_mask_p, saved_data = \
905 delete_wavefish_and_sidepeaks(data, all_p_clusters, eod_xp, eod_widths,
906 width_factor_wave, verbose=verbose-1, sdict=saved_data)
907 wave_mask_t, sidepeak_mask_t, _ = \
908 delete_wavefish_and_sidepeaks(data, all_t_clusters, eod_xt, eod_widths,
909 width_factor_wave, verbose=verbose-1)
911 og_clusters = [np.copy(all_p_clusters), np.copy(all_t_clusters)]
912 og_labels=np.copy(all_p_clusters+all_t_clusters)
914 # go through all clusters and masks??
915 all_p_clusters[(artefact_masks_p | unreliable_fish_mask_p | wave_mask_p | sidepeak_mask_p)] = -1
916 all_t_clusters[(artefact_masks_t | unreliable_fish_mask_t | wave_mask_t | sidepeak_mask_t)] = -1
918 # merge here.
919 all_clusters, x_merge, mask = merge_clusters(np.copy(all_p_clusters),
920 np.copy(all_t_clusters), eod_xp, eod_xt,
921 verbose=verbose-1)
923 if 'all_cluster_steps' in return_data or plot_level>0:
924 all_dmasks = []
925 all_mmasks = []
927 discarding_masks = \
928 np.vstack(((artefact_masks_p | unreliable_fish_mask_p | wave_mask_p | sidepeak_mask_p),
929 (artefact_masks_t | unreliable_fish_mask_t | wave_mask_t | sidepeak_mask_t)))
930 merge_mask = mask
932 # save the masks in the same formats as the snippets
933 for wi, (width_label, w_shape_label, heightlabels, unique_height_labels) in enumerate(zip(unique_width_labels, all_shapelabels, all_heightlabels, all_unique_heightlabels)):
934 w_dmasks = discarding_masks[:,width_labels==width_label]
935 w_mmasks = merge_mask[:,width_labels==width_label]
937 wd_2 = []
938 wm_2 = []
940 for hi, (height_label, h_shape_label) in enumerate(zip(unique_height_labels, w_shape_label)):
942 h_dmasks = w_dmasks[:,heightlabels==height_label]
943 h_mmasks = w_mmasks[:,heightlabels==height_label]
945 wd_2.append(h_dmasks)
946 wm_2.append(h_mmasks)
948 all_dmasks.append(wd_2)
949 all_mmasks.append(wm_2)
951 if plot_level>0:
952 plot_clustering(samplerate, [unique_width_labels, eod_widths, width_labels],
953 [all_unique_heightlabels, all_heights, all_heightlabels],
954 [all_snippets, all_features, all_shapelabels],
955 all_dmasks, all_mmasks)
956 if save_plots:
957 plt.savefig('%sclustering.%s' % (save_path, ftype))
959 if 'all_cluster_steps' in return_data:
960 saved_data = {'samplerate': samplerate,
961 'EOD_widths': [unique_width_labels, eod_widths, width_labels],
962 'EOD_heights': [all_unique_heightlabels, all_heights, all_heightlabels],
963 'EOD_shapes': [all_snippets, all_features, all_shapelabels],
964 'discarding_masks': all_dmasks,
965 'merge_masks': all_mmasks
966 }
968 if 'masks' in return_data:
969 saved_data = {'masks' : np.vstack(((artefact_masks_p & artefact_masks_t),
970 (unreliable_fish_mask_p & unreliable_fish_mask_t),
971 (wave_mask_p & wave_mask_t),
972 (sidepeak_mask_p & sidepeak_mask_t),
973 (all_p_clusters+all_t_clusters)))}
975 if verbose>0:
976 print('Clusters generated based on height, width and shape: ')
977 [print('N_{} = {:>4}'.format(int(l), len(all_clusters[all_clusters == l]))) for l in np.unique(all_clusters[all_clusters != -1])]
979 return all_clusters, x_merge, saved_data
982def BGM(x, merge_threshold=0.1, n_gaus=5, max_iter=200, n_init=5,
983 use_log=False, verbose=0, plot_level=0, xlabel='x [a.u.]',
984 save_plot=False, save_path='', save_name='', ftype='pdf', return_data=[]):
985 """ Use a Bayesian Gaussian Mixture Model to cluster one-dimensional data.
987 Additional steps are used to merge clusters that are closer than
988 `merge_threshold`. Broad gaussian fits that cover one or more other
989 gaussian fits are split by their intersections with the other
990 gaussians.
992 Parameters
993 ----------
994 x : 1D numpy array
995 Features to compute clustering on.
997 merge_threshold : float (optional)
998 Ratio for merging nearby gaussians.
999 n_gaus: int (optional)
1000 Maximum number of gaussians to fit on data.
1001 max_iter : int (optional)
1002 Maximum number of iterations for gaussian fit.
1003 n_init : int (optional)
1004 Number of initializations for the gaussian fit.
1005 use_log: boolean (optional)
1006 Set to True to compute the gaussian fit on the logarithm of x.
1007 Can improve clustering on features with nonlinear relationships such as peak height.
1008 verbose : int (optional)
1009 Verbosity level.
1010 plot_level : int (optional)
1011 Similar to verbosity levels, but with plots.
1012 Only set to > 0 for debugging purposes.
1013 xlabel : string (optional)
1014 Xlabel for displaying BGM plot.
1015 save_plot : bool (optional)
1016 Set to True to save created plot.
1017 save_path : string (optional)
1018 Path to location where data should be saved. Only used if save_plot==True.
1019 save_name : string (optional)
1020 Filename of the saved plot. Usefull as usually multiple BGM models are generated.
1021 ftype : string (optional)
1022 Filetype of plot image if save_plots==True.
1023 return_data : list of strings (optional)
1024 Keys that specify data to be logged. Keys that can be used to log data
1025 in this function are: 'BGM_width' and/or 'BGM_height' (see extract_pulsefish()).
1027 Returns
1028 -------
1029 labels : 1D numpy array
1030 Cluster labels for each sample in x.
1031 bgm_dict : dictionary
1032 Key value pairs of logged data. Data to be logged is specified by return_data.
1033 """
1035 bgm_dict = {}
1037 if len(np.unique(x)) > n_gaus:
1038 BGM_model = BayesianGaussianMixture(n_components=n_gaus, max_iter=max_iter, n_init=n_init)
1039 if use_log:
1040 labels = BGM_model.fit_predict(stats.zscore(np.log(x)).reshape(-1, 1))
1041 else:
1042 labels = BGM_model.fit_predict(stats.zscore(x).reshape(-1, 1))
1043 else:
1044 return np.zeros(len(x)), bgm_dict
1046 if verbose>0:
1047 if not BGM_model.converged_:
1048 print('!!! Gaussian mixture did not converge !!!')
1050 cur_labels = np.unique(labels)
1052 # map labels to be increasing for increasing values for x
1053 maxlab = len(cur_labels)
1054 aso = np.argsort([np.median(x[labels == l]) for l in cur_labels]) + 100
1055 for i, a in zip(cur_labels, aso):
1056 labels[labels==i] = a
1057 labels = labels - 100
1059 # separate gaussian clusters that can be split by other clusters
1060 splits = np.sort(np.copy(x))[1:][np.diff(labels[np.argsort(x)])!=0]
1062 labels[:] = 0
1063 for i, split in enumerate(splits):
1064 labels[x>=split] = i+1
1066 labels_before_merge = np.copy(labels)
1068 # merge gaussian clusters that are closer than merge_threshold
1069 labels = merge_gaussians(x, labels, merge_threshold)
1071 if 'BGM_'+save_name.split('_')[0] in return_data or plot_level>0:
1073 #sort model attributes by model_means_
1074 means = [m[0] for m in BGM_model.means_]
1075 weights = [w for w in BGM_model.weights_]
1076 variances = [v[0][0] for v in BGM_model.covariances_]
1077 weights = [w for _, w in sorted(zip(means, weights))]
1078 variances = [v for _, v in sorted(zip(means, variances))]
1079 means = sorted(means)
1081 if plot_level>0:
1082 plot_bgm(x, means, variances, weights, use_log, labels_before_merge,
1083 labels, xlabel)
1084 if save_plot:
1085 plt.savefig('%sBGM_%s.%s' % (save_path, save_name, ftype))
1087 if 'BGM_'+save_name.split('_')[0] in return_data:
1088 bgm_dict['BGM_'+save_name] = {'x':x,
1089 'use_log':use_log,
1090 'BGM':[weights, means, variances],
1091 'labels':labels_before_merge,
1092 'xlab':xlabel}
1094 return labels, bgm_dict
1097def merge_gaussians(x, labels, merge_threshold=0.1):
1098 """ Merge all clusters which have medians which are near. Only works in 1D.
1100 Parameters
1101 ----------
1102 x : 1D array of ints or floats
1103 Features used for clustering.
1104 labels : 1D array of ints
1105 Labels for each sample in x.
1106 merge_threshold : float (optional)
1107 Similarity threshold to merge clusters.
1109 Returns
1110 -------
1111 labels : 1D array of ints
1112 Merged labels for each sample in x.
1113 """
1115 # compare all the means of the gaussians. If they are too close, merge them.
1116 unique_labels = np.unique(labels[labels!=-1])
1117 x_medians = [np.median(x[labels==l]) for l in unique_labels]
1119 # fill a dict with the label mappings
1120 mapping = {}
1121 for label_1, x_m1 in zip(unique_labels, x_medians):
1122 for label_2, x_m2 in zip(unique_labels, x_medians):
1123 if label_1!=label_2:
1124 if np.abs(np.diff([x_m1, x_m2]))/np.max([x_m1, x_m2]) < merge_threshold:
1125 mapping[label_1] = label_2
1126 # apply mapping
1127 for map_key, map_value in mapping.items():
1128 labels[labels==map_key] = map_value
1130 return labels
1133def extract_snippet_features(data, eod_x, eod_heights, width, n_pc=5):
1134 """ Extract snippets from recording data, normalize them, and perform PCA.
1136 Parameters
1137 ----------
1138 data : 1D numpy array of floats
1139 Recording data.
1140 eod_x : 1D array of ints
1141 Locations of EODs as indices.
1142 eod_heights: 1D array of floats
1143 EOD heights.
1144 width : int
1145 Width to cut out to each side in samples.
1147 n_pc : int (optional)
1148 Number of PCs to use for PCA.
1150 Returns
1151 -------
1152 raw_snippets : 2D numpy array (N, EOD_width)
1153 Raw extracted EOD snippets.
1154 snippets : 2D numpy array (N, EOD_width)
1155 Normalized EOD snippets
1156 features : 2D numpy array (N,n_pc)
1157 PC values of EOD snippets
1158 bg_ratio : 1D numpy array (N)
1159 Ratio of the background activity slopes compared to EOD height.
1160 """
1161 # extract snippets with corresponding width
1162 raw_snippets = np.vstack([data[x-width:x+width] for x in eod_x])
1164 # subtract the slope and normalize the snippets
1165 snippets, bg_ratio = subtract_slope(np.copy(raw_snippets), eod_heights)
1166 snippets = StandardScaler().fit_transform(snippets.T).T
1168 # scale so that the absolute integral = 1.
1169 snippets = (snippets.T/np.sum(np.abs(snippets), axis=1)).T
1171 # compute features for clustering on waveform
1172 features = PCA(n_pc).fit(snippets).transform(snippets)
1174 return raw_snippets, snippets, features, bg_ratio
1177def cluster_on_shape(features, bg_ratio, minp, percentile=80, max_epsilon=0.01,
1178 slope_ratio_factor=4, min_cluster_fraction=0.01, verbose=0):
1179 """Separate EODs by their shape using DBSCAN.
1181 Parameters
1182 ----------
1183 features : 2D numpy array of floats (N, n_pc)
1184 PCA features of each EOD in a recording.
1185 bg_ratio : 1D array of floats
1186 Ratio of background activity slope the EOD is superimposed on.
1187 minp : int
1188 Minimum number of points for core cluster (DBSCAN).
1190 percentile : int (optional)
1191 Percentile of KNN distribution, where K=minp, to use as epsilon for DBSCAN.
1192 max_epsilon : float (optional)
1193 Maximum epsilon to use for DBSCAN clustering. This is used to avoid adding
1194 noisy clusters.
1195 slope_ratio_factor : float (optional)
1196 Influence of the slope-to-EOD ratio on the epsilon parameter.
1197 A slope_ratio_factor of 4 means that slope-to-EOD ratios >1/4
1198 start influencing epsilon.
1199 min_cluster_fraction : float (optional)
1200 Minimum fraction of all eveluated datapoint that can form a single cluster.
1201 verbose : int (optional)
1202 Verbosity level.
1204 Returns
1205 -------
1206 labels : 1D array of ints
1207 Merged labels for each sample in x.
1208 """
1210 # determine clustering threshold from data
1211 minpc = max(minp, int(len(features)*min_cluster_fraction))
1212 knn = np.sort(pairwise_distances(features, features), axis=0)[minpc]
1213 eps = min(max(1, slope_ratio_factor*np.median(bg_ratio))*max_epsilon,
1214 np.percentile(knn, percentile))
1216 if verbose>1:
1217 print('epsilon = %f'%eps)
1218 print('Slope to EOD ratio = %f'%np.median(bg_ratio))
1220 # cluster on EOD shape
1221 return DBSCAN(eps=eps, min_samples=minpc).fit(features).labels_
1224def subtract_slope(snippets, heights):
1225 """ Subtract underlying slope from all EOD snippets.
1227 Parameters
1228 ----------
1229 snippets: 2-D numpy array
1230 All EODs in a recorded stacked as snippets.
1231 Shape = (number of EODs, EOD width)
1232 heights: 1D numpy array
1233 EOD heights.
1235 Returns
1236 -------
1237 snippets: 2-D numpy array
1238 EOD snippets with underlying slope subtracted.
1239 bg_ratio : 1-D numpy array
1240 EOD height/background activity height.
1241 """
1243 left_y = snippets[:,0]
1244 right_y = snippets[:,-1]
1246 try:
1247 slopes = np.linspace(left_y, right_y, snippets.shape[1])
1248 except ValueError:
1249 delta = (right_y - left_y)/snippets.shape[1]
1250 slopes = np.arange(0, snippets.shape[1], dtype=snippets.dtype).reshape((-1,) + (1,) * np.ndim(delta))*delta + left_y
1252 return snippets - slopes.T, np.abs(left_y-right_y)/heights
1255def remove_artefacts(all_snippets, clusters, samplerate,
1256 freq_low=20000, threshold=0.75,
1257 verbose=0, return_data=[]):
1258 """ Create a mask for EOD clusters that result from artefacts, based on power in low frequency spectrum.
1260 Parameters
1261 ----------
1262 all_snippets: 2D array
1263 EOD snippets. Shape=(nEODs, EOD length)
1264 clusters: list of ints
1265 EOD cluster labels
1266 samplerate : float
1267 Samplerate of original recording data.
1268 freq_low: float
1269 Frequency up to which low frequency components are summed up.
1270 threshold : float (optional)
1271 Minimum value for sum of low frequency components relative to
1272 sum overa ll spectrl amplitudes that separates artefact from
1273 clean pulsefish clusters.
1274 verbose : int (optional)
1275 Verbosity level.
1276 return_data : list of strings (optional)
1277 Keys that specify data to be logged. The key that can be used to log data in this function is
1278 'eod_deletion' (see extract_pulsefish()).
1280 Returns
1281 -------
1282 mask: numpy array of booleans
1283 Set to True for every EOD which is an artefact.
1284 adict : dictionary
1285 Key value pairs of logged data. Data to be logged is specified by return_data.
1286 """
1287 adict = {}
1289 mask = np.zeros(clusters.shape, dtype=bool)
1291 for cluster in np.unique(clusters[clusters >= 0]):
1292 snippets = all_snippets[clusters == cluster]
1293 mean_eod = np.mean(snippets, axis=0)
1294 mean_eod = mean_eod - np.mean(mean_eod)
1295 mean_eod_fft = np.abs(np.fft.rfft(mean_eod))
1296 freqs = np.fft.rfftfreq(len(mean_eod), 1/samplerate)
1297 low_frequency_ratio = np.sum(mean_eod_fft[freqs<freq_low])/np.sum(mean_eod_fft)
1298 if low_frequency_ratio < threshold: # TODO: check threshold!
1299 mask[clusters==cluster] = True
1301 if verbose > 0:
1302 print('Deleting cluster %i with low frequency ratio of %.3f (min %.3f)' % (cluster, low_frequency_ratio, threshold))
1304 if 'eod_deletion' in return_data:
1305 adict['vals_%d' % cluster] = [mean_eod, mean_eod_fft]
1306 adict['mask_%d' % cluster] = [np.any(mask[clusters==cluster])]
1308 return mask, adict
1311def delete_unreliable_fish(clusters, eod_widths, eod_x, verbose=0, sdict={}):
1312 """ Create a mask for EOD clusters that are either mixed with noise or other fish, or wavefish.
1314 This is the case when the ration between the EOD width and the ISI is too large.
1316 Parameters
1317 ----------
1318 clusters : list of ints
1319 Cluster labels.
1320 eod_widths : list of floats or ints
1321 EOD widths in samples or seconds.
1322 eod_x : list of ints or floats
1323 EOD times in samples or seconds.
1325 verbose : int (optional)
1326 Verbosity level.
1327 sdict : dictionary
1328 Dictionary that is used to log data. This is only used if a dictionary
1329 was created by remove_artefacts().
1330 For logging data in noise and wavefish discarding steps,
1331 see remove_artefacts().
1333 Returns
1334 -------
1335 mask : numpy array of booleans
1336 Set to True for every unreliable EOD.
1337 sdict : dictionary
1338 Key value pairs of logged data. Data is only logged if a dictionary
1339 was instantiated by remove_artefacts().
1340 """
1341 mask = np.zeros(clusters.shape, dtype=bool)
1342 for cluster in np.unique(clusters[clusters >= 0]):
1343 if len(eod_x[cluster == clusters]) < 2:
1344 mask[clusters == cluster] = True
1345 if verbose>0:
1346 print('deleting unreliable cluster %i, number of EOD times %d < 2' % (cluster, len(eod_x[cluster==clusters])))
1347 elif np.max(np.median(eod_widths[clusters==cluster])/np.diff(eod_x[cluster==clusters])) > 0.5:
1348 if verbose>0:
1349 print('deleting unreliable cluster %i, score=%f' % (cluster, np.max(np.median(eod_widths[clusters==cluster])/np.diff(eod_x[cluster==clusters]))))
1350 mask[clusters==cluster] = True
1351 if 'vals_%d' % cluster in sdict:
1352 sdict['vals_%d' % cluster].append(np.median(eod_widths[clusters==cluster])/np.diff(eod_x[cluster==clusters]))
1353 sdict['mask_%d' % cluster].append(any(mask[clusters==cluster]))
1354 return mask, sdict
1357def delete_wavefish_and_sidepeaks(data, clusters, eod_x, eod_widths,
1358 width_fac, max_slope_deviation=0.5,
1359 max_phases=4, verbose=0, sdict={}):
1360 """ Create a mask for EODs that are likely from wavefish, or sidepeaks of bigger EODs.
1362 Parameters
1363 ----------
1364 data : list of floats
1365 Raw recording data.
1366 clusters : list of ints
1367 Cluster labels.
1368 eod_x : list of ints
1369 Indices of EOD times.
1370 eod_widths : list of ints
1371 EOD widths in samples.
1372 width_fac : float
1373 Multiplier for EOD analysis width.
1375 max_slope_deviation: float (optional)
1376 Maximum deviation of position of maximum slope in snippets from
1377 center position in multiples of mean width of EOD.
1378 max_phases : int (optional)
1379 Maximum number of phases for any EOD.
1380 If the mean EOD has more phases than this, it is not a pulse EOD.
1381 verbose : int (optional)
1382 Verbosity level.
1383 sdict : dictionary
1384 Dictionary that is used to log data. This is only used if a dictionary
1385 was created by remove_artefacts().
1386 For logging data in noise and wavefish discarding steps, see remove_artefacts().
1388 Returns
1389 -------
1390 mask_wave: numpy array of booleans
1391 Set to True for every EOD which is a wavefish EOD.
1392 mask_sidepeak: numpy array of booleans
1393 Set to True for every snippet which is centered around a sidepeak of an EOD.
1394 sdict : dictionary
1395 Key value pairs of logged data. Data is only logged if a dictionary
1396 was instantiated by remove_artefacts().
1397 """
1398 mask_wave = np.zeros(clusters.shape, dtype=bool)
1399 mask_sidepeak = np.zeros(clusters.shape, dtype=bool)
1401 for i, cluster in enumerate(np.unique(clusters[clusters >= 0])):
1402 mean_width = np.mean(eod_widths[clusters == cluster])
1403 cutwidth = mean_width*width_fac
1404 current_x = eod_x[(eod_x>cutwidth) & (eod_x<(len(data)-cutwidth))]
1405 current_clusters = clusters[(eod_x>cutwidth) & (eod_x<(len(data)-cutwidth))]
1406 snippets = np.vstack([data[int(x-cutwidth):int(x+cutwidth)]
1407 for x in current_x[current_clusters==cluster]])
1409 # extract information on main peaks and troughs:
1410 mean_eod = np.mean(snippets, axis=0)
1411 mean_eod = mean_eod - np.mean(mean_eod)
1413 # detect peaks and troughs on data + some maxima/minima at the
1414 # end, so that the sides are also considered for peak detection:
1415 pk, tr = detect_peaks(np.concatenate(([-10*mean_eod[0]], mean_eod, [10*mean_eod[-1]])),
1416 np.std(mean_eod))
1417 pk = pk[(pk>0)&(pk<len(mean_eod))]
1418 tr = tr[(tr>0)&(tr<len(mean_eod))]
1420 if len(pk)>0 and len(tr)>0:
1421 idxs = np.sort(np.concatenate((pk, tr)))
1422 slopes = np.abs(np.diff(mean_eod[idxs]))
1423 m_slope = np.argmax(slopes)
1424 centered = np.min(np.abs(idxs[m_slope:m_slope+2] - len(mean_eod)//2))
1426 # compute all height differences of peaks and troughs within snippets.
1427 # if they are all similar, it is probably noise or a wavefish.
1428 idxs = np.sort(np.concatenate((pk, tr)))
1429 hdiffs = np.diff(mean_eod[idxs])
1431 if centered > max_slope_deviation*mean_width: # TODO: check, factor was probably 0.16
1432 if verbose > 0:
1433 print('Deleting cluster %i, which is a sidepeak' % cluster)
1434 mask_sidepeak[clusters==cluster] = True
1436 w_diff = np.abs(np.diff(np.sort(np.concatenate((pk, tr)))))
1438 if np.abs(np.diff(idxs[m_slope:m_slope+2])) < np.mean(eod_widths[clusters==cluster])*0.5 or len(pk) + len(tr)>max_phases or np.min(w_diff)>2*cutwidth/width_fac: #or len(hdiffs[np.abs(hdiffs)>0.5*(np.max(mean_eod)-np.min(mean_eod))])>max_phases:
1439 if verbose>0:
1440 print('Deleting cluster %i, which is a wavefish' % cluster)
1441 mask_wave[clusters==cluster] = True
1442 if 'vals_%d' % cluster in sdict:
1443 sdict['vals_%d' % cluster].append([mean_eod, [pk, tr],
1444 idxs[m_slope:m_slope+2]])
1445 sdict['mask_%d' % cluster].append(any(mask_wave[clusters==cluster]))
1446 sdict['mask_%d' % cluster].append(any(mask_sidepeak[clusters==cluster]))
1448 return mask_wave, mask_sidepeak, sdict
1451def merge_clusters(clusters_1, clusters_2, x_1, x_2, verbose=0):
1452 """ Merge clusters resulting from two clustering methods.
1454 This method only works if clustering is performed on the same EODs
1455 with the same ordering, where there is a one to one mapping from
1456 clusters_1 to clusters_2.
1458 Parameters
1459 ----------
1460 clusters_1: list of ints
1461 EOD cluster labels for cluster method 1.
1462 clusters_2: list of ints
1463 EOD cluster labels for cluster method 2.
1464 x_1: list of ints
1465 Indices of EODs for cluster method 1 (clusters_1).
1466 x_2: list of ints
1467 Indices of EODs for cluster method 2 (clusters_2).
1468 verbose : int (optional)
1469 Verbosity level.
1471 Returns
1472 -------
1473 clusters : list of ints
1474 Merged clusters.
1475 x_merged : list of ints
1476 Merged cluster indices.
1477 mask : 2d numpy array of ints (N, 2)
1478 Mask for clusters that are selected from clusters_1 (mask[:,0]) and
1479 from clusters_2 (mask[:,1]).
1480 """
1481 if verbose > 0:
1482 print('\nMerge cluster:')
1484 # these arrays become 1 for each EOD that is chosen from that array
1485 c1_keep = np.zeros(len(clusters_1))
1486 c2_keep = np.zeros(len(clusters_2))
1488 # add n to one of the cluster lists to avoid overlap
1489 ovl = np.max(clusters_1) + 1
1490 clusters_2[clusters_2!=-1] = clusters_2[clusters_2!=-1] + ovl
1492 remove_clusters = [[]]
1493 keep_clusters = []
1494 og_clusters = [np.copy(clusters_1), np.copy(clusters_2)]
1496 # loop untill done
1497 while True:
1499 # compute unique clusters and cluster sizes
1500 # of cluster that have not been iterated over:
1501 c1_labels, c1_size = unique_counts(clusters_1[(clusters_1 != -1) & (c1_keep == 0)])
1502 c2_labels, c2_size = unique_counts(clusters_2[(clusters_2 != -1) & (c2_keep == 0)])
1504 # if all clusters are done, break from loop:
1505 if len(c1_size) == 0 and len(c2_size) == 0:
1506 break
1508 # if the biggest cluster is in c_p, keep this one and discard all clusters
1509 # on the same indices in c_t:
1510 elif np.argmax([np.max(np.append(c1_size, 0)), np.max(np.append(c2_size, 0))]) == 0:
1512 # remove all the mappings from the other indices
1513 cluster_mappings, _ = unique_counts(clusters_2[clusters_1 == c1_labels[np.argmax(c1_size)]])
1515 clusters_2[np.isin(clusters_2, cluster_mappings)] = -1
1517 c1_keep[clusters_1==c1_labels[np.argmax(c1_size)]] = 1
1519 remove_clusters.append(cluster_mappings)
1520 keep_clusters.append(c1_labels[np.argmax(c1_size)])
1522 if verbose > 0:
1523 print('Keep cluster %i of group 1, delete clusters %s of group 2' % (c1_labels[np.argmax(c1_size)], str(cluster_mappings[cluster_mappings!=-1] - ovl)))
1525 # if the biggest cluster is in c_t, keep this one and discard all mappings in c_p
1526 elif np.argmax([np.max(np.append(c1_size, 0)), np.max(np.append(c2_size, 0))]) == 1:
1528 # remove all the mappings from the other indices
1529 cluster_mappings, _ = unique_counts(clusters_1[clusters_2 == c2_labels[np.argmax(c2_size)]])
1531 clusters_1[np.isin(clusters_1, cluster_mappings)] = -1
1533 c2_keep[clusters_2==c2_labels[np.argmax(c2_size)]] = 1
1535 remove_clusters.append(cluster_mappings)
1536 keep_clusters.append(c2_labels[np.argmax(c2_size)])
1538 if verbose > 0:
1539 print('Keep cluster %i of group 2, delete clusters %s of group 1' % (c2_labels[np.argmax(c2_size)] - ovl, str(cluster_mappings[cluster_mappings!=-1])))
1541 # combine results
1542 clusters = (clusters_1+1)*c1_keep + (clusters_2+1)*c2_keep - 1
1543 x_merged = (x_1)*c1_keep + (x_2)*c2_keep
1545 return clusters, x_merged, np.vstack([c1_keep, c2_keep])
1548def extract_means(data, eod_x, eod_peak_x, eod_tr_x, eod_widths, clusters, samplerate,
1549 width_fac, verbose=0):
1550 """ Extract mean EODs and EOD timepoints for each EOD cluster.
1552 Parameters
1553 ----------
1554 data: list of floats
1555 Raw recording data.
1556 eod_x: list of ints
1557 Locations of EODs in samples.
1558 eod_peak_x : list of ints
1559 Locations of EOD peaks in samples.
1560 eod_tr_x : list of ints
1561 Locations of EOD troughs in samples.
1562 eod_widths: list of ints
1563 EOD widths in samples.
1564 clusters: list of ints
1565 EOD cluster labels
1566 samplerate: float
1567 samplerate of recording
1568 width_fac : float
1569 Multiplication factor for window used to extract EOD.
1571 verbose : int (optional)
1572 Verbosity level.
1574 Returns
1575 -------
1576 mean_eods: list of 2D arrays (3, eod_length)
1577 The average EOD for each detected fish. First column is time in seconds,
1578 second column the mean eod, third column the standard error.
1579 eod_times: list of 1D arrays
1580 For each detected fish the times of EOD in seconds.
1581 eod_peak_times: list of 1D arrays
1582 For each detected fish the times of EOD peaks in seconds.
1583 eod_trough_times: list of 1D arrays
1584 For each detected fish the times of EOD troughs in seconds.
1585 eod_labels: list of ints
1586 Cluster label for each detected fish.
1587 """
1588 mean_eods, eod_times, eod_peak_times, eod_tr_times, eod_heights, cluster_labels = [], [], [], [], [], []
1590 for cluster in np.unique(clusters):
1591 if cluster!=-1:
1592 cutwidth = np.mean(eod_widths[clusters==cluster])*width_fac
1593 current_x = eod_x[(eod_x>cutwidth) & (eod_x<(len(data)-cutwidth))]
1594 current_clusters = clusters[(eod_x>cutwidth) & (eod_x<(len(data)-cutwidth))]
1596 snippets = np.vstack([data[int(x-cutwidth):int(x+cutwidth)] for x in current_x[current_clusters==cluster]])
1597 mean_eod = np.mean(snippets, axis=0)
1598 eod_time = np.arange(len(mean_eod))/samplerate - cutwidth/samplerate
1600 mean_eod = np.vstack([eod_time, mean_eod, np.std(snippets, axis=0)])
1602 mean_eods.append(mean_eod)
1603 eod_times.append(eod_x[clusters==cluster]/samplerate)
1604 eod_heights.append(np.min(mean_eod)-np.max(mean_eod))
1605 eod_peak_times.append(eod_peak_x[clusters==cluster]/samplerate)
1606 eod_tr_times.append(eod_tr_x[clusters==cluster]/samplerate)
1607 cluster_labels.append(cluster)
1609 return [m for _, m in sorted(zip(eod_heights, mean_eods))], [t for _, t in sorted(zip(eod_heights, eod_times))], [pt for _, pt in sorted(zip(eod_heights, eod_peak_times))], [tt for _, tt in sorted(zip(eod_heights, eod_tr_times))], [c for _, c in sorted(zip(eod_heights, cluster_labels))]
1612def find_clipped_clusters(clusters, mean_eods, eod_times, eod_peaktimes, eod_troughtimes,
1613 cluster_labels, width_factor, clip_threshold=0.9, verbose=0):
1614 """ Detect EODs that are clipped and set all clusterlabels of these clipped EODs to -1.
1616 Also return the mean EODs and timepoints of these clipped EODs.
1618 Parameters
1619 ----------
1620 clusters: array of ints
1621 Cluster labels for each EOD in a recording.
1622 mean_eods: list of numpy arrays
1623 Mean EOD waveform for each cluster.
1624 eod_times: list of numpy arrays
1625 EOD timepoints for each EOD cluster.
1626 eod_peaktimes
1627 EOD peaktimes for each EOD cluster.
1628 eod_troughtimes
1629 EOD troughtimes for each EOD cluster.
1630 cluster_labels: numpy array
1631 Unique EOD clusterlabels.
1632 clip_threshold: float
1633 Threshold for detecting clipped EODs.
1635 verbose: int
1636 Verbosity level.
1638 Returns
1639 -------
1640 clusters : array of ints
1641 Cluster labels for each EOD in the recording, where clipped EODs have been set to -1.
1642 clipped_eods : list of numpy arrays
1643 Mean EOD waveforms for each clipped EOD cluster.
1644 clipped_times : list of numpy arrays
1645 EOD timepoints for each clipped EOD cluster.
1646 clipped_peaktimes : list of numpy arrays
1647 EOD peaktimes for each clipped EOD cluster.
1648 clipped_troughtimes : list of numpy arrays
1649 EOD troughtimes for each clipped EOD cluster.
1650 """
1651 clipped_eods, clipped_times, clipped_peaktimes, clipped_troughtimes, clipped_labels = [], [], [], [], []
1653 for mean_eod, eod_time, eod_peaktime, eod_troughtime,label in zip(mean_eods, eod_times, eod_peaktimes, eod_troughtimes, cluster_labels):
1655 if (np.count_nonzero(mean_eod[1]>clip_threshold) > len(mean_eod[1])/(width_factor*2)) or (np.count_nonzero(mean_eod[1] < -clip_threshold) > len(mean_eod[1])/(width_factor*2)):
1656 clipped_eods.append(mean_eod)
1657 clipped_times.append(eod_time)
1658 clipped_peaktimes.append(eod_peaktime)
1659 clipped_troughtimes.append(eod_troughtime)
1660 clipped_labels.append(label)
1661 if verbose>0:
1662 print('clipped pulsefish')
1664 clusters[np.isin(clusters, clipped_labels)] = -1
1666 return clusters, clipped_eods, clipped_times, clipped_peaktimes, clipped_troughtimes
1669def delete_moving_fish(clusters, eod_t, T, eod_heights, eod_widths, samplerate,
1670 min_dt=0.25, stepsize=0.05, sliding_window_factor=2000,
1671 verbose=0, plot_level=0, save_plot=False, save_path='',
1672 ftype='pdf', return_data=[]):
1673 """
1674 Use a sliding window to detect the minimum number of fish detected simultaneously,
1675 then delete all other EOD clusters.
1677 Do this only for EODs within the same width clusters, as a
1678 moving fish will preserve its EOD width.
1680 Parameters
1681 ----------
1682 clusters: list of ints
1683 EOD cluster labels.
1684 eod_t: list of floats
1685 Timepoints of the EODs (in seconds).
1686 T: float
1687 Length of recording (in seconds).
1688 eod_heights: list of floats
1689 EOD amplitudes.
1690 eod_widths: list of floats
1691 EOD widths (in seconds).
1692 samplerate: float
1693 Recording data samplerate.
1695 min_dt : float (optional)
1696 Minimum sliding window size (in seconds).
1697 stepsize : float (optional)
1698 Sliding window stepsize (in seconds).
1699 sliding_window_factor : float
1700 Multiplier for sliding window width,
1701 where the sliding window width = median(EOD_width)*sliding_window_factor.
1702 verbose : int (optional)
1703 Verbosity level.
1704 plot_level : int (optional)
1705 Similar to verbosity levels, but with plots.
1706 Only set to > 0 for debugging purposes.
1707 save_plot : bool (optional)
1708 Set to True to save the plots created by plot_level.
1709 save_path : string (optional)
1710 Path to save data to. Only important if you wish to save data (save_data==True).
1711 ftype : string (optional)
1712 Define the filetype to save the plots in if save_plots is set to True.
1713 Options are: 'png', 'jpg', 'svg' ...
1714 return_data : list of strings (optional)
1715 Keys that specify data to be logged. The key that can be used to log data
1716 in this function is 'moving_fish' (see extract_pulsefish()).
1718 Returns
1719 -------
1720 clusters : list of ints
1721 Cluster labels, where deleted clusters have been set to -1.
1722 window : list of 2 floats
1723 Start and end of window selected for deleting moving fish in seconds.
1724 mf_dict : dictionary
1725 Key value pairs of logged data. Data to be logged is specified by return_data.
1726 """
1727 mf_dict = {}
1729 if len(np.unique(clusters[clusters != -1])) == 0:
1730 return clusters, [0, 1], {}
1732 all_keep_clusters = []
1733 width_classes = merge_gaussians(eod_widths, np.copy(clusters), 0.75)
1735 all_windows = []
1736 all_dts = []
1737 ev_num = 0
1738 for iw, w in enumerate(np.unique(width_classes[clusters >= 0])):
1739 # initialize variables
1740 min_clusters = 100
1741 average_height = 0
1742 sparse_clusters = 100
1743 keep_clusters = []
1745 dt = max(min_dt, np.median(eod_widths[width_classes==w])*sliding_window_factor)
1746 window_start = 0
1747 window_end = dt
1749 wclusters = clusters[width_classes==w]
1750 weod_t = eod_t[width_classes==w]
1751 weod_heights = eod_heights[width_classes==w]
1752 weod_widths = eod_widths[width_classes==w]
1754 all_dts.append(dt)
1756 if verbose>0:
1757 print('sliding window dt = %f'%dt)
1759 # make W dependent on width??
1760 ignore_steps = np.zeros(len(np.arange(0, T-dt+stepsize, stepsize)))
1762 for i, t in enumerate(np.arange(0, T-dt+stepsize, stepsize)):
1763 current_clusters = wclusters[(weod_t>=t)&(weod_t<t+dt)&(wclusters!=-1)]
1764 if len(np.unique(current_clusters))==0:
1765 ignore_steps[i-int(dt/stepsize):i+int(dt/stepsize)] = 1
1766 if verbose>0:
1767 print('No pulsefish in recording at T=%.2f:%.2f' % (t, t+dt))
1770 x = np.arange(0, T-dt+stepsize, stepsize)
1771 y = np.ones(len(x))
1773 running_sum = np.ones(len(np.arange(0, T+stepsize, stepsize)))
1774 ulabs = np.unique(wclusters[wclusters>=0])
1776 # sliding window
1777 for j, (t, ignore_step) in enumerate(zip(x, ignore_steps)):
1778 current_clusters = wclusters[(weod_t>=t)&(weod_t<t+dt)&(wclusters!=-1)]
1779 current_widths = weod_widths[(weod_t>=t)&(weod_t<t+dt)&(wclusters!=-1)]
1781 unique_clusters = np.unique(current_clusters)
1782 y[j] = len(unique_clusters)
1784 if (len(unique_clusters) <= min_clusters) and \
1785 (ignore_step==0) and \
1786 (len(unique_clusters !=1)):
1788 current_labels = np.isin(wclusters, unique_clusters)
1789 current_height = np.mean(weod_heights[current_labels])
1791 # compute nr of clusters that are too sparse
1792 clusters_after_deletion = np.unique(remove_sparse_detections(np.copy(clusters[np.isin(clusters, unique_clusters)]), samplerate*eod_widths[np.isin(clusters, unique_clusters)], samplerate, T))
1793 current_sparse_clusters = len(unique_clusters) - len(clusters_after_deletion[clusters_after_deletion!=-1])
1795 if current_sparse_clusters <= sparse_clusters and \
1796 ((current_sparse_clusters<sparse_clusters) or
1797 (current_height > average_height) or
1798 (len(unique_clusters) < min_clusters)):
1800 keep_clusters = unique_clusters
1801 min_clusters = len(unique_clusters)
1802 average_height = current_height
1803 window_end = t+dt
1804 sparse_clusters = current_sparse_clusters
1806 all_keep_clusters.append(keep_clusters)
1807 all_windows.append(window_end)
1809 if 'moving_fish' in return_data or plot_level>0:
1810 if 'w' in mf_dict:
1811 mf_dict['w'].append(np.median(eod_widths[width_classes==w]))
1812 mf_dict['T'] = T
1813 mf_dict['dt'].append(dt)
1814 mf_dict['clusters'].append(wclusters)
1815 mf_dict['t'].append(weod_t)
1816 mf_dict['fishcount'].append([x+0.5*(x[1]-x[0]), y])
1817 mf_dict['ignore_steps'].append(ignore_steps)
1818 else:
1819 mf_dict['w'] = [np.median(eod_widths[width_classes==w])]
1820 mf_dict['T'] = [T]
1821 mf_dict['dt'] = [dt]
1822 mf_dict['clusters'] = [wclusters]
1823 mf_dict['t'] = [weod_t]
1824 mf_dict['fishcount'] = [[x+0.5*(x[1]-x[0]), y]]
1825 mf_dict['ignore_steps'] = [ignore_steps]
1827 if verbose>0:
1828 print('Estimated nr of pulsefish in recording: %i'%len(all_keep_clusters))
1830 if plot_level>0:
1831 plot_moving_fish(mf_dict['w'], mf_dict['dt'], mf_dict['clusters'],mf_dict['t'],
1832 mf_dict['fishcount'], T, mf_dict['ignore_steps'])
1833 if save_plot:
1834 plt.savefig('%sdelete_moving_fish.%s' % (save_path, ftype))
1835 # empty dict
1836 if 'moving_fish' not in return_data:
1837 mf_dict = {}
1839 # delete all clusters that are not selected
1840 clusters[np.invert(np.isin(clusters, np.concatenate(all_keep_clusters)))] = -1
1842 return clusters, [np.max(all_windows)-np.max(all_dts), np.max(all_windows)], mf_dict
1845def remove_sparse_detections(clusters, eod_widths, samplerate, T,
1846 min_density=0.0005, verbose=0):
1847 """ Remove all EOD clusters that are too sparse
1849 Parameters
1850 ----------
1851 clusters : list of ints
1852 Cluster labels.
1853 eod_widths : list of ints
1854 Cluster widths in samples.
1855 samplerate : float
1856 Samplerate.
1857 T : float
1858 Lenght of recording in seconds.
1859 min_density : float (optional)
1860 Minimum density for realistic EOD detections.
1861 verbose : int (optional)
1862 Verbosity level.
1864 Returns
1865 -------
1866 clusters : list of ints
1867 Cluster labels, where sparse clusters have been set to -1.
1868 """
1869 for c in np.unique(clusters):
1870 if c!=-1:
1872 n = len(clusters[clusters==c])
1873 w = np.median(eod_widths[clusters==c])/samplerate
1875 if n*w < T*min_density:
1876 if verbose>0:
1877 print('cluster %i is too sparse'%c)
1878 clusters[clusters==c] = -1
1879 return clusters