Coverage for src / thunderfish / pulses.py: 0%

602 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-15 17:50 +0000

1""" 

2Extract and cluster EOD waverforms of pulse-type electric fish. 

3 

4## Main function 

5 

6- `extract_pulsefish()`: checks for pulse-type fish based on the EOD amplitude and shape. 

7 

8""" 

9 

10import os 

11import numpy as np 

12 

13from scipy import stats 

14from scipy.interpolate import interp1d 

15from sklearn.preprocessing import StandardScaler 

16from sklearn.decomposition import PCA 

17from sklearn.cluster import DBSCAN 

18from sklearn.mixture import BayesianGaussianMixture 

19from sklearn.metrics import pairwise_distances 

20from thunderlab.eventdetection import detect_peaks, median_std_threshold 

21 

22from .pulseplots import * 

23 

24import warnings 

25def warn(*args, **kwargs): 

26 """ 

27 Ignore all warnings. 

28 """ 

29 pass 

30warnings.warn = warn 

31 

32try: 

33 from numba import jit 

34except ImportError: 

35 def jit(*args, **kwargs): 

36 def decorator_jit(func): 

37 return func 

38 return decorator_jit 

39 

40 

41# upgrade numpy functions for backwards compatibility: 

42if not hasattr(np, 'isin'): 

43 np.isin = np.in1d 

44 

45def unique_counts(ar): 

46 """ Find the unique elements of an array and their counts, ignoring shape. 

47 

48 The code is condensed from numpy version 1.17.0. 

49  

50 Parameters 

51 ---------- 

52 ar : numpy array 

53 Input array 

54 

55 Returns 

56 ------- 

57 unique_vaulues : numpy array 

58 Unique values in array ar. 

59 unique_counts : numpy array 

60 Number of instances for each unique value in ar. 

61 """ 

62 try: 

63 return np.unique(ar, return_counts=True) 

64 except TypeError: 

65 ar = np.asanyarray(ar).flatten() 

66 ar.sort() 

67 mask = np.empty(ar.shape, dtype=bool_) 

68 mask[:1] = True 

69 mask[1:] = ar[1:] != ar[:-1] 

70 idx = np.concatenate(np.nonzero(mask) + ([mask.size],)) 

71 return ar[mask], np.diff(idx) 

72 

73 

74########################################################################### 

75 

76 

77def extract_pulsefish(data, rate, amax, width_factor_shape=3, 

78 width_factor_wave=8, width_factor_display=4, 

79 verbose=0, plot_level=0, save_plots=False, 

80 save_path='', ftype='png', return_data=[]): 

81 """ Extract and cluster pulse-type fish EODs from single channel data. 

82  

83 Takes recording data containing an unknown number of pulsefish and extracts the mean  

84 EOD and EOD timepoints for each fish present in the recording. 

85  

86 Parameters 

87 ---------- 

88 data: 1-D array of float 

89 The data to be analysed. 

90 rate: float 

91 Sampling rate of the data in Hertz. 

92 amax: float 

93 Maximum amplitude of data range. 

94 width_factor_shape : float (optional) 

95 Width multiplier used for EOD shape analysis. 

96 EOD snippets are extracted based on width between the  

97 peak and trough multiplied by the width factor. 

98 width_factor_wave : float (optional) 

99 Width multiplier used for wavefish detection. 

100 width_factor_display : float (optional) 

101 Width multiplier used for EOD mean extraction and display. 

102 verbose : int (optional) 

103 Verbosity level. 

104 plot_level : int (optional) 

105 Similar to verbosity levels, but with plots.  

106 Only set to > 0 for debugging purposes. 

107 save_plots : bool (optional) 

108 Set to True to save the plots created by plot_level. 

109 save_path: string (optional) 

110 Path for saving plots. 

111 ftype : string (optional) 

112 Define the filetype to save the plots in if save_plots is set to True. 

113 Options are: 'png', 'jpg', 'svg' ... 

114 return_data : list of strings (optional) 

115 Specify data that should be logged and returned in a dictionary. Each clustering  

116 step has a specific keyword that results in adding different variables to the log dictionary. 

117 Optional keys for return_data and the resulting additional key-value pairs to the log dictionary are: 

118  

119 - 'all_eod_times': 

120 - 'all_times': list of two lists of floats. 

121 All peak (`all_times[0]`) and trough times (`all_times[1]`) extracted 

122 by the peak detection algorithm. Times are given in seconds. 

123 - 'eod_troughtimes': list of 1D arrays. 

124 The timepoints in seconds of each unique extracted EOD cluster, 

125 where each 1D array encodes one cluster. 

126  

127 - 'peak_detection': 

128 - "data": 1D numpy array of floats. 

129 Quadratically interpolated data which was used for peak detection. 

130 - "interp_fac": float. 

131 Interpolation factor of raw data. 

132 - "peaks_1": 1D numpy array of ints. 

133 Peak indices on interpolated data after first peak detection step. 

134 - "troughs_1": 1D numpy array of ints. 

135 Peak indices on interpolated data after first peak detection step. 

136 - "peaks_2": 1D numpy array of ints. 

137 Peak indices on interpolated data after second peak detection step. 

138 - "troughs_2": 1D numpy array of ints. 

139 Peak indices on interpolated data after second peak detection step. 

140 - "peaks_3": 1D numpy array of ints. 

141 Peak indices on interpolated data after third peak detection step. 

142 - "troughs_3": 1D numpy array of ints. 

143 Peak indices on interpolated data after third peak detection step. 

144 - "peaks_4": 1D numpy array of ints. 

145 Peak indices on interpolated data after fourth peak detection step. 

146 - "troughs_4": 1D numpy array of ints. 

147 Peak indices on interpolated data after fourth peak detection step. 

148 

149 - 'all_cluster_steps': 

150 - 'rate': float. 

151 Sampling rate of interpolated data. 

152 - 'EOD_widths': list of three 1D numpy arrays. 

153 The first list entry gives the unique labels of all width clusters 

154 as a list of ints. 

155 The second list entry gives the width values for each EOD in samples 

156 as a 1D numpy array of ints. 

157 The third list entry gives the width labels for each EOD 

158 as a 1D numpy array of ints. 

159 - 'EOD_heights': nested lists (2 layers) of three 1D numpy arrays. 

160 The first list entry gives the unique labels of all height clusters 

161 as a list of ints for each width cluster. 

162 The second list entry gives the height values for each EOD 

163 as a 1D numpy array of floats for each width cluster. 

164 The third list entry gives the height labels for each EOD 

165 as a 1D numpy array of ints for each width cluster. 

166 - 'EOD_shapes': nested lists (3 layers) of three 1D numpy arrays 

167 The first list entry gives the raw EOD snippets as a 2D numpy array 

168 for each height cluster in a width cluster. 

169 The second list entry gives the snippet PCA values for each EOD 

170 as a 2D numpy array of floats for each height cluster in a width cluster. 

171 The third list entry gives the shape labels for each EOD as a 1D numpy array 

172 of ints for each height cluster in a width cluster. 

173 - 'discarding_masks': Nested lists (two layers) of 1D numpy arrays. 

174 The masks of EODs that are discarded by the discarding step of the algorithm. 

175 The masks are 1D boolean arrays where instances that are set to True are 

176 discarded by the algorithm. Discarding masks are saved in nested lists 

177 that represent the width and height clusters. 

178 - 'merge_masks': Nested lists (two layers) of 2D numpy arrays. 

179 The masks of EODs that are discarded by the merging step of the algorithm. 

180 The masks are 2D boolean arrays where for each sample point `i` either 

181 `merge_mask[i,0]` or `merge_mask[i,1]` is set to True. Here, merge_mask[:,0] 

182 represents the peak-centered clusters and `merge_mask[:,1]` represents the 

183 trough-centered clusters. Merge masks are saved in nested lists that 

184 represent the width and height clusters. 

185 

186 - 'BGM_width': 

187 - 'BGM_width': dictionary 

188 - 'x': 1D numpy array of floats. 

189 BGM input values (in this case the EOD widths), 

190 - 'use_log': boolean. 

191 True if the z-scored logarithm of the data was used as BGM input. 

192 - 'BGM': list of three 1D numpy arrays. 

193 The first instance are the weights of the Gaussian fits. 

194 The second instance are the means of the Gaussian fits. 

195 The third instance are the variances of the Gaussian fits. 

196 - 'labels': 1D numpy array of ints. 

197 Labels defined by BGM model (before merging based on merge factor). 

198 - xlab': string. 

199 Label for plot (defines the units of the BGM data). 

200 

201 - 'BGM_height': 

202 This key adds a new dictionary for each width cluster. 

203 - 'BGM_height_*n*' : dictionary, where *n* defines the width cluster as an int. 

204 - 'x': 1D numpy array of floats. 

205 BGM input values (in this case the EOD heights), 

206 - 'use_log': boolean. 

207 True if the z-scored logarithm of the data was used as BGM input. 

208 - 'BGM': list of three 1D numpy arrays. 

209 The first instance are the weights of the Gaussian fits. 

210 The second instance are the means of the Gaussian fits. 

211 The third instance are the variances of the Gaussian fits. 

212 - 'labels': 1D numpy array of ints. 

213 Labels defined by BGM model (before merging based on merge factor). 

214 - 'xlab': string. 

215 Label for plot (defines the units of the BGM data). 

216 

217 - 'snippet_clusters': 

218 This key adds a new dictionary for each height cluster. 

219 - 'snippet_clusters*_n_m_p*' : dictionary, where *n* defines the width cluster 

220 (int), *m* defines the height cluster (int) and *p* defines shape clustering 

221 on peak or trough centered EOD snippets (string: 'peak' or 'trough'). 

222 - 'raw_snippets': 2D numpy array (nsamples, nfeatures). 

223 Raw EOD snippets. 

224 - 'snippets': 2D numpy array. 

225 Normalized EOD snippets. 

226 - 'features': 2D numpy array.(nsamples, nfeatures) 

227 PCA values for each normalized EOD snippet. 

228 - 'clusters': 1D numpy array of ints. 

229 Cluster labels. 

230 - 'rate': float. 

231 Sampling rate of snippets. 

232 

233 - 'eod_deletion': 

234 This key adds two dictionaries for each (peak centered) shape cluster, 

235 where *cluster* (int) is the unique shape cluster label. 

236 - 'mask_*cluster*' : list of four booleans. 

237 The mask for each cluster discarding step.  

238 The first instance represents the artefact masks, where artefacts 

239 are set to True. 

240 The second instance represents the unreliable cluster masks, 

241 where unreliable clusters are set to True. 

242 The third instance represents the wavefish masks, where wavefish 

243 are set to True. 

244 The fourth instance represents the sidepeak masks, where sidepeaks 

245 are set to True. 

246 - 'vals_*cluster*' : list of lists. 

247 All variables that are used for each cluster deletion step. 

248 The first instance is a list of two 1D numpy arrays: the mean EOD and 

249 the FFT of that mean EOD. 

250 The second instance is a 1D numpy array with all EOD width to ISI ratios. 

251 The third instance is a list with three entries:  

252 The first entry is a 1D numpy array zoomed out version of the mean EOD. 

253 The second entry is a list of two 1D numpy arrays that define the peak 

254 and trough indices of the zoomed out mean EOD. 

255 The third entry contains a list of two values that represent the 

256 peak-trough pair in the zoomed out mean EOD with the largest height 

257 difference. 

258 - 'rate' : float. 

259 EOD snippet sampling rate. 

260 

261 - 'masks':  

262 - 'masks' : 2D numpy array (4,N). 

263 Each row contains masks for each EOD detected by the EOD peakdetection step.  

264 The first row defines the artefact masks, the second row defines the 

265 unreliable EOD masks,  

266 the third row defines the wavefish masks and the fourth row defines 

267 the sidepeak masks. 

268 

269 - 'moving_fish': 

270 - 'moving_fish': dictionary. 

271 - 'w' : list of floats. 

272 Median width for each width cluster that the moving fish algorithm is 

273 computed on (in seconds). 

274 - 'T' : list of floats. 

275 Lenght of analyzed recording for each width cluster (in seconds). 

276 - 'dt' : list of floats. 

277 Sliding window size (in seconds) for each width cluster. 

278 - 'clusters' : list of 1D numpy int arrays. 

279 Cluster labels for each EOD cluster in a width cluster. 

280 - 't' : list of 1D numpy float arrays. 

281 EOD emission times for each EOD in a width cluster. 

282 - 'fishcount' : list of lists. 

283 Sliding window timepoints and fishcounts for each width cluster. 

284 - 'ignore_steps' : list of 1D int arrays. 

285 Mask for fishcounts that were ignored (ignored if True) in the 

286 moving_fish analysis. 

287  

288 Returns 

289 ------- 

290 mean_eods: list of 2D arrays (3, eod_length) 

291 The average EOD for each detected fish. First column is time in seconds, 

292 second column the mean eod, third column the standard error. 

293 eod_times: list of 1D arrays 

294 For each detected fish the times of EOD peaks or troughs in seconds. 

295 Use these timepoints for EOD averaging. 

296 eod_peaktimes: list of 1D arrays 

297 For each detected fish the times of EOD peaks in seconds. 

298 zoom_window: tuple of floats 

299 Start and endtime of suggested window for plotting EOD timepoints. 

300 log_dict: dictionary 

301 Dictionary with logged variables, where variables to log are specified 

302 by `return_data`. 

303 """ 

