Coverage for src/thunderfish/pulses.py: 0%

598 statements  

« prev     ^ index     » next       coverage.py v7.5.0, created at 2024-04-29 16:21 +0000

1""" 

2Extract and cluster EOD waverforms of pulse-type electric fish. 

3 

4## Main function 

5 

6- `extract_pulsefish()`: checks for pulse-type fish based on the EOD amplitude and shape. 

7 

8""" 

9 

10import os 

11import numpy as np 

12from scipy import stats 

13from scipy.interpolate import interp1d 

14from sklearn.preprocessing import StandardScaler 

15from sklearn.decomposition import PCA 

16from sklearn.cluster import DBSCAN 

17from sklearn.mixture import BayesianGaussianMixture 

18from sklearn.metrics import pairwise_distances 

19from thunderlab.eventdetection import detect_peaks, median_std_threshold 

20from .pulseplots import * 

21 

22import warnings 

23def warn(*args, **kwargs): 

24 """ 

25 Ignore all warnings. 

26 """ 

27 pass 

28warnings.warn = warn 

29 

30try: 

31 from numba import jit 

32except ImportError: 

33 def jit(*args, **kwargs): 

34 def decorator_jit(func): 

35 return func 

36 return decorator_jit 

37 

38 

39# upgrade numpy functions for backwards compatibility: 

40if not hasattr(np, 'isin'): 

41 np.isin = np.in1d 

42 

43def unique_counts(ar): 

44 """ Find the unique elements of an array and their counts, ignoring shape. 

45 

46 The code is condensed from numpy version 1.17.0. 

47  

48 Parameters 

49 ---------- 

50 ar : numpy array 

51 Input array 

52 

53 Returns 

54 ------- 

55 unique_vaulues : numpy array 

56 Unique values in array ar. 

57 unique_counts : numpy array 

58 Number of instances for each unique value in ar. 

59 """ 

60 try: 

61 return np.unique(ar, return_counts=True) 

62 except TypeError: 

63 ar = np.asanyarray(ar).flatten() 

64 ar.sort() 

65 mask = np.empty(ar.shape, dtype=bool_) 

66 mask[:1] = True 

67 mask[1:] = ar[1:] != ar[:-1] 

68 idx = np.concatenate(np.nonzero(mask) + ([mask.size],)) 

69 return ar[mask], np.diff(idx) 

70 

71 

72########################################################################### 

73 

74 

75def extract_pulsefish(data, samplerate, width_factor_shape=3, 

76 width_factor_wave=8, width_factor_display=4, 

77 verbose=0, plot_level=0, save_plots=False, 

78 save_path='', ftype='png', return_data=[]): 

79 """ Extract and cluster pulse-type fish EODs from single channel data. 

80  

81 Takes recording data containing an unknown number of pulsefish and extracts the mean  

82 EOD and EOD timepoints for each fish present in the recording. 

83  

84 Parameters 

85 ---------- 

86 data: 1-D array of float 

87 The data to be analysed. 

88 samplerate: float 

89 Sampling rate of the data in Hertz. 

90 

91 width_factor_shape : float (optional) 

92 Width multiplier used for EOD shape analysis. 

93 EOD snippets are extracted based on width between the  

94 peak and trough multiplied by the width factor. 

95 width_factor_wave : float (optional) 

96 Width multiplier used for wavefish detection. 

97 width_factor_display : float (optional) 

98 Width multiplier used for EOD mean extraction and display. 

99 verbose : int (optional) 

100 Verbosity level. 

101 plot_level : int (optional) 

102 Similar to verbosity levels, but with plots.  

103 Only set to > 0 for debugging purposes. 

104 save_plots : bool (optional) 

105 Set to True to save the plots created by plot_level. 

106 save_path: string (optional) 

107 Path for saving plots. 

108 ftype : string (optional) 

109 Define the filetype to save the plots in if save_plots is set to True. 

110 Options are: 'png', 'jpg', 'svg' ... 

111 return_data : list of strings (optional) 

112 Specify data that should be logged and returned in a dictionary. Each clustering  

113 step has a specific keyword that results in adding different variables to the log dictionary. 

114 Optional keys for return_data and the resulting additional key-value pairs to the log dictionary are: 

115  

116 - 'all_eod_times': 

117 - 'all_times': list of two lists of floats. 

118 All peak (`all_times[0]`) and trough times (`all_times[1]`) extracted 

119 by the peak detection algorithm. Times are given in seconds. 

120 - 'eod_troughtimes': list of 1D arrays. 

121 The timepoints in seconds of each unique extracted EOD cluster, 

122 where each 1D array encodes one cluster. 

123  

124 - 'peak_detection': 

125 - "data": 1D numpy array of floats. 

126 Quadratically interpolated data which was used for peak detection. 

127 - "interp_fac": float. 

128 Interpolation factor of raw data. 

129 - "peaks_1": 1D numpy array of ints. 

130 Peak indices on interpolated data after first peak detection step. 

131 - "troughs_1": 1D numpy array of ints. 

132 Peak indices on interpolated data after first peak detection step. 

133 - "peaks_2": 1D numpy array of ints. 

134 Peak indices on interpolated data after second peak detection step. 

135 - "troughs_2": 1D numpy array of ints. 

136 Peak indices on interpolated data after second peak detection step. 

137 - "peaks_3": 1D numpy array of ints. 

138 Peak indices on interpolated data after third peak detection step. 

139 - "troughs_3": 1D numpy array of ints. 

140 Peak indices on interpolated data after third peak detection step. 

141 - "peaks_4": 1D numpy array of ints. 

142 Peak indices on interpolated data after fourth peak detection step. 

143 - "troughs_4": 1D numpy array of ints. 

144 Peak indices on interpolated data after fourth peak detection step. 

145 

146 - 'all_cluster_steps': 

147 - 'samplerate': float. 

148 Samplerate of interpolated data. 

149 - 'EOD_widths': list of three 1D numpy arrays. 

150 The first list entry gives the unique labels of all width clusters 

151 as a list of ints. 

152 The second list entry gives the width values for each EOD in samples 

153 as a 1D numpy array of ints. 

154 The third list entry gives the width labels for each EOD 

155 as a 1D numpy array of ints. 

156 - 'EOD_heights': nested lists (2 layers) of three 1D numpy arrays. 

157 The first list entry gives the unique labels of all height clusters 

158 as a list of ints for each width cluster. 

159 The second list entry gives the height values for each EOD 

160 as a 1D numpy array of floats for each width cluster. 

161 The third list entry gives the height labels for each EOD 

162 as a 1D numpy array of ints for each width cluster. 

163 - 'EOD_shapes': nested lists (3 layers) of three 1D numpy arrays 

164 The first list entry gives the raw EOD snippets as a 2D numpy array 

165 for each height cluster in a width cluster. 

166 The second list entry gives the snippet PCA values for each EOD 

167 as a 2D numpy array of floats for each height cluster in a width cluster. 

168 The third list entry gives the shape labels for each EOD as a 1D numpy array 

169 of ints for each height cluster in a width cluster. 

170 - 'discarding_masks': Nested lists (two layers) of 1D numpy arrays. 

171 The masks of EODs that are discarded by the discarding step of the algorithm. 

172 The masks are 1D boolean arrays where instances that are set to True are 

173 discarded by the algorithm. Discarding masks are saved in nested lists 

174 that represent the width and height clusters. 

175 - 'merge_masks': Nested lists (two layers) of 2D numpy arrays. 

176 The masks of EODs that are discarded by the merging step of the algorithm. 

177 The masks are 2D boolean arrays where for each sample point `i` either 

178 `merge_mask[i,0]` or `merge_mask[i,1]` is set to True. Here, merge_mask[:,0] 

179 represents the peak-centered clusters and `merge_mask[:,1]` represents the 

180 trough-centered clusters. Merge masks are saved in nested lists that 

181 represent the width and height clusters. 

182 

183 - 'BGM_width': 

184 - 'BGM_width': dictionary 

185 - 'x': 1D numpy array of floats. 

186 BGM input values (in this case the EOD widths), 

187 - 'use_log': boolean. 

188 True if the z-scored logarithm of the data was used as BGM input. 

189 - 'BGM': list of three 1D numpy arrays. 

190 The first instance are the weights of the Gaussian fits. 

191 The second instance are the means of the Gaussian fits. 

192 The third instance are the variances of the Gaussian fits. 

193 - 'labels': 1D numpy array of ints. 

194 Labels defined by BGM model (before merging based on merge factor). 

195 - xlab': string. 

196 Label for plot (defines the units of the BGM data). 

197 

198 - 'BGM_height': 

199 This key adds a new dictionary for each width cluster. 

200 - 'BGM_height_*n*' : dictionary, where *n* defines the width cluster as an int. 

201 - 'x': 1D numpy array of floats. 

202 BGM input values (in this case the EOD heights), 

203 - 'use_log': boolean. 

204 True if the z-scored logarithm of the data was used as BGM input. 

205 - 'BGM': list of three 1D numpy arrays. 

206 The first instance are the weights of the Gaussian fits. 

207 The second instance are the means of the Gaussian fits. 

208 The third instance are the variances of the Gaussian fits. 

209 - 'labels': 1D numpy array of ints. 

210 Labels defined by BGM model (before merging based on merge factor). 

211 - 'xlab': string. 

212 Label for plot (defines the units of the BGM data). 

213 

214 - 'snippet_clusters': 

215 This key adds a new dictionary for each height cluster. 

216 - 'snippet_clusters*_n_m_p*' : dictionary, where *n* defines the width cluster 

217 (int), *m* defines the height cluster (int) and *p* defines shape clustering 

218 on peak or trough centered EOD snippets (string: 'peak' or 'trough'). 

219 - 'raw_snippets': 2D numpy array (nsamples, nfeatures). 

220 Raw EOD snippets. 

221 - 'snippets': 2D numpy array. 

222 Normalized EOD snippets. 

223 - 'features': 2D numpy array.(nsamples, nfeatures) 

224 PCA values for each normalized EOD snippet. 

225 - 'clusters': 1D numpy array of ints. 

226 Cluster labels. 

227 - 'samplerate': float. 

228 Samplerate of snippets. 

229 

230 - 'eod_deletion': 

231 This key adds two dictionaries for each (peak centered) shape cluster, 

232 where *cluster* (int) is the unique shape cluster label. 

233 - 'mask_*cluster*' : list of four booleans. 

234 The mask for each cluster discarding step.  

235 The first instance represents the artefact masks, where artefacts 

236 are set to True. 

237 The second instance represents the unreliable cluster masks, 

238 where unreliable clusters are set to True. 

239 The third instance represents the wavefish masks, where wavefish 

240 are set to True. 

241 The fourth instance represents the sidepeak masks, where sidepeaks 

242 are set to True. 

243 - 'vals_*cluster*' : list of lists. 

244 All variables that are used for each cluster deletion step. 

245 The first instance is a list of two 1D numpy arrays: the mean EOD and 

246 the FFT of that mean EOD. 

247 The second instance is a 1D numpy array with all EOD width to ISI ratios. 

248 The third instance is a list with three entries:  

249 The first entry is a 1D numpy array zoomed out version of the mean EOD. 

250 The second entry is a list of two 1D numpy arrays that define the peak 

251 and trough indices of the zoomed out mean EOD. 

252 The third entry contains a list of two values that represent the 

253 peak-trough pair in the zoomed out mean EOD with the largest height 

254 difference. 

255 - 'samplerate' : float. 

256 EOD snippet samplerate. 

257 

258 - 'masks':  

259 - 'masks' : 2D numpy array (4,N). 

260 Each row contains masks for each EOD detected by the EOD peakdetection step.  

261 The first row defines the artefact masks, the second row defines the 

262 unreliable EOD masks,  

263 the third row defines the wavefish masks and the fourth row defines 

264 the sidepeak masks. 

265 

266 - 'moving_fish': 

267 - 'moving_fish': dictionary. 

268 - 'w' : list of floats. 

269 Median width for each width cluster that the moving fish algorithm is 

270 computed on (in seconds). 

271 - 'T' : list of floats. 

272 Lenght of analyzed recording for each width cluster (in seconds). 

273 - 'dt' : list of floats. 

274 Sliding window size (in seconds) for each width cluster. 

275 - 'clusters' : list of 1D numpy int arrays. 

276 Cluster labels for each EOD cluster in a width cluster. 

277 - 't' : list of 1D numpy float arrays. 

278 EOD emission times for each EOD in a width cluster. 

279 - 'fishcount' : list of lists. 

280 Sliding window timepoints and fishcounts for each width cluster. 

281 - 'ignore_steps' : list of 1D int arrays. 

282 Mask for fishcounts that were ignored (ignored if True) in the 

283 moving_fish analysis. 

284  

285 Returns 

286 ------- 

287 mean_eods: list of 2D arrays (3, eod_length) 

288 The average EOD for each detected fish. First column is time in seconds, 

289 second column the mean eod, third column the standard error. 

290 eod_times: list of 1D arrays 

291 For each detected fish the times of EOD peaks or troughs in seconds. 

292 Use these timepoints for EOD averaging. 

293 eod_peaktimes: list of 1D arrays 

294 For each detected fish the times of EOD peaks in seconds. 

295 zoom_window: tuple of floats 

296 Start and endtime of suggested window for plotting EOD timepoints. 

297 log_dict: dictionary 

298 Dictionary with logged variables, where variables to log are specified 

299 by `return_data`. 

300 """ 

