Coverage for src / thunderfish / pulses.py: 0%

604 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-02-09 14:25 +0000

1""" 

2Extract and cluster EOD waverforms of pulse-type electric fish. 

3 

4## Main function 

5 

6- `extract_pulsefish()`: checks for pulse-type fish based on the EOD amplitude and shape. 

7 

8""" 

9 

10import os 

11import numpy as np 

12 

13from scipy import stats 

14from scipy.interpolate import interp1d 

15from sklearn.preprocessing import StandardScaler 

16from sklearn.decomposition import PCA 

17from sklearn.cluster import DBSCAN 

18from sklearn.mixture import BayesianGaussianMixture 

19from sklearn.metrics import pairwise_distances 

20from thunderlab.eventdetection import detect_peaks, median_std_threshold 

21 

22from .pulseplots import * 

23 

24import warnings 

25def warn(*args, **kwargs): 

26 """ 

27 Ignore all warnings. 

28 """ 

29 pass 

30warnings.warn = warn 

31 

32try: 

33 from numba import jit 

34except ImportError: 

35 def jit(*args, **kwargs): 

36 def decorator_jit(func): 

37 return func 

38 return decorator_jit 

39 

40 

41# upgrade numpy functions for backwards compatibility: 

42if not hasattr(np, 'isin'): 

43 np.isin = np.in1d 

44 

45def unique_counts(ar): 

46 """ Find the unique elements of an array and their counts, ignoring shape. 

47 

48 The code is condensed from numpy version 1.17.0. 

49  

50 Parameters 

51 ---------- 

52 ar : numpy array 

53 Input array 

54 

55 Returns 

56 ------- 

57 unique_vaulues : numpy array 

58 Unique values in array ar. 

59 unique_counts : numpy array 

60 Number of instances for each unique value in ar. 

61 """ 

62 try: 

63 return np.unique(ar, return_counts=True) 

64 except TypeError: 

65 ar = np.asanyarray(ar).flatten() 

66 ar.sort() 

67 mask = np.empty(ar.shape, dtype=bool_) 

68 mask[:1] = True 

69 mask[1:] = ar[1:] != ar[:-1] 

70 idx = np.concatenate(np.nonzero(mask) + ([mask.size],)) 

71 return ar[mask], np.diff(idx) 

72 

73 

74########################################################################### 

75 

76 

77def extract_pulsefish(data, rate, amax, width_factor_shape=3, 

78 width_factor_wave=8, width_factor_display=4, 

79 verbose=0, plot_level=0, save_plots=False, 

80 save_path='', ftype='png', return_data=[]): 

81 """ Extract and cluster pulse-type fish EODs from single channel data. 

82  

83 Takes recording data containing an unknown number of pulsefish and extracts the mean  

84 EOD and EOD timepoints for each fish present in the recording. 

85  

86 Parameters 

87 ---------- 

88 data: 1-D array of float 

89 The data to be analysed. 

90 rate: float 

91 Sampling rate of the data in Hertz. 

92 amax: float 

93 Maximum amplitude of data range. 

94 width_factor_shape : float (optional) 

95 Width multiplier used for EOD shape analysis. 

96 EOD snippets are extracted based on width between the  

97 peak and trough multiplied by the width factor. 

98 width_factor_wave : float (optional) 

99 Width multiplier used for wavefish detection. 

100 width_factor_display : float (optional) 

101 Width multiplier used for EOD mean extraction and display. 

102 verbose : int (optional) 

103 Verbosity level. 

104 plot_level : int (optional) 

105 Similar to verbosity levels, but with plots.  

106 Only set to > 0 for debugging purposes. 

107 save_plots : bool (optional) 

108 Set to True to save the plots created by plot_level. 

109 save_path: string (optional) 

110 Path for saving plots. 

111 ftype : string (optional) 

112 Define the filetype to save the plots in if save_plots is set to True. 

113 Options are: 'png', 'jpg', 'svg' ... 

114 return_data : list of strings (optional) 

115 Specify data that should be logged and returned in a dictionary. Each clustering  

116 step has a specific keyword that results in adding different variables to the log dictionary. 

117 Optional keys for return_data and the resulting additional key-value pairs to the log dictionary are: 

118  

119 - 'all_eod_times': 

120 - 'all_times': list of two lists of floats. 

121 All peak (`all_times[0]`) and trough times (`all_times[1]`) extracted 

122 by the peak detection algorithm. Times are given in seconds. 

123 - 'eod_troughtimes': list of 1D arrays. 

124 The timepoints in seconds of each unique extracted EOD cluster, 

125 where each 1D array encodes one cluster. 

126  

127 - 'peak_detection': 

128 - "data": 1D numpy array of floats. 

129 Quadratically interpolated data which was used for peak detection. 

130 - "interp_fac": float. 

131 Interpolation factor of raw data. 

132 - "peaks_1": 1D numpy array of ints. 

133 Peak indices on interpolated data after first peak detection step. 

134 - "troughs_1": 1D numpy array of ints. 

135 Peak indices on interpolated data after first peak detection step. 

136 - "peaks_2": 1D numpy array of ints. 

137 Peak indices on interpolated data after second peak detection step. 

138 - "troughs_2": 1D numpy array of ints. 

139 Peak indices on interpolated data after second peak detection step. 

140 - "peaks_3": 1D numpy array of ints. 

141 Peak indices on interpolated data after third peak detection step. 

142 - "troughs_3": 1D numpy array of ints. 

143 Peak indices on interpolated data after third peak detection step. 

144 - "peaks_4": 1D numpy array of ints. 

145 Peak indices on interpolated data after fourth peak detection step. 

146 - "troughs_4": 1D numpy array of ints. 

147 Peak indices on interpolated data after fourth peak detection step. 

148 

149 - 'all_cluster_steps': 

150 - 'rate': float. 

151 Sampling rate of interpolated data. 

152 - 'EOD_widths': list of three 1D numpy arrays. 

153 The first list entry gives the unique labels of all width clusters 

154 as a list of ints. 

155 The second list entry gives the width values for each EOD in samples 

156 as a 1D numpy array of ints. 

157 The third list entry gives the width labels for each EOD 

158 as a 1D numpy array of ints. 

159 - 'EOD_heights': nested lists (2 layers) of three 1D numpy arrays. 

160 The first list entry gives the unique labels of all height clusters 

161 as a list of ints for each width cluster. 

162 The second list entry gives the height values for each EOD 

163 as a 1D numpy array of floats for each width cluster. 

164 The third list entry gives the height labels for each EOD 

165 as a 1D numpy array of ints for each width cluster. 

166 - 'EOD_shapes': nested lists (3 layers) of three 1D numpy arrays 

167 The first list entry gives the raw EOD snippets as a 2D numpy array 

168 for each height cluster in a width cluster. 

169 The second list entry gives the snippet PCA values for each EOD 

170 as a 2D numpy array of floats for each height cluster in a width cluster. 

171 The third list entry gives the shape labels for each EOD as a 1D numpy array 

172 of ints for each height cluster in a width cluster. 

173 - 'discarding_masks': Nested lists (two layers) of 1D numpy arrays. 

174 The masks of EODs that are discarded by the discarding step of the algorithm. 

175 The masks are 1D boolean arrays where instances that are set to True are 

176 discarded by the algorithm. Discarding masks are saved in nested lists 

177 that represent the width and height clusters. 

178 - 'merge_masks': Nested lists (two layers) of 2D numpy arrays. 

179 The masks of EODs that are discarded by the merging step of the algorithm. 

180 The masks are 2D boolean arrays where for each sample point `i` either 

181 `merge_mask[i,0]` or `merge_mask[i,1]` is set to True. Here, merge_mask[:,0] 

182 represents the peak-centered clusters and `merge_mask[:,1]` represents the 

183 trough-centered clusters. Merge masks are saved in nested lists that 

184 represent the width and height clusters. 

185 

186 - 'BGM_width': 

187 - 'BGM_width': dictionary 

188 - 'x': 1D numpy array of floats. 

189 BGM input values (in this case the EOD widths), 

190 - 'use_log': boolean. 

191 True if the z-scored logarithm of the data was used as BGM input. 

192 - 'BGM': list of three 1D numpy arrays. 

193 The first instance are the weights of the Gaussian fits. 

194 The second instance are the means of the Gaussian fits. 

195 The third instance are the variances of the Gaussian fits. 

196 - 'labels': 1D numpy array of ints. 

197 Labels defined by BGM model (before merging based on merge factor). 

198 - xlab': string. 

199 Label for plot (defines the units of the BGM data). 

200 

201 - 'BGM_height': 

202 This key adds a new dictionary for each width cluster. 

203 - 'BGM_height_*n*' : dictionary, where *n* defines the width cluster as an int. 

204 - 'x': 1D numpy array of floats. 

205 BGM input values (in this case the EOD heights), 

206 - 'use_log': boolean. 

207 True if the z-scored logarithm of the data was used as BGM input. 

208 - 'BGM': list of three 1D numpy arrays. 

209 The first instance are the weights of the Gaussian fits. 

210 The second instance are the means of the Gaussian fits. 

211 The third instance are the variances of the Gaussian fits. 

212 - 'labels': 1D numpy array of ints. 

213 Labels defined by BGM model (before merging based on merge factor). 

214 - 'xlab': string. 

215 Label for plot (defines the units of the BGM data). 

216 

217 - 'snippet_clusters': 

218 This key adds a new dictionary for each height cluster. 

219 - 'snippet_clusters*_n_m_p*' : dictionary, where *n* defines the width cluster 

220 (int), *m* defines the height cluster (int) and *p* defines shape clustering 

221 on peak or trough centered EOD snippets (string: 'peak' or 'trough'). 

222 - 'raw_snippets': 2D numpy array (nsamples, nfeatures). 

223 Raw EOD snippets. 

224 - 'snippets': 2D numpy array. 

225 Normalized EOD snippets. 

226 - 'features': 2D numpy array.(nsamples, nfeatures) 

227 PCA values for each normalized EOD snippet. 

228 - 'clusters': 1D numpy array of ints. 

229 Cluster labels. 

230 - 'rate': float. 

231 Sampling rate of snippets. 

232 

233 - 'eod_deletion': 

234 This key adds two dictionaries for each (peak centered) shape cluster, 

235 where *cluster* (int) is the unique shape cluster label. 

236 - 'mask_*cluster*' : list of four booleans. 

237 The mask for each cluster discarding step.  

238 The first instance represents the artefact masks, where artefacts 

239 are set to True. 

240 The second instance represents the unreliable cluster masks, 

241 where unreliable clusters are set to True. 

242 The third instance represents the wavefish masks, where wavefish 

243 are set to True. 

244 The fourth instance represents the sidepeak masks, where sidepeaks 

245 are set to True. 

246 - 'vals_*cluster*' : list of lists. 

247 All variables that are used for each cluster deletion step. 

248 The first instance is a list of two 1D numpy arrays: the mean EOD and 

249 the FFT of that mean EOD. 

250 The second instance is a 1D numpy array with all EOD width to ISI ratios. 

251 The third instance is a list with three entries:  

252 The first entry is a 1D numpy array zoomed out version of the mean EOD. 

253 The second entry is a list of two 1D numpy arrays that define the peak 

254 and trough indices of the zoomed out mean EOD. 

255 The third entry contains a list of two values that represent the 

256 peak-trough pair in the zoomed out mean EOD with the largest height 

257 difference. 

258 - 'rate' : float. 

259 EOD snippet sampling rate. 

260 

261 - 'masks':  

262 - 'masks' : 2D numpy array (4,N). 

263 Each row contains masks for each EOD detected by the EOD peakdetection step.  

264 The first row defines the artefact masks, the second row defines the 

265 unreliable EOD masks,  

266 the third row defines the wavefish masks and the fourth row defines 

267 the sidepeak masks. 

268 

269 - 'moving_fish': 

270 - 'moving_fish': dictionary. 

271 - 'w' : list of floats. 

272 Median width for each width cluster that the moving fish algorithm is 

273 computed on (in seconds). 

274 - 'T' : list of floats. 

275 Lenght of analyzed recording for each width cluster (in seconds). 

276 - 'dt' : list of floats. 

277 Sliding window size (in seconds) for each width cluster. 

278 - 'clusters' : list of 1D numpy int arrays. 

279 Cluster labels for each EOD cluster in a width cluster. 

280 - 't' : list of 1D numpy float arrays. 

281 EOD emission times for each EOD in a width cluster. 

282 - 'fishcount' : list of lists. 

283 Sliding window timepoints and fishcounts for each width cluster. 

284 - 'ignore_steps' : list of 1D int arrays. 

285 Mask for fishcounts that were ignored (ignored if True) in the 

286 moving_fish analysis. 

287  

288 Returns 

289 ------- 

290 mean_eods: list of 2D arrays (3, eod_length) 

291 The average EOD for each detected fish. First column is time in seconds, 

292 second column the mean eod, third column the standard error. 

293 eod_times: list of 1D arrays 

294 For each detected fish the times of EOD peaks or troughs in seconds. 

295 Use these timepoints for EOD averaging. 

296 eod_peaktimes: list of 1D arrays 

297 For each detected fish the times of EOD peaks in seconds. 

298 zoom_window: tuple of floats 

299 Start and endtime of suggested window for plotting EOD timepoints. 

300 log_dict: dictionary 

301 Dictionary with logged variables, where variables to log are specified 

302 by `return_data`. 

303 """ 