304 if verbose > 0: 

305 print('') 

306 if verbose > 1: 

307 print(70*'#') 

308 print('##### extract_pulsefish', 46*'#') 

309 

310 if (save_plots and plot_level>0 and save_path): 

311 # create folder to save things in. 

312 if not os.path.exists(save_path): 

313 os.makedirs(save_path) 

314 else: 

315 save_path = '' 

316 

317 mean_eods, eod_times, eod_peaktimes, zoom_window = [], [], [], [] 

318 log_dict = {} 

319 

320 # interpolate: 

321 i_rate = 500000.0 

322 #i_rate = rate 

323 try: 

324 f = interp1d(np.arange(len(data))/rate, data, kind='quadratic') 

325 i_data = f(np.arange(0.0, (len(data)-1)/rate, 1.0/i_rate)) 

326 except MemoryError: 

327 i_rate = rate 

328 i_data = data 

329 log_dict['data'] = i_data # TODO: could be removed 

330 log_dict['rate'] = i_rate # TODO: could be removed 

331 log_dict['i_data'] = i_data 

332 log_dict['i_rate'] = i_rate 

333 # log_dict["interp_fac"] = interp_fac # TODO: is not set anymore 

334 

335 # standard deviation of data in small snippets: 

336 win_size = int(0.002*rate) # 2ms windows 

337 threshold = median_std_threshold(data, win_size) # TODO make this a parameter 

338 

339 # extract peaks: 

340 if 'peak_detection' in return_data: 

341 x_peak, x_trough, eod_heights, eod_widths, pd_log_dict = \ 

342 detect_pulses(i_data, i_rate, threshold, 

343 width_fac=np.max([width_factor_shape, 

344 width_factor_display, 

345 width_factor_wave]), 

346 verbose=verbose-1, return_data=True) 

347 log_dict.update(pd_log_dict) 

348 else: 

349 x_peak, x_trough, eod_heights, eod_widths = \ 

350 detect_pulses(i_data, i_rate, threshold, 

351 width_fac=np.max([width_factor_shape, 

352 width_factor_display, 

353 width_factor_wave]), 

354 verbose=verbose-1, return_data=False) 

355 

356 if len(x_peak) > 0: 

357 # cluster 

358 clusters, x_merge, c_log_dict = cluster(x_peak, x_trough, 

359 eod_heights, 

360 eod_widths, i_data, 

361 i_rate, 

362 width_factor_shape, 

363 width_factor_wave, 

364 merge_threshold_height=0.1*amax, 

365 verbose=verbose-1, 

366 plot_level=plot_level-1, 

367 save_plots=save_plots, 

368 save_path=save_path, 

369 ftype=ftype, 

370 return_data=return_data) 

371 

372 # extract mean eods and times 

373 mean_eods, eod_times, eod_peaktimes, eod_troughtimes, cluster_labels = \ 

374 extract_means(i_data, x_merge, x_peak, x_trough, eod_widths, clusters, 

375 i_rate, width_factor_display, verbose=verbose-1) 

376 

377 # determine clipped clusters (save them, but ignore in other steps) 

378 clusters, clipped_eods, clipped_times, clipped_peaktimes, clipped_troughtimes = \ 

379 find_clipped_clusters(clusters, mean_eods, eod_times, eod_peaktimes, 

380 eod_troughtimes, cluster_labels, width_factor_display, 

381 verbose=verbose-1) 

382 

383 # delete the moving fish 

384 clusters, zoom_window, mf_log_dict = \ 

385 delete_moving_fish(clusters, x_merge/i_rate, len(data)/rate, 

386 eod_heights, eod_widths/i_rate, i_rate, 

387 verbose=verbose-1, plot_level=plot_level-1, save_plot=save_plots, 

388 save_path=save_path, ftype=ftype, return_data=return_data) 

389 

390 if 'moving_fish' in return_data: 

391 log_dict['moving_fish'] = mf_log_dict 

392 

393 clusters = remove_sparse_detections(clusters, eod_widths, i_rate, 

394 len(data)/rate, verbose=verbose-1) 

395 

396 # extract mean eods 

397 mean_eods, eod_times, eod_peaktimes, eod_troughtimes, cluster_labels = \ 

398 extract_means(i_data, x_merge, x_peak, x_trough, eod_widths, 

399 clusters, i_rate, width_factor_display, verbose=verbose-1) 

400 

401 mean_eods.extend(clipped_eods) 

402 eod_times.extend(clipped_times) 

403 eod_peaktimes.extend(clipped_peaktimes) 

404 eod_troughtimes.extend(clipped_troughtimes) 

405 

406 if plot_level > 0: 

407 plot_all(data, eod_peaktimes, eod_troughtimes, rate, mean_eods) 

408 if save_plots: 

409 plt.savefig('%sextract_pulsefish_results.%s' % (save_path, ftype)) 

410 if save_plots: 

411 plt.close('all') 

412 

413 if 'all_eod_times' in return_data: 

414 log_dict['all_times'] = [x_peak/i_rate, x_trough/i_rate] 

415 log_dict['eod_troughtimes'] = eod_troughtimes 

416 

417 log_dict.update(c_log_dict) 

418 

419 return mean_eods, eod_times, eod_peaktimes, zoom_window, log_dict 

420 

421 

422def detect_pulses(data, rate, thresh, min_rel_slope_diff=0.25, 

423 min_width=0.00005, max_width=0.01, width_fac=5.0, 

424 verbose=0, return_data=False): 

425 """Detect pulses in data. 

426 

427 Was `def extract_eod_times(data, rate, width_factor, 

428 interp_freq=500000, max_peakwidth=0.01, 

429 min_peakwidth=None, verbose=0, return_data=[], 

430 save_path='')` before. 

431 

432 Parameters 

433 ---------- 

434 data: 1-D array of float 

435 The data to be analysed. 

436 rate: float 

437 Sampling rate of the data. 

438 thresh: float 

439 Threshold for peak and trough detection via `detect_peaks()`. 

440 Must be a positive number that sets the minimum difference 

441 between a peak and a trough. 

442 min_rel_slope_diff: float 

443 Minimum required difference between left and right slope (between 

444 peak and troughs) relative to mean slope for deciding which trough 

445 to take besed on slope difference. 

446 min_width: float 

447 Minimum width (peak-trough distance) of pulses in seconds. 

448 max_width: float 

449 Maximum width (peak-trough distance) of pulses in seconds. 

450 width_fac: float 

451 Pulses extend plus or minus `width_fac` times their width 

452 (distance between peak and assigned trough). 

453 Only pulses are returned that can fully be analysed with this width. 

454 verbose : int (optional) 

455 Verbosity level. 

456 return_data : bool 

457 If `True` data of this function is logged and returned (see 

458 extract_pulsefish()). 

459 

460 Returns 

461 ------- 

462 peak_indices: array of ints 

463 Indices of EOD peaks in data. 

464 trough_indices: array of ints 

465 Indices of EOD troughs in data. There is one x_trough for each x_peak. 

466 heights: array of floats 

467 EOD heights for each x_peak. 

468 widths: array of ints 

469 EOD widths for each x_peak (in samples). 

470 peak_detection_result : dictionary 

471 Key value pairs of logged data. 

472 This is only returned if `return_data` is `True`. 

473 

474 """ 

475 peak_detection_result = {} 

476 

477 # detect peaks and troughs in the data: 

478 peak_indices, trough_indices = detect_peaks(data, thresh) 

479 if verbose > 0: 

480 print('Peaks/troughs detected in data: %5d %5d' 

481 % (len(peak_indices), len(trough_indices))) 

482 if return_data: 

483 peak_detection_result.update(peaks_1=np.array(peak_indices), 

484 troughs_1=np.array(trough_indices)) 

485 if len(peak_indices) < 2 or \ 

486 len(trough_indices) < 2 or \ 

487 len(peak_indices) > len(data)/20: 

488 # TODO: if too many peaks increase threshold! 

489 if verbose > 0: 

490 print('No or too many peaks/troughs detected in data.') 

491 if return_data: 

492 return np.array([], dtype=int), np.array([], dtype=int), \ 

493 np.array([]), np.array([], dtype=int), peak_detection_result 

494 else: 

495 return np.array([], dtype=int), np.array([], dtype=int), \ 

496 np.array([]), np.array([], dtype=int) 

497 

498 # assign troughs to peaks: 

499 peak_indices, trough_indices, heights, widths, slopes = \ 

500 assign_side_peaks(data, peak_indices, trough_indices, min_rel_slope_diff) 

501 if verbose > 1: 

502 print('Number of peaks after assigning side-peaks: %5d' 

503 % (len(peak_indices))) 