301 if verbose > 0: 

302 print('') 

303 if verbose > 1: 

304 print(70*'#') 

305 print('##### extract_pulsefish', 46*'#') 

306 

307 if (save_plots and plot_level>0 and save_path): 

308 # create folder to save things in. 

309 if not os.path.exists(save_path): 

310 os.makedirs(save_path) 

311 else: 

312 save_path = '' 

313 

314 mean_eods, eod_times, eod_peaktimes, zoom_window = [], [], [], [] 

315 log_dict = {} 

316 

317 # interpolate: 

318 i_samplerate = 500000.0 

319 #i_samplerate = samplerate 

320 try: 

321 f = interp1d(np.arange(len(data))/samplerate, data, kind='quadratic') 

322 i_data = f(np.arange(0.0, (len(data)-1)/samplerate, 1.0/i_samplerate)) 

323 except MemoryError: 

324 i_samplerate = samplerate 

325 i_data = data 

326 log_dict['data'] = i_data # TODO: could be removed 

327 log_dict['samplerate'] = i_samplerate # TODO: could be removed 

328 log_dict['i_data'] = i_data 

329 log_dict['i_samplerate'] = i_samplerate 

330 # log_dict["interp_fac"] = interp_fac # TODO: is not set anymore 

331 

332 # standard deviation of data in small snippets: 

333 threshold = median_std_threshold(data, samplerate) # TODO make this a parameter 

334 

335 # extract peaks: 

336 if 'peak_detection' in return_data: 

337 x_peak, x_trough, eod_heights, eod_widths, pd_log_dict = \ 

338 detect_pulses(i_data, i_samplerate, threshold, 

339 width_fac=np.max([width_factor_shape, 

340 width_factor_display, 

341 width_factor_wave]), 

342 verbose=verbose-1, return_data=True) 

343 log_dict.update(pd_log_dict) 

344 else: 

345 x_peak, x_trough, eod_heights, eod_widths = \ 

346 detect_pulses(i_data, i_samplerate, threshold, 

347 width_fac=np.max([width_factor_shape, 

348 width_factor_display, 

349 width_factor_wave]), 

350 verbose=verbose-1, return_data=False) 

351 

352 if len(x_peak) > 0: 

353 # cluster 

354 clusters, x_merge, c_log_dict = cluster(x_peak, x_trough, 

355 eod_heights, 

356 eod_widths, i_data, 

357 i_samplerate, 

358 width_factor_shape, 

359 width_factor_wave, 

360 verbose=verbose-1, 

361 plot_level=plot_level-1, 

362 save_plots=save_plots, 

363 save_path=save_path, 

364 ftype=ftype, 

365 return_data=return_data) 

366 

367 # extract mean eods and times 

368 mean_eods, eod_times, eod_peaktimes, eod_troughtimes, cluster_labels = \ 

369 extract_means(i_data, x_merge, x_peak, x_trough, eod_widths, clusters, 

370 i_samplerate, width_factor_display, verbose=verbose-1) 

371 

372 # determine clipped clusters (save them, but ignore in other steps) 

373 clusters, clipped_eods, clipped_times, clipped_peaktimes, clipped_troughtimes = \ 

374 find_clipped_clusters(clusters, mean_eods, eod_times, eod_peaktimes, 

375 eod_troughtimes, cluster_labels, width_factor_display, 

376 verbose=verbose-1) 

377 

378 # delete the moving fish 

379 clusters, zoom_window, mf_log_dict = \ 

380 delete_moving_fish(clusters, x_merge/i_samplerate, len(data)/samplerate, 

381 eod_heights, eod_widths/i_samplerate, i_samplerate, 

382 verbose=verbose-1, plot_level=plot_level-1, save_plot=save_plots, 

383 save_path=save_path, ftype=ftype, return_data=return_data) 

384 

385 if 'moving_fish' in return_data: 

386 log_dict['moving_fish'] = mf_log_dict 

387 

388 clusters = remove_sparse_detections(clusters, eod_widths, i_samplerate, 

389 len(data)/samplerate, verbose=verbose-1) 

390 

391 # extract mean eods 

392 mean_eods, eod_times, eod_peaktimes, eod_troughtimes, cluster_labels = \ 

393 extract_means(i_data, x_merge, x_peak, x_trough, eod_widths, 

394 clusters, i_samplerate, width_factor_display, verbose=verbose-1) 

395 

396 mean_eods.extend(clipped_eods) 

397 eod_times.extend(clipped_times) 

398 eod_peaktimes.extend(clipped_peaktimes) 

399 eod_troughtimes.extend(clipped_troughtimes) 

400 

401 if plot_level > 0: 

402 plot_all(data, eod_peaktimes, eod_troughtimes, samplerate, mean_eods) 

403 if save_plots: 

404 plt.savefig('%sextract_pulsefish_results.%s' % (save_path, ftype)) 

405 if save_plots: 

406 plt.close('all') 

407 

408 if 'all_eod_times' in return_data: 

409 log_dict['all_times'] = [x_peak/i_samplerate, x_trough/i_samplerate] 

410 log_dict['eod_troughtimes'] = eod_troughtimes 

411 

412 log_dict.update(c_log_dict) 

413 

414 return mean_eods, eod_times, eod_peaktimes, zoom_window, log_dict 

415 

416 

417def detect_pulses(data, samplerate, thresh, min_rel_slope_diff=0.25, 

418 min_width=0.00005, max_width=0.01, width_fac=5.0, 

419 verbose=0, return_data=False): 

420 """Detect pulses in data. 

421 

422 Was `def extract_eod_times(data, samplerate, width_factor, 

423 interp_freq=500000, max_peakwidth=0.01, 

424 min_peakwidth=None, verbose=0, return_data=[], 

425 save_path='')` before. 

426 

427 Parameters 

428 ---------- 

429 data: 1-D array of float 

430 The data to be analysed. 

431 samplerate: float 

432 Sampling rate of the data. 

433 thresh: float 

434 Threshold for peak and trough detection via `detect_peaks()`. 

435 Must be a positive number that sets the minimum difference 

436 between a peak and a trough. 

437 min_rel_slope_diff: float 

438 Minimum required difference between left and right slope (between 

439 peak and troughs) relative to mean slope for deciding which trough 

440 to take besed on slope difference. 

441 min_width: float 

442 Minimum width (peak-trough distance) of pulses in seconds. 

443 max_width: float 

444 Maximum width (peak-trough distance) of pulses in seconds. 

445 width_fac: float 

446 Pulses extend plus or minus `width_fac` times their width 

447 (distance between peak and assigned trough). 

448 Only pulses are returned that can fully be analysed with this width. 

449 verbose : int (optional) 

450 Verbosity level. 

451 return_data : bool 

452 If `True` data of this function is logged and returned (see 

453 extract_pulsefish()). 

454 

455 Returns 

456 ------- 

457 peak_indices: array of ints 

458 Indices of EOD peaks in data. 

459 trough_indices: array of ints 

460 Indices of EOD troughs in data. There is one x_trough for each x_peak. 

461 heights: array of floats 

462 EOD heights for each x_peak. 

463 widths: array of ints 

464 EOD widths for each x_peak (in samples). 

465 peak_detection_result : dictionary 

466 Key value pairs of logged data. 

467 This is only returned if `return_data` is `True`. 

468 

469 """ 

470 peak_detection_result = {} 

471 

472 # detect peaks and troughs in the data: 

473 peak_indices, trough_indices = detect_peaks(data, thresh) 

474 if verbose > 0: 

475 print('Peaks/troughs detected in data: %5d %5d' 

476 % (len(peak_indices), len(trough_indices))) 

477 if return_data: 

478 peak_detection_result.update(peaks_1=np.array(peak_indices), 

479 troughs_1=np.array(trough_indices)) 

480 if len(peak_indices) < 2 or \ 

481 len(trough_indices) < 2 or \ 

482 len(peak_indices) > len(data)/20: 

483 # TODO: if too many peaks increase threshold! 

484 if verbose > 0: 

485 print('No or too many peaks/troughs detected in data.') 

486 if return_data: 

487 return np.array([], dtype=int), np.array([], dtype=int), \ 

488 np.array([]), np.array([], dtype=int), peak_detection_result 

489 else: 

490 return np.array([], dtype=int), np.array([], dtype=int), \ 

491 np.array([]), np.array([], dtype=int) 

492 

493 # assign troughs to peaks: 

494 peak_indices, trough_indices, heights, widths, slopes = \ 

495 assign_side_peaks(data, peak_indices, trough_indices, min_rel_slope_diff) 

496 if verbose > 1: 

