Module thunderfish.thunderbrowse

Expand source code
import sys
import os
import warnings
import argparse
import numpy as np
from scipy.signal import butter, sosfiltfilt
import matplotlib.pyplot as plt
from audioio import PlayAudio, fade, write_audio
from .eventdetection import detect_peaks, median_std_threshold
from .pulses import detect_pulses
from .version import __version__, __year__
from .dataloader import DataLoader


class SignalPlot:
    def __init__(self, data, samplerate, unit, filename,
                 show_channels=[], tmax=None, fcutoff=None,
                 pulses=False):
        self.filename = filename
        self.samplerate = samplerate
        self.data = data
        self.channels = self.data.shape[1] if len(self.data.shape) > 1 else 1
        self.unit = unit
        self.tmax = (len(self.data)-1)/self.samplerate
        if not tmax is None:
            self.tmax = tmax
            self.data = data[:int(tmax*self.samplerate),:]
        self.toffset = 0.0
        self.twindow = 10.0
        if self.twindow > self.tmax:
            self.twindow = np.round(2 ** (np.floor(np.log(self.tmax) / np.log(2.0)) + 1.0))
            if not tmax is None:
                self.twindow = tmax
        self.pulses = np.zeros((0, 3), dtype=int)
        self.labels = []
        self.fishes = []
        self.pulse_times = []
        self.pulse_gids = []
        if len(show_channels) == 0:
            self.show_channels = np.arange(self.channels)
        else:
            self.show_channels = np.array(show_channels)
        self.traces = len(self.show_channels)
        self.ymin = -1.0 * np.ones(self.traces)
        self.ymax = +1.0 * np.ones(self.traces)
        self.fmax = 100.0
        self.trace_artist = [None] * self.traces
        self.show_gid = False
        self.pulse_artist = []
        self.marker_artist = [None] * (self.traces + 1)
        self.ipis_artist = []
        self.ipis_labels = []
        self.figf = None
        self.axf = None
        self.pulse_colors = ['C1', 'C2', 'C3', 'C4', 'C5', 'C6', 'C7', 'C8', 'C9', 'C0']
        self.help = False
        self.helptext = []
        self.audio = PlayAudio()

        # filter data:
        if not fcutoff is None:
            sos = butter(2, fcutoff, 'high', fs=samplerate, output='sos')
            self.data = sosfiltfilt(sos, self.data[:], 0)

        # pulse detection:
        if pulses:
            # label, group, channel, peak index, trough index
            all_pulses = np.zeros((0, 5), dtype=int)
            for c in range(self.channels):
                #thresh = 1*np.std(self.data[:int(2*self.samplerate),c])
                thresh = median_std_threshold(self.data[:,c], self.samplerate,
                                              thresh_fac=6.0)
                thresh = 0.01
                #p, t = detect_peaks(self.data[:,c], thresh)
                p, t, w, h = detect_pulses(self.data[:,c], self.samplerate,
                                           thresh,
                                           min_rel_slope_diff=0.25,
                                           min_width=0.0001,
                                           max_width=0.01,
                                           width_fac=5.0)
                # label, group, channel, peak, trough:
                pulses = np.hstack((np.arange(len(p))[:,np.newaxis],
                                    np.zeros((len(p), 1), dtype=int),
                                    np.ones((len(p), 1), dtype=int)*c,
                                    p[:,np.newaxis], t[:,np.newaxis]))
                all_pulses = np.vstack((all_pulses, pulses))
            self.pulses = all_pulses[np.argsort(all_pulses[:,3]),:]
            # grouping over channels:
            max_di = int(0.0002*self.samplerate)   # TODO: parameter
            l = -1
            k = 0
            while k < len(self.pulses):
                tp = self.pulses[k,3]
                tt = self.pulses[k,4]
                height = self.data[self.pulses[k,3],self.pulses[k,2]] - \
                    self.data[self.pulses[k,4],self.pulses[k,2]]
                channel_counts = np.zeros(self.channels, dtype=int)
                channel_counts[self.pulses[k,2]] += 1
                for c in range(1, 3*self.channels):
                    if k+c >= len(self.pulses):
                        break
                    # pulse too far away:
                    if channel_counts[self.pulses[k+c,2]] > 1 or \
                       (np.abs(self.pulses[k+c,3] - tp) > max_di and
                        np.abs(self.pulses[k+c,3] - tt) > max_di and
                        np.abs(self.pulses[k+c,4] - tp) > max_di and
                        np.abs(self.pulses[k+c,4] - tt) > max_di):
                        break
                    channel_counts[self.pulses[k+c,2]] += 1
                    height_kc = self.data[self.pulses[k+c,3],self.pulses[k+c,2]] - \
                        self.data[self.pulses[k+c,4],self.pulses[k+c,2]]
                    # heighest pulse sets time reference:
                    if height_kc > height:
                        tp = self.pulses[k+c,3]
                        tt = self.pulses[k+c,4]
                        height = height_kc
                # all pulses too small:
                if height < 0.02:    # TODO parameter
                    self.pulses[k:k+c,0] = -1
                    k += c
                    continue
                # new label:
                l += 1
                # remove lost pulses:
                for j in range(c):
                    if (np.abs(self.pulses[k+j,3] - tp) > max_di and
                        np.abs(self.pulses[k+j,3] - tt) > max_di and
                        np.abs(self.pulses[k+j,4] - tp) > max_di and
                        np.abs(self.pulses[k+j,4] - tt) > max_di):
                        self.pulses[k+j,0] = -1
                        channel_counts[self.pulses[k+j,2]] -= 1
                    else:
                        self.pulses[k+j,0] = l
                        self.pulses[k+j,1] = l
                # keep only the largest pulse of each channel:
                pulses = self.pulses[k:k+c,:]
                for dc in np.where(channel_counts > 1)[0]:
                    idx = np.where(self.pulses[k:k+c,2] == dc)[0]
                    heights = self.data[pulses[idx,3],dc] - \
                        self.data[pulses[idx,4],dc]
                    for i in range(len(idx)):
                        if i != np.argmax(heights):
                            channel_counts[self.pulses[k+idx[i],2]] -= 1
                            self.pulses[k+idx[i],0] = -1
                k += c
            self.pulses = self.pulses[self.pulses[:,0] >= 0,:]

            # clustering:
            min_dists = []
            recent = []
            k = 0
            while k < len(self.pulses):
                # select pulse group:
                j = k
                gid = self.pulses[j,1]
                for c in range(self.channels):
                    k += 1
                    if k >= len(self.pulses) or \
                       self.pulses[k,1] != gid:
                        break
                heights = np.zeros(self.channels)
                heights[self.pulses[j:k,2]] = \
                    self.data[self.pulses[j:k,3],self.pulses[j:k,2]] - \
                    self.data[self.pulses[j:k,4],self.pulses[j:k,2]]
                # time of largest pulse:
                pulse_time = self.pulses[j+np.argmax(heights[self.pulses[j:k,2]]),3]
                # assign to cluster:
                if len(self.pulse_times) == 0:
                    label = len(self.pulse_times)
                    self.pulse_times.append([])
                    self.pulse_gids.append([])
                else:
                    # compute metrics of recent fishes:
                    # mean relative height difference:
                    dists = np.array([np.mean(np.abs(hh - heights)/np.max(hh))
                                        for ll, tt, hh in recent])
                    thresh = 0.1   # TODO: make parameter
                    # distance between pulses:
                    ipis = np.array([(pulse_time - tt)/self.samplerate
                                     for ll, tt, hh in recent])
                    ## how can ipis be 0, or just one sample?
                    ##if len(ipis[ipis<0.001]) > 0:
                    ##    print(ipis[ipis<0.001])
                    # ensure minimum IP distance:
                    dists[1/ipis > 300.0] = 2*np.max(dists)  # TODO: make parameter
                    # minimum ditance:
                    min_dist_idx = np.argmin(dists)
                    min_dists.append(dists[min_dist_idx])
                    if dists[min_dist_idx] < thresh:
                        label = recent[min_dist_idx][0]
                    else:
                        label = len(self.pulse_times)
                        self.pulse_times.append([])
                        self.pulse_gids.append([])
                self.pulses[j:k,0] = label
                self.pulse_times[label].append(pulse_time)
                self.pulse_gids[label].append(gid)
                self.fishes.append([label, pulse_time, heights])
                recent.append([label, pulse_time, heights])
                # remove old fish:
                for i, (ll, tt, hh) in enumerate(recent):
                    # TODO: make parameter:
                    if (pulse_time - tt)/self.samplerate <= 0.2:
                        recent = recent[i:]
                        break
                # only consider the n most recent pulses of a fish:
                n = 5    # TODO make parameter
                labels = np.array([ll for ll, tt, hh in recent])
                if np.sum(labels == label) > n:
                    del recent[np.where(labels == label)[0][0]]
            # pulse times to arrays:
            for k in range(len(self.pulse_times)):
                self.pulse_times[k] = np.array(self.pulse_times[k])


                
            """
            # find temporally missing pulses:
            npulses = np.array([len(pts) for pts in self.pulse_times],
                               dtype=int)
            idx = np.argsort(npulses)
            for i in range(len(idx)):
                li = idx[len(idx)-1-i]
                if len(self.pulse_times[li]) < 10 or \
                   len(self.pulse_times[li])/npulses[li] < 0.5:
                    continue
                ipis = np.diff(self.pulse_times[li])
                n = 4 # TODO: make parameter
                k = 0
                while k < len(ipis)-n:
                    mipi = np.median(ipis[k:k+n])
                    if ipis[k+n-2] > 1.8*mipi:
                        # search for pulse closest to pt:
                        pt = self.pulse_times[li][k+n-2] + mipi
                        mlj = -1
                        mpj = -1
                        mdj = 10*mipi
                        for lj in range(len(self.pulse_times)):
                            if lj == li or len(self.pulse_times[lj]) == 0:
                                continue
                            pj = np.argmin(np.abs(self.pulse_times[lj] - pt))
                            dj = np.abs(self.pulse_times[lj][pj] - pt)
                            if dj < int(0.001*self.samplerate) and dj < mdj:
                                mdj = dj
                                mpj = pj
                                mlj = lj
                        if mlj >= 0:
                            # there is a pulse close to pt:
                            ptj = self.pulse_times[mlj][mpj]
                            pulses = self.pulses[self.pulses[:,0] == mlj,:]
                            gid = pulses[np.argmin(np.abs(pulses[:,3] - ptj)),1]
                            self.pulse_times[li] = np.insert(self.pulse_times[li], k+n-1, ptj)
                            self.pulse_gids[li].insert(k+n-1, gid)
                            # maybe don't delete but always duplicate and flag it:
                            if False:  # can be deleted
                                self.pulse_times[mlj] = np.delete(self.pulse_times[mlj], mpj)
                                self.pulse_gids[mlj].pop(mpj)
                                self.pulses[self.pulses[:,1] == gid,0] = li
                            else:     # pulse needs to be duplicated:
                                self.pulses[self.pulses[:,1] == gid,0] = li
                            ipis = np.diff(self.pulse_times[li])
                    k += 1


                    
            # clean up pulses:
            for l in range(len(self.pulse_times)):
                if len(self.pulse_times[l])/npulses[l] < 0.5:
                    self.pulse_times[l] = np.array([])
                    self.pulse_gids[l] = []
                    self.pulses[self.pulses[:,0] == l,0] = -1
            self.pulses = self.pulses[self.pulses[:,0] >= 0,:]
            """
            
            """
            # remove labels that are too close to others:
            widths = np.zeros(len(self.pulse_times), dtype=int)
            for k in range(len(self.pulse_times)):
                widths[k] = int(np.mean(np.abs(self.pulses[self.pulses[:,0] == k,3] - self.pulses[self.pulses[:,0] == k,4])))
            for k in range(len(self.pulse_times)):
                if len(self.pulse_times[k]) > 1:
                    for j in range(k+1, len(self.pulse_times)):
                        if len(self.pulse_times[j]) > 1:
                            di = 10*max(widths[k], widths[j])
                            dts = np.array([np.min(np.abs(self.pulse_times[k] - pt)) for pt in self.pulse_times[j]])
                            if k == 1 and j == 2:
                                print(di, np.sum(dts < di), len(dts))
                                plt.hist(dts, 50)
                                plt.show()
                            if np.sum(dts < 2*max_di)/len(dts) > 0.6:
                                r = k
                                if np.sum(self.fishes[k][2]) > np.sum(self.fishes[j][2]):
                                    r = j
                                self.pulse_times[r] = np.array([])
                                self.pulses[self.pulses[:,0] == r] = -1
                                self.fishes[r] = []
            self.pulses = self.pulses[self.pulses[:,0] >= 0,:]
            """
            # all labels:
            self.labels = np.unique(self.pulses[:,0])
            # report:
            print(f'found {len(self.pulse_times)} fish:')
            for k in range(len(self.pulse_times)):
                print(f'{k:3d}: {len(self.pulse_times[k]):5d} pulses')
            ## plot histogtram of distances:
            #plt.hist(min_dists, 100)
            #plt.show()
            ## plot features:
            """
            nn = np.array([(k, len(self.pulse_times[k]))
                           for k in range(len(self.pulse_times))])
            fig, axs = plt.subplots(5, 5, figsize=(15, 9),
                                    constrained_layout=True)
            ni = np.argsort(nn[:,1])           # largest cluster ...
            ln = np.sort(nn[ni[-axs.size:],0]) # ... sort by label
            for l, ax in zip(ln, axs.flat):
                h = np.array([hh for ll, tt, hh in self.fishes if ll == l])
                ax.plot(h.T, 'o-', ms=2, lw=0.5,
                        color=self.pulse_colors[l%len(self.pulse_colors)])
                ax.text(0.05, 0.9, f'label: {l}', transform=ax.transAxes)
            """
        
        # set key bindings:
        plt.rcParams['keymap.fullscreen'] = 'f'
        plt.rcParams['keymap.pan'] = 'ctrl+m'
        plt.rcParams['keymap.quit'] = 'ctrl+w, alt+q, q'
        plt.rcParams['keymap.yscale'] = ''
        plt.rcParams['keymap.xscale'] = ''
        plt.rcParams['keymap.grid'] = ''
        #plt.rcParams['keymap.all_axes'] = ''

        # the figure:
        plt.ioff()
        splts = self.traces
        if len(self.pulses) > 0:
            splts += 1
        self.fig, self.axs = plt.subplots(splts, 1, squeeze=False,
                                          figsize=(15, 9), sharex=True)
        self.axs = self.axs.flat
        if self.traces == self.channels:
            self.fig.canvas.manager.set_window_title(self.filename)
        else:
            cs = ' c%d' % self.show_channels[0]
            self.fig.canvas.manager.set_window_title(self.filename + ' ' + cs)
        self.fig.canvas.mpl_connect('key_press_event', self.keypress)
        self.fig.canvas.mpl_connect('resize_event', self.resize)
        self.fig.canvas.mpl_connect('pick_event', self.on_pick)
        # trace plots:
        for t in range(self.traces):
            self.axs[t].set_ylabel(f'C-{self.show_channels[t]+1} [{self.unit}]')
        #for t in range(self.traces-1):
        #    self.axs[t].xaxis.set_major_formatter(plt.NullFormatter())
        if len(self.pulses) > 0:
            self.axs[-1].set_ylim(0, self.fmax)
            self.axs[-1].set_ylabel('IP freq [Hz]')
        self.axs[-1].set_xlabel('Time [s]')
        ht = self.axs[0].text(0.98, 0.05, '(ctrl+) page and arrow up, down, home, end: scroll', ha='right',
                           transform=self.axs[0].transAxes)
        self.helptext.append(ht)
        ht = self.axs[0].text(0.98, 0.1, '+, -, X, x: zoom time in/out', ha='right', transform=self.axs[0].transAxes)
        self.helptext.append(ht)
        ht = self.axs[0].text(0.98, 0.2, 'y, Y, v, V, ctrl+v, ctrl+V: zoom amplitudes out/in/max/default/max per trace/global max per trace', ha='right', transform=self.axs[0].transAxes)
        self.helptext.append(ht)
        ht = self.axs[0].text(0.98, 0.3, 'i, I: zoom IPI frequency in/out', ha='right', transform=self.axs[0].transAxes)
        self.helptext.append(ht)
        ht = self.axs[0].text(0.98, 0.4, 'p, P: play audio (display, all)', ha='right', transform=self.axs[0].transAxes)
        self.helptext.append(ht)
        ht = self.axs[0].text(0.98, 0.5, 'f: full screen', ha='right', transform=self.axs[0].transAxes)
        self.helptext.append(ht)
        ht = self.axs[0].text(0.98, 0.6, 'w: plot waveforms into png file', ha='right', transform=self.axs[0].transAxes)
        self.helptext.append(ht)
        ht = self.axs[0].text(0.98, 0.7, 'S: save audiosegment', ha='right', transform=self.axs[0].transAxes)
        self.helptext.append(ht)
        ht = self.axs[0].text(0.98, 0.8, 'q: quit', ha='right', transform=self.axs[0].transAxes)
        self.helptext.append(ht)
        ht = self.axs[0].text(0.98, 0.9, 'h: toggle this help', ha='right', transform=self.axs[0].transAxes)
        self.helptext.append(ht)
        # plot:
        for ht in self.helptext:
            ht.set_visible(self.help)
        self.update_plots()
        # feature plot:
        if len(self.labels) > 0:
            self.figf, self.axf = plt.subplots()
        plt.show()

    def __del__(self):
        pass
        #self.audio.close()

    def plot_pulses(self, axs, plot=True, tfac=1.0):
        
        def plot_pulse_traces(pulses, i, pak):
            for t in range(self.traces):
                c = self.show_channels[t]
                p = pulses[pulses[:,2] == c,3]
                if len(p) == 0:
                    continue
                if plot or pak >= len(self.pulse_artist):
                    pa, = axs[t].plot(tfac*p/self.samplerate,
                                      self.data[p,c], 'o', picker=5,
                                      color=self.pulse_colors[i%len(self.pulse_colors)])
                    if not plot:
                        self.pulse_artist.append(pa)
                else:
                    self.pulse_artist[pak].set_data(tfac*p/self.samplerate,
                                                    self.data[p,c])
                    self.pulse_artist[pak].set_color(self.pulse_colors[i%len(self.pulse_colors)])
                #if len(p) > 1 and len(p) <= 10:
                #    self.pulse_artist[pak].set_markersize(15)
                pak += 1
            return pak

        # pulses:
        pak = 0
        if self.show_gid:
            for g in range(len(self.pulse_colors)):
                pulses = self.pulses[self.pulses[:,1] % len(self.pulse_colors) == g,:]
                pak = plot_pulse_traces(pulses, g, pak)
        else:
            for l in self.labels:
                pulses = self.pulses[self.pulses[:,0] == l,:]
                pak = plot_pulse_traces(pulses, l, pak)
        while pak < len(self.pulse_artist):
            self.pulse_artist[pak].set_data([], [])
            pak += 1
        # ipis:
        for l in self.labels:
            if l < len(self.pulse_times):
                pt = self.pulse_times[l]/self.samplerate
                if len(pt) > 10:
                    if plot or not l in self.ipis_labels:
                        pa, = axs[-1].plot(tfac*pt[:-1], 1.0/np.diff(pt),
                                           '-o', picker=5,
                                           color=self.pulse_colors[l%len(self.pulse_colors)])
                        if not plot:
                            self.ipis_artist.append(pa)
                            self.ipis_labels.append(l)
                    else:
                        iak = self.ipis_labels.index(l)
                        self.ipis_artist[iak].set_data(tfac*pt[:-1],
                                                       1.0/np.diff(pt))

    def update_plots(self):
        t0 = int(np.round(self.toffset * self.samplerate))
        t1 = int(np.round((self.toffset + self.twindow) * self.samplerate))
        if t1 > len(self.data):
            t1 = len(self.data)
        time = np.arange(t0, t1) / self.samplerate
        for t in range(self.traces):
            c = self.show_channels[t]
            self.axs[t].set_xlim(self.toffset, self.toffset + self.twindow)
            if self.trace_artist[t] == None:
                self.trace_artist[t], = self.axs[t].plot(time, self.data[t0:t1,c])
            else:
                self.trace_artist[t].set_data(time, self.data[t0:t1,c])
            if t1 - t0 < 200:
                self.trace_artist[t].set_marker('o')
                self.trace_artist[t].set_markersize(3)
            else:
                self.trace_artist[t].set_marker('None')
            self.axs[t].set_ylim(self.ymin[t], self.ymax[t])
        self.plot_pulses(self.axs, False)
        self.fig.canvas.draw()

    def on_pick(self, event):
        # index of pulse artist:
        pk = -1
        for k, pa in enumerate(self.pulse_artist):
            if event.artist == pa:
                pk = k
                break
        li = -1
        pi = -1
        if pk >= 0:
            # find label and pulses of pulse artist:
            ll = self.labels[pk//self.traces]
            cc = self.show_channels[pk % self.traces]
            pulses = self.pulses[self.pulses[:,0] == ll,:]
            gid = pulses[pulses[:,2] == cc,1][event.ind[0]]
            if ll in self.ipis_labels:
                li = self.ipis_labels.index(ll)
                pi = self.pulse_gids[ll].index(gid)
        else:
            ik = -1
            for k, ia in enumerate(self.ipis_artist):
                if event.artist == ia:
                    ik = k
                    break
            if ik < 0:
                return
            li = ik
            ll = self.ipis_labels[li]
            pi = event.ind[0]
            gid = self.pulse_gids[ll][pi]
        # mark pulses:
        pulses = self.pulses[self.pulses[:,0] == ll,:]
        pulses = pulses[pulses[:,1] == gid,:]
        for t in range(self.traces):
            c = self.show_channels[t]
            pt = pulses[pulses[:,2] == c,3]
            if len(pt) > 0:
                if self.marker_artist[t] is None:
                    pa, = self.axs[t].plot(pt[0]/self.samplerate,
                                           self.data[pt[0],c], 'o', ms=10,
                                           color=self.pulse_colors[ll%len(self.pulse_colors)])
                    self.marker_artist[t] = pa
                else:
                    self.marker_artist[t].set_data(pt[0]/self.samplerate,
                                                   self.data[pt[0],c])
                    self.marker_artist[t].set_color(self.pulse_colors[ll%len(self.pulse_colors)])
            elif self.marker_artist[t] is not None:
                self.marker_artist[t].set_data([], [])
        # mark ipi:
        pt0 = -1.0
        pt1 = -1.0
        pf = -1.0
        if pi >= 0:
            pt0 = self.pulse_times[ll][pi]/self.samplerate
            pt1 = self.pulse_times[ll][pi+1]/self.samplerate
            pf = 1.0/(pt1-pt0)
            if self.marker_artist[self.traces] is None:
                pa, = self.axs[self.traces].plot(pt0, pf, 'o', ms=10,
                                                 color=self.pulse_colors[ll%len(self.pulse_colors)])
                self.marker_artist[self.traces] = pa
            else:
                self.marker_artist[self.traces].set_data(pt0, pf)
                self.marker_artist[self.traces].set_color(self.pulse_colors[ll%len(self.pulse_colors)])
        elif not self.marker_artist[self.traces] is None:
            self.marker_artist[self.traces].set_data([], [])
        self.fig.canvas.draw()
        # show features:
        if not self.axf is None and not self.fig is None:
            heights = np.zeros(self.channels)
            heights[pulses[:,2]] = \
                self.data[pulses[:,3],pulses[:,2]] - \
                self.data[pulses[:,4],pulses[:,2]]
            self.axf.plot(heights, color=self.pulse_colors[ll%len(self.pulse_colors)])
            print(f'label={ll:4d} gid={gid:5d} t={pt0:8.4f}s')
            self.figf.canvas.draw()

    def resize(self, event):
        # print('resized', event.width, event.height)
        leftpixel = 80.0
        rightpixel = 20.0
        bottompixel = 50.0
        toppixel = 20.0
        x0 = leftpixel / event.width
        x1 = 1.0 - rightpixel / event.width
        y0 = bottompixel / event.height
        y1 = 1.0 - toppixel / event.height
        self.fig.subplots_adjust(left=x0, right=x1, bottom=y0, top=y1,
                                 hspace=0)

    def keypress(self, event):
        # print('pressed', event.key)
        if event.key in '+=X':
            if self.twindow * self.samplerate > 20:
                self.twindow *= 0.5
                self.update_plots()
        elif event.key in '-x':
            if self.twindow < self.tmax:
                self.twindow *= 2.0
                self.update_plots()
        elif event.key == 'pagedown':
            if self.toffset + 0.5 * self.twindow < self.tmax:
                self.toffset += 0.5 * self.twindow
                self.update_plots()
        elif event.key == 'pageup':
            if self.toffset > 0:
                self.toffset -= 0.5 * self.twindow
                if self.toffset < 0.0:
                    self.toffset = 0.0
                self.update_plots()
        elif event.key == 'ctrl+pagedown':
            if self.toffset + 5.0 * self.twindow < self.tmax:
                self.toffset += 5.0 * self.twindow
                self.update_plots()
        elif event.key == 'ctrl+pageup':
            if self.toffset > 0:
                self.toffset -= 5.0 * self.twindow
                if self.toffset < 0.0:
                    self.toffset = 0.0
                self.update_plots()
        elif event.key == 'down':
            if self.toffset + self.twindow < self.tmax:
                self.toffset += 0.05 * self.twindow
                self.update_plots()
        elif event.key == 'up':
            if self.toffset > 0.0:
                self.toffset -= 0.05 * self.twindow
                if self.toffset < 0.0:
                    self.toffset = 0.0
                self.update_plots()
        elif event.key == 'home':
            if self.toffset > 0.0:
                self.toffset = 0.0
                self.update_plots()
        elif event.key == 'end':
            toffs = np.floor(self.tmax / self.twindow) * self.twindow
            if self.tmax - toffs <= 0.0:
                toffs -= self.twindow
            if self.tmax - toffs < self.twindow/2:
                toffs -= self.twindow/2
            if self.toffset < toffs:
                self.toffset = toffs
                self.update_plots()
        elif event.key == 'y':
            for t in range(self.traces):
                h = self.ymax[t] - self.ymin[t]
                c = 0.5 * (self.ymax[t] + self.ymin[t])
                self.ymin[t] = c - h
                self.ymax[t] = c + h
                self.axs[t].set_ylim(self.ymin[t], self.ymax[t])
            self.fig.canvas.draw()
        elif event.key == 'Y':
            for t in range(self.traces):
                h = 0.25 * (self.ymax[t] - self.ymin[t])
                c = 0.5 * (self.ymax[t] + self.ymin[t])
                self.ymin[t] = c - h
                self.ymax[t] = c + h
                self.axs[t].set_ylim(self.ymin[t], self.ymax[t])
            self.fig.canvas.draw()
        elif event.key == 'v':
            t0 = int(np.round(self.toffset * self.samplerate))
            t1 = int(np.round((self.toffset + self.twindow) * self.samplerate))
            min = np.min(self.data[t0:t1,self.show_channels])
            max = np.max(self.data[t0:t1,self.show_channels])
            h = 0.53 * (max - min)
            c = 0.5 * (max + min)
            self.ymin[:] = c - h
            self.ymax[:] = c + h
            for t in range(self.traces):
                self.axs[t].set_ylim(self.ymin[t], self.ymax[t])
            self.fig.canvas.draw()
        elif event.key == 'ctrl+v':
            t0 = int(np.round(self.toffset * self.samplerate))
            t1 = int(np.round((self.toffset + self.twindow) * self.samplerate))
            for t in range(self.traces):
                min = np.min(self.data[t0:t1,self.show_channels[t]])
                max = np.max(self.data[t0:t1,self.show_channels[t]])
                h = 0.53 * (max - min)
                c = 0.5 * (max + min)
                self.ymin[t] = c - h
                self.ymax[t] = c + h
                self.axs[t].set_ylim(self.ymin[t], self.ymax[t])
            self.fig.canvas.draw()
        elif event.key == 'ctrl+V':
            for t in range(self.traces):
                min = np.min(self.data[:,self.show_channels[t]])
                max = np.max(self.data[:,self.show_channels[t]])
                h = 0.53 * (max - min)
                c = 0.5 * (max + min)
                self.ymin[t] = c - h
                self.ymax[t] = c + h
                self.axs[t].set_ylim(self.ymin[t], self.ymax[t])
            self.fig.canvas.draw()
        elif event.key == 'V':
            self.ymin[:] = -1.0
            self.ymax[:] = +1.0
            for t in range(self.traces):
                self.axs[t].set_ylim(self.ymin[t], self.ymax[t])
            self.fig.canvas.draw()
        elif event.key == 'c':
            for t in range(self.traces):
                dy = self.ymax[t] - self.ymin[t]
                self.ymin[t] = -dy/2
                self.ymax[t] = +dy/2
                self.axs[t].set_ylim(self.ymin[t], self.ymax[t])
            self.fig.canvas.draw()
        elif event.key == 'g':
            self.show_gid = not self.show_gid
            self.plot_pulses(self.axs, False)
            self.fig.canvas.draw()
        elif event.key == 'i':
            if len(self.pulses) > 0:
                self.fmax *= 2
                self.axs[-1].set_ylim(0.0, self.fmax)
                self.fig.canvas.draw()
        elif event.key == 'I':
            if len(self.pulses) > 0:
                self.fmax /= 2
                self.axs[-1].set_ylim(0.0, self.fmax)
                self.fig.canvas.draw()
        elif event.key in ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']:
            cc = int(event.key)
            # TODO: this is not yet what we want:
            """
            if cc < self.channels:
                self.axs[cc].set_visible(not self.axs[cc].get_visible())
            self.fig.canvas.draw()
            """
        elif event.key in 'h':
            self.help = not self.help
            for ht in self.helptext:
                ht.set_visible(self.help)
            self.fig.canvas.draw()
        elif event.key in 'p':
            self.play_segment()
        elif event.key in 'P':
            self.play_all()
        elif event.key in 'S':
            self.save_segment()
        elif event.key in 'w':
            self.plot_traces()

    def play_segment(self):
        t0 = int(np.round(self.toffset * self.samplerate))
        t1 = int(np.round((self.toffset + self.twindow) * self.samplerate))
        playdata = 1.0 * np.mean(self.data[t0:t1,self.show_channels], 1)
        f = 0.1 if self.twindow > 0.5 else 0.1*self.twindow
        fade(playdata, self.samplerate, f)
        self.audio.play(playdata, self.samplerate, blocking=False)
        
    def play_all(self):
        self.audio.play(np.mean(self.data[:,self.show_channels], 1),
                        self.samplerate, blocking=False)

    def save_segment(self):
        t0s = int(np.round(self.toffset))
        t1s = int(np.round(self.toffset + self.twindow))
        t0 = int(np.round(self.toffset * self.samplerate))
        t1 = int(np.round((self.toffset + self.twindow) * self.samplerate))
        filename = self.filename.split('.')[0]
        if self.traces == self.channels:
            segment_filename = f'{filename}-{t0s:.4g}s-{t1s:.4g}s.wav'
            write_audio(segment_filename, self.data[t0:t1,:], self.samplerate)
        else:
            segment_filename = f'{filename}-{t0s:.4g}s-{t1s:.4g}s-c{self.show_channels[0]}.wav'
            write_audio(segment_filename,
                        self.data[t0:t1,self.show_channels], self.samplerate)
        print('saved segment to: ' , segment_filename)

    def plot_traces(self):
        splts = self.traces
        if len(self.pulses) > 0:
            splts += 1
        fig, axs = plt.subplots(splts, 1, squeeze=False, sharex=True,
                                figsize=(15, 9))
        axs = axs.flat
        fig.subplots_adjust(left=0.06, right=0.99, bottom=0.05, top=0.97,
                            hspace=0)
        name = self.filename.split('.')[0]
        figfile = f'{name}-{self.toffset:.4g}s-traces.png'
        if self.traces < self.channels:
            figfile = f'{name}-{self.toffset:.4g}s-c{self.show_channels[0]}-traces.png'
        axs[0].set_title(self.filename)
        t0 = int(np.round(self.toffset * self.samplerate))
        t1 = int(np.round((self.toffset + self.twindow) * self.samplerate))
        if t1>len(self.data):
            t1 = len(self.data)
        time = np.arange(t0, t1)/self.samplerate
        if self.toffset < 1.0 and self.twindow < 1.0:
            axs[-1].set_xlabel('Time [ms]')
            for t in range(self.traces):
                c = self.show_channels[t]
                axs[t].set_xlim(1000.0 * self.toffset,
                                1000.0 * (self.toffset + self.twindow))
                axs[t].plot(1000.0 * time, self.data[t0:t1,c])
            self.plot_pulses(axs, True, 1000.0)
        else:
            axs[-1].set_xlabel('Time [s]')
            for t in range(self.traces):
                c = self.show_channels[t]
                axs[t].set_xlim(self.toffset, self.toffset + self.twindow)
                axs[t].plot(time, self.data[t0:t1,c])
            self.plot_pulses(axs, True, 1.0)
        for t in range(self.traces):
            c = self.show_channels[t]
            axs[t].set_ylim(self.ymin[t], self.ymax[t])
            axs[t].set_ylabel(f'C-{c+1} [{self.unit}]')
        if len(self.pulses) > 0:
            axs[-1].set_ylabel('IP freq [Hz]')
            axs[-1].set_ylim(0.0, self.fmax)
        #for t in range(self.traces-1):
        #    axs[t].xaxis.set_major_formatter(plt.NullFormatter())
        fig.savefig(figfile, dpi=200)
        plt.close(fig)
        print('saved waveform figure to', figfile)
        

def short_user_warning(message, category, filename, lineno, file=None, line=''):
    if file is None:
        file = sys.stderr
    if category == UserWarning:
        file.write('%s line %d: %s\n' % ('/'.join(filename.split('/')[-2:]), lineno, message))
    else:
        s = warnings.formatwarning(message, category, filename, lineno, line)
        file.write(s)


def main(cargs=None):
    warnings.showwarning = short_user_warning

    # config file name:
    cfgfile = __package__ + '.cfg'

    # command line arguments:
    if cargs is None:
        cargs = sys.argv[1:]
    parser = argparse.ArgumentParser(
        description='Browse mutlichannel EOD recordings.',
        epilog='version %s by Benda-Lab (2022-%s)' % (__version__, __year__))
    parser.add_argument('--version', action='version', version=__version__)
    parser.add_argument('-v', action='count', dest='verbose')
    parser.add_argument('-c', dest='channels', default='',
                        type=str, metavar='CHANNELS',
                        help='Comma separated list of channels to be displayed (first channel is 0).')
    parser.add_argument('-t', dest='tmax', default=None,
                        type=float, metavar='TMAX',
                        help='Process and show only the first TMAX seconds.')
    parser.add_argument('-f', dest='fcutoff', default=None,
                        type=float, metavar='FREQ',
                        help='Cutoff frequency of optional high-pass filter.')
    parser.add_argument('-p', dest='pulses', action='store_true',
                        help='detect pulse fish EODs')
    parser.add_argument('file', nargs=1, default='', type=str,
                        help='name of the file with the time series data')
    args = parser.parse_args(cargs)
    filepath = args.file[0]
    cs = [s.strip() for s in args.channels.split(',')]
    channels = [int(c) for c in cs if len(c)>0]
    tmax = args.tmax
    fcutoff = args.fcutoff
    pulses = args.pulses

    # set verbosity level from command line:
    verbose = 0
    if args.verbose != None:
        verbose = args.verbose

    # load data:
    filename = os.path.basename(filepath)
    with DataLoader(filepath, 10*60.0, 5.0, verbose) as data:
        SignalPlot(data, data.samplerate, data.unit, filename,
                   channels, tmax, fcutoff, pulses)
        

        
if __name__ == '__main__':
    main()

Functions

def short_user_warning(message, category, filename, lineno, file=None, line='')
Expand source code
def short_user_warning(message, category, filename, lineno, file=None, line=''):
    if file is None:
        file = sys.stderr
    if category == UserWarning:
        file.write('%s line %d: %s\n' % ('/'.join(filename.split('/')[-2:]), lineno, message))
    else:
        s = warnings.formatwarning(message, category, filename, lineno, line)
        file.write(s)
def main(cargs=None)
Expand source code
def main(cargs=None):
    warnings.showwarning = short_user_warning

    # config file name:
    cfgfile = __package__ + '.cfg'

    # command line arguments:
    if cargs is None:
        cargs = sys.argv[1:]
    parser = argparse.ArgumentParser(
        description='Browse mutlichannel EOD recordings.',
        epilog='version %s by Benda-Lab (2022-%s)' % (__version__, __year__))
    parser.add_argument('--version', action='version', version=__version__)
    parser.add_argument('-v', action='count', dest='verbose')
    parser.add_argument('-c', dest='channels', default='',
                        type=str, metavar='CHANNELS',
                        help='Comma separated list of channels to be displayed (first channel is 0).')
    parser.add_argument('-t', dest='tmax', default=None,
                        type=float, metavar='TMAX',
                        help='Process and show only the first TMAX seconds.')
    parser.add_argument('-f', dest='fcutoff', default=None,
                        type=float, metavar='FREQ',
                        help='Cutoff frequency of optional high-pass filter.')
    parser.add_argument('-p', dest='pulses', action='store_true',
                        help='detect pulse fish EODs')
    parser.add_argument('file', nargs=1, default='', type=str,
                        help='name of the file with the time series data')
    args = parser.parse_args(cargs)
    filepath = args.file[0]
    cs = [s.strip() for s in args.channels.split(',')]
    channels = [int(c) for c in cs if len(c)>0]
    tmax = args.tmax
    fcutoff = args.fcutoff
    pulses = args.pulses

    # set verbosity level from command line:
    verbose = 0
    if args.verbose != None:
        verbose = args.verbose

    # load data:
    filename = os.path.basename(filepath)
    with DataLoader(filepath, 10*60.0, 5.0, verbose) as data:
        SignalPlot(data, data.samplerate, data.unit, filename,
                   channels, tmax, fcutoff, pulses)

Classes

class SignalPlot (data, samplerate, unit, filename, show_channels=[], tmax=None, fcutoff=None, pulses=False)
Expand source code
class SignalPlot:
    def __init__(self, data, samplerate, unit, filename,
                 show_channels=[], tmax=None, fcutoff=None,
                 pulses=False):
        self.filename = filename
        self.samplerate = samplerate
        self.data = data
        self.channels = self.data.shape[1] if len(self.data.shape) > 1 else 1
        self.unit = unit
        self.tmax = (len(self.data)-1)/self.samplerate
        if not tmax is None:
            self.tmax = tmax
            self.data = data[:int(tmax*self.samplerate),:]
        self.toffset = 0.0
        self.twindow = 10.0
        if self.twindow > self.tmax:
            self.twindow = np.round(2 ** (np.floor(np.log(self.tmax) / np.log(2.0)) + 1.0))
            if not tmax is None:
                self.twindow = tmax
        self.pulses = np.zeros((0, 3), dtype=int)
        self.labels = []
        self.fishes = []
        self.pulse_times = []
        self.pulse_gids = []
        if len(show_channels) == 0:
            self.show_channels = np.arange(self.channels)
        else:
            self.show_channels = np.array(show_channels)
        self.traces = len(self.show_channels)
        self.ymin = -1.0 * np.ones(self.traces)
        self.ymax = +1.0 * np.ones(self.traces)
        self.fmax = 100.0
        self.trace_artist = [None] * self.traces
        self.show_gid = False
        self.pulse_artist = []
        self.marker_artist = [None] * (self.traces + 1)
        self.ipis_artist = []
        self.ipis_labels = []
        self.figf = None
        self.axf = None
        self.pulse_colors = ['C1', 'C2', 'C3', 'C4', 'C5', 'C6', 'C7', 'C8', 'C9', 'C0']
        self.help = False
        self.helptext = []
        self.audio = PlayAudio()

        # filter data:
        if not fcutoff is None:
            sos = butter(2, fcutoff, 'high', fs=samplerate, output='sos')
            self.data = sosfiltfilt(sos, self.data[:], 0)

        # pulse detection:
        if pulses:
            # label, group, channel, peak index, trough index
            all_pulses = np.zeros((0, 5), dtype=int)
            for c in range(self.channels):
                #thresh = 1*np.std(self.data[:int(2*self.samplerate),c])
                thresh = median_std_threshold(self.data[:,c], self.samplerate,
                                              thresh_fac=6.0)
                thresh = 0.01
                #p, t = detect_peaks(self.data[:,c], thresh)
                p, t, w, h = detect_pulses(self.data[:,c], self.samplerate,
                                           thresh,
                                           min_rel_slope_diff=0.25,
                                           min_width=0.0001,
                                           max_width=0.01,
                                           width_fac=5.0)
                # label, group, channel, peak, trough:
                pulses = np.hstack((np.arange(len(p))[:,np.newaxis],
                                    np.zeros((len(p), 1), dtype=int),
                                    np.ones((len(p), 1), dtype=int)*c,
                                    p[:,np.newaxis], t[:,np.newaxis]))
                all_pulses = np.vstack((all_pulses, pulses))
            self.pulses = all_pulses[np.argsort(all_pulses[:,3]),:]
            # grouping over channels:
            max_di = int(0.0002*self.samplerate)   # TODO: parameter
            l = -1
            k = 0
            while k < len(self.pulses):
                tp = self.pulses[k,3]
                tt = self.pulses[k,4]
                height = self.data[self.pulses[k,3],self.pulses[k,2]] - \
                    self.data[self.pulses[k,4],self.pulses[k,2]]
                channel_counts = np.zeros(self.channels, dtype=int)
                channel_counts[self.pulses[k,2]] += 1
                for c in range(1, 3*self.channels):
                    if k+c >= len(self.pulses):
                        break
                    # pulse too far away:
                    if channel_counts[self.pulses[k+c,2]] > 1 or \
                       (np.abs(self.pulses[k+c,3] - tp) > max_di and
                        np.abs(self.pulses[k+c,3] - tt) > max_di and
                        np.abs(self.pulses[k+c,4] - tp) > max_di and
                        np.abs(self.pulses[k+c,4] - tt) > max_di):
                        break
                    channel_counts[self.pulses[k+c,2]] += 1
                    height_kc = self.data[self.pulses[k+c,3],self.pulses[k+c,2]] - \
                        self.data[self.pulses[k+c,4],self.pulses[k+c,2]]
                    # heighest pulse sets time reference:
                    if height_kc > height:
                        tp = self.pulses[k+c,3]
                        tt = self.pulses[k+c,4]
                        height = height_kc
                # all pulses too small:
                if height < 0.02:    # TODO parameter
                    self.pulses[k:k+c,0] = -1
                    k += c
                    continue
                # new label:
                l += 1
                # remove lost pulses:
                for j in range(c):
                    if (np.abs(self.pulses[k+j,3] - tp) > max_di and
                        np.abs(self.pulses[k+j,3] - tt) > max_di and
                        np.abs(self.pulses[k+j,4] - tp) > max_di and
                        np.abs(self.pulses[k+j,4] - tt) > max_di):
                        self.pulses[k+j,0] = -1
                        channel_counts[self.pulses[k+j,2]] -= 1
                    else:
                        self.pulses[k+j,0] = l
                        self.pulses[k+j,1] = l
                # keep only the largest pulse of each channel:
                pulses = self.pulses[k:k+c,:]
                for dc in np.where(channel_counts > 1)[0]:
                    idx = np.where(self.pulses[k:k+c,2] == dc)[0]
                    heights = self.data[pulses[idx,3],dc] - \
                        self.data[pulses[idx,4],dc]
                    for i in range(len(idx)):
                        if i != np.argmax(heights):
                            channel_counts[self.pulses[k+idx[i],2]] -= 1
                            self.pulses[k+idx[i],0] = -1
                k += c
            self.pulses = self.pulses[self.pulses[:,0] >= 0,:]

            # clustering:
            min_dists = []
            recent = []
            k = 0
            while k < len(self.pulses):
                # select pulse group:
                j = k
                gid = self.pulses[j,1]
                for c in range(self.channels):
                    k += 1
                    if k >= len(self.pulses) or \
                       self.pulses[k,1] != gid:
                        break
                heights = np.zeros(self.channels)
                heights[self.pulses[j:k,2]] = \
                    self.data[self.pulses[j:k,3],self.pulses[j:k,2]] - \
                    self.data[self.pulses[j:k,4],self.pulses[j:k,2]]
                # time of largest pulse:
                pulse_time = self.pulses[j+np.argmax(heights[self.pulses[j:k,2]]),3]
                # assign to cluster:
                if len(self.pulse_times) == 0:
                    label = len(self.pulse_times)
                    self.pulse_times.append([])
                    self.pulse_gids.append([])
                else:
                    # compute metrics of recent fishes:
                    # mean relative height difference:
                    dists = np.array([np.mean(np.abs(hh - heights)/np.max(hh))
                                        for ll, tt, hh in recent])
                    thresh = 0.1   # TODO: make parameter
                    # distance between pulses:
                    ipis = np.array([(pulse_time - tt)/self.samplerate
                                     for ll, tt, hh in recent])
                    ## how can ipis be 0, or just one sample?
                    ##if len(ipis[ipis<0.001]) > 0:
                    ##    print(ipis[ipis<0.001])
                    # ensure minimum IP distance:
                    dists[1/ipis > 300.0] = 2*np.max(dists)  # TODO: make parameter
                    # minimum ditance:
                    min_dist_idx = np.argmin(dists)
                    min_dists.append(dists[min_dist_idx])
                    if dists[min_dist_idx] < thresh:
                        label = recent[min_dist_idx][0]
                    else:
                        label = len(self.pulse_times)
                        self.pulse_times.append([])
                        self.pulse_gids.append([])
                self.pulses[j:k,0] = label
                self.pulse_times[label].append(pulse_time)
                self.pulse_gids[label].append(gid)
                self.fishes.append([label, pulse_time, heights])
                recent.append([label, pulse_time, heights])
                # remove old fish:
                for i, (ll, tt, hh) in enumerate(recent):
                    # TODO: make parameter:
                    if (pulse_time - tt)/self.samplerate <= 0.2:
                        recent = recent[i:]
                        break
                # only consider the n most recent pulses of a fish:
                n = 5    # TODO make parameter
                labels = np.array([ll for ll, tt, hh in recent])
                if np.sum(labels == label) > n:
                    del recent[np.where(labels == label)[0][0]]
            # pulse times to arrays:
            for k in range(len(self.pulse_times)):
                self.pulse_times[k] = np.array(self.pulse_times[k])


                
            """
            # find temporally missing pulses:
            npulses = np.array([len(pts) for pts in self.pulse_times],
                               dtype=int)
            idx = np.argsort(npulses)
            for i in range(len(idx)):
                li = idx[len(idx)-1-i]
                if len(self.pulse_times[li]) < 10 or \
                   len(self.pulse_times[li])/npulses[li] < 0.5:
                    continue
                ipis = np.diff(self.pulse_times[li])
                n = 4 # TODO: make parameter
                k = 0
                while k < len(ipis)-n:
                    mipi = np.median(ipis[k:k+n])
                    if ipis[k+n-2] > 1.8*mipi:
                        # search for pulse closest to pt:
                        pt = self.pulse_times[li][k+n-2] + mipi
                        mlj = -1
                        mpj = -1
                        mdj = 10*mipi
                        for lj in range(len(self.pulse_times)):
                            if lj == li or len(self.pulse_times[lj]) == 0:
                                continue
                            pj = np.argmin(np.abs(self.pulse_times[lj] - pt))
                            dj = np.abs(self.pulse_times[lj][pj] - pt)
                            if dj < int(0.001*self.samplerate) and dj < mdj:
                                mdj = dj
                                mpj = pj
                                mlj = lj
                        if mlj >= 0:
                            # there is a pulse close to pt:
                            ptj = self.pulse_times[mlj][mpj]
                            pulses = self.pulses[self.pulses[:,0] == mlj,:]
                            gid = pulses[np.argmin(np.abs(pulses[:,3] - ptj)),1]
                            self.pulse_times[li] = np.insert(self.pulse_times[li], k+n-1, ptj)
                            self.pulse_gids[li].insert(k+n-1, gid)
                            # maybe don't delete but always duplicate and flag it:
                            if False:  # can be deleted
                                self.pulse_times[mlj] = np.delete(self.pulse_times[mlj], mpj)
                                self.pulse_gids[mlj].pop(mpj)
                                self.pulses[self.pulses[:,1] == gid,0] = li
                            else:     # pulse needs to be duplicated:
                                self.pulses[self.pulses[:,1] == gid,0] = li
                            ipis = np.diff(self.pulse_times[li])
                    k += 1


                    
            # clean up pulses:
            for l in range(len(self.pulse_times)):
                if len(self.pulse_times[l])/npulses[l] < 0.5:
                    self.pulse_times[l] = np.array([])
                    self.pulse_gids[l] = []
                    self.pulses[self.pulses[:,0] == l,0] = -1
            self.pulses = self.pulses[self.pulses[:,0] >= 0,:]
            """
            
            """
            # remove labels that are too close to others:
            widths = np.zeros(len(self.pulse_times), dtype=int)
            for k in range(len(self.pulse_times)):
                widths[k] = int(np.mean(np.abs(self.pulses[self.pulses[:,0] == k,3] - self.pulses[self.pulses[:,0] == k,4])))
            for k in range(len(self.pulse_times)):
                if len(self.pulse_times[k]) > 1:
                    for j in range(k+1, len(self.pulse_times)):
                        if len(self.pulse_times[j]) > 1:
                            di = 10*max(widths[k], widths[j])
                            dts = np.array([np.min(np.abs(self.pulse_times[k] - pt)) for pt in self.pulse_times[j]])
                            if k == 1 and j == 2:
                                print(di, np.sum(dts < di), len(dts))
                                plt.hist(dts, 50)
                                plt.show()
                            if np.sum(dts < 2*max_di)/len(dts) > 0.6:
                                r = k
                                if np.sum(self.fishes[k][2]) > np.sum(self.fishes[j][2]):
                                    r = j
                                self.pulse_times[r] = np.array([])
                                self.pulses[self.pulses[:,0] == r] = -1
                                self.fishes[r] = []
            self.pulses = self.pulses[self.pulses[:,0] >= 0,:]
            """
            # all labels:
            self.labels = np.unique(self.pulses[:,0])
            # report:
            print(f'found {len(self.pulse_times)} fish:')
            for k in range(len(self.pulse_times)):
                print(f'{k:3d}: {len(self.pulse_times[k]):5d} pulses')
            ## plot histogtram of distances:
            #plt.hist(min_dists, 100)
            #plt.show()
            ## plot features:
            """
            nn = np.array([(k, len(self.pulse_times[k]))
                           for k in range(len(self.pulse_times))])
            fig, axs = plt.subplots(5, 5, figsize=(15, 9),
                                    constrained_layout=True)
            ni = np.argsort(nn[:,1])           # largest cluster ...
            ln = np.sort(nn[ni[-axs.size:],0]) # ... sort by label
            for l, ax in zip(ln, axs.flat):
                h = np.array([hh for ll, tt, hh in self.fishes if ll == l])
                ax.plot(h.T, 'o-', ms=2, lw=0.5,
                        color=self.pulse_colors[l%len(self.pulse_colors)])
                ax.text(0.05, 0.9, f'label: {l}', transform=ax.transAxes)
            """
        
        # set key bindings:
        plt.rcParams['keymap.fullscreen'] = 'f'
        plt.rcParams['keymap.pan'] = 'ctrl+m'
        plt.rcParams['keymap.quit'] = 'ctrl+w, alt+q, q'
        plt.rcParams['keymap.yscale'] = ''
        plt.rcParams['keymap.xscale'] = ''
        plt.rcParams['keymap.grid'] = ''
        #plt.rcParams['keymap.all_axes'] = ''

        # the figure:
        plt.ioff()
        splts = self.traces
        if len(self.pulses) > 0:
            splts += 1
        self.fig, self.axs = plt.subplots(splts, 1, squeeze=False,
                                          figsize=(15, 9), sharex=True)
        self.axs = self.axs.flat
        if self.traces == self.channels:
            self.fig.canvas.manager.set_window_title(self.filename)
        else:
            cs = ' c%d' % self.show_channels[0]
            self.fig.canvas.manager.set_window_title(self.filename + ' ' + cs)
        self.fig.canvas.mpl_connect('key_press_event', self.keypress)
        self.fig.canvas.mpl_connect('resize_event', self.resize)
        self.fig.canvas.mpl_connect('pick_event', self.on_pick)
        # trace plots:
        for t in range(self.traces):
            self.axs[t].set_ylabel(f'C-{self.show_channels[t]+1} [{self.unit}]')
        #for t in range(self.traces-1):
        #    self.axs[t].xaxis.set_major_formatter(plt.NullFormatter())
        if len(self.pulses) > 0:
            self.axs[-1].set_ylim(0, self.fmax)
            self.axs[-1].set_ylabel('IP freq [Hz]')
        self.axs[-1].set_xlabel('Time [s]')
        ht = self.axs[0].text(0.98, 0.05, '(ctrl+) page and arrow up, down, home, end: scroll', ha='right',
                           transform=self.axs[0].transAxes)
        self.helptext.append(ht)
        ht = self.axs[0].text(0.98, 0.1, '+, -, X, x: zoom time in/out', ha='right', transform=self.axs[0].transAxes)
        self.helptext.append(ht)
        ht = self.axs[0].text(0.98, 0.2, 'y, Y, v, V, ctrl+v, ctrl+V: zoom amplitudes out/in/max/default/max per trace/global max per trace', ha='right', transform=self.axs[0].transAxes)
        self.helptext.append(ht)
        ht = self.axs[0].text(0.98, 0.3, 'i, I: zoom IPI frequency in/out', ha='right', transform=self.axs[0].transAxes)
        self.helptext.append(ht)
        ht = self.axs[0].text(0.98, 0.4, 'p, P: play audio (display, all)', ha='right', transform=self.axs[0].transAxes)
        self.helptext.append(ht)
        ht = self.axs[0].text(0.98, 0.5, 'f: full screen', ha='right', transform=self.axs[0].transAxes)
        self.helptext.append(ht)
        ht = self.axs[0].text(0.98, 0.6, 'w: plot waveforms into png file', ha='right', transform=self.axs[0].transAxes)
        self.helptext.append(ht)
        ht = self.axs[0].text(0.98, 0.7, 'S: save audiosegment', ha='right', transform=self.axs[0].transAxes)
        self.helptext.append(ht)
        ht = self.axs[0].text(0.98, 0.8, 'q: quit', ha='right', transform=self.axs[0].transAxes)
        self.helptext.append(ht)
        ht = self.axs[0].text(0.98, 0.9, 'h: toggle this help', ha='right', transform=self.axs[0].transAxes)
        self.helptext.append(ht)
        # plot:
        for ht in self.helptext:
            ht.set_visible(self.help)
        self.update_plots()
        # feature plot:
        if len(self.labels) > 0:
            self.figf, self.axf = plt.subplots()
        plt.show()

    def __del__(self):
        pass
        #self.audio.close()

    def plot_pulses(self, axs, plot=True, tfac=1.0):
        
        def plot_pulse_traces(pulses, i, pak):
            for t in range(self.traces):
                c = self.show_channels[t]
                p = pulses[pulses[:,2] == c,3]
                if len(p) == 0:
                    continue
                if plot or pak >= len(self.pulse_artist):
                    pa, = axs[t].plot(tfac*p/self.samplerate,
                                      self.data[p,c], 'o', picker=5,
                                      color=self.pulse_colors[i%len(self.pulse_colors)])
                    if not plot:
                        self.pulse_artist.append(pa)
                else:
                    self.pulse_artist[pak].set_data(tfac*p/self.samplerate,
                                                    self.data[p,c])
                    self.pulse_artist[pak].set_color(self.pulse_colors[i%len(self.pulse_colors)])
                #if len(p) > 1 and len(p) <= 10:
                #    self.pulse_artist[pak].set_markersize(15)
                pak += 1
            return pak

        # pulses:
        pak = 0
        if self.show_gid:
            for g in range(len(self.pulse_colors)):
                pulses = self.pulses[self.pulses[:,1] % len(self.pulse_colors) == g,:]
                pak = plot_pulse_traces(pulses, g, pak)
        else:
            for l in self.labels:
                pulses = self.pulses[self.pulses[:,0] == l,:]
                pak = plot_pulse_traces(pulses, l, pak)
        while pak < len(self.pulse_artist):
            self.pulse_artist[pak].set_data([], [])
            pak += 1
        # ipis:
        for l in self.labels:
            if l < len(self.pulse_times):
                pt = self.pulse_times[l]/self.samplerate
                if len(pt) > 10:
                    if plot or not l in self.ipis_labels:
                        pa, = axs[-1].plot(tfac*pt[:-1], 1.0/np.diff(pt),
                                           '-o', picker=5,
                                           color=self.pulse_colors[l%len(self.pulse_colors)])
                        if not plot:
                            self.ipis_artist.append(pa)
                            self.ipis_labels.append(l)
                    else:
                        iak = self.ipis_labels.index(l)
                        self.ipis_artist[iak].set_data(tfac*pt[:-1],
                                                       1.0/np.diff(pt))

    def update_plots(self):
        t0 = int(np.round(self.toffset * self.samplerate))
        t1 = int(np.round((self.toffset + self.twindow) * self.samplerate))
        if t1 > len(self.data):
            t1 = len(self.data)
        time = np.arange(t0, t1) / self.samplerate
        for t in range(self.traces):
            c = self.show_channels[t]
            self.axs[t].set_xlim(self.toffset, self.toffset + self.twindow)
            if self.trace_artist[t] == None:
                self.trace_artist[t], = self.axs[t].plot(time, self.data[t0:t1,c])
            else:
                self.trace_artist[t].set_data(time, self.data[t0:t1,c])
            if t1 - t0 < 200:
                self.trace_artist[t].set_marker('o')
                self.trace_artist[t].set_markersize(3)
            else:
                self.trace_artist[t].set_marker('None')
            self.axs[t].set_ylim(self.ymin[t], self.ymax[t])
        self.plot_pulses(self.axs, False)
        self.fig.canvas.draw()

    def on_pick(self, event):
        # index of pulse artist:
        pk = -1
        for k, pa in enumerate(self.pulse_artist):
            if event.artist == pa:
                pk = k
                break
        li = -1
        pi = -1
        if pk >= 0:
            # find label and pulses of pulse artist:
            ll = self.labels[pk//self.traces]
            cc = self.show_channels[pk % self.traces]
            pulses = self.pulses[self.pulses[:,0] == ll,:]
            gid = pulses[pulses[:,2] == cc,1][event.ind[0]]
            if ll in self.ipis_labels:
                li = self.ipis_labels.index(ll)
                pi = self.pulse_gids[ll].index(gid)
        else:
            ik = -1
            for k, ia in enumerate(self.ipis_artist):
                if event.artist == ia:
                    ik = k
                    break
            if ik < 0:
                return
            li = ik
            ll = self.ipis_labels[li]
            pi = event.ind[0]
            gid = self.pulse_gids[ll][pi]
        # mark pulses:
        pulses = self.pulses[self.pulses[:,0] == ll,:]
        pulses = pulses[pulses[:,1] == gid,:]
        for t in range(self.traces):
            c = self.show_channels[t]
            pt = pulses[pulses[:,2] == c,3]
            if len(pt) > 0:
                if self.marker_artist[t] is None:
                    pa, = self.axs[t].plot(pt[0]/self.samplerate,
                                           self.data[pt[0],c], 'o', ms=10,
                                           color=self.pulse_colors[ll%len(self.pulse_colors)])
                    self.marker_artist[t] = pa
                else:
                    self.marker_artist[t].set_data(pt[0]/self.samplerate,
                                                   self.data[pt[0],c])
                    self.marker_artist[t].set_color(self.pulse_colors[ll%len(self.pulse_colors)])
            elif self.marker_artist[t] is not None:
                self.marker_artist[t].set_data([], [])
        # mark ipi:
        pt0 = -1.0
        pt1 = -1.0
        pf = -1.0
        if pi >= 0:
            pt0 = self.pulse_times[ll][pi]/self.samplerate
            pt1 = self.pulse_times[ll][pi+1]/self.samplerate
            pf = 1.0/(pt1-pt0)
            if self.marker_artist[self.traces] is None:
                pa, = self.axs[self.traces].plot(pt0, pf, 'o', ms=10,
                                                 color=self.pulse_colors[ll%len(self.pulse_colors)])
                self.marker_artist[self.traces] = pa
            else:
                self.marker_artist[self.traces].set_data(pt0, pf)
                self.marker_artist[self.traces].set_color(self.pulse_colors[ll%len(self.pulse_colors)])
        elif not self.marker_artist[self.traces] is None:
            self.marker_artist[self.traces].set_data([], [])
        self.fig.canvas.draw()
        # show features:
        if not self.axf is None and not self.fig is None:
            heights = np.zeros(self.channels)
            heights[pulses[:,2]] = \
                self.data[pulses[:,3],pulses[:,2]] - \
                self.data[pulses[:,4],pulses[:,2]]
            self.axf.plot(heights, color=self.pulse_colors[ll%len(self.pulse_colors)])
            print(f'label={ll:4d} gid={gid:5d} t={pt0:8.4f}s')
            self.figf.canvas.draw()

    def resize(self, event):
        # print('resized', event.width, event.height)
        leftpixel = 80.0
        rightpixel = 20.0
        bottompixel = 50.0
        toppixel = 20.0
        x0 = leftpixel / event.width
        x1 = 1.0 - rightpixel / event.width
        y0 = bottompixel / event.height
        y1 = 1.0 - toppixel / event.height
        self.fig.subplots_adjust(left=x0, right=x1, bottom=y0, top=y1,
                                 hspace=0)

    def keypress(self, event):
        # print('pressed', event.key)
        if event.key in '+=X':
            if self.twindow * self.samplerate > 20:
                self.twindow *= 0.5
                self.update_plots()
        elif event.key in '-x':
            if self.twindow < self.tmax:
                self.twindow *= 2.0
                self.update_plots()
        elif event.key == 'pagedown':
            if self.toffset + 0.5 * self.twindow < self.tmax:
                self.toffset += 0.5 * self.twindow
                self.update_plots()
        elif event.key == 'pageup':
            if self.toffset > 0:
                self.toffset -= 0.5 * self.twindow
                if self.toffset < 0.0:
                    self.toffset = 0.0
                self.update_plots()
        elif event.key == 'ctrl+pagedown':
            if self.toffset + 5.0 * self.twindow < self.tmax:
                self.toffset += 5.0 * self.twindow
                self.update_plots()
        elif event.key == 'ctrl+pageup':
            if self.toffset > 0:
                self.toffset -= 5.0 * self.twindow
                if self.toffset < 0.0:
                    self.toffset = 0.0
                self.update_plots()
        elif event.key == 'down':
            if self.toffset + self.twindow < self.tmax:
                self.toffset += 0.05 * self.twindow
                self.update_plots()
        elif event.key == 'up':
            if self.toffset > 0.0:
                self.toffset -= 0.05 * self.twindow
                if self.toffset < 0.0:
                    self.toffset = 0.0
                self.update_plots()
        elif event.key == 'home':
            if self.toffset > 0.0:
                self.toffset = 0.0
                self.update_plots()
        elif event.key == 'end':
            toffs = np.floor(self.tmax / self.twindow) * self.twindow
            if self.tmax - toffs <= 0.0:
                toffs -= self.twindow
            if self.tmax - toffs < self.twindow/2:
                toffs -= self.twindow/2
            if self.toffset < toffs:
                self.toffset = toffs
                self.update_plots()
        elif event.key == 'y':
            for t in range(self.traces):
                h = self.ymax[t] - self.ymin[t]
                c = 0.5 * (self.ymax[t] + self.ymin[t])
                self.ymin[t] = c - h
                self.ymax[t] = c + h
                self.axs[t].set_ylim(self.ymin[t], self.ymax[t])
            self.fig.canvas.draw()
        elif event.key == 'Y':
            for t in range(self.traces):
                h = 0.25 * (self.ymax[t] - self.ymin[t])
                c = 0.5 * (self.ymax[t] + self.ymin[t])
                self.ymin[t] = c - h
                self.ymax[t] = c + h
                self.axs[t].set_ylim(self.ymin[t], self.ymax[t])
            self.fig.canvas.draw()
        elif event.key == 'v':
            t0 = int(np.round(self.toffset * self.samplerate))
            t1 = int(np.round((self.toffset + self.twindow) * self.samplerate))
            min = np.min(self.data[t0:t1,self.show_channels])
            max = np.max(self.data[t0:t1,self.show_channels])
            h = 0.53 * (max - min)
            c = 0.5 * (max + min)
            self.ymin[:] = c - h
            self.ymax[:] = c + h
            for t in range(self.traces):
                self.axs[t].set_ylim(self.ymin[t], self.ymax[t])
            self.fig.canvas.draw()
        elif event.key == 'ctrl+v':
            t0 = int(np.round(self.toffset * self.samplerate))
            t1 = int(np.round((self.toffset + self.twindow) * self.samplerate))
            for t in range(self.traces):
                min = np.min(self.data[t0:t1,self.show_channels[t]])
                max = np.max(self.data[t0:t1,self.show_channels[t]])
                h = 0.53 * (max - min)
                c = 0.5 * (max + min)
                self.ymin[t] = c - h
                self.ymax[t] = c + h
                self.axs[t].set_ylim(self.ymin[t], self.ymax[t])
            self.fig.canvas.draw()
        elif event.key == 'ctrl+V':
            for t in range(self.traces):
                min = np.min(self.data[:,self.show_channels[t]])
                max = np.max(self.data[:,self.show_channels[t]])
                h = 0.53 * (max - min)
                c = 0.5 * (max + min)
                self.ymin[t] = c - h
                self.ymax[t] = c + h
                self.axs[t].set_ylim(self.ymin[t], self.ymax[t])
            self.fig.canvas.draw()
        elif event.key == 'V':
            self.ymin[:] = -1.0
            self.ymax[:] = +1.0
            for t in range(self.traces):
                self.axs[t].set_ylim(self.ymin[t], self.ymax[t])
            self.fig.canvas.draw()
        elif event.key == 'c':
            for t in range(self.traces):
                dy = self.ymax[t] - self.ymin[t]
                self.ymin[t] = -dy/2
                self.ymax[t] = +dy/2
                self.axs[t].set_ylim(self.ymin[t], self.ymax[t])
            self.fig.canvas.draw()
        elif event.key == 'g':
            self.show_gid = not self.show_gid
            self.plot_pulses(self.axs, False)
            self.fig.canvas.draw()
        elif event.key == 'i':
            if len(self.pulses) > 0:
                self.fmax *= 2
                self.axs[-1].set_ylim(0.0, self.fmax)
                self.fig.canvas.draw()
        elif event.key == 'I':
            if len(self.pulses) > 0:
                self.fmax /= 2
                self.axs[-1].set_ylim(0.0, self.fmax)
                self.fig.canvas.draw()
        elif event.key in ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']:
            cc = int(event.key)
            # TODO: this is not yet what we want:
            """
            if cc < self.channels:
                self.axs[cc].set_visible(not self.axs[cc].get_visible())
            self.fig.canvas.draw()
            """
        elif event.key in 'h':
            self.help = not self.help
            for ht in self.helptext:
                ht.set_visible(self.help)
            self.fig.canvas.draw()
        elif event.key in 'p':
            self.play_segment()
        elif event.key in 'P':
            self.play_all()
        elif event.key in 'S':
            self.save_segment()
        elif event.key in 'w':
            self.plot_traces()

    def play_segment(self):
        t0 = int(np.round(self.toffset * self.samplerate))
        t1 = int(np.round((self.toffset + self.twindow) * self.samplerate))
        playdata = 1.0 * np.mean(self.data[t0:t1,self.show_channels], 1)
        f = 0.1 if self.twindow > 0.5 else 0.1*self.twindow
        fade(playdata, self.samplerate, f)
        self.audio.play(playdata, self.samplerate, blocking=False)
        
    def play_all(self):
        self.audio.play(np.mean(self.data[:,self.show_channels], 1),
                        self.samplerate, blocking=False)

    def save_segment(self):
        t0s = int(np.round(self.toffset))
        t1s = int(np.round(self.toffset + self.twindow))
        t0 = int(np.round(self.toffset * self.samplerate))
        t1 = int(np.round((self.toffset + self.twindow) * self.samplerate))
        filename = self.filename.split('.')[0]
        if self.traces == self.channels:
            segment_filename = f'{filename}-{t0s:.4g}s-{t1s:.4g}s.wav'
            write_audio(segment_filename, self.data[t0:t1,:], self.samplerate)
        else:
            segment_filename = f'{filename}-{t0s:.4g}s-{t1s:.4g}s-c{self.show_channels[0]}.wav'
            write_audio(segment_filename,
                        self.data[t0:t1,self.show_channels], self.samplerate)
        print('saved segment to: ' , segment_filename)

    def plot_traces(self):
        splts = self.traces
        if len(self.pulses) > 0:
            splts += 1
        fig, axs = plt.subplots(splts, 1, squeeze=False, sharex=True,
                                figsize=(15, 9))
        axs = axs.flat
        fig.subplots_adjust(left=0.06, right=0.99, bottom=0.05, top=0.97,
                            hspace=0)
        name = self.filename.split('.')[0]
        figfile = f'{name}-{self.toffset:.4g}s-traces.png'
        if self.traces < self.channels:
            figfile = f'{name}-{self.toffset:.4g}s-c{self.show_channels[0]}-traces.png'
        axs[0].set_title(self.filename)
        t0 = int(np.round(self.toffset * self.samplerate))
        t1 = int(np.round((self.toffset + self.twindow) * self.samplerate))
        if t1>len(self.data):
            t1 = len(self.data)
        time = np.arange(t0, t1)/self.samplerate
        if self.toffset < 1.0 and self.twindow < 1.0:
            axs[-1].set_xlabel('Time [ms]')
            for t in range(self.traces):
                c = self.show_channels[t]
                axs[t].set_xlim(1000.0 * self.toffset,
                                1000.0 * (self.toffset + self.twindow))
                axs[t].plot(1000.0 * time, self.data[t0:t1,c])
            self.plot_pulses(axs, True, 1000.0)
        else:
            axs[-1].set_xlabel('Time [s]')
            for t in range(self.traces):
                c = self.show_channels[t]
                axs[t].set_xlim(self.toffset, self.toffset + self.twindow)
                axs[t].plot(time, self.data[t0:t1,c])
            self.plot_pulses(axs, True, 1.0)
        for t in range(self.traces):
            c = self.show_channels[t]
            axs[t].set_ylim(self.ymin[t], self.ymax[t])
            axs[t].set_ylabel(f'C-{c+1} [{self.unit}]')
        if len(self.pulses) > 0:
            axs[-1].set_ylabel('IP freq [Hz]')
            axs[-1].set_ylim(0.0, self.fmax)
        #for t in range(self.traces-1):
        #    axs[t].xaxis.set_major_formatter(plt.NullFormatter())
        fig.savefig(figfile, dpi=200)
        plt.close(fig)
        print('saved waveform figure to', figfile)

Methods

def plot_pulses(self, axs, plot=True, tfac=1.0)
Expand source code
def plot_pulses(self, axs, plot=True, tfac=1.0):
    
    def plot_pulse_traces(pulses, i, pak):
        for t in range(self.traces):
            c = self.show_channels[t]
            p = pulses[pulses[:,2] == c,3]
            if len(p) == 0:
                continue
            if plot or pak >= len(self.pulse_artist):
                pa, = axs[t].plot(tfac*p/self.samplerate,
                                  self.data[p,c], 'o', picker=5,
                                  color=self.pulse_colors[i%len(self.pulse_colors)])
                if not plot:
                    self.pulse_artist.append(pa)
            else:
                self.pulse_artist[pak].set_data(tfac*p/self.samplerate,
                                                self.data[p,c])
                self.pulse_artist[pak].set_color(self.pulse_colors[i%len(self.pulse_colors)])
            #if len(p) > 1 and len(p) <= 10:
            #    self.pulse_artist[pak].set_markersize(15)
            pak += 1
        return pak

    # pulses:
    pak = 0
    if self.show_gid:
        for g in range(len(self.pulse_colors)):
            pulses = self.pulses[self.pulses[:,1] % len(self.pulse_colors) == g,:]
            pak = plot_pulse_traces(pulses, g, pak)
    else:
        for l in self.labels:
            pulses = self.pulses[self.pulses[:,0] == l,:]
            pak = plot_pulse_traces(pulses, l, pak)
    while pak < len(self.pulse_artist):
        self.pulse_artist[pak].set_data([], [])
        pak += 1
    # ipis:
    for l in self.labels:
        if l < len(self.pulse_times):
            pt = self.pulse_times[l]/self.samplerate
            if len(pt) > 10:
                if plot or not l in self.ipis_labels:
                    pa, = axs[-1].plot(tfac*pt[:-1], 1.0/np.diff(pt),
                                       '-o', picker=5,
                                       color=self.pulse_colors[l%len(self.pulse_colors)])
                    if not plot:
                        self.ipis_artist.append(pa)
                        self.ipis_labels.append(l)
                else:
                    iak = self.ipis_labels.index(l)
                    self.ipis_artist[iak].set_data(tfac*pt[:-1],
                                                   1.0/np.diff(pt))
def update_plots(self)
Expand source code
def update_plots(self):
    t0 = int(np.round(self.toffset * self.samplerate))
    t1 = int(np.round((self.toffset + self.twindow) * self.samplerate))
    if t1 > len(self.data):
        t1 = len(self.data)
    time = np.arange(t0, t1) / self.samplerate
    for t in range(self.traces):
        c = self.show_channels[t]
        self.axs[t].set_xlim(self.toffset, self.toffset + self.twindow)
        if self.trace_artist[t] == None:
            self.trace_artist[t], = self.axs[t].plot(time, self.data[t0:t1,c])
        else:
            self.trace_artist[t].set_data(time, self.data[t0:t1,c])
        if t1 - t0 < 200:
            self.trace_artist[t].set_marker('o')
            self.trace_artist[t].set_markersize(3)
        else:
            self.trace_artist[t].set_marker('None')
        self.axs[t].set_ylim(self.ymin[t], self.ymax[t])
    self.plot_pulses(self.axs, False)
    self.fig.canvas.draw()
def on_pick(self, event)
Expand source code
def on_pick(self, event):
    # index of pulse artist:
    pk = -1
    for k, pa in enumerate(self.pulse_artist):
        if event.artist == pa:
            pk = k
            break
    li = -1
    pi = -1
    if pk >= 0:
        # find label and pulses of pulse artist:
        ll = self.labels[pk//self.traces]
        cc = self.show_channels[pk % self.traces]
        pulses = self.pulses[self.pulses[:,0] == ll,:]
        gid = pulses[pulses[:,2] == cc,1][event.ind[0]]
        if ll in self.ipis_labels:
            li = self.ipis_labels.index(ll)
            pi = self.pulse_gids[ll].index(gid)
    else:
        ik = -1
        for k, ia in enumerate(self.ipis_artist):
            if event.artist == ia:
                ik = k
                break
        if ik < 0:
            return
        li = ik
        ll = self.ipis_labels[li]
        pi = event.ind[0]
        gid = self.pulse_gids[ll][pi]
    # mark pulses:
    pulses = self.pulses[self.pulses[:,0] == ll,:]
    pulses = pulses[pulses[:,1] == gid,:]
    for t in range(self.traces):
        c = self.show_channels[t]
        pt = pulses[pulses[:,2] == c,3]
        if len(pt) > 0:
            if self.marker_artist[t] is None:
                pa, = self.axs[t].plot(pt[0]/self.samplerate,
                                       self.data[pt[0],c], 'o', ms=10,
                                       color=self.pulse_colors[ll%len(self.pulse_colors)])
                self.marker_artist[t] = pa
            else:
                self.marker_artist[t].set_data(pt[0]/self.samplerate,
                                               self.data[pt[0],c])
                self.marker_artist[t].set_color(self.pulse_colors[ll%len(self.pulse_colors)])
        elif self.marker_artist[t] is not None:
            self.marker_artist[t].set_data([], [])
    # mark ipi:
    pt0 = -1.0
    pt1 = -1.0
    pf = -1.0
    if pi >= 0:
        pt0 = self.pulse_times[ll][pi]/self.samplerate
        pt1 = self.pulse_times[ll][pi+1]/self.samplerate
        pf = 1.0/(pt1-pt0)
        if self.marker_artist[self.traces] is None:
            pa, = self.axs[self.traces].plot(pt0, pf, 'o', ms=10,
                                             color=self.pulse_colors[ll%len(self.pulse_colors)])
            self.marker_artist[self.traces] = pa
        else:
            self.marker_artist[self.traces].set_data(pt0, pf)
            self.marker_artist[self.traces].set_color(self.pulse_colors[ll%len(self.pulse_colors)])
    elif not self.marker_artist[self.traces] is None:
        self.marker_artist[self.traces].set_data([], [])
    self.fig.canvas.draw()
    # show features:
    if not self.axf is None and not self.fig is None:
        heights = np.zeros(self.channels)
        heights[pulses[:,2]] = \
            self.data[pulses[:,3],pulses[:,2]] - \
            self.data[pulses[:,4],pulses[:,2]]
        self.axf.plot(heights, color=self.pulse_colors[ll%len(self.pulse_colors)])
        print(f'label={ll:4d} gid={gid:5d} t={pt0:8.4f}s')
        self.figf.canvas.draw()
def resize(self, event)
Expand source code
def resize(self, event):
    # print('resized', event.width, event.height)
    leftpixel = 80.0
    rightpixel = 20.0
    bottompixel = 50.0
    toppixel = 20.0
    x0 = leftpixel / event.width
    x1 = 1.0 - rightpixel / event.width
    y0 = bottompixel / event.height
    y1 = 1.0 - toppixel / event.height
    self.fig.subplots_adjust(left=x0, right=x1, bottom=y0, top=y1,
                             hspace=0)
def keypress(self, event)
Expand source code
def keypress(self, event):
    # print('pressed', event.key)
    if event.key in '+=X':
        if self.twindow * self.samplerate > 20:
            self.twindow *= 0.5
            self.update_plots()
    elif event.key in '-x':
        if self.twindow < self.tmax:
            self.twindow *= 2.0
            self.update_plots()
    elif event.key == 'pagedown':
        if self.toffset + 0.5 * self.twindow < self.tmax:
            self.toffset += 0.5 * self.twindow
            self.update_plots()
    elif event.key == 'pageup':
        if self.toffset > 0:
            self.toffset -= 0.5 * self.twindow
            if self.toffset < 0.0:
                self.toffset = 0.0
            self.update_plots()
    elif event.key == 'ctrl+pagedown':
        if self.toffset + 5.0 * self.twindow < self.tmax:
            self.toffset += 5.0 * self.twindow
            self.update_plots()
    elif event.key == 'ctrl+pageup':
        if self.toffset > 0:
            self.toffset -= 5.0 * self.twindow
            if self.toffset < 0.0:
                self.toffset = 0.0
            self.update_plots()
    elif event.key == 'down':
        if self.toffset + self.twindow < self.tmax:
            self.toffset += 0.05 * self.twindow
            self.update_plots()
    elif event.key == 'up':
        if self.toffset > 0.0:
            self.toffset -= 0.05 * self.twindow
            if self.toffset < 0.0:
                self.toffset = 0.0
            self.update_plots()
    elif event.key == 'home':
        if self.toffset > 0.0:
            self.toffset = 0.0
            self.update_plots()
    elif event.key == 'end':
        toffs = np.floor(self.tmax / self.twindow) * self.twindow
        if self.tmax - toffs <= 0.0:
            toffs -= self.twindow
        if self.tmax - toffs < self.twindow/2:
            toffs -= self.twindow/2
        if self.toffset < toffs:
            self.toffset = toffs
            self.update_plots()
    elif event.key == 'y':
        for t in range(self.traces):
            h = self.ymax[t] - self.ymin[t]
            c = 0.5 * (self.ymax[t] + self.ymin[t])
            self.ymin[t] = c - h
            self.ymax[t] = c + h
            self.axs[t].set_ylim(self.ymin[t], self.ymax[t])
        self.fig.canvas.draw()
    elif event.key == 'Y':
        for t in range(self.traces):
            h = 0.25 * (self.ymax[t] - self.ymin[t])
            c = 0.5 * (self.ymax[t] + self.ymin[t])
            self.ymin[t] = c - h
            self.ymax[t] = c + h
            self.axs[t].set_ylim(self.ymin[t], self.ymax[t])
        self.fig.canvas.draw()
    elif event.key == 'v':
        t0 = int(np.round(self.toffset * self.samplerate))
        t1 = int(np.round((self.toffset + self.twindow) * self.samplerate))
        min = np.min(self.data[t0:t1,self.show_channels])
        max = np.max(self.data[t0:t1,self.show_channels])
        h = 0.53 * (max - min)
        c = 0.5 * (max + min)
        self.ymin[:] = c - h
        self.ymax[:] = c + h
        for t in range(self.traces):
            self.axs[t].set_ylim(self.ymin[t], self.ymax[t])
        self.fig.canvas.draw()
    elif event.key == 'ctrl+v':
        t0 = int(np.round(self.toffset * self.samplerate))
        t1 = int(np.round((self.toffset + self.twindow) * self.samplerate))
        for t in range(self.traces):
            min = np.min(self.data[t0:t1,self.show_channels[t]])
            max = np.max(self.data[t0:t1,self.show_channels[t]])
            h = 0.53 * (max - min)
            c = 0.5 * (max + min)
            self.ymin[t] = c - h
            self.ymax[t] = c + h
            self.axs[t].set_ylim(self.ymin[t], self.ymax[t])
        self.fig.canvas.draw()
    elif event.key == 'ctrl+V':
        for t in range(self.traces):
            min = np.min(self.data[:,self.show_channels[t]])
            max = np.max(self.data[:,self.show_channels[t]])
            h = 0.53 * (max - min)
            c = 0.5 * (max + min)
            self.ymin[t] = c - h
            self.ymax[t] = c + h
            self.axs[t].set_ylim(self.ymin[t], self.ymax[t])
        self.fig.canvas.draw()
    elif event.key == 'V':
        self.ymin[:] = -1.0
        self.ymax[:] = +1.0
        for t in range(self.traces):
            self.axs[t].set_ylim(self.ymin[t], self.ymax[t])
        self.fig.canvas.draw()
    elif event.key == 'c':
        for t in range(self.traces):
            dy = self.ymax[t] - self.ymin[t]
            self.ymin[t] = -dy/2
            self.ymax[t] = +dy/2
            self.axs[t].set_ylim(self.ymin[t], self.ymax[t])
        self.fig.canvas.draw()
    elif event.key == 'g':
        self.show_gid = not self.show_gid
        self.plot_pulses(self.axs, False)
        self.fig.canvas.draw()
    elif event.key == 'i':
        if len(self.pulses) > 0:
            self.fmax *= 2
            self.axs[-1].set_ylim(0.0, self.fmax)
            self.fig.canvas.draw()
    elif event.key == 'I':
        if len(self.pulses) > 0:
            self.fmax /= 2
            self.axs[-1].set_ylim(0.0, self.fmax)
            self.fig.canvas.draw()
    elif event.key in ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']:
        cc = int(event.key)
        # TODO: this is not yet what we want:
        """
        if cc < self.channels:
            self.axs[cc].set_visible(not self.axs[cc].get_visible())
        self.fig.canvas.draw()
        """
    elif event.key in 'h':
        self.help = not self.help
        for ht in self.helptext:
            ht.set_visible(self.help)
        self.fig.canvas.draw()
    elif event.key in 'p':
        self.play_segment()
    elif event.key in 'P':
        self.play_all()
    elif event.key in 'S':
        self.save_segment()
    elif event.key in 'w':
        self.plot_traces()
def play_segment(self)
Expand source code
def play_segment(self):
    t0 = int(np.round(self.toffset * self.samplerate))
    t1 = int(np.round((self.toffset + self.twindow) * self.samplerate))
    playdata = 1.0 * np.mean(self.data[t0:t1,self.show_channels], 1)
    f = 0.1 if self.twindow > 0.5 else 0.1*self.twindow
    fade(playdata, self.samplerate, f)
    self.audio.play(playdata, self.samplerate, blocking=False)
def play_all(self)
Expand source code
def play_all(self):
    self.audio.play(np.mean(self.data[:,self.show_channels], 1),
                    self.samplerate, blocking=False)
def save_segment(self)
Expand source code
def save_segment(self):
    t0s = int(np.round(self.toffset))
    t1s = int(np.round(self.toffset + self.twindow))
    t0 = int(np.round(self.toffset * self.samplerate))
    t1 = int(np.round((self.toffset + self.twindow) * self.samplerate))
    filename = self.filename.split('.')[0]
    if self.traces == self.channels:
        segment_filename = f'{filename}-{t0s:.4g}s-{t1s:.4g}s.wav'
        write_audio(segment_filename, self.data[t0:t1,:], self.samplerate)
    else:
        segment_filename = f'{filename}-{t0s:.4g}s-{t1s:.4g}s-c{self.show_channels[0]}.wav'
        write_audio(segment_filename,
                    self.data[t0:t1,self.show_channels], self.samplerate)
    print('saved segment to: ' , segment_filename)
def plot_traces(self)
Expand source code
def plot_traces(self):
    splts = self.traces
    if len(self.pulses) > 0:
        splts += 1
    fig, axs = plt.subplots(splts, 1, squeeze=False, sharex=True,
                            figsize=(15, 9))
    axs = axs.flat
    fig.subplots_adjust(left=0.06, right=0.99, bottom=0.05, top=0.97,
                        hspace=0)
    name = self.filename.split('.')[0]
    figfile = f'{name}-{self.toffset:.4g}s-traces.png'
    if self.traces < self.channels:
        figfile = f'{name}-{self.toffset:.4g}s-c{self.show_channels[0]}-traces.png'
    axs[0].set_title(self.filename)
    t0 = int(np.round(self.toffset * self.samplerate))
    t1 = int(np.round((self.toffset + self.twindow) * self.samplerate))
    if t1>len(self.data):
        t1 = len(self.data)
    time = np.arange(t0, t1)/self.samplerate
    if self.toffset < 1.0 and self.twindow < 1.0:
        axs[-1].set_xlabel('Time [ms]')
        for t in range(self.traces):
            c = self.show_channels[t]
            axs[t].set_xlim(1000.0 * self.toffset,
                            1000.0 * (self.toffset + self.twindow))
            axs[t].plot(1000.0 * time, self.data[t0:t1,c])
        self.plot_pulses(axs, True, 1000.0)
    else:
        axs[-1].set_xlabel('Time [s]')
        for t in range(self.traces):
            c = self.show_channels[t]
            axs[t].set_xlim(self.toffset, self.toffset + self.twindow)
            axs[t].plot(time, self.data[t0:t1,c])
        self.plot_pulses(axs, True, 1.0)
    for t in range(self.traces):
        c = self.show_channels[t]
        axs[t].set_ylim(self.ymin[t], self.ymax[t])
        axs[t].set_ylabel(f'C-{c+1} [{self.unit}]')
    if len(self.pulses) > 0:
        axs[-1].set_ylabel('IP freq [Hz]')
        axs[-1].set_ylim(0.0, self.fmax)
    #for t in range(self.traces-1):
    #    axs[t].xaxis.set_major_formatter(plt.NullFormatter())
    fig.savefig(figfile, dpi=200)
    plt.close(fig)
    print('saved waveform figure to', figfile)