504 if return_data: 

505 peak_detection_result.update(peaks_2=np.array(peak_indices), 

506 troughs_2=np.array(trough_indices)) 

507 

508 # check widths: 

509 keep = ((widths>min_width*rate) & (widths<max_width*rate)) 

510 peak_indices = peak_indices[keep] 

511 trough_indices = trough_indices[keep] 

512 heights = heights[keep] 

513 widths = widths[keep] 

514 slopes = slopes[keep] 

515 if verbose > 1: 

516 print('Number of peaks after checking pulse width: %5d' 

517 % (len(peak_indices))) 

518 if return_data: 

519 peak_detection_result.update(peaks_3=np.array(peak_indices), 

520 troughs_3=np.array(trough_indices)) 

521 

522 # discard connected peaks: 

523 same = np.nonzero(trough_indices[:-1] == trough_indices[1:])[0] 

524 keep = np.ones(len(trough_indices), dtype=bool) 

525 for i in same: 

526 # same troughs at trough_indices[i] and trough_indices[i+1]: 

527 s = slopes[i:i+2] 

528 rel_slopes = np.abs(np.diff(s))[0]/np.mean(s) 

529 if rel_slopes > min_rel_slope_diff: 

530 keep[i+(s[1]<s[0])] = False 

531 else: 

532 keep[i+(heights[i+1]<heights[i])] = False 

533 peak_indices = peak_indices[keep] 

534 trough_indices = trough_indices[keep] 

535 heights = heights[keep] 

536 widths = widths[keep] 

537 if verbose > 1: 

538 print('Number of peaks after merging pulses: %5d' 

539 % (len(peak_indices))) 

540 if return_data: 

541 peak_detection_result.update(peaks_4=np.array(peak_indices), 

542 troughs_4=np.array(trough_indices)) 

543 if len(peak_indices) == 0: 

544 if verbose > 0: 

545 print('No peaks remain as pulse candidates.') 

546 if return_data: 

547 return np.array([], dtype=int), np.array([], dtype=int), \ 

548 np.array([]), np.array([], dtype=int), peak_detection_result 

549 else: 

550 return np.array([], dtype=int), np.array([], dtype=int), \ 

551 np.array([]), np.array([], dtype=int) 

552 

553 # only take those where the maximum cutwidth does not cause issues - 

554 # if the width_fac times the width + x is more than length. 

555 keep = ((peak_indices - widths > 0) & 

556 (peak_indices + widths < len(data)) & 

557 (trough_indices - widths > 0) & 

558 (trough_indices + widths < len(data))) 

559 

560 if verbose > 0: 

561 print('Remaining peaks after EOD extraction: %5d' 

562 % (p.sum(keep))) 

563 print('') 

564 

565 if return_data: 

566 return peak_indices[keep], trough_indices[keep], \ 

567 heights[keep], widths[keep], peak_detection_result 

568 else: 

569 return peak_indices[keep], trough_indices[keep], \ 

570 heights[keep], widths[keep] 

571 

572 

573@jit(nopython=True) 

574def assign_side_peaks(data, peak_indices, trough_indices, 

575 min_rel_slope_diff=0.25): 

576 """Assign to each peak the trough resulting in a pulse with the steepest slope or largest height. 

577 

578 The slope between a peak and a trough is computed as the height 

579 difference divided by the distance between peak and trough. If the 

580 slopes between the left and the right trough differ by less than 

581 `min_rel_slope_diff`, then just the heigths between and the two 

582 troughs relative to the peak are compared. 

583 

584 Was `def detect_eod_peaks(data, main_indices, side_indices, 

585 max_width=20, min_width=2, verbose=0)` before. 

586 

587 Parameters 

588 ---------- 

589 data: array of floats 

590 Data in which the events were detected. 

591 peak_indices: array of ints 

592 Indices of the detected peaks in the data time series. 

593 trough_indices: array of ints 

594 Indices of the detected troughs in the data time series.  

595 min_rel_slope_diff: float 

596 Minimum required difference of left and right slope relative 

597 to mean slope. 

598 

599 Returns 

600 ------- 

601 peak_indices: array of ints 

602 Peak indices. Same as input `peak_indices` but potentially shorter 

603 by one or two elements. 

604 trough_indices: array of ints 

605 Corresponding trough indices of trough to the left or right 

606 of the peaks. 

607 heights: array of floats 

608 Peak heights (distance between peak and corresponding trough amplitude) 

609 widths: array of ints 

610 Peak widths (distance between peak and corresponding trough indices) 

611 slopes: array of floats 

612 Peak slope (height divided by width) 

613 """ 

614 # is a main or side peak first? 

615 peak_first = int(peak_indices[0] < trough_indices[0]) 

616 # is a main or side peak last? 

617 peak_last = int(peak_indices[-1] > trough_indices[-1]) 

618 # ensure all peaks to have side peaks (troughs) at both sides, 

619 # i.e. troughs at same index and next index are before and after peak: 

620 peak_indices = peak_indices[peak_first:len(peak_indices)-peak_last] 

621 y = data[peak_indices] 

622 

623 # indices of troughs on the left and right side of main peaks: 

624 l_indices = np.arange(len(peak_indices)) 

625 r_indices = l_indices + 1 

626 

627 # indices, distance to peak, height, and slope of left troughs: 

628 l_side_indices = trough_indices[l_indices] 

629 l_distance = np.abs(peak_indices - l_side_indices) 

630 l_height = np.abs(y - data[l_side_indices]) 

631 l_slope = np.abs(l_height/l_distance) 

632 

633 # indices, distance to peak, height, and slope of right troughs: 

634 r_side_indices = trough_indices[r_indices] 

635 r_distance = np.abs(r_side_indices - peak_indices) 

636 r_height = np.abs(y - data[r_side_indices]) 

637 r_slope = np.abs(r_height/r_distance) 

638 

639 # which trough to assign to the peak? 

640 # - either the one with the steepest slope, or 

641 # - when slopes are similar on both sides 

642 # (within `min_rel_slope_diff` difference), 

643 # the trough with the maximum height difference to the peak: 

644 rel_slopes = np.abs(l_slope-r_slope)/(0.5*(l_slope+r_slope)) 

645 take_slopes = rel_slopes > min_rel_slope_diff 

646 take_left = l_height > r_height 

647 take_left[take_slopes] = l_slope[take_slopes] > r_slope[take_slopes] 

648 

649 # assign troughs, heights, widths, and slopes: 

650 trough_indices = np.where(take_left, 

651 trough_indices[:-1], trough_indices[1:]) 

652 heights = np.where(take_left, l_height, r_height) 

653 widths = np.where(take_left, l_distance, r_distance) 

654 slopes = np.where(take_left, l_slope, r_slope) 

655 

656 return peak_indices, trough_indices, heights, widths, slopes 

657 

658 

659def cluster(eod_xp, eod_xt, eod_heights, eod_widths, data, rate, 

660 width_factor_shape, width_factor_wave, n_gaus_height=10, 

661 merge_threshold_height=0.1, n_gaus_width=3, 

662 merge_threshold_width=0.5, minp=10, verbose=0, 

663 plot_level=0, save_plots=False, save_path='', ftype='pdf', 

664 return_data=[]): 

665 """Cluster EODs. 

666  

667 First cluster on EOD widths using a Bayesian Gaussian 

668 Mixture (BGM) model, then cluster on EOD heights using a 

669 BGM model. Lastly, cluster on EOD waveform with DBSCAN. 

670 Clustering on EOD waveform is performed twice, once on 

671 peak-centered EODs and once on trough-centered EODs. 

672 Non-pulsetype EOD clusters are deleted, and clusters are 

673 merged afterwards. 

674 

675 Parameters 

676 ---------- 

677 eod_xp : list of ints 

678 Location of EOD peaks in indices. 

679 eod_xt: list of ints 

680 Locations of EOD troughs in indices. 

681 eod_heights: list of floats 

682 EOD heights. 

683 eod_widths: list of ints 

684 EOD widths in samples. 

685 data: array of floats 

686 Data in which to detect pulse EODs. 

687 rate : float 

688 Sampling rate of `data`. 

689 width_factor_shape : float 

690 Multiplier for snippet extraction width. This factor is 

691 multiplied with the width between the peak and through of a 

692 single EOD. 

693 width_factor_wave : float 

694 Multiplier for wavefish extraction width. 

695 n_gaus_height : int (optional) 

696 Number of gaussians to use for the clustering based on EOD height. 

697 merge_threshold_height : float (optional) 

698 Threshold for merging clusters that are similar in height. 

699 n_gaus_width : int (optional) 

700 Number of gaussians to use for the clustering based on EOD width. 

701 merge_threshold_width : float (optional) 

702 Threshold for merging clusters that are similar in width. 

703 minp : int (optional) 

704 Minimum number of points for core clusters (DBSCAN). 

705 verbose : int (optional) 

706 Verbosity level. 

707 plot_level : int (optional) 

708 Similar to verbosity levels, but with plots.  

709 Only set to > 0 for debugging purposes. 

710 save_plots : bool (optional) 

711 Set to True to save created plots. 

712 save_path : string (optional) 

713 Path to save plots to. Only used if save_plots==True. 

714 ftype : string (optional) 

715 Filetype to save plot images in. 

716 return_data : list of strings (optional) 

717 Keys that specify data to be logged. Keys that can be used to log data 

718 in this function are: 'all_cluster_steps', 'BGM_width', 'BGM_height', 

719 'snippet_clusters', 'eod_deletion' (see extract_pulsefish()). 

720 

721 Returns 

722 ------- 

723 labels : list of ints 

724 EOD cluster labels based on height and EOD waveform. 

725 x_merge : list of ints 

726 Locations of EODs in clusters. 

727 saved_data : dictionary 

728 Key value pairs of logged data. Data to be logged is specified 

729 by return_data. 

730 

731 """ 

732 saved_data = {} 

733 

734 if plot_level>0 or 'all_cluster_steps' in return_data: 

735 all_heightlabels = [] 

736 all_shapelabels = [] 

737 all_snippets = [] 

738 all_features = [] 

739 all_heights = [] 

740 all_unique_heightlabels = [] 

741 

742 all_p_clusters = -1 * np.ones(len(eod_xp)) 

743 all_t_clusters = -1 * np.ones(len(eod_xp)) 

744 artefact_masks_p = np.ones(len(eod_xp), dtype=bool) 

745 artefact_masks_t = np.ones(len(eod_xp), dtype=bool) 

746 

747 x_merge = -1 * np.ones(len(eod_xp)) 

748 

749 max_label_p = 0 # keep track of the labels so that no labels are overwritten 

750 max_label_t = 0 

751 

752 # loop only over height clusters that are bigger than minp 

753 # first cluster on width 

754 width_labels, bgm_log_dict = BGM(1000*eod_widths/rate, 

755 merge_threshold_width, 

756 n_gaus_width, use_log=False, 

757 verbose=verbose-1, 

758 plot_level=plot_level-1, 

759 xlabel='width [ms]', 

760 save_plot=save_plots, 

761 save_path=save_path, 

762 save_name='width', ftype=ftype, 

763 return_data=return_data) 