497 print('Number of peaks after assigning side-peaks: %5d' 

498 % (len(peak_indices))) 

499 if return_data: 

500 peak_detection_result.update(peaks_2=np.array(peak_indices), 

501 troughs_2=np.array(trough_indices)) 

502 

503 # check widths: 

504 keep = ((widths>min_width*samplerate) & (widths<max_width*samplerate)) 

505 peak_indices = peak_indices[keep] 

506 trough_indices = trough_indices[keep] 

507 heights = heights[keep] 

508 widths = widths[keep] 

509 slopes = slopes[keep] 

510 if verbose > 1: 

511 print('Number of peaks after checking pulse width: %5d' 

512 % (len(peak_indices))) 

513 if return_data: 

514 peak_detection_result.update(peaks_3=np.array(peak_indices), 

515 troughs_3=np.array(trough_indices)) 

516 

517 # discard connected peaks: 

518 same = np.nonzero(trough_indices[:-1] == trough_indices[1:])[0] 

519 keep = np.ones(len(trough_indices), dtype=bool) 

520 for i in same: 

521 # same troughs at trough_indices[i] and trough_indices[i+1]: 

522 s = slopes[i:i+2] 

523 rel_slopes = np.abs(np.diff(s))[0]/np.mean(s) 

524 if rel_slopes > min_rel_slope_diff: 

525 keep[i+(s[1]<s[0])] = False 

526 else: 

527 keep[i+(heights[i+1]<heights[i])] = False 

528 peak_indices = peak_indices[keep] 

529 trough_indices = trough_indices[keep] 

530 heights = heights[keep] 

531 widths = widths[keep] 

532 if verbose > 1: 

533 print('Number of peaks after merging pulses: %5d' 

534 % (len(peak_indices))) 

535 if return_data: 

536 peak_detection_result.update(peaks_4=np.array(peak_indices), 

537 troughs_4=np.array(trough_indices)) 

538 if len(peak_indices) == 0: 

539 if verbose > 0: 

540 print('No peaks remain as pulse candidates.') 

541 if return_data: 

542 return np.array([], dtype=int), np.array([], dtype=int), \ 

543 np.array([]), np.array([], dtype=int), peak_detection_result 

544 else: 

545 return np.array([], dtype=int), np.array([], dtype=int), \ 

546 np.array([]), np.array([], dtype=int) 

547 

548 # only take those where the maximum cutwidth does not cause issues - 

549 # if the width_fac times the width + x is more than length. 

550 keep = ((peak_indices - widths > 0) & 

551 (peak_indices + widths < len(data)) & 

552 (trough_indices - widths > 0) & 

553 (trough_indices + widths < len(data))) 

554 

555 if verbose > 0: 

556 print('Remaining peaks after EOD extraction: %5d' 

557 % (p.sum(keep))) 

558 print('') 

559 

560 if return_data: 

561 return peak_indices[keep], trough_indices[keep], \ 

562 heights[keep], widths[keep], peak_detection_result 

563 else: 

564 return peak_indices[keep], trough_indices[keep], \ 

565 heights[keep], widths[keep] 

566 

567 

568@jit(nopython=True) 

569def assign_side_peaks(data, peak_indices, trough_indices, 

570 min_rel_slope_diff=0.25): 

571 """Assign to each peak the trough resulting in a pulse with the steepest slope or largest height. 

572 

573 The slope between a peak and a trough is computed as the height 

574 difference divided by the distance between peak and trough. If the 

575 slopes between the left and the right trough differ by less than 

576 `min_rel_slope_diff`, then just the heigths between and the two 

577 troughs relative to the peak are compared. 

578 

579 Was `def detect_eod_peaks(data, main_indices, side_indices, 

580 max_width=20, min_width=2, verbose=0)` before. 

581 

582 Parameters 

583 ---------- 

584 data: array of floats 

585 Data in which the events were detected. 

586 peak_indices: array of ints 

587 Indices of the detected peaks in the data time series. 

588 trough_indices: array of ints 

589 Indices of the detected troughs in the data time series.  

590 min_rel_slope_diff: float 

591 Minimum required difference of left and right slope relative 

592 to mean slope. 

593 

594 Returns 

595 ------- 

596 peak_indices: array of ints 

597 Peak indices. Same as input `peak_indices` but potentially shorter 

598 by one or two elements. 

599 trough_indices: array of ints 

600 Corresponding trough indices of trough to the left or right 

601 of the peaks. 

602 heights: array of floats 

603 Peak heights (distance between peak and corresponding trough amplitude) 

604 widths: array of ints 

605 Peak widths (distance between peak and corresponding trough indices) 

606 slopes: array of floats 

607 Peak slope (height divided by width) 

608 """ 

609 # is a main or side peak first? 

610 peak_first = int(peak_indices[0] < trough_indices[0]) 

611 # is a main or side peak last? 

612 peak_last = int(peak_indices[-1] > trough_indices[-1]) 

613 # ensure all peaks to have side peaks (troughs) at both sides, 

614 # i.e. troughs at same index and next index are before and after peak: 

615 peak_indices = peak_indices[peak_first:len(peak_indices)-peak_last] 

616 y = data[peak_indices] 

617 

618 # indices of troughs on the left and right side of main peaks: 

619 l_indices = np.arange(len(peak_indices)) 

620 r_indices = l_indices + 1 

621 

622 # indices, distance to peak, height, and slope of left troughs: 

623 l_side_indices = trough_indices[l_indices] 

624 l_distance = np.abs(peak_indices - l_side_indices) 

625 l_height = np.abs(y - data[l_side_indices]) 

626 l_slope = np.abs(l_height/l_distance) 

627 

628 # indices, distance to peak, height, and slope of right troughs: 

629 r_side_indices = trough_indices[r_indices] 

630 r_distance = np.abs(r_side_indices - peak_indices) 

631 r_height = np.abs(y - data[r_side_indices]) 

632 r_slope = np.abs(r_height/r_distance) 

633 

634 # which trough to assign to the peak? 

635 # - either the one with the steepest slope, or 

636 # - when slopes are similar on both sides 

637 # (within `min_rel_slope_diff` difference), 

638 # the trough with the maximum height difference to the peak: 

639 rel_slopes = np.abs(l_slope-r_slope)/(0.5*(l_slope+r_slope)) 

640 take_slopes = rel_slopes > min_rel_slope_diff 

641 take_left = l_height > r_height 

642 take_left[take_slopes] = l_slope[take_slopes] > r_slope[take_slopes] 

643 

644 # assign troughs, heights, widths, and slopes: 

645 trough_indices = np.where(take_left, 

646 trough_indices[:-1], trough_indices[1:]) 

647 heights = np.where(take_left, l_height, r_height) 

648 widths = np.where(take_left, l_distance, r_distance) 

649 slopes = np.where(take_left, l_slope, r_slope) 

650 

651 return peak_indices, trough_indices, heights, widths, slopes 

652 

653 

654def cluster(eod_xp, eod_xt, eod_heights, eod_widths, data, samplerate, 

655 width_factor_shape, width_factor_wave, 

656 n_gaus_height=10, merge_threshold_height=0.1, n_gaus_width=3, 

657 merge_threshold_width=0.5, minp=10, 

658 verbose=0, plot_level=0, save_plots=False, save_path='', ftype='pdf', 

659 return_data=[]): 

660 """Cluster EODs. 

661  

662 First cluster on EOD widths using a Bayesian Gaussian 

663 Mixture (BGM) model, then cluster on EOD heights using a 

664 BGM model. Lastly, cluster on EOD waveform with DBSCAN. 

665 Clustering on EOD waveform is performed twice, once on 

666 peak-centered EODs and once on trough-centered EODs. 

667 Non-pulsetype EOD clusters are deleted, and clusters are 

668 merged afterwards. 

669 

670 Parameters 

671 ---------- 

672 eod_xp : list of ints 

673 Location of EOD peaks in indices. 

674 eod_xt: list of ints 

675 Locations of EOD troughs in indices. 

676 eod_heights: list of floats 

677 EOD heights. 

678 eod_widths: list of ints 

679 EOD widths in samples. 

680 data: array of floats 

681 Data in which to detect pulse EODs. 

682 samplerate : float 

683 Sample rate of `data`. 

684 width_factor_shape : float 

685 Multiplier for snippet extraction width. This factor is 

686 multiplied with the width between the peak and through of a 

687 single EOD. 

688 width_factor_wave : float 

689 Multiplier for wavefish extraction width. 

690 n_gaus_height : int (optional) 

691 Number of gaussians to use for the clustering based on EOD height. 

692 merge_threshold_height : float (optional) 

693 Threshold for merging clusters that are similar in height. 

694 n_gaus_width : int (optional) 

695 Number of gaussians to use for the clustering based on EOD width. 

696 merge_threshold_width : float (optional) 

697 Threshold for merging clusters that are similar in width. 

698 minp : int (optional) 

699 Minimum number of points for core clusters (DBSCAN). 

700 verbose : int (optional) 

701 Verbosity level. 

702 plot_level : int (optional) 

703 Similar to verbosity levels, but with plots.  

704 Only set to > 0 for debugging purposes. 

705 save_plots : bool (optional) 

706 Set to True to save created plots. 

707 save_path : string (optional) 

708 Path to save plots to. Only used if save_plots==True. 

709 ftype : string (optional) 

710 Filetype to save plot images in. 

711 return_data : list of strings (optional) 

712 Keys that specify data to be logged. Keys that can be used to log data 

713 in this function are: 'all_cluster_steps', 'BGM_width', 'BGM_height', 

714 'snippet_clusters', 'eod_deletion' (see extract_pulsefish()). 

715 

716 Returns 

717 ------- 

718 labels : list of ints 

719 EOD cluster labels based on height and EOD waveform. 

720 x_merge : list of ints 

721 Locations of EODs in clusters. 

722 saved_data : dictionary 

723 Key value pairs of logged data. Data to be logged is specified 

724 by return_data. 

725 

726 """ 

727 saved_data = {} 

728 

729 if plot_level>0 or 'all_cluster_steps' in return_data: 

730 all_heightlabels = [] 

731 all_shapelabels = [] 

732 all_snippets = [] 

733 all_features = [] 

734 all_heights = [] 

735 all_unique_heightlabels = [] 

736 

737 all_p_clusters = -1 * np.ones(len(eod_xp)) 

738 all_t_clusters = -1 * np.ones(len(eod_xp)) 

739 artefact_masks_p = np.ones(len(eod_xp), dtype=bool) 

740 artefact_masks_t = np.ones(len(eod_xp), dtype=bool) 

741 

742 x_merge = -1 * np.ones(len(eod_xp)) 

743 

744 max_label_p = 0 # keep track of the labels so that no labels are overwritten 

745 max_label_t = 0 

746 

747 # loop only over height clusters that are bigger than minp 

748 # first cluster on width 

749 width_labels, bgm_log_dict = BGM(1000*eod_widths/samplerate, merge_threshold_width, 

750 n_gaus_width, use_log=False, verbose=verbose-1, 

751 plot_level=plot_level-1, xlabel='width [ms]', 

752 save_plot=save_plots, save_path=save_path, 

753 save_name='width', ftype=ftype, return_data=return_data) 

