Coverage for src / thunderfish / pulseplots.py: 0%

487 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-15 17:50 +0000

1""" 

2Plot and save key steps in pulses.py for visualizing the alorithm. 

3""" 

4 

5import glob 

6import numpy as np 

7import matplotlib.pyplot as plt 

8 

9from scipy import stats 

10from matplotlib import gridspec, ticker 

11try: 

12 from matplotlib.colors import colorConverter as cc 

13except ImportError: 

14 import matplotlib.colors as cc 

15try: 

16 from matplotlib.colors import to_hex 

17except ImportError: 

18 from matplotlib.colors import rgb2hex as to_hex 

19from matplotlib.patches import ConnectionPatch, Rectangle 

20from matplotlib.lines import Line2D 

21 

22import warnings 

23def warn(*args, **kwargs): 

24 """ 

25 Ignore all warnings. 

26 """ 

27 pass 

28warnings.warn=warn 

29 

30 

31# plotting parameters and colors: 

32cmap = plt.get_cmap("Dark2") 

33c_g = cmap(0) 

34c_o = cmap(1) 

35c_grey = cmap(7) 

36cmap_pts = [cmap(2), cmap(3)] 

37 

38 

39def darker(color, saturation): 

40 """ Make a color darker. 

41 

42 From bendalab/plottools package. 

43 

44 Parameters 

45 ---------- 

46 color: dict or matplotlib color spec 

47 A matplotlib color (hex string, name color string, rgb tuple) 

48 or a dictionary with an 'color' or 'facecolor' key. 

49 saturation: float 

50 The smaller the saturation, the darker the returned color. 

51 A saturation of 0 returns black. 

52 A saturation of 1 leaves the color untouched. 

53 A saturation of 2 returns white. 

54 

55 Returns 

56 ------- 

57 color: string or dictionary 

58 The darker color as a hexadecimal RGB string (e.g. '#rrggbb'). 

59 If `color` is a dictionary, a copy of the dictionary is returned 

60 with the value of 'color' or 'facecolor' set to the darker color. 

61 """ 

62 try: 

63 c = color['color'] 

64 cd = dict(**color) 

65 cd['color'] = darker(c, saturation) 

66 return cd 

67 except (KeyError, TypeError): 

68 try: 

69 c = color['facecolor'] 

70 cd = dict(**color) 

71 cd['facecolor'] = darker(c, saturation) 

72 return cd 

73 except (KeyError, TypeError): 

74 if saturation > 2: 

75 sauration = 2 

76 if saturation > 1: 

77 return lighter(color, 2.0-saturation) 

78 if saturation < 0: 

79 saturation = 0 

80 r, g, b = cc.to_rgb(color) 

81 rd = r*saturation 

82 gd = g*saturation 

83 bd = b*saturation 

84 return to_hex((rd, gd, bd)).upper() 

85 

86 

87def lighter(color, lightness): 

88 """Make a color lighter 

89 

90 From bendalab/plottools package. 

91 

92 Parameters 

93 ---------- 

94 color: dict or matplotlib color spec 

95 A matplotlib color (hex string, name color string, rgb tuple) 

96 or a dictionary with an 'color' or 'facecolor' key. 

97 lightness: float 

98 The smaller the lightness, the lighter the returned color. 

99 A lightness of 0 returns white. 

100 A lightness of 1 leaves the color untouched. 

101 A lightness of 2 returns black. 

102 

103 Returns 

104 ------- 

105 color: string or dict 

106 The lighter color as a hexadecimal RGB string (e.g. '#rrggbb'). 

107 If `color` is a dictionary, a copy of the dictionary is returned 

108 with the value of 'color' or 'facecolor' set to the lighter color. 

109 """ 

110 try: 

111 c = color['color'] 

112 cd = dict(**color) 

113 cd['color'] = lighter(c, lightness) 

114 return cd 

115 except (KeyError, TypeError): 

116 try: 

117 c = color['facecolor'] 

118 cd = dict(**color) 

119 cd['facecolor'] = lighter(c, lightness) 

120 return cd 

121 except (KeyError, TypeError): 

122 if lightness > 2: 

123 lightness = 2 

124 if lightness > 1: 

125 return darker(color, 2.0-lightness) 

126 if lightness < 0: 

127 lightness = 0 

128 r, g, b = cc.to_rgb(color) 

129 rl = r + (1.0-lightness)*(1.0 - r) 

130 gl = g + (1.0-lightness)*(1.0 - g) 

131 bl = b + (1.0-lightness)*(1.0 - b) 

132 return to_hex((rl, gl, bl)).upper() 

133 

134 

135def xscalebar(ax, x, y, width, wunit=None, wformat=None, ha='left', va='bottom', 

136 lw=None, color=None, capsize=None, clw=None, **kwargs): 

137 """Horizontal scale bar with label. 

138 

139 From bendalab/plottools package. 

140 

141 Parameters 

142 ---------- 

143 ax: matplotlib axes 

144 Axes where to draw the scale bar. 

145 x: float 

146 x-coordinate where to draw the scale bar in relative units of the axes. 

147 y: float 

148 y-coordinate where to draw the scale bar in relative units of the axes. 

149 width: float 

150 Length of the scale bar in units of the data's x-values. 

151 wunit: string or None 

152 Optional unit of the data's x-values. 

153 wformat: string or None 

154 Optional format string for formatting the label of the scale bar 

155 or simply a string used for labeling the scale bar. 

156 ha: 'left', 'right', or 'center' 

157 Scale bar aligned left, right, or centered to (x, y) 

158 va: 'top' or 'bottom' 

159 Label of the scale bar either above or below the scale bar. 

160 lw: int, float, None 

161 Line width of the scale bar. 

162 color: matplotlib color 

163 Color of the scalebar. 

164 capsize: float or None 

165 If larger then zero draw cap lines at the ends of the bar. 

166 The length of the lines is given in points (same unit as linewidth). 

167 clw: int, float, None 

168 Line width of the cap lines. 

169 kwargs: key-word arguments 

170 Passed on to `ax.text()` used to print the scale bar label. 

171 """ 

