Coverage for src/thunderfish/pulses.py: 0%

602 statements  

« prev     ^ index     » next       coverage.py v7.7.1, created at 2025-03-23 22:57 +0000

1""" 

2Extract and cluster EOD waverforms of pulse-type electric fish. 

3 

4## Main function 

5 

6- `extract_pulsefish()`: checks for pulse-type fish based on the EOD amplitude and shape. 

7 

8""" 

9 

10import os 

11import numpy as np 

12from scipy import stats 

13from scipy.interpolate import interp1d 

14from sklearn.preprocessing import StandardScaler 

15from sklearn.decomposition import PCA 

16from sklearn.cluster import DBSCAN 

17from sklearn.mixture import BayesianGaussianMixture 

18from sklearn.metrics import pairwise_distances 

19from thunderlab.eventdetection import detect_peaks, median_std_threshold 

20from .pulseplots import * 

21 

22import warnings 

23def warn(*args, **kwargs): 

24 """ 

25 Ignore all warnings. 

26 """ 

27 pass 

28warnings.warn = warn 

29 

30try: 

31 from numba import jit 

32except ImportError: 

33 def jit(*args, **kwargs): 

34 def decorator_jit(func): 

35 return func 

36 return decorator_jit 

37 

38 

39# upgrade numpy functions for backwards compatibility: 

40if not hasattr(np, 'isin'): 

41 np.isin = np.in1d 

42 

43def unique_counts(ar): 

44 """ Find the unique elements of an array and their counts, ignoring shape. 

45 

46 The code is condensed from numpy version 1.17.0. 

47  

48 Parameters 

49 ---------- 

50 ar : numpy array 

51 Input array 

52 

53 Returns 

54 ------- 

55 unique_vaulues : numpy array 

56 Unique values in array ar. 

57 unique_counts : numpy array 

58 Number of instances for each unique value in ar. 

59 """ 

60 try: 

61 return np.unique(ar, return_counts=True) 

62 except TypeError: 

63 ar = np.asanyarray(ar).flatten() 

64 ar.sort() 

65 mask = np.empty(ar.shape, dtype=bool_) 

66 mask[:1] = True 

67 mask[1:] = ar[1:] != ar[:-1] 

68 idx = np.concatenate(np.nonzero(mask) + ([mask.size],)) 

69 return ar[mask], np.diff(idx) 

70 

71 

72########################################################################### 

73 

74 

75def extract_pulsefish(data, rate, amax, width_factor_shape=3, 

76 width_factor_wave=8, width_factor_display=4, 

77 verbose=0, plot_level=0, save_plots=False, 

78 save_path='', ftype='png', return_data=[]): 

79 """ Extract and cluster pulse-type fish EODs from single channel data. 

80  

81 Takes recording data containing an unknown number of pulsefish and extracts the mean  

82 EOD and EOD timepoints for each fish present in the recording. 

83  

84 Parameters 

85 ---------- 

86 data: 1-D array of float 

87 The data to be analysed. 

88 rate: float 

89 Sampling rate of the data in Hertz. 

90 amax: float 

91 Maximum amplitude of data range. 

92 width_factor_shape : float (optional) 

93 Width multiplier used for EOD shape analysis. 

94 EOD snippets are extracted based on width between the  

95 peak and trough multiplied by the width factor. 

96 width_factor_wave : float (optional) 

97 Width multiplier used for wavefish detection. 

98 width_factor_display : float (optional) 

99 Width multiplier used for EOD mean extraction and display. 

100 verbose : int (optional) 

101 Verbosity level. 

102 plot_level : int (optional) 

103 Similar to verbosity levels, but with plots.  

104 Only set to > 0 for debugging purposes. 

105 save_plots : bool (optional) 

106 Set to True to save the plots created by plot_level. 

107 save_path: string (optional) 

108 Path for saving plots. 

109 ftype : string (optional) 

110 Define the filetype to save the plots in if save_plots is set to True. 

111 Options are: 'png', 'jpg', 'svg' ... 

112 return_data : list of strings (optional) 

113 Specify data that should be logged and returned in a dictionary. Each clustering  

114 step has a specific keyword that results in adding different variables to the log dictionary. 

115 Optional keys for return_data and the resulting additional key-value pairs to the log dictionary are: 

116  

117 - 'all_eod_times': 

118 - 'all_times': list of two lists of floats. 

119 All peak (`all_times[0]`) and trough times (`all_times[1]`) extracted 

120 by the peak detection algorithm. Times are given in seconds. 

121 - 'eod_troughtimes': list of 1D arrays. 

122 The timepoints in seconds of each unique extracted EOD cluster, 

123 where each 1D array encodes one cluster. 

124  

125 - 'peak_detection': 

126 - "data": 1D numpy array of floats. 

127 Quadratically interpolated data which was used for peak detection. 

128 - "interp_fac": float. 

129 Interpolation factor of raw data. 

130 - "peaks_1": 1D numpy array of ints. 

131 Peak indices on interpolated data after first peak detection step. 

132 - "troughs_1": 1D numpy array of ints. 

133 Peak indices on interpolated data after first peak detection step. 

134 - "peaks_2": 1D numpy array of ints. 

135 Peak indices on interpolated data after second peak detection step. 

136 - "troughs_2": 1D numpy array of ints. 

137 Peak indices on interpolated data after second peak detection step. 

138 - "peaks_3": 1D numpy array of ints. 

139 Peak indices on interpolated data after third peak detection step. 

140 - "troughs_3": 1D numpy array of ints. 

141 Peak indices on interpolated data after third peak detection step. 

142 - "peaks_4": 1D numpy array of ints. 

143 Peak indices on interpolated data after fourth peak detection step. 

144 - "troughs_4": 1D numpy array of ints. 

145 Peak indices on interpolated data after fourth peak detection step. 

146 

147 - 'all_cluster_steps': 

148 - 'rate': float. 

149 Sampling rate of interpolated data. 

150 - 'EOD_widths': list of three 1D numpy arrays. 

151 The first list entry gives the unique labels of all width clusters 

152 as a list of ints. 

153 The second list entry gives the width values for each EOD in samples 

154 as a 1D numpy array of ints. 

155 The third list entry gives the width labels for each EOD 

156 as a 1D numpy array of ints. 

157 - 'EOD_heights': nested lists (2 layers) of three 1D numpy arrays. 

158 The first list entry gives the unique labels of all height clusters 

159 as a list of ints for each width cluster. 

160 The second list entry gives the height values for each EOD 

161 as a 1D numpy array of floats for each width cluster. 

162 The third list entry gives the height labels for each EOD 

163 as a 1D numpy array of ints for each width cluster. 

164 - 'EOD_shapes': nested lists (3 layers) of three 1D numpy arrays 

165 The first list entry gives the raw EOD snippets as a 2D numpy array 

166 for each height cluster in a width cluster. 

167 The second list entry gives the snippet PCA values for each EOD 

168 as a 2D numpy array of floats for each height cluster in a width cluster. 

169 The third list entry gives the shape labels for each EOD as a 1D numpy array 

170 of ints for each height cluster in a width cluster. 

171 - 'discarding_masks': Nested lists (two layers) of 1D numpy arrays. 

172 The masks of EODs that are discarded by the discarding step of the algorithm. 

173 The masks are 1D boolean arrays where instances that are set to True are 

174 discarded by the algorithm. Discarding masks are saved in nested lists 

175 that represent the width and height clusters. 

176 - 'merge_masks': Nested lists (two layers) of 2D numpy arrays. 

177 The masks of EODs that are discarded by the merging step of the algorithm. 

178 The masks are 2D boolean arrays where for each sample point `i` either 

179 `merge_mask[i,0]` or `merge_mask[i,1]` is set to True. Here, merge_mask[:,0] 

180 represents the peak-centered clusters and `merge_mask[:,1]` represents the 

181 trough-centered clusters. Merge masks are saved in nested lists that 

182 represent the width and height clusters. 

183 

184 - 'BGM_width': 

185 - 'BGM_width': dictionary 

186 - 'x': 1D numpy array of floats. 

187 BGM input values (in this case the EOD widths), 

188 - 'use_log': boolean. 

189 True if the z-scored logarithm of the data was used as BGM input. 

190 - 'BGM': list of three 1D numpy arrays. 

191 The first instance are the weights of the Gaussian fits. 

192 The second instance are the means of the Gaussian fits. 

193 The third instance are the variances of the Gaussian fits. 

194 - 'labels': 1D numpy array of ints. 

195 Labels defined by BGM model (before merging based on merge factor). 

196 - xlab': string. 

197 Label for plot (defines the units of the BGM data). 

198 

199 - 'BGM_height': 

200 This key adds a new dictionary for each width cluster. 

201 - 'BGM_height_*n*' : dictionary, where *n* defines the width cluster as an int. 

202 - 'x': 1D numpy array of floats. 

203 BGM input values (in this case the EOD heights), 

204 - 'use_log': boolean. 

205 True if the z-scored logarithm of the data was used as BGM input. 

206 - 'BGM': list of three 1D numpy arrays. 

207 The first instance are the weights of the Gaussian fits. 

208 The second instance are the means of the Gaussian fits. 

209 The third instance are the variances of the Gaussian fits. 

210 - 'labels': 1D numpy array of ints. 

211 Labels defined by BGM model (before merging based on merge factor). 

212 - 'xlab': string. 

213 Label for plot (defines the units of the BGM data). 

214 

215 - 'snippet_clusters': 

216 This key adds a new dictionary for each height cluster. 

217 - 'snippet_clusters*_n_m_p*' : dictionary, where *n* defines the width cluster 

218 (int), *m* defines the height cluster (int) and *p* defines shape clustering 

219 on peak or trough centered EOD snippets (string: 'peak' or 'trough'). 

220 - 'raw_snippets': 2D numpy array (nsamples, nfeatures). 

221 Raw EOD snippets. 

222 - 'snippets': 2D numpy array. 

223 Normalized EOD snippets. 

224 - 'features': 2D numpy array.(nsamples, nfeatures) 

225 PCA values for each normalized EOD snippet. 

226 - 'clusters': 1D numpy array of ints. 

227 Cluster labels. 

228 - 'rate': float. 

229 Sampling rate of snippets. 

230 

231 - 'eod_deletion': 

232 This key adds two dictionaries for each (peak centered) shape cluster, 

233 where *cluster* (int) is the unique shape cluster label. 

234 - 'mask_*cluster*' : list of four booleans. 

235 The mask for each cluster discarding step.  

236 The first instance represents the artefact masks, where artefacts 

237 are set to True. 

238 The second instance represents the unreliable cluster masks, 

239 where unreliable clusters are set to True. 

240 The third instance represents the wavefish masks, where wavefish 

241 are set to True. 

242 The fourth instance represents the sidepeak masks, where sidepeaks 

243 are set to True. 

244 - 'vals_*cluster*' : list of lists. 

245 All variables that are used for each cluster deletion step. 

246 The first instance is a list of two 1D numpy arrays: the mean EOD and 

247 the FFT of that mean EOD. 

248 The second instance is a 1D numpy array with all EOD width to ISI ratios. 

249 The third instance is a list with three entries:  

250 The first entry is a 1D numpy array zoomed out version of the mean EOD. 

251 The second entry is a list of two 1D numpy arrays that define the peak 

252 and trough indices of the zoomed out mean EOD. 

253 The third entry contains a list of two values that represent the 

254 peak-trough pair in the zoomed out mean EOD with the largest height 

255 difference. 

256 - 'rate' : float. 

257 EOD snippet sampling rate. 

258 

259 - 'masks':  

260 - 'masks' : 2D numpy array (4,N). 

261 Each row contains masks for each EOD detected by the EOD peakdetection step.  

262 The first row defines the artefact masks, the second row defines the 

263 unreliable EOD masks,  

264 the third row defines the wavefish masks and the fourth row defines 

265 the sidepeak masks. 

266 

267 - 'moving_fish': 

268 - 'moving_fish': dictionary. 

269 - 'w' : list of floats. 

270 Median width for each width cluster that the moving fish algorithm is 

271 computed on (in seconds). 

272 - 'T' : list of floats. 

273 Lenght of analyzed recording for each width cluster (in seconds). 

274 - 'dt' : list of floats. 

275 Sliding window size (in seconds) for each width cluster. 

276 - 'clusters' : list of 1D numpy int arrays. 

277 Cluster labels for each EOD cluster in a width cluster. 

278 - 't' : list of 1D numpy float arrays. 

279 EOD emission times for each EOD in a width cluster. 

280 - 'fishcount' : list of lists. 

281 Sliding window timepoints and fishcounts for each width cluster. 

282 - 'ignore_steps' : list of 1D int arrays. 

283 Mask for fishcounts that were ignored (ignored if True) in the 

284 moving_fish analysis. 

285  

286 Returns 

287 ------- 

288 mean_eods: list of 2D arrays (3, eod_length) 

289 The average EOD for each detected fish. First column is time in seconds, 

290 second column the mean eod, third column the standard error. 

291 eod_times: list of 1D arrays 

292 For each detected fish the times of EOD peaks or troughs in seconds. 

293 Use these timepoints for EOD averaging. 

294 eod_peaktimes: list of 1D arrays 

295 For each detected fish the times of EOD peaks in seconds. 

296 zoom_window: tuple of floats 

297 Start and endtime of suggested window for plotting EOD timepoints. 

298 log_dict: dictionary 

299 Dictionary with logged variables, where variables to log are specified 

300 by `return_data`. 

301 """ 