764 saved_data.update(bgm_log_dict) 

765 

766 if verbose > 0: 

767 print('Clusters generated based on EOD width:') 

768 for l in np.unique(width_labels): 

769 print(f'N_{l} = {len(width_labels[width_labels==l]):4d} h_{l} = {np.mean(eod_widths[width_labels==l]):.4f}') 

770 

771 w_labels, w_counts = unique_counts(width_labels) 

772 unique_width_labels = w_labels[w_counts>minp] 

773 

774 for wi, width_label in enumerate(unique_width_labels): 

775 

776 # select only features in one width cluster at a time 

777 w_eod_widths = eod_widths[width_labels==width_label] 

778 w_eod_heights = eod_heights[width_labels==width_label] 

779 w_eod_xp = eod_xp[width_labels==width_label] 

780 w_eod_xt = eod_xt[width_labels==width_label] 

781 width = int(width_factor_shape*np.median(w_eod_widths)) 

782 if width > w_eod_xp[0]: 

783 width = w_eod_xp[0] 

784 if width > w_eod_xt[0]: 

785 width = w_eod_xt[0] 

786 if width > len(data) - w_eod_xp[-1]: 

787 width = len(data) - w_eod_xp[-1] 

788 if width > len(data) - w_eod_xt[-1]: 

789 width = len(data) - w_eod_xt[-1] 

790 

791 wp_clusters = -1 * np.ones(len(w_eod_xp)) 

792 wt_clusters = -1 * np.ones(len(w_eod_xp)) 

793 wartefact_mask = np.ones(len(w_eod_xp)) 

794 

795 # determine height labels 

796 raw_p_snippets, p_snippets, p_features, p_bg_ratio = \ 

797 extract_snippet_features(data, w_eod_xp, w_eod_heights, width) 

798 raw_t_snippets, t_snippets, t_features, t_bg_ratio = \ 

799 extract_snippet_features(data, w_eod_xt, w_eod_heights, width) 

800 

801 height_labels, bgm_log_dict = \ 

802 BGM(w_eod_heights, min(merge_threshold_height, 

803 np.median(np.min(np.vstack([p_bg_ratio, t_bg_ratio]), 

804 axis=0))), n_gaus_height, use_log=True, 

805 verbose=verbose-1, plot_level=plot_level-1, xlabel = 

806 'height [a.u.]', save_plot=save_plots, 

807 save_path=save_path, save_name = 'height_%d' % wi, 

808 ftype=ftype, return_data=return_data) 

809 saved_data.update(bgm_log_dict) 

810 

811 if verbose > 0: 

812 print('Clusters generated based on EOD height:') 

813 for l in np.unique(height_labels): 

814 print(f'N_{l} = {len(height_labels[height_labels==l]):4d} h_{l} = {np.mean(w_eod_heights[height_labels==l]):.4f}') 

815 

816 h_labels, h_counts = unique_counts(height_labels) 

817 unique_height_labels = h_labels[h_counts>minp] 

818 

819 if plot_level > 0 or 'all_cluster_steps' in return_data: 

820 all_heightlabels.append(height_labels) 

821 all_heights.append(w_eod_heights) 

822 all_unique_heightlabels.append(unique_height_labels) 

823 shape_labels = [] 

824 cfeatures = [] 

825 csnippets = [] 

826 

827 for hi, height_label in enumerate(unique_height_labels): 

828 

829 h_eod_widths = w_eod_widths[height_labels==height_label] 

830 h_eod_heights = w_eod_heights[height_labels==height_label] 

831 h_eod_xp = w_eod_xp[height_labels==height_label] 

832 h_eod_xt = w_eod_xt[height_labels==height_label] 

833 

834 p_clusters = cluster_on_shape(p_features[height_labels==height_label], 

835 p_bg_ratio, minp, verbose=0) 

836 t_clusters = cluster_on_shape(t_features[height_labels==height_label], 

837 t_bg_ratio, minp, verbose=0) 

838 

839 if plot_level > 1: 

840 plot_feature_extraction(raw_p_snippets[height_labels==height_label], 

841 p_snippets[height_labels==height_label], 

842 p_features[height_labels==height_label], 

843 p_clusters, 1/rate, 0) 

844 plt.savefig('%sDBSCAN_peak_w%i_h%i.%s' % (save_path, wi, hi, ftype)) 

845 plot_feature_extraction(raw_t_snippets[height_labels==height_label], 

846 t_snippets[height_labels==height_label], 

847 t_features[height_labels==height_label], 

848 t_clusters, 1/rate, 1) 

849 plt.savefig('%sDBSCAN_trough_w%i_h%i.%s' % (save_path, wi, hi, ftype)) 

850 

851 if 'snippet_clusters' in return_data: 

852 saved_data[f'snippet_clusters_{width_label}_{height_label}_peak'] = { 

853 'raw_snippets': raw_p_snippets[height_labels==height_label], 

854 'snippets': p_snippets[height_labels==height_label], 

855 'features': p_features[height_labels==height_label], 

856 'clusters': p_clusters, 

857 'rate': rate} 

858 saved_data['snippet_clusters_{width_label}_{height_label}_trough'] = { 

859 'raw_snippets': raw_t_snippets[height_labels==height_label], 

860 'snippets': t_snippets[height_labels==height_label], 

861 'features': t_features[height_labels==height_label], 

862 'clusters': t_clusters, 

863 'rate': rate} 

864 

865 if plot_level > 0 or 'all_cluster_steps' in return_data: 

866 shape_labels.append([p_clusters, t_clusters]) 

867 cfeatures.append([p_features[height_labels==height_label], 

868 t_features[height_labels==height_label]]) 

869 csnippets.append([p_snippets[height_labels==height_label], 

870 t_snippets[height_labels==height_label]]) 

871 

872 p_clusters[p_clusters==-1] = -max_label_p - 1 

873 wp_clusters[height_labels==height_label] = p_clusters + max_label_p 

874 max_label_p = max(np.max(wp_clusters), np.max(all_p_clusters)) + 1 

875 

876 t_clusters[t_clusters==-1] = -max_label_t - 1 

877 wt_clusters[height_labels==height_label] = t_clusters + max_label_t 

878 max_label_t = max(np.max(wt_clusters), np.max(all_t_clusters)) + 1 

879 

880 if verbose > 0: 

881 if np.max(wp_clusters) == -1: 

882 print(f'No EOD peaks in width cluster {width_label}') 

883 else: 

884 unique_clusters = np.unique(wp_clusters[wp_clusters!=-1]) 

885 if len(unique_clusters) > 1: 

886 print('{len(unique_clusters)} different EOD peaks in width cluster {width_label}') 

887 

888 if plot_level > 0 or 'all_cluster_steps' in return_data: 

889 all_shapelabels.append(shape_labels) 

890 all_snippets.append(csnippets) 

891 all_features.append(cfeatures) 

892 

893 # for each cluster, save fft + label 

894 # so I end up with features for each label, and the masks. 

895 # then I can extract e.g. first artefact or wave etc. 

896 

897 # remove artefacts here, based on the mean snippets ffts. 

898 artefact_masks_p[width_labels==width_label], sdict = \ 

899 remove_artefacts(p_snippets, wp_clusters, rate, 

900 verbose=verbose-1, return_data=return_data) 

901 saved_data.update(sdict) 

902 artefact_masks_t[width_labels==width_label], _ = \ 

903 remove_artefacts(t_snippets, wt_clusters, rate, 

904 verbose=verbose-1, return_data=return_data) 

905 

906 # update maxlab so that no clusters are overwritten 

907 all_p_clusters[width_labels==width_label] = wp_clusters 

908 all_t_clusters[width_labels==width_label] = wt_clusters 

909 

910 # remove all non-reliable clusters 

911 unreliable_fish_mask_p, saved_data = \ 

912 delete_unreliable_fish(all_p_clusters, eod_widths, eod_xp, 

913 verbose=verbose-1, sdict=saved_data) 

914 unreliable_fish_mask_t, _ = \ 

915 delete_unreliable_fish(all_t_clusters, eod_widths, eod_xt, verbose=verbose-1) 

916 

917 wave_mask_p, sidepeak_mask_p, saved_data = \ 

918 delete_wavefish_and_sidepeaks(data, all_p_clusters, eod_xp, eod_widths, 

919 width_factor_wave, verbose=verbose-1, sdict=saved_data) 

920 wave_mask_t, sidepeak_mask_t, _ = \ 

921 delete_wavefish_and_sidepeaks(data, all_t_clusters, eod_xt, eod_widths, 

922 width_factor_wave, verbose=verbose-1) 

923 

924 og_clusters = [np.copy(all_p_clusters), np.copy(all_t_clusters)] 

925 og_labels = np.copy(all_p_clusters + all_t_clusters) 

926 

927 # go through all clusters and masks?? 

928 all_p_clusters[(artefact_masks_p | unreliable_fish_mask_p | wave_mask_p | sidepeak_mask_p)] = -1 

929 all_t_clusters[(artefact_masks_t | unreliable_fish_mask_t | wave_mask_t | sidepeak_mask_t)] = -1 

930 

931 # merge here. 

932 all_clusters, x_merge, mask = merge_clusters(np.copy(all_p_clusters), 

933 np.copy(all_t_clusters), 

934 eod_xp, eod_xt, 

935 verbose=verbose - 1) 

936 

937 if 'all_cluster_steps' in return_data or plot_level > 0: 

938 all_dmasks = [] 

939 all_mmasks = [] 

940 

941 discarding_masks = \ 

942 np.vstack(((artefact_masks_p | unreliable_fish_mask_p | wave_mask_p | sidepeak_mask_p), 

943 (artefact_masks_t | unreliable_fish_mask_t | wave_mask_t | sidepeak_mask_t))) 

944 merge_mask = mask 

945 

946 # save the masks in the same formats as the snippets 

947 for wi, (width_label, w_shape_label, heightlabels, unique_height_labels) in enumerate(zip(unique_width_labels, all_shapelabels, all_heightlabels, all_unique_heightlabels)): 

948 w_dmasks = discarding_masks[:,width_labels==width_label] 

949 w_mmasks = merge_mask[:,width_labels==width_label] 

950 

951 wd_2 = [] 

952 wm_2 = [] 

953 

954 for hi, (height_label, h_shape_label) in enumerate(zip(unique_height_labels, w_shape_label)): 

955 

956 h_dmasks = w_dmasks[:,heightlabels==height_label] 

957 h_mmasks = w_mmasks[:,heightlabels==height_label] 

958 

959 wd_2.append(h_dmasks) 

960 wm_2.append(h_mmasks) 

961 

962 all_dmasks.append(wd_2) 

963 all_mmasks.append(wm_2) 

964 

965 if plot_level > 0: 

966 plot_clustering(rate, [unique_width_labels, eod_widths, width_labels], 

967 [all_unique_heightlabels, all_heights, all_heightlabels], 

968 [all_snippets, all_features, all_shapelabels], 

969 all_dmasks, all_mmasks) 

