Module thunderfish.thunderbrowse

Functions

def short_user_warning(message, category, filename, lineno, file=None, line='')
Expand source code
def short_user_warning(message, category, filename, lineno, file=None, line=''):
    if file is None:
        file = sys.stderr
    if category == UserWarning:
        file.write('%s line %d: %s\n' % ('/'.join(filename.split('/')[-2:]), lineno, message))
    else:
        s = warnings.formatwarning(message, category, filename, lineno, line)
        file.write(s)
def main(cargs=None)
Expand source code
def main(cargs=None):
    warnings.showwarning = short_user_warning

    # config file name:
    cfgfile = __package__ + '.cfg'

    # command line arguments:
    if cargs is None:
        cargs = sys.argv[1:]
    parser = argparse.ArgumentParser(
        description='Browse mutlichannel EOD recordings.',
        epilog='version %s by Benda-Lab (2022-%s)' % (__version__, __year__))
    parser.add_argument('--version', action='version', version=__version__)
    parser.add_argument('-v', action='count', dest='verbose')
    parser.add_argument('-c', dest='channels', default='',
                        type=str, metavar='CHANNELS',
                        help='Comma separated list of channels to be displayed (first channel is 0).')
    parser.add_argument('-t', dest='tmax', default=None,
                        type=float, metavar='TMAX',
                        help='Process and show only the first TMAX seconds.')
    parser.add_argument('-f', dest='fcutoff', default=None,
                        type=float, metavar='FREQ',
                        help='Cutoff frequency of optional high-pass filter.')
    parser.add_argument('-p', dest='pulses', action='store_true',
                        help='detect pulse fish EODs')
    parser.add_argument('file', nargs=1, default='', type=str,
                        help='name of the file with the time series data')
    args = parser.parse_args(cargs)
    filepath = args.file[0]
    cs = [s.strip() for s in args.channels.split(',')]
    channels = [int(c) for c in cs if len(c)>0]
    tmax = args.tmax
    fcutoff = args.fcutoff
    pulses = args.pulses

    # set verbosity level from command line:
    verbose = 0
    if args.verbose != None:
        verbose = args.verbose

    # load data:
    filename = os.path.basename(filepath)
    with DataLoader(filepath, 10*60.0, 5.0, verbose) as data:
        SignalPlot(data, data.rate, data.unit, filename,
                   channels, tmax, fcutoff, pulses)

Classes