304 if verbose > 0: 

305 print('') 

306 if verbose > 1: 

307 print(70*'#') 

308 print('##### extract_pulsefish', 46*'#') 

309 

310 if (save_plots and plot_level>0 and save_path): 

311 # create folder to save things in. 

312 if not os.path.exists(save_path): 

313 os.makedirs(save_path) 

314 else: 

315 save_path = '' 

316 

317 mean_eods, eod_times, eod_peaktimes, zoom_window = [], [], [], [] 

318 log_dict = {} 

319 

320 # interpolate: 

321 i_rate = 500000.0 

322 #i_rate = rate 

323 try: 

324 f = interp1d(np.arange(len(data))/rate, data, kind='quadratic') 

325 i_data = f(np.arange(0.0, (len(data)-1)/rate, 1.0/i_rate)) 

326 except MemoryError: 

327 i_rate = rate 

328 i_data = data 

329 log_dict['data'] = i_data # TODO: could be removed 

330 log_dict['rate'] = i_rate # TODO: could be removed 

331 log_dict['i_data'] = i_data 

332 log_dict['i_rate'] = i_rate 

333 # log_dict["interp_fac"] = interp_fac # TODO: is not set anymore 

334 

335 # standard deviation of data in small snippets: 

336 win_size = int(0.002*rate) # 2ms windows 

337 threshold = median_std_threshold(data, win_size) # TODO make this a parameter 

338 

339 # extract peaks: 

340 if 'peak_detection' in return_data: 

341 x_peak, x_trough, eod_heights, eod_widths, pd_log_dict = \ 

342 detect_pulses(i_data, i_rate, threshold, 

343 width_fac=np.max([width_factor_shape, 

344 width_factor_display, 

345 width_factor_wave]), 

346 verbose=verbose, return_data=True) 

347 log_dict.update(pd_log_dict) 

348 else: 

349 x_peak, x_trough, eod_heights, eod_widths = \ 

350 detect_pulses(i_data, i_rate, threshold, 

351 width_fac=np.max([width_factor_shape, 

352 width_factor_display, 

353 width_factor_wave]), 

354 verbose=verbose, return_data=False) 

355 

356 if len(x_peak) > 0: 

357 # cluster 

358 clusters, x_merge, c_log_dict = cluster(x_peak, x_trough, 

359 eod_heights, 

360 eod_widths, i_data, 

361 i_rate, 

362 width_factor_shape, 

363 width_factor_wave, 

364 merge_threshold_height=0.1*amax, 

365 verbose=verbose, 

366 plot_level=plot_level-1, 

367 save_plots=save_plots, 

368 save_path=save_path, 

369 ftype=ftype, 

370 return_data=return_data) 

371 

372 # extract mean eods and times 

373 mean_eods, eod_times, eod_peaktimes, eod_troughtimes, cluster_labels = \ 

374 extract_means(i_data, x_merge, x_peak, x_trough, eod_widths, clusters, 

375 i_rate, width_factor_display, verbose=verbose) 

376 

377 # determine clipped clusters (save them, but ignore in other steps) 

378 clusters, clipped_eods, clipped_times, clipped_peaktimes, clipped_troughtimes = \ 

379 find_clipped_clusters(clusters, mean_eods, eod_times, 

380 eod_peaktimes, eod_troughtimes, 

381 cluster_labels, width_factor_display, 

382 verbose=verbose) 

383 

384 # delete the moving fish 

385 clusters, zoom_window, mf_log_dict = \ 

386 delete_moving_fish(clusters, x_merge/i_rate, len(data)/rate, 

387 eod_heights, eod_widths/i_rate, i_rate, 

388 verbose=verbose, plot_level=plot_level-1, 

389 save_plot=save_plots, 

390 save_path=save_path, ftype=ftype, 

391 return_data=return_data) 

392 

393 if 'moving_fish' in return_data: 

394 log_dict['moving_fish'] = mf_log_dict 

395 

396 clusters = remove_sparse_detections(clusters, eod_widths, i_rate, 

397 len(data)/rate, verbose=verbose) 

398 

399 # extract mean eods 

400 mean_eods, eod_times, eod_peaktimes, eod_troughtimes, cluster_labels = \ 

401 extract_means(i_data, x_merge, x_peak, x_trough, eod_widths, 

402 clusters, i_rate, width_factor_display, 

403 verbose=verbose) 

404 

405 mean_eods.extend(clipped_eods) 

406 eod_times.extend(clipped_times) 

407 eod_peaktimes.extend(clipped_peaktimes) 

408 eod_troughtimes.extend(clipped_troughtimes) 

409 

410 if plot_level > 0: 

411 plot_all(data, eod_peaktimes, eod_troughtimes, rate, mean_eods) 

412 if save_plots: 

413 plt.savefig('%sextract_pulsefish_results.%s' % (save_path, ftype)) 

414 if save_plots: 

415 plt.close('all') 

416 

417 if 'all_eod_times' in return_data: 

418 log_dict['all_times'] = [x_peak/i_rate, x_trough/i_rate] 

419 log_dict['eod_troughtimes'] = eod_troughtimes 

420 

421 log_dict.update(c_log_dict) 

422 

423 if verbose > 0: 

424 print('') 

425 

426 return mean_eods, eod_times, eod_peaktimes, zoom_window, log_dict 

427 

428 

429def detect_pulses(data, rate, thresh, min_rel_slope_diff=0.25, 

430 min_width=0.00005, max_width=0.01, width_fac=5.0, 

431 verbose=0, return_data=False): 

432 """Detect pulses in data. 

433 

434 Was `def extract_eod_times(data, rate, width_factor, 

435 interp_freq=500000, max_peakwidth=0.01, 

436 min_peakwidth=None, verbose=0, return_data=[], 

437 save_path='')` before. 

438 

439 Parameters 

440 ---------- 

441 data: 1-D array of float 

442 The data to be analysed. 

443 rate: float 

444 Sampling rate of the data. 

445 thresh: float 

446 Threshold for peak and trough detection via `detect_peaks()`. 

447 Must be a positive number that sets the minimum difference 

448 between a peak and a trough. 

449 min_rel_slope_diff: float 

450 Minimum required difference between left and right slope (between 

451 peak and troughs) relative to mean slope for deciding which trough 

452 to take besed on slope difference. 

453 min_width: float 

454 Minimum width (peak-trough distance) of pulses in seconds. 

455 max_width: float 

456 Maximum width (peak-trough distance) of pulses in seconds. 

457 width_fac: float 

458 Pulses extend plus or minus `width_fac` times their width 

459 (distance between peak and assigned trough). 

460 Only pulses are returned that can fully be analysed with this width. 

461 verbose : int (optional) 

462 Verbosity level. 

463 return_data : bool 

464 If `True` data of this function is logged and returned (see 

465 extract_pulsefish()). 

466 

467 Returns 

468 ------- 

469 peak_indices: array of ints 

470 Indices of EOD peaks in data. 

471 trough_indices: array of ints 

472 Indices of EOD troughs in data. There is one x_trough for each x_peak. 

473 heights: array of floats 

474 EOD heights for each x_peak. 

475 widths: array of ints 

476 EOD widths for each x_peak (in samples). 

477 peak_detection_result : dictionary 

478 Key value pairs of logged data. 

479 This is only returned if `return_data` is `True`. 

480 

481 """ 

482 peak_detection_result = {} 

483 

484 # detect peaks and troughs in the data: 

485 peak_indices, trough_indices = detect_peaks(data, thresh) 

486 if verbose > 0: 

487 print('Peaks/troughs detected in data: %5d %5d' 

488 % (len(peak_indices), len(trough_indices))) 

489 if return_data: 

490 peak_detection_result.update(peaks_1=np.array(peak_indices), 

491 troughs_1=np.array(trough_indices)) 

492 if len(peak_indices) < 2 or \ 

493 len(trough_indices) < 2 or \ 

494 len(peak_indices) > len(data)/20: 

495 # TODO: if too many peaks increase threshold! 

496 if verbose > 0: 

497 print('No or too many peaks/troughs detected in data.') 

498 if return_data: 

499 return np.array([], dtype=int), np.array([], dtype=int), \ 

500 np.array([]), np.array([], dtype=int), peak_detection_result 

501 else: 

502 return np.array([], dtype=int), np.array([], dtype=int), \ 

503 np.array([]), np.array([], dtype=int) 

504 

505 # assign troughs to peaks: 

506 peak_indices, trough_indices, heights, widths, slopes = \ 

507 assign_side_peaks(data, peak_indices, trough_indices, min_rel_slope_diff) 

508 if verbose > 1: 

