Coverage for src / thunderfish / thunderbrowse.py: 0%

577 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-15 17:50 +0000

1import sys 

2import os 

3import warnings 

4import argparse 

5import numpy as np 

6import matplotlib.pyplot as plt 

7 

8from scipy.signal import butter, sosfiltfilt 

9from audioio import PlayAudio, fade, write_audio 

10from thunderlab.dataloader import DataLoader 

11from thunderlab.eventdetection import detect_peaks, median_std_threshold 

12 

13from .version import __version__, __year__ 

14from .pulses import detect_pulses 

15 

16 

17class SignalPlot: 

18 def __init__(self, data, rate, unit, filename, 

19 show_channels=[], tmax=None, fcutoff=None, 

20 pulses=False): 

21 self.filename = filename 

22 self.rate = rate 

23 self.data = data 

24 self.channels = self.data.shape[1] if len(self.data.shape) > 1 else 1 

25 self.unit = unit 

26 self.tmax = (len(self.data)-1)/self.rate 

27 if not tmax is None: 

28 self.tmax = tmax 

29 self.data = data[:int(tmax*self.rate),:] 

30 self.toffset = 0.0 

31 self.twindow = 10.0 

32 if self.twindow > self.tmax: 

33 self.twindow = np.round(2 ** (np.floor(np.log(self.tmax) / np.log(2.0)) + 1.0)) 

34 if not tmax is None: 

35 self.twindow = tmax 

36 self.pulses = np.zeros((0, 3), dtype=int) 

37 self.labels = [] 

38 self.fishes = [] 

39 self.pulse_times = [] 

40 self.pulse_gids = [] 

41 if len(show_channels) == 0: 

42 self.show_channels = np.arange(self.channels) 

43 else: 

44 self.show_channels = np.array(show_channels) 

45 self.traces = len(self.show_channels) 

46 self.ymin = -1.0 * np.ones(self.traces) 

47 self.ymax = +1.0 * np.ones(self.traces) 

48 self.fmax = 100.0 

49 self.trace_artist = [None] * self.traces 

50 self.show_gid = False 

51 self.pulse_artist = [] 

52 self.marker_artist = [None] * (self.traces + 1) 

53 self.ipis_artist = [] 

54 self.ipis_labels = [] 

55 self.figf = None 

56 self.axf = None 

57 self.pulse_colors = ['C1', 'C2', 'C3', 'C4', 'C5', 'C6', 'C7', 'C8', 'C9', 'C0'] 

58 self.help = False 

59 self.helptext = [] 

60 self.audio = PlayAudio() 

61 

62 # filter data: 

63 if not fcutoff is None: 

64 sos = butter(2, fcutoff, 'high', fs=rate, output='sos') 

65 self.data = sosfiltfilt(sos, self.data[:], 0) 

66 

67 # pulse detection: 

68 if pulses: 

69 # label, group, channel, peak index, trough index 

70 all_pulses = np.zeros((0, 5), dtype=int) 

71 for c in range(self.channels): 

72 #thresh = 1*np.std(self.data[:int(2*self.rate),c]) 

73 thresh = median_std_threshold(self.data[:,c], self.rate, 

74 thresh_fac=6.0) 

75 thresh = 0.01 

76 #p, t = detect_peaks(self.data[:,c], thresh) 

77 p, t, w, h = detect_pulses(self.data[:,c], self.rate, 

78 thresh, 

79 min_rel_slope_diff=0.25, 

80 min_width=0.0001, 

81 max_width=0.01, 

82 width_fac=5.0) 

83 # label, group, channel, peak, trough: 

84 pulses = np.hstack((np.arange(len(p))[:,np.newaxis], 

85 np.zeros((len(p), 1), dtype=int), 

86 np.ones((len(p), 1), dtype=int)*c, 

87 p[:,np.newaxis], t[:,np.newaxis])) 

88 all_pulses = np.vstack((all_pulses, pulses)) 

89 self.pulses = all_pulses[np.argsort(all_pulses[:,3]),:] 

90 # grouping over channels: 

91 max_di = int(0.0002*self.rate) # TODO: parameter 

92 l = -1 

93 k = 0 

94 while k < len(self.pulses): 

95 tp = self.pulses[k,3] 

96 tt = self.pulses[k,4] 

97 height = self.data[self.pulses[k,3],self.pulses[k,2]] - \ 

98 self.data[self.pulses[k,4],self.pulses[k,2]] 

99 channel_counts = np.zeros(self.channels, dtype=int) 

100 channel_counts[self.pulses[k,2]] += 1 

101 for c in range(1, 3*self.channels): 

102 if k+c >= len(self.pulses): 

103 break 

104 # pulse too far away: 

105 if channel_counts[self.pulses[k+c,2]] > 1 or \ 

106 (np.abs(self.pulses[k+c,3] - tp) > max_di and 

107 np.abs(self.pulses[k+c,3] - tt) > max_di and 

108 np.abs(self.pulses[k+c,4] - tp) > max_di and 

109 np.abs(self.pulses[k+c,4] - tt) > max_di): 

110 break 