class SignalPlot (data, rate, unit, filename, show_channels=[], tmax=None, fcutoff=None, pulses=False)
Expand source code
class SignalPlot:
    def __init__(self, data, rate, unit, filename,
                 show_channels=[], tmax=None, fcutoff=None,
                 pulses=False):
        self.filename = filename
        self.rate = rate
        self.data = data
        self.channels = self.data.shape[1] if len(self.data.shape) > 1 else 1
        self.unit = unit
        self.tmax = (len(self.data)-1)/self.rate
        if not tmax is None:
            self.tmax = tmax
            self.data = data[:int(tmax*self.rate),:]
        self.toffset = 0.0
        self.twindow = 10.0
        if self.twindow > self.tmax:
            self.twindow = np.round(2 ** (np.floor(np.log(self.tmax) / np.log(2.0)) + 1.0))
            if not tmax is None:
                self.twindow = tmax
        self.pulses = np.zeros((0, 3), dtype=int)
        self.labels = []
        self.fishes = []
        self.pulse_times = []
        self.pulse_gids = []
        if len(show_channels) == 0:
            self.show_channels = np.arange(self.channels)
        else:
            self.show_channels = np.array(show_channels)
        self.traces = len(self.show_channels)
        self.ymin = -1.0 * np.ones(self.traces)
        self.ymax = +1.0 * np.ones(self.traces)
        self.fmax = 100.0
        self.trace_artist = [None] * self.traces
        self.show_gid = False
        self.pulse_artist = []
        self.marker_artist = [None] * (self.traces + 1)
        self.ipis_artist = []
        self.ipis_labels = []
        self.figf = None
        self.axf = None
        self.pulse_colors = ['C1', 'C2', 'C3', 'C4', 'C5', 'C6', 'C7', 'C8', 'C9', 'C0']
        self.help = False
        self.helptext = []
        self.audio = PlayAudio()

        # filter data:
        if not fcutoff is None:
            sos = butter(2, fcutoff, 'high', fs=rate, output='sos')
            self.data = sosfiltfilt(sos, self.data[:], 0)

        # pulse detection:
        if pulses:
            # label, group, channel, peak index, trough index
            all_pulses = np.zeros((0, 5), dtype=int)
            for c in range(self.channels):
                #thresh = 1*np.std(self.data[:int(2*self.rate),c])
                thresh = median_std_threshold(self.data[:,c], self.rate,
                                              thresh_fac=6.0)
                thresh = 0.01
                #p, t = detect_peaks(self.data[:,c], thresh)
                p, t, w, h = detect_pulses(self.data[:,c], self.rate,
                                           thresh,
                                           min_rel_slope_diff=0.25,
                                           min_width=0.0001,
                                           max_width=0.01,
                                           width_fac=5.0)
                # label, group, channel, peak, trough:
                pulses = np.hstack((np.arange(len(p))[:,np.newaxis],
                                    np.zeros((len(p), 1), dtype=int),
                                    np.ones((len(p), 1), dtype=int)*c,
                                    p[:,np.newaxis], t[:,np.newaxis]))
                all_pulses = np.vstack((all_pulses, pulses))
            self.pulses = all_pulses[np.argsort(all_pulses[:,3]),:]
            # grouping over channels:
            max_di = int(0.0002*self.rate)   # TODO: parameter
            l = -1
            k = 0
            while k < len(self.pulses):
                tp = self.pulses[k,3]
                tt = self.pulses[k,4]
                height = self.data[self.pulses[k,3],self.pulses[k,2]] - \
                    self.data[self.pulses[k,4],self.pulses[k,2]]
                channel_counts = np.zeros(self.channels, dtype=int)
                channel_counts[self.pulses[k,2]] += 1
                for c in range(1, 3*self.channels):
                    if k+c >= len(self.pulses):
                        break
                    # pulse too far away:
                    if channel_counts[self.pulses[k+c,2]] > 1 or \
                       (np.abs(self.pulses[k+c,3] - tp) > max_di and
                        np.abs(self.pulses[k+c,3] - tt) > max_di and
                        np.abs(self.pulses[k+c,4] - tp) > max_di and
                        np.abs(self.pulses[k+c,4] - tt) > max_di):
                        break
                    channel_counts[self.pulses[k+c,2]] += 1
                    height_kc = self.data[self.pulses[k+c,3],self.pulses[k+c,2]] - \
                        self.data[self.pulses[k+c,4],self.pulses[k+c,2]]
                    # heighest pulse sets time reference:
                    if height_kc > height:
                        tp = self.pulses[k+c,3]
                        tt = self.pulses[k+c,4]
                        height = height_kc
                # all pulses too small:
                if height < 0.02:    # TODO parameter
                    self.pulses[k:k+c,0] = -1
                    k += c
                    continue
                # new label:
                l += 1
                # remove lost pulses:
                for j in range(c):
                    if (np.abs(self.pulses[k+j,3] - tp) > max_di and
                        np.abs(self.pulses[k+j,3] - tt) > max_di and
                        np.abs(self.pulses[k+j,4] - tp) > max_di and
                        np.abs(self.pulses[k+j,4] - tt) > max_di):
                        self.pulses[k+j,0] = -1
                        channel_counts[self.pulses[k+j,2]] -= 1
                    else:
                        self.pulses[k+j,0] = l
                        self.pulses[k+j,1] = l
                # keep only the largest pulse of each channel:
                pulses = self.pulses[k:k+c,:]
                for dc in np.where(channel_counts > 1)[0]:
                    idx = np.where(self.pulses[k:k+c,2] == dc)[0]
                    heights = self.data[pulses[idx,3],dc] - \
                        self.data[pulses[idx,4],dc]
                    for i in range(len(idx)):
                        if i != np.argmax(heights):
                            channel_counts[self.pulses[k+idx[i],2]] -= 1
                            self.pulses[k+idx[i],0] = -1
                k += c
            self.pulses = self.pulses[self.pulses[:,0] >= 0,:]

            # clustering:
            min_dists = []
            recent = []
            k = 0
            while k < len(self.pulses):
                # select pulse group:
                j = k
                gid = self.pulses[j,1]
                for c in range(self.channels):
                    k += 1
                    if k >= len(self.pulses) or \
                       self.pulses[k,1] != gid:
                        break
                heights = np.zeros(self.channels)
                heights[self.pulses[j:k,2]] = \
                    self.data[self.pulses[j:k,3],self.pulses[j:k,2]] - \
                    self.data[self.pulses[j:k,4],self.pulses[j:k,2]]
                # time of largest pulse:
                pulse_time = self.pulses[j+np.argmax(heights[self.pulses[j:k,2]]),3]
                # assign to cluster:
                if len(self.pulse_times) == 0:
                    label = len(self.pulse_times)
                    self.pulse_times.append([])
                    self.pulse_gids.append([])
                else:
                    # compute metrics of recent fishes:
                    # mean relative height difference:
                    dists = np.array([np.mean(np.abs(hh - heights)/np.max(hh))
                                        for ll, tt, hh in recent])
                    thresh = 0.1   # TODO: make parameter
                    # distance between pulses:
                    ipis = np.array([(pulse_time - tt)/self.rate
                                     for ll, tt, hh in recent])
                    ## how can ipis be 0, or just one sample?
                    ##if len(ipis[ipis<0.001]) > 0:
                    ##    print(ipis[ipis<0.001])
                    # ensure minimum IP distance:
                    dists[1/ipis > 300.0] = 2*np.max(dists)  # TODO: make parameter
                    # minimum ditance:
                    min_dist_idx = np.argmin(dists)
                    min_dists.append(dists[min_dist_idx])
                    if dists[min_dist_idx] < thresh:
                        label = recent[min_dist_idx][0]
                    else:
                        label = len(self.pulse_times)
                        self.pulse_times.append([])
                        self.pulse_gids.append([])
                self.pulses[j:k,0] = label
                self.pulse_times[label].append(pulse_time)
                self.pulse_gids[label].append(gid)
                self.fishes.append([label, pulse_time, heights])
                recent.append([label, pulse_time, heights])
                # remove old fish:
                for i, (ll, tt, hh) in enumerate(recent):
                    # TODO: make parameter:
                    if (pulse_time - tt)/self.rate <= 0.2:
                        recent = recent[i:]
                        break
                # only consider the n most recent pulses of a fish:
                n = 5    # TODO make parameter
                labels = np.array([ll for ll, tt, hh in recent])
                if np.sum(labels == label) > n:
                    del recent[np.where(labels == label)[0][0]]
            # pulse times to arrays:
            for k in range(len(self.pulse_times)):
                self.pulse_times[k] = np.array(self.pulse_times[k])


                
            """
            # find temporally missing pulses:
            npulses = np.array([len(pts) for pts in self.pulse_times],
                               dtype=int)
            idx = np.argsort(npulses)
            for i in range(len(idx)):
                li = idx[len(idx)-1-i]
                if len(self.pulse_times[li]) < 10 or \
                   len(self.pulse_times[li])/npulses[li] < 0.5:
                    continue
                ipis = np.diff(self.pulse_times[li])
                n = 4 # TODO: make parameter
                k = 0
                while k < len(ipis)-n:
                    mipi = np.median(ipis[k:k+n])
                    if ipis[k+n-2] > 1.8*mipi:
                        # search for pulse closest to pt:
                        pt = self.pulse_times[li][k+n-2] + mipi
                        mlj = -1
                        mpj = -1
                        mdj = 10*mipi
                        for lj in range(len(self.pulse_times)):
                            if lj == li or len(self.pulse_times[lj]) == 0:
                                continue
                            pj = np.argmin(np.abs(self.pulse_times[lj] - pt))
                            dj = np.abs(self.pulse_times[lj][pj] - pt)
                            if dj < int(0.001*self.rate) and dj < mdj:
                                mdj = dj
                                mpj = pj
                                mlj = lj
                        if mlj >= 0:
                            # there is a pulse close to pt:
                            ptj = self.pulse_times[mlj][mpj]
                            pulses = self.pulses[self.pulses[:,0] == mlj,:]
                            gid = pulses[np.argmin(np.abs(pulses[:,3] - ptj)),1]
                            self.pulse_times[li] = np.insert(self.pulse_times[li], k+n-1, ptj)
                            self.pulse_gids[li].insert(k+n-1, gid)
                            # maybe don't delete but always duplicate and flag it:
                            if False:  # can be deleted
                                self.pulse_times[mlj] = np.delete(self.pulse_times[mlj], mpj)
                                self.pulse_gids[mlj].pop(mpj)
                                self.pulses[self.pulses[:,1] == gid,0] = li
                            else:     # pulse needs to be duplicated:
                                self.pulses[self.pulses[:,1] == gid,0] = li
                            ipis = np.diff(self.pulse_times[li])
                    k += 1


                    
            # clean up pulses:
            for l in range(len(self.pulse_times)):
                if len(self.pulse_times[l])/npulses[l] < 0.5:
                    self.pulse_times[l] = np.array([])
                    self.pulse_gids[l] = []
                    self.pulses[self.pulses[:,0] == l,0] = -1
            self.pulses = self.pulses[self.pulses[:,0] >= 0,:]
            """
            
            """
            # remove labels that are too close to others:
            widths = np.zeros(len(self.pulse_times), dtype=int)
            for k in range(len(self.pulse_times)):
                widths[k] = int(np.mean(np.abs(self.pulses[self.pulses[:,0] == k,3] - self.pulses[self.pulses[:,0] == k,4])))
            for k in range(len(self.pulse_times)):
                if len(self.pulse_times[k]) > 1:
                    for j in range(k+1, len(self.pulse_times)):
                        if len(self.pulse_times[j]) > 1:
                            di = 10*max(widths[k], widths[j])
                            dts = np.array([np.min(np.abs(self.pulse_times[k] - pt)) for pt in self.pulse_times[j]])
                            if k == 1 and j == 2:
                                print(di, np.sum(dts < di), len(dts))
                                plt.hist(dts, 50)
                                plt.show()
                            if np.sum(dts < 2*max_di)/len(dts) > 0.6:
                                r = k
                                if np.sum(self.fishes[k][2]) > np.sum(self.fishes[j][2]):
                                    r = j
                                self.pulse_times[r] = np.array([])
                                self.pulses[self.pulses[:,0] == r] = -1
                                self.fishes[r] = []
            self.pulses = self.pulses[self.pulses[:,0] >= 0,:]
            """
            # all labels:
            self.labels = np.unique(self.pulses[:,0])
            # report:
            print(f'found {len(self.pulse_times)} fish:')
            for k in range(len(self.pulse_times)):
                print(f'{k:3d}: {len(self.pulse_times[k]):5d} pulses')
            ## plot histogtram of distances:
            #plt.hist(min_dists, 100)
            #plt.show()
            ## plot features:
            """
            nn = np.array([(k, len(self.pulse_times[k]))
                           for k in range(len(self.pulse_times))])
            fig, axs = plt.subplots(5, 5, figsize=(15, 9),
                                    constrained_layout=True)
            ni = np.argsort(nn[:,1])           # largest cluster ...
            ln = np.sort(nn[ni[-axs.size:],0]) # ... sort by label
            for l, ax in zip(ln, axs.flat):
                h = np.array([hh for ll, tt, hh in self.fishes if ll == l])
                ax.plot(h.T, 'o-', ms=2, lw=0.5,
                        color=self.pulse_colors[l%len(self.pulse_colors)])
                ax.text(0.05, 0.9, f'label: {l}', transform=ax.transAxes)
            """
        
        # set key bindings:
        plt.rcParams['keymap.fullscreen'] = 'f'
        plt.rcParams['keymap.pan'] = 'ctrl+m'
        plt.rcParams['keymap.quit'] = 'ctrl+w, alt+q, q'
        plt.rcParams['keymap.yscale'] = ''
        plt.rcParams['keymap.xscale'] = ''
        plt.rcParams['keymap.grid'] = ''
        #plt.rcParams['keymap.all_axes'] = ''

        # the figure:
        plt.ioff()
        splts = self.traces
        if len(self.pulses) > 0:
            splts += 1
        self.fig, self.axs = plt.subplots(splts, 1, squeeze=False,
                                          figsize=(15, 9), sharex=True)
        self.axs = self.axs.flat
        if self.traces == self.channels:
            self.fig.canvas.manager.set_window_title(self.filename)
        else:
            cs = ' c%d' % self.show_channels[0]
            self.fig.canvas.manager.set_window_title(self.filename + ' ' + cs)
        self.fig.canvas.mpl_connect('key_press_event', self.keypress)
        self.fig.canvas.mpl_connect('resize_event', self.resize)
        self.fig.canvas.mpl_connect('pick_event', self.on_pick)
        # trace plots:
        for t in range(self.traces):
            self.axs[t].set_ylabel(f'C-{self.show_channels[t]+1} [{self.unit}]')
        #for t in range(self.traces-1):
        #    self.axs[t].xaxis.set_major_formatter(plt.NullFormatter())
        if len(self.pulses) > 0:
            self.axs[-1].set_ylim(0, self.fmax)
            self.axs[-1].set_ylabel('IP freq [Hz]')
        self.axs[-1].set_xlabel('Time [s]')
        ht = self.axs[0].text(0.98, 0.05, '(ctrl+) page and arrow up, down, home, end: scroll', ha='right',
                           transform=self.axs[0].transAxes)
        self.helptext.append(ht)
        ht = self.axs[0].text(0.98, 0.1, '+, -, X, x: zoom time in/out', ha='right', transform=self.axs[0].transAxes)
        self.helptext.append(ht)
        ht = self.axs[0].text(0.98, 0.2, 'y, Y, v, V, ctrl+v, ctrl+V: zoom amplitudes out/in/max/default/max per trace/global max per trace', ha='right', transform=self.axs[0].transAxes)
        self.helptext.append(ht)
        ht = self.axs[0].text(0.98, 0.3, 'i, I: zoom IPI frequency in/out', ha='right', transform=self.axs[0].transAxes)
        self.helptext.append(ht)
        ht = self.axs[0].text(0.98, 0.4, 'p, P: play audio (display, all)', ha='right', transform=self.axs[0].transAxes)
        self.helptext.append(ht)
        ht = self.axs[0].text(0.98, 0.5, 'f: full screen', ha='right', transform=self.axs[0].transAxes)
        self.helptext.append(ht)
        ht = self.axs[0].text(0.98, 0.6, 'w: plot waveforms into png file', ha='right', transform=self.axs[0].transAxes)
        self.helptext.append(ht)
        ht = self.axs[0].text(0.98, 0.7, 'S: save audiosegment', ha='right', transform=self.axs[0].transAxes)
        self.helptext.append(ht)
        ht = self.axs[0].text(0.98, 0.8, 'q: quit', ha='right', transform=self.axs[0].transAxes)
        self.helptext.append(ht)
        ht = self.axs[0].text(0.98, 0.9, 'h: toggle this help', ha='right', transform=self.axs[0].transAxes)
        self.helptext.append(ht)
        # plot:
        for ht in self.helptext:
            ht.set_visible(self.help)
        self.update_plots()
        # feature plot:
        if len(self.labels) > 0:
            self.figf, self.axf = plt.subplots()
        plt.show()

    def __del__(self):
        pass
        #self.audio.close()

    def plot_pulses(self, axs, plot=True, tfac=1.0):
        
        def plot_pulse_traces(pulses, i, pak):
            for t in range(self.traces):
                c = self.show_channels[t]
                p = pulses[pulses[:,2] == c,3]
                if len(p) == 0:
                    continue
                if plot or pak >= len(self.pulse_artist):
                    pa, = axs[t].plot(tfac*p/self.rate,
                                      self.data[p,c], 'o', picker=5,
                                      color=self.pulse_colors[i%len(self.pulse_colors)])
                    if not plot:
                        self.pulse_artist.append(pa)
                else:
                    self.pulse_artist[pak].set_data(tfac*p/self.rate,
                                                    self.data[p,c])
                    self.pulse_artist[pak].set_color(self.pulse_colors[i%len(self.pulse_colors)])
                #if len(p) > 1 and len(p) <= 10:
                #    self.pulse_artist[pak].set_markersize(15)
                pak += 1
            return pak

        # pulses:
        pak = 0
        if self.show_gid:
            for g in range(len(self.pulse_colors)):
                pulses = self.pulses[self.pulses[:,1] % len(self.pulse_colors) == g,:]
                pak = plot_pulse_traces(pulses, g, pak)
        else:
            for l in self.labels:
                pulses = self.pulses[self.pulses[:,0] == l,:]
                pak = plot_pulse_traces(pulses, l, pak)
        while pak < len(self.pulse_artist):
            self.pulse_artist[pak].set_data([], [])
            pak += 1
        # ipis:
        for l in self.labels:
            if l < len(self.pulse_times):
                pt = self.pulse_times[l]/self.rate
                if len(pt) > 10:
                    if plot or not l in self.ipis_labels:
                        pa, = axs[-1].plot(tfac*pt[:-1], 1.0/np.diff(pt),
                                           '-o', picker=5,
                                           color=self.pulse_colors[l%len(self.pulse_colors)])
                        if not plot:
                            self.ipis_artist.append(pa)
                            self.ipis_labels.append(l)
                    else:
                        iak = self.ipis_labels.index(l)
                        self.ipis_artist[iak].set_data(tfac*pt[:-1],
                                                       1.0/np.diff(pt))

    def update_plots(self):
        t0 = int(np.round(self.toffset * self.rate))
        t1 = int(np.round((self.toffset + self.twindow) * self.rate))
        if t1 > len(self.data):
            t1 = len(self.data)
        time = np.arange(t0, t1) / self.rate
        for t in range(self.traces):
            c = self.show_channels[t]
            self.axs[t].set_xlim(self.toffset, self.toffset + self.twindow)
            if self.trace_artist[t] == None:
                self.trace_artist[t], = self.axs[t].plot(time, self.data[t0:t1,c])
            else:
                self.trace_artist[t].set_data(time, self.data[t0:t1,c])
            if t1 - t0 < 200:
                self.trace_artist[t].set_marker('o')
                self.trace_artist[t].set_markersize(3)
            else:
                self.trace_artist[t].set_marker('None')
            self.axs[t].set_ylim(self.ymin[t], self.ymax[t])
        self.plot_pulses(self.axs, False)
        self.fig.canvas.draw()

    def on_pick(self, event):
        # index of pulse artist:
        pk = -1
        for k, pa in enumerate(self.pulse_artist):
            if event.artist == pa:
                pk = k
                break
        li = -1
        pi = -1
        if pk >= 0:
            # find label and pulses of pulse artist:
            ll = self.labels[pk//self.traces]
            cc = self.show_channels[pk % self.traces]
            pulses = self.pulses[self.pulses[:,0] == ll,:]
            gid = pulses[pulses[:,2] == cc,1][event.ind[0]]
            if ll in self.ipis_labels:
                li = self.ipis_labels.index(ll)
                pi = self.pulse_gids[ll].index(gid)
        else:
            ik = -1
            for k, ia in enumerate(self.ipis_artist):
                if event.artist == ia:
                    ik = k
                    break
            if ik < 0:
                return
            li = ik
            ll = self.ipis_labels[li]
            pi = event.ind[0]
            gid = self.pulse_gids[ll][pi]
        # mark pulses:
        pulses = self.pulses[self.pulses[:,0] == ll,:]
        pulses = pulses[pulses[:,1] == gid,:]
        for t in range(self.traces):
            c = self.show_channels[t]
            pt = pulses[pulses[:,2] == c,3]
            if len(pt) > 0:
                if self.marker_artist[t] is None:
                    pa, = self.axs[t].plot(pt[0]/self.rate,
                                           self.data[pt[0],c], 'o', ms=10,
                                           color=self.pulse_colors[ll%len(self.pulse_colors)])
                    self.marker_artist[t] = pa
                else:
                    self.marker_artist[t].set_data(pt[0]/self.rate,
                                                   self.data[pt[0],c])
                    self.marker_artist[t].set_color(self.pulse_colors[ll%len(self.pulse_colors)])
            elif self.marker_artist[t] is not None:
                self.marker_artist[t].set_data([], [])
        # mark ipi:
        pt0 = -1.0
        pt1 = -1.0
        pf = -1.0
        if pi >= 0:
            pt0 = self.pulse_times[ll][pi]/self.rate
            pt1 = self.pulse_times[ll][pi+1]/self.rate
            pf = 1.0/(pt1-pt0)
            if self.marker_artist[self.traces] is None:
                pa, = self.axs[self.traces].plot(pt0, pf, 'o', ms=10,
                                                 color=self.pulse_colors[ll%len(self.pulse_colors)])
                self.marker_artist[self.traces] = pa
            else:
                self.marker_artist[self.traces].set_data(pt0, pf)
                self.marker_artist[self.traces].set_color(self.pulse_colors[ll%len(self.pulse_colors)])
        elif not self.marker_artist[self.traces] is None:
            self.marker_artist[self.traces].set_data([], [])
        self.fig.canvas.draw()
        # show features:
        if not self.axf is None and not self.fig is None:
            heights = np.zeros(self.channels)
            heights[pulses[:,2]] = \
                self.data[pulses[:,3],pulses[:,2]] - \
                self.data[pulses[:,4],pulses[:,2]]
            self.axf.plot(heights, color=self.pulse_colors[ll%len(self.pulse_colors)])
            print(f'label={ll:4d} gid={gid:5d} t={pt0:8.4f}s')
            self.figf.canvas.draw()

    def resize(self, event):
        # print('resized', event.width, event.height)
        leftpixel = 80.0
        rightpixel = 20.0
        bottompixel = 50.0
        toppixel = 20.0
        x0 = leftpixel / event.width
        x1 = 1.0 - rightpixel / event.width
        y0 = bottompixel / event.height
        y1 = 1.0 - toppixel / event.height
        self.fig.subplots_adjust(left=x0, right=x1, bottom=y0, top=y1,
                                 hspace=0)

    def keypress(self, event):
        # print('pressed', event.key)
        if event.key in '+=X':
            if self.twindow * self.rate > 20:
                self.twindow *= 0.5
                self.update_plots()
        elif event.key in '-x':
            if self.twindow < self.tmax:
                self.twindow *= 2.0
                self.update_plots()
        elif event.key == 'pagedown':
            if self.toffset + 0.5 * self.twindow < self.tmax:
                self.toffset += 0.5 * self.twindow
                self.update_plots()
        elif event.key == 'pageup':
            if self.toffset > 0:
                self.toffset -= 0.5 * self.twindow
                if self.toffset < 0.0:
                    self.toffset = 0.0
                self.update_plots()
        elif event.key == 'ctrl+pagedown':
            if self.toffset + 5.0 * self.twindow < self.tmax:
                self.toffset += 5.0 * self.twindow
                self.update_plots()
        elif event.key == 'ctrl+pageup':
            if self.toffset > 0:
                self.toffset -= 5.0 * self.twindow
                if self.toffset < 0.0:
                    self.toffset = 0.0
                self.update_plots()
        elif event.key == 'down':
            if self.toffset + self.twindow < self.tmax:
                self.toffset += 0.05 * self.twindow
                self.update_plots()
        elif event.key == 'up':
            if self.toffset > 0.0:
                self.toffset -= 0.05 * self.twindow
                if self.toffset < 0.0:
                    self.toffset = 0.0
                self.update_plots()
        elif event.key == 'home':
            if self.toffset > 0.0:
                self.toffset = 0.0
                self.update_plots()
        elif event.key == 'end':
            toffs = np.floor(self.tmax / self.twindow) * self.twindow
            if self.tmax - toffs <= 0.0:
                toffs -= self.twindow
            if self.tmax - toffs < self.twindow/2:
                toffs -= self.twindow/2
            if self.toffset < toffs:
                self.toffset = toffs
                self.update_plots()
        elif event.key == 'y':
            for t in range(self.traces):
                h = self.ymax[t] - self.ymin[t]
                c = 0.5 * (self.ymax[t] + self.ymin[t])
                self.ymin[t] = c - h
                self.ymax[t] = c + h
                self.axs[t].set_ylim(self.ymin[t], self.ymax[t])
            self.fig.canvas.draw()
        elif event.key == 'Y':
            for t in range(self.traces):
                h = 0.25 * (self.ymax[t] - self.ymin[t])
                c = 0.5 * (self.ymax[t] + self.ymin[t])
                self.ymin[t] = c - h
                self.ymax[t] = c + h
                self.axs[t].set_ylim(self.ymin[t], self.ymax[t])
            self.fig.canvas.draw()
        elif event.key == 'v':
            t0 = int(np.round(self.toffset * self.rate))
            t1 = int(np.round((self.toffset + self.twindow) * self.rate))
            min = np.min(self.data[t0:t1,self.show_channels])
            max = np.max(self.data[t0:t1,self.show_channels])
            h = 0.53 * (max - min)
            c = 0.5 * (max + min)
            self.ymin[:] = c - h
            self.ymax[:] = c + h
            for t in range(self.traces):
                self.axs[t].set_ylim(self.ymin[t], self.ymax[t])
            self.fig.canvas.draw()
        elif event.key == 'ctrl+v':
            t0 = int(np.round(self.toffset * self.rate))
            t1 = int(np.round((self.toffset + self.twindow) * self.rate))
            for t in range(self.traces):
                min = np.min(self.data[t0:t1,self.show_channels[t]])
                max = np.max(self.data[t0:t1,self.show_channels[t]])
                h = 0.53 * (max - min)
                c = 0.5 * (max + min)
                self.ymin[t] = c - h
                self.ymax[t] = c + h
                self.axs[t].set_ylim(self.ymin[t], self.ymax[t])
            self.fig.canvas.draw()
        elif event.key == 'ctrl+V':
            for t in range(self.traces):
                min = np.min(self.data[:,self.show_channels[t]])
                max = np.max(self.data[:,self.show_channels[t]])
                h = 0.53 * (max - min)
                c = 0.5 * (max + min)
                self.ymin[t] = c - h
                self.ymax[t] = c + h
                self.axs[t].set_ylim(self.ymin[t], self.ymax[t])
            self.fig.canvas.draw()
        elif event.key == 'V':
            self.ymin[:] = -1.0
            self.ymax[:] = +1.0
            for t in range(self.traces):
                self.axs[t].set_ylim(self.ymin[t], self.ymax[t])
            self.fig.canvas.draw()
        elif event.key == 'c':
            for t in range(self.traces):
                dy = self.ymax[t] - self.ymin[t]
                self.ymin[t] = -dy/2
                self.ymax[t] = +dy/2
                self.axs[t].set_ylim(self.ymin[t], self.ymax[t])
            self.fig.canvas.draw()
        elif event.key == 'g':
            self.show_gid = not self.show_gid
            self.plot_pulses(self.axs, False)
            self.fig.canvas.draw()
        elif event.key == 'i':
            if len(self.pulses) > 0:
                self.fmax *= 2
                self.axs[-1].set_ylim(0.0, self.fmax)
                self.fig.canvas.draw()
        elif event.key == 'I':
            if len(self.pulses) > 0:
                self.fmax /= 2
                self.axs[-1].set_ylim(0.0, self.fmax)
                self.fig.canvas.draw()
        elif event.key in ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']:
            cc = int(event.key)
            # TODO: this is not yet what we want:
            """
            if cc < self.channels:
                self.axs[cc].set_visible(not self.axs[cc].get_visible())
            self.fig.canvas.draw()
            """
        elif event.key in 'h':
            self.help = not self.help
            for ht in self.helptext:
                ht.set_visible(self.help)
            self.fig.canvas.draw()
        elif event.key in 'p':
            self.play_segment()
        elif event.key in 'P':
            self.play_all()
        elif event.key in 'S':
            self.save_segment()
        elif event.key in 'w':
            self.plot_traces()

    def play_segment(self):
        t0 = int(np.round(self.toffset * self.rate))
        t1 = int(np.round((self.toffset + self.twindow) * self.rate))
        playdata = 1.0 * np.mean(self.data[t0:t1,self.show_channels], 1)
        f = 0.1 if self.twindow > 0.5 else 0.1*self.twindow
        fade(playdata, self.rate, f)
        self.audio.play(playdata, self.rate, blocking=False)
        
    def play_all(self):
        self.audio.play(np.mean(self.data[:,self.show_channels], 1),
                        self.rate, blocking=False)

    def save_segment(self):
        t0s = int(np.round(self.toffset))
        t1s = int(np.round(self.toffset + self.twindow))
        t0 = int(np.round(self.toffset * self.rate))
        t1 = int(np.round((self.toffset + self.twindow) * self.rate))
        filename = self.filename.split('.')[0]
        if self.traces == self.channels:
            segment_filename = f'{filename}-{t0s:.4g}s-{t1s:.4g}s.wav'
            write_audio(segment_filename, self.data[t0:t1,:], self.rate)
        else:
            segment_filename = f'{filename}-{t0s:.4g}s-{t1s:.4g}s-c{self.show_channels[0]}.wav'
            write_audio(segment_filename,
                        self.data[t0:t1,self.show_channels], self.rate)
        print('saved segment to: ' , segment_filename)

    def plot_traces(self):
        splts = self.traces
        if len(self.pulses) > 0:
            splts += 1
        fig, axs = plt.subplots(splts, 1, squeeze=False, sharex=True,
                                figsize=(15, 9))
        axs = axs.flat
        fig.subplots_adjust(left=0.06, right=0.99, bottom=0.05, top=0.97,
                            hspace=0)
        name = self.filename.split('.')[0]
        figfile = f'{name}-{self.toffset:.4g}s-traces.png'
        if self.traces < self.channels:
            figfile = f'{name}-{self.toffset:.4g}s-c{self.show_channels[0]}-traces.png'
        axs[0].set_title(self.filename)
        t0 = int(np.round(self.toffset * self.rate))
        t1 = int(np.round((self.toffset + self.twindow) * self.rate))
        if t1>len(self.data):
            t1 = len(self.data)
        time = np.arange(t0, t1)/self.rate
        if self.toffset < 1.0 and self.twindow < 1.0:
            axs[-1].set_xlabel('Time [ms]')
            for t in range(self.traces):
                c = self.show_channels[t]
                axs[t].set_xlim(1000.0 * self.toffset,
                                1000.0 * (self.toffset + self.twindow))
                axs[t].plot(1000.0 * time, self.data[t0:t1,c])
            self.plot_pulses(axs, True, 1000.0)
        else:
            axs[-1].set_xlabel('Time [s]')
            for t in range(self.traces):
                c = self.show_channels[t]
                axs[t].set_xlim(self.toffset, self.toffset + self.twindow)
                axs[t].plot(time, self.data[t0:t1,c])
            self.plot_pulses(axs, True, 1.0)
        for t in range(self.traces):
            c = self.show_channels[t]
            axs[t].set_ylim(self.ymin[t], self.ymax[t])
            axs[t].set_ylabel(f'C-{c+1} [{self.unit}]')
        if len(self.pulses) > 0:
            axs[-1].set_ylabel('IP freq [Hz]')
            axs[-1].set_ylim(0.0, self.fmax)
        #for t in range(self.traces-1):
        #    axs[t].xaxis.set_major_formatter(plt.NullFormatter())
        fig.savefig(figfile, dpi=200)
        plt.close(fig)
        print('saved waveform figure to', figfile)

Methods

def plot_pulses(self, axs, plot=True, tfac=1.0)
Expand source code
def plot_pulses(self, axs, plot=True, tfac=1.0):
    
    def plot_pulse_traces(pulses, i, pak):
        for t in range(self.traces):
            c = self.show_channels[t]
            p = pulses[pulses[:,2] == c,3]
            if len(p) == 0:
                continue
            if plot or pak >= len(self.pulse_artist):
                pa, = axs[t].plot(tfac*p/self.rate,
                                  self.data[p,c], 'o', picker=5,
                                  color=self.pulse_colors[i%len(self.pulse_colors)])
                if not plot:
                    self.pulse_artist.append(pa)
            else:
                self.pulse_artist[pak].set_data(tfac*p/self.rate,
                                                self.data[p,c])
                self.pulse_artist[pak].set_color(self.pulse_colors[i%len(self.pulse_colors)])
            #if len(p) > 1 and len(p) <= 10:
            #    self.pulse_artist[pak].set_markersize(15)
            pak += 1
        return pak

    # pulses:
    pak = 0
    if self.show_gid:
        for g in range(len(self.pulse_colors)):
            pulses = self.pulses[self.pulses[:,1] % len(self.pulse_colors) == g,:]
            pak = plot_pulse_traces(pulses, g, pak)
    else:
        for l in self.labels:
            pulses = self.pulses[self.pulses[:,0] == l,:]
            pak = plot_pulse_traces(pulses, l, pak)
    while pak < len(self.pulse_artist):
        self.pulse_artist[pak].set_data([], [])
        pak += 1
    # ipis:
    for l in self.labels:
        if l < len(self.pulse_times):
            pt = self.pulse_times[l]/self.rate
            if len(pt) > 10:
                if plot or not l in self.ipis_labels:
                    pa, = axs[-1].plot(tfac*pt[:-1], 1.0/np.diff(pt),
                                       '-o', picker=5,
                                       color=self.pulse_colors[l%len(self.pulse_colors)])
                    if not plot:
                        self.ipis_artist.append(pa)
                        self.ipis_labels.append(l)
                else:
                    iak = self.ipis_labels.index(l)
                    self.ipis_artist[iak].set_data(tfac*pt[:-1],
                                                   1.0/np.diff(pt))
def update_plots(self)
Expand source code
def update_plots(self):
    t0 = int(np.round(self.toffset * self.rate))
    t1 = int(np.round((self.toffset + self.twindow) * self.rate))
    if t1 > len(self.data):
        t1 = len(self.data)
    time = np.arange(t0, t1) / self.rate
    for t in range(self.traces):
        c = self.show_channels[t]
        self.axs[t].set_xlim(self.toffset, self.toffset + self.twindow)
        if self.trace_artist[t] == None:
            self.trace_artist[t], = self.axs[t].plot(time, self.data[t0:t1,c])
        else:
            self.trace_artist[t].set_data(time, self.data[t0:t1,c])
        if t1 - t0 < 200:
            self.trace_artist[t].set_marker('o')
            self.trace_artist[t].set_markersize(3)
        else:
            self.trace_artist[t].set_marker('None')
        self.axs[t].set_ylim(self.ymin[t], self.ymax[t])
    self.plot_pulses(self.axs, False)
    self.fig.canvas.draw()
def on_pick(self, event)
Expand source code
def on_pick(self, event):
    # index of pulse artist:
    pk = -1
    for k, pa in enumerate(self.pulse_artist):
        if event.artist == pa:
            pk = k
            break
    li = -1
    pi = -1
    if pk >= 0:
        # find label and pulses of pulse artist:
        ll = self.labels[pk//self.traces]
        cc = self.show_channels[pk % self.traces]
        pulses = self.pulses[self.pulses[:,0] == ll,:]
        gid = pulses[pulses[:,2] == cc,1][event.ind[0]]
        if ll in self.ipis_labels:
            li = self.ipis_labels.index(ll)
            pi = self.pulse_gids[ll].index(gid)
    else:
        ik = -1
        for k, ia in enumerate(self.ipis_artist):
            if event.artist == ia:
                ik = k
                break
        if ik < 0:
            return
        li = ik
        ll = self.ipis_labels[li]
        pi = event.ind[0]
        gid = self.pulse_gids[ll][pi]
    # mark pulses:
    pulses = self.pulses[self.pulses[:,0] == ll,:]
    pulses = pulses[pulses[:,1] == gid,:]
    for t in range(self.traces):
        c = self.show_channels[t]
        pt = pulses[pulses[:,2] == c,3]
        if len(pt) > 0:
            if self.marker_artist[t] is None:
                pa, = self.axs[t].plot(pt[0]/self.rate,
                                       self.data[pt[0],c], 'o', ms=10,
                                       color=self.pulse_colors[ll%len(self.pulse_colors)])
                self.marker_artist[t] = pa
            else:
                self.marker_artist[t].set_data(pt[0]/self.rate,
                                               self.data[pt[0],c])
                self.marker_artist[t].set_color(self.pulse_colors[ll%len(self.pulse_colors)])
        elif self.marker_artist[t] is not None:
            self.marker_artist[t].set_data([], [])
    # mark ipi:
    pt0 = -1.0
    pt1 = -1.0
    pf = -1.0
    if pi >= 0:
        pt0 = self.pulse_times[ll][pi]/self.rate
        pt1 = self.pulse_times[ll][pi+1]/self.rate
        pf = 1.0/(pt1-pt0)
        if self.marker_artist[self.traces] is None:
            pa, = self.axs[self.traces].plot(pt0, pf, 'o', ms=10,
                                             color=self.pulse_colors[ll%len(self.pulse_colors)])
            self.marker_artist[self.traces] = pa
        else:
            self.marker_artist[self.traces].set_data(pt0, pf)
            self.marker_artist[self.traces].set_color(self.pulse_colors[ll%len(self.pulse_colors)])
    elif not self.marker_artist[self.traces] is None:
        self.marker_artist[self.traces].set_data([], [])
    self.fig.canvas.draw()
    # show features:
    if not self.axf is None and not self.fig is None:
        heights = np.zeros(self.channels)
        heights[pulses[:,2]] = \
            self.data[pulses[:,3],pulses[:,2]] - \
            self.data[pulses[:,4],pulses[:,2]]
        self.axf.plot(heights, color=self.pulse_colors[ll%len(self.pulse_colors)])
        print(f'label={ll:4d} gid={gid:5d} t={pt0:8.4f}s')
        self.figf.canvas.draw()
def resize(self, event)
Expand source code
def resize(self, event):
    # print('resized', event.width, event.height)
    leftpixel = 80.0
    rightpixel = 20.0
    bottompixel = 50.0
    toppixel = 20.0
    x0 = leftpixel / event.width
    x1 = 1.0 - rightpixel / event.width
    y0 = bottompixel / event.height
    y1 = 1.0 - toppixel / event.height
    self.fig.subplots_adjust(left=x0, right=x1, bottom=y0, top=y1,
                             hspace=0)
def keypress(self, event)
Expand source code
def keypress(self, event):
    # print('pressed', event.key)
    if event.key in '+=X':
        if self.twindow * self.rate > 20:
            self.twindow *= 0.5
            self.update_plots()
    elif event.key in '-x':
        if self.twindow < self.tmax:
            self.twindow *= 2.0
            self.update_plots()
    elif event.key == 'pagedown':
        if self.toffset + 0.5 * self.twindow < self.tmax:
            self.toffset += 0.5 * self.twindow
            self.update_plots()
    elif event.key == 'pageup':
        if self.toffset > 0:
            self.toffset -= 0.5 * self.twindow
            if self.toffset < 0.0:
                self.toffset = 0.0
            self.update_plots()
    elif event.key == 'ctrl+pagedown':
        if self.toffset + 5.0 * self.twindow < self.tmax:
            self.toffset += 5.0 * self.twindow
            self.update_plots()
    elif event.key == 'ctrl+pageup':
        if self.toffset > 0:
            self.toffset -= 5.0 * self.twindow
            if self.toffset < 0.0:
                self.toffset = 0.0
            self.update_plots()
    elif event.key == 'down':
        if self.toffset + self.twindow < self.tmax:
            self.toffset += 0.05 * self.twindow
            self.update_plots()
    elif event.key == 'up':
        if self.toffset > 0.0:
            self.toffset -= 0.05 * self.twindow
            if self.toffset < 0.0:
                self.toffset = 0.0
            self.update_plots()
    elif event.key == 'home':
        if self.toffset > 0.0:
            self.toffset = 0.0
            self.update_plots()
    elif event.key == 'end':
        toffs = np.floor(self.tmax / self.twindow) * self.twindow
        if self.tmax - toffs <= 0.0:
            toffs -= self.twindow
        if self.tmax - toffs < self.twindow/2:
            toffs -= self.twindow/2
        if self.toffset < toffs:
            self.toffset = toffs
            self.update_plots()
    elif event.key == 'y':
        for t in range(self.traces):
            h = self.ymax[t] - self.ymin[t]
            c = 0.5 * (self.ymax[t] + self.ymin[t])
            self.ymin[t] = c - h
            self.ymax[t] = c + h
            self.axs[t].set_ylim(self.ymin[t], self.ymax[t])
        self.fig.canvas.draw()
    elif event.key == 'Y':
        for t in range(self.traces):
            h = 0.25 * (self.ymax[t] - self.ymin[t])
            c = 0.5 * (self.ymax[t] + self.ymin[t])
            self.ymin[t] = c - h
            self.ymax[t] = c + h
            self.axs[t].set_ylim(self.ymin[t], self.ymax[t])
        self.fig.canvas.draw()
    elif event.key == 'v':
        t0 = int(np.round(self.toffset * self.rate))
        t1 = int(np.round((self.toffset + self.twindow) * self.rate))
        min = np.min(self.data[t0:t1,self.show_channels])
        max = np.max(self.data[t0:t1,self.show_channels])
        h = 0.53 * (max - min)
        c = 0.5 * (max + min)
        self.ymin[:] = c - h
        self.ymax[:] = c + h
        for t in range(self.traces):
            self.axs[t].set_ylim(self.ymin[t], self.ymax[t])
        self.fig.canvas.draw()
    elif event.key == 'ctrl+v':
        t0 = int(np.round(self.toffset * self.rate))
        t1 = int(np.round((self.toffset + self.twindow) * self.rate))
        for t in range(self.traces):
            min = np.min(self.data[t0:t1,self.show_channels[t]])
            max = np.max(self.data[t0:t1,self.show_channels[t]])
            h = 0.53 * (max - min)
            c = 0.5 * (max + min)
            self.ymin[t] = c - h
            self.ymax[t] = c + h
            self.axs[t].set_ylim(self.ymin[t], self.ymax[t])
        self.fig.canvas.draw()
    elif event.key == 'ctrl+V':
        for t in range(self.traces):
            min = np.min(self.data[:,self.show_channels[t]])
            max = np.max(self.data[:,self.show_channels[t]])
            h = 0.53 * (max - min)
            c = 0.5 * (max + min)
            self.ymin[t] = c - h
            self.ymax[t] = c + h
            self.axs[t].set_ylim(self.ymin[t], self.ymax[t])
        self.fig.canvas.draw()
    elif event.key == 'V':
        self.ymin[:] = -1.0
        self.ymax[:] = +1.0
        for t in range(self.traces):
            self.axs[t].set_ylim(self.ymin[t], self.ymax[t])
        self.fig.canvas.draw()
    elif event.key == 'c':
        for t in range(self.traces):
            dy = self.ymax[t] - self.ymin[t]
            self.ymin[t] = -dy/2
            self.ymax[t] = +dy/2
            self.axs[t].set_ylim(self.ymin[t], self.ymax[t])
        self.fig.canvas.draw()
    elif event.key == 'g':
        self.show_gid = not self.show_gid
        self.plot_pulses(self.axs, False)
        self.fig.canvas.draw()
    elif event.key == 'i':
        if len(self.pulses) > 0:
            self.fmax *= 2
            self.axs[-1].set_ylim(0.0, self.fmax)
            self.fig.canvas.draw()
    elif event.key == 'I':
        if len(self.pulses) > 0:
            self.fmax /= 2
            self.axs[-1].set_ylim(0.0, self.fmax)
            self.fig.canvas.draw()
    elif event.key in ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']:
        cc = int(event.key)
        # TODO: this is not yet what we want:
        """
        if cc < self.channels:
            self.axs[cc].set_visible(not self.axs[cc].get_visible())
        self.fig.canvas.draw()
        """
    elif event.key in 'h':
        self.help = not self.help
        for ht in self.helptext:
            ht.set_visible(self.help)
        self.fig.canvas.draw()
    elif event.key in 'p':
        self.play_segment()
    elif event.key in 'P':
        self.play_all()
    elif event.key in 'S':
        self.save_segment()
    elif event.key in 'w':
        self.plot_traces()
def play_segment(self)
Expand source code
def play_segment(self):
    t0 = int(np.round(self.toffset * self.rate))
    t1 = int(np.round((self.toffset + self.twindow) * self.rate))
    playdata = 1.0 * np.mean(self.data[t0:t1,self.show_channels], 1)
    f = 0.1 if self.twindow > 0.5 else 0.1*self.twindow
    fade(playdata, self.rate, f)
    self.audio.play(playdata, self.rate, blocking=False)
def play_all(self)
Expand source code
def play_all(self):
    self.audio.play(np.mean(self.data[:,self.show_channels], 1),
                    self.rate, blocking=False)
def save_segment(self)
Expand source code
def save_segment(self):
    t0s = int(np.round(self.toffset))
    t1s = int(np.round(self.toffset + self.twindow))
    t0 = int(np.round(self.toffset * self.rate))
    t1 = int(np.round((self.toffset + self.twindow) * self.rate))
    filename = self.filename.split('.')[0]
    if self.traces == self.channels:
        segment_filename = f'{filename}-{t0s:.4g}s-{t1s:.4g}s.wav'
        write_audio(segment_filename, self.data[t0:t1,:], self.rate)
    else:
        segment_filename = f'{filename}-{t0s:.4g}s-{t1s:.4g}s-c{self.show_channels[0]}.wav'
        write_audio(segment_filename,
                    self.data[t0:t1,self.show_channels], self.rate)
    print('saved segment to: ' , segment_filename)
def plot_traces(self)
Expand source code
def plot_traces(self):
    splts = self.traces
    if len(self.pulses) > 0:
        splts += 1
    fig, axs = plt.subplots(splts, 1, squeeze=False, sharex=True,
                            figsize=(15, 9))
    axs = axs.flat
    fig.subplots_adjust(left=0.06, right=0.99, bottom=0.05, top=0.97,
                        hspace=0)
    name = self.filename.split('.')[0]
    figfile = f'{name}-{self.toffset:.4g}s-traces.png'
    if self.traces < self.channels:
        figfile = f'{name}-{self.toffset:.4g}s-c{self.show_channels[0]}-traces.png'
    axs[0].set_title(self.filename)
    t0 = int(np.round(self.toffset * self.rate))
    t1 = int(np.round((self.toffset + self.twindow) * self.rate))
    if t1>len(self.data):
        t1 = len(self.data)
    time = np.arange(t0, t1)/self.rate
    if self.toffset < 1.0 and self.twindow < 1.0:
        axs[-1].set_xlabel('Time [ms]')
        for t in range(self.traces):
            c = self.show_channels[t]
            axs[t].set_xlim(1000.0 * self.toffset,
                            1000.0 * (self.toffset + self.twindow))
            axs[t].plot(1000.0 * time, self.data[t0:t1,c])
        self.plot_pulses(axs, True, 1000.0)
    else:
        axs[-1].set_xlabel('Time [s]')
        for t in range(self.traces):
            c = self.show_channels[t]
            axs[t].set_xlim(self.toffset, self.toffset + self.twindow)
            axs[t].plot(time, self.data[t0:t1,c])
        self.plot_pulses(axs, True, 1.0)
    for t in range(self.traces):
        c = self.show_channels[t]
        axs[t].set_ylim(self.ymin[t], self.ymax[t])
        axs[t].set_ylabel(f'C-{c+1} [{self.unit}]')
    if len(self.pulses) > 0:
        axs[-1].set_ylabel('IP freq [Hz]')
        axs[-1].set_ylim(0.0, self.fmax)
    #for t in range(self.traces-1):
    #    axs[t].xaxis.set_major_formatter(plt.NullFormatter())
    fig.savefig(figfile, dpi=200)
    plt.close(fig)
    print('saved waveform figure to', figfile)