509 print('Number of peaks after assigning side-peaks: %5d' 

510 % (len(peak_indices))) 

511 if return_data: 

512 peak_detection_result.update(peaks_2=np.array(peak_indices), 

513 troughs_2=np.array(trough_indices)) 

514 

515 # check widths: 

516 keep = ((widths>min_width*rate) & (widths<max_width*rate)) 

517 peak_indices = peak_indices[keep] 

518 trough_indices = trough_indices[keep] 

519 heights = heights[keep] 

520 widths = widths[keep] 

521 slopes = slopes[keep] 

522 if verbose > 1: 

523 print('Number of peaks after checking pulse width: %5d' 

524 % (len(peak_indices))) 

525 if return_data: 

526 peak_detection_result.update(peaks_3=np.array(peak_indices), 

527 troughs_3=np.array(trough_indices)) 

528 

529 # discard connected peaks: 

530 same = np.nonzero(trough_indices[:-1] == trough_indices[1:])[0] 

531 keep = np.ones(len(trough_indices), dtype=bool) 

532 for i in same: 

533 # same troughs at trough_indices[i] and trough_indices[i+1]: 

534 s = slopes[i:i+2] 

535 rel_slopes = np.abs(np.diff(s))[0]/np.mean(s) 

536 if rel_slopes > min_rel_slope_diff: 

537 keep[i+(s[1]<s[0])] = False 

538 else: 

539 keep[i+(heights[i+1]<heights[i])] = False 

540 peak_indices = peak_indices[keep] 

541 trough_indices = trough_indices[keep] 

542 heights = heights[keep] 

543 widths = widths[keep] 

544 if verbose > 1: 

545 print('Number of peaks after merging pulses: %5d' 

546 % (len(peak_indices))) 

547 if return_data: 

548 peak_detection_result.update(peaks_4=np.array(peak_indices), 

549 troughs_4=np.array(trough_indices)) 

550 if len(peak_indices) == 0: 

551 if verbose > 0: 

552 print('No peaks remain as pulse candidates.') 

553 if return_data: 

554 return np.array([], dtype=int), np.array([], dtype=int), \ 

555 np.array([]), np.array([], dtype=int), peak_detection_result 

556 else: 

557 return np.array([], dtype=int), np.array([], dtype=int), \ 

558 np.array([]), np.array([], dtype=int) 

559 

560 # only take those where the maximum cutwidth does not cause issues - 

561 # if the width_fac times the width + x is more than length. 

562 keep = ((peak_indices - widths > 0) & 

563 (peak_indices + widths < len(data)) & 

564 (trough_indices - widths > 0) & 

565 (trough_indices + widths < len(data))) 

566 

567 if verbose > 0: 

568 print('Remaining peaks after EOD extraction: %5d' 

569 % (np.sum(keep))) 

570 print('') 

571 

572 if return_data: 

573 return peak_indices[keep], trough_indices[keep], \ 

574 heights[keep], widths[keep], peak_detection_result 

575 else: 

576 return peak_indices[keep], trough_indices[keep], \ 

577 heights[keep], widths[keep] 

578 

579 

580@jit(nopython=True) 

581def assign_side_peaks(data, peak_indices, trough_indices, 

582 min_rel_slope_diff=0.25): 

583 """Assign to each peak the trough resulting in a pulse with the steepest slope or largest height. 

584 

585 The slope between a peak and a trough is computed as the height 

586 difference divided by the distance between peak and trough. If the 

587 slopes between the left and the right trough differ by less than 

588 `min_rel_slope_diff`, then just the heigths between and the two 

589 troughs relative to the peak are compared. 

590 

591 Was `def detect_eod_peaks(data, main_indices, side_indices, 

592 max_width=20, min_width=2, verbose=0)` before. 

593 

594 Parameters 

595 ---------- 

596 data: array of floats 

597 Data in which the events were detected. 

598 peak_indices: array of ints 

599 Indices of the detected peaks in the data time series. 

600 trough_indices: array of ints 

601 Indices of the detected troughs in the data time series.  

602 min_rel_slope_diff: float 

603 Minimum required difference of left and right slope relative 

604 to mean slope. 

605 

606 Returns 

607 ------- 

608 peak_indices: array of ints 

609 Peak indices. Same as input `peak_indices` but potentially shorter 

610 by one or two elements. 

611 trough_indices: array of ints 

612 Corresponding trough indices of trough to the left or right 

613 of the peaks. 

614 heights: array of floats 

615 Peak heights (distance between peak and corresponding trough amplitude) 

616 widths: array of ints 

617 Peak widths (distance between peak and corresponding trough indices) 

618 slopes: array of floats 

619 Peak slope (height divided by width) 

620 """ 

621 # is a main or side peak first? 

622 peak_first = int(peak_indices[0] < trough_indices[0]) 

623 # is a main or side peak last? 

624 peak_last = int(peak_indices[-1] > trough_indices[-1]) 

625 # ensure all peaks to have side peaks (troughs) at both sides, 

626 # i.e. troughs at same index and next index are before and after peak: 

627 peak_indices = peak_indices[peak_first:len(peak_indices)-peak_last] 

628 y = data[peak_indices] 

629 

630 # indices of troughs on the left and right side of main peaks: 

631 l_indices = np.arange(len(peak_indices)) 

632 r_indices = l_indices + 1 

633 

634 # indices, distance to peak, height, and slope of left troughs: 

635 l_side_indices = trough_indices[l_indices] 

636 l_distance = np.abs(peak_indices - l_side_indices) 

637 l_height = np.abs(y - data[l_side_indices]) 

638 l_slope = np.abs(l_height/l_distance) 

639 

640 # indices, distance to peak, height, and slope of right troughs: 

641 r_side_indices = trough_indices[r_indices] 

642 r_distance = np.abs(r_side_indices - peak_indices) 

643 r_height = np.abs(y - data[r_side_indices]) 

644 r_slope = np.abs(r_height/r_distance) 

645 

646 # which trough to assign to the peak? 

647 # - either the one with the steepest slope, or 

648 # - when slopes are similar on both sides 

649 # (within `min_rel_slope_diff` difference), 

650 # the trough with the maximum height difference to the peak: 

651 rel_slopes = np.abs(l_slope-r_slope)/(0.5*(l_slope+r_slope)) 

652 take_slopes = rel_slopes > min_rel_slope_diff 

653 take_left = l_height > r_height 

654 take_left[take_slopes] = l_slope[take_slopes] > r_slope[take_slopes] 

655 

656 # assign troughs, heights, widths, and slopes: 

657 trough_indices = np.where(take_left, 

658 trough_indices[:-1], trough_indices[1:]) 

659 heights = np.where(take_left, l_height, r_height) 

660 widths = np.where(take_left, l_distance, r_distance) 

661 slopes = np.where(take_left, l_slope, r_slope) 

662 

663 return peak_indices, trough_indices, heights, widths, slopes 

664 

665 

666def cluster(eod_xp, eod_xt, eod_heights, eod_widths, data, rate, 

667 width_factor_shape, width_factor_wave, n_gaus_height=10, 

668 merge_threshold_height=0.1, n_gaus_width=3, 

669 merge_threshold_width=0.5, minp=10, verbose=0, 

670 plot_level=0, save_plots=False, save_path='', ftype='pdf', 

671 return_data=[]): 

672 """Cluster EODs. 

673  

674 First cluster on EOD widths using a Bayesian Gaussian 

675 Mixture (BGM) model, then cluster on EOD heights using a 

676 BGM model. Lastly, cluster on EOD waveform with DBSCAN. 

677 Clustering on EOD waveform is performed twice, once on 

678 peak-centered EODs and once on trough-centered EODs. 

679 Non-pulsetype EOD clusters are deleted, and clusters are 

680 merged afterwards. 

681 

682 Parameters 

683 ---------- 

684 eod_xp : list of ints 

685 Location of EOD peaks in indices. 

686 eod_xt: list of ints 

687 Locations of EOD troughs in indices. 

688 eod_heights: list of floats 

689 EOD heights. 

690 eod_widths: list of ints 

691 EOD widths in samples. 

692 data: array of floats 

693 Data in which to detect pulse EODs. 

694 rate : float 

695 Sampling rate of `data`. 

696 width_factor_shape : float 

697 Multiplier for snippet extraction width. This factor is 

698 multiplied with the width between the peak and through of a 

699 single EOD. 

700 width_factor_wave : float 

701 Multiplier for wavefish extraction width. 

702 n_gaus_height : int (optional) 

703 Number of gaussians to use for the clustering based on EOD height. 

704 merge_threshold_height : float (optional) 

705 Threshold for merging clusters that are similar in height. 

706 n_gaus_width : int (optional) 

707 Number of gaussians to use for the clustering based on EOD width. 

708 merge_threshold_width : float (optional) 

709 Threshold for merging clusters that are similar in width. 

710 minp : int (optional) 

711 Minimum number of points for core clusters (DBSCAN). 

712 verbose : int (optional) 

713 Verbosity level. 

714 plot_level : int (optional) 

715 Similar to verbosity levels, but with plots.  

716 Only set to > 0 for debugging purposes. 

717 save_plots : bool (optional) 

718 Set to True to save created plots. 

719 save_path : string (optional) 

720 Path to save plots to. Only used if save_plots==True. 

721 ftype : string (optional) 

722 Filetype to save plot images in. 

723 return_data : list of strings (optional) 

724 Keys that specify data to be logged. Keys that can be used to log data 

725 in this function are: 'all_cluster_steps', 'BGM_width', 'BGM_height', 

726 'snippet_clusters', 'eod_deletion' (see extract_pulsefish()). 

727 

728 Returns 

729 ------- 

730 labels : list of ints 

731 EOD cluster labels based on height and EOD waveform. 

732 x_merge : list of ints 

733 Locations of EODs in clusters. 

734 saved_data : dictionary 

735 Key value pairs of logged data. Data to be logged is specified 

736 by return_data. 

737 

738 """ 

739 saved_data = {} 

740 

741 if plot_level>0 or 'all_cluster_steps' in return_data: 

742 all_heightlabels = [] 

743 all_shapelabels = [] 

744 all_snippets = [] 

745 all_features = [] 

746 all_heights = [] 

747 all_unique_heightlabels = [] 

748 

749 all_p_clusters = -1 * np.ones(len(eod_xp)) 

750 all_t_clusters = -1 * np.ones(len(eod_xp)) 

751 artefact_masks_p = np.ones(len(eod_xp), dtype=bool) 

752 artefact_masks_t = np.ones(len(eod_xp), dtype=bool) 

753 

754 x_merge = -1 * np.ones(len(eod_xp)) 

755 

756 max_label_p = 0 # keep track of the labels so that no labels are overwritten 

757 max_label_t = 0 

758 

759 # loop only over height clusters that are bigger than minp 

760 # first cluster on width 

761 width_labels, bgm_log_dict = BGM(1000*eod_widths/rate, 

762 merge_threshold_width, 

763 n_gaus_width, use_log=False, 

764 verbose=verbose-1, 

765 plot_level=plot_level-1, 

766 xlabel='width [ms]', 

767 save_plot=save_plots, 

768 save_path=save_path, 

769 save_name='width', ftype=ftype, 

770 return_data=return_data) 