172 ax.autoscale(False) 

173 # ax dimensions: 

174 pixelx = np.abs(np.diff(ax.get_window_extent().get_points()[:,0]))[0] 

175 pixely = np.abs(np.diff(ax.get_window_extent().get_points()[:,1]))[0] 

176 xmin, xmax = ax.get_xlim() 

177 ymin, ymax = ax.get_ylim() 

178 unitx = xmax - xmin 

179 unity = ymax - ymin 

180 dxu = np.abs(unitx)/pixelx 

181 dyu = np.abs(unity)/pixely 

182 # transform x, y from relative units to axis units: 

183 x = xmin + x*unitx 

184 y = ymin + y*unity 

185 # bar length: 

186 if wformat is None: 

187 wformat = '%.0f' 

188 if width < 1.0: 

189 wformat = '%.1f' 

190 try: 

191 ls = wformat % width 

192 width = float(ls) 

193 except TypeError: 

194 ls = wformat 

195 # bar: 

196 if ha == 'left': 

197 x0 = x 

198 x1 = x+width 

199 elif ha == 'right': 

200 x0 = x-width 

201 x1 = x 

202 else: 

203 x0 = x-0.5*width 

204 x1 = x+0.5*width 

205 # line width: 

206 if lw is None: 

207 lw = 2 

208 # color: 

209 if color is None: 

210 color = 'k' 

211 # scalebar: 

212 lh = ax.plot([x0, x1], [y, y], '-', color=color, lw=lw, 

213 solid_capstyle='butt', clip_on=False) 

214 # get y position of line in figure pixel coordinates: 

215 ly = np.array(lh[0].get_window_extent(ax.get_figure().canvas.get_renderer()))[0,1] 

216 # caps: 

217 if capsize is None: 

218 capsize = 0 

219 if clw is None: 

220 clw = 0.5 

221 if capsize > 0.0: 

222 dy = capsize*dyu 

223 ax.plot([x0, x0], [y-dy, y+dy], '-', color=color, lw=clw, 

224 solid_capstyle='butt', clip_on=False) 

225 ax.plot([x1, x1], [y-dy, y+dy], '-', color=color, lw=clw, 

226 solid_capstyle='butt', clip_on=False) 

227 # label: 

228 if wunit: 

229 ls += u'\u2009%s' % wunit 

230 if va == 'top': 

231 th = ax.text(0.5*(x0+x1), y, ls, clip_on=False, 

232 ha='center', va='bottom', **kwargs) 

233 # get y coordinate of text bottom in figure pixel coordinates: 

234 ty = np.array(th.get_window_extent(ax.get_figure().canvas.get_renderer()))[0,1] 

235 dty = ly+0.5*lw + 2.0 - ty 

236 else: 

237 th = ax.text(0.5*(x0+x1), y, ls, clip_on=False, 

238 ha='center', va='top', **kwargs) 

239 # get y coordinate of text bottom in figure pixel coordinates: 

240 ty = np.array(th.get_window_extent(ax.get_figure().canvas.get_renderer()))[1,1] 

241 dty = ly-0.5*lw - 2.0 - ty 

242 th.set_position((0.5*(x0+x1), y+dyu*dty)) 

243 return x0, x1, y 

244 

245 

246def yscalebar(ax, x, y, height, hunit=None, hformat=None, ha='left', va='bottom', 

247 lw=None, color=None, capsize=None, clw=None, **kwargs): 

248 

249 """Vertical scale bar with label. 

250 

251 From bendalab/plottools package. 

252 

253 Parameters 

254 ---------- 

255 ax: matplotlib axes 

256 Axes where to draw the scale bar. 

257 x: float 

258 x-coordinate where to draw the scale bar in relative units of the axes. 

259 y: float 

260 y-coordinate where to draw the scale bar in relative units of the axes. 

261 height: float 

262 Length of the scale bar in units of the data's y-values. 

263 hunit: string 

264 Unit of the data's y-values. 

265 hformat: string or None 

266 Optional format string for formatting the label of the scale bar 

267 or simply a string used for labeling the scale bar. 

268 ha: 'left' or 'right' 

269 Label of the scale bar either to the left or to the right 

270 of the scale bar. 

271 va: 'top', 'bottom', or 'center' 

272 Scale bar aligned above, below, or centered on (x, y). 

273 lw: int, float, None 

274 Line width of the scale bar. 

275 color: matplotlib color 

276 Color of the scalebar. 

277 capsize: float or None 

278 If larger then zero draw cap lines at the ends of the bar. 

279 The length of the lines is given in points (same unit as linewidth). 

280 clw: int, float 

281 Line width of the cap lines. 

282 kwargs: key-word arguments 

283 Passed on to `ax.text()` used to print the scale bar label. 

284 """ 

285 

286 ax.autoscale(False) 

287 # ax dimensions: 

288 pixelx = np.abs(np.diff(ax.get_window_extent().get_points()[:,0]))[0] 

289 pixely = np.abs(np.diff(ax.get_window_extent().get_points()[:,1]))[0] 

290 xmin, xmax = ax.get_xlim() 

291 ymin, ymax = ax.get_ylim() 

292 unitx = xmax - xmin 

293 unity = ymax - ymin 

294 dxu = np.abs(unitx)/pixelx 

295 dyu = np.abs(unity)/pixely 

296 # transform x, y from relative units to axis units: 

297 x = xmin + x*unitx 

298 y = ymin + y*unity 

299 # bar length: 

300 if hformat is None: 

301 hformat = '%.0f' 

302 if height < 1.0: 

303 hformat = '%.1f' 

304 try: 

305 ls = hformat % height 

306 width = float(ls) 

307 except TypeError: 

308 ls = hformat 

309 # bar: 

310 if va == 'bottom': 

311 y0 = y 

312 y1 = y+height 

313 elif va == 'top': 

314 y0 = y-height 