111 channel_counts[self.pulses[k+c,2]] += 1 

112 height_kc = self.data[self.pulses[k+c,3],self.pulses[k+c,2]] - \ 

113 self.data[self.pulses[k+c,4],self.pulses[k+c,2]] 

114 # heighest pulse sets time reference: 

115 if height_kc > height: 

116 tp = self.pulses[k+c,3] 

117 tt = self.pulses[k+c,4] 

118 height = height_kc 

119 # all pulses too small: 

120 if height < 0.02: # TODO parameter 

121 self.pulses[k:k+c,0] = -1 

122 k += c 

123 continue 

124 # new label: 

125 l += 1 

126 # remove lost pulses: 

127 for j in range(c): 

128 if (np.abs(self.pulses[k+j,3] - tp) > max_di and 

129 np.abs(self.pulses[k+j,3] - tt) > max_di and 

130 np.abs(self.pulses[k+j,4] - tp) > max_di and 

131 np.abs(self.pulses[k+j,4] - tt) > max_di): 

132 self.pulses[k+j,0] = -1 

133 channel_counts[self.pulses[k+j,2]] -= 1 

134 else: 

135 self.pulses[k+j,0] = l 

136 self.pulses[k+j,1] = l 

137 # keep only the largest pulse of each channel: 

138 pulses = self.pulses[k:k+c,:] 

139 for dc in np.where(channel_counts > 1)[0]: 

140 idx = np.where(self.pulses[k:k+c,2] == dc)[0] 

141 heights = self.data[pulses[idx,3],dc] - \ 

142 self.data[pulses[idx,4],dc] 

143 for i in range(len(idx)): 

144 if i != np.argmax(heights): 

145 channel_counts[self.pulses[k+idx[i],2]] -= 1 

146 self.pulses[k+idx[i],0] = -1 

147 k += c 

148 self.pulses = self.pulses[self.pulses[:,0] >= 0,:] 

149 

150 # clustering: 

151 min_dists = [] 

152 recent = [] 

153 k = 0 

154 while k < len(self.pulses): 

155 # select pulse group: 

156 j = k 

157 gid = self.pulses[j,1] 

158 for c in range(self.channels): 

159 k += 1 

160 if k >= len(self.pulses) or \ 

161 self.pulses[k,1] != gid: 

162 break 

163 heights = np.zeros(self.channels) 

164 heights[self.pulses[j:k,2]] = \ 

165 self.data[self.pulses[j:k,3],self.pulses[j:k,2]] - \ 

166 self.data[self.pulses[j:k,4],self.pulses[j:k,2]] 

167 # time of largest pulse: 

168 pulse_time = self.pulses[j+np.argmax(heights[self.pulses[j:k,2]]),3] 

169 # assign to cluster: 

170 if len(self.pulse_times) == 0: 

171 label = len(self.pulse_times) 

172 self.pulse_times.append([]) 

173 self.pulse_gids.append([]) 

174 else: 

175 # compute metrics of recent fishes: 

176 # mean relative height difference: 

177 dists = np.array([np.mean(np.abs(hh - heights)/np.max(hh)) 

178 for ll, tt, hh in recent]) 

179 thresh = 0.1 # TODO: make parameter 

180 # distance between pulses: 

181 ipis = np.array([(pulse_time - tt)/self.rate 

182 for ll, tt, hh in recent]) 

183 ## how can ipis be 0, or just one sample? 

184 ##if len(ipis[ipis<0.001]) > 0: 

185 ## print(ipis[ipis<0.001]) 

186 # ensure minimum IP distance: 

187 dists[1/ipis > 300.0] = 2*np.max(dists) # TODO: make parameter 

188 # minimum ditance: 

189 min_dist_idx = np.argmin(dists) 

190 min_dists.append(dists[min_dist_idx]) 

191 if dists[min_dist_idx] < thresh: 

192 label = recent[min_dist_idx][0] 

193 else: 

194 label = len(self.pulse_times) 

195 self.pulse_times.append([]) 

196 self.pulse_gids.append([]) 

197 self.pulses[j:k,0] = label 

198 self.pulse_times[label].append(pulse_time) 

199 self.pulse_gids[label].append(gid) 

200 self.fishes.append([label, pulse_time, heights]) 

201 recent.append([label, pulse_time, heights]) 

202 # remove old fish: 

203 for i, (ll, tt, hh) in enumerate(recent): 

204 # TODO: make parameter: 

205 if (pulse_time - tt)/self.rate <= 0.2: 

206 recent = recent[i:] 

207 break 

208 # only consider the n most recent pulses of a fish: 

209 n = 5 # TODO make parameter 

210 labels = np.array([ll for ll, tt, hh in recent]) 

211 if np.sum(labels == label) > n: 

212 del recent[np.where(labels == label)[0][0]] 

213 # pulse times to arrays: 

214 for k in range(len(self.pulse_times)): 

215 self.pulse_times[k] = np.array(self.pulse_times[k]) 

216 

217 

218 