771 saved_data.update(bgm_log_dict) 

772 

773 if verbose > 0: 

774 print('Clusters generated based on EOD width:') 

775 for l in np.unique(width_labels): 

776 print(f'N_{l} = {len(width_labels[width_labels==l]):4d} h_{l} = {np.mean(eod_widths[width_labels==l]):.4f}') 

777 

778 w_labels, w_counts = unique_counts(width_labels) 

779 unique_width_labels = w_labels[w_counts>minp] 

780 

781 for wi, width_label in enumerate(unique_width_labels): 

782 

783 # select only features in one width cluster at a time 

784 w_eod_widths = eod_widths[width_labels==width_label] 

785 w_eod_heights = eod_heights[width_labels==width_label] 

786 w_eod_xp = eod_xp[width_labels==width_label] 

787 w_eod_xt = eod_xt[width_labels==width_label] 

788 width = int(width_factor_shape*np.median(w_eod_widths)) 

789 if width > w_eod_xp[0]: 

790 width = w_eod_xp[0] 

791 if width > w_eod_xt[0]: 

792 width = w_eod_xt[0] 

793 if width > len(data) - w_eod_xp[-1]: 

794 width = len(data) - w_eod_xp[-1] 

795 if width > len(data) - w_eod_xt[-1]: 

796 width = len(data) - w_eod_xt[-1] 

797 

798 wp_clusters = -1 * np.ones(len(w_eod_xp)) 

799 wt_clusters = -1 * np.ones(len(w_eod_xp)) 

800 wartefact_mask = np.ones(len(w_eod_xp)) 

801 

802 # determine height labels 

803 raw_p_snippets, p_snippets, p_features, p_bg_ratio = \ 

804 extract_snippet_features(data, w_eod_xp, w_eod_heights, width) 

805 raw_t_snippets, t_snippets, t_features, t_bg_ratio = \ 

806 extract_snippet_features(data, w_eod_xt, w_eod_heights, width) 

807 

808 height_labels, bgm_log_dict = \ 

809 BGM(w_eod_heights, min(merge_threshold_height, 

810 np.median(np.min(np.vstack([p_bg_ratio, t_bg_ratio]), 

811 axis=0))), n_gaus_height, use_log=True, 

812 verbose=verbose-1, plot_level=plot_level-1, xlabel = 

813 'height [a.u.]', save_plot=save_plots, 

814 save_path=save_path, save_name = 'height_%d' % wi, 

815 ftype=ftype, return_data=return_data) 

816 saved_data.update(bgm_log_dict) 

817 

818 if verbose > 0: 

819 print('Clusters generated based on EOD height:') 

820 for l in np.unique(height_labels): 

821 print(f'N_{l} = {len(height_labels[height_labels==l]):4d} h_{l} = {np.mean(w_eod_heights[height_labels==l]):.4f}') 

822 

823 h_labels, h_counts = unique_counts(height_labels) 

824 unique_height_labels = h_labels[h_counts>minp] 

825 

826 if plot_level > 0 or 'all_cluster_steps' in return_data: 

827 all_heightlabels.append(height_labels) 

828 all_heights.append(w_eod_heights) 

829 all_unique_heightlabels.append(unique_height_labels) 

830 shape_labels = [] 

831 cfeatures = [] 

832 csnippets = [] 

833 

834 for hi, height_label in enumerate(unique_height_labels): 

835 

836 h_eod_widths = w_eod_widths[height_labels==height_label] 

837 h_eod_heights = w_eod_heights[height_labels==height_label] 

838 h_eod_xp = w_eod_xp[height_labels==height_label] 

839 h_eod_xt = w_eod_xt[height_labels==height_label] 

840 

841 p_clusters = cluster_on_shape(p_features[height_labels==height_label], 

842 p_bg_ratio, minp, verbose=0) 

843 t_clusters = cluster_on_shape(t_features[height_labels==height_label], 

844 t_bg_ratio, minp, verbose=0) 

845 

846 if plot_level > 1: 

847 plot_feature_extraction(raw_p_snippets[height_labels==height_label], 

848 p_snippets[height_labels==height_label], 

849 p_features[height_labels==height_label], 

850 p_clusters, 1/rate, 0) 

851 plt.savefig('%sDBSCAN_peak_w%i_h%i.%s' % (save_path, wi, hi, ftype)) 

852 plot_feature_extraction(raw_t_snippets[height_labels==height_label], 

853 t_snippets[height_labels==height_label], 

854 t_features[height_labels==height_label], 

855 t_clusters, 1/rate, 1) 

856 plt.savefig('%sDBSCAN_trough_w%i_h%i.%s' % (save_path, wi, hi, ftype)) 

857 

858 if 'snippet_clusters' in return_data: 

859 saved_data[f'snippet_clusters_{width_label}_{height_label}_peak'] = { 

860 'raw_snippets': raw_p_snippets[height_labels==height_label], 

861 'snippets': p_snippets[height_labels==height_label], 

862 'features': p_features[height_labels==height_label], 

863 'clusters': p_clusters, 

864 'rate': rate} 

865 saved_data['snippet_clusters_{width_label}_{height_label}_trough'] = { 

866 'raw_snippets': raw_t_snippets[height_labels==height_label], 

867 'snippets': t_snippets[height_labels==height_label], 

868 'features': t_features[height_labels==height_label], 

869 'clusters': t_clusters, 

870 'rate': rate} 

871 

872 if plot_level > 0 or 'all_cluster_steps' in return_data: 

873 shape_labels.append([p_clusters, t_clusters]) 

874 cfeatures.append([p_features[height_labels==height_label], 

875 t_features[height_labels==height_label]]) 

876 csnippets.append([p_snippets[height_labels==height_label], 

877 t_snippets[height_labels==height_label]]) 

878 

879 p_clusters[p_clusters==-1] = -max_label_p - 1 

880 wp_clusters[height_labels==height_label] = p_clusters + max_label_p 

881 max_label_p = max(np.max(wp_clusters), np.max(all_p_clusters)) + 1 

882 

883 t_clusters[t_clusters==-1] = -max_label_t - 1 

884 wt_clusters[height_labels==height_label] = t_clusters + max_label_t 

885 max_label_t = max(np.max(wt_clusters), np.max(all_t_clusters)) + 1 

886 

887 if verbose > 0: 

888 if np.max(wp_clusters) == -1: 

889 print(f'No EOD peaks in width cluster {width_label}') 

890 else: 

891 unique_clusters = np.unique(wp_clusters[wp_clusters!=-1]) 

892 if len(unique_clusters) > 1: 

893 print(f'{len(unique_clusters)} different EOD peaks in width cluster {width_label}') 

894 

895 if plot_level > 0 or 'all_cluster_steps' in return_data: 

896 all_shapelabels.append(shape_labels) 

897 all_snippets.append(csnippets) 

898 all_features.append(cfeatures) 

899 

900 # for each cluster, save fft + label 

901 # so I end up with features for each label, and the masks. 

902 # then I can extract e.g. first artefact or wave etc. 

903 

904 # remove artefacts here, based on the mean snippets ffts. 

905 artefact_masks_p[width_labels==width_label], sdict = \ 

906 remove_artefacts(p_snippets, wp_clusters, rate, 

907 verbose=verbose-1, return_data=return_data) 

908 saved_data.update(sdict) 

909 artefact_masks_t[width_labels==width_label], _ = \ 

910 remove_artefacts(t_snippets, wt_clusters, rate, 

911 verbose=verbose-1, return_data=return_data) 

912 

913 # update maxlab so that no clusters are overwritten 

914 all_p_clusters[width_labels==width_label] = wp_clusters 

915 all_t_clusters[width_labels==width_label] = wt_clusters 

916 

917 # remove all non-reliable clusters 

918 unreliable_fish_mask_p, saved_data = \ 

919 delete_unreliable_fish(all_p_clusters, eod_widths, eod_xp, 

920 verbose=verbose-1, sdict=saved_data) 

921 unreliable_fish_mask_t, _ = \ 

922 delete_unreliable_fish(all_t_clusters, eod_widths, eod_xt, verbose=verbose-1) 

923 

924 wave_mask_p, sidepeak_mask_p, saved_data = \ 

925 delete_wavefish_and_sidepeaks(data, all_p_clusters, eod_xp, eod_widths, 

926 width_factor_wave, verbose=verbose-1, sdict=saved_data) 

927 wave_mask_t, sidepeak_mask_t, _ = \ 

928 delete_wavefish_and_sidepeaks(data, all_t_clusters, eod_xt, eod_widths, 

929 width_factor_wave, verbose=verbose-1) 

930 

931 og_clusters = [np.copy(all_p_clusters), np.copy(all_t_clusters)] 

932 og_labels = np.copy(all_p_clusters + all_t_clusters) 

933 

934 # go through all clusters and masks?? 

935 all_p_clusters[(artefact_masks_p | unreliable_fish_mask_p | wave_mask_p | sidepeak_mask_p)] = -1 

936 all_t_clusters[(artefact_masks_t | unreliable_fish_mask_t | wave_mask_t | sidepeak_mask_t)] = -1 

937 

938 # merge here. 

939 all_clusters, x_merge, mask = merge_clusters(np.copy(all_p_clusters), 

940 np.copy(all_t_clusters), 

941 eod_xp, eod_xt, 

942 verbose=verbose - 1) 

943 

944 if 'all_cluster_steps' in return_data or plot_level > 0: 

945 all_dmasks = [] 

946 all_mmasks = [] 

947 

948 discarding_masks = \ 

949 np.vstack(((artefact_masks_p | unreliable_fish_mask_p | wave_mask_p | sidepeak_mask_p), 

950 (artefact_masks_t | unreliable_fish_mask_t | wave_mask_t | sidepeak_mask_t))) 

951 merge_mask = mask 

952 

953 # save the masks in the same formats as the snippets 

954 for wi, (width_label, w_shape_label, heightlabels, unique_height_labels) in enumerate(zip(unique_width_labels, all_shapelabels, all_heightlabels, all_unique_heightlabels)): 

955 w_dmasks = discarding_masks[:,width_labels==width_label] 

956 w_mmasks = merge_mask[:,width_labels==width_label] 

957 

958 wd_2 = [] 

959 wm_2 = [] 

960 

961 for hi, (height_label, h_shape_label) in enumerate(zip(unique_height_labels, w_shape_label)): 

962 

963 h_dmasks = w_dmasks[:,heightlabels==height_label] 

964 h_mmasks = w_mmasks[:,heightlabels==height_label] 

965 

966 wd_2.append(h_dmasks) 

967 wm_2.append(h_mmasks) 

968 

969 all_dmasks.append(wd_2) 

970 all_mmasks.append(wm_2) 

971 

972 if plot_level > 0: 

973 plot_clustering(rate, [unique_width_labels, eod_widths, width_labels], 

974 [all_unique_heightlabels, all_heights, all_heightlabels], 

975 [all_snippets, all_features, all_shapelabels], 

976 all_dmasks, all_mmasks) 

