Coverage for src/thunderfish/pulseplots.py: 0%

487 statements  

« prev     ^ index     » next       coverage.py v7.5.0, created at 2024-04-29 16:21 +0000

1""" 

2Plot and save key steps in pulses.py for visualizing the alorithm. 

3""" 

4 

5import glob 

6import numpy as np 

7from scipy import stats 

8from matplotlib import gridspec, ticker 

9import matplotlib.pyplot as plt 

10try: 

11 from matplotlib.colors import colorConverter as cc 

12except ImportError: 

13 import matplotlib.colors as cc 

14try: 

15 from matplotlib.colors import to_hex 

16except ImportError: 

17 from matplotlib.colors import rgb2hex as to_hex 

18from matplotlib.patches import ConnectionPatch, Rectangle 

19from matplotlib.lines import Line2D 

20import warnings 

21def warn(*args, **kwargs): 

22 """ 

23 Ignore all warnings. 

24 """ 

25 pass 

26warnings.warn=warn 

27 

28 

29# plotting parameters and colors: 

30cmap = plt.get_cmap("Dark2") 

31c_g = cmap(0) 

32c_o = cmap(1) 

33c_grey = cmap(7) 

34cmap_pts = [cmap(2), cmap(3)] 

35 

36 

37def darker(color, saturation): 

38 """ Make a color darker. 

39 

40 From bendalab/plottools package. 

41 

42 Parameters 

43 ---------- 

44 color: dict or matplotlib color spec 

45 A matplotlib color (hex string, name color string, rgb tuple) 

46 or a dictionary with an 'color' or 'facecolor' key. 

47 saturation: float 

48 The smaller the saturation, the darker the returned color. 

49 A saturation of 0 returns black. 

50 A saturation of 1 leaves the color untouched. 

51 A saturation of 2 returns white. 

52 

53 Returns 

54 ------- 

55 color: string or dictionary 

56 The darker color as a hexadecimal RGB string (e.g. '#rrggbb'). 

57 If `color` is a dictionary, a copy of the dictionary is returned 

58 with the value of 'color' or 'facecolor' set to the darker color. 

59 """ 

60 try: 

61 c = color['color'] 

62 cd = dict(**color) 

63 cd['color'] = darker(c, saturation) 

64 return cd 

65 except (KeyError, TypeError): 

66 try: 

67 c = color['facecolor'] 

68 cd = dict(**color) 

69 cd['facecolor'] = darker(c, saturation) 

70 return cd 

71 except (KeyError, TypeError): 

72 if saturation > 2: 

73 sauration = 2 

74 if saturation > 1: 

75 return lighter(color, 2.0-saturation) 

76 if saturation < 0: 

77 saturation = 0 

78 r, g, b = cc.to_rgb(color) 

79 rd = r*saturation 

80 gd = g*saturation 

81 bd = b*saturation 

82 return to_hex((rd, gd, bd)).upper() 

83 

84 

85def lighter(color, lightness): 

86 """Make a color lighter 

87 

88 From bendalab/plottools package. 

89 

90 Parameters 

91 ---------- 

92 color: dict or matplotlib color spec 

93 A matplotlib color (hex string, name color string, rgb tuple) 

94 or a dictionary with an 'color' or 'facecolor' key. 

95 lightness: float 

96 The smaller the lightness, the lighter the returned color. 

97 A lightness of 0 returns white. 

98 A lightness of 1 leaves the color untouched. 

99 A lightness of 2 returns black. 

100 

101 Returns 

102 ------- 

103 color: string or dict 

104 The lighter color as a hexadecimal RGB string (e.g. '#rrggbb'). 

105 If `color` is a dictionary, a copy of the dictionary is returned 

106 with the value of 'color' or 'facecolor' set to the lighter color. 

107 """ 

108 try: 

109 c = color['color'] 

110 cd = dict(**color) 

111 cd['color'] = lighter(c, lightness) 

112 return cd 

113 except (KeyError, TypeError): 

114 try: 

115 c = color['facecolor'] 

116 cd = dict(**color) 

117 cd['facecolor'] = lighter(c, lightness) 

118 return cd 

119 except (KeyError, TypeError): 

120 if lightness > 2: 

121 lightness = 2 

122 if lightness > 1: 

123 return darker(color, 2.0-lightness) 

124 if lightness < 0: 

125 lightness = 0 

126 r, g, b = cc.to_rgb(color) 

127 rl = r + (1.0-lightness)*(1.0 - r) 

128 gl = g + (1.0-lightness)*(1.0 - g) 

129 bl = b + (1.0-lightness)*(1.0 - b) 

130 return to_hex((rl, gl, bl)).upper() 

131 

132 

133def xscalebar(ax, x, y, width, wunit=None, wformat=None, ha='left', va='bottom', 

134 lw=None, color=None, capsize=None, clw=None, **kwargs): 

135 """Horizontal scale bar with label. 

136 

137 From bendalab/plottools package. 

138 

139 Parameters 

140 ---------- 

141 ax: matplotlib axes 

142 Axes where to draw the scale bar. 

143 x: float 

144 x-coordinate where to draw the scale bar in relative units of the axes. 

145 y: float 

146 y-coordinate where to draw the scale bar in relative units of the axes. 

147 width: float 

148 Length of the scale bar in units of the data's x-values. 

149 wunit: string or None 

150 Optional unit of the data's x-values. 

151 wformat: string or None 

152 Optional format string for formatting the label of the scale bar 

153 or simply a string used for labeling the scale bar. 

154 ha: 'left', 'right', or 'center' 

155 Scale bar aligned left, right, or centered to (x, y) 

156 va: 'top' or 'bottom' 

157 Label of the scale bar either above or below the scale bar. 

158 lw: int, float, None 

159 Line width of the scale bar. 

160 color: matplotlib color 

161 Color of the scalebar. 

162 capsize: float or None 

163 If larger then zero draw cap lines at the ends of the bar. 

164 The length of the lines is given in points (same unit as linewidth). 

165 clw: int, float, None 

166 Line width of the cap lines. 

167 kwargs: key-word arguments 

168 Passed on to `ax.text()` used to print the scale bar label. 

169 """ 