219 """ 

220 # find temporally missing pulses: 

221 npulses = np.array([len(pts) for pts in self.pulse_times], 

222 dtype=int) 

223 idx = np.argsort(npulses) 

224 for i in range(len(idx)): 

225 li = idx[len(idx)-1-i] 

226 if len(self.pulse_times[li]) < 10 or \ 

227 len(self.pulse_times[li])/npulses[li] < 0.5: 

228 continue 

229 ipis = np.diff(self.pulse_times[li]) 

230 n = 4 # TODO: make parameter 

231 k = 0 

232 while k < len(ipis)-n: 

233 mipi = np.median(ipis[k:k+n]) 

234 if ipis[k+n-2] > 1.8*mipi: 

235 # search for pulse closest to pt: 

236 pt = self.pulse_times[li][k+n-2] + mipi 

237 mlj = -1 

238 mpj = -1 

239 mdj = 10*mipi 

240 for lj in range(len(self.pulse_times)): 

241 if lj == li or len(self.pulse_times[lj]) == 0: 

242 continue 

243 pj = np.argmin(np.abs(self.pulse_times[lj] - pt)) 

244 dj = np.abs(self.pulse_times[lj][pj] - pt) 

245 if dj < int(0.001*self.rate) and dj < mdj: 

246 mdj = dj 

247 mpj = pj 

248 mlj = lj 

249 if mlj >= 0: 

250 # there is a pulse close to pt: 

251 ptj = self.pulse_times[mlj][mpj] 

252 pulses = self.pulses[self.pulses[:,0] == mlj,:] 

253 gid = pulses[np.argmin(np.abs(pulses[:,3] - ptj)),1] 

254 self.pulse_times[li] = np.insert(self.pulse_times[li], k+n-1, ptj) 

255 self.pulse_gids[li].insert(k+n-1, gid) 

256 # maybe don't delete but always duplicate and flag it: 

257 if False: # can be deleted 

258 self.pulse_times[mlj] = np.delete(self.pulse_times[mlj], mpj) 

259 self.pulse_gids[mlj].pop(mpj) 

260 self.pulses[self.pulses[:,1] == gid,0] = li 

261 else: # pulse needs to be duplicated: 

262 self.pulses[self.pulses[:,1] == gid,0] = li 

263 ipis = np.diff(self.pulse_times[li]) 

264 k += 1 

265 

266 

267  

268 # clean up pulses: 

269 for l in range(len(self.pulse_times)): 

270 if len(self.pulse_times[l])/npulses[l] < 0.5: 

271 self.pulse_times[l] = np.array([]) 

272 self.pulse_gids[l] = [] 

273 self.pulses[self.pulses[:,0] == l,0] = -1 

274 self.pulses = self.pulses[self.pulses[:,0] >= 0,:] 

275 """ 

276 

277 """ 

278 # remove labels that are too close to others: 

279 widths = np.zeros(len(self.pulse_times), dtype=int) 

280 for k in range(len(self.pulse_times)): 

281 widths[k] = int(np.mean(np.abs(self.pulses[self.pulses[:,0] == k,3] - self.pulses[self.pulses[:,0] == k,4]))) 

282 for k in range(len(self.pulse_times)): 

283 if len(self.pulse_times[k]) > 1: 

284 for j in range(k+1, len(self.pulse_times)): 

285 if len(self.pulse_times[j]) > 1: 

286 di = 10*max(widths[k], widths[j]) 

287 dts = np.array([np.min(np.abs(self.pulse_times[k] - pt)) for pt in self.pulse_times[j]]) 

288 if k == 1 and j == 2: 

289 print(di, np.sum(dts < di), len(dts)) 

290 plt.hist(dts, 50) 

291 plt.show() 

292 if np.sum(dts < 2*max_di)/len(dts) > 0.6: 

293 r = k 

294 if np.sum(self.fishes[k][2]) > np.sum(self.fishes[j][2]): 

295 r = j 

296 self.pulse_times[r] = np.array([]) 

297 self.pulses[self.pulses[:,0] == r] = -1 

298 self.fishes[r] = [] 

299 self.pulses = self.pulses[self.pulses[:,0] >= 0,:] 

300 """ 

301 # all labels: 

302 self.labels = np.unique(self.pulses[:,0]) 

303 # report: 

304 print(f'found {len(self.pulse_times)} fish:') 

305 for k in range(len(self.pulse_times)): 

306 print(f'{k:3d}: {len(self.pulse_times[k]):5d} pulses') 

307 ## plot histogtram of distances: 

308 #plt.hist(min_dists, 100) 

309 #plt.show() 

310 ## plot features: 

311 """ 

312 nn = np.array([(k, len(self.pulse_times[k])) 

313 for k in range(len(self.pulse_times))]) 

314 fig, axs = plt.subplots(5, 5, figsize=(15, 9), 

315 constrained_layout=True) 

316 ni = np.argsort(nn[:,1]) # largest cluster ... 

317 ln = np.sort(nn[ni[-axs.size:],0]) # ... sort by label 

318 for l, ax in zip(ln, axs.flat): 

319 h = np.array([hh for ll, tt, hh in self.fishes if ll == l]) 

320 ax.plot(h.T, 'o-', ms=2, lw=0.5, 

321 color=self.pulse_colors[l%len(self.pulse_colors)]) 

322 ax.text(0.05, 0.9, f'label: {l}', transform=ax.transAxes) 

323 """ 

