Coverage for src / thunderfish / pulseplots.py: 0%
487 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-15 17:50 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-15 17:50 +0000
1"""
2Plot and save key steps in pulses.py for visualizing the alorithm.
3"""
5import glob
6import numpy as np
7import matplotlib.pyplot as plt
9from scipy import stats
10from matplotlib import gridspec, ticker
11try:
12 from matplotlib.colors import colorConverter as cc
13except ImportError:
14 import matplotlib.colors as cc
15try:
16 from matplotlib.colors import to_hex
17except ImportError:
18 from matplotlib.colors import rgb2hex as to_hex
19from matplotlib.patches import ConnectionPatch, Rectangle
20from matplotlib.lines import Line2D
22import warnings
23def warn(*args, **kwargs):
24 """
25 Ignore all warnings.
26 """
27 pass
28warnings.warn=warn
31# plotting parameters and colors:
32cmap = plt.get_cmap("Dark2")
33c_g = cmap(0)
34c_o = cmap(1)
35c_grey = cmap(7)
36cmap_pts = [cmap(2), cmap(3)]
39def darker(color, saturation):
40 """ Make a color darker.
42 From bendalab/plottools package.
44 Parameters
45 ----------
46 color: dict or matplotlib color spec
47 A matplotlib color (hex string, name color string, rgb tuple)
48 or a dictionary with an 'color' or 'facecolor' key.
49 saturation: float
50 The smaller the saturation, the darker the returned color.
51 A saturation of 0 returns black.
52 A saturation of 1 leaves the color untouched.
53 A saturation of 2 returns white.
55 Returns
56 -------
57 color: string or dictionary
58 The darker color as a hexadecimal RGB string (e.g. '#rrggbb').
59 If `color` is a dictionary, a copy of the dictionary is returned
60 with the value of 'color' or 'facecolor' set to the darker color.
61 """
62 try:
63 c = color['color']
64 cd = dict(**color)
65 cd['color'] = darker(c, saturation)
66 return cd
67 except (KeyError, TypeError):
68 try:
69 c = color['facecolor']
70 cd = dict(**color)
71 cd['facecolor'] = darker(c, saturation)
72 return cd
73 except (KeyError, TypeError):
74 if saturation > 2:
75 sauration = 2
76 if saturation > 1:
77 return lighter(color, 2.0-saturation)
78 if saturation < 0:
79 saturation = 0
80 r, g, b = cc.to_rgb(color)
81 rd = r*saturation
82 gd = g*saturation
83 bd = b*saturation
84 return to_hex((rd, gd, bd)).upper()
87def lighter(color, lightness):
88 """Make a color lighter
90 From bendalab/plottools package.
92 Parameters
93 ----------
94 color: dict or matplotlib color spec
95 A matplotlib color (hex string, name color string, rgb tuple)
96 or a dictionary with an 'color' or 'facecolor' key.
97 lightness: float
98 The smaller the lightness, the lighter the returned color.
99 A lightness of 0 returns white.
100 A lightness of 1 leaves the color untouched.
101 A lightness of 2 returns black.
103 Returns
104 -------
105 color: string or dict
106 The lighter color as a hexadecimal RGB string (e.g. '#rrggbb').
107 If `color` is a dictionary, a copy of the dictionary is returned
108 with the value of 'color' or 'facecolor' set to the lighter color.
109 """
110 try:
111 c = color['color']
112 cd = dict(**color)
113 cd['color'] = lighter(c, lightness)
114 return cd
115 except (KeyError, TypeError):
116 try:
117 c = color['facecolor']
118 cd = dict(**color)
119 cd['facecolor'] = lighter(c, lightness)
120 return cd
121 except (KeyError, TypeError):
122 if lightness > 2:
123 lightness = 2
124 if lightness > 1:
125 return darker(color, 2.0-lightness)
126 if lightness < 0:
127 lightness = 0
128 r, g, b = cc.to_rgb(color)
129 rl = r + (1.0-lightness)*(1.0 - r)
130 gl = g + (1.0-lightness)*(1.0 - g)
131 bl = b + (1.0-lightness)*(1.0 - b)
132 return to_hex((rl, gl, bl)).upper()
135def xscalebar(ax, x, y, width, wunit=None, wformat=None, ha='left', va='bottom',
136 lw=None, color=None, capsize=None, clw=None, **kwargs):
137 """Horizontal scale bar with label.
139 From bendalab/plottools package.
141 Parameters
142 ----------
143 ax: matplotlib axes
144 Axes where to draw the scale bar.
145 x: float
146 x-coordinate where to draw the scale bar in relative units of the axes.
147 y: float
148 y-coordinate where to draw the scale bar in relative units of the axes.
149 width: float
150 Length of the scale bar in units of the data's x-values.
151 wunit: string or None
152 Optional unit of the data's x-values.
153 wformat: string or None
154 Optional format string for formatting the label of the scale bar
155 or simply a string used for labeling the scale bar.
156 ha: 'left', 'right', or 'center'
157 Scale bar aligned left, right, or centered to (x, y)
158 va: 'top' or 'bottom'
159 Label of the scale bar either above or below the scale bar.
160 lw: int, float, None
161 Line width of the scale bar.
162 color: matplotlib color
163 Color of the scalebar.
164 capsize: float or None
165 If larger then zero draw cap lines at the ends of the bar.
166 The length of the lines is given in points (same unit as linewidth).
167 clw: int, float, None
168 Line width of the cap lines.
169 kwargs: key-word arguments
170 Passed on to `ax.text()` used to print the scale bar label.
171 """
172 ax.autoscale(False)
173 # ax dimensions:
174 pixelx = np.abs(np.diff(ax.get_window_extent().get_points()[:,0]))[0]
175 pixely = np.abs(np.diff(ax.get_window_extent().get_points()[:,1]))[0]
176 xmin, xmax = ax.get_xlim()
177 ymin, ymax = ax.get_ylim()
178 unitx = xmax - xmin
179 unity = ymax - ymin
180 dxu = np.abs(unitx)/pixelx
181 dyu = np.abs(unity)/pixely
182 # transform x, y from relative units to axis units:
183 x = xmin + x*unitx
184 y = ymin + y*unity
185 # bar length:
186 if wformat is None:
187 wformat = '%.0f'
188 if width < 1.0:
189 wformat = '%.1f'
190 try:
191 ls = wformat % width
192 width = float(ls)
193 except TypeError:
194 ls = wformat
195 # bar:
196 if ha == 'left':
197 x0 = x
198 x1 = x+width
199 elif ha == 'right':
200 x0 = x-width
201 x1 = x
202 else:
203 x0 = x-0.5*width
204 x1 = x+0.5*width
205 # line width:
206 if lw is None:
207 lw = 2
208 # color:
209 if color is None:
210 color = 'k'
211 # scalebar:
212 lh = ax.plot([x0, x1], [y, y], '-', color=color, lw=lw,
213 solid_capstyle='butt', clip_on=False)
214 # get y position of line in figure pixel coordinates:
215 ly = np.array(lh[0].get_window_extent(ax.get_figure().canvas.get_renderer()))[0,1]
216 # caps:
217 if capsize is None:
218 capsize = 0
219 if clw is None:
220 clw = 0.5
221 if capsize > 0.0:
222 dy = capsize*dyu
223 ax.plot([x0, x0], [y-dy, y+dy], '-', color=color, lw=clw,
224 solid_capstyle='butt', clip_on=False)
225 ax.plot([x1, x1], [y-dy, y+dy], '-', color=color, lw=clw,
226 solid_capstyle='butt', clip_on=False)
227 # label:
228 if wunit:
229 ls += u'\u2009%s' % wunit
230 if va == 'top':
231 th = ax.text(0.5*(x0+x1), y, ls, clip_on=False,
232 ha='center', va='bottom', **kwargs)
233 # get y coordinate of text bottom in figure pixel coordinates:
234 ty = np.array(th.get_window_extent(ax.get_figure().canvas.get_renderer()))[0,1]
235 dty = ly+0.5*lw + 2.0 - ty
236 else:
237 th = ax.text(0.5*(x0+x1), y, ls, clip_on=False,
238 ha='center', va='top', **kwargs)
239 # get y coordinate of text bottom in figure pixel coordinates:
240 ty = np.array(th.get_window_extent(ax.get_figure().canvas.get_renderer()))[1,1]
241 dty = ly-0.5*lw - 2.0 - ty
242 th.set_position((0.5*(x0+x1), y+dyu*dty))
243 return x0, x1, y
246def yscalebar(ax, x, y, height, hunit=None, hformat=None, ha='left', va='bottom',
247 lw=None, color=None, capsize=None, clw=None, **kwargs):
249 """Vertical scale bar with label.
251 From bendalab/plottools package.
253 Parameters
254 ----------
255 ax: matplotlib axes
256 Axes where to draw the scale bar.
257 x: float
258 x-coordinate where to draw the scale bar in relative units of the axes.
259 y: float
260 y-coordinate where to draw the scale bar in relative units of the axes.
261 height: float
262 Length of the scale bar in units of the data's y-values.
263 hunit: string
264 Unit of the data's y-values.
265 hformat: string or None
266 Optional format string for formatting the label of the scale bar
267 or simply a string used for labeling the scale bar.
268 ha: 'left' or 'right'
269 Label of the scale bar either to the left or to the right
270 of the scale bar.
271 va: 'top', 'bottom', or 'center'
272 Scale bar aligned above, below, or centered on (x, y).
273 lw: int, float, None
274 Line width of the scale bar.
275 color: matplotlib color
276 Color of the scalebar.
277 capsize: float or None
278 If larger then zero draw cap lines at the ends of the bar.
279 The length of the lines is given in points (same unit as linewidth).
280 clw: int, float
281 Line width of the cap lines.
282 kwargs: key-word arguments
283 Passed on to `ax.text()` used to print the scale bar label.
284 """
286 ax.autoscale(False)
287 # ax dimensions:
288 pixelx = np.abs(np.diff(ax.get_window_extent().get_points()[:,0]))[0]
289 pixely = np.abs(np.diff(ax.get_window_extent().get_points()[:,1]))[0]
290 xmin, xmax = ax.get_xlim()
291 ymin, ymax = ax.get_ylim()
292 unitx = xmax - xmin
293 unity = ymax - ymin
294 dxu = np.abs(unitx)/pixelx
295 dyu = np.abs(unity)/pixely
296 # transform x, y from relative units to axis units:
297 x = xmin + x*unitx
298 y = ymin + y*unity
299 # bar length:
300 if hformat is None:
301 hformat = '%.0f'
302 if height < 1.0:
303 hformat = '%.1f'
304 try:
305 ls = hformat % height
306 width = float(ls)
307 except TypeError:
308 ls = hformat
309 # bar:
310 if va == 'bottom':
311 y0 = y
312 y1 = y+height
313 elif va == 'top':
314 y0 = y-height
315 y1 = y
316 else:
317 y0 = y-0.5*height
318 y1 = y+0.5*height
319 # line width:
320 if lw is None:
321 lw = 2
322 # color:
323 if color is None:
324 color = 'k'
325 # scalebar:
326 lh = ax.plot([x, x], [y0, y1], '-', color=color, lw=lw,
327 solid_capstyle='butt', clip_on=False)
328 # get x position of line in figure pixel coordinates:
329 lx = np.array(lh[0].get_window_extent(ax.get_figure().canvas.get_renderer()))[0,0]
330 # caps:
331 if capsize is None:
332 capsize = 0
333 if clw is None:
334 clw = 0.5
335 if capsize > 0.0:
336 dx = capsize*dxu
337 ax.plot([x-dx, x+dx], [y0, y0], '-', color=color, lw=clw, solid_capstyle='butt',
338 clip_on=False)
339 ax.plot([x-dx, x+dx], [y1, y1], '-', color=color, lw=clw, solid_capstyle='butt',
340 clip_on=False)
341 # label:
342 if hunit:
343 ls += u'\u2009%s' % hunit
344 if ha == 'right':
345 th = ax.text(x, 0.5*(y0+y1), ls, clip_on=False, rotation=90.0,
346 ha='left', va='center', **kwargs)
347 # get x coordinate of text bottom in figure pixel coordinates:
348 tx = np.array(th.get_window_extent(ax.get_figure().canvas.get_renderer()))[0,0]
349 dtx = lx+0.5*lw + 2.0 - tx
350 else:
351 th = ax.text(x, 0.5*(y0+y1), ls, clip_on=False, rotation=90.0,
352 ha='right', va='center', **kwargs)
353 # get x coordinate of text bottom in figure pixel coordinates:
354 tx = np.array(th.get_window_extent(ax.get_figure().canvas.get_renderer()))[1,0]
355 dtx = lx-0.5*lw - 1.0 - tx
356 th.set_position((x+dxu*dtx, 0.5*(y0+y1)))
357 return x, y0, y1
360def arrowed_spines(ax, ms=10):
361 """ Spine with arrow on the y-axis of a plot.
363 Parameters
364 ----------
365 ax : matplotlib figure axis
366 Axis on which the arrow should be plot.
367 """
368 xmin, xmax = ax.get_xlim()
369 ymin, ymax = ax.get_ylim()
370 ax.scatter([xmin], [ymax], s=ms, marker='^', clip_on=False, color='k')
371 ax.set_xlim(xmin, xmax)
372 ax.set_ylim(ymin, ymax)
375def loghist(ax, x, bmin, bmax, n, c, orientation='vertical', label=''):
376 """ Plot histogram with logarithmic scale.
378 Parameters
379 ----------
380 ax : matplotlib axis
381 Axis to plot the histogram on.
382 x : numpy array
383 Input data for histogram.
384 bmin : float
385 Minimum value for the histogram bins.
386 bmax : float
387 Maximum value for the histogram bins.
388 n : int
389 Number of bins.
390 c : matplotlib color
391 Color of histogram.
392 orientation : string (optional)
393 Histogram orientation.
394 Defaults to 'vertical'.
395 label : string (optional)
396 Label for x.
397 Defaults to '' (no label).
399 Returns
400 -------
401 n : array
402 The values of the histogram bins.
403 bins : array
404 The edges of the bins.
405 patches : BarContainer
406 Container of individual artists used to create the histogram.
407 """
408 return ax.hist(x, bins=np.exp(np.linspace(np.log(bmin), np.log(bmax), n)),
409 color=c, orientation=orientation, label=label)
412def plot_all(data, eod_p_times, eod_tr_times, fs, mean_eods):
413 """Quick way to view the output of extract_pulsefish in a single plot.
415 Parameters
416 ----------
417 data: array
418 Recording data.
419 eod_p_times: array of ints
420 EOD peak indices.
421 eod_tr_times: array of ints
422 EOD trough indices.
423 fs: float
424 Sampling rate.
425 mean_eods: list of numpy arrays
426 Mean EODs of each pulsefish found in the recording.
427 """
428 fig = plt.figure(figsize=(10, 5))
430 if len(eod_p_times) > 0:
431 gs = gridspec.GridSpec(2, len(eod_p_times))
432 ax = fig.add_subplot(gs[0,:])
433 ax.plot(np.arange(len(data))/fs, data, c='k', alpha=0.3)
435 for i, (pt, tt) in enumerate(zip(eod_p_times, eod_tr_times)):
436 ax.plot(pt, data[(pt*fs).astype('int')], 'o', label=i+1, ms=10, c=cmap(i))
437 ax.plot(tt, data[(tt*fs).astype('int')], 'o', label=i+1, ms=10, c=cmap(i))
439 ax.set_xlabel('time [s]')
440 ax.set_ylabel('amplitude [V]')
442 for i, m in enumerate(mean_eods):
443 ax = fig.add_subplot(gs[1,i])
444 ax.plot(1000*m[0], 1000*m[1], c='k')
446 ax.fill_between(1000*m[0], 1000*(m[1]-m[2]), 1000*(m[1]+m[2]), color=cmap(i))
447 ax.set_xlabel('time [ms]')
448 ax.set_ylabel('amplitude [mV]')
449 else:
450 plt.plot(np.arange(len(data))/fs, data, c='k', alpha=0.3)
452 plt.tight_layout()
455def plot_clustering(rate, eod_widths, eod_hights, eod_shapes, disc_masks, merge_masks):
456 """Plot all clustering steps.
458 Plot clustering steps on width, height and shape. Then plot the remaining EODs after
459 the EOD assessment step and the EODs after the merge step.
461 Parameters
462 ----------
463 rate : float
464 Sampling rate of EOD snippets.
465 eod_widths : list of three 1D numpy arrays
466 The first list entry gives the unique labels of all width clusters as a list of ints.
467 The second list entry gives the width values for each EOD in samples as a
468 1D numpy array of ints.
469 The third list entry gives the width labels for each EOD as a 1D numpy array of ints.
470 eod_hights : nested lists (2 layers) of three 1D numpy arrays
471 The first list entry gives the unique labels of all height clusters as a list of ints
472 for each width cluster.
473 The second list entry gives the height values for each EOD as a 1D numpy array
474 of floats for each width cluster.
475 The third list entry gives the height labels for each EOD as a 1D numpy array
476 of ints for each width cluster.
477 eod_shapes : nested lists (3 layers) of three 1D numpy arrays
478 The first list entry gives the raw EOD snippets as a 2D numpy array for each
479 height cluster in a width cluster.
480 The second list entry gives the snippet PCA values for each EOD as a 2D numpy array
481 of floats for each height cluster in a width cluster.
482 The third list entry gives the shape labels for each EOD as a 1D numpy array of ints
483 for each height cluster in a width cluster.
484 disc_masks : Nested lists (two layers) of 1D numpy arrays
485 The masks of EODs that are discarded by the discarding step of the algorithm.
486 The masks are 1D boolean arrays where
487 instances that are set to True are discarded by the algorithm. Discarding masks
488 are saved in nested lists that represent the width and height clusters.
489 merge_masks : Nested lists (two layers) of 2D numpy arrays
490 The masks of EODs that are discarded by the merging step of the algorithm.
491 The masks are 2D boolean arrays where
492 for each sample point `i` either `merge_mask[i,0]` or `merge_mask[i,1]` is set to True.
493 Here, merge_mask[:,0] represents the
494 peak-centered clusters and `merge_mask[:,1]` represents the trough-centered clusters.
495 Merge masks are saved in nested lists
496 that represent the width and height clusters.
497 """
498 # create figure + transparant figure.
499 fig = plt.figure(figsize=(12, 7))
500 transFigure = fig.transFigure.inverted()
502 # set up the figure layout
503 outer = gridspec.GridSpec(1, 5, width_ratios=[1, 1, 2, 1, 2], left=0.05, right=0.95)
505 # set titles for each clustering step
506 titles = ['1. Widths', '2. Heights', '3. Shape', '4. Pulse EODs', '5. Merge']
507 for i, title in enumerate(titles):
508 title_ax = gridspec.GridSpecFromSubplotSpec(1, 1, subplot_spec = outer[i])
509 ax = fig.add_subplot(title_ax[0])
510 ax.text(0, 110, title, ha='center', va='bottom', clip_on=False)
511 ax.set_xlim(-100, 100)
512 ax.set_ylim(-100, 100)
513 ax.axis('off')
515 # compute sizes for each axis
516 w_size = 1
517 h_size = len(eod_hights[1])
519 shape_size = np.sum([len(sl) for sl in eod_shapes[0]])
521 # count required axes sized for the last two plot columns.
522 disc_size = 0
523 merge_size= 0
524 for shapelabel, dmasks, mmasks in zip(eod_shapes[2], disc_masks, merge_masks):
525 for sl, dm, mm in zip(shapelabel, dmasks, mmasks):
526 uld1 = np.unique((sl[0]+1)*np.invert(dm[0]))
527 uld2 = np.unique((sl[1]+1)*np.invert(dm[1]))
528 disc_size = disc_size+len(uld1[uld1>0])+len(uld2[uld2>0])
530 uld1 = np.unique((sl[0]+1)*mm[0])
531 uld2 = np.unique((sl[1]+1)*mm[1])
532 merge_size = merge_size+len(uld1[uld1>0])+len(uld2[uld2>0])
534 # set counters to keep track of the plot axes
535 disc_block = 0
536 merge_block = 0
537 shape_count = 0
539 # create all axes
540 width_hist_ax = gridspec.GridSpecFromSubplotSpec(w_size, 1, subplot_spec = outer[0])
541 hight_hist_ax = gridspec.GridSpecFromSubplotSpec(h_size, 1, subplot_spec = outer[1])
542 shape_ax = gridspec.GridSpecFromSubplotSpec(shape_size, 1, subplot_spec = outer[2])
543 shape_windows = [gridspec.GridSpecFromSubplotSpec(2, 2, hspace=0.0, wspace=0.0,
544 subplot_spec=shape_ax[i])
545 for i in range(shape_size)]
547 EOD_delete_ax = gridspec.GridSpecFromSubplotSpec(disc_size, 1, subplot_spec=outer[3])
548 EOD_merge_ax = gridspec.GridSpecFromSubplotSpec(merge_size, 1, subplot_spec=outer[4])
550 # plot width labels histogram
551 ax1 = fig.add_subplot(width_hist_ax[0])
552 # set axes features.
553 ax1.set_xscale('log')
554 ax1.spines['top'].set_visible(False)
555 ax1.spines['right'].set_visible(False)
556 ax1.spines['bottom'].set_visible(False)
557 ax1.axes.xaxis.set_visible(False)
558 ax1.set_yticklabels([])
560 # indices for plot colors (dark to light)
561 colidxsw = -np.linspace(-1.25, -0.5, h_size)
563 for i, (wl, colw, uhl, eod_h, eod_h_labs, w_snip, w_feat, w_lab, w_dm, w_mm) in enumerate(zip(eod_widths[0], colidxsw, eod_hights[0], eod_hights[1], eod_hights[2], eod_shapes[0], eod_shapes[1], eod_shapes[2], disc_masks, merge_masks)):
565 # plot width hist
566 hw, _, _ = ax1.hist(eod_widths[1][eod_widths[2]==wl],
567 bins=np.linspace(np.min(eod_widths[1]), np.max(eod_widths[1]), 100),
568 color=lighter(c_o, colw), orientation='horizontal')
570 # set arrow when the last hist is plot so the size of the axes are known.
571 if i == h_size-1:
572 arrowed_spines(ax1, ms=20)
574 # determine total size of the hight historgams now.
575 my, b = np.histogram(eod_h, bins=np.exp(np.linspace(np.min(np.log(eod_h)),
576 np.max(np.log(eod_h)), 100)))
577 maxy = np.max(my)
579 # set axes features for hight hist.
580 ax2 = fig.add_subplot(hight_hist_ax[h_size-i-1])
581 ax2.set_xscale('log')
582 ax2.spines['top'].set_visible(False)
583 ax2.spines['right'].set_visible(False)
584 ax2.spines['bottom'].set_visible(False)
585 ax2.set_xlim(0.9, maxy)
586 ax2.axes.xaxis.set_visible(False)
587 ax2.set_yscale('log')
588 ax2.yaxis.set_major_formatter(ticker.NullFormatter())
589 ax2.yaxis.set_minor_formatter(ticker.NullFormatter())
591 # define colors for plots
592 colidxsh = -np.linspace(-1.25, -0.5, len(uhl))
594 for n, (hl, hcol, snippets, features, labels, dmasks, mmasks) in enumerate(zip(uhl, colidxsh, w_snip, w_feat, w_lab, w_dm, w_mm)):
596 hh, _, _ = loghist(ax2, eod_h[eod_h_labs==hl], np.min(eod_h), np.max(eod_h), 100,
597 lighter(c_g, hcol), orientation='horizontal')
599 # set arrow spines only on last plot
600 if n == len(uhl)-1:
601 arrowed_spines(ax2, ms=10)
603 # plot line from the width histogram to the height histogram.
604 if n == 0:
605 coord1 = transFigure.transform(ax1.transData.transform([np.median(hw[hw!=0]),
606 np.median(eod_widths[1][eod_widths[2]==wl])]))
607 coord2 = transFigure.transform(ax2.transData.transform([0.9, np.mean(eod_h)]))
608 line = Line2D((coord1[0], coord2[0]), (coord1[1], coord2[1]),
609 transform=fig.transFigure, color='grey', linewidth=0.5)
610 fig.lines.append(line)
612 # compute sizes of the eod_discarding and merge steps
613 s1 = np.unique((labels[0]+1)*(~dmasks[0]))
614 s2 = np.unique((labels[1]+1)*(~dmasks[1]))
615 disc_block = disc_block + len(s1[s1>0]) + len(s2[s2>0])
617 s1 = np.unique((labels[0]+1)*(mmasks[0]))
618 s2 = np.unique((labels[1]+1)*(mmasks[1]))
619 merge_block = merge_block + len(s1[s1>0]) + len(s2[s2>0])
621 axs = []
622 disc_count = 0
623 merge_count = 0
625 # now plot the clusters for peak and trough centerings
626 for pt, cmap_pt in zip([0, 1], cmap_pts):
628 ax3 = fig.add_subplot(shape_windows[shape_size-1-shape_count][pt,0])
629 ax4 = fig.add_subplot(shape_windows[shape_size-1-shape_count][pt,1])
631 # remove axes
632 ax3.axes.xaxis.set_visible(False)
633 ax4.axes.yaxis.set_visible(False)
634 ax3.axes.yaxis.set_visible(False)
635 ax4.axes.xaxis.set_visible(False)
637 # set color indices
638 colidxss = -np.linspace(-1.25, -0.5, len(np.unique(labels[pt][labels[pt]>=0])))
639 j=0
640 for c in np.unique(labels[pt]):
642 if c<0:
643 # plot noise features + snippets
644 ax3.plot(features[pt][labels[pt]==c,0], features[pt][labels[pt]==c,1],
645 '.', color='lightgrey', label='-1', rasterized=True)
646 ax4.plot(snippets[pt][labels[pt]==c].T, linewidth=0.1,
647 color='lightgrey', label='-1', rasterized=True)
648 else:
649 # plot cluster features and snippets
650 ax3.plot(features[pt][labels[pt]==c,0], features[pt][labels[pt]==c,1],
651 '.', color=lighter(cmap_pt, colidxss[j]), label=c,
652 rasterized=True)
653 ax4.plot(snippets[pt][labels[pt]==c].T, linewidth=0.1,
654 color=lighter(cmap_pt, colidxss[j]), label=c, rasterized=True)
656 # check if the current cluster is an EOD, if yes, plot it.
657 if np.sum(dmasks[pt][labels[pt]==c]) == 0:
659 ax = fig.add_subplot(EOD_delete_ax[disc_size-disc_block+disc_count])
660 ax.axis('off')
662 # plot mean EOD snippet
663 ax.plot(np.mean(snippets[pt][labels[pt]==c], axis=0),
664 color=lighter(cmap_pt, colidxss[j]))
665 disc_count = disc_count + 1
667 # match colors and draw line..
668 coord1 = transFigure.transform(ax4.transData.transform([ax4.get_xlim()[1],
669 ax4.get_ylim()[0] + 0.5*(ax4.get_ylim()[1]-ax4.get_ylim()[0])]))
670 coord2 = transFigure.transform(ax.transData.transform([ax.get_xlim()[0],ax.get_ylim()[0] + 0.5*(ax.get_ylim()[1]-ax.get_ylim()[0])]))
671 line = Line2D((coord1[0], coord2[0]), (coord1[1], coord2[1]),
672 transform=fig.transFigure, color='grey',
673 linewidth=0.5)
674 fig.lines.append(line)
675 axs.append(ax)
677 # check if the current EOD survives the merge step
678 # if so, plot it.
679 if np.sum(mmasks[pt, labels[pt]==c])>0:
681 ax = fig.add_subplot(EOD_merge_ax[merge_size-merge_block+merge_count])
682 ax.axis('off')
684 ax.plot(np.mean(snippets[pt][labels[pt]==c], axis=0),
685 color=lighter(cmap_pt, colidxss[j]))
686 merge_count = merge_count + 1
688 j=j+1
690 if pt==0:
691 # draw line from hight cluster to EOD shape clusters.
692 coord1 = transFigure.transform(ax2.transData.transform([np.median(hh[hh!=0]),
693 np.median(eod_h[eod_h_labs==hl])]))
694 coord2 = transFigure.transform(ax3.transData.transform([ax3.get_xlim()[0],
695 ax3.get_ylim()[0]]))
696 line = Line2D((coord1[0], coord2[0]), (coord1[1], coord2[1]),
697 transform=fig.transFigure, color='grey', linewidth=0.5)
698 fig.lines.append(line)
700 shape_count = shape_count + 1
702 if len(axs)>0:
703 # plot lines that indicate the merged clusters.
704 coord1 = transFigure.transform(axs[0].transData.transform([axs[0].get_xlim()[1]+0.1*(axs[0].get_xlim()[1]-axs[0].get_xlim()[0]),
705 axs[0].get_ylim()[1]-0.25*(axs[0].get_ylim()[1]-axs[0].get_ylim()[0])]))
706 coord2 = transFigure.transform(axs[-1].transData.transform([axs[-1].get_xlim()[1]+0.1*(axs[-1].get_xlim()[1]-axs[-1].get_xlim()[0]),
707 axs[-1].get_ylim()[0]+0.25*(axs[-1].get_ylim()[1]-axs[-1].get_ylim()[0])]))
708 line = Line2D((coord1[0], coord2[0]), (coord1[1], coord2[1]),
709 transform=fig.transFigure, color='grey', linewidth=1)
710 fig.lines.append(line)
713def plot_bgm(x, means, variances, weights, use_log, labels, labels_am, xlab):
714 """Plot a BGM clustering step either on EOD width or height.
716 Parameters
717 ----------
718 x : 1D numpy array of floats
719 BGM input values.
720 means : list of floats
721 BGM Gaussian means
722 variances : list of floats
723 BGM Gaussian variances.
724 weights : list of floats
725 BGM Gaussian weights.
726 use_log : boolean
727 True if the z-scored logarithm of the data was used as BGM input.
728 labels : 1D numpy array of ints
729 Labels defined by BGM model (before merging based on merge factor).
730 labels_am : 1D numpy array of ints
731 Labels defined by BGM model (after merging based on merge factor).
732 xlab : string
733 Label for plot (defines the units of the BGM data).
734 """
735 if 'width' in xlab:
736 ccol = c_o
737 elif 'height' in xlab:
738 ccol = c_g
739 else:
740 ccol = 'b'
742 # get the transform that was used as BGM input
743 if use_log:
744 x_transform = stats.zscore(np.log(x))
745 xplot = np.exp(np.linspace(np.log(np.min(x)), np.log(np.max(x)), 1000))
746 else:
747 x_transform = stats.zscore(x)
748 xplot = np.linspace(np.min(x), np.max(x), 1000)
750 # compute the x values and gaussians
751 x2 = np.linspace(np.min(x_transform), np.max(x_transform), 1000)
752 gaussians = []
753 gmax = 0
754 for i, (w, m, std) in enumerate(zip(weights, means, variances)):
755 gaus = np.sqrt(w*stats.norm.pdf(x2, m, np.sqrt(std)))
756 gaussians.append(gaus)
757 gmax = max(np.max(gaus), gmax)
759 # compute classes defined by gaussian intersections
760 classes = np.argmax(np.vstack(gaussians), axis=0)
762 # find the minimum of any gaussian that is within its class
763 gmin = 100
764 for i, c in enumerate(np.unique(classes)):
765 gmin=min(gmin, np.min(gaussians[c][classes==c]))
767 # set up the figure
768 fig, ax1 = plt.subplots(figsize=(8, 4.8))
769 fig_ysize = 4
770 ax2 = ax1.twinx()
771 ax1.spines['top'].set_visible(False)
772 ax2.spines['top'].set_visible(False)
773 ax1.set_xlabel('x [a.u.]')
774 ax1.set_ylabel('#')
775 ax2.set_ylabel('Likelihood')
776 ax2.set_yscale('log')
777 ax1.set_yscale('log')
778 if use_log:
779 ax1.set_xscale('log')
780 ax1.set_xlabel(xlab)
782 # define colors for plotting gaussians
783 colidxs = -np.linspace(-1.25, -0.5, len(np.unique(classes)))
785 # plot the gaussians
786 for i, c in enumerate(np.unique(classes)):
787 ax2.plot(xplot, gaussians[c], c=lighter(c_grey, colidxs[i]), linewidth=2,
788 label=r'$N(\mu_%i, \sigma_%i)$'%(c, c))
790 # plot intersection lines
791 ax2.vlines(xplot[1:][np.diff(classes)!=0], 0, gmax/gmin, color='k', linewidth=2,
792 linestyle='--')
793 ax2.set_ylim(gmin, np.max(np.vstack(gaussians))*1.1)
795 # plot data distributions and classes
796 colidxs = -np.linspace(-1.25, -0.5, len(np.unique(labels)))
797 for i, l in enumerate(np.unique(labels)):
798 if use_log:
799 h, binn, _ = loghist(ax1, x[labels==l], np.min(x), np.max(x), 100,
800 lighter(ccol, colidxs[i]), label=r'$x_%i$'%l)
801 else:
802 h, binn, _ = ax1.hist(x[labels==l], bins=np.linspace(np.min(x), np.max(x), 100),
803 color=lighter(ccol, colidxs[i]), label=r'$x_%i$'%l)
805 # annotate merged clusters
806 for l in np.unique(labels_am):
807 maps = np.unique(labels[labels_am==l])
808 if len(maps) > 1:
809 x1 = x[labels==maps[0]]
810 x2 = x[labels==maps[1]]
812 print(np.median(x1))
813 print(np.median(x2))
814 print(gmax)
815 ax2.plot([np.median(x1), np.median(x2)], [1.2*gmax, 1.2*gmax], c='k', clip_on=False)
816 ax2.plot([np.median(x1), np.median(x1)], [1.1*gmax, 1.2*gmax], c='k', clip_on=False)
817 ax2.plot([np.median(x2), np.median(x2)], [1.1*gmax, 1.2*gmax], c='k', clip_on=False)
818 ax2.annotate(r'$\frac{|{\tilde{x}_%i-\tilde{x}_%i}|}{max(\tilde{x}_%i, \tilde{x}_%i)} < \epsilon$' % (maps[0], maps[1], maps[0], maps[1]), [np.median(x1)*1.1, gmax*1.2], xytext=(10, 10), textcoords='offset points', fontsize=12, annotation_clip=False, ha='center')
820 # add legends and plot.
821 ax2.legend(loc='lower left', frameon=False, bbox_to_anchor=(-0.05, 1.3),
822 ncol=len(np.unique(classes)))
823 ax1.legend(loc='upper left', frameon=False, bbox_to_anchor=(-0.05, 1.3),
824 ncol=len(np.unique(labels)))
825 plt.tight_layout()
828def plot_feature_extraction(raw_snippets, normalized_snippets, features, labels, dt, pt):
829 """Plot clustering step on EOD shape.
831 Parameters
832 ----------
833 raw_snippets : 2D numpy array
834 Raw EOD snippets.
835 normalized_snippets : 2D numpy array
836 Normalized EOD snippets.
837 features : 2D numpy array
838 PCA values for each normalized EOD snippet.
839 labels : 1D numpy array of ints
840 Cluster labels.
841 dt : float
842 Sample interval of snippets.
843 pt : int
844 Set to 0 for peak-centered EODs and set to 1 for trough-centered EODs.
845 """
846 ccol = cmap_pts[pt]
848 # set up the figure layout
849 fig = plt.figure(figsize=(((2+0.2)*4.8), 4.8))
850 outer = gridspec.GridSpec(1, 2, wspace=0.2, hspace=0)
852 x = np.arange(-dt*1000*raw_snippets.shape[1]/2, dt*1000*raw_snippets.shape[1]/2, dt*1000)
854 snip_ax = gridspec.GridSpecFromSubplotSpec(2, 1, subplot_spec = outer[0], hspace=0.35)
855 pc_ax = gridspec.GridSpecFromSubplotSpec(features.shape[1]-1, features.shape[1]-1,
856 subplot_spec = outer[1], hspace=0, wspace=0)
858 # 3 plots: raw snippets, normalized, pcs.
859 ax_raw_snip = fig.add_subplot(snip_ax[0])
860 ax_normalized_snip = fig.add_subplot(snip_ax[1])
862 colidxs = -np.linspace(-1.25, -0.5, len(np.unique(labels[labels>=0])))
863 j=0
865 for c in np.unique(labels):
866 if c<0:
867 color='lightgrey'
868 else:
869 color = lighter(ccol, colidxs[j])
870 j=j+1
872 ax_raw_snip.plot(x, raw_snippets[labels==c].T, color=color, label='-1',
873 rasterized=True, alpha=0.25)
874 ax_normalized_snip.plot(x, normalized_snippets[labels==c].T, color=color, alpha=0.25)
875 ax_raw_snip.spines['top'].set_visible(False)
876 ax_raw_snip.spines['right'].set_visible(False)
877 ax_raw_snip.get_xaxis().set_ticklabels([])
878 ax_raw_snip.set_title('Raw snippets')
879 ax_raw_snip.set_ylabel('Amplitude [a.u.]')
880 ax_normalized_snip.spines['top'].set_visible(False)
881 ax_normalized_snip.spines['right'].set_visible(False)
882 ax_normalized_snip.set_title('Normalized snippets')
883 ax_normalized_snip.set_ylabel('Amplitude [a.u.]')
884 ax_normalized_snip.set_xlabel('Time [ms]')
886 ax_raw_snip.axis('off')
887 ax_normalized_snip.axis('off')
889 ax_overlay = fig.add_subplot(pc_ax[:,:])
890 ax_overlay.set_title('Features')
891 ax_overlay.axis('off')
893 for n in range(features.shape[1]):
894 for m in range(n):
895 ax = fig.add_subplot(pc_ax[n-1,m])
896 ax.scatter(features[labels==c,m], features[labels==c,n], marker='.',
897 color=color, alpha=0.25)
898 ax.set_xlim(np.min(features), np.max(features))
899 ax.set_ylim(np.min(features), np.max(features))
900 ax.get_xaxis().set_ticklabels([])
901 ax.get_yaxis().set_ticklabels([])
902 ax.get_xaxis().set_ticks([])
903 ax.get_yaxis().set_ticks([])
905 if m==0:
906 ax.set_ylabel('PC %i'%(n+1))
908 if n==features.shape[1]-1:
909 ax.set_xlabel('PC %i'%(m+1))
911 ax = fig.add_subplot(pc_ax[0,features.shape[1]-2])
912 ax.set_xlim(np.min(features), np.max(features))
913 ax.set_ylim(np.min(features), np.max(features))
915 size = max(1, int(np.ceil(-np.log10(np.max(features)-np.min(features)))))
916 wbar = np.floor((np.max(features)-np.min(features))*10**size)/10**size
918 # should be smaller than the actual thing! so like x% of it?
919 xscalebar(ax, 0, 0, wbar, wformat='%%.%if'%size)
920 yscalebar(ax, 0, 0, wbar, hformat='%%.%if'%size)
921 ax.axis('off')
923def plot_moving_fish(ws, dts, clusterss, ts, fishcounts, T, ignore_stepss):
924 """Plot moving fish detection step.
926 Parameters
927 ----------
928 ws : list of floats
929 Median width for each width cluster that the moving fish algorithm is computed on
930 (in seconds).
931 dts : list of floats
932 Sliding window size (in seconds) for each width cluster.
933 clusterss : list of 1D numpy int arrays
934 Cluster labels for each EOD cluster in a width cluster.
935 ts : list of 1D numpy float arrays
936 EOD emission times for each EOD in a width cluster.
937 fishcounts : list of lists
938 Sliding window timepoints and fishcounts for each width cluster.
939 T : float
940 Lenght of analyzed recording in seconds.
941 ignore_stepss : list of 1D int arrays
942 Mask for fishcounts that were ignored (ignored if True) in the moving_fish analysis.
943 """
944 fig = plt.figure()
946 # create gridspec
947 outer = gridspec.GridSpec(len(ws), 1)
949 for i, (w, dt, clusters, t, fishcount, ignore_steps) in enumerate(zip(ws, dts, clusterss, ts, fishcounts, ignore_stepss)):
951 gs = gridspec.GridSpecFromSubplotSpec(2, 1, subplot_spec = outer[i])
953 # axis for clusters
954 ax1 = fig.add_subplot(gs[0])
955 # axis for fishcount
956 ax2 = fig.add_subplot(gs[1])
958 # plot clusters as eventplot
959 for cnum, c in enumerate(np.unique(clusters[clusters>=0])):
960 ax1.eventplot(t[clusters==c], lineoffsets=cnum, linelengths=0.5, color=cmap(i))
961 cnum = cnum + 1
963 # Plot the sliding window
964 rect=Rectangle((0, -0.5), dt, cnum, linewidth=1, linestyle='--', edgecolor='k',
965 facecolor='none', clip_on=False)
966 ax1.add_patch(rect)
967 ax1.arrow(dt+0.1, -0.5, 0.5, 0, head_width=0.1, head_length=0.1, facecolor='k',
968 edgecolor='k')
970 # plot parameters
971 ax1.set_title(r'$\tilde{w}_%i = %.3f ms$'%(i, 1000*w))
972 ax1.set_ylabel('cluster #')
973 ax1.set_yticks(range(0, cnum))
974 ax1.set_xlabel('time')
975 ax1.set_xlim(0, T)
976 ax1.axes.xaxis.set_visible(False)
977 ax1.spines['bottom'].set_visible(False)
978 ax1.spines['top'].set_visible(False)
979 ax1.spines['right'].set_visible(False)
980 ax1.spines['left'].set_visible(False)
982 # plot for fishcount
983 x = fishcount[0]
984 y = fishcount[1]
986 ax2 = fig.add_subplot(gs[1])
987 ax2.spines['top'].set_visible(False)
988 ax2.spines['right'].set_visible(False)
989 ax2.spines['bottom'].set_visible(False)
990 ax2.axes.xaxis.set_visible(False)
992 yplot = np.copy(y)
993 ax2.plot(x+dt/2, yplot, linestyle='-', marker='.', c=cmap(i), alpha=0.25)
994 yplot[ignore_steps.astype(bool)] = np.nan
995 ax2.plot(x+dt/2, yplot, linestyle='-', marker='.', c=cmap(i))
996 ax2.set_ylabel('Fish count')
997 ax2.set_yticks(range(int(np.min(y)), 1+int(np.max(y))))
998 ax2.set_xlim(0, T)
1000 if i < len(ws)-1:
1001 ax2.axes.xaxis.set_visible(False)
1002 else:
1003 ax2.axes.xaxis.set_visible(False)
1004 xscalebar(ax2, 1, 0, 1, wunit='s', ha='right')
1006 con = ConnectionPatch([0, -0.5], [dt/2, y[0]], "data", "data",
1007 axesA=ax1, axesB=ax2, color='k')
1008 ax2.add_artist(con)
1009 con = ConnectionPatch([dt, -0.5], [dt/2, y[0]], "data", "data",
1010 axesA=ax1, axesB=ax2, color='k')
1011 ax2.add_artist(con)
1013 plt.xlim(0, T)