977 if save_plots: 

978 plt.savefig('%sclustering.%s' % (save_path, ftype)) 

979 

980 if 'all_cluster_steps' in return_data: 

981 saved_data = {'rate': rate, 

982 'EOD_widths': [unique_width_labels, eod_widths, width_labels], 

983 'EOD_heights': [all_unique_heightlabels, all_heights, all_heightlabels], 

984 'EOD_shapes': [all_snippets, all_features, all_shapelabels], 

985 'discarding_masks': all_dmasks, 

986 'merge_masks': all_mmasks 

987 } 

988 

989 if 'masks' in return_data: 

990 saved_data = {'masks' : np.vstack(((artefact_masks_p & artefact_masks_t), 

991 (unreliable_fish_mask_p & unreliable_fish_mask_t), 

992 (wave_mask_p & wave_mask_t), 

993 (sidepeak_mask_p & sidepeak_mask_t), 

994 (all_p_clusters+all_t_clusters)))} 

995 

996 if verbose > 0: 

997 print('Clusters generated based on height, width and shape: ') 

998 for l in np.unique(all_clusters[all_clusters != -1]): 

999 print(f'N_{int(l)} = {len(all_clusters[all_clusters == l]):4d}') 

1000 

1001 return all_clusters, x_merge, saved_data 

1002 

1003 

1004def BGM(x, merge_threshold=0.1, n_gaus=5, max_iter=200, n_init=5, 

1005 use_log=False, verbose=0, plot_level=0, xlabel='x [a.u.]', 

1006 save_plot=False, save_path='', save_name='', ftype='pdf', 

1007 return_data=[]): 

1008 """ Use a Bayesian Gaussian Mixture Model to cluster one-dimensional data. 

1009 

1010 Additional steps are used to merge clusters that are closer than 

1011 `merge_threshold`. Broad gaussian fits that cover one or more other 

1012 gaussian fits are split by their intersections with the other 

1013 gaussians. 

1014 

1015 Parameters 

1016 ---------- 

1017 x : 1D numpy array 

1018 Features to compute clustering on.  

1019 

1020 merge_threshold : float (optional) 

1021 Ratio for merging nearby gaussians. 

1022 n_gaus: int (optional) 

1023 Maximum number of gaussians to fit on data. 

1024 max_iter : int (optional) 

1025 Maximum number of iterations for gaussian fit. 

1026 n_init : int (optional) 

1027 Number of initializations for the gaussian fit. 

1028 use_log: boolean (optional) 

1029 Set to True to compute the gaussian fit on the logarithm of x. 

1030 Can improve clustering on features with nonlinear relationships such as peak height. 

1031 verbose : int (optional) 

1032 Verbosity level. 

1033 plot_level : int (optional) 

1034 Similar to verbosity levels, but with plots.  

1035 Only set to > 0 for debugging purposes. 

1036 xlabel : string (optional) 

1037 Xlabel for displaying BGM plot. 

1038 save_plot : bool (optional) 

1039 Set to True to save created plot. 

1040 save_path : string (optional) 

1041 Path to location where data should be saved. Only used if save_plot==True. 

1042 save_name : string (optional) 

1043 Filename of the saved plot. Usefull as usually multiple BGM models are generated. 

1044 ftype : string (optional) 

1045 Filetype of plot image if save_plots==True. 

1046 return_data : list of strings (optional) 

1047 Keys that specify data to be logged. Keys that can be used to log data 

1048 in this function are: 'BGM_width' and/or 'BGM_height' (see extract_pulsefish()). 

1049 

1050 Returns 

1051 ------- 

1052 labels : 1D numpy array 

1053 Cluster labels for each sample in x. 

1054 bgm_dict : dictionary 

1055 Key value pairs of logged data. Data to be logged is specified by return_data. 

1056 """ 

1057 

1058 bgm_dict = {} 

1059 

1060 if len(np.unique(x)) > n_gaus: 

1061 BGM_model = BayesianGaussianMixture(n_components=n_gaus, max_iter=max_iter, n_init=n_init) 

1062 if use_log: 

1063 labels = BGM_model.fit_predict(stats.zscore(np.log(x)).reshape(-1, 1)) 

1064 else: 

1065 labels = BGM_model.fit_predict(stats.zscore(x).reshape(-1, 1)) 

1066 else: 

1067 return np.zeros(len(x)), bgm_dict 

1068 

1069 if verbose>0: 

1070 if not BGM_model.converged_: 

1071 print('!!! Gaussian mixture did not converge !!!') 

1072 

1073 cur_labels = np.unique(labels) 

1074 

1075 # map labels to be increasing for increasing values for x 

1076 maxlab = len(cur_labels) 

1077 aso = np.argsort([np.median(x[labels == l]) for l in cur_labels]) + 100 

1078 for i, a in zip(cur_labels, aso): 

1079 labels[labels==i] = a 

1080 labels = labels - 100 

1081 

1082 # separate gaussian clusters that can be split by other clusters 

1083 splits = np.sort(np.copy(x))[1:][np.diff(labels[np.argsort(x)])!=0] 

1084 

1085 labels[:] = 0 

1086 for i, split in enumerate(splits): 

1087 labels[x>=split] = i+1 

1088 

1089 labels_before_merge = np.copy(labels) 

1090 

1091 # merge gaussian clusters that are closer than merge_threshold 

1092 labels = merge_gaussians(x, labels, merge_threshold) 

1093 

1094 if 'BGM_'+save_name.split('_')[0] in return_data or plot_level>0: 

1095 

1096 #sort model attributes by model_means_ 

1097 means = [m[0] for m in BGM_model.means_] 

1098 weights = [w for w in BGM_model.weights_] 

1099 variances = [v[0][0] for v in BGM_model.covariances_] 

1100 weights = [w for _, w in sorted(zip(means, weights))] 

1101 variances = [v for _, v in sorted(zip(means, variances))] 

1102 means = sorted(means) 

1103 

1104 if plot_level>0: 

1105 plot_bgm(x, means, variances, weights, use_log, labels_before_merge, 

1106 labels, xlabel) 

1107 if save_plot: 

1108 plt.savefig('%sBGM_%s.%s' % (save_path, save_name, ftype)) 

1109 

1110 if 'BGM_'+save_name.split('_')[0] in return_data: 

1111 bgm_dict['BGM_'+save_name] = {'x':x, 

1112 'use_log':use_log, 

1113 'BGM':[weights, means, variances], 

1114 'labels':labels_before_merge, 

1115 'xlab':xlabel} 

1116 

1117 return labels, bgm_dict 

1118 

1119 

1120def merge_gaussians(x, labels, merge_threshold=0.1): 

1121 """ Merge all clusters which have medians which are near. Only works in 1D. 

1122 

1123 Parameters 

1124 ---------- 

1125 x : 1D array of ints or floats 

1126 Features used for clustering. 

1127 labels : 1D array of ints 

1128 Labels for each sample in x. 

1129 merge_threshold : float (optional) 

1130 Similarity threshold to merge clusters. 

1131 

1132 Returns 

1133 ------- 

1134 labels : 1D array of ints 

1135 Merged labels for each sample in x. 

1136 """ 

1137 

1138 # compare all the means of the gaussians. If they are too close, merge them. 

1139 unique_labels = np.unique(labels[labels!=-1]) 

1140 x_medians = [np.median(x[labels==l]) for l in unique_labels] 

1141 

1142 # fill a dict with the label mappings 

1143 mapping = {} 

1144 for label_1, x_m1 in zip(unique_labels, x_medians): 

1145 for label_2, x_m2 in zip(unique_labels, x_medians): 

1146 if label_1!=label_2: 

1147 if np.abs(np.diff([x_m1, x_m2]))/np.max([x_m1, x_m2]) < merge_threshold: 

1148 mapping[label_1] = label_2 

1149 # apply mapping 

1150 for map_key, map_value in mapping.items(): 

1151 labels[labels==map_key] = map_value 

1152 

1153 return labels 

1154 

1155 

1156def extract_snippet_features(data, eod_x, eod_heights, width, n_pc=5): 

1157 """ Extract snippets from recording data, normalize them, and perform PCA. 

1158 

1159 Parameters 

1160 ---------- 

1161 data : 1D numpy array of floats 

1162 Recording data. 

1163 eod_x : 1D array of ints 

1164 Locations of EODs as indices. 

1165 eod_heights: 1D array of floats 

1166 EOD heights. 

1167 width : int 

1168 Width to cut out to each side in samples. 

1169 

1170 n_pc : int (optional) 

1171 Number of PCs to use for PCA. 

1172 

1173 Returns 

1174 ------- 

1175 raw_snippets : 2D numpy array (N, EOD_width) 

1176 Raw extracted EOD snippets. 

1177 snippets : 2D numpy array (N, EOD_width) 

1178 Normalized EOD snippets 

1179 features : 2D numpy array (N,n_pc) 

1180 PC values of EOD snippets 

1181 bg_ratio : 1D numpy array (N) 

1182 Ratio of the background activity slopes compared to EOD height. 

1183 """ 

1184 # extract snippets with corresponding width 

1185 raw_snippets = np.vstack([data[x-width:x+width] for x in eod_x]) 

1186 

1187 # subtract the slope and normalize the snippets 

1188 snippets, bg_ratio = subtract_slope(np.copy(raw_snippets), eod_heights) 

1189 snippets = StandardScaler().fit_transform(snippets.T).T 

1190 

1191 # scale so that the absolute integral = 1. 

1192 snippets = (snippets.T/np.sum(np.abs(snippets), axis=1)).T 

1193 

1194 # compute features for clustering on waveform 

1195 features = PCA(n_pc).fit_transform(snippets) 

1196 

1197 return raw_snippets, snippets, features, bg_ratio 

1198 

1199 

1200def cluster_on_shape(features, bg_ratio, minp, percentile=80, 

1201 max_epsilon=0.01, slope_ratio_factor=4, 

1202 min_cluster_fraction=0.01, verbose=0): 

1203 """Separate EODs by their shape using DBSCAN. 

1204 

1205 Parameters 

1206 ---------- 

1207 features : 2D numpy array of floats (N, n_pc) 

1208 PCA features of each EOD in a recording. 

1209 bg_ratio : 1D array of floats 

1210 Ratio of background activity slope the EOD is superimposed on. 

1211 minp : int 

1212 Minimum number of points for core cluster (DBSCAN). 

1213 

1214 percentile : int (optional) 

1215 Percentile of KNN distribution, where K=minp, to use as epsilon for DBSCAN. 

1216 max_epsilon : float (optional) 

1217 Maximum epsilon to use for DBSCAN clustering. This is used to avoid adding 

1218 noisy clusters. 

1219 slope_ratio_factor : float (optional) 

1220 Influence of the slope-to-EOD ratio on the epsilon parameter. 

1221 A slope_ratio_factor of 4 means that slope-to-EOD ratios >1/4 

1222 start influencing epsilon. 

1223 min_cluster_fraction : float (optional) 

1224 Minimum fraction of all eveluated datapoint that can form a single cluster. 

1225 verbose : int (optional) 

1226 Verbosity level. 

1227 

1228 Returns 

1229 ------- 

1230 labels : 1D array of ints 

1231 Merged labels for each sample in x. 

1232 """ 