324 

325 # set key bindings: 

326 plt.rcParams['keymap.fullscreen'] = 'f' 

327 plt.rcParams['keymap.pan'] = 'ctrl+m' 

328 plt.rcParams['keymap.quit'] = 'ctrl+w, alt+q, q' 

329 plt.rcParams['keymap.yscale'] = '' 

330 plt.rcParams['keymap.xscale'] = '' 

331 plt.rcParams['keymap.grid'] = '' 

332 #plt.rcParams['keymap.all_axes'] = '' 

333 

334 # the figure: 

335 plt.ioff() 

336 splts = self.traces 

337 if len(self.pulses) > 0: 

338 splts += 1 

339 self.fig, self.axs = plt.subplots(splts, 1, squeeze=False, 

340 figsize=(15, 9), sharex=True) 

341 self.axs = self.axs.flat 

342 if self.traces == self.channels: 

343 self.fig.canvas.manager.set_window_title(self.filename) 

344 else: 

345 cs = ' c%d' % self.show_channels[0] 

346 self.fig.canvas.manager.set_window_title(self.filename + ' ' + cs) 

347 self.fig.canvas.mpl_connect('key_press_event', self.keypress) 

348 self.fig.canvas.mpl_connect('resize_event', self.resize) 

349 self.fig.canvas.mpl_connect('pick_event', self.on_pick) 

350 # trace plots: 

351 for t in range(self.traces): 

352 self.axs[t].set_ylabel(f'C-{self.show_channels[t]+1} [{self.unit}]') 

353 #for t in range(self.traces-1): 

354 # self.axs[t].xaxis.set_major_formatter(plt.NullFormatter()) 

355 if len(self.pulses) > 0: 

356 self.axs[-1].set_ylim(0, self.fmax) 

357 self.axs[-1].set_ylabel('IP freq [Hz]') 

358 self.axs[-1].set_xlabel('Time [s]') 

359 ht = self.axs[0].text(0.98, 0.05, '(ctrl+) page and arrow up, down, home, end: scroll', ha='right', 

360 transform=self.axs[0].transAxes) 

361 self.helptext.append(ht) 

362 ht = self.axs[0].text(0.98, 0.1, '+, -, X, x: zoom time in/out', ha='right', transform=self.axs[0].transAxes) 

363 self.helptext.append(ht) 

364 ht = self.axs[0].text(0.98, 0.2, 'y, Y, v, V, ctrl+v, ctrl+V: zoom amplitudes out/in/max/default/max per trace/global max per trace', ha='right', transform=self.axs[0].transAxes) 

365 self.helptext.append(ht) 

366 ht = self.axs[0].text(0.98, 0.3, 'i, I: zoom IPI frequency in/out', ha='right', transform=self.axs[0].transAxes) 

367 self.helptext.append(ht) 

368 ht = self.axs[0].text(0.98, 0.4, 'p, P: play audio (display, all)', ha='right', transform=self.axs[0].transAxes) 

369 self.helptext.append(ht) 

370 ht = self.axs[0].text(0.98, 0.5, 'f: full screen', ha='right', transform=self.axs[0].transAxes) 

371 self.helptext.append(ht) 

372 ht = self.axs[0].text(0.98, 0.6, 'w: plot waveforms into png file', ha='right', transform=self.axs[0].transAxes) 

373 self.helptext.append(ht) 

374 ht = self.axs[0].text(0.98, 0.7, 'S: save audiosegment', ha='right', transform=self.axs[0].transAxes) 

375 self.helptext.append(ht) 

376 ht = self.axs[0].text(0.98, 0.8, 'q: quit', ha='right', transform=self.axs[0].transAxes) 

377 self.helptext.append(ht) 

378 ht = self.axs[0].text(0.98, 0.9, 'h: toggle this help', ha='right', transform=self.axs[0].transAxes) 

379 self.helptext.append(ht) 

380 # plot: 

381 for ht in self.helptext: 

382 ht.set_visible(self.help) 

383 self.update_plots() 

384 # feature plot: 

385 if len(self.labels) > 0: 

386 self.figf, self.axf = plt.subplots() 

387 plt.show() 

388 

389 def __del__(self): 

390 pass 

391 #self.audio.close() 

392 

393 def plot_pulses(self, axs, plot=True, tfac=1.0): 

394 

395 def plot_pulse_traces(pulses, i, pak): 

396 for t in range(self.traces): 

397 c = self.show_channels[t] 

398 p = pulses[pulses[:,2] == c,3] 

399 if len(p) == 0: 

400 continue 

401 if plot or pak >= len(self.pulse_artist): 

402 pa, = axs[t].plot(tfac*p/self.rate, 

403 self.data[p,c], 'o', picker=5, 

404 color=self.pulse_colors[i%len(self.pulse_colors)]) 

405 if not plot: 

406 self.pulse_artist.append(pa) 

407 else: 