754 saved_data.update(bgm_log_dict) 

755 

756 if verbose>0: 

757 print('Clusters generated based on EOD width:') 

758 [print('N_{} = {:>4} h_{} = {:.4f}'.format(l, len(width_labels[width_labels==l]), l, np.mean(eod_widths[width_labels==l]))) for l in np.unique(width_labels)] 

759 

760 

761 w_labels, w_counts = unique_counts(width_labels) 

762 unique_width_labels = w_labels[w_counts>minp] 

763 

764 for wi, width_label in enumerate(unique_width_labels): 

765 

766 # select only features in one width cluster at a time 

767 w_eod_widths = eod_widths[width_labels==width_label] 

768 w_eod_heights = eod_heights[width_labels==width_label] 

769 w_eod_xp = eod_xp[width_labels==width_label] 

770 w_eod_xt = eod_xt[width_labels==width_label] 

771 width = int(width_factor_shape*np.median(w_eod_widths)) 

772 if width > w_eod_xp[0]: 

773 width = w_eod_xp[0] 

774 if width > w_eod_xt[0]: 

775 width = w_eod_xt[0] 

776 if width > len(data) - w_eod_xp[-1]: 

777 width = len(data) - w_eod_xp[-1] 

778 if width > len(data) - w_eod_xt[-1]: 

779 width = len(data) - w_eod_xt[-1] 

780 

781 wp_clusters = -1 * np.ones(len(w_eod_xp)) 

782 wt_clusters = -1 * np.ones(len(w_eod_xp)) 

783 wartefact_mask = np.ones(len(w_eod_xp)) 

784 

785 # determine height labels 

786 raw_p_snippets, p_snippets, p_features, p_bg_ratio = \ 

787 extract_snippet_features(data, w_eod_xp, w_eod_heights, width) 

788 raw_t_snippets, t_snippets, t_features, t_bg_ratio = \ 

789 extract_snippet_features(data, w_eod_xt, w_eod_heights, width) 

790 

791 height_labels, bgm_log_dict = \ 

792 BGM(w_eod_heights, 

793 min(merge_threshold_height, np.median(np.min(np.vstack([p_bg_ratio, t_bg_ratio]), axis=0))), 

794 n_gaus_height, use_log=True, verbose=verbose-1, plot_level=plot_level-1, 

795 xlabel = 'height [a.u.]', save_plot=save_plots, save_path=save_path, 

796 save_name = 'height_%d' % wi, ftype=ftype, return_data=return_data) 

797 saved_data.update(bgm_log_dict) 

798 

799 if verbose>0: 

800 print('Clusters generated based on EOD height:') 

801 [print('N_{} = {:>4} h_{} = {:.4f}'.format(l, len(height_labels[height_labels==l]), l, np.mean(w_eod_heights[height_labels==l]))) for l in np.unique(height_labels)] 

802 

803 h_labels, h_counts = unique_counts(height_labels) 

804 unique_height_labels = h_labels[h_counts>minp] 

805 

806 if plot_level>0 or 'all_cluster_steps' in return_data: 

807 all_heightlabels.append(height_labels) 

808 all_heights.append(w_eod_heights) 

809 all_unique_heightlabels.append(unique_height_labels) 

810 shape_labels = [] 

811 cfeatures = [] 

812 csnippets = [] 

813 

814 for hi, height_label in enumerate(unique_height_labels): 

815 

816 h_eod_widths = w_eod_widths[height_labels==height_label] 

817 h_eod_heights = w_eod_heights[height_labels==height_label] 

818 h_eod_xp = w_eod_xp[height_labels==height_label] 

819 h_eod_xt = w_eod_xt[height_labels==height_label] 

820 

821 p_clusters = cluster_on_shape(p_features[height_labels==height_label], 

822 p_bg_ratio, minp, verbose=0) 

823 t_clusters = cluster_on_shape(t_features[height_labels==height_label], 

824 t_bg_ratio, minp, verbose=0) 

825 

826 if plot_level>1: 

827 plot_feature_extraction(raw_p_snippets[height_labels==height_label], 

828 p_snippets[height_labels==height_label], 

829 p_features[height_labels==height_label], 

830 p_clusters, 1/samplerate, 0) 

831 plt.savefig('%sDBSCAN_peak_w%i_h%i.%s' % (save_path, wi, hi, ftype)) 

832 plot_feature_extraction(raw_t_snippets[height_labels==height_label], 

833 t_snippets[height_labels==height_label], 

834 t_features[height_labels==height_label], 

835 t_clusters, 1/samplerate, 1) 

836 plt.savefig('%sDBSCAN_trough_w%i_h%i.%s' % (save_path, wi, hi, ftype)) 

837 

838 if 'snippet_clusters' in return_data: 

839 saved_data['snippet_clusters_%i_%i_peak' % (width_label, height_label)] = { 

840 'raw_snippets':raw_p_snippets[height_labels==height_label], 

841 'snippets':p_snippets[height_labels==height_label], 

842 'features':p_features[height_labels==height_label], 

843 'clusters':p_clusters, 

844 'samplerate':samplerate} 

845 saved_data['snippet_clusters_%i_%i_trough' % (width_label, height_label)] = { 

846 'raw_snippets':raw_t_snippets[height_labels==height_label], 

847 'snippets':t_snippets[height_labels==height_label], 

848 'features':t_features[height_labels==height_label], 

849 'clusters':t_clusters, 

850 'samplerate':samplerate} 

851 

852 if plot_level>0 or 'all_cluster_steps' in return_data: 

853 shape_labels.append([p_clusters, t_clusters]) 

854 cfeatures.append([p_features[height_labels==height_label], 

855 t_features[height_labels==height_label]]) 

856 csnippets.append([p_snippets[height_labels==height_label], 

857 t_snippets[height_labels==height_label]]) 

858 

859 p_clusters[p_clusters==-1] = -max_label_p - 1 

860 wp_clusters[height_labels==height_label] = p_clusters + max_label_p 

861 max_label_p = max(np.max(wp_clusters), np.max(all_p_clusters)) + 1 

862 

863 t_clusters[t_clusters==-1] = -max_label_t - 1 

864 wt_clusters[height_labels==height_label] = t_clusters + max_label_t 

865 max_label_t = max(np.max(wt_clusters), np.max(all_t_clusters)) + 1 

866 

867 if verbose > 0: 

868 if np.max(wp_clusters) == -1: 

869 print('No EOD peaks in width cluster %i' % width_label) 

870 else: 

871 unique_clusters = np.unique(wp_clusters[wp_clusters!=-1]) 

872 if len(unique_clusters) > 1: 

873 print('%i different EOD peaks in width cluster %i' % (len(unique_clusters), width_label)) 

874 

875 if plot_level>0 or 'all_cluster_steps' in return_data: 

876 all_shapelabels.append(shape_labels) 

877 all_snippets.append(csnippets) 

878 all_features.append(cfeatures) 

879 

880 # for each cluster, save fft + label 

881 # so I end up with features for each label, and the masks. 

882 # then I can extract e.g. first artefact or wave etc. 

883 

884 # remove artefacts here, based on the mean snippets ffts. 

885 artefact_masks_p[width_labels==width_label], sdict = \ 

886 remove_artefacts(p_snippets, wp_clusters, samplerate, 

887 verbose=verbose-1, return_data=return_data) 

888 saved_data.update(sdict) 

889 artefact_masks_t[width_labels==width_label], _ = \ 

890 remove_artefacts(t_snippets, wt_clusters, samplerate, 

891 verbose=verbose-1, return_data=return_data) 

892 

893 # update maxlab so that no clusters are overwritten 

894 all_p_clusters[width_labels==width_label] = wp_clusters 

895 all_t_clusters[width_labels==width_label] = wt_clusters 

896 

897 # remove all non-reliable clusters 

898 unreliable_fish_mask_p, saved_data = \ 

899 delete_unreliable_fish(all_p_clusters, eod_widths, eod_xp, 

900 verbose=verbose-1, sdict=saved_data) 

901 unreliable_fish_mask_t, _ = \ 

902 delete_unreliable_fish(all_t_clusters, eod_widths, eod_xt, verbose=verbose-1) 

903 

904 wave_mask_p, sidepeak_mask_p, saved_data = \ 

905 delete_wavefish_and_sidepeaks(data, all_p_clusters, eod_xp, eod_widths, 

906 width_factor_wave, verbose=verbose-1, sdict=saved_data) 

907 wave_mask_t, sidepeak_mask_t, _ = \ 

908 delete_wavefish_and_sidepeaks(data, all_t_clusters, eod_xt, eod_widths, 

909 width_factor_wave, verbose=verbose-1) 

910 

911 og_clusters = [np.copy(all_p_clusters), np.copy(all_t_clusters)] 

912 og_labels=np.copy(all_p_clusters+all_t_clusters) 

913 

914 # go through all clusters and masks?? 

915 all_p_clusters[(artefact_masks_p | unreliable_fish_mask_p | wave_mask_p | sidepeak_mask_p)] = -1 

916 all_t_clusters[(artefact_masks_t | unreliable_fish_mask_t | wave_mask_t | sidepeak_mask_t)] = -1 

917 

918 # merge here. 

919 all_clusters, x_merge, mask = merge_clusters(np.copy(all_p_clusters), 

920 np.copy(all_t_clusters), eod_xp, eod_xt, 

921 verbose=verbose-1) 

922 

923 if 'all_cluster_steps' in return_data or plot_level>0: 

924 all_dmasks = [] 

925 all_mmasks = [] 

926 

927 discarding_masks = \ 

928 np.vstack(((artefact_masks_p | unreliable_fish_mask_p | wave_mask_p | sidepeak_mask_p), 

929 (artefact_masks_t | unreliable_fish_mask_t | wave_mask_t | sidepeak_mask_t))) 

930 merge_mask = mask 

931 

932 # save the masks in the same formats as the snippets 

933 for wi, (width_label, w_shape_label, heightlabels, unique_height_labels) in enumerate(zip(unique_width_labels, all_shapelabels, all_heightlabels, all_unique_heightlabels)): 

934 w_dmasks = discarding_masks[:,width_labels==width_label] 

935 w_mmasks = merge_mask[:,width_labels==width_label] 

936 

937 wd_2 = [] 

938 wm_2 = [] 

939 

940 for hi, (height_label, h_shape_label) in enumerate(zip(unique_height_labels, w_shape_label)): 

941 

942 h_dmasks = w_dmasks[:,heightlabels==height_label] 

943 h_mmasks = w_mmasks[:,heightlabels==height_label] 

944 

945 wd_2.append(h_dmasks) 

946 wm_2.append(h_mmasks) 

947 

948 all_dmasks.append(wd_2) 

949 all_mmasks.append(wm_2) 

950 

951 if plot_level>0: 