170 ax.autoscale(False) 

171 # ax dimensions: 

172 pixelx = np.abs(np.diff(ax.get_window_extent().get_points()[:,0]))[0] 

173 pixely = np.abs(np.diff(ax.get_window_extent().get_points()[:,1]))[0] 

174 xmin, xmax = ax.get_xlim() 

175 ymin, ymax = ax.get_ylim() 

176 unitx = xmax - xmin 

177 unity = ymax - ymin 

178 dxu = np.abs(unitx)/pixelx 

179 dyu = np.abs(unity)/pixely 

180 # transform x, y from relative units to axis units: 

181 x = xmin + x*unitx 

182 y = ymin + y*unity 

183 # bar length: 

184 if wformat is None: 

185 wformat = '%.0f' 

186 if width < 1.0: 

187 wformat = '%.1f' 

188 try: 

189 ls = wformat % width 

190 width = float(ls) 

191 except TypeError: 

192 ls = wformat 

193 # bar: 

194 if ha == 'left': 

195 x0 = x 

196 x1 = x+width 

197 elif ha == 'right': 

198 x0 = x-width 

199 x1 = x 

200 else: 

201 x0 = x-0.5*width 

202 x1 = x+0.5*width 

203 # line width: 

204 if lw is None: 

205 lw = 2 

206 # color: 

207 if color is None: 

208 color = 'k' 

209 # scalebar: 

210 lh = ax.plot([x0, x1], [y, y], '-', color=color, lw=lw, 

211 solid_capstyle='butt', clip_on=False) 

212 # get y position of line in figure pixel coordinates: 

213 ly = np.array(lh[0].get_window_extent(ax.get_figure().canvas.get_renderer()))[0,1] 

214 # caps: 

215 if capsize is None: 

216 capsize = 0 

217 if clw is None: 

218 clw = 0.5 

219 if capsize > 0.0: 

220 dy = capsize*dyu 

221 ax.plot([x0, x0], [y-dy, y+dy], '-', color=color, lw=clw, 

222 solid_capstyle='butt', clip_on=False) 

223 ax.plot([x1, x1], [y-dy, y+dy], '-', color=color, lw=clw, 

224 solid_capstyle='butt', clip_on=False) 

225 # label: 

226 if wunit: 

227 ls += u'\u2009%s' % wunit 

228 if va == 'top': 

229 th = ax.text(0.5*(x0+x1), y, ls, clip_on=False, 

230 ha='center', va='bottom', **kwargs) 

231 # get y coordinate of text bottom in figure pixel coordinates: 

232 ty = np.array(th.get_window_extent(ax.get_figure().canvas.get_renderer()))[0,1] 

233 dty = ly+0.5*lw + 2.0 - ty 

234 else: 

235 th = ax.text(0.5*(x0+x1), y, ls, clip_on=False, 

236 ha='center', va='top', **kwargs) 

237 # get y coordinate of text bottom in figure pixel coordinates: 

238 ty = np.array(th.get_window_extent(ax.get_figure().canvas.get_renderer()))[1,1] 

239 dty = ly-0.5*lw - 2.0 - ty 

240 th.set_position((0.5*(x0+x1), y+dyu*dty)) 

241 return x0, x1, y 

242 

243 

244def yscalebar(ax, x, y, height, hunit=None, hformat=None, ha='left', va='bottom', 

245 lw=None, color=None, capsize=None, clw=None, **kwargs): 

246 

247 """Vertical scale bar with label. 

248 

249 From bendalab/plottools package. 

250 

251 Parameters 

252 ---------- 

253 ax: matplotlib axes 

254 Axes where to draw the scale bar. 

255 x: float 

256 x-coordinate where to draw the scale bar in relative units of the axes. 

257 y: float 

258 y-coordinate where to draw the scale bar in relative units of the axes. 

259 height: float 

260 Length of the scale bar in units of the data's y-values. 

261 hunit: string 

262 Unit of the data's y-values. 

263 hformat: string or None 

264 Optional format string for formatting the label of the scale bar 

265 or simply a string used for labeling the scale bar. 

266 ha: 'left' or 'right' 

267 Label of the scale bar either to the left or to the right 

268 of the scale bar. 

269 va: 'top', 'bottom', or 'center' 

270 Scale bar aligned above, below, or centered on (x, y). 

271 lw: int, float, None 

272 Line width of the scale bar. 

273 color: matplotlib color 

274 Color of the scalebar. 

275 capsize: float or None 

276 If larger then zero draw cap lines at the ends of the bar. 

277 The length of the lines is given in points (same unit as linewidth). 

278 clw: int, float 

279 Line width of the cap lines. 

280 kwargs: key-word arguments 

281 Passed on to `ax.text()` used to print the scale bar label. 

282 """ 

283 

284 ax.autoscale(False) 

285 # ax dimensions: 

286 pixelx = np.abs(np.diff(ax.get_window_extent().get_points()[:,0]))[0] 

287 pixely = np.abs(np.diff(ax.get_window_extent().get_points()[:,1]))[0] 

288 xmin, xmax = ax.get_xlim() 

289 ymin, ymax = ax.get_ylim() 

290 unitx = xmax - xmin 

291 unity = ymax - ymin 

292 dxu = np.abs(unitx)/pixelx 

293 dyu = np.abs(unity)/pixely 

294 # transform x, y from relative units to axis units: 

295 x = xmin + x*unitx 

296 y = ymin + y*unity 

297 # bar length: 

298 if hformat is None: 

299 hformat = '%.0f' 

300 if height < 1.0: 

301 hformat = '%.1f' 

302 try: 

303 ls = hformat % height 

304 width = float(ls) 

305 except TypeError: 

306 ls = hformat 

307 # bar: 

308 if va == 'bottom': 

309 y0 = y 

310 y1 = y+height 