315 y1 = y 

316 else: 

317 y0 = y-0.5*height 

318 y1 = y+0.5*height 

319 # line width: 

320 if lw is None: 

321 lw = 2 

322 # color: 

323 if color is None: 

324 color = 'k' 

325 # scalebar: 

326 lh = ax.plot([x, x], [y0, y1], '-', color=color, lw=lw, 

327 solid_capstyle='butt', clip_on=False) 

328 # get x position of line in figure pixel coordinates: 

329 lx = np.array(lh[0].get_window_extent(ax.get_figure().canvas.get_renderer()))[0,0] 

330 # caps: 

331 if capsize is None: 

332 capsize = 0 

333 if clw is None: 

334 clw = 0.5 

335 if capsize > 0.0: 

336 dx = capsize*dxu 

337 ax.plot([x-dx, x+dx], [y0, y0], '-', color=color, lw=clw, solid_capstyle='butt', 

338 clip_on=False) 

339 ax.plot([x-dx, x+dx], [y1, y1], '-', color=color, lw=clw, solid_capstyle='butt', 

340 clip_on=False) 

341 # label: 

342 if hunit: 

343 ls += u'\u2009%s' % hunit 

344 if ha == 'right': 

345 th = ax.text(x, 0.5*(y0+y1), ls, clip_on=False, rotation=90.0, 

346 ha='left', va='center', **kwargs) 

347 # get x coordinate of text bottom in figure pixel coordinates: 

348 tx = np.array(th.get_window_extent(ax.get_figure().canvas.get_renderer()))[0,0] 

349 dtx = lx+0.5*lw + 2.0 - tx 

350 else: 

351 th = ax.text(x, 0.5*(y0+y1), ls, clip_on=False, rotation=90.0, 

352 ha='right', va='center', **kwargs) 

353 # get x coordinate of text bottom in figure pixel coordinates: 

354 tx = np.array(th.get_window_extent(ax.get_figure().canvas.get_renderer()))[1,0] 

355 dtx = lx-0.5*lw - 1.0 - tx 

356 th.set_position((x+dxu*dtx, 0.5*(y0+y1))) 

357 return x, y0, y1 

358 

359 

360def arrowed_spines(ax, ms=10): 

361 """ Spine with arrow on the y-axis of a plot. 

362 

363 Parameters 

364 ---------- 

365 ax : matplotlib figure axis 

366 Axis on which the arrow should be plot.  

367 """ 

368 xmin, xmax = ax.get_xlim() 

369 ymin, ymax = ax.get_ylim() 

370 ax.scatter([xmin], [ymax], s=ms, marker='^', clip_on=False, color='k') 

371 ax.set_xlim(xmin, xmax) 

372 ax.set_ylim(ymin, ymax) 

373 

374 

375def loghist(ax, x, bmin, bmax, n, c, orientation='vertical', label=''): 

376 """ Plot histogram with logarithmic scale. 

377 

378 Parameters 

379 ---------- 

380 ax : matplotlib axis 

381 Axis to plot the histogram on. 

382 x : numpy array 

383 Input data for histogram. 

384 bmin : float 

385 Minimum value for the histogram bins. 

386 bmax : float 

387 Maximum value for the histogram bins.  

388 n : int 

389 Number of bins. 

390 c : matplotlib color 

391 Color of histogram. 

392 orientation : string (optional) 

393 Histogram orientation. 

394 Defaults to 'vertical'. 

395 label : string (optional) 

396 Label for x.  

397 Defaults to '' (no label). 

398 

399 Returns 

400 ------- 

401 n : array 

402 The values of the histogram bins. 

403 bins : array 

404 The edges of the bins. 

405 patches : BarContainer 

406 Container of individual artists used to create the histogram. 

407 """ 

408 return ax.hist(x, bins=np.exp(np.linspace(np.log(bmin), np.log(bmax), n)), 

409 color=c, orientation=orientation, label=label) 

410 

411 

412def plot_all(data, eod_p_times, eod_tr_times, fs, mean_eods): 

413 """Quick way to view the output of extract_pulsefish in a single plot. 

414 

415 Parameters 

416 ---------- 

417 data: array 

418 Recording data. 

419 eod_p_times: array of ints 

420 EOD peak indices. 

421 eod_tr_times: array of ints 

422 EOD trough indices. 

423 fs: float 

424 Sampling rate. 

425 mean_eods: list of numpy arrays 

426 Mean EODs of each pulsefish found in the recording. 

427 """ 

428 fig = plt.figure(figsize=(10, 5)) 

429 

430 if len(eod_p_times) > 0: 

431 gs = gridspec.GridSpec(2, len(eod_p_times)) 

432 ax = fig.add_subplot(gs[0,:]) 

433 ax.plot(np.arange(len(data))/fs, data, c='k', alpha=0.3) 

434 

435 for i, (pt, tt) in enumerate(zip(eod_p_times, eod_tr_times)): 

436 ax.plot(pt, data[(pt*fs).astype('int')], 'o', label=i+1, ms=10, c=cmap(i)) 

437 ax.plot(tt, data[(tt*fs).astype('int')], 'o', label=i+1, ms=10, c=cmap(i)) 

438 

439 ax.set_xlabel('time [s]') 

440 ax.set_ylabel('amplitude [V]') 

441 

442 for i, m in enumerate(mean_eods): 

443 ax = fig.add_subplot(gs[1,i]) 

444 ax.plot(1000*m[0], 1000*m[1], c='k') 

445 

446 ax.fill_between(1000*m[0], 1000*(m[1]-m[2]), 1000*(m[1]+m[2]), color=cmap(i)) 

447 ax.set_xlabel('time [ms]') 

448 ax.set_ylabel('amplitude [mV]') 

449 else: 

450 plt.plot(np.arange(len(data))/fs, data, c='k', alpha=0.3) 

451 

452 plt.tight_layout() 

453 

454 

455def plot_clustering(rate, eod_widths, eod_hights, eod_shapes, disc_masks, merge_masks): 