952 plot_clustering(samplerate, [unique_width_labels, eod_widths, width_labels], 

953 [all_unique_heightlabels, all_heights, all_heightlabels], 

954 [all_snippets, all_features, all_shapelabels], 

955 all_dmasks, all_mmasks) 

956 if save_plots: 

957 plt.savefig('%sclustering.%s' % (save_path, ftype)) 

958 

959 if 'all_cluster_steps' in return_data: 

960 saved_data = {'samplerate': samplerate, 

961 'EOD_widths': [unique_width_labels, eod_widths, width_labels], 

962 'EOD_heights': [all_unique_heightlabels, all_heights, all_heightlabels], 

963 'EOD_shapes': [all_snippets, all_features, all_shapelabels], 

964 'discarding_masks': all_dmasks, 

965 'merge_masks': all_mmasks 

966 } 

967 

968 if 'masks' in return_data: 

969 saved_data = {'masks' : np.vstack(((artefact_masks_p & artefact_masks_t), 

970 (unreliable_fish_mask_p & unreliable_fish_mask_t), 

971 (wave_mask_p & wave_mask_t), 

972 (sidepeak_mask_p & sidepeak_mask_t), 

973 (all_p_clusters+all_t_clusters)))} 

974 

975 if verbose>0: 

976 print('Clusters generated based on height, width and shape: ') 

977 [print('N_{} = {:>4}'.format(int(l), len(all_clusters[all_clusters == l]))) for l in np.unique(all_clusters[all_clusters != -1])] 

978 

979 return all_clusters, x_merge, saved_data 

980 

981 

982def BGM(x, merge_threshold=0.1, n_gaus=5, max_iter=200, n_init=5, 

983 use_log=False, verbose=0, plot_level=0, xlabel='x [a.u.]', 

984 save_plot=False, save_path='', save_name='', ftype='pdf', return_data=[]): 

985 """ Use a Bayesian Gaussian Mixture Model to cluster one-dimensional data. 

986 

987 Additional steps are used to merge clusters that are closer than 

988 `merge_threshold`. Broad gaussian fits that cover one or more other 

989 gaussian fits are split by their intersections with the other 

990 gaussians. 

991 

992 Parameters 

993 ---------- 

994 x : 1D numpy array 

995 Features to compute clustering on.  

996 

997 merge_threshold : float (optional) 

998 Ratio for merging nearby gaussians. 

999 n_gaus: int (optional) 

1000 Maximum number of gaussians to fit on data. 

1001 max_iter : int (optional) 

1002 Maximum number of iterations for gaussian fit. 

1003 n_init : int (optional) 

1004 Number of initializations for the gaussian fit. 

1005 use_log: boolean (optional) 

1006 Set to True to compute the gaussian fit on the logarithm of x. 

1007 Can improve clustering on features with nonlinear relationships such as peak height. 

1008 verbose : int (optional) 

1009 Verbosity level. 

1010 plot_level : int (optional) 

1011 Similar to verbosity levels, but with plots.  

1012 Only set to > 0 for debugging purposes. 

1013 xlabel : string (optional) 

1014 Xlabel for displaying BGM plot. 

1015 save_plot : bool (optional) 

1016 Set to True to save created plot. 

1017 save_path : string (optional) 

1018 Path to location where data should be saved. Only used if save_plot==True. 

1019 save_name : string (optional) 

1020 Filename of the saved plot. Usefull as usually multiple BGM models are generated. 

1021 ftype : string (optional) 

1022 Filetype of plot image if save_plots==True. 

1023 return_data : list of strings (optional) 

1024 Keys that specify data to be logged. Keys that can be used to log data 

1025 in this function are: 'BGM_width' and/or 'BGM_height' (see extract_pulsefish()). 

1026 

1027 Returns 

1028 ------- 

1029 labels : 1D numpy array 

1030 Cluster labels for each sample in x. 

1031 bgm_dict : dictionary 

1032 Key value pairs of logged data. Data to be logged is specified by return_data. 

1033 """ 

1034 

1035 bgm_dict = {} 

1036 

1037 if len(np.unique(x)) > n_gaus: 

1038 BGM_model = BayesianGaussianMixture(n_components=n_gaus, max_iter=max_iter, n_init=n_init) 

1039 if use_log: 

1040 labels = BGM_model.fit_predict(stats.zscore(np.log(x)).reshape(-1, 1)) 

1041 else: 

1042 labels = BGM_model.fit_predict(stats.zscore(x).reshape(-1, 1)) 

1043 else: 

1044 return np.zeros(len(x)), bgm_dict 

1045 

1046 if verbose>0: 

1047 if not BGM_model.converged_: 

1048 print('!!! Gaussian mixture did not converge !!!') 

1049 

1050 cur_labels = np.unique(labels) 

1051 

1052 # map labels to be increasing for increasing values for x 

1053 maxlab = len(cur_labels) 

1054 aso = np.argsort([np.median(x[labels == l]) for l in cur_labels]) + 100 

1055 for i, a in zip(cur_labels, aso): 

1056 labels[labels==i] = a 

1057 labels = labels - 100 

1058 

1059 # separate gaussian clusters that can be split by other clusters 

1060 splits = np.sort(np.copy(x))[1:][np.diff(labels[np.argsort(x)])!=0] 

1061 

1062 labels[:] = 0 

1063 for i, split in enumerate(splits): 

1064 labels[x>=split] = i+1 

1065 

1066 labels_before_merge = np.copy(labels) 

1067 

1068 # merge gaussian clusters that are closer than merge_threshold 

1069 labels = merge_gaussians(x, labels, merge_threshold) 

1070 

1071 if 'BGM_'+save_name.split('_')[0] in return_data or plot_level>0: 

1072 

1073 #sort model attributes by model_means_ 

1074 means = [m[0] for m in BGM_model.means_] 

1075 weights = [w for w in BGM_model.weights_] 

1076 variances = [v[0][0] for v in BGM_model.covariances_] 

1077 weights = [w for _, w in sorted(zip(means, weights))] 

1078 variances = [v for _, v in sorted(zip(means, variances))] 

1079 means = sorted(means) 

1080 

1081 if plot_level>0: 

1082 plot_bgm(x, means, variances, weights, use_log, labels_before_merge, 

1083 labels, xlabel) 

1084 if save_plot: 

1085 plt.savefig('%sBGM_%s.%s' % (save_path, save_name, ftype)) 

1086 

1087 if 'BGM_'+save_name.split('_')[0] in return_data: 

1088 bgm_dict['BGM_'+save_name] = {'x':x, 

1089 'use_log':use_log, 

1090 'BGM':[weights, means, variances], 

1091 'labels':labels_before_merge, 

1092 'xlab':xlabel} 

1093 

1094 return labels, bgm_dict 

1095 

1096 

1097def merge_gaussians(x, labels, merge_threshold=0.1): 

1098 """ Merge all clusters which have medians which are near. Only works in 1D. 

1099 

1100 Parameters 

1101 ---------- 

1102 x : 1D array of ints or floats 

1103 Features used for clustering. 

1104 labels : 1D array of ints 

1105 Labels for each sample in x. 

1106 merge_threshold : float (optional) 

1107 Similarity threshold to merge clusters. 

1108 

1109 Returns 

1110 ------- 

1111 labels : 1D array of ints 

1112 Merged labels for each sample in x. 

1113 """ 

1114 

1115 # compare all the means of the gaussians. If they are too close, merge them. 

1116 unique_labels = np.unique(labels[labels!=-1]) 

1117 x_medians = [np.median(x[labels==l]) for l in unique_labels] 

1118 

1119 # fill a dict with the label mappings 

1120 mapping = {} 

1121 for label_1, x_m1 in zip(unique_labels, x_medians): 

1122 for label_2, x_m2 in zip(unique_labels, x_medians): 

1123 if label_1!=label_2: 

1124 if np.abs(np.diff([x_m1, x_m2]))/np.max([x_m1, x_m2]) < merge_threshold: 

1125 mapping[label_1] = label_2 

1126 # apply mapping 

1127 for map_key, map_value in mapping.items(): 

1128 labels[labels==map_key] = map_value 

1129 

1130 return labels 

1131 

1132 

1133def extract_snippet_features(data, eod_x, eod_heights, width, n_pc=5): 

1134 """ Extract snippets from recording data, normalize them, and perform PCA. 

1135 

1136 Parameters 

1137 ---------- 

1138 data : 1D numpy array of floats 

1139 Recording data. 

1140 eod_x : 1D array of ints 

1141 Locations of EODs as indices. 

1142 eod_heights: 1D array of floats 

1143 EOD heights. 

1144 width : int 

1145 Width to cut out to each side in samples. 

1146 

1147 n_pc : int (optional) 

1148 Number of PCs to use for PCA. 

1149 

1150 Returns 

1151 ------- 

1152 raw_snippets : 2D numpy array (N, EOD_width) 

1153 Raw extracted EOD snippets. 

1154 snippets : 2D numpy array (N, EOD_width) 

1155 Normalized EOD snippets 

1156 features : 2D numpy array (N,n_pc) 

1157 PC values of EOD snippets 

1158 bg_ratio : 1D numpy array (N) 

1159 Ratio of the background activity slopes compared to EOD height. 

1160 """ 

1161 # extract snippets with corresponding width 

1162 raw_snippets = np.vstack([data[x-width:x+width] for x in eod_x]) 

1163 

1164 # subtract the slope and normalize the snippets 

1165 snippets, bg_ratio = subtract_slope(np.copy(raw_snippets), eod_heights) 

1166 snippets = StandardScaler().fit_transform(snippets.T).T 

1167 

1168 # scale so that the absolute integral = 1. 

1169 snippets = (snippets.T/np.sum(np.abs(snippets), axis=1)).T 

1170 

1171 # compute features for clustering on waveform 

1172 features = PCA(n_pc).fit(snippets).transform(snippets) 

1173 

1174 return raw_snippets, snippets, features, bg_ratio 

1175 

1176 

1177def cluster_on_shape(features, bg_ratio, minp, percentile=80, max_epsilon=0.01, 

1178 slope_ratio_factor=4, min_cluster_fraction=0.01, verbose=0): 

1179 """Separate EODs by their shape using DBSCAN. 

1180 

1181 Parameters 

1182 ---------- 

1183 features : 2D numpy array of floats (N, n_pc) 

1184 PCA features of each EOD in a recording. 

1185 bg_ratio : 1D array of floats 

1186 Ratio of background activity slope the EOD is superimposed on. 

1187 minp : int 

1188 Minimum number of points for core cluster (DBSCAN). 

1189 

1190 percentile : int (optional) 

1191 Percentile of KNN distribution, where K=minp, to use as epsilon for DBSCAN. 

1192 max_epsilon : float (optional) 

1193 Maximum epsilon to use for DBSCAN clustering. This is used to avoid adding 

1194 noisy clusters. 

1195 slope_ratio_factor : float (optional) 

1196 Influence of the slope-to-EOD ratio on the epsilon parameter. 

1197 A slope_ratio_factor of 4 means that slope-to-EOD ratios >1/4 

1198 start influencing epsilon. 

1199 min_cluster_fraction : float (optional) 

1200 Minimum fraction of all eveluated datapoint that can form a single cluster. 

1201 verbose : int (optional) 

1202 Verbosity level. 

1203 

1204 Returns 

1205 ------- 

1206 labels : 1D array of ints 

1207 Merged labels for each sample in x. 

1208 """ 