311 elif va == 'top': 

312 y0 = y-height 

313 y1 = y 

314 else: 

315 y0 = y-0.5*height 

316 y1 = y+0.5*height 

317 # line width: 

318 if lw is None: 

319 lw = 2 

320 # color: 

321 if color is None: 

322 color = 'k' 

323 # scalebar: 

324 lh = ax.plot([x, x], [y0, y1], '-', color=color, lw=lw, 

325 solid_capstyle='butt', clip_on=False) 

326 # get x position of line in figure pixel coordinates: 

327 lx = np.array(lh[0].get_window_extent(ax.get_figure().canvas.get_renderer()))[0,0] 

328 # caps: 

329 if capsize is None: 

330 capsize = 0 

331 if clw is None: 

332 clw = 0.5 

333 if capsize > 0.0: 

334 dx = capsize*dxu 

335 ax.plot([x-dx, x+dx], [y0, y0], '-', color=color, lw=clw, solid_capstyle='butt', 

336 clip_on=False) 

337 ax.plot([x-dx, x+dx], [y1, y1], '-', color=color, lw=clw, solid_capstyle='butt', 

338 clip_on=False) 

339 # label: 

340 if hunit: 

341 ls += u'\u2009%s' % hunit 

342 if ha == 'right': 

343 th = ax.text(x, 0.5*(y0+y1), ls, clip_on=False, rotation=90.0, 

344 ha='left', va='center', **kwargs) 

345 # get x coordinate of text bottom in figure pixel coordinates: 

346 tx = np.array(th.get_window_extent(ax.get_figure().canvas.get_renderer()))[0,0] 

347 dtx = lx+0.5*lw + 2.0 - tx 

348 else: 

349 th = ax.text(x, 0.5*(y0+y1), ls, clip_on=False, rotation=90.0, 

350 ha='right', va='center', **kwargs) 

351 # get x coordinate of text bottom in figure pixel coordinates: 

352 tx = np.array(th.get_window_extent(ax.get_figure().canvas.get_renderer()))[1,0] 

353 dtx = lx-0.5*lw - 1.0 - tx 

354 th.set_position((x+dxu*dtx, 0.5*(y0+y1))) 

355 return x, y0, y1 

356 

357 

358def arrowed_spines(ax, ms=10): 

359 """ Spine with arrow on the y-axis of a plot. 

360 

361 Parameters 

362 ---------- 

363 ax : matplotlib figure axis 

364 Axis on which the arrow should be plot.  

365 """ 

366 xmin, xmax = ax.get_xlim() 

367 ymin, ymax = ax.get_ylim() 

368 ax.scatter([xmin], [ymax], s=ms, marker='^', clip_on=False, color='k') 

369 ax.set_xlim(xmin, xmax) 

370 ax.set_ylim(ymin, ymax) 

371 

372 

373def loghist(ax, x, bmin, bmax, n, c, orientation='vertical', label=''): 

374 """ Plot histogram with logarithmic scale. 

375 

376 Parameters 

377 ---------- 

378 ax : matplotlib axis 

379 Axis to plot the histogram on. 

380 x : numpy array 

381 Input data for histogram. 

382 bmin : float 

383 Minimum value for the histogram bins. 

384 bmax : float 

385 Maximum value for the histogram bins.  

386 n : int 

387 Number of bins. 

388 c : matplotlib color 

389 Color of histogram. 

390 orientation : string (optional) 

391 Histogram orientation. 

392 Defaults to 'vertical'. 

393 label : string (optional) 

394 Label for x.  

395 Defaults to '' (no label). 

396 

397 Returns 

398 ------- 

399 n : array 

400 The values of the histogram bins. 

401 bins : array 

402 The edges of the bins. 

403 patches : BarContainer 

404 Container of individual artists used to create the histogram. 

405 """ 

406 return ax.hist(x, bins=np.exp(np.linspace(np.log(bmin), np.log(bmax), n)), 

407 color=c, orientation=orientation, label=label) 

408 

409 

410def plot_all(data, eod_p_times, eod_tr_times, fs, mean_eods): 

411 """Quick way to view the output of extract_pulsefish in a single plot. 

412 

413 Parameters 

414 ---------- 

415 data: array 

416 Recording data. 

417 eod_p_times: array of ints 

418 EOD peak indices. 

419 eod_tr_times: array of ints 

420 EOD trough indices. 

421 fs: float 

422 Samplerate. 

423 mean_eods: list of numpy arrays 

424 Mean EODs of each pulsefish found in the recording. 

425 """ 

426 fig = plt.figure(figsize=(10, 5)) 

427 

428 if len(eod_p_times) > 0: 

429 gs = gridspec.GridSpec(2, len(eod_p_times)) 

430 ax = fig.add_subplot(gs[0,:]) 

431 ax.plot(np.arange(len(data))/fs, data, c='k', alpha=0.3) 

432 

433 for i, (pt, tt) in enumerate(zip(eod_p_times, eod_tr_times)): 

434 ax.plot(pt, data[(pt*fs).astype('int')], 'o', label=i+1, ms=10, c=cmap(i)) 

435 ax.plot(tt, data[(tt*fs).astype('int')], 'o', label=i+1, ms=10, c=cmap(i)) 

436 

437 ax.set_xlabel('time [s]') 

438 ax.set_ylabel('amplitude [V]') 

439 

440 for i, m in enumerate(mean_eods): 

441 ax = fig.add_subplot(gs[1,i]) 

442 ax.plot(1000*m[0], 1000*m[1], c='k') 

443 

444 ax.fill_between(1000*m[0], 1000*(m[1]-m[2]), 1000*(m[1]+m[2]), color=cmap(i)) 

445 ax.set_xlabel('time [ms]') 

446 ax.set_ylabel('amplitude [mV]') 

447 else: 

448 plt.plot(np.arange(len(data))/fs, data, c='k', alpha=0.3) 

449 

450 plt.tight_layout() 

451 

452 