456 """Plot all clustering steps. 

457  

458 Plot clustering steps on width, height and shape. Then plot the remaining EODs after  

459 the EOD assessment step and the EODs after the merge step. 

460 

461 Parameters 

462 ---------- 

463 rate : float 

464 Sampling rate of EOD snippets. 

465 eod_widths : list of three 1D numpy arrays 

466 The first list entry gives the unique labels of all width clusters as a list of ints. 

467 The second list entry gives the width values for each EOD in samples as a 

468 1D numpy array of ints. 

469 The third list entry gives the width labels for each EOD as a 1D numpy array of ints. 

470 eod_hights : nested lists (2 layers) of three 1D numpy arrays 

471 The first list entry gives the unique labels of all height clusters as a list of ints 

472 for each width cluster. 

473 The second list entry gives the height values for each EOD as a 1D numpy array 

474 of floats for each width cluster. 

475 The third list entry gives the height labels for each EOD as a 1D numpy array 

476 of ints for each width cluster. 

477 eod_shapes : nested lists (3 layers) of three 1D numpy arrays 

478 The first list entry gives the raw EOD snippets as a 2D numpy array for each 

479 height cluster in a width cluster. 

480 The second list entry gives the snippet PCA values for each EOD as a 2D numpy array 

481 of floats for each height cluster in a width cluster. 

482 The third list entry gives the shape labels for each EOD as a 1D numpy array of ints 

483 for each height cluster in a width cluster. 

484 disc_masks : Nested lists (two layers) of 1D numpy arrays 

485 The masks of EODs that are discarded by the discarding step of the algorithm. 

486 The masks are 1D boolean arrays where  

487 instances that are set to True are discarded by the algorithm. Discarding masks 

488 are saved in nested lists that represent the width and height clusters. 

489 merge_masks : Nested lists (two layers) of 2D numpy arrays 

490 The masks of EODs that are discarded by the merging step of the algorithm. 

491 The masks are 2D boolean arrays where  

492 for each sample point `i` either `merge_mask[i,0]` or `merge_mask[i,1]` is set to True. 

493 Here, merge_mask[:,0] represents the  

494 peak-centered clusters and `merge_mask[:,1]` represents the trough-centered clusters. 

495 Merge masks are saved in nested lists  

496 that represent the width and height clusters. 

497 """ 

498 # create figure + transparant figure. 

499 fig = plt.figure(figsize=(12, 7)) 

500 transFigure = fig.transFigure.inverted() 

501 

502 # set up the figure layout 

503 outer = gridspec.GridSpec(1, 5, width_ratios=[1, 1, 2, 1, 2], left=0.05, right=0.95) 

504 

505 # set titles for each clustering step 

506 titles = ['1. Widths', '2. Heights', '3. Shape', '4. Pulse EODs', '5. Merge'] 

507 for i, title in enumerate(titles): 

508 title_ax = gridspec.GridSpecFromSubplotSpec(1, 1, subplot_spec = outer[i]) 

509 ax = fig.add_subplot(title_ax[0]) 

510 ax.text(0, 110, title, ha='center', va='bottom', clip_on=False) 

511 ax.set_xlim(-100, 100) 

512 ax.set_ylim(-100, 100) 

513 ax.axis('off') 

514 

515 # compute sizes for each axis 

516 w_size = 1 

517 h_size = len(eod_hights[1]) 

518 

519 shape_size = np.sum([len(sl) for sl in eod_shapes[0]]) 

520 

521 # count required axes sized for the last two plot columns. 

522 disc_size = 0 

523 merge_size= 0 

524 for shapelabel, dmasks, mmasks in zip(eod_shapes[2], disc_masks, merge_masks): 

525 for sl, dm, mm in zip(shapelabel, dmasks, mmasks): 

526 uld1 = np.unique((sl[0]+1)*np.invert(dm[0])) 

527 uld2 = np.unique((sl[1]+1)*np.invert(dm[1])) 

528 disc_size = disc_size+len(uld1[uld1>0])+len(uld2[uld2>0]) 

529 

530 uld1 = np.unique((sl[0]+1)*mm[0]) 

531 uld2 = np.unique((sl[1]+1)*mm[1]) 

532 merge_size = merge_size+len(uld1[uld1>0])+len(uld2[uld2>0]) 

533 

534 # set counters to keep track of the plot axes 

535 disc_block = 0 

536 merge_block = 0 

537 shape_count = 0 

538 

539 # create all axes 

540 width_hist_ax = gridspec.GridSpecFromSubplotSpec(w_size, 1, subplot_spec = outer[0]) 

541 hight_hist_ax = gridspec.GridSpecFromSubplotSpec(h_size, 1, subplot_spec = outer[1]) 

542 shape_ax = gridspec.GridSpecFromSubplotSpec(shape_size, 1, subplot_spec = outer[2]) 

543 shape_windows = [gridspec.GridSpecFromSubplotSpec(2, 2, hspace=0.0, wspace=0.0, 

544 subplot_spec=shape_ax[i]) 

545 for i in range(shape_size)] 

546 

547 EOD_delete_ax = gridspec.GridSpecFromSubplotSpec(disc_size, 1, subplot_spec=outer[3]) 

548 EOD_merge_ax = gridspec.GridSpecFromSubplotSpec(merge_size, 1, subplot_spec=outer[4]) 

549 

550 # plot width labels histogram 

551 ax1 = fig.add_subplot(width_hist_ax[0]) 

552 # set axes features. 

553 ax1.set_xscale('log') 

554 ax1.spines['top'].set_visible(False) 

555 ax1.spines['right'].set_visible(False) 

556 ax1.spines['bottom'].set_visible(False) 

557 ax1.axes.xaxis.set_visible(False) 

558 ax1.set_yticklabels([]) 

559 

560 # indices for plot colors (dark to light) 

561 colidxsw = -np.linspace(-1.25, -0.5, h_size) 

562 