1209 

1210 # determine clustering threshold from data 

1211 minpc = max(minp, int(len(features)*min_cluster_fraction)) 

1212 knn = np.sort(pairwise_distances(features, features), axis=0)[minpc] 

1213 eps = min(max(1, slope_ratio_factor*np.median(bg_ratio))*max_epsilon, 

1214 np.percentile(knn, percentile)) 

1215 

1216 if verbose>1: 

1217 print('epsilon = %f'%eps) 

1218 print('Slope to EOD ratio = %f'%np.median(bg_ratio)) 

1219 

1220 # cluster on EOD shape 

1221 return DBSCAN(eps=eps, min_samples=minpc).fit(features).labels_ 

1222 

1223 

1224def subtract_slope(snippets, heights): 

1225 """ Subtract underlying slope from all EOD snippets. 

1226 

1227 Parameters 

1228 ---------- 

1229 snippets: 2-D numpy array 

1230 All EODs in a recorded stacked as snippets.  

1231 Shape = (number of EODs, EOD width) 

1232 heights: 1D numpy array 

1233 EOD heights. 

1234 

1235 Returns 

1236 ------- 

1237 snippets: 2-D numpy array 

1238 EOD snippets with underlying slope subtracted. 

1239 bg_ratio : 1-D numpy array 

1240 EOD height/background activity height. 

1241 """ 

1242 

1243 left_y = snippets[:,0] 

1244 right_y = snippets[:,-1] 

1245 

1246 try: 

1247 slopes = np.linspace(left_y, right_y, snippets.shape[1]) 

1248 except ValueError: 

1249 delta = (right_y - left_y)/snippets.shape[1] 

1250 slopes = np.arange(0, snippets.shape[1], dtype=snippets.dtype).reshape((-1,) + (1,) * np.ndim(delta))*delta + left_y 

1251 

1252 return snippets - slopes.T, np.abs(left_y-right_y)/heights 

1253 

1254 

1255def remove_artefacts(all_snippets, clusters, samplerate, 

1256 freq_low=20000, threshold=0.75, 

1257 verbose=0, return_data=[]): 

1258 """ Create a mask for EOD clusters that result from artefacts, based on power in low frequency spectrum. 

1259 

1260 Parameters 

1261 ---------- 

1262 all_snippets: 2D array 

1263 EOD snippets. Shape=(nEODs, EOD length) 

1264 clusters: list of ints 

1265 EOD cluster labels 

1266 samplerate : float 

1267 Samplerate of original recording data. 

1268 freq_low: float 

1269 Frequency up to which low frequency components are summed up.  

1270 threshold : float (optional) 

1271 Minimum value for sum of low frequency components relative to 

1272 sum overa ll spectrl amplitudes that separates artefact from 

1273 clean pulsefish clusters. 

1274 verbose : int (optional) 

1275 Verbosity level. 

1276 return_data : list of strings (optional) 

1277 Keys that specify data to be logged. The key that can be used to log data in this function is 

1278 'eod_deletion' (see extract_pulsefish()). 

1279 

1280 Returns 

1281 ------- 

1282 mask: numpy array of booleans 

1283 Set to True for every EOD which is an artefact. 

1284 adict : dictionary 

1285 Key value pairs of logged data. Data to be logged is specified by return_data. 

1286 """ 

1287 adict = {} 

1288 

1289 mask = np.zeros(clusters.shape, dtype=bool) 

1290 

1291 for cluster in np.unique(clusters[clusters >= 0]): 

1292 snippets = all_snippets[clusters == cluster] 

1293 mean_eod = np.mean(snippets, axis=0) 

1294 mean_eod = mean_eod - np.mean(mean_eod) 

1295 mean_eod_fft = np.abs(np.fft.rfft(mean_eod)) 

1296 freqs = np.fft.rfftfreq(len(mean_eod), 1/samplerate) 

1297 low_frequency_ratio = np.sum(mean_eod_fft[freqs<freq_low])/np.sum(mean_eod_fft) 

1298 if low_frequency_ratio < threshold: # TODO: check threshold! 

1299 mask[clusters==cluster] = True 

1300 

1301 if verbose > 0: 

1302 print('Deleting cluster %i with low frequency ratio of %.3f (min %.3f)' % (cluster, low_frequency_ratio, threshold)) 

1303 

1304 if 'eod_deletion' in return_data: 

1305 adict['vals_%d' % cluster] = [mean_eod, mean_eod_fft] 

1306 adict['mask_%d' % cluster] = [np.any(mask[clusters==cluster])] 

1307 

1308 return mask, adict 

1309 

1310 

1311def delete_unreliable_fish(clusters, eod_widths, eod_x, verbose=0, sdict={}): 

1312 """ Create a mask for EOD clusters that are either mixed with noise or other fish, or wavefish. 

1313  

1314 This is the case when the ration between the EOD width and the ISI is too large. 

1315 

1316 Parameters 

1317 ---------- 

1318 clusters : list of ints 

1319 Cluster labels. 

1320 eod_widths : list of floats or ints 

1321 EOD widths in samples or seconds. 

1322 eod_x : list of ints or floats 

1323 EOD times in samples or seconds. 

1324 

1325 verbose : int (optional) 

1326 Verbosity level. 

1327 sdict : dictionary 

1328 Dictionary that is used to log data. This is only used if a dictionary 

1329 was created by remove_artefacts(). 

1330 For logging data in noise and wavefish discarding steps, 

1331 see remove_artefacts(). 

1332 

1333 Returns 

1334 ------- 

1335 mask : numpy array of booleans 

1336 Set to True for every unreliable EOD. 

1337 sdict : dictionary 

1338 Key value pairs of logged data. Data is only logged if a dictionary 

1339 was instantiated by remove_artefacts(). 

1340 """ 

1341 mask = np.zeros(clusters.shape, dtype=bool) 

1342 for cluster in np.unique(clusters[clusters >= 0]): 

1343 if len(eod_x[cluster == clusters]) < 2: 

1344 mask[clusters == cluster] = True 

1345 if verbose>0: 

1346 print('deleting unreliable cluster %i, number of EOD times %d < 2' % (cluster, len(eod_x[cluster==clusters]))) 

1347 elif np.max(np.median(eod_widths[clusters==cluster])/np.diff(eod_x[cluster==clusters])) > 0.5: 

1348 if verbose>0: 

1349 print('deleting unreliable cluster %i, score=%f' % (cluster, np.max(np.median(eod_widths[clusters==cluster])/np.diff(eod_x[cluster==clusters])))) 

1350 mask[clusters==cluster] = True 

1351 if 'vals_%d' % cluster in sdict: 

1352 sdict['vals_%d' % cluster].append(np.median(eod_widths[clusters==cluster])/np.diff(eod_x[cluster==clusters])) 

1353 sdict['mask_%d' % cluster].append(any(mask[clusters==cluster])) 

1354 return mask, sdict 

1355 

1356 

1357def delete_wavefish_and_sidepeaks(data, clusters, eod_x, eod_widths, 

1358 width_fac, max_slope_deviation=0.5, 

1359 max_phases=4, verbose=0, sdict={}): 

1360 """ Create a mask for EODs that are likely from wavefish, or sidepeaks of bigger EODs. 

1361 

1362 Parameters 

1363 ---------- 

1364 data : list of floats 

1365 Raw recording data. 

1366 clusters : list of ints 

1367 Cluster labels. 

1368 eod_x : list of ints 

1369 Indices of EOD times. 

1370 eod_widths : list of ints 

1371 EOD widths in samples. 

1372 width_fac : float 

1373 Multiplier for EOD analysis width. 

1374 

1375 max_slope_deviation: float (optional) 

1376 Maximum deviation of position of maximum slope in snippets from 

1377 center position in multiples of mean width of EOD. 

1378 max_phases : int (optional) 

1379 Maximum number of phases for any EOD.  

1380 If the mean EOD has more phases than this, it is not a pulse EOD. 

1381 verbose : int (optional)  

1382 Verbosity level. 

1383 sdict : dictionary 

1384 Dictionary that is used to log data. This is only used if a dictionary 

1385 was created by remove_artefacts(). 

1386 For logging data in noise and wavefish discarding steps, see remove_artefacts(). 

1387 

1388 Returns 

1389 ------- 

1390 mask_wave: numpy array of booleans 

1391 Set to True for every EOD which is a wavefish EOD. 

1392 mask_sidepeak: numpy array of booleans 

1393 Set to True for every snippet which is centered around a sidepeak of an EOD. 

1394 sdict : dictionary 

1395 Key value pairs of logged data. Data is only logged if a dictionary 

1396 was instantiated by remove_artefacts(). 

1397 """ 

1398 mask_wave = np.zeros(clusters.shape, dtype=bool) 

1399 mask_sidepeak = np.zeros(clusters.shape, dtype=bool) 

1400 

1401 for i, cluster in enumerate(np.unique(clusters[clusters >= 0])): 

1402 mean_width = np.mean(eod_widths[clusters == cluster]) 

1403 cutwidth = mean_width*width_fac 

1404 current_x = eod_x[(eod_x>cutwidth) & (eod_x<(len(data)-cutwidth))] 

1405 current_clusters = clusters[(eod_x>cutwidth) & (eod_x<(len(data)-cutwidth))] 

1406 snippets = np.vstack([data[int(x-cutwidth):int(x+cutwidth)] 

1407 for x in current_x[current_clusters==cluster]]) 

1408 

1409 # extract information on main peaks and troughs: 

1410 mean_eod = np.mean(snippets, axis=0) 

1411 mean_eod = mean_eod - np.mean(mean_eod) 

1412 

1413 # detect peaks and troughs on data + some maxima/minima at the 

1414 # end, so that the sides are also considered for peak detection: 

1415 pk, tr = detect_peaks(np.concatenate(([-10*mean_eod[0]], mean_eod, [10*mean_eod[-1]])), 

1416 np.std(mean_eod)) 

1417 pk = pk[(pk>0)&(pk<len(mean_eod))] 

1418 tr = tr[(tr>0)&(tr<len(mean_eod))] 

1419 

1420 if len(pk)>0 and len(tr)>0: 

1421 idxs = np.sort(np.concatenate((pk, tr))) 

1422 slopes = np.abs(np.diff(mean_eod[idxs])) 

1423 m_slope = np.argmax(slopes) 