453def plot_clustering(samplerate, eod_widths, eod_hights, eod_shapes, disc_masks, merge_masks): 

454 """Plot all clustering steps. 

455  

456 Plot clustering steps on width, height and shape. Then plot the remaining EODs after  

457 the EOD assessment step and the EODs after the merge step. 

458 

459 Parameters 

460 ---------- 

461 samplerate : float 

462 Samplerate of EOD snippets. 

463 eod_widths : list of three 1D numpy arrays 

464 The first list entry gives the unique labels of all width clusters as a list of ints. 

465 The second list entry gives the width values for each EOD in samples as a 

466 1D numpy array of ints. 

467 The third list entry gives the width labels for each EOD as a 1D numpy array of ints. 

468 eod_hights : nested lists (2 layers) of three 1D numpy arrays 

469 The first list entry gives the unique labels of all height clusters as a list of ints 

470 for each width cluster. 

471 The second list entry gives the height values for each EOD as a 1D numpy array 

472 of floats for each width cluster. 

473 The third list entry gives the height labels for each EOD as a 1D numpy array 

474 of ints for each width cluster. 

475 eod_shapes : nested lists (3 layers) of three 1D numpy arrays 

476 The first list entry gives the raw EOD snippets as a 2D numpy array for each 

477 height cluster in a width cluster. 

478 The second list entry gives the snippet PCA values for each EOD as a 2D numpy array 

479 of floats for each height cluster in a width cluster. 

480 The third list entry gives the shape labels for each EOD as a 1D numpy array of ints 

481 for each height cluster in a width cluster. 

482 disc_masks : Nested lists (two layers) of 1D numpy arrays 

483 The masks of EODs that are discarded by the discarding step of the algorithm. 

484 The masks are 1D boolean arrays where  

485 instances that are set to True are discarded by the algorithm. Discarding masks 

486 are saved in nested lists that represent the width and height clusters. 

487 merge_masks : Nested lists (two layers) of 2D numpy arrays 

488 The masks of EODs that are discarded by the merging step of the algorithm. 

489 The masks are 2D boolean arrays where  

490 for each sample point `i` either `merge_mask[i,0]` or `merge_mask[i,1]` is set to True. 

491 Here, merge_mask[:,0] represents the  

492 peak-centered clusters and `merge_mask[:,1]` represents the trough-centered clusters. 

493 Merge masks are saved in nested lists  

494 that represent the width and height clusters. 

495 """ 

496 # create figure + transparant figure. 

497 fig = plt.figure(figsize=(12, 7)) 

498 transFigure = fig.transFigure.inverted() 

499 

500 # set up the figure layout 

501 outer = gridspec.GridSpec(1, 5, width_ratios=[1, 1, 2, 1, 2], left=0.05, right=0.95) 

502 

503 # set titles for each clustering step 

504 titles = ['1. Widths', '2. Heights', '3. Shape', '4. Pulse EODs', '5. Merge'] 

505 for i, title in enumerate(titles): 

506 title_ax = gridspec.GridSpecFromSubplotSpec(1, 1, subplot_spec = outer[i]) 

507 ax = fig.add_subplot(title_ax[0]) 

508 ax.text(0, 110, title, ha='center', va='bottom', clip_on=False) 

509 ax.set_xlim(-100, 100) 

510 ax.set_ylim(-100, 100) 

511 ax.axis('off') 

512 

513 # compute sizes for each axis 

514 w_size = 1 

515 h_size = len(eod_hights[1]) 

516 

517 shape_size = np.sum([len(sl) for sl in eod_shapes[0]]) 

518 

519 # count required axes sized for the last two plot columns. 

520 disc_size = 0 

521 merge_size= 0 

522 for shapelabel, dmasks, mmasks in zip(eod_shapes[2], disc_masks, merge_masks): 

523 for sl, dm, mm in zip(shapelabel, dmasks, mmasks): 

524 uld1 = np.unique((sl[0]+1)*np.invert(dm[0])) 

525 uld2 = np.unique((sl[1]+1)*np.invert(dm[1])) 

526 disc_size = disc_size+len(uld1[uld1>0])+len(uld2[uld2>0]) 

527 

528 uld1 = np.unique((sl[0]+1)*mm[0]) 

529 uld2 = np.unique((sl[1]+1)*mm[1]) 

530 merge_size = merge_size+len(uld1[uld1>0])+len(uld2[uld2>0]) 

531 

532 # set counters to keep track of the plot axes 

533 disc_block = 0 

534 merge_block = 0 

535 shape_count = 0 

536 

537 # create all axes 

538 width_hist_ax = gridspec.GridSpecFromSubplotSpec(w_size, 1, subplot_spec = outer[0]) 

539 hight_hist_ax = gridspec.GridSpecFromSubplotSpec(h_size, 1, subplot_spec = outer[1]) 

540 shape_ax = gridspec.GridSpecFromSubplotSpec(shape_size, 1, subplot_spec = outer[2]) 

541 shape_windows = [gridspec.GridSpecFromSubplotSpec(2, 2, hspace=0.0, wspace=0.0, 

542 subplot_spec=shape_ax[i]) 

543 for i in range(shape_size)] 

544 

545 EOD_delete_ax = gridspec.GridSpecFromSubplotSpec(disc_size, 1, subplot_spec=outer[3]) 

546 EOD_merge_ax = gridspec.GridSpecFromSubplotSpec(merge_size, 1, subplot_spec=outer[4]) 

547 

548 # plot width labels histogram 

549 ax1 = fig.add_subplot(width_hist_ax[0]) 

550 # set axes features. 

551 ax1.set_xscale('log') 

552 ax1.spines['top'].set_visible(False) 

553 ax1.spines['right'].set_visible(False) 

554 ax1.spines['bottom'].set_visible(False) 

555 ax1.axes.xaxis.set_visible(False) 

556 ax1.set_yticklabels([]) 

557 

558 # indices for plot colors (dark to light) 

559 colidxsw = -np.linspace(-1.25, -0.5, h_size) 

560 