563 for i, (wl, colw, uhl, eod_h, eod_h_labs, w_snip, w_feat, w_lab, w_dm, w_mm) in enumerate(zip(eod_widths[0], colidxsw, eod_hights[0], eod_hights[1], eod_hights[2], eod_shapes[0], eod_shapes[1], eod_shapes[2], disc_masks, merge_masks)): 

564 

565 # plot width hist 

566 hw, _, _ = ax1.hist(eod_widths[1][eod_widths[2]==wl], 

567 bins=np.linspace(np.min(eod_widths[1]), np.max(eod_widths[1]), 100), 

568 color=lighter(c_o, colw), orientation='horizontal') 

569 

570 # set arrow when the last hist is plot so the size of the axes are known. 

571 if i == h_size-1: 

572 arrowed_spines(ax1, ms=20) 

573 

574 # determine total size of the hight historgams now. 

575 my, b = np.histogram(eod_h, bins=np.exp(np.linspace(np.min(np.log(eod_h)), 

576 np.max(np.log(eod_h)), 100))) 

577 maxy = np.max(my) 

578 

579 # set axes features for hight hist. 

580 ax2 = fig.add_subplot(hight_hist_ax[h_size-i-1]) 

581 ax2.set_xscale('log') 

582 ax2.spines['top'].set_visible(False) 

583 ax2.spines['right'].set_visible(False) 

584 ax2.spines['bottom'].set_visible(False) 

585 ax2.set_xlim(0.9, maxy) 

586 ax2.axes.xaxis.set_visible(False) 

587 ax2.set_yscale('log') 

588 ax2.yaxis.set_major_formatter(ticker.NullFormatter()) 

589 ax2.yaxis.set_minor_formatter(ticker.NullFormatter()) 

590 

591 # define colors for plots 

592 colidxsh = -np.linspace(-1.25, -0.5, len(uhl)) 

593 

594 for n, (hl, hcol, snippets, features, labels, dmasks, mmasks) in enumerate(zip(uhl, colidxsh, w_snip, w_feat, w_lab, w_dm, w_mm)): 

595 

596 hh, _, _ = loghist(ax2, eod_h[eod_h_labs==hl], np.min(eod_h), np.max(eod_h), 100, 

597 lighter(c_g, hcol), orientation='horizontal') 

598 

599 # set arrow spines only on last plot 

600 if n == len(uhl)-1: 

601 arrowed_spines(ax2, ms=10) 

602 

603 # plot line from the width histogram to the height histogram. 

604 if n == 0: 

605 coord1 = transFigure.transform(ax1.transData.transform([np.median(hw[hw!=0]), 

606 np.median(eod_widths[1][eod_widths[2]==wl])])) 

607 coord2 = transFigure.transform(ax2.transData.transform([0.9, np.mean(eod_h)])) 

608 line = Line2D((coord1[0], coord2[0]), (coord1[1], coord2[1]), 

609 transform=fig.transFigure, color='grey', linewidth=0.5) 

610 fig.lines.append(line) 

611 

612 # compute sizes of the eod_discarding and merge steps 

613 s1 = np.unique((labels[0]+1)*(~dmasks[0])) 

614 s2 = np.unique((labels[1]+1)*(~dmasks[1])) 

615 disc_block = disc_block + len(s1[s1>0]) + len(s2[s2>0]) 

616 

617 s1 = np.unique((labels[0]+1)*(mmasks[0])) 

618 s2 = np.unique((labels[1]+1)*(mmasks[1])) 

619 merge_block = merge_block + len(s1[s1>0]) + len(s2[s2>0]) 

620 

621 axs = [] 

622 disc_count = 0 

623 merge_count = 0 

624 

625 # now plot the clusters for peak and trough centerings 

626 for pt, cmap_pt in zip([0, 1], cmap_pts): 

627 

628 ax3 = fig.add_subplot(shape_windows[shape_size-1-shape_count][pt,0]) 

629 ax4 = fig.add_subplot(shape_windows[shape_size-1-shape_count][pt,1]) 

630 

631 # remove axes 

632 ax3.axes.xaxis.set_visible(False) 

633 ax4.axes.yaxis.set_visible(False) 

634 ax3.axes.yaxis.set_visible(False) 

635 ax4.axes.xaxis.set_visible(False) 

636 

637 # set color indices 

638 colidxss = -np.linspace(-1.25, -0.5, len(np.unique(labels[pt][labels[pt]>=0]))) 

639 j=0 

640 for c in np.unique(labels[pt]): 

641 

642 if c<0: 

643 # plot noise features + snippets 

644 ax3.plot(features[pt][labels[pt]==c,0], features[pt][labels[pt]==c,1], 

645 '.', color='lightgrey', label='-1', rasterized=True) 

646 ax4.plot(snippets[pt][labels[pt]==c].T, linewidth=0.1, 

647 color='lightgrey', label='-1', rasterized=True) 

648 else: 

649 # plot cluster features and snippets 

650 ax3.plot(features[pt][labels[pt]==c,0], features[pt][labels[pt]==c,1], 

651 '.', color=lighter(cmap_pt, colidxss[j]), label=c, 

652 rasterized=True) 

653 ax4.plot(snippets[pt][labels[pt]==c].T, linewidth=0.1, 

654 color=lighter(cmap_pt, colidxss[j]), label=c, rasterized=True) 

655 

656 # check if the current cluster is an EOD, if yes, plot it. 

657 if np.sum(dmasks[pt][labels[pt]==c]) == 0: 

658 

659 ax = fig.add_subplot(EOD_delete_ax[disc_size-disc_block+disc_count]) 

660 ax.axis('off') 

661 

662 # plot mean EOD snippet 

663 ax.plot(np.mean(snippets[pt][labels[pt]==c], axis=0), 

664 color=lighter(cmap_pt, colidxss[j])) 

665 disc_count = disc_count + 1 

666 

667 # match colors and draw line..  

668 coord1 = transFigure.transform(ax4.transData.transform([ax4.get_xlim()[1], 

669 ax4.get_ylim()[0] + 0.5*(ax4.get_ylim()[1]-ax4.get_ylim()[0])])) 