1424 centered = np.min(np.abs(idxs[m_slope:m_slope+2] - len(mean_eod)//2)) 

1425 

1426 # compute all height differences of peaks and troughs within snippets. 

1427 # if they are all similar, it is probably noise or a wavefish. 

1428 idxs = np.sort(np.concatenate((pk, tr))) 

1429 hdiffs = np.diff(mean_eod[idxs]) 

1430 

1431 if centered > max_slope_deviation*mean_width: # TODO: check, factor was probably 0.16 

1432 if verbose > 0: 

1433 print('Deleting cluster %i, which is a sidepeak' % cluster) 

1434 mask_sidepeak[clusters==cluster] = True 

1435 

1436 w_diff = np.abs(np.diff(np.sort(np.concatenate((pk, tr))))) 

1437 

1438 if np.abs(np.diff(idxs[m_slope:m_slope+2])) < np.mean(eod_widths[clusters==cluster])*0.5 or len(pk) + len(tr)>max_phases or np.min(w_diff)>2*cutwidth/width_fac: #or len(hdiffs[np.abs(hdiffs)>0.5*(np.max(mean_eod)-np.min(mean_eod))])>max_phases: 

1439 if verbose>0: 

1440 print('Deleting cluster %i, which is a wavefish' % cluster) 

1441 mask_wave[clusters==cluster] = True 

1442 if 'vals_%d' % cluster in sdict: 

1443 sdict['vals_%d' % cluster].append([mean_eod, [pk, tr], 

1444 idxs[m_slope:m_slope+2]]) 

1445 sdict['mask_%d' % cluster].append(any(mask_wave[clusters==cluster])) 

1446 sdict['mask_%d' % cluster].append(any(mask_sidepeak[clusters==cluster])) 

1447 

1448 return mask_wave, mask_sidepeak, sdict 

1449 

1450 

1451def merge_clusters(clusters_1, clusters_2, x_1, x_2, verbose=0): 

1452 """ Merge clusters resulting from two clustering methods. 

1453 

1454 This method only works if clustering is performed on the same EODs 

1455 with the same ordering, where there is a one to one mapping from 

1456 clusters_1 to clusters_2.  

1457 

1458 Parameters 

1459 ---------- 

1460 clusters_1: list of ints 

1461 EOD cluster labels for cluster method 1. 

1462 clusters_2: list of ints 

1463 EOD cluster labels for cluster method 2. 

1464 x_1: list of ints 

1465 Indices of EODs for cluster method 1 (clusters_1). 

1466 x_2: list of ints 

1467 Indices of EODs for cluster method 2 (clusters_2). 

1468 verbose : int (optional) 

1469 Verbosity level. 

1470 

1471 Returns 

1472 ------- 

1473 clusters : list of ints 

1474 Merged clusters. 

1475 x_merged : list of ints 

1476 Merged cluster indices. 

1477 mask : 2d numpy array of ints (N, 2) 

1478 Mask for clusters that are selected from clusters_1 (mask[:,0]) and 

1479 from clusters_2 (mask[:,1]). 

1480 """ 

1481 if verbose > 0: 

1482 print('\nMerge cluster:') 

1483 

1484 # these arrays become 1 for each EOD that is chosen from that array 

1485 c1_keep = np.zeros(len(clusters_1)) 

1486 c2_keep = np.zeros(len(clusters_2)) 

1487 

1488 # add n to one of the cluster lists to avoid overlap 

1489 ovl = np.max(clusters_1) + 1 

1490 clusters_2[clusters_2!=-1] = clusters_2[clusters_2!=-1] + ovl 

1491 

1492 remove_clusters = [[]] 

1493 keep_clusters = [] 

1494 og_clusters = [np.copy(clusters_1), np.copy(clusters_2)] 

1495 

1496 # loop untill done 

1497 while True: 

1498 

1499 # compute unique clusters and cluster sizes 

1500 # of cluster that have not been iterated over: 

1501 c1_labels, c1_size = unique_counts(clusters_1[(clusters_1 != -1) & (c1_keep == 0)]) 

1502 c2_labels, c2_size = unique_counts(clusters_2[(clusters_2 != -1) & (c2_keep == 0)]) 

1503 

1504 # if all clusters are done, break from loop: 

1505 if len(c1_size) == 0 and len(c2_size) == 0: 

1506 break 

1507 

1508 # if the biggest cluster is in c_p, keep this one and discard all clusters 

1509 # on the same indices in c_t: 

1510 elif np.argmax([np.max(np.append(c1_size, 0)), np.max(np.append(c2_size, 0))]) == 0: 

1511 

1512 # remove all the mappings from the other indices 

1513 cluster_mappings, _ = unique_counts(clusters_2[clusters_1 == c1_labels[np.argmax(c1_size)]]) 

1514 

1515 clusters_2[np.isin(clusters_2, cluster_mappings)] = -1 

1516 

1517 c1_keep[clusters_1==c1_labels[np.argmax(c1_size)]] = 1 

1518 

1519 remove_clusters.append(cluster_mappings) 

1520 keep_clusters.append(c1_labels[np.argmax(c1_size)]) 

1521 

1522 if verbose > 0: 

1523 print('Keep cluster %i of group 1, delete clusters %s of group 2' % (c1_labels[np.argmax(c1_size)], str(cluster_mappings[cluster_mappings!=-1] - ovl))) 

1524 

1525 # if the biggest cluster is in c_t, keep this one and discard all mappings in c_p 

1526 elif np.argmax([np.max(np.append(c1_size, 0)), np.max(np.append(c2_size, 0))]) == 1: 

1527 

1528 # remove all the mappings from the other indices 

1529 cluster_mappings, _ = unique_counts(clusters_1[clusters_2 == c2_labels[np.argmax(c2_size)]]) 

1530 

1531 clusters_1[np.isin(clusters_1, cluster_mappings)] = -1 

1532 

1533 c2_keep[clusters_2==c2_labels[np.argmax(c2_size)]] = 1 

1534 

1535 remove_clusters.append(cluster_mappings) 

1536 keep_clusters.append(c2_labels[np.argmax(c2_size)]) 

1537 

1538 if verbose > 0: 

1539 print('Keep cluster %i of group 2, delete clusters %s of group 1' % (c2_labels[np.argmax(c2_size)] - ovl, str(cluster_mappings[cluster_mappings!=-1]))) 

1540 

1541 # combine results  

1542 clusters = (clusters_1+1)*c1_keep + (clusters_2+1)*c2_keep - 1 

1543 x_merged = (x_1)*c1_keep + (x_2)*c2_keep 

1544 

1545 return clusters, x_merged, np.vstack([c1_keep, c2_keep]) 

1546 

1547 

1548def extract_means(data, eod_x, eod_peak_x, eod_tr_x, eod_widths, clusters, samplerate, 

1549 width_fac, verbose=0): 

1550 """ Extract mean EODs and EOD timepoints for each EOD cluster. 

1551 

1552 Parameters 

1553 ---------- 

1554 data: list of floats 

1555 Raw recording data. 

1556 eod_x: list of ints 

1557 Locations of EODs in samples. 

1558 eod_peak_x : list of ints 

1559 Locations of EOD peaks in samples. 

1560 eod_tr_x : list of ints 

1561 Locations of EOD troughs in samples. 

1562 eod_widths: list of ints 

1563 EOD widths in samples. 

1564 clusters: list of ints 

1565 EOD cluster labels 

1566 samplerate: float 

1567 samplerate of recording  

1568 width_fac : float 

1569 Multiplication factor for window used to extract EOD. 

1570  

1571 verbose : int (optional) 

1572 Verbosity level. 

1573 

1574 Returns 

1575 ------- 

1576 mean_eods: list of 2D arrays (3, eod_length) 

1577 The average EOD for each detected fish. First column is time in seconds, 

1578 second column the mean eod, third column the standard error. 

1579 eod_times: list of 1D arrays 

1580 For each detected fish the times of EOD in seconds. 

1581 eod_peak_times: list of 1D arrays 

1582 For each detected fish the times of EOD peaks in seconds. 

1583 eod_trough_times: list of 1D arrays 

1584 For each detected fish the times of EOD troughs in seconds. 

1585 eod_labels: list of ints 

1586 Cluster label for each detected fish. 

1587 """ 

1588 mean_eods, eod_times, eod_peak_times, eod_tr_times, eod_heights, cluster_labels = [], [], [], [], [], [] 

1589 

1590 for cluster in np.unique(clusters): 

1591 if cluster!=-1: 

1592 cutwidth = np.mean(eod_widths[clusters==cluster])*width_fac 

1593 current_x = eod_x[(eod_x>cutwidth) & (eod_x<(len(data)-cutwidth))] 

1594 current_clusters = clusters[(eod_x>cutwidth) & (eod_x<(len(data)-cutwidth))] 

1595 

1596 snippets = np.vstack([data[int(x-cutwidth):int(x+cutwidth)] for x in current_x[current_clusters==cluster]]) 

1597 mean_eod = np.mean(snippets, axis=0) 

1598 eod_time = np.arange(len(mean_eod))/samplerate - cutwidth/samplerate 

1599 

1600 mean_eod = np.vstack([eod_time, mean_eod, np.std(snippets, axis=0)]) 

1601 

1602 mean_eods.append(mean_eod) 

1603 eod_times.append(eod_x[clusters==cluster]/samplerate) 

1604 eod_heights.append(np.min(mean_eod)-np.max(mean_eod)) 

1605 eod_peak_times.append(eod_peak_x[clusters==cluster]/samplerate) 

1606 eod_tr_times.append(eod_tr_x[clusters==cluster]/samplerate) 

1607 cluster_labels.append(cluster) 

1608 

1609 return [m for _, m in sorted(zip(eod_heights, mean_eods))], [t for _, t in sorted(zip(eod_heights, eod_times))], [pt for _, pt in sorted(zip(eod_heights, eod_peak_times))], [tt for _, tt in sorted(zip(eod_heights, eod_tr_times))], [c for _, c in sorted(zip(eod_heights, cluster_labels))] 

1610 

1611 

1612def find_clipped_clusters(clusters, mean_eods, eod_times, eod_peaktimes, eod_troughtimes, 

1613 cluster_labels, width_factor, clip_threshold=0.9, verbose=0): 

1614 """ Detect EODs that are clipped and set all clusterlabels of these clipped EODs to -1. 

1615  

1616 Also return the mean EODs and timepoints of these clipped EODs. 

1617 

1618 Parameters 

1619 ---------- 

1620 clusters: array of ints 

1621 Cluster labels for each EOD in a recording. 

1622 mean_eods: list of numpy arrays 

1623 Mean EOD waveform for each cluster. 

1624 eod_times: list of numpy arrays 

1625 EOD timepoints for each EOD cluster. 

1626 eod_peaktimes 

1627 EOD peaktimes for each EOD cluster. 

1628 eod_troughtimes 

1629 EOD troughtimes for each EOD cluster. 

1630 cluster_labels: numpy array 

1631 Unique EOD clusterlabels. 

1632 clip_threshold: float 

1633 Threshold for detecting clipped EODs. 

1634  

1635 verbose: int 

1636 Verbosity level. 

1637 

1638 Returns 

1639 ------- 

1640 clusters : array of ints 

1641 Cluster labels for each EOD in the recording, where clipped EODs have been set to -1. 

1642 clipped_eods : list of numpy arrays 

1643 Mean EOD waveforms for each clipped EOD cluster. 

1644 clipped_times : list of numpy arrays 

1645 EOD timepoints for each clipped EOD cluster. 

1646 clipped_peaktimes : list of numpy arrays 

1647 EOD peaktimes for each clipped EOD cluster. 

1648 clipped_troughtimes : list of numpy arrays 

1649 EOD troughtimes for each clipped EOD cluster. 

1650 """ 

1651 clipped_eods, clipped_times, clipped_peaktimes, clipped_troughtimes, clipped_labels = [], [], [], [], [] 

1652 

1653 for mean_eod, eod_time, eod_peaktime, eod_troughtime,label in zip(mean_eods, eod_times, eod_peaktimes, eod_troughtimes, cluster_labels): 

1654 

1655 if (np.count_nonzero(mean_eod[1]>clip_threshold) > len(mean_eod[1])/(width_factor*2)) or (np.count_nonzero(mean_eod[1] < -clip_threshold) > len(mean_eod[1])/(width_factor*2)): 

1656 clipped_eods.append(mean_eod) 

1657 clipped_times.append(eod_time) 

1658 clipped_peaktimes.append(eod_peaktime) 

1659 clipped_troughtimes.append(eod_troughtime) 

1660 clipped_labels.append(label) 

1661 if verbose>0: 

1662 print('clipped pulsefish') 

1663 

1664 clusters[np.isin(clusters, clipped_labels)] = -1 

1665 

1666 return clusters, clipped_eods, clipped_times, clipped_peaktimes, clipped_troughtimes 

1667 

1668 

1669def delete_moving_fish(clusters, eod_t, T, eod_heights, eod_widths, samplerate, 

1670 min_dt=0.25, stepsize=0.05, sliding_window_factor=2000, 

1671 verbose=0, plot_level=0, save_plot=False, save_path='', 

1672 ftype='pdf', return_data=[]): 

1673 """ 

1674 Use a sliding window to detect the minimum number of fish detected simultaneously,  

1675 then delete all other EOD clusters.  

1676 

1677 Do this only for EODs within the same width clusters, as a 

1678 moving fish will preserve its EOD width. 

1679 

1680 Parameters 

1681 ---------- 

1682 clusters: list of ints 

1683 EOD cluster labels. 

1684 eod_t: list of floats 

1685 Timepoints of the EODs (in seconds). 

1686 T: float 

1687 Length of recording (in seconds). 

1688 eod_heights: list of floats 

1689 EOD amplitudes. 

1690 eod_widths: list of floats 

1691 EOD widths (in seconds). 

1692 samplerate: float 

1693 Recording data samplerate. 

1694 

1695 min_dt : float (optional) 

1696 Minimum sliding window size (in seconds). 

1697 stepsize : float (optional) 

1698 Sliding window stepsize (in seconds). 

1699 sliding_window_factor : float 

1700 Multiplier for sliding window width, 

1701 where the sliding window width = median(EOD_width)*sliding_window_factor. 

1702 verbose : int (optional) 

1703 Verbosity level. 

1704 plot_level : int (optional) 

1705 Similar to verbosity levels, but with plots.  

1706 Only set to > 0 for debugging purposes. 

1707 save_plot : bool (optional) 

1708 Set to True to save the plots created by plot_level. 

1709 save_path : string (optional) 

1710 Path to save data to. Only important if you wish to save data (save_data==True). 

1711 ftype : string (optional) 

1712 Define the filetype to save the plots in if save_plots is set to True. 

1713 Options are: 'png', 'jpg', 'svg' ... 

1714 return_data : list of strings (optional) 

1715 Keys that specify data to be logged. The key that can be used to log data 

1716 in this function is 'moving_fish' (see extract_pulsefish()). 

1717 

1718 Returns 

1719 ------- 

1720 clusters : list of ints 

1721 Cluster labels, where deleted clusters have been set to -1. 

1722 window : list of 2 floats 

1723 Start and end of window selected for deleting moving fish in seconds. 

1724 mf_dict : dictionary 

1725 Key value pairs of logged data. Data to be logged is specified by return_data. 

1726 """ 

1727 mf_dict = {} 

1728 

1729 if len(np.unique(clusters[clusters != -1])) == 0: 

1730 return clusters, [0, 1], {} 

1731 

1732 all_keep_clusters = [] 

1733 width_classes = merge_gaussians(eod_widths, np.copy(clusters), 0.75) 

1734 

1735 all_windows = [] 

1736 all_dts = [] 

1737 ev_num = 0 

1738 for iw, w in enumerate(np.unique(width_classes[clusters >= 0])): 

1739 # initialize variables 

1740 min_clusters = 100 

1741 average_height = 0 

1742 sparse_clusters = 100 

1743 keep_clusters = [] 

1744 

1745 dt = max(min_dt, np.median(eod_widths[width_classes==w])*sliding_window_factor) 

1746 window_start = 0 

1747 window_end = dt 

1748 

1749 wclusters = clusters[width_classes==w] 

1750 weod_t = eod_t[width_classes==w] 

1751 weod_heights = eod_heights[width_classes==w] 

1752 weod_widths = eod_widths[width_classes==w] 

1753 

1754 all_dts.append(dt) 

1755 

1756 if verbose>0: 

1757 print('sliding window dt = %f'%dt) 

1758 

1759 # make W dependent on width?? 

1760 ignore_steps = np.zeros(len(np.arange(0, T-dt+stepsize, stepsize))) 

1761 

1762 for i, t in enumerate(np.arange(0, T-dt+stepsize, stepsize)): 

1763 current_clusters = wclusters[(weod_t>=t)&(weod_t<t+dt)&(wclusters!=-1)] 

1764 if len(np.unique(current_clusters))==0: 

1765 ignore_steps[i-int(dt/stepsize):i+int(dt/stepsize)] = 1 

1766 if verbose>0: 

1767 print('No pulsefish in recording at T=%.2f:%.2f' % (t, t+dt)) 

1768 

1769 

1770 x = np.arange(0, T-dt+stepsize, stepsize) 

1771 y = np.ones(len(x)) 

1772 

1773 running_sum = np.ones(len(np.arange(0, T+stepsize, stepsize))) 

1774 ulabs = np.unique(wclusters[wclusters>=0]) 

1775 

1776 # sliding window 

1777 for j, (t, ignore_step) in enumerate(zip(x, ignore_steps)): 

1778 current_clusters = wclusters[(weod_t>=t)&(weod_t<t+dt)&(wclusters!=-1)] 

1779 current_widths = weod_widths[(weod_t>=t)&(weod_t<t+dt)&(wclusters!=-1)] 

1780 

1781 unique_clusters = np.unique(current_clusters) 

1782 y[j] = len(unique_clusters) 

1783 

1784 if (len(unique_clusters) <= min_clusters) and \ 

1785 (ignore_step==0) and \ 

1786 (len(unique_clusters !=1)): 

1787 

1788 current_labels = np.isin(wclusters, unique_clusters) 

1789 current_height = np.mean(weod_heights[current_labels]) 

1790 

1791 # compute nr of clusters that are too sparse 

1792 clusters_after_deletion = np.unique(remove_sparse_detections(np.copy(clusters[np.isin(clusters, unique_clusters)]), samplerate*eod_widths[np.isin(clusters, unique_clusters)], samplerate, T)) 

1793 current_sparse_clusters = len(unique_clusters) - len(clusters_after_deletion[clusters_after_deletion!=-1]) 

1794 

1795 if current_sparse_clusters <= sparse_clusters and \ 

1796 ((current_sparse_clusters<sparse_clusters) or 

1797 (current_height > average_height) or 

1798 (len(unique_clusters) < min_clusters)): 

1799 

1800 keep_clusters = unique_clusters 

1801 min_clusters = len(unique_clusters) 

1802 average_height = current_height 

1803 window_end = t+dt 

1804 sparse_clusters = current_sparse_clusters 

1805 

1806 all_keep_clusters.append(keep_clusters) 

1807 all_windows.append(window_end) 

1808 

1809 if 'moving_fish' in return_data or plot_level>0: 

1810 if 'w' in mf_dict: 

1811 mf_dict['w'].append(np.median(eod_widths[width_classes==w])) 

1812 mf_dict['T'] = T 

1813 mf_dict['dt'].append(dt) 

1814 mf_dict['clusters'].append(wclusters) 

1815 mf_dict['t'].append(weod_t) 

1816 mf_dict['fishcount'].append([x+0.5*(x[1]-x[0]), y]) 

1817 mf_dict['ignore_steps'].append(ignore_steps) 

1818 else: 

1819 mf_dict['w'] = [np.median(eod_widths[width_classes==w])] 

1820 mf_dict['T'] = [T] 

1821 mf_dict['dt'] = [dt] 

1822 mf_dict['clusters'] = [wclusters] 

1823 mf_dict['t'] = [weod_t] 

1824 mf_dict['fishcount'] = [[x+0.5*(x[1]-x[0]), y]] 

1825 mf_dict['ignore_steps'] = [ignore_steps] 

1826 

1827 if verbose>0: 

1828 print('Estimated nr of pulsefish in recording: %i'%len(all_keep_clusters)) 

1829 

1830 if plot_level>0: 

1831 plot_moving_fish(mf_dict['w'], mf_dict['dt'], mf_dict['clusters'],mf_dict['t'], 

1832 mf_dict['fishcount'], T, mf_dict['ignore_steps']) 

1833 if save_plot: 

1834 plt.savefig('%sdelete_moving_fish.%s' % (save_path, ftype)) 

1835 # empty dict 

1836 if 'moving_fish' not in return_data: 

1837 mf_dict = {} 

1838 

1839 # delete all clusters that are not selected 

1840 clusters[np.invert(np.isin(clusters, np.concatenate(all_keep_clusters)))] = -1 

1841 

1842 return clusters, [np.max(all_windows)-np.max(all_dts), np.max(all_windows)], mf_dict 

1843 

1844 

1845def remove_sparse_detections(clusters, eod_widths, samplerate, T, 

1846 min_density=0.0005, verbose=0): 

1847 """ Remove all EOD clusters that are too sparse 

1848 

1849 Parameters 

1850 ---------- 

1851 clusters : list of ints 

1852 Cluster labels. 

1853 eod_widths : list of ints 

1854 Cluster widths in samples. 

1855 samplerate : float 

1856 Samplerate. 

1857 T : float 

1858 Lenght of recording in seconds. 

1859 min_density : float (optional) 

1860 Minimum density for realistic EOD detections. 

1861 verbose : int (optional) 

1862 Verbosity level. 

1863 

1864 Returns 

1865 ------- 

1866 clusters : list of ints 

1867 Cluster labels, where sparse clusters have been set to -1. 

1868 """ 

1869 for c in np.unique(clusters): 

1870 if c!=-1: 

1871 

1872 n = len(clusters[clusters==c]) 

1873 w = np.median(eod_widths[clusters==c])/samplerate 

1874 

1875 if n*w < T*min_density: 

1876 if verbose>0: 

1877 print('cluster %i is too sparse'%c) 

1878 clusters[clusters==c] = -1 

1879 return clusters