561 for i, (wl, colw, uhl, eod_h, eod_h_labs, w_snip, w_feat, w_lab, w_dm, w_mm) in enumerate(zip(eod_widths[0], colidxsw, eod_hights[0], eod_hights[1], eod_hights[2], eod_shapes[0], eod_shapes[1], eod_shapes[2], disc_masks, merge_masks)): 

562 

563 # plot width hist 

564 hw, _, _ = ax1.hist(eod_widths[1][eod_widths[2]==wl], 

565 bins=np.linspace(np.min(eod_widths[1]), np.max(eod_widths[1]), 100), 

566 color=lighter(c_o, colw), orientation='horizontal') 

567 

568 # set arrow when the last hist is plot so the size of the axes are known. 

569 if i == h_size-1: 

570 arrowed_spines(ax1, ms=20) 

571 

572 # determine total size of the hight historgams now. 

573 my, b = np.histogram(eod_h, bins=np.exp(np.linspace(np.min(np.log(eod_h)), 

574 np.max(np.log(eod_h)), 100))) 

575 maxy = np.max(my) 

576 

577 # set axes features for hight hist. 

578 ax2 = fig.add_subplot(hight_hist_ax[h_size-i-1]) 

579 ax2.set_xscale('log') 

580 ax2.spines['top'].set_visible(False) 

581 ax2.spines['right'].set_visible(False) 

582 ax2.spines['bottom'].set_visible(False) 

583 ax2.set_xlim(0.9, maxy) 

584 ax2.axes.xaxis.set_visible(False) 

585 ax2.set_yscale('log') 

586 ax2.yaxis.set_major_formatter(ticker.NullFormatter()) 

587 ax2.yaxis.set_minor_formatter(ticker.NullFormatter()) 

588 

589 # define colors for plots 

590 colidxsh = -np.linspace(-1.25, -0.5, len(uhl)) 

591 

592 for n, (hl, hcol, snippets, features, labels, dmasks, mmasks) in enumerate(zip(uhl, colidxsh, w_snip, w_feat, w_lab, w_dm, w_mm)): 

593 

594 hh, _, _ = loghist(ax2, eod_h[eod_h_labs==hl], np.min(eod_h), np.max(eod_h), 100, 

595 lighter(c_g, hcol), orientation='horizontal') 

596 

597 # set arrow spines only on last plot 

598 if n == len(uhl)-1: 

599 arrowed_spines(ax2, ms=10) 

600 

601 # plot line from the width histogram to the height histogram. 

602 if n == 0: 

603 coord1 = transFigure.transform(ax1.transData.transform([np.median(hw[hw!=0]), 

604 np.median(eod_widths[1][eod_widths[2]==wl])])) 

605 coord2 = transFigure.transform(ax2.transData.transform([0.9, np.mean(eod_h)])) 

606 line = Line2D((coord1[0], coord2[0]), (coord1[1], coord2[1]), 

607 transform=fig.transFigure, color='grey', linewidth=0.5) 

608 fig.lines.append(line) 

609 

610 # compute sizes of the eod_discarding and merge steps 

611 s1 = np.unique((labels[0]+1)*(~dmasks[0])) 

612 s2 = np.unique((labels[1]+1)*(~dmasks[1])) 

613 disc_block = disc_block + len(s1[s1>0]) + len(s2[s2>0]) 

614 

615 s1 = np.unique((labels[0]+1)*(mmasks[0])) 

616 s2 = np.unique((labels[1]+1)*(mmasks[1])) 

617 merge_block = merge_block + len(s1[s1>0]) + len(s2[s2>0]) 

618 

619 axs = [] 

620 disc_count = 0 

621 merge_count = 0 

622 

623 # now plot the clusters for peak and trough centerings 

624 for pt, cmap_pt in zip([0, 1], cmap_pts): 

625 

626 ax3 = fig.add_subplot(shape_windows[shape_size-1-shape_count][pt,0]) 

627 ax4 = fig.add_subplot(shape_windows[shape_size-1-shape_count][pt,1]) 

628 

629 # remove axes 

630 ax3.axes.xaxis.set_visible(False) 

631 ax4.axes.yaxis.set_visible(False) 

632 ax3.axes.yaxis.set_visible(False) 

633 ax4.axes.xaxis.set_visible(False) 

634 

635 # set color indices 

636 colidxss = -np.linspace(-1.25, -0.5, len(np.unique(labels[pt][labels[pt]>=0]))) 

637 j=0 

638 for c in np.unique(labels[pt]): 

639 

640 if c<0: 

641 # plot noise features + snippets 

642 ax3.plot(features[pt][labels[pt]==c,0], features[pt][labels[pt]==c,1], 

643 '.', color='lightgrey', label='-1', rasterized=True) 

644 ax4.plot(snippets[pt][labels[pt]==c].T, linewidth=0.1, 

645 color='lightgrey', label='-1', rasterized=True) 

646 else: 

647 # plot cluster features and snippets 

648 ax3.plot(features[pt][labels[pt]==c,0], features[pt][labels[pt]==c,1], 

649 '.', color=lighter(cmap_pt, colidxss[j]), label=c, 

650 rasterized=True) 

651 ax4.plot(snippets[pt][labels[pt]==c].T, linewidth=0.1, 

652 color=lighter(cmap_pt, colidxss[j]), label=c, rasterized=True) 

653 

654 # check if the current cluster is an EOD, if yes, plot it. 

655 if np.sum(dmasks[pt][labels[pt]==c]) == 0: 

656 

657 ax = fig.add_subplot(EOD_delete_ax[disc_size-disc_block+disc_count]) 

658 ax.axis('off') 

659 

660 # plot mean EOD snippet 

661 ax.plot(np.mean(snippets[pt][labels[pt]==c], axis=0), 

662 color=lighter(cmap_pt, colidxss[j])) 

663 disc_count = disc_count + 1 

664 

665 # match colors and draw line..  