408 self.pulse_artist[pak].set_data(tfac*p/self.rate, 

409 self.data[p,c]) 

410 self.pulse_artist[pak].set_color(self.pulse_colors[i%len(self.pulse_colors)]) 

411 #if len(p) > 1 and len(p) <= 10: 

412 # self.pulse_artist[pak].set_markersize(15) 

413 pak += 1 

414 return pak 

415 

416 # pulses: 

417 pak = 0 

418 if self.show_gid: 

419 for g in range(len(self.pulse_colors)): 

420 pulses = self.pulses[self.pulses[:,1] % len(self.pulse_colors) == g,:] 

421 pak = plot_pulse_traces(pulses, g, pak) 

422 else: 

423 for l in self.labels: 

424 pulses = self.pulses[self.pulses[:,0] == l,:] 

425 pak = plot_pulse_traces(pulses, l, pak) 

426 while pak < len(self.pulse_artist): 

427 self.pulse_artist[pak].set_data([], []) 

428 pak += 1 

429 # ipis: 

430 for l in self.labels: 

431 if l < len(self.pulse_times): 

432 pt = self.pulse_times[l]/self.rate 

433 if len(pt) > 10: 

434 if plot or not l in self.ipis_labels: 

435 pa, = axs[-1].plot(tfac*pt[:-1], 1.0/np.diff(pt), 

436 '-o', picker=5, 

437 color=self.pulse_colors[l%len(self.pulse_colors)]) 

438 if not plot: 

439 self.ipis_artist.append(pa) 

440 self.ipis_labels.append(l) 

441 else: 

442 iak = self.ipis_labels.index(l) 

443 self.ipis_artist[iak].set_data(tfac*pt[:-1], 

444 1.0/np.diff(pt)) 

445 

446 def update_plots(self): 

447 t0 = int(np.round(self.toffset * self.rate)) 

448 t1 = int(np.round((self.toffset + self.twindow) * self.rate)) 

449 if t1 > len(self.data): 

450 t1 = len(self.data) 

451 time = np.arange(t0, t1) / self.rate 

452 for t in range(self.traces): 

453 c = self.show_channels[t] 

454 self.axs[t].set_xlim(self.toffset, self.toffset + self.twindow) 

455 if self.trace_artist[t] == None: 

456 self.trace_artist[t], = self.axs[t].plot(time, self.data[t0:t1,c]) 

457 else: 

458 self.trace_artist[t].set_data(time, self.data[t0:t1,c]) 

459 if t1 - t0 < 200: 

460 self.trace_artist[t].set_marker('o') 

461 self.trace_artist[t].set_markersize(3) 

462 else: 

463 self.trace_artist[t].set_marker('None') 

464 self.axs[t].set_ylim(self.ymin[t], self.ymax[t]) 

465 self.plot_pulses(self.axs, False) 

466 self.fig.canvas.draw() 

467 

468 def on_pick(self, event): 

469 # index of pulse artist: 

470 pk = -1 

471 for k, pa in enumerate(self.pulse_artist): 

472 if event.artist == pa: 

473 pk = k 

474 break 

475 li = -1 

476 pi = -1 

477 if pk >= 0: 

478 # find label and pulses of pulse artist: 