970 if save_plots: 

971 plt.savefig('%sclustering.%s' % (save_path, ftype)) 

972 

973 if 'all_cluster_steps' in return_data: 

974 saved_data = {'rate': rate, 

975 'EOD_widths': [unique_width_labels, eod_widths, width_labels], 

976 'EOD_heights': [all_unique_heightlabels, all_heights, all_heightlabels], 

977 'EOD_shapes': [all_snippets, all_features, all_shapelabels], 

978 'discarding_masks': all_dmasks, 

979 'merge_masks': all_mmasks 

980 } 

981 

982 if 'masks' in return_data: 

983 saved_data = {'masks' : np.vstack(((artefact_masks_p & artefact_masks_t), 

984 (unreliable_fish_mask_p & unreliable_fish_mask_t), 

985 (wave_mask_p & wave_mask_t), 

986 (sidepeak_mask_p & sidepeak_mask_t), 

987 (all_p_clusters+all_t_clusters)))} 

988 

989 if verbose > 0: 

990 print('Clusters generated based on height, width and shape: ') 

991 for l in np.unique(all_clusters[all_clusters != -1]): 

992 print('N_{int(l)} = {len(all_clusters[all_clusters == l]):4d}') 

993 

994 return all_clusters, x_merge, saved_data 

995 

996 

997def BGM(x, merge_threshold=0.1, n_gaus=5, max_iter=200, n_init=5, 

998 use_log=False, verbose=0, plot_level=0, xlabel='x [a.u.]', 

999 save_plot=False, save_path='', save_name='', ftype='pdf', 

1000 return_data=[]): 

1001 """ Use a Bayesian Gaussian Mixture Model to cluster one-dimensional data. 

1002 

1003 Additional steps are used to merge clusters that are closer than 

1004 `merge_threshold`. Broad gaussian fits that cover one or more other 

1005 gaussian fits are split by their intersections with the other 

1006 gaussians. 

1007 

1008 Parameters 

1009 ---------- 

1010 x : 1D numpy array 

1011 Features to compute clustering on.  

1012 

1013 merge_threshold : float (optional) 

1014 Ratio for merging nearby gaussians. 

1015 n_gaus: int (optional) 

1016 Maximum number of gaussians to fit on data. 

1017 max_iter : int (optional) 

1018 Maximum number of iterations for gaussian fit. 

1019 n_init : int (optional) 

1020 Number of initializations for the gaussian fit. 

1021 use_log: boolean (optional) 

1022 Set to True to compute the gaussian fit on the logarithm of x. 

1023 Can improve clustering on features with nonlinear relationships such as peak height. 

1024 verbose : int (optional) 

1025 Verbosity level. 

1026 plot_level : int (optional) 

1027 Similar to verbosity levels, but with plots.  

1028 Only set to > 0 for debugging purposes. 

1029 xlabel : string (optional) 

1030 Xlabel for displaying BGM plot. 

1031 save_plot : bool (optional) 

1032 Set to True to save created plot. 

1033 save_path : string (optional) 

1034 Path to location where data should be saved. Only used if save_plot==True. 

1035 save_name : string (optional) 

1036 Filename of the saved plot. Usefull as usually multiple BGM models are generated. 

1037 ftype : string (optional) 

1038 Filetype of plot image if save_plots==True. 

1039 return_data : list of strings (optional) 

1040 Keys that specify data to be logged. Keys that can be used to log data 

1041 in this function are: 'BGM_width' and/or 'BGM_height' (see extract_pulsefish()). 

1042 

1043 Returns 

1044 ------- 

1045 labels : 1D numpy array 

1046 Cluster labels for each sample in x. 

1047 bgm_dict : dictionary 

1048 Key value pairs of logged data. Data to be logged is specified by return_data. 

1049 """ 

1050 

1051 bgm_dict = {} 

1052 

1053 if len(np.unique(x)) > n_gaus: 

1054 BGM_model = BayesianGaussianMixture(n_components=n_gaus, max_iter=max_iter, n_init=n_init) 

1055 if use_log: 

1056 labels = BGM_model.fit_predict(stats.zscore(np.log(x)).reshape(-1, 1)) 

1057 else: 

1058 labels = BGM_model.fit_predict(stats.zscore(x).reshape(-1, 1)) 

1059 else: 

1060 return np.zeros(len(x)), bgm_dict 

1061 

1062 if verbose>0: 

1063 if not BGM_model.converged_: 

1064 print('!!! Gaussian mixture did not converge !!!') 

1065 

1066 cur_labels = np.unique(labels) 

1067 

1068 # map labels to be increasing for increasing values for x 

1069 maxlab = len(cur_labels) 

1070 aso = np.argsort([np.median(x[labels == l]) for l in cur_labels]) + 100 

1071 for i, a in zip(cur_labels, aso): 

1072 labels[labels==i] = a 

1073 labels = labels - 100 

1074 

1075 # separate gaussian clusters that can be split by other clusters 

1076 splits = np.sort(np.copy(x))[1:][np.diff(labels[np.argsort(x)])!=0] 

1077 

1078 labels[:] = 0 

1079 for i, split in enumerate(splits): 

1080 labels[x>=split] = i+1 

1081 

1082 labels_before_merge = np.copy(labels) 

1083 

1084 # merge gaussian clusters that are closer than merge_threshold 

1085 labels = merge_gaussians(x, labels, merge_threshold) 

1086 

1087 if 'BGM_'+save_name.split('_')[0] in return_data or plot_level>0: 

1088 

1089 #sort model attributes by model_means_ 

1090 means = [m[0] for m in BGM_model.means_] 

1091 weights = [w for w in BGM_model.weights_] 

1092 variances = [v[0][0] for v in BGM_model.covariances_] 

1093 weights = [w for _, w in sorted(zip(means, weights))] 

1094 variances = [v for _, v in sorted(zip(means, variances))] 

1095 means = sorted(means) 

1096 

1097 if plot_level>0: 

1098 plot_bgm(x, means, variances, weights, use_log, labels_before_merge, 

1099 labels, xlabel) 

1100 if save_plot: 

1101 plt.savefig('%sBGM_%s.%s' % (save_path, save_name, ftype)) 

1102 

1103 if 'BGM_'+save_name.split('_')[0] in return_data: 

1104 bgm_dict['BGM_'+save_name] = {'x':x, 

1105 'use_log':use_log, 

1106 'BGM':[weights, means, variances], 

1107 'labels':labels_before_merge, 

1108 'xlab':xlabel} 

1109 

1110 return labels, bgm_dict 

1111 

1112 

1113def merge_gaussians(x, labels, merge_threshold=0.1): 

1114 """ Merge all clusters which have medians which are near. Only works in 1D. 

1115 

1116 Parameters 

1117 ---------- 

1118 x : 1D array of ints or floats 

1119 Features used for clustering. 

1120 labels : 1D array of ints 

1121 Labels for each sample in x. 

1122 merge_threshold : float (optional) 

1123 Similarity threshold to merge clusters. 

1124 

1125 Returns 

1126 ------- 

1127 labels : 1D array of ints 

1128 Merged labels for each sample in x. 

1129 """ 

1130 

1131 # compare all the means of the gaussians. If they are too close, merge them. 

1132 unique_labels = np.unique(labels[labels!=-1]) 

1133 x_medians = [np.median(x[labels==l]) for l in unique_labels] 

1134 

1135 # fill a dict with the label mappings 

1136 mapping = {} 

1137 for label_1, x_m1 in zip(unique_labels, x_medians): 

1138 for label_2, x_m2 in zip(unique_labels, x_medians): 

1139 if label_1!=label_2: 

1140 if np.abs(np.diff([x_m1, x_m2]))/np.max([x_m1, x_m2]) < merge_threshold: 

1141 mapping[label_1] = label_2 

1142 # apply mapping 

1143 for map_key, map_value in mapping.items(): 

1144 labels[labels==map_key] = map_value 

1145 

1146 return labels 

1147 

1148 

1149def extract_snippet_features(data, eod_x, eod_heights, width, n_pc=5): 

1150 """ Extract snippets from recording data, normalize them, and perform PCA. 

1151 

1152 Parameters 

1153 ---------- 

1154 data : 1D numpy array of floats 

1155 Recording data. 

1156 eod_x : 1D array of ints 

1157 Locations of EODs as indices. 

1158 eod_heights: 1D array of floats 

1159 EOD heights. 

1160 width : int 

1161 Width to cut out to each side in samples. 

1162 

1163 n_pc : int (optional) 

1164 Number of PCs to use for PCA. 

1165 

1166 Returns 

1167 ------- 

1168 raw_snippets : 2D numpy array (N, EOD_width) 

1169 Raw extracted EOD snippets. 

1170 snippets : 2D numpy array (N, EOD_width) 

1171 Normalized EOD snippets 

1172 features : 2D numpy array (N,n_pc) 

1173 PC values of EOD snippets 

1174 bg_ratio : 1D numpy array (N) 

1175 Ratio of the background activity slopes compared to EOD height. 

1176 """ 

1177 # extract snippets with corresponding width 

1178 raw_snippets = np.vstack([data[x-width:x+width] for x in eod_x]) 

1179 

1180 # subtract the slope and normalize the snippets 

1181 snippets, bg_ratio = subtract_slope(np.copy(raw_snippets), eod_heights) 

1182 snippets = StandardScaler().fit_transform(snippets.T).T 

1183 

1184 # scale so that the absolute integral = 1. 

1185 snippets = (snippets.T/np.sum(np.abs(snippets), axis=1)).T 

1186 

1187 # compute features for clustering on waveform 

1188 features = PCA(n_pc).fit_transform(snippets) 

1189 

1190 return raw_snippets, snippets, features, bg_ratio 

1191 

1192 

1193def cluster_on_shape(features, bg_ratio, minp, percentile=80, 

1194 max_epsilon=0.01, slope_ratio_factor=4, 

1195 min_cluster_fraction=0.01, verbose=0): 

1196 """Separate EODs by their shape using DBSCAN. 

1197 

1198 Parameters 

1199 ---------- 

1200 features : 2D numpy array of floats (N, n_pc) 

1201 PCA features of each EOD in a recording. 

1202 bg_ratio : 1D array of floats 

1203 Ratio of background activity slope the EOD is superimposed on. 

1204 minp : int 

1205 Minimum number of points for core cluster (DBSCAN). 

1206 

1207 percentile : int (optional) 

1208 Percentile of KNN distribution, where K=minp, to use as epsilon for DBSCAN. 

1209 max_epsilon : float (optional) 

1210 Maximum epsilon to use for DBSCAN clustering. This is used to avoid adding 

1211 noisy clusters. 

1212 slope_ratio_factor : float (optional) 

1213 Influence of the slope-to-EOD ratio on the epsilon parameter. 

1214 A slope_ratio_factor of 4 means that slope-to-EOD ratios >1/4 

1215 start influencing epsilon. 

1216 min_cluster_fraction : float (optional) 

1217 Minimum fraction of all eveluated datapoint that can form a single cluster. 

1218 verbose : int (optional) 

1219 Verbosity level. 

1220 

1221 Returns 

1222 ------- 

1223 labels : 1D array of ints 

1224 Merged labels for each sample in x. 

1225 """ 