666 coord1 = transFigure.transform(ax4.transData.transform([ax4.get_xlim()[1], 

667 ax4.get_ylim()[0] + 0.5*(ax4.get_ylim()[1]-ax4.get_ylim()[0])])) 

668 coord2 = transFigure.transform(ax.transData.transform([ax.get_xlim()[0],ax.get_ylim()[0] + 0.5*(ax.get_ylim()[1]-ax.get_ylim()[0])])) 

669 line = Line2D((coord1[0], coord2[0]), (coord1[1], coord2[1]), 

670 transform=fig.transFigure, color='grey', 

671 linewidth=0.5) 

672 fig.lines.append(line) 

673 axs.append(ax) 

674 

675 # check if the current EOD survives the merge step 

676 # if so, plot it. 

677 if np.sum(mmasks[pt, labels[pt]==c])>0: 

678 

679 ax = fig.add_subplot(EOD_merge_ax[merge_size-merge_block+merge_count]) 

680 ax.axis('off') 

681 

682 ax.plot(np.mean(snippets[pt][labels[pt]==c], axis=0), 

683 color=lighter(cmap_pt, colidxss[j])) 

684 merge_count = merge_count + 1 

685 

686 j=j+1 

687 

688 if pt==0: 

689 # draw line from hight cluster to EOD shape clusters. 

690 coord1 = transFigure.transform(ax2.transData.transform([np.median(hh[hh!=0]), 

691 np.median(eod_h[eod_h_labs==hl])])) 

692 coord2 = transFigure.transform(ax3.transData.transform([ax3.get_xlim()[0], 

693 ax3.get_ylim()[0]])) 

694 line = Line2D((coord1[0], coord2[0]), (coord1[1], coord2[1]), 

695 transform=fig.transFigure, color='grey', linewidth=0.5) 

696 fig.lines.append(line) 

697 

698 shape_count = shape_count + 1 

699 

700 if len(axs)>0: 

701 # plot lines that indicate the merged clusters. 

702 coord1 = transFigure.transform(axs[0].transData.transform([axs[0].get_xlim()[1]+0.1*(axs[0].get_xlim()[1]-axs[0].get_xlim()[0]), 

703 axs[0].get_ylim()[1]-0.25*(axs[0].get_ylim()[1]-axs[0].get_ylim()[0])])) 

704 coord2 = transFigure.transform(axs[-1].transData.transform([axs[-1].get_xlim()[1]+0.1*(axs[-1].get_xlim()[1]-axs[-1].get_xlim()[0]), 

705 axs[-1].get_ylim()[0]+0.25*(axs[-1].get_ylim()[1]-axs[-1].get_ylim()[0])])) 

706 line = Line2D((coord1[0], coord2[0]), (coord1[1], coord2[1]), 

707 transform=fig.transFigure, color='grey', linewidth=1) 

708 fig.lines.append(line) 

709 

710 

711def plot_bgm(x, means, variances, weights, use_log, labels, labels_am, xlab): 

712 """Plot a BGM clustering step either on EOD width or height. 

713 

714 Parameters 

715 ---------- 

716 x : 1D numpy array of floats 

717 BGM input values. 

718 means : list of floats 

719 BGM Gaussian means 

720 variances : list of floats 

721 BGM Gaussian variances. 

722 weights : list of floats 

723 BGM Gaussian weights. 

724 use_log : boolean 

725 True if the z-scored logarithm of the data was used as BGM input. 

726 labels : 1D numpy array of ints 

727 Labels defined by BGM model (before merging based on merge factor). 

728 labels_am : 1D numpy array of ints 

729 Labels defined by BGM model (after merging based on merge factor). 

730 xlab : string 

731 Label for plot (defines the units of the BGM data). 

732 """ 

733 if 'width' in xlab: 

734 ccol = c_o 

735 elif 'height' in xlab: 

736 ccol = c_g 

737 else: 

738 ccol = 'b' 

739 

740 # get the transform that was used as BGM input 

741 if use_log: 

742 x_transform = stats.zscore(np.log(x)) 

743 xplot = np.exp(np.linspace(np.log(np.min(x)), np.log(np.max(x)), 1000)) 

744 else: 

745 x_transform = stats.zscore(x) 

746 xplot = np.linspace(np.min(x), np.max(x), 1000) 

747 

748 # compute the x values and gaussians 

749 x2 = np.linspace(np.min(x_transform), np.max(x_transform), 1000) 

750 gaussians = [] 

751 gmax = 0 

752 for i, (w, m, std) in enumerate(zip(weights, means, variances)): 

753 gaus = np.sqrt(w*stats.norm.pdf(x2, m, np.sqrt(std))) 

754 gaussians.append(gaus) 

755 gmax = max(np.max(gaus), gmax) 

756 

757 # compute classes defined by gaussian intersections 

758 classes = np.argmax(np.vstack(gaussians), axis=0) 

759 

760 # find the minimum of any gaussian that is within its class 

761 gmin = 100 

762 for i, c in enumerate(np.unique(classes)): 

763 gmin=min(gmin, np.min(gaussians[c][classes==c])) 

764 

765 # set up the figure 

766 fig, ax1 = plt.subplots(figsize=(8, 4.8)) 

767 fig_ysize = 4 

768 ax2 = ax1.twinx() 

769 ax1.spines['top'].set_visible(False) 

770 ax2.spines['top'].set_visible(False) 

771 ax1.set_xlabel('x [a.u.]') 

772 ax1.set_ylabel('#') 

773 ax2.set_ylabel('Likelihood') 

774 ax2.set_yscale('log') 

775 ax1.set_yscale('log') 

776 if use_log: 

777 ax1.set_xscale('log') 

778 ax1.set_xlabel(xlab) 

779 

780 # define colors for plotting gaussians 

781 colidxs = -np.linspace(-1.25, -0.5, len(np.unique(classes))) 

782 

783 # plot the gaussians 

784 for i, c in enumerate(np.unique(classes)): 

785 ax2.plot(xplot, gaussians[c], c=lighter(c_grey, colidxs[i]), linewidth=2, 

786 label=r'$N(\mu_%i, \sigma_%i)$'%(c, c)) 