1233 

1234 # determine clustering threshold from data 

1235 minpc = max(minp, int(len(features)*min_cluster_fraction)) 

1236 knn = np.sort(pairwise_distances(features, features), axis=0)[minpc] 

1237 eps = min(max(1, slope_ratio_factor*np.median(bg_ratio))*max_epsilon, 

1238 np.percentile(knn, percentile)) 

1239 

1240 if verbose>1: 

1241 print('epsilon = %f'%eps) 

1242 print('Slope to EOD ratio = %f'%np.median(bg_ratio)) 

1243 

1244 # cluster on EOD shape 

1245 return DBSCAN(eps=eps, min_samples=minpc).fit(features).labels_ 

1246 

1247 

1248def subtract_slope(snippets, heights): 

1249 """ Subtract underlying slope from all EOD snippets. 

1250 

1251 Parameters 

1252 ---------- 

1253 snippets: 2-D numpy array 

1254 All EODs in a recorded stacked as snippets.  

1255 Shape = (number of EODs, EOD width) 

1256 heights: 1D numpy array 

1257 EOD heights. 

1258 

1259 Returns 

1260 ------- 

1261 snippets: 2-D numpy array 

1262 EOD snippets with underlying slope subtracted. 

1263 bg_ratio : 1-D numpy array 

1264 EOD height/background activity height. 

1265 """ 

1266 

1267 left_y = snippets[:,0] 

1268 right_y = snippets[:,-1] 

1269 

1270 try: 

1271 slopes = np.linspace(left_y, right_y, snippets.shape[1]) 

1272 except ValueError: 

1273 delta = (right_y - left_y)/snippets.shape[1] 

1274 slopes = np.arange(0, snippets.shape[1], dtype=snippets.dtype).reshape((-1,) + (1,) * np.ndim(delta))*delta + left_y 

1275 

1276 return snippets - slopes.T, np.abs(left_y-right_y)/heights 

1277 

1278 

1279def remove_artefacts(all_snippets, clusters, rate, 

1280 freq_low=20000, threshold=0.75, 

1281 verbose=0, return_data=[]): 

1282 """ Create a mask for EOD clusters that result from artefacts, based on power in low frequency spectrum. 

1283 

1284 Parameters 

1285 ---------- 

1286 all_snippets: 2D array 

1287 EOD snippets. Shape=(nEODs, EOD length) 

1288 clusters: list of ints 

1289 EOD cluster labels 

1290 rate : float 

1291 Sampling rate of original recording data. 

1292 freq_low: float 

1293 Frequency up to which low frequency components are summed up.  

1294 threshold : float (optional) 

1295 Minimum value for sum of low frequency components relative to 

1296 sum overa ll spectrl amplitudes that separates artefact from 

1297 clean pulsefish clusters. 

1298 verbose : int (optional) 

1299 Verbosity level. 

1300 return_data : list of strings (optional) 

1301 Keys that specify data to be logged. The key that can be used to log data in this function is 

1302 'eod_deletion' (see extract_pulsefish()). 

1303 

1304 Returns 

1305 ------- 

1306 mask: numpy array of booleans 

1307 Set to True for every EOD which is an artefact. 

1308 adict : dictionary 

1309 Key value pairs of logged data. Data to be logged is specified by return_data. 

1310 """ 

1311 adict = {} 

1312 

1313 mask = np.zeros(clusters.shape, dtype=bool) 

1314 

1315 for cluster in np.unique(clusters[clusters >= 0]): 

1316 snippets = all_snippets[clusters == cluster] 

1317 mean_eod = np.mean(snippets, axis=0) 

1318 mean_eod = mean_eod - np.mean(mean_eod) 

1319 mean_eod_fft = np.abs(np.fft.rfft(mean_eod)) 

1320 freqs = np.fft.rfftfreq(len(mean_eod), 1/rate) 

1321 low_frequency_ratio = np.sum(mean_eod_fft[freqs<freq_low])/np.sum(mean_eod_fft) 

1322 if low_frequency_ratio < threshold: # TODO: check threshold! 

1323 mask[clusters==cluster] = True 

1324 

1325 if verbose > 0: 

1326 print('Deleting cluster %i with low frequency ratio of %.3f (min %.3f)' % (cluster, low_frequency_ratio, threshold)) 

1327 

1328 if 'eod_deletion' in return_data: 

1329 adict['vals_%d' % cluster] = [mean_eod, mean_eod_fft] 

1330 adict['mask_%d' % cluster] = [np.any(mask[clusters==cluster])] 

1331 

1332 return mask, adict 

1333 

1334 

1335def delete_unreliable_fish(clusters, eod_widths, eod_x, verbose=0, sdict={}): 

1336 """ Create a mask for EOD clusters that are either mixed with noise or other fish, or wavefish. 

1337  

1338 This is the case when the ration between the EOD width and the ISI is too large. 

1339 

1340 Parameters 

1341 ---------- 

1342 clusters : list of ints 

1343 Cluster labels. 

1344 eod_widths : list of floats or ints 

1345 EOD widths in samples or seconds. 

1346 eod_x : list of ints or floats 

1347 EOD times in samples or seconds. 

1348 

1349 verbose : int (optional) 

1350 Verbosity level. 

1351 sdict : dictionary 

1352 Dictionary that is used to log data. This is only used if a dictionary 

1353 was created by remove_artefacts(). 

1354 For logging data in noise and wavefish discarding steps, 

1355 see remove_artefacts(). 

1356 

1357 Returns 

1358 ------- 

1359 mask : numpy array of booleans 

1360 Set to True for every unreliable EOD. 

1361 sdict : dictionary 

1362 Key value pairs of logged data. Data is only logged if a dictionary 

1363 was instantiated by remove_artefacts(). 

1364 """ 

1365 mask = np.zeros(clusters.shape, dtype=bool) 

1366 for cluster in np.unique(clusters[clusters >= 0]): 

1367 if len(eod_x[cluster == clusters]) < 2: 

1368 mask[clusters == cluster] = True 

1369 if verbose>0: 

1370 print('deleting unreliable cluster %i, number of EOD times %d < 2' % (cluster, len(eod_x[cluster==clusters]))) 

1371 elif np.max(np.median(eod_widths[clusters==cluster])/np.diff(eod_x[cluster==clusters])) > 0.5: 

1372 if verbose>0: 

1373 print('deleting unreliable cluster %i, score=%f' % (cluster, np.max(np.median(eod_widths[clusters==cluster])/np.diff(eod_x[cluster==clusters])))) 

1374 mask[clusters==cluster] = True 

1375 if 'vals_%d' % cluster in sdict: 

1376 sdict['vals_%d' % cluster].append(np.median(eod_widths[clusters==cluster])/np.diff(eod_x[cluster==clusters])) 

1377 sdict['mask_%d' % cluster].append(any(mask[clusters==cluster])) 

1378 return mask, sdict 

1379 

1380 

1381def delete_wavefish_and_sidepeaks(data, clusters, eod_x, eod_widths, 

1382 width_fac, max_slope_deviation=0.5, 

1383 max_phases=4, verbose=0, sdict={}): 

1384 """ Create a mask for EODs that are likely from wavefish, or sidepeaks of bigger EODs. 

1385 

1386 Parameters 

1387 ---------- 

1388 data : list of floats 

1389 Raw recording data. 

1390 clusters : list of ints 

1391 Cluster labels. 

1392 eod_x : list of ints 

1393 Indices of EOD times. 

1394 eod_widths : list of ints 

1395 EOD widths in samples. 

1396 width_fac : float 

1397 Multiplier for EOD analysis width. 

1398 

1399 max_slope_deviation: float (optional) 

1400 Maximum deviation of position of maximum slope in snippets from 

1401 center position in multiples of mean width of EOD. 

1402 max_phases : int (optional) 

1403 Maximum number of phases for any EOD.  

1404 If the mean EOD has more phases than this, it is not a pulse EOD. 

1405 verbose : int (optional)  

1406 Verbosity level. 

1407 sdict : dictionary 

1408 Dictionary that is used to log data. This is only used if a dictionary 

1409 was created by remove_artefacts(). 

1410 For logging data in noise and wavefish discarding steps, see remove_artefacts(). 

1411 

1412 Returns 

1413 ------- 

1414 mask_wave: numpy array of booleans 

1415 Set to True for every EOD which is a wavefish EOD. 

1416 mask_sidepeak: numpy array of booleans 

1417 Set to True for every snippet which is centered around a sidepeak of an EOD. 

1418 sdict : dictionary 

1419 Key value pairs of logged data. Data is only logged if a dictionary 

1420 was instantiated by remove_artefacts(). 

1421 """ 

1422 mask_wave = np.zeros(clusters.shape, dtype=bool) 

1423 mask_sidepeak = np.zeros(clusters.shape, dtype=bool) 

1424 

1425 for i, cluster in enumerate(np.unique(clusters[clusters >= 0])): 

1426 mean_width = np.mean(eod_widths[clusters == cluster]) 

1427 cutwidth = mean_width*width_fac 

1428 current_x = eod_x[(eod_x>cutwidth) & (eod_x<(len(data)-cutwidth))] 

1429 current_clusters = clusters[(eod_x>cutwidth) & (eod_x<(len(data)-cutwidth))] 

1430 snippets = np.vstack([data[int(x-cutwidth):int(x+cutwidth)] 

1431 for x in current_x[current_clusters==cluster]]) 

1432 

1433 # extract information on main peaks and troughs: 

1434 mean_eod = np.mean(snippets, axis=0) 

1435 mean_eod = mean_eod - np.mean(mean_eod) 

1436 

1437 # detect peaks and troughs on data + some maxima/minima at the 

1438 # end, so that the sides are also considered for peak detection: 

1439 pk, tr = detect_peaks(np.concatenate(([-10*mean_eod[0]], mean_eod, [10*mean_eod[-1]])), 

1440 np.std(mean_eod)) 

1441 pk = pk[(pk>0)&(pk<len(mean_eod))] 

1442 tr = tr[(tr>0)&(tr<len(mean_eod))] 

1443 

1444 if len(pk)>0 and len(tr)>0: 

1445 idxs = np.sort(np.concatenate((pk, tr))) 

1446 slopes = np.abs(np.diff(mean_eod[idxs])) 

1447 m_slope = np.argmax(slopes) 