302 if verbose > 0: 

303 print('') 

304 if verbose > 1: 

305 print(70*'#') 

306 print('##### extract_pulsefish', 46*'#') 

307 

308 if (save_plots and plot_level>0 and save_path): 

309 # create folder to save things in. 

310 if not os.path.exists(save_path): 

311 os.makedirs(save_path) 

312 else: 

313 save_path = '' 

314 

315 mean_eods, eod_times, eod_peaktimes, zoom_window = [], [], [], [] 

316 log_dict = {} 

317 

318 # interpolate: 

319 i_rate = 500000.0 

320 #i_rate = rate 

321 try: 

322 f = interp1d(np.arange(len(data))/rate, data, kind='quadratic') 

323 i_data = f(np.arange(0.0, (len(data)-1)/rate, 1.0/i_rate)) 

324 except MemoryError: 

325 i_rate = rate 

326 i_data = data 

327 log_dict['data'] = i_data # TODO: could be removed 

328 log_dict['rate'] = i_rate # TODO: could be removed 

329 log_dict['i_data'] = i_data 

330 log_dict['i_rate'] = i_rate 

331 # log_dict["interp_fac"] = interp_fac # TODO: is not set anymore 

332 

333 # standard deviation of data in small snippets: 

334 win_size = int(0.002*rate) # 2ms windows 

335 threshold = median_std_threshold(data, win_size) # TODO make this a parameter 

336 

337 # extract peaks: 

338 if 'peak_detection' in return_data: 

339 x_peak, x_trough, eod_heights, eod_widths, pd_log_dict = \ 

340 detect_pulses(i_data, i_rate, threshold, 

341 width_fac=np.max([width_factor_shape, 

342 width_factor_display, 

343 width_factor_wave]), 

344 verbose=verbose-1, return_data=True) 

345 log_dict.update(pd_log_dict) 

346 else: 

347 x_peak, x_trough, eod_heights, eod_widths = \ 

348 detect_pulses(i_data, i_rate, threshold, 

349 width_fac=np.max([width_factor_shape, 

350 width_factor_display, 

351 width_factor_wave]), 

352 verbose=verbose-1, return_data=False) 

353 

354 if len(x_peak) > 0: 

355 # cluster 

356 clusters, x_merge, c_log_dict = cluster(x_peak, x_trough, 

357 eod_heights, 

358 eod_widths, i_data, 

359 i_rate, 

360 width_factor_shape, 

361 width_factor_wave, 

362 merge_threshold_height=0.1*amax, 

363 verbose=verbose-1, 

364 plot_level=plot_level-1, 

365 save_plots=save_plots, 

366 save_path=save_path, 

367 ftype=ftype, 

368 return_data=return_data) 

369 

370 # extract mean eods and times 

371 mean_eods, eod_times, eod_peaktimes, eod_troughtimes, cluster_labels = \ 

372 extract_means(i_data, x_merge, x_peak, x_trough, eod_widths, clusters, 

373 i_rate, width_factor_display, verbose=verbose-1) 

374 

375 # determine clipped clusters (save them, but ignore in other steps) 

376 clusters, clipped_eods, clipped_times, clipped_peaktimes, clipped_troughtimes = \ 

377 find_clipped_clusters(clusters, mean_eods, eod_times, eod_peaktimes, 

378 eod_troughtimes, cluster_labels, width_factor_display, 

379 verbose=verbose-1) 

380 

381 # delete the moving fish 

382 clusters, zoom_window, mf_log_dict = \ 

383 delete_moving_fish(clusters, x_merge/i_rate, len(data)/rate, 

384 eod_heights, eod_widths/i_rate, i_rate, 

385 verbose=verbose-1, plot_level=plot_level-1, save_plot=save_plots, 

386 save_path=save_path, ftype=ftype, return_data=return_data) 

387 

388 if 'moving_fish' in return_data: 

389 log_dict['moving_fish'] = mf_log_dict 

390 

391 clusters = remove_sparse_detections(clusters, eod_widths, i_rate, 

392 len(data)/rate, verbose=verbose-1) 

393 

394 # extract mean eods 

395 mean_eods, eod_times, eod_peaktimes, eod_troughtimes, cluster_labels = \ 

396 extract_means(i_data, x_merge, x_peak, x_trough, eod_widths, 

397 clusters, i_rate, width_factor_display, verbose=verbose-1) 

398 

399 mean_eods.extend(clipped_eods) 

400 eod_times.extend(clipped_times) 

401 eod_peaktimes.extend(clipped_peaktimes) 

402 eod_troughtimes.extend(clipped_troughtimes) 

403 

404 if plot_level > 0: 

405 plot_all(data, eod_peaktimes, eod_troughtimes, rate, mean_eods) 

406 if save_plots: 

407 plt.savefig('%sextract_pulsefish_results.%s' % (save_path, ftype)) 

408 if save_plots: 

409 plt.close('all') 

410 

411 if 'all_eod_times' in return_data: 

412 log_dict['all_times'] = [x_peak/i_rate, x_trough/i_rate] 

413 log_dict['eod_troughtimes'] = eod_troughtimes 

414 

415 log_dict.update(c_log_dict) 

416 

417 return mean_eods, eod_times, eod_peaktimes, zoom_window, log_dict 

418 

419 

420def detect_pulses(data, rate, thresh, min_rel_slope_diff=0.25, 

421 min_width=0.00005, max_width=0.01, width_fac=5.0, 

422 verbose=0, return_data=False): 

423 """Detect pulses in data. 

424 

425 Was `def extract_eod_times(data, rate, width_factor, 

426 interp_freq=500000, max_peakwidth=0.01, 

427 min_peakwidth=None, verbose=0, return_data=[], 

428 save_path='')` before. 

429 

430 Parameters 

431 ---------- 

432 data: 1-D array of float 

433 The data to be analysed. 

434 rate: float 

435 Sampling rate of the data. 

436 thresh: float 

437 Threshold for peak and trough detection via `detect_peaks()`. 

438 Must be a positive number that sets the minimum difference 

439 between a peak and a trough. 

440 min_rel_slope_diff: float 

441 Minimum required difference between left and right slope (between 

442 peak and troughs) relative to mean slope for deciding which trough 

443 to take besed on slope difference. 

444 min_width: float 

445 Minimum width (peak-trough distance) of pulses in seconds. 

446 max_width: float 

447 Maximum width (peak-trough distance) of pulses in seconds. 

448 width_fac: float 

449 Pulses extend plus or minus `width_fac` times their width 

450 (distance between peak and assigned trough). 

451 Only pulses are returned that can fully be analysed with this width. 

452 verbose : int (optional) 

453 Verbosity level. 

454 return_data : bool 

455 If `True` data of this function is logged and returned (see 

456 extract_pulsefish()). 

457 

458 Returns 

459 ------- 

460 peak_indices: array of ints 

461 Indices of EOD peaks in data. 

462 trough_indices: array of ints 

463 Indices of EOD troughs in data. There is one x_trough for each x_peak. 

464 heights: array of floats 

465 EOD heights for each x_peak. 

466 widths: array of ints 

467 EOD widths for each x_peak (in samples). 

468 peak_detection_result : dictionary 

469 Key value pairs of logged data. 

470 This is only returned if `return_data` is `True`. 

471 

472 """ 

473 peak_detection_result = {} 

474 

475 # detect peaks and troughs in the data: 

476 peak_indices, trough_indices = detect_peaks(data, thresh) 

477 if verbose > 0: 

478 print('Peaks/troughs detected in data: %5d %5d' 

479 % (len(peak_indices), len(trough_indices))) 

480 if return_data: 

481 peak_detection_result.update(peaks_1=np.array(peak_indices), 

482 troughs_1=np.array(trough_indices)) 

483 if len(peak_indices) < 2 or \ 

484 len(trough_indices) < 2 or \ 

485 len(peak_indices) > len(data)/20: 

486 # TODO: if too many peaks increase threshold! 

487 if verbose > 0: 

488 print('No or too many peaks/troughs detected in data.') 

489 if return_data: 

490 return np.array([], dtype=int), np.array([], dtype=int), \ 

491 np.array([]), np.array([], dtype=int), peak_detection_result 

492 else: 

493 return np.array([], dtype=int), np.array([], dtype=int), \ 

494 np.array([]), np.array([], dtype=int) 

495 

496 # assign troughs to peaks: 

497 peak_indices, trough_indices, heights, widths, slopes = \ 

498 assign_side_peaks(data, peak_indices, trough_indices, min_rel_slope_diff) 

499 if verbose > 1: 

