Coverage for src/thunderfish/thunderbrowse.py: 0%

577 statements  

« prev     ^ index     » next       coverage.py v7.5.0, created at 2024-04-29 16:21 +0000

1import sys 

2import os 

3import warnings 

4import argparse 

5import numpy as np 

6from scipy.signal import butter, sosfiltfilt 

7import matplotlib.pyplot as plt 

8from audioio import PlayAudio, fade, write_audio 

9from thunderlab.dataloader import DataLoader 

10from thunderlab.eventdetection import detect_peaks, median_std_threshold 

11from .version import __version__, __year__ 

12from .pulses import detect_pulses 

13 

14 

15class SignalPlot: 

16 def __init__(self, data, samplerate, unit, filename, 

17 show_channels=[], tmax=None, fcutoff=None, 

18 pulses=False): 

19 self.filename = filename 

20 self.samplerate = samplerate 

21 self.data = data 

22 self.channels = self.data.shape[1] if len(self.data.shape) > 1 else 1 

23 self.unit = unit 

24 self.tmax = (len(self.data)-1)/self.samplerate 

25 if not tmax is None: 

26 self.tmax = tmax 

27 self.data = data[:int(tmax*self.samplerate),:] 

28 self.toffset = 0.0 

29 self.twindow = 10.0 

30 if self.twindow > self.tmax: 

31 self.twindow = np.round(2 ** (np.floor(np.log(self.tmax) / np.log(2.0)) + 1.0)) 

32 if not tmax is None: 

33 self.twindow = tmax 

34 self.pulses = np.zeros((0, 3), dtype=int) 

35 self.labels = [] 

36 self.fishes = [] 

37 self.pulse_times = [] 

38 self.pulse_gids = [] 

39 if len(show_channels) == 0: 

40 self.show_channels = np.arange(self.channels) 

41 else: 

42 self.show_channels = np.array(show_channels) 

43 self.traces = len(self.show_channels) 

44 self.ymin = -1.0 * np.ones(self.traces) 

45 self.ymax = +1.0 * np.ones(self.traces) 

46 self.fmax = 100.0 

47 self.trace_artist = [None] * self.traces 

48 self.show_gid = False 

49 self.pulse_artist = [] 

50 self.marker_artist = [None] * (self.traces + 1) 

51 self.ipis_artist = [] 

52 self.ipis_labels = [] 

53 self.figf = None 

54 self.axf = None 

55 self.pulse_colors = ['C1', 'C2', 'C3', 'C4', 'C5', 'C6', 'C7', 'C8', 'C9', 'C0'] 

56 self.help = False 

57 self.helptext = [] 

58 self.audio = PlayAudio() 

59 

60 # filter data: 

61 if not fcutoff is None: 

62 sos = butter(2, fcutoff, 'high', fs=samplerate, output='sos') 

63 self.data = sosfiltfilt(sos, self.data[:], 0) 

64 

65 # pulse detection: 

66 if pulses: 

67 # label, group, channel, peak index, trough index 

68 all_pulses = np.zeros((0, 5), dtype=int) 

69 for c in range(self.channels): 

70 #thresh = 1*np.std(self.data[:int(2*self.samplerate),c]) 

71 thresh = median_std_threshold(self.data[:,c], self.samplerate, 

72 thresh_fac=6.0) 

73 thresh = 0.01 

74 #p, t = detect_peaks(self.data[:,c], thresh) 

75 p, t, w, h = detect_pulses(self.data[:,c], self.samplerate, 

76 thresh, 

77 min_rel_slope_diff=0.25, 

78 min_width=0.0001, 

79 max_width=0.01, 

80 width_fac=5.0) 

81 # label, group, channel, peak, trough: 

82 pulses = np.hstack((np.arange(len(p))[:,np.newaxis], 

83 np.zeros((len(p), 1), dtype=int), 

84 np.ones((len(p), 1), dtype=int)*c, 

85 p[:,np.newaxis], t[:,np.newaxis])) 

86 all_pulses = np.vstack((all_pulses, pulses)) 

87 self.pulses = all_pulses[np.argsort(all_pulses[:,3]),:] 

88 # grouping over channels: 

89 max_di = int(0.0002*self.samplerate) # TODO: parameter 

90 l = -1 

91 k = 0 

92 while k < len(self.pulses): 

93 tp = self.pulses[k,3] 

94 tt = self.pulses[k,4] 

95 height = self.data[self.pulses[k,3],self.pulses[k,2]] - \ 

96 self.data[self.pulses[k,4],self.pulses[k,2]] 

97 channel_counts = np.zeros(self.channels, dtype=int) 

98 channel_counts[self.pulses[k,2]] += 1 

99 for c in range(1, 3*self.channels): 

100 if k+c >= len(self.pulses): 

101 break 

102 # pulse too far away: 

103 if channel_counts[self.pulses[k+c,2]] > 1 or \ 

104 (np.abs(self.pulses[k+c,3] - tp) > max_di and 

105 np.abs(self.pulses[k+c,3] - tt) > max_di and 

106 np.abs(self.pulses[k+c,4] - tp) > max_di and 

107 np.abs(self.pulses[k+c,4] - tt) > max_di): 