670 coord2 = transFigure.transform(ax.transData.transform([ax.get_xlim()[0],ax.get_ylim()[0] + 0.5*(ax.get_ylim()[1]-ax.get_ylim()[0])])) 

671 line = Line2D((coord1[0], coord2[0]), (coord1[1], coord2[1]), 

672 transform=fig.transFigure, color='grey', 

673 linewidth=0.5) 

674 fig.lines.append(line) 

675 axs.append(ax) 

676 

677 # check if the current EOD survives the merge step 

678 # if so, plot it. 

679 if np.sum(mmasks[pt, labels[pt]==c])>0: 

680 

681 ax = fig.add_subplot(EOD_merge_ax[merge_size-merge_block+merge_count]) 

682 ax.axis('off') 

683 

684 ax.plot(np.mean(snippets[pt][labels[pt]==c], axis=0), 

685 color=lighter(cmap_pt, colidxss[j])) 

686 merge_count = merge_count + 1 

687 

688 j=j+1 

689 

690 if pt==0: 

691 # draw line from hight cluster to EOD shape clusters. 

692 coord1 = transFigure.transform(ax2.transData.transform([np.median(hh[hh!=0]), 

693 np.median(eod_h[eod_h_labs==hl])])) 

694 coord2 = transFigure.transform(ax3.transData.transform([ax3.get_xlim()[0], 

695 ax3.get_ylim()[0]])) 

696 line = Line2D((coord1[0], coord2[0]), (coord1[1], coord2[1]), 

697 transform=fig.transFigure, color='grey', linewidth=0.5) 

698 fig.lines.append(line) 

699 

700 shape_count = shape_count + 1 

701 

702 if len(axs)>0: 

703 # plot lines that indicate the merged clusters. 

704 coord1 = transFigure.transform(axs[0].transData.transform([axs[0].get_xlim()[1]+0.1*(axs[0].get_xlim()[1]-axs[0].get_xlim()[0]), 

705 axs[0].get_ylim()[1]-0.25*(axs[0].get_ylim()[1]-axs[0].get_ylim()[0])])) 

706 coord2 = transFigure.transform(axs[-1].transData.transform([axs[-1].get_xlim()[1]+0.1*(axs[-1].get_xlim()[1]-axs[-1].get_xlim()[0]), 

707 axs[-1].get_ylim()[0]+0.25*(axs[-1].get_ylim()[1]-axs[-1].get_ylim()[0])])) 

708 line = Line2D((coord1[0], coord2[0]), (coord1[1], coord2[1]), 

709 transform=fig.transFigure, color='grey', linewidth=1) 

710 fig.lines.append(line) 

711 

712 

713def plot_bgm(x, means, variances, weights, use_log, labels, labels_am, xlab): 

714 """Plot a BGM clustering step either on EOD width or height. 

715 

716 Parameters 

717 ---------- 

718 x : 1D numpy array of floats 

719 BGM input values. 

720 means : list of floats 

721 BGM Gaussian means 

722 variances : list of floats 

723 BGM Gaussian variances. 

724 weights : list of floats 

725 BGM Gaussian weights. 

726 use_log : boolean 

727 True if the z-scored logarithm of the data was used as BGM input. 

728 labels : 1D numpy array of ints 

729 Labels defined by BGM model (before merging based on merge factor). 

730 labels_am : 1D numpy array of ints 

731 Labels defined by BGM model (after merging based on merge factor). 

732 xlab : string 

733 Label for plot (defines the units of the BGM data). 

734 """ 

735 if 'width' in xlab: 

736 ccol = c_o 

737 elif 'height' in xlab: 

738 ccol = c_g 

739 else: 

740 ccol = 'b' 

741 

742 # get the transform that was used as BGM input 

743 if use_log: 

744 x_transform = stats.zscore(np.log(x)) 

745 xplot = np.exp(np.linspace(np.log(np.min(x)), np.log(np.max(x)), 1000)) 

746 else: 

747 x_transform = stats.zscore(x) 

748 xplot = np.linspace(np.min(x), np.max(x), 1000) 

749 

750 # compute the x values and gaussians 

751 x2 = np.linspace(np.min(x_transform), np.max(x_transform), 1000) 

752 gaussians = [] 

753 gmax = 0 

754 for i, (w, m, std) in enumerate(zip(weights, means, variances)): 

755 gaus = np.sqrt(w*stats.norm.pdf(x2, m, np.sqrt(std))) 

756 gaussians.append(gaus) 

757 gmax = max(np.max(gaus), gmax) 

758 

759 # compute classes defined by gaussian intersections 

760 classes = np.argmax(np.vstack(gaussians), axis=0) 

761 

762 # find the minimum of any gaussian that is within its class 

763 gmin = 100 

764 for i, c in enumerate(np.unique(classes)): 

765 gmin=min(gmin, np.min(gaussians[c][classes==c])) 

766 

767 # set up the figure 

768 fig, ax1 = plt.subplots(figsize=(8, 4.8)) 

769 fig_ysize = 4 

770 ax2 = ax1.twinx() 

771 ax1.spines['top'].set_visible(False) 

772 ax2.spines['top'].set_visible(False) 

773 ax1.set_xlabel('x [a.u.]') 

774 ax1.set_ylabel('#') 

775 ax2.set_ylabel('Likelihood') 

776 ax2.set_yscale('log') 

777 ax1.set_yscale('log') 

778 if use_log: 

779 ax1.set_xscale('log') 

780 ax1.set_xlabel(xlab) 

781 

782 # define colors for plotting gaussians 

783 colidxs = -np.linspace(-1.25, -0.5, len(np.unique(classes))) 

784 

785 # plot the gaussians 

786 for i, c in enumerate(np.unique(classes)): 

787 ax2.plot(xplot, gaussians[c], c=lighter(c_grey, colidxs[i]), linewidth=2, 

788 label=r'$N(\mu_%i, \sigma_%i)$'%(c, c)) 