500 print('Number of peaks after assigning side-peaks: %5d' 

501 % (len(peak_indices))) 

502 if return_data: 

503 peak_detection_result.update(peaks_2=np.array(peak_indices), 

504 troughs_2=np.array(trough_indices)) 

505 

506 # check widths: 

507 keep = ((widths>min_width*rate) & (widths<max_width*rate)) 

508 peak_indices = peak_indices[keep] 

509 trough_indices = trough_indices[keep] 

510 heights = heights[keep] 

511 widths = widths[keep] 

512 slopes = slopes[keep] 

513 if verbose > 1: 

514 print('Number of peaks after checking pulse width: %5d' 

515 % (len(peak_indices))) 

516 if return_data: 

517 peak_detection_result.update(peaks_3=np.array(peak_indices), 

518 troughs_3=np.array(trough_indices)) 

519 

520 # discard connected peaks: 

521 same = np.nonzero(trough_indices[:-1] == trough_indices[1:])[0] 

522 keep = np.ones(len(trough_indices), dtype=bool) 

523 for i in same: 

524 # same troughs at trough_indices[i] and trough_indices[i+1]: 

525 s = slopes[i:i+2] 

526 rel_slopes = np.abs(np.diff(s))[0]/np.mean(s) 

527 if rel_slopes > min_rel_slope_diff: 

528 keep[i+(s[1]<s[0])] = False 

529 else: 

530 keep[i+(heights[i+1]<heights[i])] = False 

531 peak_indices = peak_indices[keep] 

532 trough_indices = trough_indices[keep] 

533 heights = heights[keep] 

534 widths = widths[keep] 

535 if verbose > 1: 

536 print('Number of peaks after merging pulses: %5d' 

537 % (len(peak_indices))) 

538 if return_data: 

539 peak_detection_result.update(peaks_4=np.array(peak_indices), 

540 troughs_4=np.array(trough_indices)) 

541 if len(peak_indices) == 0: 

542 if verbose > 0: 

543 print('No peaks remain as pulse candidates.') 

544 if return_data: 

545 return np.array([], dtype=int), np.array([], dtype=int), \ 

546 np.array([]), np.array([], dtype=int), peak_detection_result 

547 else: 

548 return np.array([], dtype=int), np.array([], dtype=int), \ 

549 np.array([]), np.array([], dtype=int) 

550 

551 # only take those where the maximum cutwidth does not cause issues - 

552 # if the width_fac times the width + x is more than length. 

553 keep = ((peak_indices - widths > 0) & 

554 (peak_indices + widths < len(data)) & 

555 (trough_indices - widths > 0) & 

556 (trough_indices + widths < len(data))) 

557 

558 if verbose > 0: 

559 print('Remaining peaks after EOD extraction: %5d' 

560 % (p.sum(keep))) 

561 print('') 

562 

563 if return_data: 

564 return peak_indices[keep], trough_indices[keep], \ 

565 heights[keep], widths[keep], peak_detection_result 

566 else: 

567 return peak_indices[keep], trough_indices[keep], \ 

568 heights[keep], widths[keep] 

569 

570 

571@jit(nopython=True) 

572def assign_side_peaks(data, peak_indices, trough_indices, 

573 min_rel_slope_diff=0.25): 

574 """Assign to each peak the trough resulting in a pulse with the steepest slope or largest height. 

575 

576 The slope between a peak and a trough is computed as the height 

577 difference divided by the distance between peak and trough. If the 

578 slopes between the left and the right trough differ by less than 

579 `min_rel_slope_diff`, then just the heigths between and the two 

580 troughs relative to the peak are compared. 

581 

582 Was `def detect_eod_peaks(data, main_indices, side_indices, 

583 max_width=20, min_width=2, verbose=0)` before. 

584 

585 Parameters 

586 ---------- 

587 data: array of floats 

588 Data in which the events were detected. 

589 peak_indices: array of ints 

590 Indices of the detected peaks in the data time series. 

591 trough_indices: array of ints 

592 Indices of the detected troughs in the data time series.  

593 min_rel_slope_diff: float 

594 Minimum required difference of left and right slope relative 

595 to mean slope. 

596 

597 Returns 

598 ------- 

599 peak_indices: array of ints 

600 Peak indices. Same as input `peak_indices` but potentially shorter 

601 by one or two elements. 

602 trough_indices: array of ints 

603 Corresponding trough indices of trough to the left or right 

604 of the peaks. 

605 heights: array of floats 

606 Peak heights (distance between peak and corresponding trough amplitude) 

607 widths: array of ints 

608 Peak widths (distance between peak and corresponding trough indices) 

609 slopes: array of floats 

610 Peak slope (height divided by width) 

611 """ 

612 # is a main or side peak first? 

613 peak_first = int(peak_indices[0] < trough_indices[0]) 

614 # is a main or side peak last? 

615 peak_last = int(peak_indices[-1] > trough_indices[-1]) 

616 # ensure all peaks to have side peaks (troughs) at both sides, 

617 # i.e. troughs at same index and next index are before and after peak: 

618 peak_indices = peak_indices[peak_first:len(peak_indices)-peak_last] 

619 y = data[peak_indices] 

620 

621 # indices of troughs on the left and right side of main peaks: 

622 l_indices = np.arange(len(peak_indices)) 

623 r_indices = l_indices + 1 

624 

625 # indices, distance to peak, height, and slope of left troughs: 

626 l_side_indices = trough_indices[l_indices] 

627 l_distance = np.abs(peak_indices - l_side_indices) 

628 l_height = np.abs(y - data[l_side_indices]) 

629 l_slope = np.abs(l_height/l_distance) 

630 

631 # indices, distance to peak, height, and slope of right troughs: 

632 r_side_indices = trough_indices[r_indices] 

633 r_distance = np.abs(r_side_indices - peak_indices) 

634 r_height = np.abs(y - data[r_side_indices]) 

635 r_slope = np.abs(r_height/r_distance) 

636 

637 # which trough to assign to the peak? 

638 # - either the one with the steepest slope, or 

639 # - when slopes are similar on both sides 

640 # (within `min_rel_slope_diff` difference), 

641 # the trough with the maximum height difference to the peak: 

642 rel_slopes = np.abs(l_slope-r_slope)/(0.5*(l_slope+r_slope)) 

643 take_slopes = rel_slopes > min_rel_slope_diff 

644 take_left = l_height > r_height 

645 take_left[take_slopes] = l_slope[take_slopes] > r_slope[take_slopes] 

646 

647 # assign troughs, heights, widths, and slopes: 

648 trough_indices = np.where(take_left, 

649 trough_indices[:-1], trough_indices[1:]) 

650 heights = np.where(take_left, l_height, r_height) 

651 widths = np.where(take_left, l_distance, r_distance) 

652 slopes = np.where(take_left, l_slope, r_slope) 

653 

654 return peak_indices, trough_indices, heights, widths, slopes 

655 

656 

657def cluster(eod_xp, eod_xt, eod_heights, eod_widths, data, rate, 

658 width_factor_shape, width_factor_wave, n_gaus_height=10, 

659 merge_threshold_height=0.1, n_gaus_width=3, 

660 merge_threshold_width=0.5, minp=10, verbose=0, 

661 plot_level=0, save_plots=False, save_path='', ftype='pdf', 

662 return_data=[]): 

663 """Cluster EODs. 

664  

665 First cluster on EOD widths using a Bayesian Gaussian 

666 Mixture (BGM) model, then cluster on EOD heights using a 

667 BGM model. Lastly, cluster on EOD waveform with DBSCAN. 

668 Clustering on EOD waveform is performed twice, once on 

669 peak-centered EODs and once on trough-centered EODs. 

670 Non-pulsetype EOD clusters are deleted, and clusters are 

671 merged afterwards. 

672 

673 Parameters 

674 ---------- 

675 eod_xp : list of ints 

676 Location of EOD peaks in indices. 

677 eod_xt: list of ints 

678 Locations of EOD troughs in indices. 

679 eod_heights: list of floats 

680 EOD heights. 

681 eod_widths: list of ints 

682 EOD widths in samples. 

683 data: array of floats 

684 Data in which to detect pulse EODs. 

685 rate : float 

686 Sampling rate of `data`. 

687 width_factor_shape : float 

688 Multiplier for snippet extraction width. This factor is 

689 multiplied with the width between the peak and through of a 

690 single EOD. 

691 width_factor_wave : float 

692 Multiplier for wavefish extraction width. 

693 n_gaus_height : int (optional) 

694 Number of gaussians to use for the clustering based on EOD height. 

695 merge_threshold_height : float (optional) 

696 Threshold for merging clusters that are similar in height. 

697 n_gaus_width : int (optional) 

698 Number of gaussians to use for the clustering based on EOD width. 

699 merge_threshold_width : float (optional) 

700 Threshold for merging clusters that are similar in width. 

701 minp : int (optional) 

702 Minimum number of points for core clusters (DBSCAN). 

703 verbose : int (optional) 

704 Verbosity level. 

705 plot_level : int (optional) 

706 Similar to verbosity levels, but with plots.  

707 Only set to > 0 for debugging purposes. 

708 save_plots : bool (optional) 

709 Set to True to save created plots. 

710 save_path : string (optional) 

711 Path to save plots to. Only used if save_plots==True. 

712 ftype : string (optional) 

713 Filetype to save plot images in. 

714 return_data : list of strings (optional) 

715 Keys that specify data to be logged. Keys that can be used to log data 

716 in this function are: 'all_cluster_steps', 'BGM_width', 'BGM_height', 

717 'snippet_clusters', 'eod_deletion' (see extract_pulsefish()). 

718 

719 Returns 

720 ------- 

721 labels : list of ints 

722 EOD cluster labels based on height and EOD waveform. 

723 x_merge : list of ints 

724 Locations of EODs in clusters. 

725 saved_data : dictionary 

726 Key value pairs of logged data. Data to be logged is specified 

727 by return_data. 

728 

729 """ 

730 saved_data = {} 

731 

732 if plot_level>0 or 'all_cluster_steps' in return_data: 

733 all_heightlabels = [] 

734 all_shapelabels = [] 

735 all_snippets = [] 

736 all_features = [] 

737 all_heights = [] 

738 all_unique_heightlabels = [] 

739 

740 all_p_clusters = -1 * np.ones(len(eod_xp)) 

741 all_t_clusters = -1 * np.ones(len(eod_xp)) 

742 artefact_masks_p = np.ones(len(eod_xp), dtype=bool) 

743 artefact_masks_t = np.ones(len(eod_xp), dtype=bool) 

744 

745 x_merge = -1 * np.ones(len(eod_xp)) 

746 

747 max_label_p = 0 # keep track of the labels so that no labels are overwritten 

748 max_label_t = 0 

749 

750 # loop only over height clusters that are bigger than minp 

751 # first cluster on width 

752 width_labels, bgm_log_dict = BGM(1000*eod_widths/rate, 

753 merge_threshold_width, 

754 n_gaus_width, use_log=False, 

755 verbose=verbose-1, 

756 plot_level=plot_level-1, 

757 xlabel='width [ms]', 

758 save_plot=save_plots, 

759 save_path=save_path, 

760 save_name='width', ftype=ftype, 

761 return_data=return_data) 