108 break 

109 channel_counts[self.pulses[k+c,2]] += 1 

110 height_kc = self.data[self.pulses[k+c,3],self.pulses[k+c,2]] - \ 

111 self.data[self.pulses[k+c,4],self.pulses[k+c,2]] 

112 # heighest pulse sets time reference: 

113 if height_kc > height: 

114 tp = self.pulses[k+c,3] 

115 tt = self.pulses[k+c,4] 

116 height = height_kc 

117 # all pulses too small: 

118 if height < 0.02: # TODO parameter 

119 self.pulses[k:k+c,0] = -1 

120 k += c 

121 continue 

122 # new label: 

123 l += 1 

124 # remove lost pulses: 

125 for j in range(c): 

126 if (np.abs(self.pulses[k+j,3] - tp) > max_di and 

127 np.abs(self.pulses[k+j,3] - tt) > max_di and 

128 np.abs(self.pulses[k+j,4] - tp) > max_di and 

129 np.abs(self.pulses[k+j,4] - tt) > max_di): 

130 self.pulses[k+j,0] = -1 

131 channel_counts[self.pulses[k+j,2]] -= 1 

132 else: 

133 self.pulses[k+j,0] = l 

134 self.pulses[k+j,1] = l 

135 # keep only the largest pulse of each channel: 

136 pulses = self.pulses[k:k+c,:] 

137 for dc in np.where(channel_counts > 1)[0]: 

138 idx = np.where(self.pulses[k:k+c,2] == dc)[0] 

139 heights = self.data[pulses[idx,3],dc] - \ 

140 self.data[pulses[idx,4],dc] 

141 for i in range(len(idx)): 

142 if i != np.argmax(heights): 

143 channel_counts[self.pulses[k+idx[i],2]] -= 1 

144 self.pulses[k+idx[i],0] = -1 

145 k += c 

146 self.pulses = self.pulses[self.pulses[:,0] >= 0,:] 

147 

148 # clustering: 

149 min_dists = [] 

150 recent = [] 

151 k = 0 

152 while k < len(self.pulses): 

153 # select pulse group: 

154 j = k 

155 gid = self.pulses[j,1] 

156 for c in range(self.channels): 

157 k += 1 

158 if k >= len(self.pulses) or \ 

159 self.pulses[k,1] != gid: 

160 break 

161 heights = np.zeros(self.channels) 

162 heights[self.pulses[j:k,2]] = \ 

163 self.data[self.pulses[j:k,3],self.pulses[j:k,2]] - \ 

164 self.data[self.pulses[j:k,4],self.pulses[j:k,2]] 

165 # time of largest pulse: 

166 pulse_time = self.pulses[j+np.argmax(heights[self.pulses[j:k,2]]),3] 

167 # assign to cluster: 

168 if len(self.pulse_times) == 0: 

169 label = len(self.pulse_times) 

170 self.pulse_times.append([]) 

171 self.pulse_gids.append([]) 

172 else: 

173 # compute metrics of recent fishes: 

174 # mean relative height difference: 

175 dists = np.array([np.mean(np.abs(hh - heights)/np.max(hh)) 

176 for ll, tt, hh in recent]) 

177 thresh = 0.1 # TODO: make parameter 

178 # distance between pulses: 

179 ipis = np.array([(pulse_time - tt)/self.samplerate 

180 for ll, tt, hh in recent]) 

181 ## how can ipis be 0, or just one sample? 

182 ##if len(ipis[ipis<0.001]) > 0: 

183 ## print(ipis[ipis<0.001]) 

184 # ensure minimum IP distance: 

185 dists[1/ipis > 300.0] = 2*np.max(dists) # TODO: make parameter 

186 # minimum ditance: 

187 min_dist_idx = np.argmin(dists) 

188 min_dists.append(dists[min_dist_idx]) 

189 if dists[min_dist_idx] < thresh: 

190 label = recent[min_dist_idx][0] 

191 else: 

192 label = len(self.pulse_times) 

193 self.pulse_times.append([]) 

194 self.pulse_gids.append([]) 

195 self.pulses[j:k,0] = label 

196 self.pulse_times[label].append(pulse_time) 

197 self.pulse_gids[label].append(gid) 

198 self.fishes.append([label, pulse_time, heights]) 

199 recent.append([label, pulse_time, heights]) 

200 # remove old fish: 

201 for i, (ll, tt, hh) in enumerate(recent): 

202 # TODO: make parameter: 

203 if (pulse_time - tt)/self.samplerate <= 0.2: 

204 recent = recent[i:] 

205 break 

206 # only consider the n most recent pulses of a fish: 

207 n = 5 # TODO make parameter 

208 labels = np.array([ll for ll, tt, hh in recent]) 

209 if np.sum(labels == label) > n: 

210 del recent[np.where(labels == label)[0][0]] 

211 # pulse times to arrays: 

212 for k in range(len(self.pulse_times)): 

213 self.pulse_times[k] = np.array(self.pulse_times[k]) 

214 

215 

216 