1226 

1227 # determine clustering threshold from data 

1228 minpc = max(minp, int(len(features)*min_cluster_fraction)) 

1229 knn = np.sort(pairwise_distances(features, features), axis=0)[minpc] 

1230 eps = min(max(1, slope_ratio_factor*np.median(bg_ratio))*max_epsilon, 

1231 np.percentile(knn, percentile)) 

1232 

1233 if verbose>1: 

1234 print('epsilon = %f'%eps) 

1235 print('Slope to EOD ratio = %f'%np.median(bg_ratio)) 

1236 

1237 # cluster on EOD shape 

1238 return DBSCAN(eps=eps, min_samples=minpc).fit(features).labels_ 

1239 

1240 

1241def subtract_slope(snippets, heights): 

1242 """ Subtract underlying slope from all EOD snippets. 

1243 

1244 Parameters 

1245 ---------- 

1246 snippets: 2-D numpy array 

1247 All EODs in a recorded stacked as snippets.  

1248 Shape = (number of EODs, EOD width) 

1249 heights: 1D numpy array 

1250 EOD heights. 

1251 

1252 Returns 

1253 ------- 

1254 snippets: 2-D numpy array 

1255 EOD snippets with underlying slope subtracted. 

1256 bg_ratio : 1-D numpy array 

1257 EOD height/background activity height. 

1258 """ 

1259 

1260 left_y = snippets[:,0] 

1261 right_y = snippets[:,-1] 

1262 

1263 try: 

1264 slopes = np.linspace(left_y, right_y, snippets.shape[1]) 

1265 except ValueError: 

1266 delta = (right_y - left_y)/snippets.shape[1] 

1267 slopes = np.arange(0, snippets.shape[1], dtype=snippets.dtype).reshape((-1,) + (1,) * np.ndim(delta))*delta + left_y 

1268 

1269 return snippets - slopes.T, np.abs(left_y-right_y)/heights 

1270 

1271 

1272def remove_artefacts(all_snippets, clusters, rate, 

1273 freq_low=20000, threshold=0.75, 

1274 verbose=0, return_data=[]): 

1275 """ Create a mask for EOD clusters that result from artefacts, based on power in low frequency spectrum. 

1276 

1277 Parameters 

1278 ---------- 

1279 all_snippets: 2D array 

1280 EOD snippets. Shape=(nEODs, EOD length) 

1281 clusters: list of ints 

1282 EOD cluster labels 

1283 rate : float 

1284 Sampling rate of original recording data. 

1285 freq_low: float 

1286 Frequency up to which low frequency components are summed up.  

1287 threshold : float (optional) 

1288 Minimum value for sum of low frequency components relative to 

1289 sum overa ll spectrl amplitudes that separates artefact from 

1290 clean pulsefish clusters. 

1291 verbose : int (optional) 

1292 Verbosity level. 

1293 return_data : list of strings (optional) 

1294 Keys that specify data to be logged. The key that can be used to log data in this function is 

1295 'eod_deletion' (see extract_pulsefish()). 

1296 

1297 Returns 

1298 ------- 

1299 mask: numpy array of booleans 

1300 Set to True for every EOD which is an artefact. 

1301 adict : dictionary 

1302 Key value pairs of logged data. Data to be logged is specified by return_data. 

1303 """ 

1304 adict = {} 

1305 

1306 mask = np.zeros(clusters.shape, dtype=bool) 

1307 

1308 for cluster in np.unique(clusters[clusters >= 0]): 

1309 snippets = all_snippets[clusters == cluster] 

1310 mean_eod = np.mean(snippets, axis=0) 

1311 mean_eod = mean_eod - np.mean(mean_eod) 

1312 mean_eod_fft = np.abs(np.fft.rfft(mean_eod)) 

1313 freqs = np.fft.rfftfreq(len(mean_eod), 1/rate) 

1314 low_frequency_ratio = np.sum(mean_eod_fft[freqs<freq_low])/np.sum(mean_eod_fft) 

1315 if low_frequency_ratio < threshold: # TODO: check threshold! 

1316 mask[clusters==cluster] = True 

1317 

1318 if verbose > 0: 

1319 print('Deleting cluster %i with low frequency ratio of %.3f (min %.3f)' % (cluster, low_frequency_ratio, threshold)) 

1320 

1321 if 'eod_deletion' in return_data: 

1322 adict['vals_%d' % cluster] = [mean_eod, mean_eod_fft] 

1323 adict['mask_%d' % cluster] = [np.any(mask[clusters==cluster])] 

1324 

1325 return mask, adict 

1326 

1327 

1328def delete_unreliable_fish(clusters, eod_widths, eod_x, verbose=0, sdict={}): 

1329 """ Create a mask for EOD clusters that are either mixed with noise or other fish, or wavefish. 

1330  

1331 This is the case when the ration between the EOD width and the ISI is too large. 

1332 

1333 Parameters 

1334 ---------- 

1335 clusters : list of ints 

1336 Cluster labels. 

1337 eod_widths : list of floats or ints 

1338 EOD widths in samples or seconds. 

1339 eod_x : list of ints or floats 

1340 EOD times in samples or seconds. 

1341 

1342 verbose : int (optional) 

1343 Verbosity level. 

1344 sdict : dictionary 

1345 Dictionary that is used to log data. This is only used if a dictionary 

1346 was created by remove_artefacts(). 

1347 For logging data in noise and wavefish discarding steps, 

1348 see remove_artefacts(). 

1349 

1350 Returns 

1351 ------- 

1352 mask : numpy array of booleans 

1353 Set to True for every unreliable EOD. 

1354 sdict : dictionary 

1355 Key value pairs of logged data. Data is only logged if a dictionary 

1356 was instantiated by remove_artefacts(). 

1357 """ 

1358 mask = np.zeros(clusters.shape, dtype=bool) 

1359 for cluster in np.unique(clusters[clusters >= 0]): 

1360 if len(eod_x[cluster == clusters]) < 2: 

1361 mask[clusters == cluster] = True 

1362 if verbose>0: 

1363 print('deleting unreliable cluster %i, number of EOD times %d < 2' % (cluster, len(eod_x[cluster==clusters]))) 

1364 elif np.max(np.median(eod_widths[clusters==cluster])/np.diff(eod_x[cluster==clusters])) > 0.5: 

1365 if verbose>0: 

1366 print('deleting unreliable cluster %i, score=%f' % (cluster, np.max(np.median(eod_widths[clusters==cluster])/np.diff(eod_x[cluster==clusters])))) 

1367 mask[clusters==cluster] = True 

1368 if 'vals_%d' % cluster in sdict: 

1369 sdict['vals_%d' % cluster].append(np.median(eod_widths[clusters==cluster])/np.diff(eod_x[cluster==clusters])) 

1370 sdict['mask_%d' % cluster].append(any(mask[clusters==cluster])) 

1371 return mask, sdict 

1372 

1373 

1374def delete_wavefish_and_sidepeaks(data, clusters, eod_x, eod_widths, 

1375 width_fac, max_slope_deviation=0.5, 

1376 max_phases=4, verbose=0, sdict={}): 

1377 """ Create a mask for EODs that are likely from wavefish, or sidepeaks of bigger EODs. 

1378 

1379 Parameters 

1380 ---------- 

1381 data : list of floats 

1382 Raw recording data. 

1383 clusters : list of ints 

1384 Cluster labels. 

1385 eod_x : list of ints 

1386 Indices of EOD times. 

1387 eod_widths : list of ints 

1388 EOD widths in samples. 

1389 width_fac : float 

1390 Multiplier for EOD analysis width. 

1391 

1392 max_slope_deviation: float (optional) 

1393 Maximum deviation of position of maximum slope in snippets from 

1394 center position in multiples of mean width of EOD. 

1395 max_phases : int (optional) 

1396 Maximum number of phases for any EOD.  

1397 If the mean EOD has more phases than this, it is not a pulse EOD. 

1398 verbose : int (optional)  

1399 Verbosity level. 

1400 sdict : dictionary 

1401 Dictionary that is used to log data. This is only used if a dictionary 

1402 was created by remove_artefacts(). 

1403 For logging data in noise and wavefish discarding steps, see remove_artefacts(). 

1404 

1405 Returns 

1406 ------- 

1407 mask_wave: numpy array of booleans 

1408 Set to True for every EOD which is a wavefish EOD. 

1409 mask_sidepeak: numpy array of booleans 

1410 Set to True for every snippet which is centered around a sidepeak of an EOD. 

1411 sdict : dictionary 

1412 Key value pairs of logged data. Data is only logged if a dictionary 

1413 was instantiated by remove_artefacts(). 

1414 """ 

1415 mask_wave = np.zeros(clusters.shape, dtype=bool) 

1416 mask_sidepeak = np.zeros(clusters.shape, dtype=bool) 

1417 

1418 for i, cluster in enumerate(np.unique(clusters[clusters >= 0])): 

1419 mean_width = np.mean(eod_widths[clusters == cluster]) 

1420 cutwidth = mean_width*width_fac 

1421 current_x = eod_x[(eod_x>cutwidth) & (eod_x<(len(data)-cutwidth))] 

1422 current_clusters = clusters[(eod_x>cutwidth) & (eod_x<(len(data)-cutwidth))] 

1423 snippets = np.vstack([data[int(x-cutwidth):int(x+cutwidth)] 

1424 for x in current_x[current_clusters==cluster]]) 

1425 

1426 # extract information on main peaks and troughs: 

1427 mean_eod = np.mean(snippets, axis=0) 

1428 mean_eod = mean_eod - np.mean(mean_eod) 

1429 

1430 # detect peaks and troughs on data + some maxima/minima at the 

1431 # end, so that the sides are also considered for peak detection: 

1432 pk, tr = detect_peaks(np.concatenate(([-10*mean_eod[0]], mean_eod, [10*mean_eod[-1]])), 

1433 np.std(mean_eod)) 

1434 pk = pk[(pk>0)&(pk<len(mean_eod))] 

1435 tr = tr[(tr>0)&(tr<len(mean_eod))] 

1436 

1437 if len(pk)>0 and len(tr)>0: 

1438 idxs = np.sort(np.concatenate((pk, tr))) 

1439 slopes = np.abs(np.diff(mean_eod[idxs])) 

1440 m_slope = np.argmax(slopes) 