762 saved_data.update(bgm_log_dict) 

763 

764 if verbose > 0: 

765 print('Clusters generated based on EOD width:') 

766 for l in np.unique(width_labels): 

767 print(f'N_{l} = {len(width_labels[width_labels==l]):4d} h_{l} = {np.mean(eod_widths[width_labels==l]):.4f}') 

768 

769 w_labels, w_counts = unique_counts(width_labels) 

770 unique_width_labels = w_labels[w_counts>minp] 

771 

772 for wi, width_label in enumerate(unique_width_labels): 

773 

774 # select only features in one width cluster at a time 

775 w_eod_widths = eod_widths[width_labels==width_label] 

776 w_eod_heights = eod_heights[width_labels==width_label] 

777 w_eod_xp = eod_xp[width_labels==width_label] 

778 w_eod_xt = eod_xt[width_labels==width_label] 

779 width = int(width_factor_shape*np.median(w_eod_widths)) 

780 if width > w_eod_xp[0]: 

781 width = w_eod_xp[0] 

782 if width > w_eod_xt[0]: 

783 width = w_eod_xt[0] 

784 if width > len(data) - w_eod_xp[-1]: 

785 width = len(data) - w_eod_xp[-1] 

786 if width > len(data) - w_eod_xt[-1]: 

787 width = len(data) - w_eod_xt[-1] 

788 

789 wp_clusters = -1 * np.ones(len(w_eod_xp)) 

790 wt_clusters = -1 * np.ones(len(w_eod_xp)) 

791 wartefact_mask = np.ones(len(w_eod_xp)) 

792 

793 # determine height labels 

794 raw_p_snippets, p_snippets, p_features, p_bg_ratio = \ 

795 extract_snippet_features(data, w_eod_xp, w_eod_heights, width) 

796 raw_t_snippets, t_snippets, t_features, t_bg_ratio = \ 

797 extract_snippet_features(data, w_eod_xt, w_eod_heights, width) 

798 

799 height_labels, bgm_log_dict = \ 

800 BGM(w_eod_heights, min(merge_threshold_height, 

801 np.median(np.min(np.vstack([p_bg_ratio, t_bg_ratio]), 

802 axis=0))), n_gaus_height, use_log=True, 

803 verbose=verbose-1, plot_level=plot_level-1, xlabel = 

804 'height [a.u.]', save_plot=save_plots, 

805 save_path=save_path, save_name = 'height_%d' % wi, 

806 ftype=ftype, return_data=return_data) 

807 saved_data.update(bgm_log_dict) 

808 

809 if verbose > 0: 

810 print('Clusters generated based on EOD height:') 

811 for l in np.unique(height_labels): 

812 print(f'N_{l} = {len(height_labels[height_labels==l]):4d} h_{l} = {np.mean(w_eod_heights[height_labels==l]):.4f}') 

813 

814 h_labels, h_counts = unique_counts(height_labels) 

815 unique_height_labels = h_labels[h_counts>minp] 

816 

817 if plot_level > 0 or 'all_cluster_steps' in return_data: 

818 all_heightlabels.append(height_labels) 

819 all_heights.append(w_eod_heights) 

820 all_unique_heightlabels.append(unique_height_labels) 

821 shape_labels = [] 

822 cfeatures = [] 

823 csnippets = [] 

824 

825 for hi, height_label in enumerate(unique_height_labels): 

826 

827 h_eod_widths = w_eod_widths[height_labels==height_label] 

828 h_eod_heights = w_eod_heights[height_labels==height_label] 

829 h_eod_xp = w_eod_xp[height_labels==height_label] 

830 h_eod_xt = w_eod_xt[height_labels==height_label] 

831 

832 p_clusters = cluster_on_shape(p_features[height_labels==height_label], 

833 p_bg_ratio, minp, verbose=0) 

834 t_clusters = cluster_on_shape(t_features[height_labels==height_label], 

835 t_bg_ratio, minp, verbose=0) 

836 

837 if plot_level > 1: 

838 plot_feature_extraction(raw_p_snippets[height_labels==height_label], 

839 p_snippets[height_labels==height_label], 

840 p_features[height_labels==height_label], 

841 p_clusters, 1/rate, 0) 

842 plt.savefig('%sDBSCAN_peak_w%i_h%i.%s' % (save_path, wi, hi, ftype)) 

843 plot_feature_extraction(raw_t_snippets[height_labels==height_label], 

844 t_snippets[height_labels==height_label], 

845 t_features[height_labels==height_label], 

846 t_clusters, 1/rate, 1) 

847 plt.savefig('%sDBSCAN_trough_w%i_h%i.%s' % (save_path, wi, hi, ftype)) 

848 

849 if 'snippet_clusters' in return_data: 

850 saved_data[f'snippet_clusters_{width_label}_{height_label}_peak'] = { 

851 'raw_snippets': raw_p_snippets[height_labels==height_label], 

852 'snippets': p_snippets[height_labels==height_label], 

853 'features': p_features[height_labels==height_label], 

854 'clusters': p_clusters, 

855 'rate': rate} 

856 saved_data['snippet_clusters_{width_label}_{height_label}_trough'] = { 

857 'raw_snippets': raw_t_snippets[height_labels==height_label], 

858 'snippets': t_snippets[height_labels==height_label], 

859 'features': t_features[height_labels==height_label], 

860 'clusters': t_clusters, 

861 'rate': rate} 

862 

863 if plot_level > 0 or 'all_cluster_steps' in return_data: 

864 shape_labels.append([p_clusters, t_clusters]) 

865 cfeatures.append([p_features[height_labels==height_label], 

866 t_features[height_labels==height_label]]) 

867 csnippets.append([p_snippets[height_labels==height_label], 

868 t_snippets[height_labels==height_label]]) 

869 

870 p_clusters[p_clusters==-1] = -max_label_p - 1 

871 wp_clusters[height_labels==height_label] = p_clusters + max_label_p 

872 max_label_p = max(np.max(wp_clusters), np.max(all_p_clusters)) + 1 

873 

874 t_clusters[t_clusters==-1] = -max_label_t - 1 

875 wt_clusters[height_labels==height_label] = t_clusters + max_label_t 

876 max_label_t = max(np.max(wt_clusters), np.max(all_t_clusters)) + 1 

877 

878 if verbose > 0: 

879 if np.max(wp_clusters) == -1: 

880 print(f'No EOD peaks in width cluster {width_label}') 

881 else: 

882 unique_clusters = np.unique(wp_clusters[wp_clusters!=-1]) 

883 if len(unique_clusters) > 1: 

884 print('{len(unique_clusters)} different EOD peaks in width cluster {width_label}') 

885 

886 if plot_level > 0 or 'all_cluster_steps' in return_data: 

887 all_shapelabels.append(shape_labels) 

888 all_snippets.append(csnippets) 

889 all_features.append(cfeatures) 

890 

891 # for each cluster, save fft + label 

892 # so I end up with features for each label, and the masks. 

893 # then I can extract e.g. first artefact or wave etc. 

894 

895 # remove artefacts here, based on the mean snippets ffts. 

896 artefact_masks_p[width_labels==width_label], sdict = \ 

897 remove_artefacts(p_snippets, wp_clusters, rate, 

898 verbose=verbose-1, return_data=return_data) 

899 saved_data.update(sdict) 

900 artefact_masks_t[width_labels==width_label], _ = \ 

901 remove_artefacts(t_snippets, wt_clusters, rate, 

902 verbose=verbose-1, return_data=return_data) 

903 

904 # update maxlab so that no clusters are overwritten 

905 all_p_clusters[width_labels==width_label] = wp_clusters 

906 all_t_clusters[width_labels==width_label] = wt_clusters 

907 

908 # remove all non-reliable clusters 

909 unreliable_fish_mask_p, saved_data = \ 

910 delete_unreliable_fish(all_p_clusters, eod_widths, eod_xp, 

911 verbose=verbose-1, sdict=saved_data) 

912 unreliable_fish_mask_t, _ = \ 

913 delete_unreliable_fish(all_t_clusters, eod_widths, eod_xt, verbose=verbose-1) 

914 

915 wave_mask_p, sidepeak_mask_p, saved_data = \ 

916 delete_wavefish_and_sidepeaks(data, all_p_clusters, eod_xp, eod_widths, 

917 width_factor_wave, verbose=verbose-1, sdict=saved_data) 

918 wave_mask_t, sidepeak_mask_t, _ = \ 

919 delete_wavefish_and_sidepeaks(data, all_t_clusters, eod_xt, eod_widths, 

920 width_factor_wave, verbose=verbose-1) 

921 

922 og_clusters = [np.copy(all_p_clusters), np.copy(all_t_clusters)] 

923 og_labels = np.copy(all_p_clusters + all_t_clusters) 

924 

925 # go through all clusters and masks?? 

926 all_p_clusters[(artefact_masks_p | unreliable_fish_mask_p | wave_mask_p | sidepeak_mask_p)] = -1 

927 all_t_clusters[(artefact_masks_t | unreliable_fish_mask_t | wave_mask_t | sidepeak_mask_t)] = -1 

928 

929 # merge here. 

930 all_clusters, x_merge, mask = merge_clusters(np.copy(all_p_clusters), 

931 np.copy(all_t_clusters), 

932 eod_xp, eod_xt, 

933 verbose=verbose - 1) 

934 

935 if 'all_cluster_steps' in return_data or plot_level > 0: 

936 all_dmasks = [] 

937 all_mmasks = [] 

938 

939 discarding_masks = \ 

940 np.vstack(((artefact_masks_p | unreliable_fish_mask_p | wave_mask_p | sidepeak_mask_p), 

941 (artefact_masks_t | unreliable_fish_mask_t | wave_mask_t | sidepeak_mask_t))) 

942 merge_mask = mask 

943 

944 # save the masks in the same formats as the snippets 

945 for wi, (width_label, w_shape_label, heightlabels, unique_height_labels) in enumerate(zip(unique_width_labels, all_shapelabels, all_heightlabels, all_unique_heightlabels)): 

946 w_dmasks = discarding_masks[:,width_labels==width_label] 

947 w_mmasks = merge_mask[:,width_labels==width_label] 

948 

949 wd_2 = [] 

950 wm_2 = [] 

951 

952 for hi, (height_label, h_shape_label) in enumerate(zip(unique_height_labels, w_shape_label)): 

953 

954 h_dmasks = w_dmasks[:,heightlabels==height_label] 

955 h_mmasks = w_mmasks[:,heightlabels==height_label] 

956 

957 wd_2.append(h_dmasks) 

958 wm_2.append(h_mmasks) 

959 

960 all_dmasks.append(wd_2) 

961 all_mmasks.append(wm_2) 

962 

963 if plot_level > 0: 