217 """ 

218 # find temporally missing pulses: 

219 npulses = np.array([len(pts) for pts in self.pulse_times], 

220 dtype=int) 

221 idx = np.argsort(npulses) 

222 for i in range(len(idx)): 

223 li = idx[len(idx)-1-i] 

224 if len(self.pulse_times[li]) < 10 or \ 

225 len(self.pulse_times[li])/npulses[li] < 0.5: 

226 continue 

227 ipis = np.diff(self.pulse_times[li]) 

228 n = 4 # TODO: make parameter 

229 k = 0 

230 while k < len(ipis)-n: 

231 mipi = np.median(ipis[k:k+n]) 

232 if ipis[k+n-2] > 1.8*mipi: 

233 # search for pulse closest to pt: 

234 pt = self.pulse_times[li][k+n-2] + mipi 

235 mlj = -1 

236 mpj = -1 

237 mdj = 10*mipi 

238 for lj in range(len(self.pulse_times)): 

239 if lj == li or len(self.pulse_times[lj]) == 0: 

240 continue 

241 pj = np.argmin(np.abs(self.pulse_times[lj] - pt)) 

242 dj = np.abs(self.pulse_times[lj][pj] - pt) 

243 if dj < int(0.001*self.samplerate) and dj < mdj: 

244 mdj = dj 

245 mpj = pj 

246 mlj = lj 

247 if mlj >= 0: 

248 # there is a pulse close to pt: 

249 ptj = self.pulse_times[mlj][mpj] 

250 pulses = self.pulses[self.pulses[:,0] == mlj,:] 

251 gid = pulses[np.argmin(np.abs(pulses[:,3] - ptj)),1] 

252 self.pulse_times[li] = np.insert(self.pulse_times[li], k+n-1, ptj) 

253 self.pulse_gids[li].insert(k+n-1, gid) 

254 # maybe don't delete but always duplicate and flag it: 

255 if False: # can be deleted 

256 self.pulse_times[mlj] = np.delete(self.pulse_times[mlj], mpj) 

257 self.pulse_gids[mlj].pop(mpj) 

258 self.pulses[self.pulses[:,1] == gid,0] = li 

259 else: # pulse needs to be duplicated: 

260 self.pulses[self.pulses[:,1] == gid,0] = li 

261 ipis = np.diff(self.pulse_times[li]) 

262 k += 1 

263 

264 

265  

266 # clean up pulses: 

267 for l in range(len(self.pulse_times)): 

268 if len(self.pulse_times[l])/npulses[l] < 0.5: 

269 self.pulse_times[l] = np.array([]) 

270 self.pulse_gids[l] = [] 

271 self.pulses[self.pulses[:,0] == l,0] = -1 

272 self.pulses = self.pulses[self.pulses[:,0] >= 0,:] 

273 """ 

274 

275 """ 

276 # remove labels that are too close to others: 

277 widths = np.zeros(len(self.pulse_times), dtype=int) 

278 for k in range(len(self.pulse_times)): 

279 widths[k] = int(np.mean(np.abs(self.pulses[self.pulses[:,0] == k,3] - self.pulses[self.pulses[:,0] == k,4]))) 

280 for k in range(len(self.pulse_times)): 

281 if len(self.pulse_times[k]) > 1: 

282 for j in range(k+1, len(self.pulse_times)): 

283 if len(self.pulse_times[j]) > 1: 

284 di = 10*max(widths[k], widths[j]) 

285 dts = np.array([np.min(np.abs(self.pulse_times[k] - pt)) for pt in self.pulse_times[j]]) 

286 if k == 1 and j == 2: 

287 print(di, np.sum(dts < di), len(dts)) 

288 plt.hist(dts, 50) 

289 plt.show() 

290 if np.sum(dts < 2*max_di)/len(dts) > 0.6: 

291 r = k 

292 if np.sum(self.fishes[k][2]) > np.sum(self.fishes[j][2]): 

293 r = j 

294 self.pulse_times[r] = np.array([]) 

295 self.pulses[self.pulses[:,0] == r] = -1 

296 self.fishes[r] = [] 

297 self.pulses = self.pulses[self.pulses[:,0] >= 0,:] 

298 """ 

299 # all labels: 

300 self.labels = np.unique(self.pulses[:,0]) 

301 # report: 

302 print(f'found {len(self.pulse_times)} fish:') 

303 for k in range(len(self.pulse_times)): 

304 print(f'{k:3d}: {len(self.pulse_times[k]):5d} pulses') 

305 ## plot histogtram of distances: 

306 #plt.hist(min_dists, 100) 

307 #plt.show() 

308 ## plot features: 

309 """ 

310 nn = np.array([(k, len(self.pulse_times[k])) 

311 for k in range(len(self.pulse_times))]) 

312 fig, axs = plt.subplots(5, 5, figsize=(15, 9), 

313 constrained_layout=True) 

314 ni = np.argsort(nn[:,1]) # largest cluster ... 

315 ln = np.sort(nn[ni[-axs.size:],0]) # ... sort by label 

316 for l, ax in zip(ln, axs.flat): 

317 h = np.array([hh for ll, tt, hh in self.fishes if ll == l]) 

318 ax.plot(h.T, 'o-', ms=2, lw=0.5, 

319 color=self.pulse_colors[l%len(self.pulse_colors)]) 

320 ax.text(0.05, 0.9, f'label: {l}', transform=ax.transAxes) 

321 """ 

