Coverage for src / thunderfish / thunderbrowse.py: 0%
577 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-15 17:50 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-15 17:50 +0000
1import sys
2import os
3import warnings
4import argparse
5import numpy as np
6import matplotlib.pyplot as plt
8from scipy.signal import butter, sosfiltfilt
9from audioio import PlayAudio, fade, write_audio
10from thunderlab.dataloader import DataLoader
11from thunderlab.eventdetection import detect_peaks, median_std_threshold
13from .version import __version__, __year__
14from .pulses import detect_pulses
17class SignalPlot:
18 def __init__(self, data, rate, unit, filename,
19 show_channels=[], tmax=None, fcutoff=None,
20 pulses=False):
21 self.filename = filename
22 self.rate = rate
23 self.data = data
24 self.channels = self.data.shape[1] if len(self.data.shape) > 1 else 1
25 self.unit = unit
26 self.tmax = (len(self.data)-1)/self.rate
27 if not tmax is None:
28 self.tmax = tmax
29 self.data = data[:int(tmax*self.rate),:]
30 self.toffset = 0.0
31 self.twindow = 10.0
32 if self.twindow > self.tmax:
33 self.twindow = np.round(2 ** (np.floor(np.log(self.tmax) / np.log(2.0)) + 1.0))
34 if not tmax is None:
35 self.twindow = tmax
36 self.pulses = np.zeros((0, 3), dtype=int)
37 self.labels = []
38 self.fishes = []
39 self.pulse_times = []
40 self.pulse_gids = []
41 if len(show_channels) == 0:
42 self.show_channels = np.arange(self.channels)
43 else:
44 self.show_channels = np.array(show_channels)
45 self.traces = len(self.show_channels)
46 self.ymin = -1.0 * np.ones(self.traces)
47 self.ymax = +1.0 * np.ones(self.traces)
48 self.fmax = 100.0
49 self.trace_artist = [None] * self.traces
50 self.show_gid = False
51 self.pulse_artist = []
52 self.marker_artist = [None] * (self.traces + 1)
53 self.ipis_artist = []
54 self.ipis_labels = []
55 self.figf = None
56 self.axf = None
57 self.pulse_colors = ['C1', 'C2', 'C3', 'C4', 'C5', 'C6', 'C7', 'C8', 'C9', 'C0']
58 self.help = False
59 self.helptext = []
60 self.audio = PlayAudio()
62 # filter data:
63 if not fcutoff is None:
64 sos = butter(2, fcutoff, 'high', fs=rate, output='sos')
65 self.data = sosfiltfilt(sos, self.data[:], 0)
67 # pulse detection:
68 if pulses:
69 # label, group, channel, peak index, trough index
70 all_pulses = np.zeros((0, 5), dtype=int)
71 for c in range(self.channels):
72 #thresh = 1*np.std(self.data[:int(2*self.rate),c])
73 thresh = median_std_threshold(self.data[:,c], self.rate,
74 thresh_fac=6.0)
75 thresh = 0.01
76 #p, t = detect_peaks(self.data[:,c], thresh)
77 p, t, w, h = detect_pulses(self.data[:,c], self.rate,
78 thresh,
79 min_rel_slope_diff=0.25,
80 min_width=0.0001,
81 max_width=0.01,
82 width_fac=5.0)
83 # label, group, channel, peak, trough:
84 pulses = np.hstack((np.arange(len(p))[:,np.newaxis],
85 np.zeros((len(p), 1), dtype=int),
86 np.ones((len(p), 1), dtype=int)*c,
87 p[:,np.newaxis], t[:,np.newaxis]))
88 all_pulses = np.vstack((all_pulses, pulses))
89 self.pulses = all_pulses[np.argsort(all_pulses[:,3]),:]
90 # grouping over channels:
91 max_di = int(0.0002*self.rate) # TODO: parameter
92 l = -1
93 k = 0
94 while k < len(self.pulses):
95 tp = self.pulses[k,3]
96 tt = self.pulses[k,4]
97 height = self.data[self.pulses[k,3],self.pulses[k,2]] - \
98 self.data[self.pulses[k,4],self.pulses[k,2]]
99 channel_counts = np.zeros(self.channels, dtype=int)
100 channel_counts[self.pulses[k,2]] += 1
101 for c in range(1, 3*self.channels):
102 if k+c >= len(self.pulses):
103 break
104 # pulse too far away:
105 if channel_counts[self.pulses[k+c,2]] > 1 or \
106 (np.abs(self.pulses[k+c,3] - tp) > max_di and
107 np.abs(self.pulses[k+c,3] - tt) > max_di and
108 np.abs(self.pulses[k+c,4] - tp) > max_di and
109 np.abs(self.pulses[k+c,4] - tt) > max_di):
110 break
111 channel_counts[self.pulses[k+c,2]] += 1
112 height_kc = self.data[self.pulses[k+c,3],self.pulses[k+c,2]] - \
113 self.data[self.pulses[k+c,4],self.pulses[k+c,2]]
114 # heighest pulse sets time reference:
115 if height_kc > height:
116 tp = self.pulses[k+c,3]
117 tt = self.pulses[k+c,4]
118 height = height_kc
119 # all pulses too small:
120 if height < 0.02: # TODO parameter
121 self.pulses[k:k+c,0] = -1
122 k += c
123 continue
124 # new label:
125 l += 1
126 # remove lost pulses:
127 for j in range(c):
128 if (np.abs(self.pulses[k+j,3] - tp) > max_di and
129 np.abs(self.pulses[k+j,3] - tt) > max_di and
130 np.abs(self.pulses[k+j,4] - tp) > max_di and
131 np.abs(self.pulses[k+j,4] - tt) > max_di):
132 self.pulses[k+j,0] = -1
133 channel_counts[self.pulses[k+j,2]] -= 1
134 else:
135 self.pulses[k+j,0] = l
136 self.pulses[k+j,1] = l
137 # keep only the largest pulse of each channel:
138 pulses = self.pulses[k:k+c,:]
139 for dc in np.where(channel_counts > 1)[0]:
140 idx = np.where(self.pulses[k:k+c,2] == dc)[0]
141 heights = self.data[pulses[idx,3],dc] - \
142 self.data[pulses[idx,4],dc]
143 for i in range(len(idx)):
144 if i != np.argmax(heights):
145 channel_counts[self.pulses[k+idx[i],2]] -= 1
146 self.pulses[k+idx[i],0] = -1
147 k += c
148 self.pulses = self.pulses[self.pulses[:,0] >= 0,:]
150 # clustering:
151 min_dists = []
152 recent = []
153 k = 0
154 while k < len(self.pulses):
155 # select pulse group:
156 j = k
157 gid = self.pulses[j,1]
158 for c in range(self.channels):
159 k += 1
160 if k >= len(self.pulses) or \
161 self.pulses[k,1] != gid:
162 break
163 heights = np.zeros(self.channels)
164 heights[self.pulses[j:k,2]] = \
165 self.data[self.pulses[j:k,3],self.pulses[j:k,2]] - \
166 self.data[self.pulses[j:k,4],self.pulses[j:k,2]]
167 # time of largest pulse:
168 pulse_time = self.pulses[j+np.argmax(heights[self.pulses[j:k,2]]),3]
169 # assign to cluster:
170 if len(self.pulse_times) == 0:
171 label = len(self.pulse_times)
172 self.pulse_times.append([])
173 self.pulse_gids.append([])
174 else:
175 # compute metrics of recent fishes:
176 # mean relative height difference:
177 dists = np.array([np.mean(np.abs(hh - heights)/np.max(hh))
178 for ll, tt, hh in recent])
179 thresh = 0.1 # TODO: make parameter
180 # distance between pulses:
181 ipis = np.array([(pulse_time - tt)/self.rate
182 for ll, tt, hh in recent])
183 ## how can ipis be 0, or just one sample?
184 ##if len(ipis[ipis<0.001]) > 0:
185 ## print(ipis[ipis<0.001])
186 # ensure minimum IP distance:
187 dists[1/ipis > 300.0] = 2*np.max(dists) # TODO: make parameter
188 # minimum ditance:
189 min_dist_idx = np.argmin(dists)
190 min_dists.append(dists[min_dist_idx])
191 if dists[min_dist_idx] < thresh:
192 label = recent[min_dist_idx][0]
193 else:
194 label = len(self.pulse_times)
195 self.pulse_times.append([])
196 self.pulse_gids.append([])
197 self.pulses[j:k,0] = label
198 self.pulse_times[label].append(pulse_time)
199 self.pulse_gids[label].append(gid)
200 self.fishes.append([label, pulse_time, heights])
201 recent.append([label, pulse_time, heights])
202 # remove old fish:
203 for i, (ll, tt, hh) in enumerate(recent):
204 # TODO: make parameter:
205 if (pulse_time - tt)/self.rate <= 0.2:
206 recent = recent[i:]
207 break
208 # only consider the n most recent pulses of a fish:
209 n = 5 # TODO make parameter
210 labels = np.array([ll for ll, tt, hh in recent])
211 if np.sum(labels == label) > n:
212 del recent[np.where(labels == label)[0][0]]
213 # pulse times to arrays:
214 for k in range(len(self.pulse_times)):
215 self.pulse_times[k] = np.array(self.pulse_times[k])
219 """
220 # find temporally missing pulses:
221 npulses = np.array([len(pts) for pts in self.pulse_times],
222 dtype=int)
223 idx = np.argsort(npulses)
224 for i in range(len(idx)):
225 li = idx[len(idx)-1-i]
226 if len(self.pulse_times[li]) < 10 or \
227 len(self.pulse_times[li])/npulses[li] < 0.5:
228 continue
229 ipis = np.diff(self.pulse_times[li])
230 n = 4 # TODO: make parameter
231 k = 0
232 while k < len(ipis)-n:
233 mipi = np.median(ipis[k:k+n])
234 if ipis[k+n-2] > 1.8*mipi:
235 # search for pulse closest to pt:
236 pt = self.pulse_times[li][k+n-2] + mipi
237 mlj = -1
238 mpj = -1
239 mdj = 10*mipi
240 for lj in range(len(self.pulse_times)):
241 if lj == li or len(self.pulse_times[lj]) == 0:
242 continue
243 pj = np.argmin(np.abs(self.pulse_times[lj] - pt))
244 dj = np.abs(self.pulse_times[lj][pj] - pt)
245 if dj < int(0.001*self.rate) and dj < mdj:
246 mdj = dj
247 mpj = pj
248 mlj = lj
249 if mlj >= 0:
250 # there is a pulse close to pt:
251 ptj = self.pulse_times[mlj][mpj]
252 pulses = self.pulses[self.pulses[:,0] == mlj,:]
253 gid = pulses[np.argmin(np.abs(pulses[:,3] - ptj)),1]
254 self.pulse_times[li] = np.insert(self.pulse_times[li], k+n-1, ptj)
255 self.pulse_gids[li].insert(k+n-1, gid)
256 # maybe don't delete but always duplicate and flag it:
257 if False: # can be deleted
258 self.pulse_times[mlj] = np.delete(self.pulse_times[mlj], mpj)
259 self.pulse_gids[mlj].pop(mpj)
260 self.pulses[self.pulses[:,1] == gid,0] = li
261 else: # pulse needs to be duplicated:
262 self.pulses[self.pulses[:,1] == gid,0] = li
263 ipis = np.diff(self.pulse_times[li])
264 k += 1
268 # clean up pulses:
269 for l in range(len(self.pulse_times)):
270 if len(self.pulse_times[l])/npulses[l] < 0.5:
271 self.pulse_times[l] = np.array([])
272 self.pulse_gids[l] = []
273 self.pulses[self.pulses[:,0] == l,0] = -1
274 self.pulses = self.pulses[self.pulses[:,0] >= 0,:]
275 """
277 """
278 # remove labels that are too close to others:
279 widths = np.zeros(len(self.pulse_times), dtype=int)
280 for k in range(len(self.pulse_times)):
281 widths[k] = int(np.mean(np.abs(self.pulses[self.pulses[:,0] == k,3] - self.pulses[self.pulses[:,0] == k,4])))
282 for k in range(len(self.pulse_times)):
283 if len(self.pulse_times[k]) > 1:
284 for j in range(k+1, len(self.pulse_times)):
285 if len(self.pulse_times[j]) > 1:
286 di = 10*max(widths[k], widths[j])
287 dts = np.array([np.min(np.abs(self.pulse_times[k] - pt)) for pt in self.pulse_times[j]])
288 if k == 1 and j == 2:
289 print(di, np.sum(dts < di), len(dts))
290 plt.hist(dts, 50)
291 plt.show()
292 if np.sum(dts < 2*max_di)/len(dts) > 0.6:
293 r = k
294 if np.sum(self.fishes[k][2]) > np.sum(self.fishes[j][2]):
295 r = j
296 self.pulse_times[r] = np.array([])
297 self.pulses[self.pulses[:,0] == r] = -1
298 self.fishes[r] = []
299 self.pulses = self.pulses[self.pulses[:,0] >= 0,:]
300 """
301 # all labels:
302 self.labels = np.unique(self.pulses[:,0])
303 # report:
304 print(f'found {len(self.pulse_times)} fish:')
305 for k in range(len(self.pulse_times)):
306 print(f'{k:3d}: {len(self.pulse_times[k]):5d} pulses')
307 ## plot histogtram of distances:
308 #plt.hist(min_dists, 100)
309 #plt.show()
310 ## plot features:
311 """
312 nn = np.array([(k, len(self.pulse_times[k]))
313 for k in range(len(self.pulse_times))])
314 fig, axs = plt.subplots(5, 5, figsize=(15, 9),
315 constrained_layout=True)
316 ni = np.argsort(nn[:,1]) # largest cluster ...
317 ln = np.sort(nn[ni[-axs.size:],0]) # ... sort by label
318 for l, ax in zip(ln, axs.flat):
319 h = np.array([hh for ll, tt, hh in self.fishes if ll == l])
320 ax.plot(h.T, 'o-', ms=2, lw=0.5,
321 color=self.pulse_colors[l%len(self.pulse_colors)])
322 ax.text(0.05, 0.9, f'label: {l}', transform=ax.transAxes)
323 """
325 # set key bindings:
326 plt.rcParams['keymap.fullscreen'] = 'f'
327 plt.rcParams['keymap.pan'] = 'ctrl+m'
328 plt.rcParams['keymap.quit'] = 'ctrl+w, alt+q, q'
329 plt.rcParams['keymap.yscale'] = ''
330 plt.rcParams['keymap.xscale'] = ''
331 plt.rcParams['keymap.grid'] = ''
332 #plt.rcParams['keymap.all_axes'] = ''
334 # the figure:
335 plt.ioff()
336 splts = self.traces
337 if len(self.pulses) > 0:
338 splts += 1
339 self.fig, self.axs = plt.subplots(splts, 1, squeeze=False,
340 figsize=(15, 9), sharex=True)
341 self.axs = self.axs.flat
342 if self.traces == self.channels:
343 self.fig.canvas.manager.set_window_title(self.filename)
344 else:
345 cs = ' c%d' % self.show_channels[0]
346 self.fig.canvas.manager.set_window_title(self.filename + ' ' + cs)
347 self.fig.canvas.mpl_connect('key_press_event', self.keypress)
348 self.fig.canvas.mpl_connect('resize_event', self.resize)
349 self.fig.canvas.mpl_connect('pick_event', self.on_pick)
350 # trace plots:
351 for t in range(self.traces):
352 self.axs[t].set_ylabel(f'C-{self.show_channels[t]+1} [{self.unit}]')
353 #for t in range(self.traces-1):
354 # self.axs[t].xaxis.set_major_formatter(plt.NullFormatter())
355 if len(self.pulses) > 0:
356 self.axs[-1].set_ylim(0, self.fmax)
357 self.axs[-1].set_ylabel('IP freq [Hz]')
358 self.axs[-1].set_xlabel('Time [s]')
359 ht = self.axs[0].text(0.98, 0.05, '(ctrl+) page and arrow up, down, home, end: scroll', ha='right',
360 transform=self.axs[0].transAxes)
361 self.helptext.append(ht)
362 ht = self.axs[0].text(0.98, 0.1, '+, -, X, x: zoom time in/out', ha='right', transform=self.axs[0].transAxes)
363 self.helptext.append(ht)
364 ht = self.axs[0].text(0.98, 0.2, 'y, Y, v, V, ctrl+v, ctrl+V: zoom amplitudes out/in/max/default/max per trace/global max per trace', ha='right', transform=self.axs[0].transAxes)
365 self.helptext.append(ht)
366 ht = self.axs[0].text(0.98, 0.3, 'i, I: zoom IPI frequency in/out', ha='right', transform=self.axs[0].transAxes)
367 self.helptext.append(ht)
368 ht = self.axs[0].text(0.98, 0.4, 'p, P: play audio (display, all)', ha='right', transform=self.axs[0].transAxes)
369 self.helptext.append(ht)
370 ht = self.axs[0].text(0.98, 0.5, 'f: full screen', ha='right', transform=self.axs[0].transAxes)
371 self.helptext.append(ht)
372 ht = self.axs[0].text(0.98, 0.6, 'w: plot waveforms into png file', ha='right', transform=self.axs[0].transAxes)
373 self.helptext.append(ht)
374 ht = self.axs[0].text(0.98, 0.7, 'S: save audiosegment', ha='right', transform=self.axs[0].transAxes)
375 self.helptext.append(ht)
376 ht = self.axs[0].text(0.98, 0.8, 'q: quit', ha='right', transform=self.axs[0].transAxes)
377 self.helptext.append(ht)
378 ht = self.axs[0].text(0.98, 0.9, 'h: toggle this help', ha='right', transform=self.axs[0].transAxes)
379 self.helptext.append(ht)
380 # plot:
381 for ht in self.helptext:
382 ht.set_visible(self.help)
383 self.update_plots()
384 # feature plot:
385 if len(self.labels) > 0:
386 self.figf, self.axf = plt.subplots()
387 plt.show()
389 def __del__(self):
390 pass
391 #self.audio.close()
393 def plot_pulses(self, axs, plot=True, tfac=1.0):
395 def plot_pulse_traces(pulses, i, pak):
396 for t in range(self.traces):
397 c = self.show_channels[t]
398 p = pulses[pulses[:,2] == c,3]
399 if len(p) == 0:
400 continue
401 if plot or pak >= len(self.pulse_artist):
402 pa, = axs[t].plot(tfac*p/self.rate,
403 self.data[p,c], 'o', picker=5,
404 color=self.pulse_colors[i%len(self.pulse_colors)])
405 if not plot:
406 self.pulse_artist.append(pa)
407 else:
408 self.pulse_artist[pak].set_data(tfac*p/self.rate,
409 self.data[p,c])
410 self.pulse_artist[pak].set_color(self.pulse_colors[i%len(self.pulse_colors)])
411 #if len(p) > 1 and len(p) <= 10:
412 # self.pulse_artist[pak].set_markersize(15)
413 pak += 1
414 return pak
416 # pulses:
417 pak = 0
418 if self.show_gid:
419 for g in range(len(self.pulse_colors)):
420 pulses = self.pulses[self.pulses[:,1] % len(self.pulse_colors) == g,:]
421 pak = plot_pulse_traces(pulses, g, pak)
422 else:
423 for l in self.labels:
424 pulses = self.pulses[self.pulses[:,0] == l,:]
425 pak = plot_pulse_traces(pulses, l, pak)
426 while pak < len(self.pulse_artist):
427 self.pulse_artist[pak].set_data([], [])
428 pak += 1
429 # ipis:
430 for l in self.labels:
431 if l < len(self.pulse_times):
432 pt = self.pulse_times[l]/self.rate
433 if len(pt) > 10:
434 if plot or not l in self.ipis_labels:
435 pa, = axs[-1].plot(tfac*pt[:-1], 1.0/np.diff(pt),
436 '-o', picker=5,
437 color=self.pulse_colors[l%len(self.pulse_colors)])
438 if not plot:
439 self.ipis_artist.append(pa)
440 self.ipis_labels.append(l)
441 else:
442 iak = self.ipis_labels.index(l)
443 self.ipis_artist[iak].set_data(tfac*pt[:-1],
444 1.0/np.diff(pt))
446 def update_plots(self):
447 t0 = int(np.round(self.toffset * self.rate))
448 t1 = int(np.round((self.toffset + self.twindow) * self.rate))
449 if t1 > len(self.data):
450 t1 = len(self.data)
451 time = np.arange(t0, t1) / self.rate
452 for t in range(self.traces):
453 c = self.show_channels[t]
454 self.axs[t].set_xlim(self.toffset, self.toffset + self.twindow)
455 if self.trace_artist[t] == None:
456 self.trace_artist[t], = self.axs[t].plot(time, self.data[t0:t1,c])
457 else:
458 self.trace_artist[t].set_data(time, self.data[t0:t1,c])
459 if t1 - t0 < 200:
460 self.trace_artist[t].set_marker('o')
461 self.trace_artist[t].set_markersize(3)
462 else:
463 self.trace_artist[t].set_marker('None')
464 self.axs[t].set_ylim(self.ymin[t], self.ymax[t])
465 self.plot_pulses(self.axs, False)
466 self.fig.canvas.draw()
468 def on_pick(self, event):
469 # index of pulse artist:
470 pk = -1
471 for k, pa in enumerate(self.pulse_artist):
472 if event.artist == pa:
473 pk = k
474 break
475 li = -1
476 pi = -1
477 if pk >= 0:
478 # find label and pulses of pulse artist:
479 ll = self.labels[pk//self.traces]
480 cc = self.show_channels[pk % self.traces]
481 pulses = self.pulses[self.pulses[:,0] == ll,:]
482 gid = pulses[pulses[:,2] == cc,1][event.ind[0]]
483 if ll in self.ipis_labels:
484 li = self.ipis_labels.index(ll)
485 pi = self.pulse_gids[ll].index(gid)
486 else:
487 ik = -1
488 for k, ia in enumerate(self.ipis_artist):
489 if event.artist == ia:
490 ik = k
491 break
492 if ik < 0:
493 return
494 li = ik
495 ll = self.ipis_labels[li]
496 pi = event.ind[0]
497 gid = self.pulse_gids[ll][pi]
498 # mark pulses:
499 pulses = self.pulses[self.pulses[:,0] == ll,:]
500 pulses = pulses[pulses[:,1] == gid,:]
501 for t in range(self.traces):
502 c = self.show_channels[t]
503 pt = pulses[pulses[:,2] == c,3]
504 if len(pt) > 0:
505 if self.marker_artist[t] is None:
506 pa, = self.axs[t].plot(pt[0]/self.rate,
507 self.data[pt[0],c], 'o', ms=10,
508 color=self.pulse_colors[ll%len(self.pulse_colors)])
509 self.marker_artist[t] = pa
510 else:
511 self.marker_artist[t].set_data(pt[0]/self.rate,
512 self.data[pt[0],c])
513 self.marker_artist[t].set_color(self.pulse_colors[ll%len(self.pulse_colors)])
514 elif self.marker_artist[t] is not None:
515 self.marker_artist[t].set_data([], [])
516 # mark ipi:
517 pt0 = -1.0
518 pt1 = -1.0
519 pf = -1.0
520 if pi >= 0:
521 pt0 = self.pulse_times[ll][pi]/self.rate
522 pt1 = self.pulse_times[ll][pi+1]/self.rate
523 pf = 1.0/(pt1-pt0)
524 if self.marker_artist[self.traces] is None:
525 pa, = self.axs[self.traces].plot(pt0, pf, 'o', ms=10,
526 color=self.pulse_colors[ll%len(self.pulse_colors)])
527 self.marker_artist[self.traces] = pa
528 else:
529 self.marker_artist[self.traces].set_data(pt0, pf)
530 self.marker_artist[self.traces].set_color(self.pulse_colors[ll%len(self.pulse_colors)])
531 elif not self.marker_artist[self.traces] is None:
532 self.marker_artist[self.traces].set_data([], [])
533 self.fig.canvas.draw()
534 # show features:
535 if not self.axf is None and not self.fig is None:
536 heights = np.zeros(self.channels)
537 heights[pulses[:,2]] = \
538 self.data[pulses[:,3],pulses[:,2]] - \
539 self.data[pulses[:,4],pulses[:,2]]
540 self.axf.plot(heights, color=self.pulse_colors[ll%len(self.pulse_colors)])
541 print(f'label={ll:4d} gid={gid:5d} t={pt0:8.4f}s')
542 self.figf.canvas.draw()
544 def resize(self, event):
545 # print('resized', event.width, event.height)
546 leftpixel = 80.0
547 rightpixel = 20.0
548 bottompixel = 50.0
549 toppixel = 20.0
550 x0 = leftpixel / event.width
551 x1 = 1.0 - rightpixel / event.width
552 y0 = bottompixel / event.height
553 y1 = 1.0 - toppixel / event.height
554 self.fig.subplots_adjust(left=x0, right=x1, bottom=y0, top=y1,
555 hspace=0)
557 def keypress(self, event):
558 # print('pressed', event.key)
559 if event.key in '+=X':
560 if self.twindow * self.rate > 20:
561 self.twindow *= 0.5
562 self.update_plots()
563 elif event.key in '-x':
564 if self.twindow < self.tmax:
565 self.twindow *= 2.0
566 self.update_plots()
567 elif event.key == 'pagedown':
568 if self.toffset + 0.5 * self.twindow < self.tmax:
569 self.toffset += 0.5 * self.twindow
570 self.update_plots()
571 elif event.key == 'pageup':
572 if self.toffset > 0:
573 self.toffset -= 0.5 * self.twindow
574 if self.toffset < 0.0:
575 self.toffset = 0.0
576 self.update_plots()
577 elif event.key == 'ctrl+pagedown':
578 if self.toffset + 5.0 * self.twindow < self.tmax:
579 self.toffset += 5.0 * self.twindow
580 self.update_plots()
581 elif event.key == 'ctrl+pageup':
582 if self.toffset > 0:
583 self.toffset -= 5.0 * self.twindow
584 if self.toffset < 0.0:
585 self.toffset = 0.0
586 self.update_plots()
587 elif event.key == 'down':
588 if self.toffset + self.twindow < self.tmax:
589 self.toffset += 0.05 * self.twindow
590 self.update_plots()
591 elif event.key == 'up':
592 if self.toffset > 0.0:
593 self.toffset -= 0.05 * self.twindow
594 if self.toffset < 0.0:
595 self.toffset = 0.0
596 self.update_plots()
597 elif event.key == 'home':
598 if self.toffset > 0.0:
599 self.toffset = 0.0
600 self.update_plots()
601 elif event.key == 'end':
602 toffs = np.floor(self.tmax / self.twindow) * self.twindow
603 if self.tmax - toffs <= 0.0:
604 toffs -= self.twindow
605 if self.tmax - toffs < self.twindow/2:
606 toffs -= self.twindow/2
607 if self.toffset < toffs:
608 self.toffset = toffs
609 self.update_plots()
610 elif event.key == 'y':
611 for t in range(self.traces):
612 h = self.ymax[t] - self.ymin[t]
613 c = 0.5 * (self.ymax[t] + self.ymin[t])
614 self.ymin[t] = c - h
615 self.ymax[t] = c + h
616 self.axs[t].set_ylim(self.ymin[t], self.ymax[t])
617 self.fig.canvas.draw()
618 elif event.key == 'Y':
619 for t in range(self.traces):
620 h = 0.25 * (self.ymax[t] - self.ymin[t])
621 c = 0.5 * (self.ymax[t] + self.ymin[t])
622 self.ymin[t] = c - h
623 self.ymax[t] = c + h
624 self.axs[t].set_ylim(self.ymin[t], self.ymax[t])
625 self.fig.canvas.draw()
626 elif event.key == 'v':
627 t0 = int(np.round(self.toffset * self.rate))
628 t1 = int(np.round((self.toffset + self.twindow) * self.rate))
629 min = np.min(self.data[t0:t1,self.show_channels])
630 max = np.max(self.data[t0:t1,self.show_channels])
631 h = 0.53 * (max - min)
632 c = 0.5 * (max + min)
633 self.ymin[:] = c - h
634 self.ymax[:] = c + h
635 for t in range(self.traces):
636 self.axs[t].set_ylim(self.ymin[t], self.ymax[t])
637 self.fig.canvas.draw()
638 elif event.key == 'ctrl+v':
639 t0 = int(np.round(self.toffset * self.rate))
640 t1 = int(np.round((self.toffset + self.twindow) * self.rate))
641 for t in range(self.traces):
642 min = np.min(self.data[t0:t1,self.show_channels[t]])
643 max = np.max(self.data[t0:t1,self.show_channels[t]])
644 h = 0.53 * (max - min)
645 c = 0.5 * (max + min)
646 self.ymin[t] = c - h
647 self.ymax[t] = c + h
648 self.axs[t].set_ylim(self.ymin[t], self.ymax[t])
649 self.fig.canvas.draw()
650 elif event.key == 'ctrl+V':
651 for t in range(self.traces):
652 min = np.min(self.data[:,self.show_channels[t]])
653 max = np.max(self.data[:,self.show_channels[t]])
654 h = 0.53 * (max - min)
655 c = 0.5 * (max + min)
656 self.ymin[t] = c - h
657 self.ymax[t] = c + h
658 self.axs[t].set_ylim(self.ymin[t], self.ymax[t])
659 self.fig.canvas.draw()
660 elif event.key == 'V':
661 self.ymin[:] = -1.0
662 self.ymax[:] = +1.0
663 for t in range(self.traces):
664 self.axs[t].set_ylim(self.ymin[t], self.ymax[t])
665 self.fig.canvas.draw()
666 elif event.key == 'c':
667 for t in range(self.traces):
668 dy = self.ymax[t] - self.ymin[t]
669 self.ymin[t] = -dy/2
670 self.ymax[t] = +dy/2
671 self.axs[t].set_ylim(self.ymin[t], self.ymax[t])
672 self.fig.canvas.draw()
673 elif event.key == 'g':
674 self.show_gid = not self.show_gid
675 self.plot_pulses(self.axs, False)
676 self.fig.canvas.draw()
677 elif event.key == 'i':
678 if len(self.pulses) > 0:
679 self.fmax *= 2
680 self.axs[-1].set_ylim(0.0, self.fmax)
681 self.fig.canvas.draw()
682 elif event.key == 'I':
683 if len(self.pulses) > 0:
684 self.fmax /= 2
685 self.axs[-1].set_ylim(0.0, self.fmax)
686 self.fig.canvas.draw()
687 elif event.key in ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']:
688 cc = int(event.key)
689 # TODO: this is not yet what we want:
690 """
691 if cc < self.channels:
692 self.axs[cc].set_visible(not self.axs[cc].get_visible())
693 self.fig.canvas.draw()
694 """
695 elif event.key in 'h':
696 self.help = not self.help
697 for ht in self.helptext:
698 ht.set_visible(self.help)
699 self.fig.canvas.draw()
700 elif event.key in 'p':
701 self.play_segment()
702 elif event.key in 'P':
703 self.play_all()
704 elif event.key in 'S':
705 self.save_segment()
706 elif event.key in 'w':
707 self.plot_traces()
709 def play_segment(self):
710 t0 = int(np.round(self.toffset * self.rate))
711 t1 = int(np.round((self.toffset + self.twindow) * self.rate))
712 playdata = 1.0 * np.mean(self.data[t0:t1,self.show_channels], 1)
713 f = 0.1 if self.twindow > 0.5 else 0.1*self.twindow
714 fade(playdata, self.rate, f)
715 self.audio.play(playdata, self.rate, blocking=False)
717 def play_all(self):
718 self.audio.play(np.mean(self.data[:,self.show_channels], 1),
719 self.rate, blocking=False)
721 def save_segment(self):
722 t0s = int(np.round(self.toffset))
723 t1s = int(np.round(self.toffset + self.twindow))
724 t0 = int(np.round(self.toffset * self.rate))
725 t1 = int(np.round((self.toffset + self.twindow) * self.rate))
726 filename = self.filename.split('.')[0]
727 if self.traces == self.channels:
728 segment_filename = f'{filename}-{t0s:.4g}s-{t1s:.4g}s.wav'
729 write_audio(segment_filename, self.data[t0:t1,:], self.rate)
730 else:
731 segment_filename = f'{filename}-{t0s:.4g}s-{t1s:.4g}s-c{self.show_channels[0]}.wav'
732 write_audio(segment_filename,
733 self.data[t0:t1,self.show_channels], self.rate)
734 print('saved segment to: ' , segment_filename)
736 def plot_traces(self):
737 splts = self.traces
738 if len(self.pulses) > 0:
739 splts += 1
740 fig, axs = plt.subplots(splts, 1, squeeze=False, sharex=True,
741 figsize=(15, 9))
742 axs = axs.flat
743 fig.subplots_adjust(left=0.06, right=0.99, bottom=0.05, top=0.97,
744 hspace=0)
745 name = self.filename.split('.')[0]
746 figfile = f'{name}-{self.toffset:.4g}s-traces.png'
747 if self.traces < self.channels:
748 figfile = f'{name}-{self.toffset:.4g}s-c{self.show_channels[0]}-traces.png'
749 axs[0].set_title(self.filename)
750 t0 = int(np.round(self.toffset * self.rate))
751 t1 = int(np.round((self.toffset + self.twindow) * self.rate))
752 if t1>len(self.data):
753 t1 = len(self.data)
754 time = np.arange(t0, t1)/self.rate
755 if self.toffset < 1.0 and self.twindow < 1.0:
756 axs[-1].set_xlabel('Time [ms]')
757 for t in range(self.traces):
758 c = self.show_channels[t]
759 axs[t].set_xlim(1000.0 * self.toffset,
760 1000.0 * (self.toffset + self.twindow))
761 axs[t].plot(1000.0 * time, self.data[t0:t1,c])
762 self.plot_pulses(axs, True, 1000.0)
763 else:
764 axs[-1].set_xlabel('Time [s]')
765 for t in range(self.traces):
766 c = self.show_channels[t]
767 axs[t].set_xlim(self.toffset, self.toffset + self.twindow)
768 axs[t].plot(time, self.data[t0:t1,c])
769 self.plot_pulses(axs, True, 1.0)
770 for t in range(self.traces):
771 c = self.show_channels[t]
772 axs[t].set_ylim(self.ymin[t], self.ymax[t])
773 axs[t].set_ylabel(f'C-{c+1} [{self.unit}]')
774 if len(self.pulses) > 0:
775 axs[-1].set_ylabel('IP freq [Hz]')
776 axs[-1].set_ylim(0.0, self.fmax)
777 #for t in range(self.traces-1):
778 # axs[t].xaxis.set_major_formatter(plt.NullFormatter())
779 fig.savefig(figfile, dpi=200)
780 plt.close(fig)
781 print('saved waveform figure to', figfile)
784def short_user_warning(message, category, filename, lineno, file=None, line=''):
785 if file is None:
786 file = sys.stderr
787 if category == UserWarning:
788 file.write('%s line %d: %s\n' % ('/'.join(filename.split('/')[-2:]), lineno, message))
789 else:
790 s = warnings.formatwarning(message, category, filename, lineno, line)
791 file.write(s)
794def main(cargs=None):
795 warnings.showwarning = short_user_warning
797 # config file name:
798 cfgfile = __package__ + '.cfg'
800 # command line arguments:
801 if cargs is None:
802 cargs = sys.argv[1:]
803 parser = argparse.ArgumentParser(
804 description='Browse mutlichannel EOD recordings.',
805 epilog='version %s by Benda-Lab (2022-%s)' % (__version__, __year__))
806 parser.add_argument('--version', action='version', version=__version__)
807 parser.add_argument('-v', action='count', dest='verbose')
808 parser.add_argument('-c', dest='channels', default='',
809 type=str, metavar='CHANNELS',
810 help='Comma separated list of channels to be displayed (first channel is 0).')
811 parser.add_argument('-t', dest='tmax', default=None,
812 type=float, metavar='TMAX',
813 help='Process and show only the first TMAX seconds.')
814 parser.add_argument('-f', dest='fcutoff', default=None,
815 type=float, metavar='FREQ',
816 help='Cutoff frequency of optional high-pass filter.')
817 parser.add_argument('-p', dest='pulses', action='store_true',
818 help='detect pulse fish EODs')
819 parser.add_argument('file', nargs=1, default='', type=str,
820 help='name of the file with the time series data')
821 args = parser.parse_args(cargs)
822 filepath = args.file[0]
823 cs = [s.strip() for s in args.channels.split(',')]
824 channels = [int(c) for c in cs if len(c)>0]
825 tmax = args.tmax
826 fcutoff = args.fcutoff
827 pulses = args.pulses
829 # set verbosity level from command line:
830 verbose = 0
831 if args.verbose != None:
832 verbose = args.verbose
834 # load data:
835 filename = os.path.basename(filepath)
836 with DataLoader(filepath, 10*60.0, 5.0, verbose) as data:
837 SignalPlot(data, data.rate, data.unit, filename,
838 channels, tmax, fcutoff, pulses)
842if __name__ == '__main__':
843 main()