964 plot_clustering(rate, [unique_width_labels, eod_widths, width_labels], 

965 [all_unique_heightlabels, all_heights, all_heightlabels], 

966 [all_snippets, all_features, all_shapelabels], 

967 all_dmasks, all_mmasks) 

968 if save_plots: 

969 plt.savefig('%sclustering.%s' % (save_path, ftype)) 

970 

971 if 'all_cluster_steps' in return_data: 

972 saved_data = {'rate': rate, 

973 'EOD_widths': [unique_width_labels, eod_widths, width_labels], 

974 'EOD_heights': [all_unique_heightlabels, all_heights, all_heightlabels], 

975 'EOD_shapes': [all_snippets, all_features, all_shapelabels], 

976 'discarding_masks': all_dmasks, 

977 'merge_masks': all_mmasks 

978 } 

979 

980 if 'masks' in return_data: 

981 saved_data = {'masks' : np.vstack(((artefact_masks_p & artefact_masks_t), 

982 (unreliable_fish_mask_p & unreliable_fish_mask_t), 

983 (wave_mask_p & wave_mask_t), 

984 (sidepeak_mask_p & sidepeak_mask_t), 

985 (all_p_clusters+all_t_clusters)))} 

986 

987 if verbose > 0: 

988 print('Clusters generated based on height, width and shape: ') 

989 for l in np.unique(all_clusters[all_clusters != -1]): 

990 print('N_{int(l)} = {len(all_clusters[all_clusters == l]):4d}') 

991 

992 return all_clusters, x_merge, saved_data 

993 

994 

995def BGM(x, merge_threshold=0.1, n_gaus=5, max_iter=200, n_init=5, 

996 use_log=False, verbose=0, plot_level=0, xlabel='x [a.u.]', 

997 save_plot=False, save_path='', save_name='', ftype='pdf', 

998 return_data=[]): 

999 """ Use a Bayesian Gaussian Mixture Model to cluster one-dimensional data. 

1000 

1001 Additional steps are used to merge clusters that are closer than 

1002 `merge_threshold`. Broad gaussian fits that cover one or more other 

1003 gaussian fits are split by their intersections with the other 

1004 gaussians. 

1005 

1006 Parameters 

1007 ---------- 

1008 x : 1D numpy array 

1009 Features to compute clustering on.  

1010 

1011 merge_threshold : float (optional) 

1012 Ratio for merging nearby gaussians. 

1013 n_gaus: int (optional) 

1014 Maximum number of gaussians to fit on data. 

1015 max_iter : int (optional) 

1016 Maximum number of iterations for gaussian fit. 

1017 n_init : int (optional) 

1018 Number of initializations for the gaussian fit. 

1019 use_log: boolean (optional) 

1020 Set to True to compute the gaussian fit on the logarithm of x. 

1021 Can improve clustering on features with nonlinear relationships such as peak height. 

1022 verbose : int (optional) 

1023 Verbosity level. 

1024 plot_level : int (optional) 

1025 Similar to verbosity levels, but with plots.  

1026 Only set to > 0 for debugging purposes. 

1027 xlabel : string (optional) 

1028 Xlabel for displaying BGM plot. 

1029 save_plot : bool (optional) 

1030 Set to True to save created plot. 

1031 save_path : string (optional) 

1032 Path to location where data should be saved. Only used if save_plot==True. 

1033 save_name : string (optional) 

1034 Filename of the saved plot. Usefull as usually multiple BGM models are generated. 

1035 ftype : string (optional) 

1036 Filetype of plot image if save_plots==True. 

1037 return_data : list of strings (optional) 

1038 Keys that specify data to be logged. Keys that can be used to log data 

1039 in this function are: 'BGM_width' and/or 'BGM_height' (see extract_pulsefish()). 

1040 

1041 Returns 

1042 ------- 

1043 labels : 1D numpy array 

1044 Cluster labels for each sample in x. 

1045 bgm_dict : dictionary 

1046 Key value pairs of logged data. Data to be logged is specified by return_data. 

1047 """ 

1048 

1049 bgm_dict = {} 

1050 

1051 if len(np.unique(x)) > n_gaus: 

1052 BGM_model = BayesianGaussianMixture(n_components=n_gaus, max_iter=max_iter, n_init=n_init) 

1053 if use_log: 

1054 labels = BGM_model.fit_predict(stats.zscore(np.log(x)).reshape(-1, 1)) 

1055 else: 

1056 labels = BGM_model.fit_predict(stats.zscore(x).reshape(-1, 1)) 

1057 else: 

1058 return np.zeros(len(x)), bgm_dict 

1059 

1060 if verbose>0: 

1061 if not BGM_model.converged_: 

1062 print('!!! Gaussian mixture did not converge !!!') 

1063 

1064 cur_labels = np.unique(labels) 

1065 

1066 # map labels to be increasing for increasing values for x 

1067 maxlab = len(cur_labels) 

1068 aso = np.argsort([np.median(x[labels == l]) for l in cur_labels]) + 100 

1069 for i, a in zip(cur_labels, aso): 

1070 labels[labels==i] = a 

1071 labels = labels - 100 

1072 

1073 # separate gaussian clusters that can be split by other clusters 

1074 splits = np.sort(np.copy(x))[1:][np.diff(labels[np.argsort(x)])!=0] 

1075 

1076 labels[:] = 0 

1077 for i, split in enumerate(splits): 

1078 labels[x>=split] = i+1 

1079 

1080 labels_before_merge = np.copy(labels) 

1081 

1082 # merge gaussian clusters that are closer than merge_threshold 

1083 labels = merge_gaussians(x, labels, merge_threshold) 

1084 

1085 if 'BGM_'+save_name.split('_')[0] in return_data or plot_level>0: 

1086 

1087 #sort model attributes by model_means_ 

1088 means = [m[0] for m in BGM_model.means_] 

1089 weights = [w for w in BGM_model.weights_] 

1090 variances = [v[0][0] for v in BGM_model.covariances_] 

1091 weights = [w for _, w in sorted(zip(means, weights))] 

1092 variances = [v for _, v in sorted(zip(means, variances))] 

1093 means = sorted(means) 

1094 

1095 if plot_level>0: 

1096 plot_bgm(x, means, variances, weights, use_log, labels_before_merge, 

1097 labels, xlabel) 

1098 if save_plot: 

1099 plt.savefig('%sBGM_%s.%s' % (save_path, save_name, ftype)) 

1100 

1101 if 'BGM_'+save_name.split('_')[0] in return_data: 

1102 bgm_dict['BGM_'+save_name] = {'x':x, 

1103 'use_log':use_log, 

1104 'BGM':[weights, means, variances], 

1105 'labels':labels_before_merge, 

1106 'xlab':xlabel} 

1107 

1108 return labels, bgm_dict 

1109 

1110 

1111def merge_gaussians(x, labels, merge_threshold=0.1): 

1112 """ Merge all clusters which have medians which are near. Only works in 1D. 

1113 

1114 Parameters 

1115 ---------- 

1116 x : 1D array of ints or floats 

1117 Features used for clustering. 

1118 labels : 1D array of ints 

1119 Labels for each sample in x. 

1120 merge_threshold : float (optional) 

1121 Similarity threshold to merge clusters. 

1122 

1123 Returns 

1124 ------- 

1125 labels : 1D array of ints 

1126 Merged labels for each sample in x. 

1127 """ 

1128 

1129 # compare all the means of the gaussians. If they are too close, merge them. 

1130 unique_labels = np.unique(labels[labels!=-1]) 

1131 x_medians = [np.median(x[labels==l]) for l in unique_labels] 

1132 

1133 # fill a dict with the label mappings 

1134 mapping = {} 

1135 for label_1, x_m1 in zip(unique_labels, x_medians): 

1136 for label_2, x_m2 in zip(unique_labels, x_medians): 

1137 if label_1!=label_2: 

1138 if np.abs(np.diff([x_m1, x_m2]))/np.max([x_m1, x_m2]) < merge_threshold: 

1139 mapping[label_1] = label_2 

1140 # apply mapping 

1141 for map_key, map_value in mapping.items(): 

1142 labels[labels==map_key] = map_value 

1143 

1144 return labels 

1145 

1146 

1147def extract_snippet_features(data, eod_x, eod_heights, width, n_pc=5): 

1148 """ Extract snippets from recording data, normalize them, and perform PCA. 

1149 

1150 Parameters 

1151 ---------- 

1152 data : 1D numpy array of floats 

1153 Recording data. 

1154 eod_x : 1D array of ints 

1155 Locations of EODs as indices. 

1156 eod_heights: 1D array of floats 

1157 EOD heights. 

1158 width : int 

1159 Width to cut out to each side in samples. 

1160 

1161 n_pc : int (optional) 

1162 Number of PCs to use for PCA. 

1163 

1164 Returns 

1165 ------- 

1166 raw_snippets : 2D numpy array (N, EOD_width) 

1167 Raw extracted EOD snippets. 

1168 snippets : 2D numpy array (N, EOD_width) 

1169 Normalized EOD snippets 

1170 features : 2D numpy array (N,n_pc) 

1171 PC values of EOD snippets 

1172 bg_ratio : 1D numpy array (N) 

1173 Ratio of the background activity slopes compared to EOD height. 

1174 """ 

1175 # extract snippets with corresponding width 

1176 raw_snippets = np.vstack([data[x-width:x+width] for x in eod_x]) 

1177 

1178 # subtract the slope and normalize the snippets 

1179 snippets, bg_ratio = subtract_slope(np.copy(raw_snippets), eod_heights) 

1180 snippets = StandardScaler().fit_transform(snippets.T).T 

1181 

1182 # scale so that the absolute integral = 1. 

1183 snippets = (snippets.T/np.sum(np.abs(snippets), axis=1)).T 

1184 

1185 # compute features for clustering on waveform 

1186 features = PCA(n_pc).fit_transform(snippets) 

1187 

1188 return raw_snippets, snippets, features, bg_ratio 

1189 

1190 

1191def cluster_on_shape(features, bg_ratio, minp, percentile=80, 

1192 max_epsilon=0.01, slope_ratio_factor=4, 

1193 min_cluster_fraction=0.01, verbose=0): 

1194 """Separate EODs by their shape using DBSCAN. 

1195 

1196 Parameters 

1197 ---------- 

1198 features : 2D numpy array of floats (N, n_pc) 

1199 PCA features of each EOD in a recording. 

1200 bg_ratio : 1D array of floats 

1201 Ratio of background activity slope the EOD is superimposed on. 

1202 minp : int 

1203 Minimum number of points for core cluster (DBSCAN). 

1204 

1205 percentile : int (optional) 

1206 Percentile of KNN distribution, where K=minp, to use as epsilon for DBSCAN. 

1207 max_epsilon : float (optional) 

1208 Maximum epsilon to use for DBSCAN clustering. This is used to avoid adding 

1209 noisy clusters. 

1210 slope_ratio_factor : float (optional) 

1211 Influence of the slope-to-EOD ratio on the epsilon parameter. 

1212 A slope_ratio_factor of 4 means that slope-to-EOD ratios >1/4 

1213 start influencing epsilon. 

1214 min_cluster_fraction : float (optional) 

1215 Minimum fraction of all eveluated datapoint that can form a single cluster. 

1216 verbose : int (optional) 

1217 Verbosity level. 

1218 

1219 Returns 

1220 ------- 

1221 labels : 1D array of ints 

1222 Merged labels for each sample in x. 

1223 """ 