322 

323 # set key bindings: 

324 plt.rcParams['keymap.fullscreen'] = 'f' 

325 plt.rcParams['keymap.pan'] = 'ctrl+m' 

326 plt.rcParams['keymap.quit'] = 'ctrl+w, alt+q, q' 

327 plt.rcParams['keymap.yscale'] = '' 

328 plt.rcParams['keymap.xscale'] = '' 

329 plt.rcParams['keymap.grid'] = '' 

330 #plt.rcParams['keymap.all_axes'] = '' 

331 

332 # the figure: 

333 plt.ioff() 

334 splts = self.traces 

335 if len(self.pulses) > 0: 

336 splts += 1 

337 self.fig, self.axs = plt.subplots(splts, 1, squeeze=False, 

338 figsize=(15, 9), sharex=True) 

339 self.axs = self.axs.flat 

340 if self.traces == self.channels: 

341 self.fig.canvas.manager.set_window_title(self.filename) 

342 else: 

343 cs = ' c%d' % self.show_channels[0] 

344 self.fig.canvas.manager.set_window_title(self.filename + ' ' + cs) 

345 self.fig.canvas.mpl_connect('key_press_event', self.keypress) 

346 self.fig.canvas.mpl_connect('resize_event', self.resize) 

347 self.fig.canvas.mpl_connect('pick_event', self.on_pick) 

348 # trace plots: 

349 for t in range(self.traces): 

350 self.axs[t].set_ylabel(f'C-{self.show_channels[t]+1} [{self.unit}]') 

351 #for t in range(self.traces-1): 

352 # self.axs[t].xaxis.set_major_formatter(plt.NullFormatter()) 

353 if len(self.pulses) > 0: 

354 self.axs[-1].set_ylim(0, self.fmax) 

355 self.axs[-1].set_ylabel('IP freq [Hz]') 

356 self.axs[-1].set_xlabel('Time [s]') 

357 ht = self.axs[0].text(0.98, 0.05, '(ctrl+) page and arrow up, down, home, end: scroll', ha='right', 

358 transform=self.axs[0].transAxes) 

359 self.helptext.append(ht) 

360 ht = self.axs[0].text(0.98, 0.1, '+, -, X, x: zoom time in/out', ha='right', transform=self.axs[0].transAxes) 

361 self.helptext.append(ht) 

362 ht = self.axs[0].text(0.98, 0.2, 'y, Y, v, V, ctrl+v, ctrl+V: zoom amplitudes out/in/max/default/max per trace/global max per trace', ha='right', transform=self.axs[0].transAxes) 

363 self.helptext.append(ht) 

364 ht = self.axs[0].text(0.98, 0.3, 'i, I: zoom IPI frequency in/out', ha='right', transform=self.axs[0].transAxes) 

365 self.helptext.append(ht) 

366 ht = self.axs[0].text(0.98, 0.4, 'p, P: play audio (display, all)', ha='right', transform=self.axs[0].transAxes) 

367 self.helptext.append(ht) 

368 ht = self.axs[0].text(0.98, 0.5, 'f: full screen', ha='right', transform=self.axs[0].transAxes) 

369 self.helptext.append(ht) 

370 ht = self.axs[0].text(0.98, 0.6, 'w: plot waveforms into png file', ha='right', transform=self.axs[0].transAxes) 

371 self.helptext.append(ht) 

372 ht = self.axs[0].text(0.98, 0.7, 'S: save audiosegment', ha='right', transform=self.axs[0].transAxes) 

373 self.helptext.append(ht) 

374 ht = self.axs[0].text(0.98, 0.8, 'q: quit', ha='right', transform=self.axs[0].transAxes) 

375 self.helptext.append(ht) 

376 ht = self.axs[0].text(0.98, 0.9, 'h: toggle this help', ha='right', transform=self.axs[0].transAxes) 

377 self.helptext.append(ht) 

378 # plot: 

379 for ht in self.helptext: 

380 ht.set_visible(self.help) 

381 self.update_plots() 

382 # feature plot: 

383 if len(self.labels) > 0: 

384 self.figf, self.axf = plt.subplots() 

385 plt.show() 

386 

387 def __del__(self): 

388 pass 

389 #self.audio.close() 

390 

391 def plot_pulses(self, axs, plot=True, tfac=1.0): 

392 

393 def plot_pulse_traces(pulses, i, pak): 

394 for t in range(self.traces): 

395 c = self.show_channels[t] 

396 p = pulses[pulses[:,2] == c,3] 

397 if len(p) == 0: 

398 continue 

399 if plot or pak >= len(self.pulse_artist): 

400 pa, = axs[t].plot(tfac*p/self.samplerate, 

401 self.data[p,c], 'o', picker=5, 

402 color=self.pulse_colors[i%len(self.pulse_colors)]) 

403 if not plot: 

404 self.pulse_artist.append(pa) 

405 else: 