479 ll = self.labels[pk//self.traces] 

480 cc = self.show_channels[pk % self.traces] 

481 pulses = self.pulses[self.pulses[:,0] == ll,:] 

482 gid = pulses[pulses[:,2] == cc,1][event.ind[0]] 

483 if ll in self.ipis_labels: 

484 li = self.ipis_labels.index(ll) 

485 pi = self.pulse_gids[ll].index(gid) 

486 else: 

487 ik = -1 

488 for k, ia in enumerate(self.ipis_artist): 

489 if event.artist == ia: 

490 ik = k 

491 break 

492 if ik < 0: 

493 return 

494 li = ik 

495 ll = self.ipis_labels[li] 

496 pi = event.ind[0] 

497 gid = self.pulse_gids[ll][pi] 

498 # mark pulses: 

499 pulses = self.pulses[self.pulses[:,0] == ll,:] 

500 pulses = pulses[pulses[:,1] == gid,:] 

501 for t in range(self.traces): 

502 c = self.show_channels[t] 

503 pt = pulses[pulses[:,2] == c,3] 

504 if len(pt) > 0: 

505 if self.marker_artist[t] is None: 

506 pa, = self.axs[t].plot(pt[0]/self.rate, 

507 self.data[pt[0],c], 'o', ms=10, 

508 color=self.pulse_colors[ll%len(self.pulse_colors)]) 

509 self.marker_artist[t] = pa 

510 else: 

511 self.marker_artist[t].set_data(pt[0]/self.rate, 

512 self.data[pt[0],c]) 

513 self.marker_artist[t].set_color(self.pulse_colors[ll%len(self.pulse_colors)]) 

514 elif self.marker_artist[t] is not None: 

515 self.marker_artist[t].set_data([], []) 

516 # mark ipi: 

517 pt0 = -1.0 

518 pt1 = -1.0 

519 pf = -1.0 

520 if pi >= 0: 

521 pt0 = self.pulse_times[ll][pi]/self.rate 

522 pt1 = self.pulse_times[ll][pi+1]/self.rate 

523 pf = 1.0/(pt1-pt0) 

524 if self.marker_artist[self.traces] is None: 

525 pa, = self.axs[self.traces].plot(pt0, pf, 'o', ms=10, 

526 color=self.pulse_colors[ll%len(self.pulse_colors)]) 

527 self.marker_artist[self.traces] = pa 

528 else: 

529 self.marker_artist[self.traces].set_data(pt0, pf) 

530 self.marker_artist[self.traces].set_color(self.pulse_colors[ll%len(self.pulse_colors)]) 

531 elif not self.marker_artist[self.traces] is None: 

532 self.marker_artist[self.traces].set_data([], []) 

533 self.fig.canvas.draw() 

534 # show features: 

535 if not self.axf is None and not self.fig is None: 

536 heights = np.zeros(self.channels) 

537 heights[pulses[:,2]] = \ 

538 self.data[pulses[:,3],pulses[:,2]] - \ 

539 self.data[pulses[:,4],pulses[:,2]] 

540 self.axf.plot(heights, color=self.pulse_colors[ll%len(self.pulse_colors)]) 

541 print(f'label={ll:4d} gid={gid:5d} t={pt0:8.4f}s') 

542 self.figf.canvas.draw() 

543 

544 def resize(self, event): 

545 # print('resized', event.width, event.height) 

546 leftpixel = 80.0 

547 rightpixel = 20.0 

548 bottompixel = 50.0 

549 toppixel = 20.0 

550 x0 = leftpixel / event.width 

551 x1 = 1.0 - rightpixel / event.width 

552 y0 = bottompixel / event.height 

553 y1 = 1.0 - toppixel / event.height 

554 self.fig.subplots_adjust(left=x0, right=x1, bottom=y0, top=y1, 

555 hspace=0) 

556 

557 def keypress(self, event): 

558 # print('pressed', event.key) 

559 if event.key in '+=X': 

560 if self.twindow * self.rate > 20: 

561 self.twindow *= 0.5 

562 self.update_plots() 

563 elif event.key in '-x': 

564 if self.twindow < self.tmax: 

565 self.twindow *= 2.0 

566 self.update_plots() 

567 elif event.key == 'pagedown': 

568 if self.toffset + 0.5 * self.twindow < self.tmax: 

569 self.toffset += 0.5 * self.twindow 

570 self.update_plots() 

571 elif event.key == 'pageup': 

572 if self.toffset > 0: 

573 self.toffset -= 0.5 * self.twindow 

574 if self.toffset < 0.0: 

575 self.toffset = 0.0 

576 self.update_plots() 

577 elif event.key == 'ctrl+pagedown': 

578 if self.toffset + 5.0 * self.twindow < self.tmax: 

579 self.toffset += 5.0 * self.twindow 

580 self.update_plots() 

581 elif event.key == 'ctrl+pageup': 

582 if self.toffset > 0: 

583 self.toffset -= 5.0 * self.twindow 

584 if self.toffset < 0.0: 

585 self.toffset = 0.0 

586 self.update_plots() 

587 elif event.key == 'down': 

588 if self.toffset + self.twindow < self.tmax: 

589 self.toffset += 0.05 * self.twindow 

590 self.update_plots() 

591 elif event.key == 'up': 

592 if self.toffset > 0.0: 

593 self.toffset -= 0.05 * self.twindow 

594 if self.toffset < 0.0: 

595 self.toffset = 0.0 

596 self.update_plots() 

597 elif event.key == 'home': 

598 if self.toffset > 0.0: 

599 self.toffset = 0.0 

600 self.update_plots() 

601 elif event.key == 'end': 

602 toffs = np.floor(self.tmax / self.twindow) * self.twindow 

603 if self.tmax - toffs <= 0.0: 

604 toffs -= self.twindow 

605 if self.tmax - toffs < self.twindow/2: 

606 toffs -= self.twindow/2 

607 if self.toffset < toffs: 

608 self.toffset = toffs 

609 self.update_plots() 

610 elif event.key == 'y': 

611 for t in range(self.traces): 

612 h = self.ymax[t] - self.ymin[t] 

613 c = 0.5 * (self.ymax[t] + self.ymin[t]) 

614 self.ymin[t] = c - h 

615 self.ymax[t] = c + h 

616 self.axs[t].set_ylim(self.ymin[t], self.ymax[t]) 

617 self.fig.canvas.draw() 

618 elif event.key == 'Y': 

619 for t in range(self.traces): 

620 h = 0.25 * (self.ymax[t] - self.ymin[t]) 

621 c = 0.5 * (self.ymax[t] + self.ymin[t]) 

622 self.ymin[t] = c - h 

623 self.ymax[t] = c + h 

624 self.axs[t].set_ylim(self.ymin[t], self.ymax[t]) 

625 self.fig.canvas.draw() 

626 elif event.key == 'v': 

627 t0 = int(np.round(self.toffset * self.rate)) 

628 t1 = int(np.round((self.toffset + self.twindow) * self.rate)) 

629 min = np.min(self.data[t0:t1,self.show_channels]) 

630 max = np.max(self.data[t0:t1,self.show_channels]) 

631 h = 0.53 * (max - min) 

632 c = 0.5 * (max + min) 

633 self.ymin[:] = c - h 

634 self.ymax[:] = c + h 

635 for t in range(self.traces): 

636 self.axs[t].set_ylim(self.ymin[t], self.ymax[t]) 

637 self.fig.canvas.draw() 

638 elif event.key == 'ctrl+v': 

639 t0 = int(np.round(self.toffset * self.rate)) 

640 t1 = int(np.round((self.toffset + self.twindow) * self.rate)) 

641 for t in range(self.traces): 

642 min = np.min(self.data[t0:t1,self.show_channels[t]]) 

643 max = np.max(self.data[t0:t1,self.show_channels[t]]) 

644 h = 0.53 * (max - min) 

645 c = 0.5 * (max + min) 

646 self.ymin[t] = c - h 

647 self.ymax[t] = c + h 

648 self.axs[t].set_ylim(self.ymin[t], self.ymax[t]) 

649 self.fig.canvas.draw() 

650 elif event.key == 'ctrl+V': 

651 for t in range(self.traces): 

652 min = np.min(self.data[:,self.show_channels[t]]) 

653 max = np.max(self.data[:,self.show_channels[t]]) 

654 h = 0.53 * (max - min) 

655 c = 0.5 * (max + min) 

656 self.ymin[t] = c - h 

657 self.ymax[t] = c + h 

658 self.axs[t].set_ylim(self.ymin[t], self.ymax[t]) 

659 self.fig.canvas.draw() 

660 elif event.key == 'V': 

661 self.ymin[:] = -1.0 

662 self.ymax[:] = +1.0 

663 for t in range(self.traces): 

664 self.axs[t].set_ylim(self.ymin[t], self.ymax[t]) 

665 self.fig.canvas.draw() 

666 elif event.key == 'c': 

667 for t in range(self.traces): 

668 dy = self.ymax[t] - self.ymin[t] 

669 self.ymin[t] = -dy/2 

670 self.ymax[t] = +dy/2 

671 self.axs[t].set_ylim(self.ymin[t], self.ymax[t]) 

672 self.fig.canvas.draw() 

673 elif event.key == 'g': 

674 self.show_gid = not self.show_gid 

675 self.plot_pulses(self.axs, False) 

676 self.fig.canvas.draw() 

677 elif event.key == 'i': 

678 if len(self.pulses) > 0: 

679 self.fmax *= 2 

680 self.axs[-1].set_ylim(0.0, self.fmax) 

681 self.fig.canvas.draw() 

682 elif event.key == 'I': 

683 if len(self.pulses) > 0: 

684 self.fmax /= 2 

685 self.axs[-1].set_ylim(0.0, self.fmax) 

686 self.fig.canvas.draw() 

687 elif event.key in ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']: 

688 cc = int(event.key) 

689 # TODO: this is not yet what we want: 

690 """ 

691 if cc < self.channels: 

692 self.axs[cc].set_visible(not self.axs[cc].get_visible()) 

693 self.fig.canvas.draw() 

694 """ 

695 elif event.key in 'h': 

696 self.help = not self.help 

697 for ht in self.helptext: 

698 ht.set_visible(self.help) 

699 self.fig.canvas.draw() 

700 elif event.key in 'p': 

701 self.play_segment() 

702 elif event.key in 'P': 

703 self.play_all() 

704 elif event.key in 'S': 

705 self.save_segment() 

706 elif event.key in 'w': 

707 self.plot_traces() 

708 

709 def play_segment(self): 

710 t0 = int(np.round(self.toffset * self.rate)) 

711 t1 = int(np.round((self.toffset + self.twindow) * self.rate)) 

712 playdata = 1.0 * np.mean(self.data[t0:t1,self.show_channels], 1) 

713 f = 0.1 if self.twindow > 0.5 else 0.1*self.twindow 

714 fade(playdata, self.rate, f) 

715 self.audio.play(playdata, self.rate, blocking=False) 

716 

717 def play_all(self): 

718 self.audio.play(np.mean(self.data[:,self.show_channels], 1), 

719 self.rate, blocking=False) 

720 

721 def save_segment(self): 

722 t0s = int(np.round(self.toffset)) 

723 t1s = int(np.round(self.toffset + self.twindow)) 

724 t0 = int(np.round(self.toffset * self.rate)) 

725 t1 = int(np.round((self.toffset + self.twindow) * self.rate)) 

726 filename = self.filename.split('.')[0] 

727 if self.traces == self.channels: 

728 segment_filename = f'{filename}-{t0s:.4g}s-{t1s:.4g}s.wav' 

729 write_audio(segment_filename, self.data[t0:t1,:], self.rate) 

730 else: 

731 segment_filename = f'{filename}-{t0s:.4g}s-{t1s:.4g}s-c{self.show_channels[0]}.wav' 

732 write_audio(segment_filename, 

733 self.data[t0:t1,self.show_channels], self.rate) 

734 print('saved segment to: ' , segment_filename) 

735 

736 def plot_traces(self): 

737 splts = self.traces 

738 if len(self.pulses) > 0: 

739 splts += 1 

740 fig, axs = plt.subplots(splts, 1, squeeze=False, sharex=True, 

741 figsize=(15, 9)) 

742 axs = axs.flat 

743 fig.subplots_adjust(left=0.06, right=0.99, bottom=0.05, top=0.97, 

744 hspace=0) 

745 name = self.filename.split('.')[0] 

746 figfile = f'{name}-{self.toffset:.4g}s-traces.png' 

747 if self.traces < self.channels: 

748 figfile = f'{name}-{self.toffset:.4g}s-c{self.show_channels[0]}-traces.png' 

749 axs[0].set_title(self.filename) 

750 t0 = int(np.round(self.toffset * self.rate)) 

751 t1 = int(np.round((self.toffset + self.twindow) * self.rate)) 

752 if t1>len(self.data): 

753 t1 = len(self.data) 

754 time = np.arange(t0, t1)/self.rate 

755 if self.toffset < 1.0 and self.twindow < 1.0: 

756 axs[-1].set_xlabel('Time [ms]') 

757 for t in range(self.traces): 

758 c = self.show_channels[t] 

759 axs[t].set_xlim(1000.0 * self.toffset, 

760 1000.0 * (self.toffset + self.twindow)) 

761 axs[t].plot(1000.0 * time, self.data[t0:t1,c]) 

762 self.plot_pulses(axs, True, 1000.0) 

763 else: 

764 axs[-1].set_xlabel('Time [s]') 

765 for t in range(self.traces): 

766 c = self.show_channels[t] 

767 axs[t].set_xlim(self.toffset, self.toffset + self.twindow) 

768 axs[t].plot(time, self.data[t0:t1,c]) 

769 self.plot_pulses(axs, True, 1.0) 

770 for t in range(self.traces): 

771 c = self.show_channels[t] 

772 axs[t].set_ylim(self.ymin[t], self.ymax[t]) 

773 axs[t].set_ylabel(f'C-{c+1} [{self.unit}]') 

774 if len(self.pulses) > 0: 

775 axs[-1].set_ylabel('IP freq [Hz]') 

776 axs[-1].set_ylim(0.0, self.fmax) 

777 #for t in range(self.traces-1): 

778 # axs[t].xaxis.set_major_formatter(plt.NullFormatter()) 

779 fig.savefig(figfile, dpi=200) 

780 plt.close(fig) 

781 print('saved waveform figure to', figfile) 

782 

783 

784def short_user_warning(message, category, filename, lineno, file=None, line=''): 

785 if file is None: 

786 file = sys.stderr 

787 if category == UserWarning: 

788 file.write('%s line %d: %s\n' % ('/'.join(filename.split('/')[-2:]), lineno, message)) 

789 else: 

790 s = warnings.formatwarning(message, category, filename, lineno, line) 

791 file.write(s) 

792 

793 

794def main(cargs=None): 

795 warnings.showwarning = short_user_warning 

796 

797 # config file name: 

798 cfgfile = __package__ + '.cfg' 

799 

800 # command line arguments: 

801 if cargs is None: 

802 cargs = sys.argv[1:] 

803 parser = argparse.ArgumentParser( 

804 description='Browse mutlichannel EOD recordings.', 

805 epilog='version %s by Benda-Lab (2022-%s)' % (__version__, __year__)) 

806 parser.add_argument('--version', action='version', version=__version__) 

807 parser.add_argument('-v', action='count', dest='verbose') 

808 parser.add_argument('-c', dest='channels', default='', 

809 type=str, metavar='CHANNELS', 

810 help='Comma separated list of channels to be displayed (first channel is 0).') 

811 parser.add_argument('-t', dest='tmax', default=None, 

812 type=float, metavar='TMAX', 

813 help='Process and show only the first TMAX seconds.') 

814 parser.add_argument('-f', dest='fcutoff', default=None, 

815 type=float, metavar='FREQ', 

816 help='Cutoff frequency of optional high-pass filter.') 

817 parser.add_argument('-p', dest='pulses', action='store_true', 

818 help='detect pulse fish EODs') 

819 parser.add_argument('file', nargs=1, default='', type=str, 

820 help='name of the file with the time series data') 

821 args = parser.parse_args(cargs) 

822 filepath = args.file[0] 

823 cs = [s.strip() for s in args.channels.split(',')] 

824 channels = [int(c) for c in cs if len(c)>0] 

825 tmax = args.tmax 

826 fcutoff = args.fcutoff 

827 pulses = args.pulses 

828 

829 # set verbosity level from command line: 

830 verbose = 0 

831 if args.verbose != None: 

832 verbose = args.verbose 

833 

834 # load data: 

835 filename = os.path.basename(filepath) 

836 with DataLoader(filepath, 10*60.0, 5.0, verbose) as data: 

837 SignalPlot(data, data.rate, data.unit, filename, 

838 channels, tmax, fcutoff, pulses) 

839 

840 

841 

842if __name__ == '__main__': 

843 main()