1224 

1225 # determine clustering threshold from data 

1226 minpc = max(minp, int(len(features)*min_cluster_fraction)) 

1227 knn = np.sort(pairwise_distances(features, features), axis=0)[minpc] 

1228 eps = min(max(1, slope_ratio_factor*np.median(bg_ratio))*max_epsilon, 

1229 np.percentile(knn, percentile)) 

1230 

1231 if verbose>1: 

1232 print('epsilon = %f'%eps) 

1233 print('Slope to EOD ratio = %f'%np.median(bg_ratio)) 

1234 

1235 # cluster on EOD shape 

1236 return DBSCAN(eps=eps, min_samples=minpc).fit(features).labels_ 

1237 

1238 

1239def subtract_slope(snippets, heights): 

1240 """ Subtract underlying slope from all EOD snippets. 

1241 

1242 Parameters 

1243 ---------- 

1244 snippets: 2-D numpy array 

1245 All EODs in a recorded stacked as snippets.  

1246 Shape = (number of EODs, EOD width) 

1247 heights: 1D numpy array 

1248 EOD heights. 

1249 

1250 Returns 

1251 ------- 

1252 snippets: 2-D numpy array 

1253 EOD snippets with underlying slope subtracted. 

1254 bg_ratio : 1-D numpy array 

1255 EOD height/background activity height. 

1256 """ 

1257 

1258 left_y = snippets[:,0] 

1259 right_y = snippets[:,-1] 

1260 

1261 try: 

1262 slopes = np.linspace(left_y, right_y, snippets.shape[1]) 

1263 except ValueError: 

1264 delta = (right_y - left_y)/snippets.shape[1] 

1265 slopes = np.arange(0, snippets.shape[1], dtype=snippets.dtype).reshape((-1,) + (1,) * np.ndim(delta))*delta + left_y 

1266 

1267 return snippets - slopes.T, np.abs(left_y-right_y)/heights 

1268 

1269 

1270def remove_artefacts(all_snippets, clusters, rate, 

1271 freq_low=20000, threshold=0.75, 

1272 verbose=0, return_data=[]): 

1273 """ Create a mask for EOD clusters that result from artefacts, based on power in low frequency spectrum. 

1274 

1275 Parameters 

1276 ---------- 

1277 all_snippets: 2D array 

1278 EOD snippets. Shape=(nEODs, EOD length) 

1279 clusters: list of ints 

1280 EOD cluster labels 

1281 rate : float 

1282 Sampling rate of original recording data. 

1283 freq_low: float 

1284 Frequency up to which low frequency components are summed up.  

1285 threshold : float (optional) 

1286 Minimum value for sum of low frequency components relative to 

1287 sum overa ll spectrl amplitudes that separates artefact from 

1288 clean pulsefish clusters. 

1289 verbose : int (optional) 

1290 Verbosity level. 

1291 return_data : list of strings (optional) 

1292 Keys that specify data to be logged. The key that can be used to log data in this function is 

1293 'eod_deletion' (see extract_pulsefish()). 

1294 

1295 Returns 

1296 ------- 

1297 mask: numpy array of booleans 

1298 Set to True for every EOD which is an artefact. 

1299 adict : dictionary 

1300 Key value pairs of logged data. Data to be logged is specified by return_data. 

1301 """ 

1302 adict = {} 

1303 

1304 mask = np.zeros(clusters.shape, dtype=bool) 

1305 

1306 for cluster in np.unique(clusters[clusters >= 0]): 

1307 snippets = all_snippets[clusters == cluster] 

1308 mean_eod = np.mean(snippets, axis=0) 

1309 mean_eod = mean_eod - np.mean(mean_eod) 

1310 mean_eod_fft = np.abs(np.fft.rfft(mean_eod)) 

1311 freqs = np.fft.rfftfreq(len(mean_eod), 1/rate) 

1312 low_frequency_ratio = np.sum(mean_eod_fft[freqs<freq_low])/np.sum(mean_eod_fft) 

1313 if low_frequency_ratio < threshold: # TODO: check threshold! 

1314 mask[clusters==cluster] = True 

1315 

1316 if verbose > 0: 

1317 print('Deleting cluster %i with low frequency ratio of %.3f (min %.3f)' % (cluster, low_frequency_ratio, threshold)) 

1318 

1319 if 'eod_deletion' in return_data: 

1320 adict['vals_%d' % cluster] = [mean_eod, mean_eod_fft] 

1321 adict['mask_%d' % cluster] = [np.any(mask[clusters==cluster])] 

1322 

1323 return mask, adict 

1324 

1325 

1326def delete_unreliable_fish(clusters, eod_widths, eod_x, verbose=0, sdict={}): 

1327 """ Create a mask for EOD clusters that are either mixed with noise or other fish, or wavefish. 

1328  

1329 This is the case when the ration between the EOD width and the ISI is too large. 

1330 

1331 Parameters 

1332 ---------- 

1333 clusters : list of ints 

1334 Cluster labels. 

1335 eod_widths : list of floats or ints 

1336 EOD widths in samples or seconds. 

1337 eod_x : list of ints or floats 

1338 EOD times in samples or seconds. 

1339 

1340 verbose : int (optional) 

1341 Verbosity level. 

1342 sdict : dictionary 

1343 Dictionary that is used to log data. This is only used if a dictionary 

1344 was created by remove_artefacts(). 

1345 For logging data in noise and wavefish discarding steps, 

1346 see remove_artefacts(). 

1347 

1348 Returns 

1349 ------- 

1350 mask : numpy array of booleans 

1351 Set to True for every unreliable EOD. 

1352 sdict : dictionary 

1353 Key value pairs of logged data. Data is only logged if a dictionary 

1354 was instantiated by remove_artefacts(). 

1355 """ 

1356 mask = np.zeros(clusters.shape, dtype=bool) 

1357 for cluster in np.unique(clusters[clusters >= 0]): 

1358 if len(eod_x[cluster == clusters]) < 2: 

1359 mask[clusters == cluster] = True 

1360 if verbose>0: 

1361 print('deleting unreliable cluster %i, number of EOD times %d < 2' % (cluster, len(eod_x[cluster==clusters]))) 

1362 elif np.max(np.median(eod_widths[clusters==cluster])/np.diff(eod_x[cluster==clusters])) > 0.5: 

1363 if verbose>0: 

1364 print('deleting unreliable cluster %i, score=%f' % (cluster, np.max(np.median(eod_widths[clusters==cluster])/np.diff(eod_x[cluster==clusters])))) 

1365 mask[clusters==cluster] = True 

1366 if 'vals_%d' % cluster in sdict: 

1367 sdict['vals_%d' % cluster].append(np.median(eod_widths[clusters==cluster])/np.diff(eod_x[cluster==clusters])) 

1368 sdict['mask_%d' % cluster].append(any(mask[clusters==cluster])) 

1369 return mask, sdict 

1370 

1371 

1372def delete_wavefish_and_sidepeaks(data, clusters, eod_x, eod_widths, 

1373 width_fac, max_slope_deviation=0.5, 

1374 max_phases=4, verbose=0, sdict={}): 

1375 """ Create a mask for EODs that are likely from wavefish, or sidepeaks of bigger EODs. 

1376 

1377 Parameters 

1378 ---------- 

1379 data : list of floats 

1380 Raw recording data. 

1381 clusters : list of ints 

1382 Cluster labels. 

1383 eod_x : list of ints 

1384 Indices of EOD times. 

1385 eod_widths : list of ints 

1386 EOD widths in samples. 

1387 width_fac : float 

1388 Multiplier for EOD analysis width. 

1389 

1390 max_slope_deviation: float (optional) 

1391 Maximum deviation of position of maximum slope in snippets from 

1392 center position in multiples of mean width of EOD. 

1393 max_phases : int (optional) 

1394 Maximum number of phases for any EOD.  

1395 If the mean EOD has more phases than this, it is not a pulse EOD. 

1396 verbose : int (optional)  

1397 Verbosity level. 

1398 sdict : dictionary 

1399 Dictionary that is used to log data. This is only used if a dictionary 

1400 was created by remove_artefacts(). 

1401 For logging data in noise and wavefish discarding steps, see remove_artefacts(). 

1402 

1403 Returns 

1404 ------- 

1405 mask_wave: numpy array of booleans 

1406 Set to True for every EOD which is a wavefish EOD. 

1407 mask_sidepeak: numpy array of booleans 

1408 Set to True for every snippet which is centered around a sidepeak of an EOD. 

1409 sdict : dictionary 

1410 Key value pairs of logged data. Data is only logged if a dictionary 

1411 was instantiated by remove_artefacts(). 

1412 """ 

1413 mask_wave = np.zeros(clusters.shape, dtype=bool) 

1414 mask_sidepeak = np.zeros(clusters.shape, dtype=bool) 

1415 

1416 for i, cluster in enumerate(np.unique(clusters[clusters >= 0])): 

1417 mean_width = np.mean(eod_widths[clusters == cluster]) 

1418 cutwidth = mean_width*width_fac 

1419 current_x = eod_x[(eod_x>cutwidth) & (eod_x<(len(data)-cutwidth))] 

1420 current_clusters = clusters[(eod_x>cutwidth) & (eod_x<(len(data)-cutwidth))] 

1421 snippets = np.vstack([data[int(x-cutwidth):int(x+cutwidth)] 

1422 for x in current_x[current_clusters==cluster]]) 

1423 

1424 # extract information on main peaks and troughs: 

1425 mean_eod = np.mean(snippets, axis=0) 

1426 mean_eod = mean_eod - np.mean(mean_eod) 

1427 

1428 # detect peaks and troughs on data + some maxima/minima at the 

1429 # end, so that the sides are also considered for peak detection: 

1430 pk, tr = detect_peaks(np.concatenate(([-10*mean_eod[0]], mean_eod, [10*mean_eod[-1]])), 

1431 np.std(mean_eod)) 

1432 pk = pk[(pk>0)&(pk<len(mean_eod))] 

1433 tr = tr[(tr>0)&(tr<len(mean_eod))] 

1434 

1435 if len(pk)>0 and len(tr)>0: 

1436 idxs = np.sort(np.concatenate((pk, tr))) 

1437 slopes = np.abs(np.diff(mean_eod[idxs])) 

1438 m_slope = np.argmax(slopes) 

