Coverage for src/thunderfish/pulseplots.py: 0%
487 statements
« prev ^ index » next coverage.py v7.5.0, created at 2024-04-29 16:21 +0000
« prev ^ index » next coverage.py v7.5.0, created at 2024-04-29 16:21 +0000
1"""
2Plot and save key steps in pulses.py for visualizing the alorithm.
3"""
5import glob
6import numpy as np
7from scipy import stats
8from matplotlib import gridspec, ticker
9import matplotlib.pyplot as plt
10try:
11 from matplotlib.colors import colorConverter as cc
12except ImportError:
13 import matplotlib.colors as cc
14try:
15 from matplotlib.colors import to_hex
16except ImportError:
17 from matplotlib.colors import rgb2hex as to_hex
18from matplotlib.patches import ConnectionPatch, Rectangle
19from matplotlib.lines import Line2D
20import warnings
21def warn(*args, **kwargs):
22 """
23 Ignore all warnings.
24 """
25 pass
26warnings.warn=warn
29# plotting parameters and colors:
30cmap = plt.get_cmap("Dark2")
31c_g = cmap(0)
32c_o = cmap(1)
33c_grey = cmap(7)
34cmap_pts = [cmap(2), cmap(3)]
37def darker(color, saturation):
38 """ Make a color darker.
40 From bendalab/plottools package.
42 Parameters
43 ----------
44 color: dict or matplotlib color spec
45 A matplotlib color (hex string, name color string, rgb tuple)
46 or a dictionary with an 'color' or 'facecolor' key.
47 saturation: float
48 The smaller the saturation, the darker the returned color.
49 A saturation of 0 returns black.
50 A saturation of 1 leaves the color untouched.
51 A saturation of 2 returns white.
53 Returns
54 -------
55 color: string or dictionary
56 The darker color as a hexadecimal RGB string (e.g. '#rrggbb').
57 If `color` is a dictionary, a copy of the dictionary is returned
58 with the value of 'color' or 'facecolor' set to the darker color.
59 """
60 try:
61 c = color['color']
62 cd = dict(**color)
63 cd['color'] = darker(c, saturation)
64 return cd
65 except (KeyError, TypeError):
66 try:
67 c = color['facecolor']
68 cd = dict(**color)
69 cd['facecolor'] = darker(c, saturation)
70 return cd
71 except (KeyError, TypeError):
72 if saturation > 2:
73 sauration = 2
74 if saturation > 1:
75 return lighter(color, 2.0-saturation)
76 if saturation < 0:
77 saturation = 0
78 r, g, b = cc.to_rgb(color)
79 rd = r*saturation
80 gd = g*saturation
81 bd = b*saturation
82 return to_hex((rd, gd, bd)).upper()
85def lighter(color, lightness):
86 """Make a color lighter
88 From bendalab/plottools package.
90 Parameters
91 ----------
92 color: dict or matplotlib color spec
93 A matplotlib color (hex string, name color string, rgb tuple)
94 or a dictionary with an 'color' or 'facecolor' key.
95 lightness: float
96 The smaller the lightness, the lighter the returned color.
97 A lightness of 0 returns white.
98 A lightness of 1 leaves the color untouched.
99 A lightness of 2 returns black.
101 Returns
102 -------
103 color: string or dict
104 The lighter color as a hexadecimal RGB string (e.g. '#rrggbb').
105 If `color` is a dictionary, a copy of the dictionary is returned
106 with the value of 'color' or 'facecolor' set to the lighter color.
107 """
108 try:
109 c = color['color']
110 cd = dict(**color)
111 cd['color'] = lighter(c, lightness)
112 return cd
113 except (KeyError, TypeError):
114 try:
115 c = color['facecolor']
116 cd = dict(**color)
117 cd['facecolor'] = lighter(c, lightness)
118 return cd
119 except (KeyError, TypeError):
120 if lightness > 2:
121 lightness = 2
122 if lightness > 1:
123 return darker(color, 2.0-lightness)
124 if lightness < 0:
125 lightness = 0
126 r, g, b = cc.to_rgb(color)
127 rl = r + (1.0-lightness)*(1.0 - r)
128 gl = g + (1.0-lightness)*(1.0 - g)
129 bl = b + (1.0-lightness)*(1.0 - b)
130 return to_hex((rl, gl, bl)).upper()
133def xscalebar(ax, x, y, width, wunit=None, wformat=None, ha='left', va='bottom',
134 lw=None, color=None, capsize=None, clw=None, **kwargs):
135 """Horizontal scale bar with label.
137 From bendalab/plottools package.
139 Parameters
140 ----------
141 ax: matplotlib axes
142 Axes where to draw the scale bar.
143 x: float
144 x-coordinate where to draw the scale bar in relative units of the axes.
145 y: float
146 y-coordinate where to draw the scale bar in relative units of the axes.
147 width: float
148 Length of the scale bar in units of the data's x-values.
149 wunit: string or None
150 Optional unit of the data's x-values.
151 wformat: string or None
152 Optional format string for formatting the label of the scale bar
153 or simply a string used for labeling the scale bar.
154 ha: 'left', 'right', or 'center'
155 Scale bar aligned left, right, or centered to (x, y)
156 va: 'top' or 'bottom'
157 Label of the scale bar either above or below the scale bar.
158 lw: int, float, None
159 Line width of the scale bar.
160 color: matplotlib color
161 Color of the scalebar.
162 capsize: float or None
163 If larger then zero draw cap lines at the ends of the bar.
164 The length of the lines is given in points (same unit as linewidth).
165 clw: int, float, None
166 Line width of the cap lines.
167 kwargs: key-word arguments
168 Passed on to `ax.text()` used to print the scale bar label.
169 """
170 ax.autoscale(False)
171 # ax dimensions:
172 pixelx = np.abs(np.diff(ax.get_window_extent().get_points()[:,0]))[0]
173 pixely = np.abs(np.diff(ax.get_window_extent().get_points()[:,1]))[0]
174 xmin, xmax = ax.get_xlim()
175 ymin, ymax = ax.get_ylim()
176 unitx = xmax - xmin
177 unity = ymax - ymin
178 dxu = np.abs(unitx)/pixelx
179 dyu = np.abs(unity)/pixely
180 # transform x, y from relative units to axis units:
181 x = xmin + x*unitx
182 y = ymin + y*unity
183 # bar length:
184 if wformat is None:
185 wformat = '%.0f'
186 if width < 1.0:
187 wformat = '%.1f'
188 try:
189 ls = wformat % width
190 width = float(ls)
191 except TypeError:
192 ls = wformat
193 # bar:
194 if ha == 'left':
195 x0 = x
196 x1 = x+width
197 elif ha == 'right':
198 x0 = x-width
199 x1 = x
200 else:
201 x0 = x-0.5*width
202 x1 = x+0.5*width
203 # line width:
204 if lw is None:
205 lw = 2
206 # color:
207 if color is None:
208 color = 'k'
209 # scalebar:
210 lh = ax.plot([x0, x1], [y, y], '-', color=color, lw=lw,
211 solid_capstyle='butt', clip_on=False)
212 # get y position of line in figure pixel coordinates:
213 ly = np.array(lh[0].get_window_extent(ax.get_figure().canvas.get_renderer()))[0,1]
214 # caps:
215 if capsize is None:
216 capsize = 0
217 if clw is None:
218 clw = 0.5
219 if capsize > 0.0:
220 dy = capsize*dyu
221 ax.plot([x0, x0], [y-dy, y+dy], '-', color=color, lw=clw,
222 solid_capstyle='butt', clip_on=False)
223 ax.plot([x1, x1], [y-dy, y+dy], '-', color=color, lw=clw,
224 solid_capstyle='butt', clip_on=False)
225 # label:
226 if wunit:
227 ls += u'\u2009%s' % wunit
228 if va == 'top':
229 th = ax.text(0.5*(x0+x1), y, ls, clip_on=False,
230 ha='center', va='bottom', **kwargs)
231 # get y coordinate of text bottom in figure pixel coordinates:
232 ty = np.array(th.get_window_extent(ax.get_figure().canvas.get_renderer()))[0,1]
233 dty = ly+0.5*lw + 2.0 - ty
234 else:
235 th = ax.text(0.5*(x0+x1), y, ls, clip_on=False,
236 ha='center', va='top', **kwargs)
237 # get y coordinate of text bottom in figure pixel coordinates:
238 ty = np.array(th.get_window_extent(ax.get_figure().canvas.get_renderer()))[1,1]
239 dty = ly-0.5*lw - 2.0 - ty
240 th.set_position((0.5*(x0+x1), y+dyu*dty))
241 return x0, x1, y
244def yscalebar(ax, x, y, height, hunit=None, hformat=None, ha='left', va='bottom',
245 lw=None, color=None, capsize=None, clw=None, **kwargs):
247 """Vertical scale bar with label.
249 From bendalab/plottools package.
251 Parameters
252 ----------
253 ax: matplotlib axes
254 Axes where to draw the scale bar.
255 x: float
256 x-coordinate where to draw the scale bar in relative units of the axes.
257 y: float
258 y-coordinate where to draw the scale bar in relative units of the axes.
259 height: float
260 Length of the scale bar in units of the data's y-values.
261 hunit: string
262 Unit of the data's y-values.
263 hformat: string or None
264 Optional format string for formatting the label of the scale bar
265 or simply a string used for labeling the scale bar.
266 ha: 'left' or 'right'
267 Label of the scale bar either to the left or to the right
268 of the scale bar.
269 va: 'top', 'bottom', or 'center'
270 Scale bar aligned above, below, or centered on (x, y).
271 lw: int, float, None
272 Line width of the scale bar.
273 color: matplotlib color
274 Color of the scalebar.
275 capsize: float or None
276 If larger then zero draw cap lines at the ends of the bar.
277 The length of the lines is given in points (same unit as linewidth).
278 clw: int, float
279 Line width of the cap lines.
280 kwargs: key-word arguments
281 Passed on to `ax.text()` used to print the scale bar label.
282 """
284 ax.autoscale(False)
285 # ax dimensions:
286 pixelx = np.abs(np.diff(ax.get_window_extent().get_points()[:,0]))[0]
287 pixely = np.abs(np.diff(ax.get_window_extent().get_points()[:,1]))[0]
288 xmin, xmax = ax.get_xlim()
289 ymin, ymax = ax.get_ylim()
290 unitx = xmax - xmin
291 unity = ymax - ymin
292 dxu = np.abs(unitx)/pixelx
293 dyu = np.abs(unity)/pixely
294 # transform x, y from relative units to axis units:
295 x = xmin + x*unitx
296 y = ymin + y*unity
297 # bar length:
298 if hformat is None:
299 hformat = '%.0f'
300 if height < 1.0:
301 hformat = '%.1f'
302 try:
303 ls = hformat % height
304 width = float(ls)
305 except TypeError:
306 ls = hformat
307 # bar:
308 if va == 'bottom':
309 y0 = y
310 y1 = y+height
311 elif va == 'top':
312 y0 = y-height
313 y1 = y
314 else:
315 y0 = y-0.5*height
316 y1 = y+0.5*height
317 # line width:
318 if lw is None:
319 lw = 2
320 # color:
321 if color is None:
322 color = 'k'
323 # scalebar:
324 lh = ax.plot([x, x], [y0, y1], '-', color=color, lw=lw,
325 solid_capstyle='butt', clip_on=False)
326 # get x position of line in figure pixel coordinates:
327 lx = np.array(lh[0].get_window_extent(ax.get_figure().canvas.get_renderer()))[0,0]
328 # caps:
329 if capsize is None:
330 capsize = 0
331 if clw is None:
332 clw = 0.5
333 if capsize > 0.0:
334 dx = capsize*dxu
335 ax.plot([x-dx, x+dx], [y0, y0], '-', color=color, lw=clw, solid_capstyle='butt',
336 clip_on=False)
337 ax.plot([x-dx, x+dx], [y1, y1], '-', color=color, lw=clw, solid_capstyle='butt',
338 clip_on=False)
339 # label:
340 if hunit:
341 ls += u'\u2009%s' % hunit
342 if ha == 'right':
343 th = ax.text(x, 0.5*(y0+y1), ls, clip_on=False, rotation=90.0,
344 ha='left', va='center', **kwargs)
345 # get x coordinate of text bottom in figure pixel coordinates:
346 tx = np.array(th.get_window_extent(ax.get_figure().canvas.get_renderer()))[0,0]
347 dtx = lx+0.5*lw + 2.0 - tx
348 else:
349 th = ax.text(x, 0.5*(y0+y1), ls, clip_on=False, rotation=90.0,
350 ha='right', va='center', **kwargs)
351 # get x coordinate of text bottom in figure pixel coordinates:
352 tx = np.array(th.get_window_extent(ax.get_figure().canvas.get_renderer()))[1,0]
353 dtx = lx-0.5*lw - 1.0 - tx
354 th.set_position((x+dxu*dtx, 0.5*(y0+y1)))
355 return x, y0, y1
358def arrowed_spines(ax, ms=10):
359 """ Spine with arrow on the y-axis of a plot.
361 Parameters
362 ----------
363 ax : matplotlib figure axis
364 Axis on which the arrow should be plot.
365 """
366 xmin, xmax = ax.get_xlim()
367 ymin, ymax = ax.get_ylim()
368 ax.scatter([xmin], [ymax], s=ms, marker='^', clip_on=False, color='k')
369 ax.set_xlim(xmin, xmax)
370 ax.set_ylim(ymin, ymax)
373def loghist(ax, x, bmin, bmax, n, c, orientation='vertical', label=''):
374 """ Plot histogram with logarithmic scale.
376 Parameters
377 ----------
378 ax : matplotlib axis
379 Axis to plot the histogram on.
380 x : numpy array
381 Input data for histogram.
382 bmin : float
383 Minimum value for the histogram bins.
384 bmax : float
385 Maximum value for the histogram bins.
386 n : int
387 Number of bins.
388 c : matplotlib color
389 Color of histogram.
390 orientation : string (optional)
391 Histogram orientation.
392 Defaults to 'vertical'.
393 label : string (optional)
394 Label for x.
395 Defaults to '' (no label).
397 Returns
398 -------
399 n : array
400 The values of the histogram bins.
401 bins : array
402 The edges of the bins.
403 patches : BarContainer
404 Container of individual artists used to create the histogram.
405 """
406 return ax.hist(x, bins=np.exp(np.linspace(np.log(bmin), np.log(bmax), n)),
407 color=c, orientation=orientation, label=label)
410def plot_all(data, eod_p_times, eod_tr_times, fs, mean_eods):
411 """Quick way to view the output of extract_pulsefish in a single plot.
413 Parameters
414 ----------
415 data: array
416 Recording data.
417 eod_p_times: array of ints
418 EOD peak indices.
419 eod_tr_times: array of ints
420 EOD trough indices.
421 fs: float
422 Samplerate.
423 mean_eods: list of numpy arrays
424 Mean EODs of each pulsefish found in the recording.
425 """
426 fig = plt.figure(figsize=(10, 5))
428 if len(eod_p_times) > 0:
429 gs = gridspec.GridSpec(2, len(eod_p_times))
430 ax = fig.add_subplot(gs[0,:])
431 ax.plot(np.arange(len(data))/fs, data, c='k', alpha=0.3)
433 for i, (pt, tt) in enumerate(zip(eod_p_times, eod_tr_times)):
434 ax.plot(pt, data[(pt*fs).astype('int')], 'o', label=i+1, ms=10, c=cmap(i))
435 ax.plot(tt, data[(tt*fs).astype('int')], 'o', label=i+1, ms=10, c=cmap(i))
437 ax.set_xlabel('time [s]')
438 ax.set_ylabel('amplitude [V]')
440 for i, m in enumerate(mean_eods):
441 ax = fig.add_subplot(gs[1,i])
442 ax.plot(1000*m[0], 1000*m[1], c='k')
444 ax.fill_between(1000*m[0], 1000*(m[1]-m[2]), 1000*(m[1]+m[2]), color=cmap(i))
445 ax.set_xlabel('time [ms]')
446 ax.set_ylabel('amplitude [mV]')
447 else:
448 plt.plot(np.arange(len(data))/fs, data, c='k', alpha=0.3)
450 plt.tight_layout()
453def plot_clustering(samplerate, eod_widths, eod_hights, eod_shapes, disc_masks, merge_masks):
454 """Plot all clustering steps.
456 Plot clustering steps on width, height and shape. Then plot the remaining EODs after
457 the EOD assessment step and the EODs after the merge step.
459 Parameters
460 ----------
461 samplerate : float
462 Samplerate of EOD snippets.
463 eod_widths : list of three 1D numpy arrays
464 The first list entry gives the unique labels of all width clusters as a list of ints.
465 The second list entry gives the width values for each EOD in samples as a
466 1D numpy array of ints.
467 The third list entry gives the width labels for each EOD as a 1D numpy array of ints.
468 eod_hights : nested lists (2 layers) of three 1D numpy arrays
469 The first list entry gives the unique labels of all height clusters as a list of ints
470 for each width cluster.
471 The second list entry gives the height values for each EOD as a 1D numpy array
472 of floats for each width cluster.
473 The third list entry gives the height labels for each EOD as a 1D numpy array
474 of ints for each width cluster.
475 eod_shapes : nested lists (3 layers) of three 1D numpy arrays
476 The first list entry gives the raw EOD snippets as a 2D numpy array for each
477 height cluster in a width cluster.
478 The second list entry gives the snippet PCA values for each EOD as a 2D numpy array
479 of floats for each height cluster in a width cluster.
480 The third list entry gives the shape labels for each EOD as a 1D numpy array of ints
481 for each height cluster in a width cluster.
482 disc_masks : Nested lists (two layers) of 1D numpy arrays
483 The masks of EODs that are discarded by the discarding step of the algorithm.
484 The masks are 1D boolean arrays where
485 instances that are set to True are discarded by the algorithm. Discarding masks
486 are saved in nested lists that represent the width and height clusters.
487 merge_masks : Nested lists (two layers) of 2D numpy arrays
488 The masks of EODs that are discarded by the merging step of the algorithm.
489 The masks are 2D boolean arrays where
490 for each sample point `i` either `merge_mask[i,0]` or `merge_mask[i,1]` is set to True.
491 Here, merge_mask[:,0] represents the
492 peak-centered clusters and `merge_mask[:,1]` represents the trough-centered clusters.
493 Merge masks are saved in nested lists
494 that represent the width and height clusters.
495 """
496 # create figure + transparant figure.
497 fig = plt.figure(figsize=(12, 7))
498 transFigure = fig.transFigure.inverted()
500 # set up the figure layout
501 outer = gridspec.GridSpec(1, 5, width_ratios=[1, 1, 2, 1, 2], left=0.05, right=0.95)
503 # set titles for each clustering step
504 titles = ['1. Widths', '2. Heights', '3. Shape', '4. Pulse EODs', '5. Merge']
505 for i, title in enumerate(titles):
506 title_ax = gridspec.GridSpecFromSubplotSpec(1, 1, subplot_spec = outer[i])
507 ax = fig.add_subplot(title_ax[0])
508 ax.text(0, 110, title, ha='center', va='bottom', clip_on=False)
509 ax.set_xlim(-100, 100)
510 ax.set_ylim(-100, 100)
511 ax.axis('off')
513 # compute sizes for each axis
514 w_size = 1
515 h_size = len(eod_hights[1])
517 shape_size = np.sum([len(sl) for sl in eod_shapes[0]])
519 # count required axes sized for the last two plot columns.
520 disc_size = 0
521 merge_size= 0
522 for shapelabel, dmasks, mmasks in zip(eod_shapes[2], disc_masks, merge_masks):
523 for sl, dm, mm in zip(shapelabel, dmasks, mmasks):
524 uld1 = np.unique((sl[0]+1)*np.invert(dm[0]))
525 uld2 = np.unique((sl[1]+1)*np.invert(dm[1]))
526 disc_size = disc_size+len(uld1[uld1>0])+len(uld2[uld2>0])
528 uld1 = np.unique((sl[0]+1)*mm[0])
529 uld2 = np.unique((sl[1]+1)*mm[1])
530 merge_size = merge_size+len(uld1[uld1>0])+len(uld2[uld2>0])
532 # set counters to keep track of the plot axes
533 disc_block = 0
534 merge_block = 0
535 shape_count = 0
537 # create all axes
538 width_hist_ax = gridspec.GridSpecFromSubplotSpec(w_size, 1, subplot_spec = outer[0])
539 hight_hist_ax = gridspec.GridSpecFromSubplotSpec(h_size, 1, subplot_spec = outer[1])
540 shape_ax = gridspec.GridSpecFromSubplotSpec(shape_size, 1, subplot_spec = outer[2])
541 shape_windows = [gridspec.GridSpecFromSubplotSpec(2, 2, hspace=0.0, wspace=0.0,
542 subplot_spec=shape_ax[i])
543 for i in range(shape_size)]
545 EOD_delete_ax = gridspec.GridSpecFromSubplotSpec(disc_size, 1, subplot_spec=outer[3])
546 EOD_merge_ax = gridspec.GridSpecFromSubplotSpec(merge_size, 1, subplot_spec=outer[4])
548 # plot width labels histogram
549 ax1 = fig.add_subplot(width_hist_ax[0])
550 # set axes features.
551 ax1.set_xscale('log')
552 ax1.spines['top'].set_visible(False)
553 ax1.spines['right'].set_visible(False)
554 ax1.spines['bottom'].set_visible(False)
555 ax1.axes.xaxis.set_visible(False)
556 ax1.set_yticklabels([])
558 # indices for plot colors (dark to light)
559 colidxsw = -np.linspace(-1.25, -0.5, h_size)
561 for i, (wl, colw, uhl, eod_h, eod_h_labs, w_snip, w_feat, w_lab, w_dm, w_mm) in enumerate(zip(eod_widths[0], colidxsw, eod_hights[0], eod_hights[1], eod_hights[2], eod_shapes[0], eod_shapes[1], eod_shapes[2], disc_masks, merge_masks)):
563 # plot width hist
564 hw, _, _ = ax1.hist(eod_widths[1][eod_widths[2]==wl],
565 bins=np.linspace(np.min(eod_widths[1]), np.max(eod_widths[1]), 100),
566 color=lighter(c_o, colw), orientation='horizontal')
568 # set arrow when the last hist is plot so the size of the axes are known.
569 if i == h_size-1:
570 arrowed_spines(ax1, ms=20)
572 # determine total size of the hight historgams now.
573 my, b = np.histogram(eod_h, bins=np.exp(np.linspace(np.min(np.log(eod_h)),
574 np.max(np.log(eod_h)), 100)))
575 maxy = np.max(my)
577 # set axes features for hight hist.
578 ax2 = fig.add_subplot(hight_hist_ax[h_size-i-1])
579 ax2.set_xscale('log')
580 ax2.spines['top'].set_visible(False)
581 ax2.spines['right'].set_visible(False)
582 ax2.spines['bottom'].set_visible(False)
583 ax2.set_xlim(0.9, maxy)
584 ax2.axes.xaxis.set_visible(False)
585 ax2.set_yscale('log')
586 ax2.yaxis.set_major_formatter(ticker.NullFormatter())
587 ax2.yaxis.set_minor_formatter(ticker.NullFormatter())
589 # define colors for plots
590 colidxsh = -np.linspace(-1.25, -0.5, len(uhl))
592 for n, (hl, hcol, snippets, features, labels, dmasks, mmasks) in enumerate(zip(uhl, colidxsh, w_snip, w_feat, w_lab, w_dm, w_mm)):
594 hh, _, _ = loghist(ax2, eod_h[eod_h_labs==hl], np.min(eod_h), np.max(eod_h), 100,
595 lighter(c_g, hcol), orientation='horizontal')
597 # set arrow spines only on last plot
598 if n == len(uhl)-1:
599 arrowed_spines(ax2, ms=10)
601 # plot line from the width histogram to the height histogram.
602 if n == 0:
603 coord1 = transFigure.transform(ax1.transData.transform([np.median(hw[hw!=0]),
604 np.median(eod_widths[1][eod_widths[2]==wl])]))
605 coord2 = transFigure.transform(ax2.transData.transform([0.9, np.mean(eod_h)]))
606 line = Line2D((coord1[0], coord2[0]), (coord1[1], coord2[1]),
607 transform=fig.transFigure, color='grey', linewidth=0.5)
608 fig.lines.append(line)
610 # compute sizes of the eod_discarding and merge steps
611 s1 = np.unique((labels[0]+1)*(~dmasks[0]))
612 s2 = np.unique((labels[1]+1)*(~dmasks[1]))
613 disc_block = disc_block + len(s1[s1>0]) + len(s2[s2>0])
615 s1 = np.unique((labels[0]+1)*(mmasks[0]))
616 s2 = np.unique((labels[1]+1)*(mmasks[1]))
617 merge_block = merge_block + len(s1[s1>0]) + len(s2[s2>0])
619 axs = []
620 disc_count = 0
621 merge_count = 0
623 # now plot the clusters for peak and trough centerings
624 for pt, cmap_pt in zip([0, 1], cmap_pts):
626 ax3 = fig.add_subplot(shape_windows[shape_size-1-shape_count][pt,0])
627 ax4 = fig.add_subplot(shape_windows[shape_size-1-shape_count][pt,1])
629 # remove axes
630 ax3.axes.xaxis.set_visible(False)
631 ax4.axes.yaxis.set_visible(False)
632 ax3.axes.yaxis.set_visible(False)
633 ax4.axes.xaxis.set_visible(False)
635 # set color indices
636 colidxss = -np.linspace(-1.25, -0.5, len(np.unique(labels[pt][labels[pt]>=0])))
637 j=0
638 for c in np.unique(labels[pt]):
640 if c<0:
641 # plot noise features + snippets
642 ax3.plot(features[pt][labels[pt]==c,0], features[pt][labels[pt]==c,1],
643 '.', color='lightgrey', label='-1', rasterized=True)
644 ax4.plot(snippets[pt][labels[pt]==c].T, linewidth=0.1,
645 color='lightgrey', label='-1', rasterized=True)
646 else:
647 # plot cluster features and snippets
648 ax3.plot(features[pt][labels[pt]==c,0], features[pt][labels[pt]==c,1],
649 '.', color=lighter(cmap_pt, colidxss[j]), label=c,
650 rasterized=True)
651 ax4.plot(snippets[pt][labels[pt]==c].T, linewidth=0.1,
652 color=lighter(cmap_pt, colidxss[j]), label=c, rasterized=True)
654 # check if the current cluster is an EOD, if yes, plot it.
655 if np.sum(dmasks[pt][labels[pt]==c]) == 0:
657 ax = fig.add_subplot(EOD_delete_ax[disc_size-disc_block+disc_count])
658 ax.axis('off')
660 # plot mean EOD snippet
661 ax.plot(np.mean(snippets[pt][labels[pt]==c], axis=0),
662 color=lighter(cmap_pt, colidxss[j]))
663 disc_count = disc_count + 1
665 # match colors and draw line..
666 coord1 = transFigure.transform(ax4.transData.transform([ax4.get_xlim()[1],
667 ax4.get_ylim()[0] + 0.5*(ax4.get_ylim()[1]-ax4.get_ylim()[0])]))
668 coord2 = transFigure.transform(ax.transData.transform([ax.get_xlim()[0],ax.get_ylim()[0] + 0.5*(ax.get_ylim()[1]-ax.get_ylim()[0])]))
669 line = Line2D((coord1[0], coord2[0]), (coord1[1], coord2[1]),
670 transform=fig.transFigure, color='grey',
671 linewidth=0.5)
672 fig.lines.append(line)
673 axs.append(ax)
675 # check if the current EOD survives the merge step
676 # if so, plot it.
677 if np.sum(mmasks[pt, labels[pt]==c])>0:
679 ax = fig.add_subplot(EOD_merge_ax[merge_size-merge_block+merge_count])
680 ax.axis('off')
682 ax.plot(np.mean(snippets[pt][labels[pt]==c], axis=0),
683 color=lighter(cmap_pt, colidxss[j]))
684 merge_count = merge_count + 1
686 j=j+1
688 if pt==0:
689 # draw line from hight cluster to EOD shape clusters.
690 coord1 = transFigure.transform(ax2.transData.transform([np.median(hh[hh!=0]),
691 np.median(eod_h[eod_h_labs==hl])]))
692 coord2 = transFigure.transform(ax3.transData.transform([ax3.get_xlim()[0],
693 ax3.get_ylim()[0]]))
694 line = Line2D((coord1[0], coord2[0]), (coord1[1], coord2[1]),
695 transform=fig.transFigure, color='grey', linewidth=0.5)
696 fig.lines.append(line)
698 shape_count = shape_count + 1
700 if len(axs)>0:
701 # plot lines that indicate the merged clusters.
702 coord1 = transFigure.transform(axs[0].transData.transform([axs[0].get_xlim()[1]+0.1*(axs[0].get_xlim()[1]-axs[0].get_xlim()[0]),
703 axs[0].get_ylim()[1]-0.25*(axs[0].get_ylim()[1]-axs[0].get_ylim()[0])]))
704 coord2 = transFigure.transform(axs[-1].transData.transform([axs[-1].get_xlim()[1]+0.1*(axs[-1].get_xlim()[1]-axs[-1].get_xlim()[0]),
705 axs[-1].get_ylim()[0]+0.25*(axs[-1].get_ylim()[1]-axs[-1].get_ylim()[0])]))
706 line = Line2D((coord1[0], coord2[0]), (coord1[1], coord2[1]),
707 transform=fig.transFigure, color='grey', linewidth=1)
708 fig.lines.append(line)
711def plot_bgm(x, means, variances, weights, use_log, labels, labels_am, xlab):
712 """Plot a BGM clustering step either on EOD width or height.
714 Parameters
715 ----------
716 x : 1D numpy array of floats
717 BGM input values.
718 means : list of floats
719 BGM Gaussian means
720 variances : list of floats
721 BGM Gaussian variances.
722 weights : list of floats
723 BGM Gaussian weights.
724 use_log : boolean
725 True if the z-scored logarithm of the data was used as BGM input.
726 labels : 1D numpy array of ints
727 Labels defined by BGM model (before merging based on merge factor).
728 labels_am : 1D numpy array of ints
729 Labels defined by BGM model (after merging based on merge factor).
730 xlab : string
731 Label for plot (defines the units of the BGM data).
732 """
733 if 'width' in xlab:
734 ccol = c_o
735 elif 'height' in xlab:
736 ccol = c_g
737 else:
738 ccol = 'b'
740 # get the transform that was used as BGM input
741 if use_log:
742 x_transform = stats.zscore(np.log(x))
743 xplot = np.exp(np.linspace(np.log(np.min(x)), np.log(np.max(x)), 1000))
744 else:
745 x_transform = stats.zscore(x)
746 xplot = np.linspace(np.min(x), np.max(x), 1000)
748 # compute the x values and gaussians
749 x2 = np.linspace(np.min(x_transform), np.max(x_transform), 1000)
750 gaussians = []
751 gmax = 0
752 for i, (w, m, std) in enumerate(zip(weights, means, variances)):
753 gaus = np.sqrt(w*stats.norm.pdf(x2, m, np.sqrt(std)))
754 gaussians.append(gaus)
755 gmax = max(np.max(gaus), gmax)
757 # compute classes defined by gaussian intersections
758 classes = np.argmax(np.vstack(gaussians), axis=0)
760 # find the minimum of any gaussian that is within its class
761 gmin = 100
762 for i, c in enumerate(np.unique(classes)):
763 gmin=min(gmin, np.min(gaussians[c][classes==c]))
765 # set up the figure
766 fig, ax1 = plt.subplots(figsize=(8, 4.8))
767 fig_ysize = 4
768 ax2 = ax1.twinx()
769 ax1.spines['top'].set_visible(False)
770 ax2.spines['top'].set_visible(False)
771 ax1.set_xlabel('x [a.u.]')
772 ax1.set_ylabel('#')
773 ax2.set_ylabel('Likelihood')
774 ax2.set_yscale('log')
775 ax1.set_yscale('log')
776 if use_log:
777 ax1.set_xscale('log')
778 ax1.set_xlabel(xlab)
780 # define colors for plotting gaussians
781 colidxs = -np.linspace(-1.25, -0.5, len(np.unique(classes)))
783 # plot the gaussians
784 for i, c in enumerate(np.unique(classes)):
785 ax2.plot(xplot, gaussians[c], c=lighter(c_grey, colidxs[i]), linewidth=2,
786 label=r'$N(\mu_%i, \sigma_%i)$'%(c, c))
788 # plot intersection lines
789 ax2.vlines(xplot[1:][np.diff(classes)!=0], 0, gmax/gmin, color='k', linewidth=2,
790 linestyle='--')
791 ax2.set_ylim(gmin, np.max(np.vstack(gaussians))*1.1)
793 # plot data distributions and classes
794 colidxs = -np.linspace(-1.25, -0.5, len(np.unique(labels)))
795 for i, l in enumerate(np.unique(labels)):
796 if use_log:
797 h, binn, _ = loghist(ax1, x[labels==l], np.min(x), np.max(x), 100,
798 lighter(ccol, colidxs[i]), label=r'$x_%i$'%l)
799 else:
800 h, binn, _ = ax1.hist(x[labels==l], bins=np.linspace(np.min(x), np.max(x), 100),
801 color=lighter(ccol, colidxs[i]), label=r'$x_%i$'%l)
803 # annotate merged clusters
804 for l in np.unique(labels_am):
805 maps = np.unique(labels[labels_am==l])
806 if len(maps) > 1:
807 x1 = x[labels==maps[0]]
808 x2 = x[labels==maps[1]]
810 print(np.median(x1))
811 print(np.median(x2))
812 print(gmax)
813 ax2.plot([np.median(x1), np.median(x2)], [1.2*gmax, 1.2*gmax], c='k', clip_on=False)
814 ax2.plot([np.median(x1), np.median(x1)], [1.1*gmax, 1.2*gmax], c='k', clip_on=False)
815 ax2.plot([np.median(x2), np.median(x2)], [1.1*gmax, 1.2*gmax], c='k', clip_on=False)
816 ax2.annotate(r'$\frac{|{\tilde{x}_%i-\tilde{x}_%i}|}{max(\tilde{x}_%i, \tilde{x}_%i)} < \epsilon$' % (maps[0], maps[1], maps[0], maps[1]), [np.median(x1)*1.1, gmax*1.2], xytext=(10, 10), textcoords='offset points', fontsize=12, annotation_clip=False, ha='center')
818 # add legends and plot.
819 ax2.legend(loc='lower left', frameon=False, bbox_to_anchor=(-0.05, 1.3),
820 ncol=len(np.unique(classes)))
821 ax1.legend(loc='upper left', frameon=False, bbox_to_anchor=(-0.05, 1.3),
822 ncol=len(np.unique(labels)))
823 plt.tight_layout()
826def plot_feature_extraction(raw_snippets, normalized_snippets, features, labels, dt, pt):
827 """Plot clustering step on EOD shape.
829 Parameters
830 ----------
831 raw_snippets : 2D numpy array
832 Raw EOD snippets.
833 normalized_snippets : 2D numpy array
834 Normalized EOD snippets.
835 features : 2D numpy array
836 PCA values for each normalized EOD snippet.
837 labels : 1D numpy array of ints
838 Cluster labels.
839 dt : float
840 Sample interval of snippets.
841 pt : int
842 Set to 0 for peak-centered EODs and set to 1 for trough-centered EODs.
843 """
844 ccol = cmap_pts[pt]
846 # set up the figure layout
847 fig = plt.figure(figsize=(((2+0.2)*4.8), 4.8))
848 outer = gridspec.GridSpec(1, 2, wspace=0.2, hspace=0)
850 x = np.arange(-dt*1000*raw_snippets.shape[1]/2, dt*1000*raw_snippets.shape[1]/2, dt*1000)
852 snip_ax = gridspec.GridSpecFromSubplotSpec(2, 1, subplot_spec = outer[0], hspace=0.35)
853 pc_ax = gridspec.GridSpecFromSubplotSpec(features.shape[1]-1, features.shape[1]-1,
854 subplot_spec = outer[1], hspace=0, wspace=0)
856 # 3 plots: raw snippets, normalized, pcs.
857 ax_raw_snip = fig.add_subplot(snip_ax[0])
858 ax_normalized_snip = fig.add_subplot(snip_ax[1])
860 colidxs = -np.linspace(-1.25, -0.5, len(np.unique(labels[labels>=0])))
861 j=0
863 for c in np.unique(labels):
864 if c<0:
865 color='lightgrey'
866 else:
867 color = lighter(ccol, colidxs[j])
868 j=j+1
870 ax_raw_snip.plot(x, raw_snippets[labels==c].T, color=color, label='-1',
871 rasterized=True, alpha=0.25)
872 ax_normalized_snip.plot(x, normalized_snippets[labels==c].T, color=color, alpha=0.25)
873 ax_raw_snip.spines['top'].set_visible(False)
874 ax_raw_snip.spines['right'].set_visible(False)
875 ax_raw_snip.get_xaxis().set_ticklabels([])
876 ax_raw_snip.set_title('Raw snippets')
877 ax_raw_snip.set_ylabel('Amplitude [a.u.]')
878 ax_normalized_snip.spines['top'].set_visible(False)
879 ax_normalized_snip.spines['right'].set_visible(False)
880 ax_normalized_snip.set_title('Normalized snippets')
881 ax_normalized_snip.set_ylabel('Amplitude [a.u.]')
882 ax_normalized_snip.set_xlabel('Time [ms]')
884 ax_raw_snip.axis('off')
885 ax_normalized_snip.axis('off')
887 ax_overlay = fig.add_subplot(pc_ax[:,:])
888 ax_overlay.set_title('Features')
889 ax_overlay.axis('off')
891 for n in range(features.shape[1]):
892 for m in range(n):
893 ax = fig.add_subplot(pc_ax[n-1,m])
894 ax.scatter(features[labels==c,m], features[labels==c,n], marker='.',
895 color=color, alpha=0.25)
896 ax.set_xlim(np.min(features), np.max(features))
897 ax.set_ylim(np.min(features), np.max(features))
898 ax.get_xaxis().set_ticklabels([])
899 ax.get_yaxis().set_ticklabels([])
900 ax.get_xaxis().set_ticks([])
901 ax.get_yaxis().set_ticks([])
903 if m==0:
904 ax.set_ylabel('PC %i'%(n+1))
906 if n==features.shape[1]-1:
907 ax.set_xlabel('PC %i'%(m+1))
909 ax = fig.add_subplot(pc_ax[0,features.shape[1]-2])
910 ax.set_xlim(np.min(features), np.max(features))
911 ax.set_ylim(np.min(features), np.max(features))
913 size = max(1, int(np.ceil(-np.log10(np.max(features)-np.min(features)))))
914 wbar = np.floor((np.max(features)-np.min(features))*10**size)/10**size
916 # should be smaller than the actual thing! so like x% of it?
917 xscalebar(ax, 0, 0, wbar, wformat='%%.%if'%size)
918 yscalebar(ax, 0, 0, wbar, hformat='%%.%if'%size)
919 ax.axis('off')
921def plot_moving_fish(ws, dts, clusterss, ts, fishcounts, T, ignore_stepss):
922 """Plot moving fish detection step.
924 Parameters
925 ----------
926 ws : list of floats
927 Median width for each width cluster that the moving fish algorithm is computed on
928 (in seconds).
929 dts : list of floats
930 Sliding window size (in seconds) for each width cluster.
931 clusterss : list of 1D numpy int arrays
932 Cluster labels for each EOD cluster in a width cluster.
933 ts : list of 1D numpy float arrays
934 EOD emission times for each EOD in a width cluster.
935 fishcounts : list of lists
936 Sliding window timepoints and fishcounts for each width cluster.
937 T : float
938 Lenght of analyzed recording in seconds.
939 ignore_stepss : list of 1D int arrays
940 Mask for fishcounts that were ignored (ignored if True) in the moving_fish analysis.
941 """
942 fig = plt.figure()
944 # create gridspec
945 outer = gridspec.GridSpec(len(ws), 1)
947 for i, (w, dt, clusters, t, fishcount, ignore_steps) in enumerate(zip(ws, dts, clusterss, ts, fishcounts, ignore_stepss)):
949 gs = gridspec.GridSpecFromSubplotSpec(2, 1, subplot_spec = outer[i])
951 # axis for clusters
952 ax1 = fig.add_subplot(gs[0])
953 # axis for fishcount
954 ax2 = fig.add_subplot(gs[1])
956 # plot clusters as eventplot
957 for cnum, c in enumerate(np.unique(clusters[clusters>=0])):
958 ax1.eventplot(t[clusters==c], lineoffsets=cnum, linelengths=0.5, color=cmap(i))
959 cnum = cnum + 1
961 # Plot the sliding window
962 rect=Rectangle((0, -0.5), dt, cnum, linewidth=1, linestyle='--', edgecolor='k',
963 facecolor='none', clip_on=False)
964 ax1.add_patch(rect)
965 ax1.arrow(dt+0.1, -0.5, 0.5, 0, head_width=0.1, head_length=0.1, facecolor='k',
966 edgecolor='k')
968 # plot parameters
969 ax1.set_title(r'$\tilde{w}_%i = %.3f ms$'%(i, 1000*w))
970 ax1.set_ylabel('cluster #')
971 ax1.set_yticks(range(0, cnum))
972 ax1.set_xlabel('time')
973 ax1.set_xlim(0, T)
974 ax1.axes.xaxis.set_visible(False)
975 ax1.spines['bottom'].set_visible(False)
976 ax1.spines['top'].set_visible(False)
977 ax1.spines['right'].set_visible(False)
978 ax1.spines['left'].set_visible(False)
980 # plot for fishcount
981 x = fishcount[0]
982 y = fishcount[1]
984 ax2 = fig.add_subplot(gs[1])
985 ax2.spines['top'].set_visible(False)
986 ax2.spines['right'].set_visible(False)
987 ax2.spines['bottom'].set_visible(False)
988 ax2.axes.xaxis.set_visible(False)
990 yplot = np.copy(y)
991 ax2.plot(x+dt/2, yplot, linestyle='-', marker='.', c=cmap(i), alpha=0.25)
992 yplot[ignore_steps.astype(bool)] = np.NaN
993 ax2.plot(x+dt/2, yplot, linestyle='-', marker='.', c=cmap(i))
994 ax2.set_ylabel('Fish count')
995 ax2.set_yticks(range(int(np.min(y)), 1+int(np.max(y))))
996 ax2.set_xlim(0, T)
998 if i < len(ws)-1:
999 ax2.axes.xaxis.set_visible(False)
1000 else:
1001 ax2.axes.xaxis.set_visible(False)
1002 xscalebar(ax2, 1, 0, 1, wunit='s', ha='right')
1004 con = ConnectionPatch([0, -0.5], [dt/2, y[0]], "data", "data",
1005 axesA=ax1, axesB=ax2, color='k')
1006 ax2.add_artist(con)
1007 con = ConnectionPatch([dt, -0.5], [dt/2, y[0]], "data", "data",
1008 axesA=ax1, axesB=ax2, color='k')
1009 ax2.add_artist(con)
1011 plt.xlim(0, T)