789 

790 # plot intersection lines 

791 ax2.vlines(xplot[1:][np.diff(classes)!=0], 0, gmax/gmin, color='k', linewidth=2, 

792 linestyle='--') 

793 ax2.set_ylim(gmin, np.max(np.vstack(gaussians))*1.1) 

794 

795 # plot data distributions and classes 

796 colidxs = -np.linspace(-1.25, -0.5, len(np.unique(labels))) 

797 for i, l in enumerate(np.unique(labels)): 

798 if use_log: 

799 h, binn, _ = loghist(ax1, x[labels==l], np.min(x), np.max(x), 100, 

800 lighter(ccol, colidxs[i]), label=r'$x_%i$'%l) 

801 else: 

802 h, binn, _ = ax1.hist(x[labels==l], bins=np.linspace(np.min(x), np.max(x), 100), 

803 color=lighter(ccol, colidxs[i]), label=r'$x_%i$'%l) 

804 

805 # annotate merged clusters 

806 for l in np.unique(labels_am): 

807 maps = np.unique(labels[labels_am==l]) 

808 if len(maps) > 1: 

809 x1 = x[labels==maps[0]] 

810 x2 = x[labels==maps[1]] 

811 

812 print(np.median(x1)) 

813 print(np.median(x2)) 

814 print(gmax) 

815 ax2.plot([np.median(x1), np.median(x2)], [1.2*gmax, 1.2*gmax], c='k', clip_on=False) 

816 ax2.plot([np.median(x1), np.median(x1)], [1.1*gmax, 1.2*gmax], c='k', clip_on=False) 

817 ax2.plot([np.median(x2), np.median(x2)], [1.1*gmax, 1.2*gmax], c='k', clip_on=False) 

818 ax2.annotate(r'$\frac{|{\tilde{x}_%i-\tilde{x}_%i}|}{max(\tilde{x}_%i, \tilde{x}_%i)} < \epsilon$' % (maps[0], maps[1], maps[0], maps[1]), [np.median(x1)*1.1, gmax*1.2], xytext=(10, 10), textcoords='offset points', fontsize=12, annotation_clip=False, ha='center') 

819 

820 # add legends and plot. 

821 ax2.legend(loc='lower left', frameon=False, bbox_to_anchor=(-0.05, 1.3), 

822 ncol=len(np.unique(classes))) 

823 ax1.legend(loc='upper left', frameon=False, bbox_to_anchor=(-0.05, 1.3), 

824 ncol=len(np.unique(labels))) 

825 plt.tight_layout() 

826 

827 

828def plot_feature_extraction(raw_snippets, normalized_snippets, features, labels, dt, pt): 

829 """Plot clustering step on EOD shape. 

830  

831 Parameters 

832 ---------- 

833 raw_snippets : 2D numpy array 

834 Raw EOD snippets. 

835 normalized_snippets : 2D numpy array 

836 Normalized EOD snippets. 

837 features : 2D numpy array 

838 PCA values for each normalized EOD snippet. 

839 labels : 1D numpy array of ints 

840 Cluster labels. 

841 dt : float 

842 Sample interval of snippets. 

843 pt : int 

844 Set to 0 for peak-centered EODs and set to 1 for trough-centered EODs. 

845 """ 

846 ccol = cmap_pts[pt] 

847 

848 # set up the figure layout 

849 fig = plt.figure(figsize=(((2+0.2)*4.8), 4.8)) 

850 outer = gridspec.GridSpec(1, 2, wspace=0.2, hspace=0) 

851 

852 x = np.arange(-dt*1000*raw_snippets.shape[1]/2, dt*1000*raw_snippets.shape[1]/2, dt*1000) 

853 

854 snip_ax = gridspec.GridSpecFromSubplotSpec(2, 1, subplot_spec = outer[0], hspace=0.35) 

855 pc_ax = gridspec.GridSpecFromSubplotSpec(features.shape[1]-1, features.shape[1]-1, 

856 subplot_spec = outer[1], hspace=0, wspace=0) 

857 

858 # 3 plots: raw snippets, normalized, pcs. 

859 ax_raw_snip = fig.add_subplot(snip_ax[0]) 

860 ax_normalized_snip = fig.add_subplot(snip_ax[1]) 

861 

862 colidxs = -np.linspace(-1.25, -0.5, len(np.unique(labels[labels>=0]))) 

863 j=0 

864 

865 for c in np.unique(labels): 

866 if c<0: 

867 color='lightgrey' 

868 else: 

869 color = lighter(ccol, colidxs[j]) 

870 j=j+1 

871 

872 ax_raw_snip.plot(x, raw_snippets[labels==c].T, color=color, label='-1', 

873 rasterized=True, alpha=0.25) 

874 ax_normalized_snip.plot(x, normalized_snippets[labels==c].T, color=color, alpha=0.25) 

875 ax_raw_snip.spines['top'].set_visible(False) 

876 ax_raw_snip.spines['right'].set_visible(False) 

877 ax_raw_snip.get_xaxis().set_ticklabels([]) 

878 ax_raw_snip.set_title('Raw snippets') 

879 ax_raw_snip.set_ylabel('Amplitude [a.u.]') 

880 ax_normalized_snip.spines['top'].set_visible(False) 

881 ax_normalized_snip.spines['right'].set_visible(False) 

882 ax_normalized_snip.set_title('Normalized snippets') 

883 ax_normalized_snip.set_ylabel('Amplitude [a.u.]') 

884 ax_normalized_snip.set_xlabel('Time [ms]') 

885 

886 ax_raw_snip.axis('off') 

887 ax_normalized_snip.axis('off') 

888 

889 ax_overlay = fig.add_subplot(pc_ax[:,:]) 

890 ax_overlay.set_title('Features') 

891 ax_overlay.axis('off') 

892 

893 for n in range(features.shape[1]): 

894 for m in range(n): 

895 ax = fig.add_subplot(pc_ax[n-1,m]) 

