Coverage for src/thunderfish/thunderbrowse.py: 0%
577 statements
« prev ^ index » next coverage.py v7.5.0, created at 2024-04-29 16:21 +0000
« prev ^ index » next coverage.py v7.5.0, created at 2024-04-29 16:21 +0000
1import sys
2import os
3import warnings
4import argparse
5import numpy as np
6from scipy.signal import butter, sosfiltfilt
7import matplotlib.pyplot as plt
8from audioio import PlayAudio, fade, write_audio
9from thunderlab.dataloader import DataLoader
10from thunderlab.eventdetection import detect_peaks, median_std_threshold
11from .version import __version__, __year__
12from .pulses import detect_pulses
15class SignalPlot:
16 def __init__(self, data, samplerate, unit, filename,
17 show_channels=[], tmax=None, fcutoff=None,
18 pulses=False):
19 self.filename = filename
20 self.samplerate = samplerate
21 self.data = data
22 self.channels = self.data.shape[1] if len(self.data.shape) > 1 else 1
23 self.unit = unit
24 self.tmax = (len(self.data)-1)/self.samplerate
25 if not tmax is None:
26 self.tmax = tmax
27 self.data = data[:int(tmax*self.samplerate),:]
28 self.toffset = 0.0
29 self.twindow = 10.0
30 if self.twindow > self.tmax:
31 self.twindow = np.round(2 ** (np.floor(np.log(self.tmax) / np.log(2.0)) + 1.0))
32 if not tmax is None:
33 self.twindow = tmax
34 self.pulses = np.zeros((0, 3), dtype=int)
35 self.labels = []
36 self.fishes = []
37 self.pulse_times = []
38 self.pulse_gids = []
39 if len(show_channels) == 0:
40 self.show_channels = np.arange(self.channels)
41 else:
42 self.show_channels = np.array(show_channels)
43 self.traces = len(self.show_channels)
44 self.ymin = -1.0 * np.ones(self.traces)
45 self.ymax = +1.0 * np.ones(self.traces)
46 self.fmax = 100.0
47 self.trace_artist = [None] * self.traces
48 self.show_gid = False
49 self.pulse_artist = []
50 self.marker_artist = [None] * (self.traces + 1)
51 self.ipis_artist = []
52 self.ipis_labels = []
53 self.figf = None
54 self.axf = None
55 self.pulse_colors = ['C1', 'C2', 'C3', 'C4', 'C5', 'C6', 'C7', 'C8', 'C9', 'C0']
56 self.help = False
57 self.helptext = []
58 self.audio = PlayAudio()
60 # filter data:
61 if not fcutoff is None:
62 sos = butter(2, fcutoff, 'high', fs=samplerate, output='sos')
63 self.data = sosfiltfilt(sos, self.data[:], 0)
65 # pulse detection:
66 if pulses:
67 # label, group, channel, peak index, trough index
68 all_pulses = np.zeros((0, 5), dtype=int)
69 for c in range(self.channels):
70 #thresh = 1*np.std(self.data[:int(2*self.samplerate),c])
71 thresh = median_std_threshold(self.data[:,c], self.samplerate,
72 thresh_fac=6.0)
73 thresh = 0.01
74 #p, t = detect_peaks(self.data[:,c], thresh)
75 p, t, w, h = detect_pulses(self.data[:,c], self.samplerate,
76 thresh,
77 min_rel_slope_diff=0.25,
78 min_width=0.0001,
79 max_width=0.01,
80 width_fac=5.0)
81 # label, group, channel, peak, trough:
82 pulses = np.hstack((np.arange(len(p))[:,np.newaxis],
83 np.zeros((len(p), 1), dtype=int),
84 np.ones((len(p), 1), dtype=int)*c,
85 p[:,np.newaxis], t[:,np.newaxis]))
86 all_pulses = np.vstack((all_pulses, pulses))
87 self.pulses = all_pulses[np.argsort(all_pulses[:,3]),:]
88 # grouping over channels:
89 max_di = int(0.0002*self.samplerate) # TODO: parameter
90 l = -1
91 k = 0
92 while k < len(self.pulses):
93 tp = self.pulses[k,3]
94 tt = self.pulses[k,4]
95 height = self.data[self.pulses[k,3],self.pulses[k,2]] - \
96 self.data[self.pulses[k,4],self.pulses[k,2]]
97 channel_counts = np.zeros(self.channels, dtype=int)
98 channel_counts[self.pulses[k,2]] += 1
99 for c in range(1, 3*self.channels):
100 if k+c >= len(self.pulses):
101 break
102 # pulse too far away:
103 if channel_counts[self.pulses[k+c,2]] > 1 or \
104 (np.abs(self.pulses[k+c,3] - tp) > max_di and
105 np.abs(self.pulses[k+c,3] - tt) > max_di and
106 np.abs(self.pulses[k+c,4] - tp) > max_di and
107 np.abs(self.pulses[k+c,4] - tt) > max_di):
108 break
109 channel_counts[self.pulses[k+c,2]] += 1
110 height_kc = self.data[self.pulses[k+c,3],self.pulses[k+c,2]] - \
111 self.data[self.pulses[k+c,4],self.pulses[k+c,2]]
112 # heighest pulse sets time reference:
113 if height_kc > height:
114 tp = self.pulses[k+c,3]
115 tt = self.pulses[k+c,4]
116 height = height_kc
117 # all pulses too small:
118 if height < 0.02: # TODO parameter
119 self.pulses[k:k+c,0] = -1
120 k += c
121 continue
122 # new label:
123 l += 1
124 # remove lost pulses:
125 for j in range(c):
126 if (np.abs(self.pulses[k+j,3] - tp) > max_di and
127 np.abs(self.pulses[k+j,3] - tt) > max_di and
128 np.abs(self.pulses[k+j,4] - tp) > max_di and
129 np.abs(self.pulses[k+j,4] - tt) > max_di):
130 self.pulses[k+j,0] = -1
131 channel_counts[self.pulses[k+j,2]] -= 1
132 else:
133 self.pulses[k+j,0] = l
134 self.pulses[k+j,1] = l
135 # keep only the largest pulse of each channel:
136 pulses = self.pulses[k:k+c,:]
137 for dc in np.where(channel_counts > 1)[0]:
138 idx = np.where(self.pulses[k:k+c,2] == dc)[0]
139 heights = self.data[pulses[idx,3],dc] - \
140 self.data[pulses[idx,4],dc]
141 for i in range(len(idx)):
142 if i != np.argmax(heights):
143 channel_counts[self.pulses[k+idx[i],2]] -= 1
144 self.pulses[k+idx[i],0] = -1
145 k += c
146 self.pulses = self.pulses[self.pulses[:,0] >= 0,:]
148 # clustering:
149 min_dists = []
150 recent = []
151 k = 0
152 while k < len(self.pulses):
153 # select pulse group:
154 j = k
155 gid = self.pulses[j,1]
156 for c in range(self.channels):
157 k += 1
158 if k >= len(self.pulses) or \
159 self.pulses[k,1] != gid:
160 break
161 heights = np.zeros(self.channels)
162 heights[self.pulses[j:k,2]] = \
163 self.data[self.pulses[j:k,3],self.pulses[j:k,2]] - \
164 self.data[self.pulses[j:k,4],self.pulses[j:k,2]]
165 # time of largest pulse:
166 pulse_time = self.pulses[j+np.argmax(heights[self.pulses[j:k,2]]),3]
167 # assign to cluster:
168 if len(self.pulse_times) == 0:
169 label = len(self.pulse_times)
170 self.pulse_times.append([])
171 self.pulse_gids.append([])
172 else:
173 # compute metrics of recent fishes:
174 # mean relative height difference:
175 dists = np.array([np.mean(np.abs(hh - heights)/np.max(hh))
176 for ll, tt, hh in recent])
177 thresh = 0.1 # TODO: make parameter
178 # distance between pulses:
179 ipis = np.array([(pulse_time - tt)/self.samplerate
180 for ll, tt, hh in recent])
181 ## how can ipis be 0, or just one sample?
182 ##if len(ipis[ipis<0.001]) > 0:
183 ## print(ipis[ipis<0.001])
184 # ensure minimum IP distance:
185 dists[1/ipis > 300.0] = 2*np.max(dists) # TODO: make parameter
186 # minimum ditance:
187 min_dist_idx = np.argmin(dists)
188 min_dists.append(dists[min_dist_idx])
189 if dists[min_dist_idx] < thresh:
190 label = recent[min_dist_idx][0]
191 else:
192 label = len(self.pulse_times)
193 self.pulse_times.append([])
194 self.pulse_gids.append([])
195 self.pulses[j:k,0] = label
196 self.pulse_times[label].append(pulse_time)
197 self.pulse_gids[label].append(gid)
198 self.fishes.append([label, pulse_time, heights])
199 recent.append([label, pulse_time, heights])
200 # remove old fish:
201 for i, (ll, tt, hh) in enumerate(recent):
202 # TODO: make parameter:
203 if (pulse_time - tt)/self.samplerate <= 0.2:
204 recent = recent[i:]
205 break
206 # only consider the n most recent pulses of a fish:
207 n = 5 # TODO make parameter
208 labels = np.array([ll for ll, tt, hh in recent])
209 if np.sum(labels == label) > n:
210 del recent[np.where(labels == label)[0][0]]
211 # pulse times to arrays:
212 for k in range(len(self.pulse_times)):
213 self.pulse_times[k] = np.array(self.pulse_times[k])
217 """
218 # find temporally missing pulses:
219 npulses = np.array([len(pts) for pts in self.pulse_times],
220 dtype=int)
221 idx = np.argsort(npulses)
222 for i in range(len(idx)):
223 li = idx[len(idx)-1-i]
224 if len(self.pulse_times[li]) < 10 or \
225 len(self.pulse_times[li])/npulses[li] < 0.5:
226 continue
227 ipis = np.diff(self.pulse_times[li])
228 n = 4 # TODO: make parameter
229 k = 0
230 while k < len(ipis)-n:
231 mipi = np.median(ipis[k:k+n])
232 if ipis[k+n-2] > 1.8*mipi:
233 # search for pulse closest to pt:
234 pt = self.pulse_times[li][k+n-2] + mipi
235 mlj = -1
236 mpj = -1
237 mdj = 10*mipi
238 for lj in range(len(self.pulse_times)):
239 if lj == li or len(self.pulse_times[lj]) == 0:
240 continue
241 pj = np.argmin(np.abs(self.pulse_times[lj] - pt))
242 dj = np.abs(self.pulse_times[lj][pj] - pt)
243 if dj < int(0.001*self.samplerate) and dj < mdj:
244 mdj = dj
245 mpj = pj
246 mlj = lj
247 if mlj >= 0:
248 # there is a pulse close to pt:
249 ptj = self.pulse_times[mlj][mpj]
250 pulses = self.pulses[self.pulses[:,0] == mlj,:]
251 gid = pulses[np.argmin(np.abs(pulses[:,3] - ptj)),1]
252 self.pulse_times[li] = np.insert(self.pulse_times[li], k+n-1, ptj)
253 self.pulse_gids[li].insert(k+n-1, gid)
254 # maybe don't delete but always duplicate and flag it:
255 if False: # can be deleted
256 self.pulse_times[mlj] = np.delete(self.pulse_times[mlj], mpj)
257 self.pulse_gids[mlj].pop(mpj)
258 self.pulses[self.pulses[:,1] == gid,0] = li
259 else: # pulse needs to be duplicated:
260 self.pulses[self.pulses[:,1] == gid,0] = li
261 ipis = np.diff(self.pulse_times[li])
262 k += 1
266 # clean up pulses:
267 for l in range(len(self.pulse_times)):
268 if len(self.pulse_times[l])/npulses[l] < 0.5:
269 self.pulse_times[l] = np.array([])
270 self.pulse_gids[l] = []
271 self.pulses[self.pulses[:,0] == l,0] = -1
272 self.pulses = self.pulses[self.pulses[:,0] >= 0,:]
273 """
275 """
276 # remove labels that are too close to others:
277 widths = np.zeros(len(self.pulse_times), dtype=int)
278 for k in range(len(self.pulse_times)):
279 widths[k] = int(np.mean(np.abs(self.pulses[self.pulses[:,0] == k,3] - self.pulses[self.pulses[:,0] == k,4])))
280 for k in range(len(self.pulse_times)):
281 if len(self.pulse_times[k]) > 1:
282 for j in range(k+1, len(self.pulse_times)):
283 if len(self.pulse_times[j]) > 1:
284 di = 10*max(widths[k], widths[j])
285 dts = np.array([np.min(np.abs(self.pulse_times[k] - pt)) for pt in self.pulse_times[j]])
286 if k == 1 and j == 2:
287 print(di, np.sum(dts < di), len(dts))
288 plt.hist(dts, 50)
289 plt.show()
290 if np.sum(dts < 2*max_di)/len(dts) > 0.6:
291 r = k
292 if np.sum(self.fishes[k][2]) > np.sum(self.fishes[j][2]):
293 r = j
294 self.pulse_times[r] = np.array([])
295 self.pulses[self.pulses[:,0] == r] = -1
296 self.fishes[r] = []
297 self.pulses = self.pulses[self.pulses[:,0] >= 0,:]
298 """
299 # all labels:
300 self.labels = np.unique(self.pulses[:,0])
301 # report:
302 print(f'found {len(self.pulse_times)} fish:')
303 for k in range(len(self.pulse_times)):
304 print(f'{k:3d}: {len(self.pulse_times[k]):5d} pulses')
305 ## plot histogtram of distances:
306 #plt.hist(min_dists, 100)
307 #plt.show()
308 ## plot features:
309 """
310 nn = np.array([(k, len(self.pulse_times[k]))
311 for k in range(len(self.pulse_times))])
312 fig, axs = plt.subplots(5, 5, figsize=(15, 9),
313 constrained_layout=True)
314 ni = np.argsort(nn[:,1]) # largest cluster ...
315 ln = np.sort(nn[ni[-axs.size:],0]) # ... sort by label
316 for l, ax in zip(ln, axs.flat):
317 h = np.array([hh for ll, tt, hh in self.fishes if ll == l])
318 ax.plot(h.T, 'o-', ms=2, lw=0.5,
319 color=self.pulse_colors[l%len(self.pulse_colors)])
320 ax.text(0.05, 0.9, f'label: {l}', transform=ax.transAxes)
321 """
323 # set key bindings:
324 plt.rcParams['keymap.fullscreen'] = 'f'
325 plt.rcParams['keymap.pan'] = 'ctrl+m'
326 plt.rcParams['keymap.quit'] = 'ctrl+w, alt+q, q'
327 plt.rcParams['keymap.yscale'] = ''
328 plt.rcParams['keymap.xscale'] = ''
329 plt.rcParams['keymap.grid'] = ''
330 #plt.rcParams['keymap.all_axes'] = ''
332 # the figure:
333 plt.ioff()
334 splts = self.traces
335 if len(self.pulses) > 0:
336 splts += 1
337 self.fig, self.axs = plt.subplots(splts, 1, squeeze=False,
338 figsize=(15, 9), sharex=True)
339 self.axs = self.axs.flat
340 if self.traces == self.channels:
341 self.fig.canvas.manager.set_window_title(self.filename)
342 else:
343 cs = ' c%d' % self.show_channels[0]
344 self.fig.canvas.manager.set_window_title(self.filename + ' ' + cs)
345 self.fig.canvas.mpl_connect('key_press_event', self.keypress)
346 self.fig.canvas.mpl_connect('resize_event', self.resize)
347 self.fig.canvas.mpl_connect('pick_event', self.on_pick)
348 # trace plots:
349 for t in range(self.traces):
350 self.axs[t].set_ylabel(f'C-{self.show_channels[t]+1} [{self.unit}]')
351 #for t in range(self.traces-1):
352 # self.axs[t].xaxis.set_major_formatter(plt.NullFormatter())
353 if len(self.pulses) > 0:
354 self.axs[-1].set_ylim(0, self.fmax)
355 self.axs[-1].set_ylabel('IP freq [Hz]')
356 self.axs[-1].set_xlabel('Time [s]')
357 ht = self.axs[0].text(0.98, 0.05, '(ctrl+) page and arrow up, down, home, end: scroll', ha='right',
358 transform=self.axs[0].transAxes)
359 self.helptext.append(ht)
360 ht = self.axs[0].text(0.98, 0.1, '+, -, X, x: zoom time in/out', ha='right', transform=self.axs[0].transAxes)
361 self.helptext.append(ht)
362 ht = self.axs[0].text(0.98, 0.2, 'y, Y, v, V, ctrl+v, ctrl+V: zoom amplitudes out/in/max/default/max per trace/global max per trace', ha='right', transform=self.axs[0].transAxes)
363 self.helptext.append(ht)
364 ht = self.axs[0].text(0.98, 0.3, 'i, I: zoom IPI frequency in/out', ha='right', transform=self.axs[0].transAxes)
365 self.helptext.append(ht)
366 ht = self.axs[0].text(0.98, 0.4, 'p, P: play audio (display, all)', ha='right', transform=self.axs[0].transAxes)
367 self.helptext.append(ht)
368 ht = self.axs[0].text(0.98, 0.5, 'f: full screen', ha='right', transform=self.axs[0].transAxes)
369 self.helptext.append(ht)
370 ht = self.axs[0].text(0.98, 0.6, 'w: plot waveforms into png file', ha='right', transform=self.axs[0].transAxes)
371 self.helptext.append(ht)
372 ht = self.axs[0].text(0.98, 0.7, 'S: save audiosegment', ha='right', transform=self.axs[0].transAxes)
373 self.helptext.append(ht)
374 ht = self.axs[0].text(0.98, 0.8, 'q: quit', ha='right', transform=self.axs[0].transAxes)
375 self.helptext.append(ht)
376 ht = self.axs[0].text(0.98, 0.9, 'h: toggle this help', ha='right', transform=self.axs[0].transAxes)
377 self.helptext.append(ht)
378 # plot:
379 for ht in self.helptext:
380 ht.set_visible(self.help)
381 self.update_plots()
382 # feature plot:
383 if len(self.labels) > 0:
384 self.figf, self.axf = plt.subplots()
385 plt.show()
387 def __del__(self):
388 pass
389 #self.audio.close()
391 def plot_pulses(self, axs, plot=True, tfac=1.0):
393 def plot_pulse_traces(pulses, i, pak):
394 for t in range(self.traces):
395 c = self.show_channels[t]
396 p = pulses[pulses[:,2] == c,3]
397 if len(p) == 0:
398 continue
399 if plot or pak >= len(self.pulse_artist):
400 pa, = axs[t].plot(tfac*p/self.samplerate,
401 self.data[p,c], 'o', picker=5,
402 color=self.pulse_colors[i%len(self.pulse_colors)])
403 if not plot:
404 self.pulse_artist.append(pa)
405 else:
406 self.pulse_artist[pak].set_data(tfac*p/self.samplerate,
407 self.data[p,c])
408 self.pulse_artist[pak].set_color(self.pulse_colors[i%len(self.pulse_colors)])
409 #if len(p) > 1 and len(p) <= 10:
410 # self.pulse_artist[pak].set_markersize(15)
411 pak += 1
412 return pak
414 # pulses:
415 pak = 0
416 if self.show_gid:
417 for g in range(len(self.pulse_colors)):
418 pulses = self.pulses[self.pulses[:,1] % len(self.pulse_colors) == g,:]
419 pak = plot_pulse_traces(pulses, g, pak)
420 else:
421 for l in self.labels:
422 pulses = self.pulses[self.pulses[:,0] == l,:]
423 pak = plot_pulse_traces(pulses, l, pak)
424 while pak < len(self.pulse_artist):
425 self.pulse_artist[pak].set_data([], [])
426 pak += 1
427 # ipis:
428 for l in self.labels:
429 if l < len(self.pulse_times):
430 pt = self.pulse_times[l]/self.samplerate
431 if len(pt) > 10:
432 if plot or not l in self.ipis_labels:
433 pa, = axs[-1].plot(tfac*pt[:-1], 1.0/np.diff(pt),
434 '-o', picker=5,
435 color=self.pulse_colors[l%len(self.pulse_colors)])
436 if not plot:
437 self.ipis_artist.append(pa)
438 self.ipis_labels.append(l)
439 else:
440 iak = self.ipis_labels.index(l)
441 self.ipis_artist[iak].set_data(tfac*pt[:-1],
442 1.0/np.diff(pt))
444 def update_plots(self):
445 t0 = int(np.round(self.toffset * self.samplerate))
446 t1 = int(np.round((self.toffset + self.twindow) * self.samplerate))
447 if t1 > len(self.data):
448 t1 = len(self.data)
449 time = np.arange(t0, t1) / self.samplerate
450 for t in range(self.traces):
451 c = self.show_channels[t]
452 self.axs[t].set_xlim(self.toffset, self.toffset + self.twindow)
453 if self.trace_artist[t] == None:
454 self.trace_artist[t], = self.axs[t].plot(time, self.data[t0:t1,c])
455 else:
456 self.trace_artist[t].set_data(time, self.data[t0:t1,c])
457 if t1 - t0 < 200:
458 self.trace_artist[t].set_marker('o')
459 self.trace_artist[t].set_markersize(3)
460 else:
461 self.trace_artist[t].set_marker('None')
462 self.axs[t].set_ylim(self.ymin[t], self.ymax[t])
463 self.plot_pulses(self.axs, False)
464 self.fig.canvas.draw()
466 def on_pick(self, event):
467 # index of pulse artist:
468 pk = -1
469 for k, pa in enumerate(self.pulse_artist):
470 if event.artist == pa:
471 pk = k
472 break
473 li = -1
474 pi = -1
475 if pk >= 0:
476 # find label and pulses of pulse artist:
477 ll = self.labels[pk//self.traces]
478 cc = self.show_channels[pk % self.traces]
479 pulses = self.pulses[self.pulses[:,0] == ll,:]
480 gid = pulses[pulses[:,2] == cc,1][event.ind[0]]
481 if ll in self.ipis_labels:
482 li = self.ipis_labels.index(ll)
483 pi = self.pulse_gids[ll].index(gid)
484 else:
485 ik = -1
486 for k, ia in enumerate(self.ipis_artist):
487 if event.artist == ia:
488 ik = k
489 break
490 if ik < 0:
491 return
492 li = ik
493 ll = self.ipis_labels[li]
494 pi = event.ind[0]
495 gid = self.pulse_gids[ll][pi]
496 # mark pulses:
497 pulses = self.pulses[self.pulses[:,0] == ll,:]
498 pulses = pulses[pulses[:,1] == gid,:]
499 for t in range(self.traces):
500 c = self.show_channels[t]
501 pt = pulses[pulses[:,2] == c,3]
502 if len(pt) > 0:
503 if self.marker_artist[t] is None:
504 pa, = self.axs[t].plot(pt[0]/self.samplerate,
505 self.data[pt[0],c], 'o', ms=10,
506 color=self.pulse_colors[ll%len(self.pulse_colors)])
507 self.marker_artist[t] = pa
508 else:
509 self.marker_artist[t].set_data(pt[0]/self.samplerate,
510 self.data[pt[0],c])
511 self.marker_artist[t].set_color(self.pulse_colors[ll%len(self.pulse_colors)])
512 elif self.marker_artist[t] is not None:
513 self.marker_artist[t].set_data([], [])
514 # mark ipi:
515 pt0 = -1.0
516 pt1 = -1.0
517 pf = -1.0
518 if pi >= 0:
519 pt0 = self.pulse_times[ll][pi]/self.samplerate
520 pt1 = self.pulse_times[ll][pi+1]/self.samplerate
521 pf = 1.0/(pt1-pt0)
522 if self.marker_artist[self.traces] is None:
523 pa, = self.axs[self.traces].plot(pt0, pf, 'o', ms=10,
524 color=self.pulse_colors[ll%len(self.pulse_colors)])
525 self.marker_artist[self.traces] = pa
526 else:
527 self.marker_artist[self.traces].set_data(pt0, pf)
528 self.marker_artist[self.traces].set_color(self.pulse_colors[ll%len(self.pulse_colors)])
529 elif not self.marker_artist[self.traces] is None:
530 self.marker_artist[self.traces].set_data([], [])
531 self.fig.canvas.draw()
532 # show features:
533 if not self.axf is None and not self.fig is None:
534 heights = np.zeros(self.channels)
535 heights[pulses[:,2]] = \
536 self.data[pulses[:,3],pulses[:,2]] - \
537 self.data[pulses[:,4],pulses[:,2]]
538 self.axf.plot(heights, color=self.pulse_colors[ll%len(self.pulse_colors)])
539 print(f'label={ll:4d} gid={gid:5d} t={pt0:8.4f}s')
540 self.figf.canvas.draw()
542 def resize(self, event):
543 # print('resized', event.width, event.height)
544 leftpixel = 80.0
545 rightpixel = 20.0
546 bottompixel = 50.0
547 toppixel = 20.0
548 x0 = leftpixel / event.width
549 x1 = 1.0 - rightpixel / event.width
550 y0 = bottompixel / event.height
551 y1 = 1.0 - toppixel / event.height
552 self.fig.subplots_adjust(left=x0, right=x1, bottom=y0, top=y1,
553 hspace=0)
555 def keypress(self, event):
556 # print('pressed', event.key)
557 if event.key in '+=X':
558 if self.twindow * self.samplerate > 20:
559 self.twindow *= 0.5
560 self.update_plots()
561 elif event.key in '-x':
562 if self.twindow < self.tmax:
563 self.twindow *= 2.0
564 self.update_plots()
565 elif event.key == 'pagedown':
566 if self.toffset + 0.5 * self.twindow < self.tmax:
567 self.toffset += 0.5 * self.twindow
568 self.update_plots()
569 elif event.key == 'pageup':
570 if self.toffset > 0:
571 self.toffset -= 0.5 * self.twindow
572 if self.toffset < 0.0:
573 self.toffset = 0.0
574 self.update_plots()
575 elif event.key == 'ctrl+pagedown':
576 if self.toffset + 5.0 * self.twindow < self.tmax:
577 self.toffset += 5.0 * self.twindow
578 self.update_plots()
579 elif event.key == 'ctrl+pageup':
580 if self.toffset > 0:
581 self.toffset -= 5.0 * self.twindow
582 if self.toffset < 0.0:
583 self.toffset = 0.0
584 self.update_plots()
585 elif event.key == 'down':
586 if self.toffset + self.twindow < self.tmax:
587 self.toffset += 0.05 * self.twindow
588 self.update_plots()
589 elif event.key == 'up':
590 if self.toffset > 0.0:
591 self.toffset -= 0.05 * self.twindow
592 if self.toffset < 0.0:
593 self.toffset = 0.0
594 self.update_plots()
595 elif event.key == 'home':
596 if self.toffset > 0.0:
597 self.toffset = 0.0
598 self.update_plots()
599 elif event.key == 'end':
600 toffs = np.floor(self.tmax / self.twindow) * self.twindow
601 if self.tmax - toffs <= 0.0:
602 toffs -= self.twindow
603 if self.tmax - toffs < self.twindow/2:
604 toffs -= self.twindow/2
605 if self.toffset < toffs:
606 self.toffset = toffs
607 self.update_plots()
608 elif event.key == 'y':
609 for t in range(self.traces):
610 h = self.ymax[t] - self.ymin[t]
611 c = 0.5 * (self.ymax[t] + self.ymin[t])
612 self.ymin[t] = c - h
613 self.ymax[t] = c + h
614 self.axs[t].set_ylim(self.ymin[t], self.ymax[t])
615 self.fig.canvas.draw()
616 elif event.key == 'Y':
617 for t in range(self.traces):
618 h = 0.25 * (self.ymax[t] - self.ymin[t])
619 c = 0.5 * (self.ymax[t] + self.ymin[t])
620 self.ymin[t] = c - h
621 self.ymax[t] = c + h
622 self.axs[t].set_ylim(self.ymin[t], self.ymax[t])
623 self.fig.canvas.draw()
624 elif event.key == 'v':
625 t0 = int(np.round(self.toffset * self.samplerate))
626 t1 = int(np.round((self.toffset + self.twindow) * self.samplerate))
627 min = np.min(self.data[t0:t1,self.show_channels])
628 max = np.max(self.data[t0:t1,self.show_channels])
629 h = 0.53 * (max - min)
630 c = 0.5 * (max + min)
631 self.ymin[:] = c - h
632 self.ymax[:] = c + h
633 for t in range(self.traces):
634 self.axs[t].set_ylim(self.ymin[t], self.ymax[t])
635 self.fig.canvas.draw()
636 elif event.key == 'ctrl+v':
637 t0 = int(np.round(self.toffset * self.samplerate))
638 t1 = int(np.round((self.toffset + self.twindow) * self.samplerate))
639 for t in range(self.traces):
640 min = np.min(self.data[t0:t1,self.show_channels[t]])
641 max = np.max(self.data[t0:t1,self.show_channels[t]])
642 h = 0.53 * (max - min)
643 c = 0.5 * (max + min)
644 self.ymin[t] = c - h
645 self.ymax[t] = c + h
646 self.axs[t].set_ylim(self.ymin[t], self.ymax[t])
647 self.fig.canvas.draw()
648 elif event.key == 'ctrl+V':
649 for t in range(self.traces):
650 min = np.min(self.data[:,self.show_channels[t]])
651 max = np.max(self.data[:,self.show_channels[t]])
652 h = 0.53 * (max - min)
653 c = 0.5 * (max + min)
654 self.ymin[t] = c - h
655 self.ymax[t] = c + h
656 self.axs[t].set_ylim(self.ymin[t], self.ymax[t])
657 self.fig.canvas.draw()
658 elif event.key == 'V':
659 self.ymin[:] = -1.0
660 self.ymax[:] = +1.0
661 for t in range(self.traces):
662 self.axs[t].set_ylim(self.ymin[t], self.ymax[t])
663 self.fig.canvas.draw()
664 elif event.key == 'c':
665 for t in range(self.traces):
666 dy = self.ymax[t] - self.ymin[t]
667 self.ymin[t] = -dy/2
668 self.ymax[t] = +dy/2
669 self.axs[t].set_ylim(self.ymin[t], self.ymax[t])
670 self.fig.canvas.draw()
671 elif event.key == 'g':
672 self.show_gid = not self.show_gid
673 self.plot_pulses(self.axs, False)
674 self.fig.canvas.draw()
675 elif event.key == 'i':
676 if len(self.pulses) > 0:
677 self.fmax *= 2
678 self.axs[-1].set_ylim(0.0, self.fmax)
679 self.fig.canvas.draw()
680 elif event.key == 'I':
681 if len(self.pulses) > 0:
682 self.fmax /= 2
683 self.axs[-1].set_ylim(0.0, self.fmax)
684 self.fig.canvas.draw()
685 elif event.key in ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']:
686 cc = int(event.key)
687 # TODO: this is not yet what we want:
688 """
689 if cc < self.channels:
690 self.axs[cc].set_visible(not self.axs[cc].get_visible())
691 self.fig.canvas.draw()
692 """
693 elif event.key in 'h':
694 self.help = not self.help
695 for ht in self.helptext:
696 ht.set_visible(self.help)
697 self.fig.canvas.draw()
698 elif event.key in 'p':
699 self.play_segment()
700 elif event.key in 'P':
701 self.play_all()
702 elif event.key in 'S':
703 self.save_segment()
704 elif event.key in 'w':
705 self.plot_traces()
707 def play_segment(self):
708 t0 = int(np.round(self.toffset * self.samplerate))
709 t1 = int(np.round((self.toffset + self.twindow) * self.samplerate))
710 playdata = 1.0 * np.mean(self.data[t0:t1,self.show_channels], 1)
711 f = 0.1 if self.twindow > 0.5 else 0.1*self.twindow
712 fade(playdata, self.samplerate, f)
713 self.audio.play(playdata, self.samplerate, blocking=False)
715 def play_all(self):
716 self.audio.play(np.mean(self.data[:,self.show_channels], 1),
717 self.samplerate, blocking=False)
719 def save_segment(self):
720 t0s = int(np.round(self.toffset))
721 t1s = int(np.round(self.toffset + self.twindow))
722 t0 = int(np.round(self.toffset * self.samplerate))
723 t1 = int(np.round((self.toffset + self.twindow) * self.samplerate))
724 filename = self.filename.split('.')[0]
725 if self.traces == self.channels:
726 segment_filename = f'{filename}-{t0s:.4g}s-{t1s:.4g}s.wav'
727 write_audio(segment_filename, self.data[t0:t1,:], self.samplerate)
728 else:
729 segment_filename = f'{filename}-{t0s:.4g}s-{t1s:.4g}s-c{self.show_channels[0]}.wav'
730 write_audio(segment_filename,
731 self.data[t0:t1,self.show_channels], self.samplerate)
732 print('saved segment to: ' , segment_filename)
734 def plot_traces(self):
735 splts = self.traces
736 if len(self.pulses) > 0:
737 splts += 1
738 fig, axs = plt.subplots(splts, 1, squeeze=False, sharex=True,
739 figsize=(15, 9))
740 axs = axs.flat
741 fig.subplots_adjust(left=0.06, right=0.99, bottom=0.05, top=0.97,
742 hspace=0)
743 name = self.filename.split('.')[0]
744 figfile = f'{name}-{self.toffset:.4g}s-traces.png'
745 if self.traces < self.channels:
746 figfile = f'{name}-{self.toffset:.4g}s-c{self.show_channels[0]}-traces.png'
747 axs[0].set_title(self.filename)
748 t0 = int(np.round(self.toffset * self.samplerate))
749 t1 = int(np.round((self.toffset + self.twindow) * self.samplerate))
750 if t1>len(self.data):
751 t1 = len(self.data)
752 time = np.arange(t0, t1)/self.samplerate
753 if self.toffset < 1.0 and self.twindow < 1.0:
754 axs[-1].set_xlabel('Time [ms]')
755 for t in range(self.traces):
756 c = self.show_channels[t]
757 axs[t].set_xlim(1000.0 * self.toffset,
758 1000.0 * (self.toffset + self.twindow))
759 axs[t].plot(1000.0 * time, self.data[t0:t1,c])
760 self.plot_pulses(axs, True, 1000.0)
761 else:
762 axs[-1].set_xlabel('Time [s]')
763 for t in range(self.traces):
764 c = self.show_channels[t]
765 axs[t].set_xlim(self.toffset, self.toffset + self.twindow)
766 axs[t].plot(time, self.data[t0:t1,c])
767 self.plot_pulses(axs, True, 1.0)
768 for t in range(self.traces):
769 c = self.show_channels[t]
770 axs[t].set_ylim(self.ymin[t], self.ymax[t])
771 axs[t].set_ylabel(f'C-{c+1} [{self.unit}]')
772 if len(self.pulses) > 0:
773 axs[-1].set_ylabel('IP freq [Hz]')
774 axs[-1].set_ylim(0.0, self.fmax)
775 #for t in range(self.traces-1):
776 # axs[t].xaxis.set_major_formatter(plt.NullFormatter())
777 fig.savefig(figfile, dpi=200)
778 plt.close(fig)
779 print('saved waveform figure to', figfile)
782def short_user_warning(message, category, filename, lineno, file=None, line=''):
783 if file is None:
784 file = sys.stderr
785 if category == UserWarning:
786 file.write('%s line %d: %s\n' % ('/'.join(filename.split('/')[-2:]), lineno, message))
787 else:
788 s = warnings.formatwarning(message, category, filename, lineno, line)
789 file.write(s)
792def main(cargs=None):
793 warnings.showwarning = short_user_warning
795 # config file name:
796 cfgfile = __package__ + '.cfg'
798 # command line arguments:
799 if cargs is None:
800 cargs = sys.argv[1:]
801 parser = argparse.ArgumentParser(
802 description='Browse mutlichannel EOD recordings.',
803 epilog='version %s by Benda-Lab (2022-%s)' % (__version__, __year__))
804 parser.add_argument('--version', action='version', version=__version__)
805 parser.add_argument('-v', action='count', dest='verbose')
806 parser.add_argument('-c', dest='channels', default='',
807 type=str, metavar='CHANNELS',
808 help='Comma separated list of channels to be displayed (first channel is 0).')
809 parser.add_argument('-t', dest='tmax', default=None,
810 type=float, metavar='TMAX',
811 help='Process and show only the first TMAX seconds.')
812 parser.add_argument('-f', dest='fcutoff', default=None,
813 type=float, metavar='FREQ',
814 help='Cutoff frequency of optional high-pass filter.')
815 parser.add_argument('-p', dest='pulses', action='store_true',
816 help='detect pulse fish EODs')
817 parser.add_argument('file', nargs=1, default='', type=str,
818 help='name of the file with the time series data')
819 args = parser.parse_args(cargs)
820 filepath = args.file[0]
821 cs = [s.strip() for s in args.channels.split(',')]
822 channels = [int(c) for c in cs if len(c)>0]
823 tmax = args.tmax
824 fcutoff = args.fcutoff
825 pulses = args.pulses
827 # set verbosity level from command line:
828 verbose = 0
829 if args.verbose != None:
830 verbose = args.verbose
832 # load data:
833 filename = os.path.basename(filepath)
834 with DataLoader(filepath, 10*60.0, 5.0, verbose) as data:
835 SignalPlot(data, data.samplerate, data.unit, filename,
836 channels, tmax, fcutoff, pulses)
840if __name__ == '__main__':
841 main()