1439 centered = np.min(np.abs(idxs[m_slope:m_slope+2] - len(mean_eod)//2)) 

1440 

1441 # compute all height differences of peaks and troughs within snippets. 

1442 # if they are all similar, it is probably noise or a wavefish. 

1443 idxs = np.sort(np.concatenate((pk, tr))) 

1444 hdiffs = np.diff(mean_eod[idxs]) 

1445 

1446 if centered > max_slope_deviation*mean_width: # TODO: check, factor was probably 0.16 

1447 if verbose > 0: 

1448 print('Deleting cluster %i, which is a sidepeak' % cluster) 

1449 mask_sidepeak[clusters==cluster] = True 

1450 

1451 w_diff = np.abs(np.diff(np.sort(np.concatenate((pk, tr))))) 

1452 

1453 if np.abs(np.diff(idxs[m_slope:m_slope+2])) < np.mean(eod_widths[clusters==cluster])*0.5 or len(pk) + len(tr)>max_phases or np.min(w_diff)>2*cutwidth/width_fac: #or len(hdiffs[np.abs(hdiffs)>0.5*(np.max(mean_eod)-np.min(mean_eod))])>max_phases: 

1454 if verbose>0: 

1455 print('Deleting cluster %i, which is a wavefish' % cluster) 

1456 mask_wave[clusters==cluster] = True 

1457 if 'vals_%d' % cluster in sdict: 

1458 sdict['vals_%d' % cluster].append([mean_eod, [pk, tr], 

1459 idxs[m_slope:m_slope+2]]) 

1460 sdict['mask_%d' % cluster].append(any(mask_wave[clusters==cluster])) 

1461 sdict['mask_%d' % cluster].append(any(mask_sidepeak[clusters==cluster])) 

1462 

1463 return mask_wave, mask_sidepeak, sdict 

1464 

1465 

1466def merge_clusters(clusters_1, clusters_2, x_1, x_2, verbose=0): 

1467 """ Merge clusters resulting from two clustering methods. 

1468 

1469 This method only works if clustering is performed on the same EODs 

1470 with the same ordering, where there is a one to one mapping from 

1471 clusters_1 to clusters_2.  

1472 

1473 Parameters 

1474 ---------- 

1475 clusters_1: list of ints 

1476 EOD cluster labels for cluster method 1. 

1477 clusters_2: list of ints 

1478 EOD cluster labels for cluster method 2. 

1479 x_1: list of ints 

1480 Indices of EODs for cluster method 1 (clusters_1). 

1481 x_2: list of ints 

1482 Indices of EODs for cluster method 2 (clusters_2). 

1483 verbose : int (optional) 

1484 Verbosity level. 

1485 

1486 Returns 

1487 ------- 

1488 clusters : list of ints 

1489 Merged clusters. 

1490 x_merged : list of ints 

1491 Merged cluster indices. 

1492 mask : 2d numpy array of ints (N, 2) 

1493 Mask for clusters that are selected from clusters_1 (mask[:,0]) and 

1494 from clusters_2 (mask[:,1]). 

1495 """ 

1496 if verbose > 0: 

1497 print('\nMerge cluster:') 

1498 

1499 # these arrays become 1 for each EOD that is chosen from that array 

1500 c1_keep = np.zeros(len(clusters_1)) 

1501 c2_keep = np.zeros(len(clusters_2)) 

1502 

1503 # add n to one of the cluster lists to avoid overlap 

1504 ovl = np.max(clusters_1) + 1 

1505 clusters_2[clusters_2!=-1] = clusters_2[clusters_2!=-1] + ovl 

1506 

1507 remove_clusters = [[]] 

1508 keep_clusters = [] 

1509 og_clusters = [np.copy(clusters_1), np.copy(clusters_2)] 

1510 

1511 # loop untill done 

1512 while True: 

1513 

1514 # compute unique clusters and cluster sizes 

1515 # of cluster that have not been iterated over: 

1516 c1_labels, c1_size = unique_counts(clusters_1[(clusters_1 != -1) & (c1_keep == 0)]) 

1517 c2_labels, c2_size = unique_counts(clusters_2[(clusters_2 != -1) & (c2_keep == 0)]) 

1518 

1519 # if all clusters are done, break from loop: 

1520 if len(c1_size) == 0 and len(c2_size) == 0: 

1521 break 

1522 

1523 # if the biggest cluster is in c_p, keep this one and discard all clusters 

1524 # on the same indices in c_t: 

1525 elif np.argmax([np.max(np.append(c1_size, 0)), np.max(np.append(c2_size, 0))]) == 0: 

1526 

1527 # remove all the mappings from the other indices 

1528 cluster_mappings, _ = unique_counts(clusters_2[clusters_1 == c1_labels[np.argmax(c1_size)]]) 

1529 

1530 clusters_2[np.isin(clusters_2, cluster_mappings)] = -1 

1531 

1532 c1_keep[clusters_1==c1_labels[np.argmax(c1_size)]] = 1 

1533 

1534 remove_clusters.append(cluster_mappings) 

1535 keep_clusters.append(c1_labels[np.argmax(c1_size)]) 

1536 

1537 if verbose > 0: 

1538 print('Keep cluster %i of group 1, delete clusters %s of group 2' % (c1_labels[np.argmax(c1_size)], str(cluster_mappings[cluster_mappings!=-1] - ovl))) 

1539 

1540 # if the biggest cluster is in c_t, keep this one and discard all mappings in c_p 

1541 elif np.argmax([np.max(np.append(c1_size, 0)), np.max(np.append(c2_size, 0))]) == 1: 

1542 

1543 # remove all the mappings from the other indices 

1544 cluster_mappings, _ = unique_counts(clusters_1[clusters_2 == c2_labels[np.argmax(c2_size)]]) 

1545 

1546 clusters_1[np.isin(clusters_1, cluster_mappings)] = -1 

1547 

1548 c2_keep[clusters_2==c2_labels[np.argmax(c2_size)]] = 1 

1549 

1550 remove_clusters.append(cluster_mappings) 

1551 keep_clusters.append(c2_labels[np.argmax(c2_size)]) 

1552 

1553 if verbose > 0: 

1554 print('Keep cluster %i of group 2, delete clusters %s of group 1' % (c2_labels[np.argmax(c2_size)] - ovl, str(cluster_mappings[cluster_mappings!=-1]))) 

1555 

1556 # combine results  

1557 clusters = (clusters_1+1)*c1_keep + (clusters_2+1)*c2_keep - 1 

1558 x_merged = (x_1)*c1_keep + (x_2)*c2_keep 

1559 

1560 return clusters, x_merged, np.vstack([c1_keep, c2_keep]) 

1561 

1562 

1563def extract_means(data, eod_x, eod_peak_x, eod_tr_x, eod_widths, 

1564 clusters, rate, width_fac, verbose=0): 

1565 """ Extract mean EODs and EOD timepoints for each EOD cluster. 

1566 

1567 Parameters 

1568 ---------- 

1569 data: list of floats 

1570 Raw recording data. 

1571 eod_x: list of ints 

1572 Locations of EODs in samples. 

1573 eod_peak_x : list of ints 

1574 Locations of EOD peaks in samples. 

1575 eod_tr_x : list of ints 

1576 Locations of EOD troughs in samples. 

1577 eod_widths: list of ints 

1578 EOD widths in samples. 

1579 clusters: list of ints 

1580 EOD cluster labels 

1581 rate: float 

1582 Sampling rate of recording  

1583 width_fac : float 

1584 Multiplication factor for window used to extract EOD. 

1585  

1586 verbose : int (optional) 

1587 Verbosity level. 

1588 

1589 Returns 

1590 ------- 

1591 mean_eods: list of 2D arrays (3, eod_length) 

1592 The average EOD for each detected fish. First column is time in seconds, 

1593 second column the mean eod, third column the standard error. 

1594 eod_times: list of 1D arrays 

1595 For each detected fish the times of EOD in seconds. 

1596 eod_peak_times: list of 1D arrays 

1597 For each detected fish the times of EOD peaks in seconds. 

1598 eod_trough_times: list of 1D arrays 

1599 For each detected fish the times of EOD troughs in seconds. 

1600 eod_labels: list of ints 

1601 Cluster label for each detected fish. 

1602 """ 

1603 mean_eods, eod_times, eod_peak_times, eod_tr_times, eod_heights, cluster_labels = [], [], [], [], [], [] 

1604 

1605 for cluster in np.unique(clusters): 

1606 if cluster!=-1: 

1607 cutwidth = np.mean(eod_widths[clusters==cluster])*width_fac 

1608 current_x = eod_x[(eod_x>cutwidth) & (eod_x<(len(data)-cutwidth))] 

1609 current_clusters = clusters[(eod_x>cutwidth) & (eod_x<(len(data)-cutwidth))] 

1610 

1611 snippets = np.vstack([data[int(x-cutwidth):int(x+cutwidth)] for x in current_x[current_clusters==cluster]]) 

1612 mean_eod = np.mean(snippets, axis=0) 

1613 eod_time = np.arange(len(mean_eod))/rate - cutwidth/rate 

1614 

1615 mean_eod = np.vstack([eod_time, mean_eod, np.std(snippets, axis=0)]) 

1616 

1617 mean_eods.append(mean_eod) 

1618 eod_times.append(eod_x[clusters==cluster]/rate) 

1619 eod_heights.append(np.min(mean_eod)-np.max(mean_eod)) 

1620 eod_peak_times.append(eod_peak_x[clusters==cluster]/rate) 

1621 eod_tr_times.append(eod_tr_x[clusters==cluster]/rate) 

1622 cluster_labels.append(cluster) 

1623 

1624 return [m for _, m in sorted(zip(eod_heights, mean_eods))], [t for _, t in sorted(zip(eod_heights, eod_times))], [pt for _, pt in sorted(zip(eod_heights, eod_peak_times))], [tt for _, tt in sorted(zip(eod_heights, eod_tr_times))], [c for _, c in sorted(zip(eod_heights, cluster_labels))] 

1625 

1626 

1627def find_clipped_clusters(clusters, mean_eods, eod_times, 

1628 eod_peaktimes, eod_troughtimes, 

1629 cluster_labels, width_factor, 

1630 clip_threshold=0.9, verbose=0): 

1631 """ Detect EODs that are clipped and set all clusterlabels of these clipped EODs to -1. 

1632  

1633 Also return the mean EODs and timepoints of these clipped EODs. 

1634 

1635 Parameters 

1636 ---------- 

1637 clusters: array of ints 

1638 Cluster labels for each EOD in a recording. 

1639 mean_eods: list of numpy arrays 

1640 Mean EOD waveform for each cluster. 

1641 eod_times: list of numpy arrays 

1642 EOD timepoints for each EOD cluster. 

1643 eod_peaktimes 

1644 EOD peaktimes for each EOD cluster. 

1645 eod_troughtimes 

1646 EOD troughtimes for each EOD cluster. 

1647 cluster_labels: numpy array 

1648 Unique EOD clusterlabels. 

1649 clip_threshold: float 

1650 Threshold for detecting clipped EODs. 

1651  

1652 verbose: int 

1653 Verbosity level. 

1654 

1655 Returns 

1656 ------- 

1657 clusters : array of ints 

1658 Cluster labels for each EOD in the recording, where clipped EODs have been set to -1. 

1659 clipped_eods : list of numpy arrays 

1660 Mean EOD waveforms for each clipped EOD cluster. 

1661 clipped_times : list of numpy arrays 

1662 EOD timepoints for each clipped EOD cluster. 

1663 clipped_peaktimes : list of numpy arrays 

1664 EOD peaktimes for each clipped EOD cluster. 

1665 clipped_troughtimes : list of numpy arrays 

1666 EOD troughtimes for each clipped EOD cluster. 

1667 """ 

1668 clipped_eods, clipped_times, clipped_peaktimes, clipped_troughtimes, clipped_labels = [], [], [], [], [] 

1669 

1670 for mean_eod, eod_time, eod_peaktime, eod_troughtime,label in zip(mean_eods, eod_times, eod_peaktimes, eod_troughtimes, cluster_labels): 

1671 

1672 if (np.count_nonzero(mean_eod[1]>clip_threshold) > len(mean_eod[1])/(width_factor*2)) or (np.count_nonzero(mean_eod[1] < -clip_threshold) > len(mean_eod[1])/(width_factor*2)): 

1673 clipped_eods.append(mean_eod) 

1674 clipped_times.append(eod_time) 

1675 clipped_peaktimes.append(eod_peaktime) 

1676 clipped_troughtimes.append(eod_troughtime) 

1677 clipped_labels.append(label) 

1678 if verbose>0: 

1679 print('clipped pulsefish') 

1680 

1681 clusters[np.isin(clusters, clipped_labels)] = -1 

1682 

1683 return clusters, clipped_eods, clipped_times, clipped_peaktimes, clipped_troughtimes 

1684 

1685 

1686def delete_moving_fish(clusters, eod_t, T, eod_heights, eod_widths, 

1687 rate, min_dt=0.25, stepsize=0.05, 

1688 sliding_window_factor=2000, verbose=0, 

1689 plot_level=0, save_plot=False, save_path='', 

1690 ftype='pdf', return_data=[]): 

1691 """ 

1692 Use a sliding window to detect the minimum number of fish detected simultaneously,  

1693 then delete all other EOD clusters.  

1694 

1695 Do this only for EODs within the same width clusters, as a 

1696 moving fish will preserve its EOD width. 

1697 

1698 Parameters 

1699 ---------- 

1700 clusters: list of ints 

1701 EOD cluster labels. 

1702 eod_t: list of floats 

1703 Timepoints of the EODs (in seconds). 

1704 T: float 

1705 Length of recording (in seconds). 

1706 eod_heights: list of floats 

1707 EOD amplitudes. 

1708 eod_widths: list of floats 

1709 EOD widths (in seconds). 

1710 rate: float 

1711 Recording data sampling rate. 

1712 

1713 min_dt : float (optional) 

1714 Minimum sliding window size (in seconds). 

1715 stepsize : float (optional) 

1716 Sliding window stepsize (in seconds). 

1717 sliding_window_factor : float 

1718 Multiplier for sliding window width, 

1719 where the sliding window width = median(EOD_width)*sliding_window_factor. 

1720 verbose : int (optional) 

1721 Verbosity level. 

1722 plot_level : int (optional) 

1723 Similar to verbosity levels, but with plots.  

1724 Only set to > 0 for debugging purposes. 

1725 save_plot : bool (optional) 

1726 Set to True to save the plots created by plot_level. 

1727 save_path : string (optional) 

1728 Path to save data to. Only important if you wish to save data (save_data==True). 

1729 ftype : string (optional) 

1730 Define the filetype to save the plots in if save_plots is set to True. 

1731 Options are: 'png', 'jpg', 'svg' ... 

1732 return_data : list of strings (optional) 

1733 Keys that specify data to be logged. The key that can be used to log data 

1734 in this function is 'moving_fish' (see extract_pulsefish()). 

1735 

1736 Returns 

1737 ------- 

1738 clusters : list of ints 

1739 Cluster labels, where deleted clusters have been set to -1. 

1740 window : list of 2 floats 

1741 Start and end of window selected for deleting moving fish in seconds. 

1742 mf_dict : dictionary 

1743 Key value pairs of logged data. Data to be logged is specified by return_data. 

1744 """ 

1745 mf_dict = {} 

1746 

1747 if len(np.unique(clusters[clusters != -1])) == 0: 

1748 return clusters, [0, 1], {} 

1749 

1750 all_keep_clusters = [] 

1751 width_classes = merge_gaussians(eod_widths, np.copy(clusters), 0.75) 

1752 

1753 all_windows = [] 

1754 all_dts = [] 

1755 ev_num = 0 

1756 for iw, w in enumerate(np.unique(width_classes[clusters >= 0])): 

1757 # initialize variables 

1758 min_clusters = 100 

1759 average_height = 0 

1760 sparse_clusters = 100 

1761 keep_clusters = [] 

1762 

1763 dt = max(min_dt, np.median(eod_widths[width_classes==w])*sliding_window_factor) 

1764 window_start = 0 

1765 window_end = dt 

1766 

1767 wclusters = clusters[width_classes==w] 

1768 weod_t = eod_t[width_classes==w] 

1769 weod_heights = eod_heights[width_classes==w] 

1770 weod_widths = eod_widths[width_classes==w] 

1771 

1772 all_dts.append(dt) 

1773 

1774 if verbose>0: 

1775 print('sliding window dt = %f'%dt) 

1776 

1777 # make W dependent on width?? 

1778 ignore_steps = np.zeros(len(np.arange(0, T-dt+stepsize, stepsize))) 

1779 

1780 for i, t in enumerate(np.arange(0, T-dt+stepsize, stepsize)): 

1781 current_clusters = wclusters[(weod_t>=t)&(weod_t<t+dt)&(wclusters!=-1)] 

1782 if len(np.unique(current_clusters))==0: 

1783 ignore_steps[i-int(dt/stepsize):i+int(dt/stepsize)] = 1 

1784 if verbose>0: 

1785 print('No pulsefish in recording at T=%.2f:%.2f' % (t, t+dt)) 

1786 

1787 

1788 x = np.arange(0, T-dt+stepsize, stepsize) 

1789 y = np.ones(len(x)) 

1790 

1791 running_sum = np.ones(len(np.arange(0, T+stepsize, stepsize))) 

1792 ulabs = np.unique(wclusters[wclusters>=0]) 

1793 

1794 # sliding window 

1795 for j, (t, ignore_step) in enumerate(zip(x, ignore_steps)): 

1796 current_clusters = wclusters[(weod_t>=t)&(weod_t<t+dt)&(wclusters!=-1)] 

1797 current_widths = weod_widths[(weod_t>=t)&(weod_t<t+dt)&(wclusters!=-1)] 

1798 

1799 unique_clusters = np.unique(current_clusters) 

1800 y[j] = len(unique_clusters) 

1801 

1802 if (len(unique_clusters) <= min_clusters) and \ 

1803 (ignore_step==0) and \ 

1804 (len(unique_clusters !=1)): 

1805 

1806 current_labels = np.isin(wclusters, unique_clusters) 

1807 current_height = np.mean(weod_heights[current_labels]) 

1808 

1809 # compute nr of clusters that are too sparse 

1810 clusters_after_deletion = np.unique(remove_sparse_detections(np.copy(clusters[np.isin(clusters, unique_clusters)]), rate*eod_widths[np.isin(clusters, unique_clusters)], rate, T)) 

1811 current_sparse_clusters = len(unique_clusters) - len(clusters_after_deletion[clusters_after_deletion!=-1]) 

1812 

1813 if current_sparse_clusters <= sparse_clusters and \ 

1814 ((current_sparse_clusters<sparse_clusters) or 

1815 (current_height > average_height) or 

1816 (len(unique_clusters) < min_clusters)): 

1817 

1818 keep_clusters = unique_clusters 

1819 min_clusters = len(unique_clusters) 

1820 average_height = current_height 

1821 window_end = t+dt 

1822 sparse_clusters = current_sparse_clusters 

1823 

1824 all_keep_clusters.append(keep_clusters) 

1825 all_windows.append(window_end) 

1826 

1827 if 'moving_fish' in return_data or plot_level>0: 

1828 if 'w' in mf_dict: 

1829 mf_dict['w'].append(np.median(eod_widths[width_classes==w])) 

1830 mf_dict['T'] = T 

1831 mf_dict['dt'].append(dt) 

1832 mf_dict['clusters'].append(wclusters) 

1833 mf_dict['t'].append(weod_t) 

1834 mf_dict['fishcount'].append([x+0.5*(x[1]-x[0]), y]) 

1835 mf_dict['ignore_steps'].append(ignore_steps) 

1836 else: 

1837 mf_dict['w'] = [np.median(eod_widths[width_classes==w])] 

1838 mf_dict['T'] = [T] 

1839 mf_dict['dt'] = [dt] 

1840 mf_dict['clusters'] = [wclusters] 

1841 mf_dict['t'] = [weod_t] 

1842 mf_dict['fishcount'] = [[x+0.5*(x[1]-x[0]), y]] 

1843 mf_dict['ignore_steps'] = [ignore_steps] 

1844 

1845 if verbose>0: 

1846 print('Estimated nr of pulsefish in recording: %i'%len(all_keep_clusters)) 

1847 

1848 if plot_level>0: 

1849 plot_moving_fish(mf_dict['w'], mf_dict['dt'], mf_dict['clusters'],mf_dict['t'], 

1850 mf_dict['fishcount'], T, mf_dict['ignore_steps']) 

1851 if save_plot: 

1852 plt.savefig('%sdelete_moving_fish.%s' % (save_path, ftype)) 

1853 # empty dict 

1854 if 'moving_fish' not in return_data: 

1855 mf_dict = {} 

1856 

1857 # delete all clusters that are not selected 

1858 clusters[np.invert(np.isin(clusters, np.concatenate(all_keep_clusters)))] = -1 

1859 

1860 return clusters, [np.max(all_windows)-np.max(all_dts), np.max(all_windows)], mf_dict 

1861 

1862 

1863def remove_sparse_detections(clusters, eod_widths, rate, T, 

1864 min_density=0.0005, verbose=0): 

1865 """ Remove all EOD clusters that are too sparse 

1866 

1867 Parameters 

1868 ---------- 

1869 clusters : list of ints 

1870 Cluster labels. 

1871 eod_widths : list of ints 

1872 Cluster widths in samples. 

1873 rate : float 

1874 Sampling rate. 

1875 T : float 

1876 Lenght of recording in seconds. 

1877 min_density : float (optional) 

1878 Minimum density for realistic EOD detections. 

1879 verbose : int (optional) 

1880 Verbosity level. 

1881 

1882 Returns 

1883 ------- 

1884 clusters : list of ints 

1885 Cluster labels, where sparse clusters have been set to -1. 

1886 """ 

1887 for c in np.unique(clusters): 

1888 if c!=-1: 

1889 

1890 n = len(clusters[clusters==c]) 

1891 w = np.median(eod_widths[clusters==c])/rate 

1892 

1893 if n*w < T*min_density: 

1894 if verbose>0: 

1895 print('cluster %i is too sparse'%c) 

1896 clusters[clusters==c] = -1 

1897 return clusters