406 self.pulse_artist[pak].set_data(tfac*p/self.samplerate, 

407 self.data[p,c]) 

408 self.pulse_artist[pak].set_color(self.pulse_colors[i%len(self.pulse_colors)]) 

409 #if len(p) > 1 and len(p) <= 10: 

410 # self.pulse_artist[pak].set_markersize(15) 

411 pak += 1 

412 return pak 

413 

414 # pulses: 

415 pak = 0 

416 if self.show_gid: 

417 for g in range(len(self.pulse_colors)): 

418 pulses = self.pulses[self.pulses[:,1] % len(self.pulse_colors) == g,:] 

419 pak = plot_pulse_traces(pulses, g, pak) 

420 else: 

421 for l in self.labels: 

422 pulses = self.pulses[self.pulses[:,0] == l,:] 

423 pak = plot_pulse_traces(pulses, l, pak) 

424 while pak < len(self.pulse_artist): 

425 self.pulse_artist[pak].set_data([], []) 

426 pak += 1 

427 # ipis: 

428 for l in self.labels: 

429 if l < len(self.pulse_times): 

430 pt = self.pulse_times[l]/self.samplerate 

431 if len(pt) > 10: 

432 if plot or not l in self.ipis_labels: 

433 pa, = axs[-1].plot(tfac*pt[:-1], 1.0/np.diff(pt), 

434 '-o', picker=5, 

435 color=self.pulse_colors[l%len(self.pulse_colors)]) 

436 if not plot: 

437 self.ipis_artist.append(pa) 

438 self.ipis_labels.append(l) 

439 else: 

440 iak = self.ipis_labels.index(l) 

441 self.ipis_artist[iak].set_data(tfac*pt[:-1], 

442 1.0/np.diff(pt)) 

443 

444 def update_plots(self): 

445 t0 = int(np.round(self.toffset * self.samplerate)) 

446 t1 = int(np.round((self.toffset + self.twindow) * self.samplerate)) 

447 if t1 > len(self.data): 

448 t1 = len(self.data) 

449 time = np.arange(t0, t1) / self.samplerate 

450 for t in range(self.traces): 

451 c = self.show_channels[t] 

452 self.axs[t].set_xlim(self.toffset, self.toffset + self.twindow) 

453 if self.trace_artist[t] == None: 

454 self.trace_artist[t], = self.axs[t].plot(time, self.data[t0:t1,c]) 

455 else: 

456 self.trace_artist[t].set_data(time, self.data[t0:t1,c]) 

457 if t1 - t0 < 200: 

458 self.trace_artist[t].set_marker('o') 

459 self.trace_artist[t].set_markersize(3) 

460 else: 

461 self.trace_artist[t].set_marker('None') 

462 self.axs[t].set_ylim(self.ymin[t], self.ymax[t]) 

463 self.plot_pulses(self.axs, False) 

464 self.fig.canvas.draw() 

465 

466 def on_pick(self, event): 

467 # index of pulse artist: 

468 pk = -1 

469 for k, pa in enumerate(self.pulse_artist): 

470 if event.artist == pa: 

471 pk = k 

472 break 

473 li = -1 

474 pi = -1 

475 if pk >= 0: 

476 # find label and pulses of pulse artist: 