787 

788 # plot intersection lines 

789 ax2.vlines(xplot[1:][np.diff(classes)!=0], 0, gmax/gmin, color='k', linewidth=2, 

790 linestyle='--') 

791 ax2.set_ylim(gmin, np.max(np.vstack(gaussians))*1.1) 

792 

793 # plot data distributions and classes 

794 colidxs = -np.linspace(-1.25, -0.5, len(np.unique(labels))) 

795 for i, l in enumerate(np.unique(labels)): 

796 if use_log: 

797 h, binn, _ = loghist(ax1, x[labels==l], np.min(x), np.max(x), 100, 

798 lighter(ccol, colidxs[i]), label=r'$x_%i$'%l) 

799 else: 

800 h, binn, _ = ax1.hist(x[labels==l], bins=np.linspace(np.min(x), np.max(x), 100), 

801 color=lighter(ccol, colidxs[i]), label=r'$x_%i$'%l) 

802 

803 # annotate merged clusters 

804 for l in np.unique(labels_am): 

805 maps = np.unique(labels[labels_am==l]) 

806 if len(maps) > 1: 

807 x1 = x[labels==maps[0]] 

808 x2 = x[labels==maps[1]] 

809 

810 print(np.median(x1)) 

811 print(np.median(x2)) 

812 print(gmax) 

813 ax2.plot([np.median(x1), np.median(x2)], [1.2*gmax, 1.2*gmax], c='k', clip_on=False) 

814 ax2.plot([np.median(x1), np.median(x1)], [1.1*gmax, 1.2*gmax], c='k', clip_on=False) 

815 ax2.plot([np.median(x2), np.median(x2)], [1.1*gmax, 1.2*gmax], c='k', clip_on=False) 

816 ax2.annotate(r'$\frac{|{\tilde{x}_%i-\tilde{x}_%i}|}{max(\tilde{x}_%i, \tilde{x}_%i)} < \epsilon$' % (maps[0], maps[1], maps[0], maps[1]), [np.median(x1)*1.1, gmax*1.2], xytext=(10, 10), textcoords='offset points', fontsize=12, annotation_clip=False, ha='center') 

817 

818 # add legends and plot. 

819 ax2.legend(loc='lower left', frameon=False, bbox_to_anchor=(-0.05, 1.3), 

820 ncol=len(np.unique(classes))) 

821 ax1.legend(loc='upper left', frameon=False, bbox_to_anchor=(-0.05, 1.3), 

822 ncol=len(np.unique(labels))) 

823 plt.tight_layout() 

824 

825 

826def plot_feature_extraction(raw_snippets, normalized_snippets, features, labels, dt, pt): 

827 """Plot clustering step on EOD shape. 

828  

829 Parameters 

830 ---------- 

831 raw_snippets : 2D numpy array 

832 Raw EOD snippets. 

833 normalized_snippets : 2D numpy array 

834 Normalized EOD snippets. 

835 features : 2D numpy array 

836 PCA values for each normalized EOD snippet. 

837 labels : 1D numpy array of ints 

838 Cluster labels. 

839 dt : float 

840 Sample interval of snippets. 

841 pt : int 

842 Set to 0 for peak-centered EODs and set to 1 for trough-centered EODs. 

843 """ 

844 ccol = cmap_pts[pt] 

845 

846 # set up the figure layout 

847 fig = plt.figure(figsize=(((2+0.2)*4.8), 4.8)) 

848 outer = gridspec.GridSpec(1, 2, wspace=0.2, hspace=0) 

849 

850 x = np.arange(-dt*1000*raw_snippets.shape[1]/2, dt*1000*raw_snippets.shape[1]/2, dt*1000) 

851 

852 snip_ax = gridspec.GridSpecFromSubplotSpec(2, 1, subplot_spec = outer[0], hspace=0.35) 

853 pc_ax = gridspec.GridSpecFromSubplotSpec(features.shape[1]-1, features.shape[1]-1, 

854 subplot_spec = outer[1], hspace=0, wspace=0) 

855 

856 # 3 plots: raw snippets, normalized, pcs. 

857 ax_raw_snip = fig.add_subplot(snip_ax[0]) 

858 ax_normalized_snip = fig.add_subplot(snip_ax[1]) 

859 

860 colidxs = -np.linspace(-1.25, -0.5, len(np.unique(labels[labels>=0]))) 

861 j=0 

862 

863 for c in np.unique(labels): 

864 if c<0: 

865 color='lightgrey' 

866 else: 

867 color = lighter(ccol, colidxs[j]) 

868 j=j+1 

869 

870 ax_raw_snip.plot(x, raw_snippets[labels==c].T, color=color, label='-1', 

871 rasterized=True, alpha=0.25) 

872 ax_normalized_snip.plot(x, normalized_snippets[labels==c].T, color=color, alpha=0.25) 

873 ax_raw_snip.spines['top'].set_visible(False) 

874 ax_raw_snip.spines['right'].set_visible(False) 

875 ax_raw_snip.get_xaxis().set_ticklabels([]) 

876 ax_raw_snip.set_title('Raw snippets') 

877 ax_raw_snip.set_ylabel('Amplitude [a.u.]') 

878 ax_normalized_snip.spines['top'].set_visible(False) 

879 ax_normalized_snip.spines['right'].set_visible(False) 

880 ax_normalized_snip.set_title('Normalized snippets') 

881 ax_normalized_snip.set_ylabel('Amplitude [a.u.]') 

882 ax_normalized_snip.set_xlabel('Time [ms]') 

883 

884 ax_raw_snip.axis('off') 

885 ax_normalized_snip.axis('off') 

886 

887 ax_overlay = fig.add_subplot(pc_ax[:,:]) 

888 ax_overlay.set_title('Features') 

889 ax_overlay.axis('off') 

890 

891 for n in range(features.shape[1]): 

892 for m in range(n): 