1448 centered = np.min(np.abs(idxs[m_slope:m_slope+2] - len(mean_eod)//2)) 

1449 

1450 # compute all height differences of peaks and troughs within snippets. 

1451 # if they are all similar, it is probably noise or a wavefish. 

1452 idxs = np.sort(np.concatenate((pk, tr))) 

1453 hdiffs = np.diff(mean_eod[idxs]) 

1454 

1455 if centered > max_slope_deviation*mean_width: # TODO: check, factor was probably 0.16 

1456 if verbose > 0: 

1457 print('Deleting cluster %i, which is a sidepeak' % cluster) 

1458 mask_sidepeak[clusters==cluster] = True 

1459 

1460 w_diff = np.abs(np.diff(np.sort(np.concatenate((pk, tr))))) 

1461 

1462 if np.abs(np.diff(idxs[m_slope:m_slope+2])) < np.mean(eod_widths[clusters==cluster])*0.5 or len(pk) + len(tr)>max_phases or np.min(w_diff)>2*cutwidth/width_fac: #or len(hdiffs[np.abs(hdiffs)>0.5*(np.max(mean_eod)-np.min(mean_eod))])>max_phases: 

1463 if verbose>0: 

1464 print('Deleting cluster %i, which is a wavefish' % cluster) 

1465 mask_wave[clusters==cluster] = True 

1466 if 'vals_%d' % cluster in sdict: 

1467 sdict['vals_%d' % cluster].append([mean_eod, [pk, tr], 

1468 idxs[m_slope:m_slope+2]]) 

1469 sdict['mask_%d' % cluster].append(any(mask_wave[clusters==cluster])) 

1470 sdict['mask_%d' % cluster].append(any(mask_sidepeak[clusters==cluster])) 

1471 

1472 return mask_wave, mask_sidepeak, sdict 

1473 

1474 

1475def merge_clusters(clusters_1, clusters_2, x_1, x_2, verbose=0): 

1476 """ Merge clusters resulting from two clustering methods. 

1477 

1478 This method only works if clustering is performed on the same EODs 

1479 with the same ordering, where there is a one to one mapping from 

1480 clusters_1 to clusters_2.  

1481 

1482 Parameters 

1483 ---------- 

1484 clusters_1: list of ints 

1485 EOD cluster labels for cluster method 1. 

1486 clusters_2: list of ints 

1487 EOD cluster labels for cluster method 2. 

1488 x_1: list of ints 

1489 Indices of EODs for cluster method 1 (clusters_1). 

1490 x_2: list of ints 

1491 Indices of EODs for cluster method 2 (clusters_2). 

1492 verbose : int (optional) 

1493 Verbosity level. 

1494 

1495 Returns 

1496 ------- 

1497 clusters : list of ints 

1498 Merged clusters. 

1499 x_merged : list of ints 

1500 Merged cluster indices. 

1501 mask : 2d numpy array of ints (N, 2) 

1502 Mask for clusters that are selected from clusters_1 (mask[:,0]) and 

1503 from clusters_2 (mask[:,1]). 

1504 """ 

1505 if verbose > 0: 

1506 print('\nMerge cluster:') 

1507 

1508 # these arrays become 1 for each EOD that is chosen from that array 

1509 c1_keep = np.zeros(len(clusters_1)) 

1510 c2_keep = np.zeros(len(clusters_2)) 

1511 

1512 # add n to one of the cluster lists to avoid overlap 

1513 ovl = np.max(clusters_1) + 1 

1514 clusters_2[clusters_2!=-1] = clusters_2[clusters_2!=-1] + ovl 

1515 

1516 remove_clusters = [[]] 

1517 keep_clusters = [] 

1518 og_clusters = [np.copy(clusters_1), np.copy(clusters_2)] 

1519 

1520 # loop untill done 

1521 while True: 

1522 

1523 # compute unique clusters and cluster sizes 

1524 # of cluster that have not been iterated over: 

1525 c1_labels, c1_size = unique_counts(clusters_1[(clusters_1 != -1) & (c1_keep == 0)]) 

1526 c2_labels, c2_size = unique_counts(clusters_2[(clusters_2 != -1) & (c2_keep == 0)]) 

1527 

1528 # if all clusters are done, break from loop: 

1529 if len(c1_size) == 0 and len(c2_size) == 0: 

1530 break 

1531 

1532 # if the biggest cluster is in c_p, keep this one and discard all clusters 

1533 # on the same indices in c_t: 

1534 elif np.argmax([np.max(np.append(c1_size, 0)), np.max(np.append(c2_size, 0))]) == 0: 

1535 

1536 # remove all the mappings from the other indices 

1537 cluster_mappings, _ = unique_counts(clusters_2[clusters_1 == c1_labels[np.argmax(c1_size)]]) 

1538 

1539 clusters_2[np.isin(clusters_2, cluster_mappings)] = -1 

1540 

1541 c1_keep[clusters_1==c1_labels[np.argmax(c1_size)]] = 1 

1542 

1543 remove_clusters.append(cluster_mappings) 

1544 keep_clusters.append(c1_labels[np.argmax(c1_size)]) 

1545 

1546 if verbose > 0: 

1547 print('Keep cluster %i of group 1, delete clusters %s of group 2' % (c1_labels[np.argmax(c1_size)], str(cluster_mappings[cluster_mappings!=-1] - ovl))) 

1548 

1549 # if the biggest cluster is in c_t, keep this one and discard all mappings in c_p 

1550 elif np.argmax([np.max(np.append(c1_size, 0)), np.max(np.append(c2_size, 0))]) == 1: 

1551 

1552 # remove all the mappings from the other indices 

1553 cluster_mappings, _ = unique_counts(clusters_1[clusters_2 == c2_labels[np.argmax(c2_size)]]) 

1554 

1555 clusters_1[np.isin(clusters_1, cluster_mappings)] = -1 

1556 

1557 c2_keep[clusters_2==c2_labels[np.argmax(c2_size)]] = 1 

1558 

1559 remove_clusters.append(cluster_mappings) 

1560 keep_clusters.append(c2_labels[np.argmax(c2_size)]) 

1561 

1562 if verbose > 0: 

1563 print('Keep cluster %i of group 2, delete clusters %s of group 1' % (c2_labels[np.argmax(c2_size)] - ovl, str(cluster_mappings[cluster_mappings!=-1]))) 

1564 

1565 # combine results  

1566 clusters = (clusters_1+1)*c1_keep + (clusters_2+1)*c2_keep - 1 

1567 x_merged = (x_1)*c1_keep + (x_2)*c2_keep 

1568 

1569 return clusters, x_merged, np.vstack([c1_keep, c2_keep]) 

1570 

1571 

1572def extract_means(data, eod_x, eod_peak_x, eod_tr_x, eod_widths, 

1573 clusters, rate, width_fac, verbose=0): 

1574 """ Extract mean EODs and EOD timepoints for each EOD cluster. 

1575 

1576 Parameters 

1577 ---------- 

1578 data: list of floats 

1579 Raw recording data. 

1580 eod_x: list of ints 

1581 Locations of EODs in samples. 

1582 eod_peak_x : list of ints 

1583 Locations of EOD peaks in samples. 

1584 eod_tr_x : list of ints 

1585 Locations of EOD troughs in samples. 

1586 eod_widths: list of ints 

1587 EOD widths in samples. 

1588 clusters: list of ints 

1589 EOD cluster labels 

1590 rate: float 

1591 Sampling rate of recording  

1592 width_fac : float 

1593 Multiplication factor for window used to extract EOD. 

1594  

1595 verbose : int (optional) 

1596 Verbosity level. 

1597 

1598 Returns 

1599 ------- 

1600 mean_eods: list of 2D arrays (3, eod_length) 

1601 The average EOD for each detected fish. First column is time in seconds, 

1602 second column the mean eod, third column the standard error. 

1603 eod_times: list of 1D arrays 

1604 For each detected fish the times of EOD in seconds. 

1605 eod_peak_times: list of 1D arrays 

1606 For each detected fish the times of EOD peaks in seconds. 

1607 eod_trough_times: list of 1D arrays 

1608 For each detected fish the times of EOD troughs in seconds. 

1609 eod_labels: list of ints 

1610 Cluster label for each detected fish. 

1611 """ 

1612 mean_eods, eod_times, eod_peak_times, eod_tr_times, eod_heights, cluster_labels = [], [], [], [], [], [] 

1613 

1614 for cluster in np.unique(clusters): 

1615 if cluster!=-1: 

1616 cutwidth = np.mean(eod_widths[clusters==cluster])*width_fac 

1617 current_x = eod_x[(eod_x>cutwidth) & (eod_x<(len(data)-cutwidth))] 

1618 current_clusters = clusters[(eod_x>cutwidth) & (eod_x<(len(data)-cutwidth))] 

1619 

1620 snippets = np.vstack([data[int(x-cutwidth):int(x+cutwidth)] for x in current_x[current_clusters==cluster]]) 

1621 mean_eod = np.mean(snippets, axis=0) 

1622 eod_time = np.arange(len(mean_eod))/rate - cutwidth/rate 

1623 

1624 mean_eod = np.vstack([eod_time, mean_eod, np.std(snippets, axis=0)]) 

1625 

1626 mean_eods.append(mean_eod) 

1627 eod_times.append(eod_x[clusters==cluster]/rate) 

1628 eod_heights.append(np.min(mean_eod)-np.max(mean_eod)) 

1629 eod_peak_times.append(eod_peak_x[clusters==cluster]/rate) 

1630 eod_tr_times.append(eod_tr_x[clusters==cluster]/rate) 

1631 cluster_labels.append(cluster) 

1632 

1633 return [m for _, m in sorted(zip(eod_heights, mean_eods))], [t for _, t in sorted(zip(eod_heights, eod_times))], [pt for _, pt in sorted(zip(eod_heights, eod_peak_times))], [tt for _, tt in sorted(zip(eod_heights, eod_tr_times))], [c for _, c in sorted(zip(eod_heights, cluster_labels))] 

1634 

1635 

1636def find_clipped_clusters(clusters, mean_eods, eod_times, 

1637 eod_peaktimes, eod_troughtimes, 

1638 cluster_labels, width_factor, 

1639 clip_threshold=0.9, verbose=0): 

1640 """ Detect EODs that are clipped and set all clusterlabels of these clipped EODs to -1. 

1641  

1642 Also return the mean EODs and timepoints of these clipped EODs. 

1643 

1644 Parameters 

1645 ---------- 

1646 clusters: array of ints 

1647 Cluster labels for each EOD in a recording. 

1648 mean_eods: list of numpy arrays 

1649 Mean EOD waveform for each cluster. 

1650 eod_times: list of numpy arrays 

1651 EOD timepoints for each EOD cluster. 

1652 eod_peaktimes 

1653 EOD peaktimes for each EOD cluster. 

1654 eod_troughtimes 

1655 EOD troughtimes for each EOD cluster. 

1656 cluster_labels: numpy array 

1657 Unique EOD clusterlabels. 

1658 clip_threshold: float 

1659 Threshold for detecting clipped EODs. 

1660  

1661 verbose: int 

1662 Verbosity level. 

1663 

1664 Returns 

1665 ------- 

1666 clusters : array of ints 

1667 Cluster labels for each EOD in the recording, where clipped EODs have been set to -1. 

1668 clipped_eods : list of numpy arrays 

1669 Mean EOD waveforms for each clipped EOD cluster. 

1670 clipped_times : list of numpy arrays 

1671 EOD timepoints for each clipped EOD cluster. 

1672 clipped_peaktimes : list of numpy arrays 

1673 EOD peaktimes for each clipped EOD cluster. 

1674 clipped_troughtimes : list of numpy arrays 

1675 EOD troughtimes for each clipped EOD cluster. 

1676 """ 

1677 clipped_eods, clipped_times, clipped_peaktimes, clipped_troughtimes, clipped_labels = [], [], [], [], [] 

1678 

1679 for mean_eod, eod_time, eod_peaktime, eod_troughtime,label in zip(mean_eods, eod_times, eod_peaktimes, eod_troughtimes, cluster_labels): 

1680 

1681 if (np.count_nonzero(mean_eod[1]>clip_threshold) > len(mean_eod[1])/(width_factor*2)) or (np.count_nonzero(mean_eod[1] < -clip_threshold) > len(mean_eod[1])/(width_factor*2)): 

1682 clipped_eods.append(mean_eod) 

1683 clipped_times.append(eod_time) 

1684 clipped_peaktimes.append(eod_peaktime) 

1685 clipped_troughtimes.append(eod_troughtime) 

1686 clipped_labels.append(label) 

1687 if verbose>0: 

1688 print('clipped pulsefish') 

1689 

1690 clusters[np.isin(clusters, clipped_labels)] = -1 

1691 

1692 return clusters, clipped_eods, clipped_times, clipped_peaktimes, clipped_troughtimes 

1693 

1694 

1695def delete_moving_fish(clusters, eod_t, T, eod_heights, eod_widths, 

1696 rate, min_dt=0.25, stepsize=0.05, 

1697 sliding_window_factor=2000, verbose=0, 

1698 plot_level=0, save_plot=False, save_path='', 

1699 ftype='pdf', return_data=[]): 

1700 """ 

1701 Use a sliding window to detect the minimum number of fish detected simultaneously,  

1702 then delete all other EOD clusters.  

1703 

1704 Do this only for EODs within the same width clusters, as a 

1705 moving fish will preserve its EOD width. 

1706 

1707 Parameters 

1708 ---------- 

1709 clusters: list of ints 

1710 EOD cluster labels. 

1711 eod_t: list of floats 

1712 Timepoints of the EODs (in seconds). 

1713 T: float 

1714 Length of recording (in seconds). 

1715 eod_heights: list of floats 

1716 EOD amplitudes. 

1717 eod_widths: list of floats 

1718 EOD widths (in seconds). 

1719 rate: float 

1720 Recording data sampling rate. 

1721 

1722 min_dt : float (optional) 

1723 Minimum sliding window size (in seconds). 

1724 stepsize : float (optional) 

1725 Sliding window stepsize (in seconds). 

1726 sliding_window_factor : float 

1727 Multiplier for sliding window width, 

1728 where the sliding window width = median(EOD_width)*sliding_window_factor. 

1729 verbose : int (optional) 

1730 Verbosity level. 

1731 plot_level : int (optional) 

1732 Similar to verbosity levels, but with plots.  

1733 Only set to > 0 for debugging purposes. 

1734 save_plot : bool (optional) 

1735 Set to True to save the plots created by plot_level. 

1736 save_path : string (optional) 

1737 Path to save data to. Only important if you wish to save data (save_data==True). 

1738 ftype : string (optional) 

1739 Define the filetype to save the plots in if save_plots is set to True. 

1740 Options are: 'png', 'jpg', 'svg' ... 

1741 return_data : list of strings (optional) 

1742 Keys that specify data to be logged. The key that can be used to log data 

1743 in this function is 'moving_fish' (see extract_pulsefish()). 

1744 

1745 Returns 

1746 ------- 

1747 clusters : list of ints 

1748 Cluster labels, where deleted clusters have been set to -1. 

1749 window : list of 2 floats 

1750 Start and end of window selected for deleting moving fish in seconds. 

1751 mf_dict : dictionary 

1752 Key value pairs of logged data. Data to be logged is specified by return_data. 

1753 """ 

1754 mf_dict = {} 

1755 

1756 if len(np.unique(clusters[clusters != -1])) == 0: 

1757 return clusters, [0, 1], {} 

1758 

1759 all_keep_clusters = [] 

1760 width_classes = merge_gaussians(eod_widths, np.copy(clusters), 0.75) 

1761 

1762 all_windows = [] 

1763 all_dts = [] 

1764 ev_num = 0 

1765 for iw, w in enumerate(np.unique(width_classes[clusters >= 0])): 

1766 # initialize variables 

1767 min_clusters = 100 

1768 average_height = 0 

1769 sparse_clusters = 100 

1770 keep_clusters = [] 

1771 

1772 dt = max(min_dt, np.median(eod_widths[width_classes==w])*sliding_window_factor) 

1773 window_start = 0 

1774 window_end = dt 

1775 

1776 wclusters = clusters[width_classes==w] 

1777 weod_t = eod_t[width_classes==w] 

1778 weod_heights = eod_heights[width_classes==w] 

1779 weod_widths = eod_widths[width_classes==w] 

1780 

1781 all_dts.append(dt) 

1782 

1783 if verbose>0: 

1784 print('sliding window dt = %f'%dt) 

1785 

1786 # make W dependent on width?? 

1787 ignore_steps = np.zeros(len(np.arange(0, T-dt+stepsize, stepsize))) 

1788 

1789 for i, t in enumerate(np.arange(0, T-dt+stepsize, stepsize)): 

1790 current_clusters = wclusters[(weod_t>=t)&(weod_t<t+dt)&(wclusters!=-1)] 

1791 if len(np.unique(current_clusters))==0: 

1792 ignore_steps[i-int(dt/stepsize):i+int(dt/stepsize)] = 1 

1793 if verbose>0: 

1794 print('No pulsefish in recording at T=%.2f:%.2f' % (t, t+dt)) 

1795 

1796 

1797 x = np.arange(0, T-dt+stepsize, stepsize) 

1798 y = np.ones(len(x)) 

1799 

1800 running_sum = np.ones(len(np.arange(0, T+stepsize, stepsize))) 

1801 ulabs = np.unique(wclusters[wclusters>=0]) 

1802 

1803 # sliding window 

1804 for j, (t, ignore_step) in enumerate(zip(x, ignore_steps)): 

1805 current_clusters = wclusters[(weod_t>=t)&(weod_t<t+dt)&(wclusters!=-1)] 

1806 current_widths = weod_widths[(weod_t>=t)&(weod_t<t+dt)&(wclusters!=-1)] 

1807 

1808 unique_clusters = np.unique(current_clusters) 

1809 y[j] = len(unique_clusters) 

1810 

1811 if (len(unique_clusters) <= min_clusters) and \ 

1812 (ignore_step==0) and \ 

1813 (len(unique_clusters !=1)): 

1814 

1815 current_labels = np.isin(wclusters, unique_clusters) 

1816 current_height = np.mean(weod_heights[current_labels]) 

1817 

1818 # compute nr of clusters that are too sparse 

1819 clusters_after_deletion = np.unique(remove_sparse_detections(np.copy(clusters[np.isin(clusters, unique_clusters)]), rate*eod_widths[np.isin(clusters, unique_clusters)], rate, T)) 

1820 current_sparse_clusters = len(unique_clusters) - len(clusters_after_deletion[clusters_after_deletion!=-1]) 

1821 

1822 if current_sparse_clusters <= sparse_clusters and \ 

1823 ((current_sparse_clusters<sparse_clusters) or 

1824 (current_height > average_height) or 

1825 (len(unique_clusters) < min_clusters)): 

1826 

1827 keep_clusters = unique_clusters 

1828 min_clusters = len(unique_clusters) 

1829 average_height = current_height 

1830 window_end = t+dt 

1831 sparse_clusters = current_sparse_clusters 

1832 

1833 all_keep_clusters.append(keep_clusters) 

1834 all_windows.append(window_end) 

1835 

1836 if 'moving_fish' in return_data or plot_level>0: 

1837 if 'w' in mf_dict: 

1838 mf_dict['w'].append(np.median(eod_widths[width_classes==w])) 

1839 mf_dict['T'] = T 

1840 mf_dict['dt'].append(dt) 

1841 mf_dict['clusters'].append(wclusters) 

1842 mf_dict['t'].append(weod_t) 

1843 mf_dict['fishcount'].append([x+0.5*(x[1]-x[0]), y]) 

1844 mf_dict['ignore_steps'].append(ignore_steps) 

1845 else: 

1846 mf_dict['w'] = [np.median(eod_widths[width_classes==w])] 

1847 mf_dict['T'] = [T] 

1848 mf_dict['dt'] = [dt] 

1849 mf_dict['clusters'] = [wclusters] 

1850 mf_dict['t'] = [weod_t] 

1851 mf_dict['fishcount'] = [[x+0.5*(x[1]-x[0]), y]] 

1852 mf_dict['ignore_steps'] = [ignore_steps] 

1853 

1854 if verbose>0: 

1855 print('Estimated nr of pulsefish in recording: %i'%len(all_keep_clusters)) 

1856 

1857 if plot_level>0: 

1858 plot_moving_fish(mf_dict['w'], mf_dict['dt'], mf_dict['clusters'],mf_dict['t'], 

1859 mf_dict['fishcount'], T, mf_dict['ignore_steps']) 

1860 if save_plot: 

1861 plt.savefig('%sdelete_moving_fish.%s' % (save_path, ftype)) 

1862 # empty dict 

1863 if 'moving_fish' not in return_data: 

1864 mf_dict = {} 

1865 

1866 # delete all clusters that are not selected 

1867 clusters[np.invert(np.isin(clusters, np.concatenate(all_keep_clusters)))] = -1 

1868 

1869 return clusters, [np.max(all_windows)-np.max(all_dts), np.max(all_windows)], mf_dict 

1870 

1871 

1872def remove_sparse_detections(clusters, eod_widths, rate, T, 

1873 min_density=0.0005, verbose=0): 

1874 """ Remove all EOD clusters that are too sparse 

1875 

1876 Parameters 

1877 ---------- 

1878 clusters : list of ints 

1879 Cluster labels. 

1880 eod_widths : list of ints 

1881 Cluster widths in samples. 

1882 rate : float 

1883 Sampling rate. 

1884 T : float 

1885 Lenght of recording in seconds. 

1886 min_density : float (optional) 

1887 Minimum density for realistic EOD detections. 

1888 verbose : int (optional) 

1889 Verbosity level. 

1890 

1891 Returns 

1892 ------- 

1893 clusters : list of ints 

1894 Cluster labels, where sparse clusters have been set to -1. 

1895 """ 

1896 for c in np.unique(clusters): 

1897 if c!=-1: 

1898 

1899 n = len(clusters[clusters==c]) 

1900 w = np.median(eod_widths[clusters==c])/rate 

1901 

1902 if n*w < T*min_density: 

1903 if verbose>0: 

1904 print('cluster %i is too sparse'%c) 

1905 clusters[clusters==c] = -1 

1906 return clusters