477 ll = self.labels[pk//self.traces] 

478 cc = self.show_channels[pk % self.traces] 

479 pulses = self.pulses[self.pulses[:,0] == ll,:] 

480 gid = pulses[pulses[:,2] == cc,1][event.ind[0]] 

481 if ll in self.ipis_labels: 

482 li = self.ipis_labels.index(ll) 

483 pi = self.pulse_gids[ll].index(gid) 

484 else: 

485 ik = -1 

486 for k, ia in enumerate(self.ipis_artist): 

487 if event.artist == ia: 

488 ik = k 

489 break 

490 if ik < 0: 

491 return 

492 li = ik 

493 ll = self.ipis_labels[li] 

494 pi = event.ind[0] 

495 gid = self.pulse_gids[ll][pi] 

496 # mark pulses: 

497 pulses = self.pulses[self.pulses[:,0] == ll,:] 

498 pulses = pulses[pulses[:,1] == gid,:] 

499 for t in range(self.traces): 

500 c = self.show_channels[t] 

501 pt = pulses[pulses[:,2] == c,3] 

502 if len(pt) > 0: 

503 if self.marker_artist[t] is None: 

504 pa, = self.axs[t].plot(pt[0]/self.samplerate, 

505 self.data[pt[0],c], 'o', ms=10, 

506 color=self.pulse_colors[ll%len(self.pulse_colors)]) 

507 self.marker_artist[t] = pa 

508 else: 

509 self.marker_artist[t].set_data(pt[0]/self.samplerate, 

510 self.data[pt[0],c]) 

511 self.marker_artist[t].set_color(self.pulse_colors[ll%len(self.pulse_colors)]) 

512 elif self.marker_artist[t] is not None: 

513 self.marker_artist[t].set_data([], []) 

514 # mark ipi: 

515 pt0 = -1.0 

516 pt1 = -1.0 

517 pf = -1.0 

518 if pi >= 0: 

519 pt0 = self.pulse_times[ll][pi]/self.samplerate 

520 pt1 = self.pulse_times[ll][pi+1]/self.samplerate 

521 pf = 1.0/(pt1-pt0) 

522 if self.marker_artist[self.traces] is None: 

523 pa, = self.axs[self.traces].plot(pt0, pf, 'o', ms=10, 

524 color=self.pulse_colors[ll%len(self.pulse_colors)]) 

525 self.marker_artist[self.traces] = pa 

526 else: 

527 self.marker_artist[self.traces].set_data(pt0, pf) 

528 self.marker_artist[self.traces].set_color(self.pulse_colors[ll%len(self.pulse_colors)]) 

529 elif not self.marker_artist[self.traces] is None: 

530 self.marker_artist[self.traces].set_data([], []) 

531 self.fig.canvas.draw() 

532 # show features: 

533 if not self.axf is None and not self.fig is None: 

534 heights = np.zeros(self.channels) 

535 heights[pulses[:,2]] = \ 

536 self.data[pulses[:,3],pulses[:,2]] - \ 

537 self.data[pulses[:,4],pulses[:,2]] 

538 self.axf.plot(heights, color=self.pulse_colors[ll%len(self.pulse_colors)]) 

539 print(f'label={ll:4d} gid={gid:5d} t={pt0:8.4f}s') 

540 self.figf.canvas.draw() 

541 

542 def resize(self, event): 

543 # print('resized', event.width, event.height) 

544 leftpixel = 80.0 

545 rightpixel = 20.0 

546 bottompixel = 50.0 

547 toppixel = 20.0 

548 x0 = leftpixel / event.width 

549 x1 = 1.0 - rightpixel / event.width 

550 y0 = bottompixel / event.height 

551 y1 = 1.0 - toppixel / event.height 

552 self.fig.subplots_adjust(left=x0, right=x1, bottom=y0, top=y1, 

553 hspace=0) 

554 

555 def keypress(self, event): 

556 # print('pressed', event.key) 

557 if event.key in '+=X': 

558 if self.twindow * self.samplerate > 20: 

559 self.twindow *= 0.5 

560 self.update_plots() 

561 elif event.key in '-x': 

562 if self.twindow < self.tmax: 

563 self.twindow *= 2.0 

564 self.update_plots() 

565 elif event.key == 'pagedown': 

566 if self.toffset + 0.5 * self.twindow < self.tmax: 

567 self.toffset += 0.5 * self.twindow 

568 self.update_plots() 

569 elif event.key == 'pageup': 

570 if self.toffset > 0: 

571 self.toffset -= 0.5 * self.twindow 

572 if self.toffset < 0.0: 

573 self.toffset = 0.0 

574 self.update_plots() 

575 elif event.key == 'ctrl+pagedown': 

576 if self.toffset + 5.0 * self.twindow < self.tmax: 

577 self.toffset += 5.0 * self.twindow 

578 self.update_plots() 

579 elif event.key == 'ctrl+pageup': 

580 if self.toffset > 0: 

581 self.toffset -= 5.0 * self.twindow 

582 if self.toffset < 0.0: 

583 self.toffset = 0.0 

584 self.update_plots() 

585 elif event.key == 'down': 

586 if self.toffset + self.twindow < self.tmax: 

587 self.toffset += 0.05 * self.twindow 

588 self.update_plots() 

589 elif event.key == 'up': 

590 if self.toffset > 0.0: 

591 self.toffset -= 0.05 * self.twindow 

592 if self.toffset < 0.0: 

593 self.toffset = 0.0 

594 self.update_plots() 

595 elif event.key == 'home': 

596 if self.toffset > 0.0: 

597 self.toffset = 0.0 

598 self.update_plots() 

599 elif event.key == 'end': 

600 toffs = np.floor(self.tmax / self.twindow) * self.twindow 

601 if self.tmax - toffs <= 0.0: 

602 toffs -= self.twindow 

603 if self.tmax - toffs < self.twindow/2: 

604 toffs -= self.twindow/2 

605 if self.toffset < toffs: 

606 self.toffset = toffs 

607 self.update_plots() 

608 elif event.key == 'y': 

609 for t in range(self.traces): 

610 h = self.ymax[t] - self.ymin[t] 

611 c = 0.5 * (self.ymax[t] + self.ymin[t]) 

612 self.ymin[t] = c - h 

613 self.ymax[t] = c + h 

614 self.axs[t].set_ylim(self.ymin[t], self.ymax[t]) 

615 self.fig.canvas.draw() 

616 elif event.key == 'Y': 

617 for t in range(self.traces): 

618 h = 0.25 * (self.ymax[t] - self.ymin[t]) 

619 c = 0.5 * (self.ymax[t] + self.ymin[t]) 

620 self.ymin[t] = c - h 

621 self.ymax[t] = c + h 

622 self.axs[t].set_ylim(self.ymin[t], self.ymax[t]) 

623 self.fig.canvas.draw() 

624 elif event.key == 'v': 

625 t0 = int(np.round(self.toffset * self.samplerate)) 

626 t1 = int(np.round((self.toffset + self.twindow) * self.samplerate)) 

627 min = np.min(self.data[t0:t1,self.show_channels]) 

628 max = np.max(self.data[t0:t1,self.show_channels]) 

629 h = 0.53 * (max - min) 

630 c = 0.5 * (max + min) 

631 self.ymin[:] = c - h 

632 self.ymax[:] = c + h 

633 for t in range(self.traces): 

634 self.axs[t].set_ylim(self.ymin[t], self.ymax[t]) 

635 self.fig.canvas.draw() 

636 elif event.key == 'ctrl+v': 

637 t0 = int(np.round(self.toffset * self.samplerate)) 

638 t1 = int(np.round((self.toffset + self.twindow) * self.samplerate)) 

639 for t in range(self.traces): 

640 min = np.min(self.data[t0:t1,self.show_channels[t]]) 

641 max = np.max(self.data[t0:t1,self.show_channels[t]]) 

642 h = 0.53 * (max - min) 

643 c = 0.5 * (max + min) 

644 self.ymin[t] = c - h 

645 self.ymax[t] = c + h 

646 self.axs[t].set_ylim(self.ymin[t], self.ymax[t]) 

647 self.fig.canvas.draw() 

648 elif event.key == 'ctrl+V': 

649 for t in range(self.traces): 

650 min = np.min(self.data[:,self.show_channels[t]]) 

651 max = np.max(self.data[:,self.show_channels[t]]) 

652 h = 0.53 * (max - min) 

653 c = 0.5 * (max + min) 

654 self.ymin[t] = c - h 

655 self.ymax[t] = c + h 

656 self.axs[t].set_ylim(self.ymin[t], self.ymax[t]) 

657 self.fig.canvas.draw() 

658 elif event.key == 'V': 

659 self.ymin[:] = -1.0 

660 self.ymax[:] = +1.0 

661 for t in range(self.traces): 

662 self.axs[t].set_ylim(self.ymin[t], self.ymax[t]) 

663 self.fig.canvas.draw() 

664 elif event.key == 'c': 

665 for t in range(self.traces): 

666 dy = self.ymax[t] - self.ymin[t] 

667 self.ymin[t] = -dy/2 

668 self.ymax[t] = +dy/2 

669 self.axs[t].set_ylim(self.ymin[t], self.ymax[t]) 

670 self.fig.canvas.draw() 

671 elif event.key == 'g': 

672 self.show_gid = not self.show_gid 

673 self.plot_pulses(self.axs, False) 

674 self.fig.canvas.draw() 

675 elif event.key == 'i': 

676 if len(self.pulses) > 0: 

677 self.fmax *= 2 

678 self.axs[-1].set_ylim(0.0, self.fmax) 

679 self.fig.canvas.draw() 

680 elif event.key == 'I': 

681 if len(self.pulses) > 0: 

682 self.fmax /= 2 

683 self.axs[-1].set_ylim(0.0, self.fmax) 

684 self.fig.canvas.draw() 

685 elif event.key in ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']: 

686 cc = int(event.key) 

687 # TODO: this is not yet what we want: 

688 """ 

689 if cc < self.channels: 

690 self.axs[cc].set_visible(not self.axs[cc].get_visible()) 

691 self.fig.canvas.draw() 

692 """ 

693 elif event.key in 'h': 

694 self.help = not self.help 

695 for ht in self.helptext: 

696 ht.set_visible(self.help) 

697 self.fig.canvas.draw() 

698 elif event.key in 'p': 

699 self.play_segment() 

700 elif event.key in 'P': 

701 self.play_all() 

702 elif event.key in 'S': 

703 self.save_segment() 

704 elif event.key in 'w': 

705 self.plot_traces() 

706 

707 def play_segment(self): 

708 t0 = int(np.round(self.toffset * self.samplerate)) 

709 t1 = int(np.round((self.toffset + self.twindow) * self.samplerate)) 

710 playdata = 1.0 * np.mean(self.data[t0:t1,self.show_channels], 1) 

711 f = 0.1 if self.twindow > 0.5 else 0.1*self.twindow 

712 fade(playdata, self.samplerate, f) 

713 self.audio.play(playdata, self.samplerate, blocking=False) 

714 

715 def play_all(self): 

716 self.audio.play(np.mean(self.data[:,self.show_channels], 1), 

717 self.samplerate, blocking=False) 

718 

719 def save_segment(self): 

720 t0s = int(np.round(self.toffset)) 

721 t1s = int(np.round(self.toffset + self.twindow)) 

722 t0 = int(np.round(self.toffset * self.samplerate)) 

723 t1 = int(np.round((self.toffset + self.twindow) * self.samplerate)) 

724 filename = self.filename.split('.')[0] 

725 if self.traces == self.channels: 

726 segment_filename = f'{filename}-{t0s:.4g}s-{t1s:.4g}s.wav' 

727 write_audio(segment_filename, self.data[t0:t1,:], self.samplerate) 

728 else: 

729 segment_filename = f'{filename}-{t0s:.4g}s-{t1s:.4g}s-c{self.show_channels[0]}.wav' 

730 write_audio(segment_filename, 

731 self.data[t0:t1,self.show_channels], self.samplerate) 

732 print('saved segment to: ' , segment_filename) 

733 

734 def plot_traces(self): 

735 splts = self.traces 

736 if len(self.pulses) > 0: 

737 splts += 1 

738 fig, axs = plt.subplots(splts, 1, squeeze=False, sharex=True, 

739 figsize=(15, 9)) 

740 axs = axs.flat 

741 fig.subplots_adjust(left=0.06, right=0.99, bottom=0.05, top=0.97, 

742 hspace=0) 

743 name = self.filename.split('.')[0] 

744 figfile = f'{name}-{self.toffset:.4g}s-traces.png' 

745 if self.traces < self.channels: 

746 figfile = f'{name}-{self.toffset:.4g}s-c{self.show_channels[0]}-traces.png' 

747 axs[0].set_title(self.filename) 

748 t0 = int(np.round(self.toffset * self.samplerate)) 

749 t1 = int(np.round((self.toffset + self.twindow) * self.samplerate)) 

750 if t1>len(self.data): 

751 t1 = len(self.data) 

752 time = np.arange(t0, t1)/self.samplerate 

753 if self.toffset < 1.0 and self.twindow < 1.0: 

754 axs[-1].set_xlabel('Time [ms]') 

755 for t in range(self.traces): 

756 c = self.show_channels[t] 

757 axs[t].set_xlim(1000.0 * self.toffset, 

758 1000.0 * (self.toffset + self.twindow)) 

759 axs[t].plot(1000.0 * time, self.data[t0:t1,c]) 

760 self.plot_pulses(axs, True, 1000.0) 

761 else: 

762 axs[-1].set_xlabel('Time [s]') 

763 for t in range(self.traces): 

764 c = self.show_channels[t] 

765 axs[t].set_xlim(self.toffset, self.toffset + self.twindow) 

766 axs[t].plot(time, self.data[t0:t1,c]) 

767 self.plot_pulses(axs, True, 1.0) 

768 for t in range(self.traces): 

769 c = self.show_channels[t] 

770 axs[t].set_ylim(self.ymin[t], self.ymax[t]) 

771 axs[t].set_ylabel(f'C-{c+1} [{self.unit}]') 

772 if len(self.pulses) > 0: 

773 axs[-1].set_ylabel('IP freq [Hz]') 

774 axs[-1].set_ylim(0.0, self.fmax) 

775 #for t in range(self.traces-1): 

776 # axs[t].xaxis.set_major_formatter(plt.NullFormatter()) 

777 fig.savefig(figfile, dpi=200) 

778 plt.close(fig) 

779 print('saved waveform figure to', figfile) 

780 

781 

782def short_user_warning(message, category, filename, lineno, file=None, line=''): 

783 if file is None: 

784 file = sys.stderr 

785 if category == UserWarning: 

786 file.write('%s line %d: %s\n' % ('/'.join(filename.split('/')[-2:]), lineno, message)) 

787 else: 

788 s = warnings.formatwarning(message, category, filename, lineno, line) 

789 file.write(s) 

790 

791 

792def main(cargs=None): 

793 warnings.showwarning = short_user_warning 

794 

795 # config file name: 

796 cfgfile = __package__ + '.cfg' 

797 

798 # command line arguments: 

799 if cargs is None: 

800 cargs = sys.argv[1:] 

801 parser = argparse.ArgumentParser( 

802 description='Browse mutlichannel EOD recordings.', 

803 epilog='version %s by Benda-Lab (2022-%s)' % (__version__, __year__)) 

804 parser.add_argument('--version', action='version', version=__version__) 

805 parser.add_argument('-v', action='count', dest='verbose') 

806 parser.add_argument('-c', dest='channels', default='', 

807 type=str, metavar='CHANNELS', 

808 help='Comma separated list of channels to be displayed (first channel is 0).') 

809 parser.add_argument('-t', dest='tmax', default=None, 

810 type=float, metavar='TMAX', 

811 help='Process and show only the first TMAX seconds.') 

812 parser.add_argument('-f', dest='fcutoff', default=None, 

813 type=float, metavar='FREQ', 

814 help='Cutoff frequency of optional high-pass filter.') 

815 parser.add_argument('-p', dest='pulses', action='store_true', 

816 help='detect pulse fish EODs') 

817 parser.add_argument('file', nargs=1, default='', type=str, 

818 help='name of the file with the time series data') 

819 args = parser.parse_args(cargs) 

820 filepath = args.file[0] 

821 cs = [s.strip() for s in args.channels.split(',')] 

822 channels = [int(c) for c in cs if len(c)>0] 

823 tmax = args.tmax 

824 fcutoff = args.fcutoff 

825 pulses = args.pulses 

826 

827 # set verbosity level from command line: 

828 verbose = 0 

829 if args.verbose != None: 

830 verbose = args.verbose 

831 

832 # load data: 

833 filename = os.path.basename(filepath) 

834 with DataLoader(filepath, 10*60.0, 5.0, verbose) as data: 

835 SignalPlot(data, data.samplerate, data.unit, filename, 

836 channels, tmax, fcutoff, pulses) 

837 

838 

839 

840if __name__ == '__main__': 

841 main()