896 ax.scatter(features[labels==c,m], features[labels==c,n], marker='.', 

897 color=color, alpha=0.25) 

898 ax.set_xlim(np.min(features), np.max(features)) 

899 ax.set_ylim(np.min(features), np.max(features)) 

900 ax.get_xaxis().set_ticklabels([]) 

901 ax.get_yaxis().set_ticklabels([]) 

902 ax.get_xaxis().set_ticks([]) 

903 ax.get_yaxis().set_ticks([]) 

904 

905 if m==0: 

906 ax.set_ylabel('PC %i'%(n+1)) 

907 

908 if n==features.shape[1]-1: 

909 ax.set_xlabel('PC %i'%(m+1)) 

910 

911 ax = fig.add_subplot(pc_ax[0,features.shape[1]-2]) 

912 ax.set_xlim(np.min(features), np.max(features)) 

913 ax.set_ylim(np.min(features), np.max(features)) 

914 

915 size = max(1, int(np.ceil(-np.log10(np.max(features)-np.min(features))))) 

916 wbar = np.floor((np.max(features)-np.min(features))*10**size)/10**size 

917 

918 # should be smaller than the actual thing! so like x% of it? 

919 xscalebar(ax, 0, 0, wbar, wformat='%%.%if'%size) 

920 yscalebar(ax, 0, 0, wbar, hformat='%%.%if'%size) 

921 ax.axis('off') 

922 

923def plot_moving_fish(ws, dts, clusterss, ts, fishcounts, T, ignore_stepss): 

924 """Plot moving fish detection step. 

925 

926 Parameters 

927 ---------- 

928 ws : list of floats 

929 Median width for each width cluster that the moving fish algorithm is computed on 

930 (in seconds). 

931 dts : list of floats 

932 Sliding window size (in seconds) for each width cluster. 

933 clusterss : list of 1D numpy int arrays 

934 Cluster labels for each EOD cluster in a width cluster. 

935 ts : list of 1D numpy float arrays 

936 EOD emission times for each EOD in a width cluster. 

937 fishcounts : list of lists 

938 Sliding window timepoints and fishcounts for each width cluster. 

939 T : float 

940 Lenght of analyzed recording in seconds. 

941 ignore_stepss : list of 1D int arrays 

942 Mask for fishcounts that were ignored (ignored if True) in the moving_fish analysis. 

943 """ 

944 fig = plt.figure() 

945 

946 # create gridspec 

947 outer = gridspec.GridSpec(len(ws), 1) 

948 

949 for i, (w, dt, clusters, t, fishcount, ignore_steps) in enumerate(zip(ws, dts, clusterss, ts, fishcounts, ignore_stepss)): 

950 

951 gs = gridspec.GridSpecFromSubplotSpec(2, 1, subplot_spec = outer[i]) 

952 

953 # axis for clusters 

954 ax1 = fig.add_subplot(gs[0]) 

955 # axis for fishcount 

956 ax2 = fig.add_subplot(gs[1]) 

957 

958 # plot clusters as eventplot 

959 for cnum, c in enumerate(np.unique(clusters[clusters>=0])): 

960 ax1.eventplot(t[clusters==c], lineoffsets=cnum, linelengths=0.5, color=cmap(i)) 

961 cnum = cnum + 1 

962 

963 # Plot the sliding window 

964 rect=Rectangle((0, -0.5), dt, cnum, linewidth=1, linestyle='--', edgecolor='k', 

965 facecolor='none', clip_on=False) 

966 ax1.add_patch(rect) 

967 ax1.arrow(dt+0.1, -0.5, 0.5, 0, head_width=0.1, head_length=0.1, facecolor='k', 

968 edgecolor='k') 

969 

970 # plot parameters 

971 ax1.set_title(r'$\tilde{w}_%i = %.3f ms$'%(i, 1000*w)) 

972 ax1.set_ylabel('cluster #') 

973 ax1.set_yticks(range(0, cnum)) 

974 ax1.set_xlabel('time') 

975 ax1.set_xlim(0, T) 

976 ax1.axes.xaxis.set_visible(False) 

977 ax1.spines['bottom'].set_visible(False) 

978 ax1.spines['top'].set_visible(False) 

979 ax1.spines['right'].set_visible(False) 

980 ax1.spines['left'].set_visible(False) 

981 

982 # plot for fishcount 

983 x = fishcount[0] 

984 y = fishcount[1] 

985 

986 ax2 = fig.add_subplot(gs[1]) 

987 ax2.spines['top'].set_visible(False) 

988 ax2.spines['right'].set_visible(False) 

989 ax2.spines['bottom'].set_visible(False) 

990 ax2.axes.xaxis.set_visible(False) 

991 

992 yplot = np.copy(y) 

993 ax2.plot(x+dt/2, yplot, linestyle='-', marker='.', c=cmap(i), alpha=0.25) 

994 yplot[ignore_steps.astype(bool)] = np.nan 

995 ax2.plot(x+dt/2, yplot, linestyle='-', marker='.', c=cmap(i)) 

996 ax2.set_ylabel('Fish count') 

997 ax2.set_yticks(range(int(np.min(y)), 1+int(np.max(y)))) 

998 ax2.set_xlim(0, T) 

999 

1000 if i < len(ws)-1: 

1001 ax2.axes.xaxis.set_visible(False) 

1002 else: 

1003 ax2.axes.xaxis.set_visible(False) 

1004 xscalebar(ax2, 1, 0, 1, wunit='s', ha='right') 

1005 

1006 con = ConnectionPatch([0, -0.5], [dt/2, y[0]], "data", "data", 

1007 axesA=ax1, axesB=ax2, color='k') 

1008 ax2.add_artist(con) 

1009 con = ConnectionPatch([dt, -0.5], [dt/2, y[0]], "data", "data", 

1010 axesA=ax1, axesB=ax2, color='k') 

1011 ax2.add_artist(con) 

1012 

1013 plt.xlim(0, T)