1441 centered = np.min(np.abs(idxs[m_slope:m_slope+2] - len(mean_eod)//2)) 

1442 

1443 # compute all height differences of peaks and troughs within snippets. 

1444 # if they are all similar, it is probably noise or a wavefish. 

1445 idxs = np.sort(np.concatenate((pk, tr))) 

1446 hdiffs = np.diff(mean_eod[idxs]) 

1447 

1448 if centered > max_slope_deviation*mean_width: # TODO: check, factor was probably 0.16 

1449 if verbose > 0: 

1450 print('Deleting cluster %i, which is a sidepeak' % cluster) 

1451 mask_sidepeak[clusters==cluster] = True 

1452 

1453 w_diff = np.abs(np.diff(np.sort(np.concatenate((pk, tr))))) 

1454 

1455 if np.abs(np.diff(idxs[m_slope:m_slope+2])) < np.mean(eod_widths[clusters==cluster])*0.5 or len(pk) + len(tr)>max_phases or np.min(w_diff)>2*cutwidth/width_fac: #or len(hdiffs[np.abs(hdiffs)>0.5*(np.max(mean_eod)-np.min(mean_eod))])>max_phases: 

1456 if verbose>0: 

1457 print('Deleting cluster %i, which is a wavefish' % cluster) 

1458 mask_wave[clusters==cluster] = True 

1459 if 'vals_%d' % cluster in sdict: 

1460 sdict['vals_%d' % cluster].append([mean_eod, [pk, tr], 

1461 idxs[m_slope:m_slope+2]]) 

1462 sdict['mask_%d' % cluster].append(any(mask_wave[clusters==cluster])) 

1463 sdict['mask_%d' % cluster].append(any(mask_sidepeak[clusters==cluster])) 

1464 

1465 return mask_wave, mask_sidepeak, sdict 

1466 

1467 

1468def merge_clusters(clusters_1, clusters_2, x_1, x_2, verbose=0): 

1469 """ Merge clusters resulting from two clustering methods. 

1470 

1471 This method only works if clustering is performed on the same EODs 

1472 with the same ordering, where there is a one to one mapping from 

1473 clusters_1 to clusters_2.  

1474 

1475 Parameters 

1476 ---------- 

1477 clusters_1: list of ints 

1478 EOD cluster labels for cluster method 1. 

1479 clusters_2: list of ints 

1480 EOD cluster labels for cluster method 2. 

1481 x_1: list of ints 

1482 Indices of EODs for cluster method 1 (clusters_1). 

1483 x_2: list of ints 

1484 Indices of EODs for cluster method 2 (clusters_2). 

1485 verbose : int (optional) 

1486 Verbosity level. 

1487 

1488 Returns 

1489 ------- 

1490 clusters : list of ints 

1491 Merged clusters. 

1492 x_merged : list of ints 

1493 Merged cluster indices. 

1494 mask : 2d numpy array of ints (N, 2) 

1495 Mask for clusters that are selected from clusters_1 (mask[:,0]) and 

1496 from clusters_2 (mask[:,1]). 

1497 """ 

1498 if verbose > 0: 

1499 print('\nMerge cluster:') 

1500 

1501 # these arrays become 1 for each EOD that is chosen from that array 

1502 c1_keep = np.zeros(len(clusters_1)) 

1503 c2_keep = np.zeros(len(clusters_2)) 

1504 

1505 # add n to one of the cluster lists to avoid overlap 

1506 ovl = np.max(clusters_1) + 1 

1507 clusters_2[clusters_2!=-1] = clusters_2[clusters_2!=-1] + ovl 

1508 

1509 remove_clusters = [[]] 

1510 keep_clusters = [] 

1511 og_clusters = [np.copy(clusters_1), np.copy(clusters_2)] 

1512 

1513 # loop untill done 

1514 while True: 

1515 

1516 # compute unique clusters and cluster sizes 

1517 # of cluster that have not been iterated over: 

1518 c1_labels, c1_size = unique_counts(clusters_1[(clusters_1 != -1) & (c1_keep == 0)]) 

1519 c2_labels, c2_size = unique_counts(clusters_2[(clusters_2 != -1) & (c2_keep == 0)]) 

1520 

1521 # if all clusters are done, break from loop: 

1522 if len(c1_size) == 0 and len(c2_size) == 0: 

1523 break 

1524 

1525 # if the biggest cluster is in c_p, keep this one and discard all clusters 

1526 # on the same indices in c_t: 

1527 elif np.argmax([np.max(np.append(c1_size, 0)), np.max(np.append(c2_size, 0))]) == 0: 

1528 

1529 # remove all the mappings from the other indices 

1530 cluster_mappings, _ = unique_counts(clusters_2[clusters_1 == c1_labels[np.argmax(c1_size)]]) 

1531 

1532 clusters_2[np.isin(clusters_2, cluster_mappings)] = -1 

1533 

1534 c1_keep[clusters_1==c1_labels[np.argmax(c1_size)]] = 1 

1535 

1536 remove_clusters.append(cluster_mappings) 

1537 keep_clusters.append(c1_labels[np.argmax(c1_size)]) 

1538 

1539 if verbose > 0: 

1540 print('Keep cluster %i of group 1, delete clusters %s of group 2' % (c1_labels[np.argmax(c1_size)], str(cluster_mappings[cluster_mappings!=-1] - ovl))) 

1541 

1542 # if the biggest cluster is in c_t, keep this one and discard all mappings in c_p 

1543 elif np.argmax([np.max(np.append(c1_size, 0)), np.max(np.append(c2_size, 0))]) == 1: 

1544 

1545 # remove all the mappings from the other indices 

1546 cluster_mappings, _ = unique_counts(clusters_1[clusters_2 == c2_labels[np.argmax(c2_size)]]) 

1547 

1548 clusters_1[np.isin(clusters_1, cluster_mappings)] = -1 

1549 

1550 c2_keep[clusters_2==c2_labels[np.argmax(c2_size)]] = 1 

1551 

1552 remove_clusters.append(cluster_mappings) 

1553 keep_clusters.append(c2_labels[np.argmax(c2_size)]) 

1554 

1555 if verbose > 0: 

1556 print('Keep cluster %i of group 2, delete clusters %s of group 1' % (c2_labels[np.argmax(c2_size)] - ovl, str(cluster_mappings[cluster_mappings!=-1]))) 

1557 

1558 # combine results  

1559 clusters = (clusters_1+1)*c1_keep + (clusters_2+1)*c2_keep - 1 

1560 x_merged = (x_1)*c1_keep + (x_2)*c2_keep 

1561 

1562 return clusters, x_merged, np.vstack([c1_keep, c2_keep]) 

1563 

1564 

1565def extract_means(data, eod_x, eod_peak_x, eod_tr_x, eod_widths, 

1566 clusters, rate, width_fac, verbose=0): 

1567 """ Extract mean EODs and EOD timepoints for each EOD cluster. 

1568 

1569 Parameters 

1570 ---------- 

1571 data: list of floats 

1572 Raw recording data. 

1573 eod_x: list of ints 

1574 Locations of EODs in samples. 

1575 eod_peak_x : list of ints 

1576 Locations of EOD peaks in samples. 

1577 eod_tr_x : list of ints 

1578 Locations of EOD troughs in samples. 

1579 eod_widths: list of ints 

1580 EOD widths in samples. 

1581 clusters: list of ints 

1582 EOD cluster labels 

1583 rate: float 

1584 Sampling rate of recording  

1585 width_fac : float 

1586 Multiplication factor for window used to extract EOD. 

1587  

1588 verbose : int (optional) 

1589 Verbosity level. 

1590 

1591 Returns 

1592 ------- 

1593 mean_eods: list of 2D arrays (3, eod_length) 

1594 The average EOD for each detected fish. First column is time in seconds, 

1595 second column the mean eod, third column the standard error. 

1596 eod_times: list of 1D arrays 

1597 For each detected fish the times of EOD in seconds. 

1598 eod_peak_times: list of 1D arrays 

1599 For each detected fish the times of EOD peaks in seconds. 

1600 eod_trough_times: list of 1D arrays 

1601 For each detected fish the times of EOD troughs in seconds. 

1602 eod_labels: list of ints 

1603 Cluster label for each detected fish. 

1604 """ 

1605 mean_eods, eod_times, eod_peak_times, eod_tr_times, eod_heights, cluster_labels = [], [], [], [], [], [] 

1606 

1607 for cluster in np.unique(clusters): 

1608 if cluster!=-1: 

1609 cutwidth = np.mean(eod_widths[clusters==cluster])*width_fac 

1610 current_x = eod_x[(eod_x>cutwidth) & (eod_x<(len(data)-cutwidth))] 

1611 current_clusters = clusters[(eod_x>cutwidth) & (eod_x<(len(data)-cutwidth))] 

1612 

1613 snippets = np.vstack([data[int(x-cutwidth):int(x+cutwidth)] for x in current_x[current_clusters==cluster]]) 

1614 mean_eod = np.mean(snippets, axis=0) 

1615 eod_time = np.arange(len(mean_eod))/rate - cutwidth/rate 

1616 

1617 mean_eod = np.vstack([eod_time, mean_eod, np.std(snippets, axis=0)]) 

1618 

1619 mean_eods.append(mean_eod) 

1620 eod_times.append(eod_x[clusters==cluster]/rate) 

1621 eod_heights.append(np.min(mean_eod)-np.max(mean_eod)) 

1622 eod_peak_times.append(eod_peak_x[clusters==cluster]/rate) 

1623 eod_tr_times.append(eod_tr_x[clusters==cluster]/rate) 

1624 cluster_labels.append(cluster) 

1625 

1626 return [m for _, m in sorted(zip(eod_heights, mean_eods))], [t for _, t in sorted(zip(eod_heights, eod_times))], [pt for _, pt in sorted(zip(eod_heights, eod_peak_times))], [tt for _, tt in sorted(zip(eod_heights, eod_tr_times))], [c for _, c in sorted(zip(eod_heights, cluster_labels))] 

1627 

1628 

1629def find_clipped_clusters(clusters, mean_eods, eod_times, 

1630 eod_peaktimes, eod_troughtimes, 

1631 cluster_labels, width_factor, 

1632 clip_threshold=0.9, verbose=0): 

1633 """ Detect EODs that are clipped and set all clusterlabels of these clipped EODs to -1. 

1634  

1635 Also return the mean EODs and timepoints of these clipped EODs. 

1636 

1637 Parameters 

1638 ---------- 

1639 clusters: array of ints 

1640 Cluster labels for each EOD in a recording. 

1641 mean_eods: list of numpy arrays 

1642 Mean EOD waveform for each cluster. 

1643 eod_times: list of numpy arrays 

1644 EOD timepoints for each EOD cluster. 

1645 eod_peaktimes 

1646 EOD peaktimes for each EOD cluster. 

1647 eod_troughtimes 

1648 EOD troughtimes for each EOD cluster. 

1649 cluster_labels: numpy array 

1650 Unique EOD clusterlabels. 

1651 clip_threshold: float 

1652 Threshold for detecting clipped EODs. 

1653  

1654 verbose: int 

1655 Verbosity level. 

1656 

1657 Returns 

1658 ------- 

1659 clusters : array of ints 

1660 Cluster labels for each EOD in the recording, where clipped EODs have been set to -1. 

1661 clipped_eods : list of numpy arrays 

1662 Mean EOD waveforms for each clipped EOD cluster. 

1663 clipped_times : list of numpy arrays 

1664 EOD timepoints for each clipped EOD cluster. 

1665 clipped_peaktimes : list of numpy arrays 

1666 EOD peaktimes for each clipped EOD cluster. 

1667 clipped_troughtimes : list of numpy arrays 

1668 EOD troughtimes for each clipped EOD cluster. 

1669 """ 

1670 clipped_eods, clipped_times, clipped_peaktimes, clipped_troughtimes, clipped_labels = [], [], [], [], [] 

1671 

1672 for mean_eod, eod_time, eod_peaktime, eod_troughtime,label in zip(mean_eods, eod_times, eod_peaktimes, eod_troughtimes, cluster_labels): 

1673 

1674 if (np.count_nonzero(mean_eod[1]>clip_threshold) > len(mean_eod[1])/(width_factor*2)) or (np.count_nonzero(mean_eod[1] < -clip_threshold) > len(mean_eod[1])/(width_factor*2)): 

1675 clipped_eods.append(mean_eod) 

1676 clipped_times.append(eod_time) 

1677 clipped_peaktimes.append(eod_peaktime) 

1678 clipped_troughtimes.append(eod_troughtime) 

1679 clipped_labels.append(label) 

1680 if verbose>0: 

1681 print('clipped pulsefish') 

1682 

1683 clusters[np.isin(clusters, clipped_labels)] = -1 

1684 

1685 return clusters, clipped_eods, clipped_times, clipped_peaktimes, clipped_troughtimes 

1686 

1687 

1688def delete_moving_fish(clusters, eod_t, T, eod_heights, eod_widths, 

1689 rate, min_dt=0.25, stepsize=0.05, 

1690 sliding_window_factor=2000, verbose=0, 

1691 plot_level=0, save_plot=False, save_path='', 

1692 ftype='pdf', return_data=[]): 

1693 """ 

1694 Use a sliding window to detect the minimum number of fish detected simultaneously,  

1695 then delete all other EOD clusters.  

1696 

1697 Do this only for EODs within the same width clusters, as a 

1698 moving fish will preserve its EOD width. 

1699 

1700 Parameters 

1701 ---------- 

1702 clusters: list of ints 

1703 EOD cluster labels. 

1704 eod_t: list of floats 

1705 Timepoints of the EODs (in seconds). 

1706 T: float 

1707 Length of recording (in seconds). 

1708 eod_heights: list of floats 

1709 EOD amplitudes. 

1710 eod_widths: list of floats 

1711 EOD widths (in seconds). 

1712 rate: float 

1713 Recording data sampling rate. 

1714 

1715 min_dt : float (optional) 

1716 Minimum sliding window size (in seconds). 

1717 stepsize : float (optional) 

1718 Sliding window stepsize (in seconds). 

1719 sliding_window_factor : float 

1720 Multiplier for sliding window width, 

1721 where the sliding window width = median(EOD_width)*sliding_window_factor. 

1722 verbose : int (optional) 

1723 Verbosity level. 

1724 plot_level : int (optional) 

1725 Similar to verbosity levels, but with plots.  

1726 Only set to > 0 for debugging purposes. 

1727 save_plot : bool (optional) 

1728 Set to True to save the plots created by plot_level. 

1729 save_path : string (optional) 

1730 Path to save data to. Only important if you wish to save data (save_data==True). 

1731 ftype : string (optional) 

1732 Define the filetype to save the plots in if save_plots is set to True. 

1733 Options are: 'png', 'jpg', 'svg' ... 

1734 return_data : list of strings (optional) 

1735 Keys that specify data to be logged. The key that can be used to log data 

1736 in this function is 'moving_fish' (see extract_pulsefish()). 

1737 

1738 Returns 

1739 ------- 

1740 clusters : list of ints 

1741 Cluster labels, where deleted clusters have been set to -1. 

1742 window : list of 2 floats 

1743 Start and end of window selected for deleting moving fish in seconds. 

1744 mf_dict : dictionary 

1745 Key value pairs of logged data. Data to be logged is specified by return_data. 

1746 """ 

1747 mf_dict = {} 

1748 

1749 if len(np.unique(clusters[clusters != -1])) == 0: 

1750 return clusters, [0, 1], {} 

1751 

1752 all_keep_clusters = [] 

1753 width_classes = merge_gaussians(eod_widths, np.copy(clusters), 0.75) 

1754 

1755 all_windows = [] 

1756 all_dts = [] 

1757 ev_num = 0 

1758 for iw, w in enumerate(np.unique(width_classes[clusters >= 0])): 

1759 # initialize variables 

1760 min_clusters = 100 

1761 average_height = 0 

1762 sparse_clusters = 100 

1763 keep_clusters = [] 

1764 

1765 dt = max(min_dt, np.median(eod_widths[width_classes==w])*sliding_window_factor) 

1766 window_start = 0 

1767 window_end = dt 

1768 

1769 wclusters = clusters[width_classes==w] 

1770 weod_t = eod_t[width_classes==w] 

1771 weod_heights = eod_heights[width_classes==w] 

1772 weod_widths = eod_widths[width_classes==w] 

1773 

1774 all_dts.append(dt) 

1775 

1776 if verbose>0: 

1777 print('sliding window dt = %f'%dt) 

1778 

1779 # make W dependent on width?? 

1780 ignore_steps = np.zeros(len(np.arange(0, T-dt+stepsize, stepsize))) 

1781 

1782 for i, t in enumerate(np.arange(0, T-dt+stepsize, stepsize)): 

1783 current_clusters = wclusters[(weod_t>=t)&(weod_t<t+dt)&(wclusters!=-1)] 

1784 if len(np.unique(current_clusters))==0: 

1785 ignore_steps[i-int(dt/stepsize):i+int(dt/stepsize)] = 1 

1786 if verbose>0: 

1787 print('No pulsefish in recording at T=%.2f:%.2f' % (t, t+dt)) 

1788 

1789 

1790 x = np.arange(0, T-dt+stepsize, stepsize) 

1791 y = np.ones(len(x)) 

1792 

1793 running_sum = np.ones(len(np.arange(0, T+stepsize, stepsize))) 

1794 ulabs = np.unique(wclusters[wclusters>=0]) 

1795 

1796 # sliding window 

1797 for j, (t, ignore_step) in enumerate(zip(x, ignore_steps)): 

1798 current_clusters = wclusters[(weod_t>=t)&(weod_t<t+dt)&(wclusters!=-1)] 

1799 current_widths = weod_widths[(weod_t>=t)&(weod_t<t+dt)&(wclusters!=-1)] 

1800 

1801 unique_clusters = np.unique(current_clusters) 

1802 y[j] = len(unique_clusters) 

1803 

1804 if (len(unique_clusters) <= min_clusters) and \ 

1805 (ignore_step==0) and \ 

1806 (len(unique_clusters !=1)): 

1807 

1808 current_labels = np.isin(wclusters, unique_clusters) 

1809 current_height = np.mean(weod_heights[current_labels]) 

1810 

1811 # compute nr of clusters that are too sparse 

1812 clusters_after_deletion = np.unique(remove_sparse_detections(np.copy(clusters[np.isin(clusters, unique_clusters)]), rate*eod_widths[np.isin(clusters, unique_clusters)], rate, T)) 

1813 current_sparse_clusters = len(unique_clusters) - len(clusters_after_deletion[clusters_after_deletion!=-1]) 

1814 

1815 if current_sparse_clusters <= sparse_clusters and \ 

1816 ((current_sparse_clusters<sparse_clusters) or 

1817 (current_height > average_height) or 

1818 (len(unique_clusters) < min_clusters)): 

1819 

1820 keep_clusters = unique_clusters 

1821 min_clusters = len(unique_clusters) 

1822 average_height = current_height 

1823 window_end = t+dt 

1824 sparse_clusters = current_sparse_clusters 

1825 

1826 all_keep_clusters.append(keep_clusters) 

1827 all_windows.append(window_end) 

1828 

1829 if 'moving_fish' in return_data or plot_level>0: 

1830 if 'w' in mf_dict: 

1831 mf_dict['w'].append(np.median(eod_widths[width_classes==w])) 

1832 mf_dict['T'] = T 

1833 mf_dict['dt'].append(dt) 

1834 mf_dict['clusters'].append(wclusters) 

1835 mf_dict['t'].append(weod_t) 

1836 mf_dict['fishcount'].append([x+0.5*(x[1]-x[0]), y]) 

1837 mf_dict['ignore_steps'].append(ignore_steps) 

1838 else: 

1839 mf_dict['w'] = [np.median(eod_widths[width_classes==w])] 

1840 mf_dict['T'] = [T] 

1841 mf_dict['dt'] = [dt] 

1842 mf_dict['clusters'] = [wclusters] 

1843 mf_dict['t'] = [weod_t] 

1844 mf_dict['fishcount'] = [[x+0.5*(x[1]-x[0]), y]] 

1845 mf_dict['ignore_steps'] = [ignore_steps] 

1846 

1847 if verbose>0: 

1848 print('Estimated nr of pulsefish in recording: %i'%len(all_keep_clusters)) 

1849 

1850 if plot_level>0: 

1851 plot_moving_fish(mf_dict['w'], mf_dict['dt'], mf_dict['clusters'],mf_dict['t'], 

1852 mf_dict['fishcount'], T, mf_dict['ignore_steps']) 

1853 if save_plot: 

1854 plt.savefig('%sdelete_moving_fish.%s' % (save_path, ftype)) 

1855 # empty dict 

1856 if 'moving_fish' not in return_data: 

1857 mf_dict = {} 

1858 

1859 # delete all clusters that are not selected 

1860 clusters[np.invert(np.isin(clusters, np.concatenate(all_keep_clusters)))] = -1 

1861 

1862 return clusters, [np.max(all_windows)-np.max(all_dts), np.max(all_windows)], mf_dict 

1863 

1864 

1865def remove_sparse_detections(clusters, eod_widths, rate, T, 

1866 min_density=0.0005, verbose=0): 

1867 """ Remove all EOD clusters that are too sparse 

1868 

1869 Parameters 

1870 ---------- 

1871 clusters : list of ints 

1872 Cluster labels. 

1873 eod_widths : list of ints 

1874 Cluster widths in samples. 

1875 rate : float 

1876 Sampling rate. 

1877 T : float 

1878 Lenght of recording in seconds. 

1879 min_density : float (optional) 

1880 Minimum density for realistic EOD detections. 

1881 verbose : int (optional) 

1882 Verbosity level. 

1883 

1884 Returns 

1885 ------- 

1886 clusters : list of ints 

1887 Cluster labels, where sparse clusters have been set to -1. 

1888 """ 

1889 for c in np.unique(clusters): 

1890 if c!=-1: 

1891 

1892 n = len(clusters[clusters==c]) 

1893 w = np.median(eod_widths[clusters==c])/rate 

1894 

1895 if n*w < T*min_density: 

1896 if verbose>0: 

1897 print('cluster %i is too sparse'%c) 

1898 clusters[clusters==c] = -1 

1899 return clusters