893 ax = fig.add_subplot(pc_ax[n-1,m]) 

894 ax.scatter(features[labels==c,m], features[labels==c,n], marker='.', 

895 color=color, alpha=0.25) 

896 ax.set_xlim(np.min(features), np.max(features)) 

897 ax.set_ylim(np.min(features), np.max(features)) 

898 ax.get_xaxis().set_ticklabels([]) 

899 ax.get_yaxis().set_ticklabels([]) 

900 ax.get_xaxis().set_ticks([]) 

901 ax.get_yaxis().set_ticks([]) 

902 

903 if m==0: 

904 ax.set_ylabel('PC %i'%(n+1)) 

905 

906 if n==features.shape[1]-1: 

907 ax.set_xlabel('PC %i'%(m+1)) 

908 

909 ax = fig.add_subplot(pc_ax[0,features.shape[1]-2]) 

910 ax.set_xlim(np.min(features), np.max(features)) 

911 ax.set_ylim(np.min(features), np.max(features)) 

912 

913 size = max(1, int(np.ceil(-np.log10(np.max(features)-np.min(features))))) 

914 wbar = np.floor((np.max(features)-np.min(features))*10**size)/10**size 

915 

916 # should be smaller than the actual thing! so like x% of it? 

917 xscalebar(ax, 0, 0, wbar, wformat='%%.%if'%size) 

918 yscalebar(ax, 0, 0, wbar, hformat='%%.%if'%size) 

919 ax.axis('off') 

920 

921def plot_moving_fish(ws, dts, clusterss, ts, fishcounts, T, ignore_stepss): 

922 """Plot moving fish detection step. 

923 

924 Parameters 

925 ---------- 

926 ws : list of floats 

927 Median width for each width cluster that the moving fish algorithm is computed on 

928 (in seconds). 

929 dts : list of floats 

930 Sliding window size (in seconds) for each width cluster. 

931 clusterss : list of 1D numpy int arrays 

932 Cluster labels for each EOD cluster in a width cluster. 

933 ts : list of 1D numpy float arrays 

934 EOD emission times for each EOD in a width cluster. 

935 fishcounts : list of lists 

936 Sliding window timepoints and fishcounts for each width cluster. 

937 T : float 

938 Lenght of analyzed recording in seconds. 

939 ignore_stepss : list of 1D int arrays 

940 Mask for fishcounts that were ignored (ignored if True) in the moving_fish analysis. 

941 """ 

942 fig = plt.figure() 

943 

944 # create gridspec 

945 outer = gridspec.GridSpec(len(ws), 1) 

946 

947 for i, (w, dt, clusters, t, fishcount, ignore_steps) in enumerate(zip(ws, dts, clusterss, ts, fishcounts, ignore_stepss)): 

948 

949 gs = gridspec.GridSpecFromSubplotSpec(2, 1, subplot_spec = outer[i]) 

950 

951 # axis for clusters 

952 ax1 = fig.add_subplot(gs[0]) 

953 # axis for fishcount 

954 ax2 = fig.add_subplot(gs[1]) 

955 

956 # plot clusters as eventplot 

957 for cnum, c in enumerate(np.unique(clusters[clusters>=0])): 

958 ax1.eventplot(t[clusters==c], lineoffsets=cnum, linelengths=0.5, color=cmap(i)) 

959 cnum = cnum + 1 

960 

961 # Plot the sliding window 

962 rect=Rectangle((0, -0.5), dt, cnum, linewidth=1, linestyle='--', edgecolor='k', 

963 facecolor='none', clip_on=False) 

964 ax1.add_patch(rect) 

965 ax1.arrow(dt+0.1, -0.5, 0.5, 0, head_width=0.1, head_length=0.1, facecolor='k', 

966 edgecolor='k') 

967 

968 # plot parameters 

969 ax1.set_title(r'$\tilde{w}_%i = %.3f ms$'%(i, 1000*w)) 

970 ax1.set_ylabel('cluster #') 

971 ax1.set_yticks(range(0, cnum)) 

972 ax1.set_xlabel('time') 

973 ax1.set_xlim(0, T) 

974 ax1.axes.xaxis.set_visible(False) 

975 ax1.spines['bottom'].set_visible(False) 

976 ax1.spines['top'].set_visible(False) 

977 ax1.spines['right'].set_visible(False) 

978 ax1.spines['left'].set_visible(False) 

979 

980 # plot for fishcount 

981 x = fishcount[0] 

982 y = fishcount[1] 

983 

984 ax2 = fig.add_subplot(gs[1]) 

985 ax2.spines['top'].set_visible(False) 

986 ax2.spines['right'].set_visible(False) 

987 ax2.spines['bottom'].set_visible(False) 

988 ax2.axes.xaxis.set_visible(False) 

989 

990 yplot = np.copy(y) 

991 ax2.plot(x+dt/2, yplot, linestyle='-', marker='.', c=cmap(i), alpha=0.25) 

992 yplot[ignore_steps.astype(bool)] = np.NaN 

993 ax2.plot(x+dt/2, yplot, linestyle='-', marker='.', c=cmap(i)) 

994 ax2.set_ylabel('Fish count') 

995 ax2.set_yticks(range(int(np.min(y)), 1+int(np.max(y)))) 

996 ax2.set_xlim(0, T) 

997 

998 if i < len(ws)-1: 

999 ax2.axes.xaxis.set_visible(False) 

1000 else: 

1001 ax2.axes.xaxis.set_visible(False) 

1002 xscalebar(ax2, 1, 0, 1, wunit='s', ha='right') 

1003 

1004 con = ConnectionPatch([0, -0.5], [dt/2, y[0]], "data", "data", 

1005 axesA=ax1, axesB=ax2, color='k') 

1006 ax2.add_artist(con) 

1007 con = ConnectionPatch([dt, -0.5], [dt/2, y[0]], "data", "data", 

1008 axesA=ax1, axesB=ax2, color='k') 

1009 ax2.add_artist(con) 

1010 

1011 plt.xlim(0, T)