Module thunderlab.multivariateexplorer
Simple GUI for viewing and exploring multivariate data.
class MultiVariateExplorer
: simple matplotlib-based GUI for viewing and exploring multivariate data.categorize()
: convert categorial string data into integer categories.select_features()
: assemble list of column indices.select_coloring()
: select column from data table for colorizing the data.list_available_features()
: print available features on console.
def categorize(data)
Expand source code
def categorize(data): """Convert categorial string data into integer categories. Parameters ---------- data: list or ndarray of str Data with textual categories. Returns ------- categories: list of str A sorted unique list of the strings in `data`, that is the names of the categories. cdata: ndarray of int A copy of the input `data` where each string value is replaced by an integer number that is an index into the returned `categories`. """ cats = sorted(set(data)) cdata = np.array([cats.index(x) for x in data], dtype=int) return cats, cdata
Convert categorial string data into integer categories.
- Data with textual categories.
- A sorted unique list of the strings in
, that is the names of the categories. cdata
- A copy of the input
where each string value is replaced by an integer number that is an index into the returnedcategories
def select_features(data, columns)
Expand source code
def select_features(data, columns): """Assemble list of column indices. Parameters ---------- data: TableData The table from which to select features. columns: list of str Feature names (column headers) to be selected from the data. If a column is listed twice (even times) it is not added. Returns ------- data_cols: list of int List of indices into data columns for selecting features. """ if len(columns) == 0: data_cols = list(np.arange(len(data))) else: data_cols = [] for c in columns: idx = data.index(c) if idx is None: print(f'"{c}" is not a valid data column') elif idx in data_cols: data_cols.remove(idx) else: data_cols.append(idx) return data_cols
Assemble list of column indices.
- The table from which to select features.
- Feature names (column headers) to be selected from the data. If a column is listed twice (even times) it is not added.
- List of indices into data columns for selecting features.
def select_coloring(data, data_cols, color_col)
Expand source code
def select_coloring(data, data_cols, color_col): """Select column from data table for colorizing the data. Pass the output of this function on to MultivariateExplorer.set_colors(). Parameters ---------- data: TableData Table with all EOD properties from which columns are selected. data_cols: list of int List of columns selected to be explored. color_col: str or int Column to be selected for coloring the data. If 'row' then use the row index of the data in the table for coloring. Returns ------- colors: int or list of values or None Either index of `data_cols` or additional data from the data table to be used for coloring. color_label: str or None Label for labeling the color bar. color_idx: int or None Index of color column in `data`. error: None or str In case an invalid column is selected, an error string. """ color_idx = data.index(color_col) colors = None color_label = None if color_idx is None and color_col != 'row': return None, None, None, f'"{color_col}" is not a valid column for color code' if color_idx is None: colors = -2 elif color_idx in data_cols: colors = data_cols.index(color_idx) else: if len(data.unit(color_idx)) > 0 and not data.unit(color_idx) in ['-', '1']: color_label = f'{data.label(color_idx)} [{data.unit(color_idx)}]' else: color_label = data.label(color_idx) colors = data[:, color_idx] return colors, color_label, color_idx, None
Select column from data table for colorizing the data.
Pass the output of this function on to MultivariateExplorer.set_colors().
- Table with all EOD properties from which columns are selected.
- List of columns selected to be explored.
- Column to be selected for coloring the data. If 'row' then use the row index of the data in the table for coloring.
- Either index of
or additional data from the data table to be used for coloring. color_label
- Label for labeling the color bar.
- Index of color column in
. error
- In case an invalid column is selected, an error string.
def list_available_features(data, data_cols=[], color_col=None)
Expand source code
def list_available_features(data, data_cols=[], color_col=None): """Print available features on console. Parameters ---------- data: TableData The full data table. data_cols: list of int List of indices of columns (features) in the table that are passed on to the MultivariateExplorer. color_col: int or None Index of columns (feature) in the table that is initially used for color coding the data. """ print('available features:') for k, c in enumerate(data.keys()): s = [' '] * 3 if k in data_cols: s[1] = '*' if color_col is not None and k == color_col: s[0] = 'C' print(''.join(s) + c) if len(data_cols) > 0: print('*: feature selected for exploration') if color_col is not None: print('C: feature selected for color coding the data')
Print available features on console.
- The full data table.
- List of indices of columns (features) in the table that are passed on to the MultivariateExplorer.
- Index of columns (feature) in the table that is initially used for color coding the data.
def demo()
Expand source code
def demo(): """Run the multivariate explorer with a random test data set. """ # generate data: n = 100 data = [] data.append(np.random.randn(n) + 2.0) data.append(1.0+0.1*data[0] + 1.5*np.random.randn(n)) data.append(10*(-3.0*data[0] + 2.0*data[1] + 1.8*np.random.randn(n))) idx = np.random.randint(0, 3, n) names = ['aaa', 'bbb', 'ccc'] data.append([names[i] for i in idx]) # generate waveforms: waveforms = [] time = np.arange(0.0, 10.0, 0.01) for r in range(len(data[0])): x = data[0][r]*np.sin(2.0*np.pi*data[1][r]*time + data[2][r]) y = data[0][r]*np.exp(-0.5*((time-data[1][r])/(0.2*data[2][r]))**2.0) waveforms.append(np.column_stack((time, x, y))) #waveforms.append([np.column_stack((time, x)), np.column_stack((time, y))]) # initialize explorer: expl = MultivariateExplorer(data, list(map(chr, np.arange(len(data))+ord('A'))), 'Explorer') expl.set_wave_data(waveforms, 'Time', ['Sine', 'Gauss']) # explore data: expl.set_colors()
Run the multivariate explorer with a random test data set.
def main(*cargs)
Expand source code
def main(*cargs): # parse command line: parser = argparse.ArgumentParser(add_help=False, description='View and explore multivariate data.', epilog = f'version {__version__} by Benda-Lab (2019-{__year__})') parser.add_argument('-h', '--help', nargs=0, action=PrintHelp, help='show this help message and exit') parser.add_argument('--version', action='version', version=__version__) parser.add_argument('-l', dest='list_features', action='store_true', help='list all available data columns (features) and exit') parser.add_argument('-d', dest='data_cols', action='append', default=[], metavar='COLUMN', help='data columns (features) to be explored') parser.add_argument('-c', dest='color_col', default=None, type=str, metavar='COLUMN', help='data column to be used for color code or "row"') parser.add_argument('-m', dest='color_map', default='jet', type=str, metavar='CMAP', help='name of color map to be used') parser.add_argument('file', nargs='?', default='', type=str, help='a file containing a table of data (csv file or similar)') if len(cargs) == 0: cargs = None args = parser.parse_args(cargs) if args.file: # load data: data = TableData(args.file) data_cols = select_features(data, args.data_cols) # select column used for coloring the data: colors, color_label, color_col, error = \ select_coloring(data, data_cols, args.color_col) if error: parser.error(error) # list features: if args.list_features: list_available_features(data, data_cols, color_col) parser.exit() # explore data: expl = MultivariateExplorer(data[:, data_cols]) expl.set_colors(colors, color_label, args.color_map) else: demo()
class MultivariateExplorer (data, labels=None, title=None)
Expand source code
class MultivariateExplorer(object): """Simple matplotlib-based GUI for viewing and exploring multivariate data. Shown are scatter plots of all pairs of variables or PCA axis. Points in the scatter plots are colored according to the values of one of the variables. Data points can be selected and optionally corresponding waveforms are shown. First you initialize the explorer with the data. Then you optionally specify how to colorize the data and provide waveform data associated with the data. Finally you show the figure: ``` expl = MultivariateExplorer(data) expl.set_colors(2) expl.set_wave_data(waveforms, 'Time [s]', 'Sine') ``` The `compute_pca() function computes a principal component analysis (PCA) on the input data, and `save_pca()` writes the principal components to a file. Customize the appearance and information provided by subclassing MultivariateExplorer and reimplementing the functions - fix_scatter_plot() - fix_waveform_plot() - list_selection() - analyze_selection() See the documentation of these functions for details. """ mouse_actions = [ ('left click', 'select sample'), ('left and drag', 'rectangular selection of samples and/or zoom'), ('shift + left click/drag', 'add samples to selection'), ('ctrl + left click/drag', 'remove samples from selection') ] """List of tuples with mouse actions and a description of their action.""" key_actions = [ ('c, C', 'cycle color map trough data columns'), ('p,P', 'toggle between features, PCs, and scaled PCs'), ('<, pageup', 'decrease number of displayed featured/PCs'), ('>, pagedown', 'increase number of displayed features/PCs'), ('o, z', 'toggle zoom mode on or off'), ('backspace', 'zoom back'), ('n, N', 'decrease, increase number of bins of histograms'), ('H', 'toggle between scatter plot and 2D histogram'), ('left, right, up, down', 'show and move magnified scatter plot'), ('escape', 'close magnified scatter plot'), ('ctrl + a', 'select all'), ('+, -', 'increase, decrease pick radius'), ('0', 'reset pick radius'), ('l', 'list selection on console'), ('w', 'toggle maximized waveform plot'), ('h', 'toggle help window'), ] """List of tuples with key shortcuts and a description of their action.""" def __init__(self, data, labels=None, title=None): """Initialize explorer with scatter-plot data. Parameters ---------- data: TableData, 2D array, or list of 1D arrays The data to be explored. Each column is a variable. For the 2D array the columns are the second dimension, for a list of 1D arrays, the list goes over columns, i.e. each 1D array is one column. labels: list of str If data is not a TableData, then this provides labels for the data columns. title: str Title for the window. """ # data. categories and labels: self.raw_data = None # original data table as 2D numpy array (samples x features) self.raw_labels = None # for each feature a label, optional with unit self.categories = [] # for each feature None or list of categories if isinstance(data, TableData): for c, col in enumerate(data): if not isinstance(col[0], (int, float, np.integer, np.floating)): # categorial data: #print(data[:,c]) cats, data[:,c] = categorize(col) self.categories.append(cats) else: self.categories.append(None) self.raw_data = data.array() if labels is None: self.raw_labels = [] for c in range(len(data)): if len(data.unit(c)) > 0 and not data.unit(c) in ['-', '1']: self.raw_labels.append(f'{data.label(c)} [{data.unit(c)}]') else: self.raw_labels.append(data.label(c)) else: self.raw_labels = labels else: if isinstance(data, np.ndarray): self.raw_data = data self.categories = [None] * data.shape[1] else: for c, col in enumerate(data): if not isinstance(col[0], (int, float, np.integer, np.floating)): # categorial data: cats, data[c] = categorize(col) self.categories.append(cats) else: self.categories.append(None) self.raw_data = np.asarray(data).T self.raw_labels = labels # remove columns containing only invalid numbers: cols = np.all(~np.isfinite(self.raw_data), 0) if np.sum(cols) > 0: print('removed columns containing no numbers:', [l for l, c in zip(self.raw_labels, cols) if c]) self.raw_data = self.raw_data[:, ~cols] self.raw_labels = [l for l, c in zip(self.raw_labels, cols) if not c] # remove rows containing invalid numbers: self.valid_samples = ~np.any(~np.isfinite(self.raw_data), 1) self.raw_data = self.raw_data[self.valid_samples, :] if np.sum(~self.valid_samples) > 0: print(f'removed {np.sum(~self.valid_samples)} rows containing invalid numbers:') for k in range(len(self.valid_samples)): if not self.valid_samples[k]: print(k) self.valid_rows = [k for k in range(len(self.valid_samples)) if self.valid_samples[k]] # title for the window: self.title = title if title is not None else 'MultivariateExplorer' # data, pca-data, scaled-pca data (no pca data yet): self.all_data = [self.raw_data, None, None] self.all_labels = [self.raw_labels, None, None] self.all_maxcols = [self.raw_data.shape[1], None, None] self.all_titles = ['data', 'PCA', 'scaled PCA'] # added to window title # pca: self.pca_tables = [None, None] # raw and scaled pca coefficients self._pca_header(self.raw_data, self.raw_labels) # prepare header of the pca tables # start showing raw data: self.show_mode = 0 # show data, pca or scaled pca = self.all_data[self.show_mode] # the data shown self.labels = self.all_labels[self.show_mode] # the feature labels shown self.maxcols = self.all_maxcols[self.show_mode] # maximum number of features currently shown if self.maxcols > 6: self.maxcols = 6 # waveform data: self.wave_data = [] self.wave_nested = False self.wave_has_xticks = [] self.wave_xlabels = [] self.wave_ylabels = [] self.wave_title = False # colors: self.color_map = plt.get_cmap('jet') self.extra_colors = None # additional data column to be used for coloring self.extra_color_label = None # label for extra_colors self.extra_categories = None # category name for extra_colors if needed self.color_set_index = 0 # -1: rows and extra_colors, 0: data, 1: pca, 2: scaled pca self.color_index = 0 # column used for coloring with color_set_index self.color_values = None # data column currently used for coloring as specified by color_set_index and color_index self.color_label = None # label of data currently used for coloring self.data_colors = None # actual colors for color_values self.color_vmin = None self.color_vmax = None self.color_ticks = None self.cbax = None # axes of color bar # figure variables: self.plt_params = {} for k in ['toolbar', 'keymap.quit', 'keymap.back', 'keymap.forward', 'keymap.zoom', 'keymap.pan', 'keymap.xscale', 'keymap.yscale']: self.plt_params[k] = plt.rcParams[k] if k != 'toolbar': plt.rcParams[k] = '' self.xborder = 100.0 # pixel for ylabels self.yborder = 50.0 # pixel for xlabels self.spacing = 10.0 # pixel between plots self.mborder = 20.0 # pixel around magnified plot self.pick_radius = 4.0 # histogram plots: self.hist_ax = [] # list of histogram axes self.hist_indices = [] # feature index of the histogram axes self.hist_selector = [] # for each histogram axes a selector self.hist_nbins = 30 # number of bins for computing histograms # scatter plots: self.scatter_ax = [] # list of axes with scatter plots (1D) self.scatter_indices = [] # for each axes a tuple of the column and row index self.scatter_artists = [] # artists of selected scatter points self.scatter_selector = [] # selector for each axes self.scatter = True # scatter (True) or density (False) self.mark_data = [] # list of indices of selected data self.significance_level = 0.05 # r is bold if p is below self.select_zooms = False self.zoom_stack = [] # magnified scatter plot: self.magnified_on = False self.magnified_backdrop = None self.magnified_size = np.array([0.6, 0.6]) # waveform plots: self.wave_ax = [] # help window: self.help_ax = None def set_wave_data(self, data, xlabels='', ylabels=[], title=False): """Add waveform data to explorer. Parameters ---------- data: list of (list of) 2D arrays Waveform data associated with each row of the data. Elements of the outer list correspond to the rows of the data. The inner 2D arrays contain a common x-axes (first column) and one or more corresponding y-values (second and optional higher columns). Each column for y-values is plotted in its own axes on top of each other, from top to bottom. The optional inner list of 2D arrays contains several 2D arrays as ascribed above each with its own common x-axes. xlabel: str or list of str The xlabels for the waveform plots. If only a string is given, then there will be a common xaxis for all the plots, and only the lowest one gets a labeled xaxis. If a list of strings is given, each waveform plot gets its own labeled x-axis. ylabels: list of str The ylabels for each of the waveform plots. title: bool or str If True or a string, povide space on top of the waveform plots for a title. If string, set this as the title for the waveform plots. """ self.wave_data = [] if data is not None and len(data) > 0: self.wave_data = data self.wave_has_xticks = [] self.wave_nested = isinstance(data[0], (list, tuple)) if self.wave_nested: for data in self.wave_data[0]: for k in range(data.shape[1]-2): self.wave_has_xticks.append(False) self.wave_has_xticks.append(True) else: for k in range(self.wave_data[0].shape[1]-2): self.wave_has_xticks.append(False) self.wave_has_xticks.append(True) if isinstance(xlabels, (list, tuple)): self.wave_xlabels = xlabels else: self.wave_xlabels = [xlabels] self.wave_ylabels = ylabels self.wave_title = title self.wave_ax = [] def set_colors(self, colors=0, color_label=None, color_map=None): """Set data column used to color scatter plots. Parameters ---------- colors: int or 1D array Index to colum in data to be used for coloring scatter plots. -2 for coloring row index of data. Or data array used to color scalar plots. color_label: str If colors is an array, this is a label describing the data. It is used to label the color bar. color_map: str or None Name of a matplotlib color map. If None 'jet' is used. """ if isinstance(colors, (np.integer, int)): if colors < 0: self.color_set_index = -1 self.color_index = 0 else: self.color_set_index = 0 self.color_index = colors else: if not isinstance(colors[0], (int, float, np.integer, np.floating)): # categorial data: self.extra_categories, self.extra_colors = categorize(colors) else: self.extra_colors = colors self.extra_colors = self.extra_colors[self.valid_samples] self.extra_color_label = color_label self.color_set_index = -1 self.color_index = 1 self.color_map = plt.get_cmap(color_map if color_map else 'jet') def show(self, ioff=True): """Show interactive scatter plots for exploration. """ if ioff: plt.ioff() else: plt.ion() plt.rcParams['toolbar'] = 'None' plt.rcParams['keymap.quit'] = 'ctrl+w, alt+q, ctrl+q, q' plt.rcParams['font.size'] = 12 self.fig = plt.figure(facecolor='white', figsize=(10, 8)) self.fig.canvas.manager.set_window_title(self.title + ': ' + self.all_titles[self.show_mode]) self.fig.canvas.mpl_connect('key_press_event', self._on_key) self.fig.canvas.mpl_connect('resize_event', self._on_resize) self.fig.canvas.mpl_connect('pick_event', self._on_pick) if self.color_map is None: self.color_map = plt.get_cmap('jet') self._set_color_column() self._init_hist_plots() self._init_scatter_plots() self.wave_ax = [] if self.wave_data is not None and len(self.wave_data) > 0: axx = None xi = 0 for k, has_xticks in enumerate(self.wave_has_xticks): ax = self.fig.add_subplot(1, len(self.wave_has_xticks), 1+k, sharex=axx) self.wave_ax.append(ax) if has_xticks: if xi >= len(self.wave_xlabels): self.wave_xlabels.append('') ax.set_xlabel(self.wave_xlabels[xi]) xi += 1 axx = None else: #ax.xaxis.set_major_formatter(plt.NullFormatter()) if axx is None: axx = ax for ax, ylabel in zip(self.wave_ax, self.wave_ylabels): ax.set_ylabel(ylabel) if not isinstance(self.wave_title, bool) and self.wave_title: self.wave_ax[0].set_title(self.wave_title) self.fix_waveform_plot(self.wave_ax, self.mark_data) self._plot_magnified_scatter() self._plot_help() def _pca_header(self, data, labels): """Set up header for the table of principal components. Parameters ---------- data: ndarray of float The data (samples x features) without invalid (infinite or NaN) numbers. labels: list of str Labels of the features. """ lbs = [] for l, d in zip(labels, data): if '[' in l: lbs.append(l.split('[')[0].strip()) elif '/' in l: lbs.append(l.split('/')[0].strip()) else: lbs.append(l) header = TableData(header=lbs) header.set_formats('%.3f') header.insert(0, ['PC'] + ['-']*header.nsecs, '', '%d') header.insert(1, 'variance', '%', '%.3f') for k in range(len(self.pca_tables)): self.pca_tables[k] = TableData(header) def compute_pca(self, scale=False, write=False): """Compute PCA based on the data. Parameters ---------- scale: boolean If True standardize data before computing PCA, i.e. remove mean of each variabel and divide by its standard deviation. write: boolean If True write PCA components to standard out. """ # pca: pca = decomposition.PCA() if scale: scaler = preprocessing.StandardScaler() pca_label = 'sPC' else: pca_label = 'PC' for k in range(len(pca.components_)): if np.abs(np.min(pca.components_[k])) > np.max(pca.components_[k]): pca.components_[k] *= -1.0 pca_data = pca.transform(self.raw_data) pca_labels = [f'{pca_label}{k+1} ' + (f'({100*v:.1f}%)' if v > 0.01 else (f'{100*v:.2f}%')) for k, v in enumerate(pca.explained_variance_ratio_)] if np.min(pca.explained_variance_ratio_) >= 0.01: pca_maxcols = pca_data.shape[1] else: pca_maxcols = np.argmax(pca.explained_variance_ratio_ < 0.01) if pca_maxcols < 2: pca_maxcols = 2 if pca_maxcols > 6: pca_maxcols = 6 # table with PCA feature weights: pca_table = self.pca_tables[1] if scale else self.pca_tables[0] pca_table.clear_data() pca_table.set_section(pca_label, 0, pca_table.nsecs) for k, comp in enumerate(pca.components_): pca_table.append_data(k+1, 0) pca_table.append_data(100.0*pca.explained_variance_ratio_[k]) pca_table.append_data(comp) if write: pca_table.write(table_format='out', unit_style='none') # submit data: if scale: self.all_data[2] = pca_data self.all_labels[2] = pca_labels self.all_maxcols[2] = pca_maxcols else: self.all_data[1] = pca_data self.all_labels[1] = pca_labels self.all_maxcols[1] = pca_maxcols def save_pca(self, file_name, scale, **kwargs): """Write PCA data to file. Parameters ---------- file_name: str Name of ouput file. scale: boolean If True write PCA components of standardized PCA. kwargs: dict Additional parameter for TableData.write() """ if scale: pca_file = file_name + '-pcacor' pca_table = self.pca_tables[1] else: pca_file = file_name + '-pcacov' pca_table = self.pca_tables[0] if 'unit_style' in kwargs: del kwargs['unit_style'] if 'table_format' in kwargs: pca_table.write(pca_file, unit_style='none', **kwargs) else: pca_file += '.dat' pca_table.write(pca_file, unit_style='none') def _set_color_column(self): """Initialize variables used for colorization of scatter points.""" if self.color_set_index == -1: if self.color_index == 0: self.color_values = np.arange([0], dtype=float) self.color_label = 'sample' elif self.color_index == 1: self.color_values = self.extra_colors self.color_label = self.extra_color_label else: self.color_values = self.all_data[self.color_set_index][:,self.color_index] self.color_label = self.all_labels[self.color_set_index][self.color_index] self.color_vmin, self.color_vmax, self.color_ticks = \ self.fix_scatter_plot(self.cbax, self.color_values, self.color_label, 'c') if self.color_ticks is None: if self.color_set_index == 0 and \ self.categories[self.color_index] is not None: self.color_ticks = np.arange(len(self.categories[self.color_index])) elif self.color_set_index == -1 and \ self.color_index == 1 and \ self.extra_categories is not None: self.color_ticks = np.arange(len(self.extra_categories)) self.data_colors = self.color_map((self.color_values - self.color_vmin)/(self.color_vmax - self.color_vmin)) def _add_backdrop(self, ax): bbox = ax.get_tightbbox(self.fig.canvas.get_renderer()) if bbox is not None: self.magnified_backdrop = \ patches.Rectangle((bbox.x0 - self.mborder, bbox.y0 - self.mborder), bbox.width + 2*self.mborder, bbox.height + 2*self.mborder, transform=None, clip_on=False, facecolor='#f7f7f7', edgecolor='none', zorder=-5) ax.add_patch(self.magnified_backdrop) def _create_selector(self, ax): try: selector = \ widgets.RectangleSelector(ax, self._on_select, useblit=True, button=1, minspanx=0, minspany=0, spancoords='pixels', props=dict(facecolor='gray', edgecolor='gray', alpha=0.2, fill=True), state_modifier_keys=dict(move='', clear='', square='', center='ctrl')) except TypeError: # old matplotlib: selector = widgets.RectangleSelector(ax, self._on_select, useblit=True, button=1) return selector def _plot_hist(self, ax, magnifiedax): """Plot and label a histogram.""" try: idx = self.hist_ax.index(ax) c = self.hist_indices[idx] in_hist = True except ValueError: idx = self.scatter_ax.index(ax) c = self.scatter_indices[idx][0] in_hist = False ax.clear() #ax.relim() #ax.autoscale(True) x =[:,c] ax.hist(x, self.hist_nbins) #ax.autoscale(False) ax.set_xlabel(self.labels[c]) ax.xaxis.set_major_locator(plt.AutoLocator()) ax.xaxis.set_major_formatter(plt.ScalarFormatter()) if self.show_mode == 0: if self.categories[c] is not None: ax.set_xticks(np.arange(len(self.categories[c]))) ax.set_xticklabels(self.categories[c]) self.fix_scatter_plot(ax,[:,c], self.labels[c], 'x') if magnifiedax: ax.text(0.05, 0.9, f'n={len(}', transform=ax.transAxes, zorder=100) ax.set_ylabel('count') cax = self.hist_ax[self.scatter_indices[-1][0]] ax.set_xlim(cax.get_xlim()) else: if c == 0: ax.text(0.05, 0.9, f'n={len(}', transform=ax.transAxes, zorder=100) ax.set_ylabel('count') else: ax.yaxis.set_major_formatter(plt.NullFormatter()) selector = self._create_selector(ax) if in_hist: self.hist_selector[idx] = selector else: self.scatter_selector[idx] = selector self.scatter_artists[idx] = None ax.relim(True) if magnifiedax: self._add_backdrop(ax) def _set_hist_ylim(self): ymax = np.max([ax.get_ylim() for ax in self.hist_ax[:self.maxcols]], 0)[1] for ax in self.hist_ax: ax.set_ylim(0, ymax) def _init_hist_plots(self): """Initial plots of the histograms.""" n =[1] self.hist_ax = [] for r in range(n): ax = self.fig.add_subplot(n, n, (n-1)*n+r+1) self.hist_ax.append(ax) self.hist_indices.append(r) self.hist_selector.append(None) self._plot_hist(ax, False) self._set_hist_ylim() def _plot_scatter(self, ax, magnifiedax, cax=None): """Plot a scatter plot.""" idx = self.scatter_ax.index(ax) c, r = self.scatter_indices[idx] if self.scatter: # scatter plot ax.clear() a = ax.scatter([:,c],[:,r], s=50, edgecolors='white', linewidths=0.5, picker=self.pick_radius, zorder=10) a.set_facecolor(self.data_colors) pr, pp = pearsonr([:,c],[:,r]) fw = 'bold' if pp < self.significance_level/[1] else 'normal' if pr < 0: ax.text(0.95, 0.9, f'r={pr:.2f}, p={pp:.3f}', fontweight=fw, transform=ax.transAxes, ha='right', zorder=100) else: ax.text(0.05, 0.9, f'r={pr:.2f}, p={pp:.3f}', fontweight=fw, transform=ax.transAxes, zorder=100) # color bar: if cax is not None: a = ax.scatter([:, c],[:, r], c=self.color_values, cmap=self.color_map) self.fig.colorbar(a, cax=cax, ticks=self.color_ticks) a.remove() cax.set_ylabel(self.color_label) self.color_vmin, self.color_vmax, self.color_ticks = \ self.fix_scatter_plot(self.cbax, self.color_values, self.color_label, 'c') if self.color_ticks is None: if self.color_set_index == 0 and \ self.categories[self.color_index] is not None: cax.set_yticklabels(self.categories[self.color_index]) elif self.color_set_index == -1 and \ self.color_index == 1 and \ self.extra_categories is not None: cax.set_yticklabels(self.extra_categories) else: # histogram if self.show_mode == 0: self.fix_scatter_plot(ax,[:,c], self.labels[c], 'x') self.fix_scatter_plot(ax,[:,r], self.labels[r], 'y') axrange = [ax.get_xlim(), ax.get_ylim()] ax.clear() ax.hist2d([:,c],[:,r], self.hist_nbins, range=axrange, cmap=plt.get_cmap('Greys')) # selected data: a = ax.scatter([self.mark_data, c],[self.mark_data, r], s=100, edgecolors='black', linewidths=0.5, picker=self.pick_radius, zorder=11) a.set_facecolor(self.data_colors[self.mark_data]) self.scatter_artists[idx] = a ax.xaxis.set_major_locator(plt.AutoLocator()) ax.yaxis.set_major_locator(plt.AutoLocator()) ax.xaxis.set_major_formatter(plt.ScalarFormatter()) ax.yaxis.set_major_formatter(plt.ScalarFormatter()) if self.show_mode == 0: if self.categories[c] is not None: ax.set_xticks(np.arange(len(self.categories[c]))) ax.set_xticklabels(self.categories[c]) if self.categories[r] is not None: ax.set_yticks(np.arange(len(self.categories[r]))) ax.set_yticklabels(self.categories[r]) if magnifiedax: ax.set_xlabel(self.labels[c]) ax.set_ylabel(self.labels[r]) cax = self.scatter_ax[self.scatter_indices[:-1].index(self.scatter_indices[-1])] ax.set_xlim(cax.get_xlim()) ax.set_ylim(cax.get_ylim()) else: if c == 0: ax.set_ylabel(self.labels[r]) if self.show_mode == 0: self.fix_scatter_plot(ax,[:, c], self.labels[c], 'x') self.fix_scatter_plot(ax,[:, r], self.labels[r], 'y') if not magnifiedax: ax.xaxis.set_major_formatter(plt.NullFormatter()) if c > 0: ax.yaxis.set_major_formatter(plt.NullFormatter()) ax.set_xlim(*self.hist_ax[c].get_xlim()) ax.set_ylim(*self.hist_ax[r].get_xlim()) if magnifiedax: self._add_backdrop(ax) selector = self._create_selector(ax) self.scatter_selector[idx] = selector ax.relim(True) def _init_scatter_plots(self): """Initial plots of scatter plots.""" self.cbax = self.fig.add_axes([0.5, 0.5, 0.1, 0.5]) cbax = self.cbax n =[1] for r in range(1, n): for c in range(r): ax = self.fig.add_subplot(n, n, (r-1)*n+c+1) self.scatter_ax.append(ax) self.scatter_indices.append([c, r]) self.scatter_artists.append(None) self.scatter_selector.append(None) self._plot_scatter(ax, False, cbax) cbax = None def _plot_magnified_scatter(self): """Initial plot of the magnified scatter plot.""" ax = self.fig.add_axes([0.5, 0.9, 0.05, 0.05]) ax.set_visible(False) self.magnified_on = False c = 0 r = 1 a = ax.scatter([:, c],[:, r], s=50, edgecolors='none') a.set_facecolor(self.data_colors) a = ax.scatter([self.mark_data, c],[self.mark_data, r], s=80) a.set_facecolor(self.data_colors[self.mark_data]) ax.set_xlabel(self.labels[c]) ax.set_ylabel(self.labels[r]) self.fix_scatter_plot(ax,[:, c], self.labels[c], 'x') self.fix_scatter_plot(ax,[:, r], self.labels[r], 'y') self.scatter_ax.append(ax) self.scatter_indices.append([c, r]) self.scatter_artists.append(a) self.scatter_selector.append(None) def _plot_help(self): ax = self.fig.add_subplot(1, 1, 1) ax.set_position([0.02, 0.02, 0.96, 0.96]) ax.xaxis.set_major_locator(plt.NullLocator()) ax.yaxis.set_major_locator(plt.NullLocator()) n = len(self.mouse_actions) + len(self.key_actions) + 4 dy = 1/n y = 1 - 1.5*dy ax.text(0.05, y, 'Key shortcuts', transform=ax.transAxes, fontweight='bold') y -= dy for a, d in self.key_actions: ax.text(0.05, y, a, transform=ax.transAxes) ax.text(0.3, y, d, transform=ax.transAxes) y -= dy y -= dy ax.text(0.05, y, 'Mouse actions', transform=ax.transAxes, fontweight='bold') y -= dy for a, d in self.mouse_actions: ax.text(0.05, y, a, transform=ax.transAxes) ax.text(0.3, y, d, transform=ax.transAxes) y -= dy ax.set_visible(False) self.help_ax = ax def fix_scatter_plot(self, ax, data, label, axis): """Customize an axes of a scatter plot. This function is called after a scatter plot has been plotted. Once for the x axes, once for the y axis and once for the color bar. Reimplement this function to set appropriate limits and ticks. Return values are only used for the color bar (`axis='c'`). Otherwise they are ignored. For example, ticks for phase variables can be nicely labeled using the unicode character for pi: ``` if 'phase' in label: if axis == 'y': ax.set_ylim(0.0, 2.0*np.pi) ax.set_yticks(np.arange(0.0, 2.5*np.pi, 0.5*np.pi)) ax.set_yticklabels(['0', u'\u03c0/2', u'\u03c0', u'3\u03c0/2', u'2\u03c0']) ``` Parameters ---------- ax: matplotlib axes Axes of the scatter plot or color bar to be worked on. data: 1D array Data array of the axes. label: str Label coresponding to the data array. axis: str 'x', 'y': set properties of x or y axes of ax. 'c': set properies of color bar axes (note that ax can be None!) and return vmin, vmax, and ticks. Returns ------- min: float minimum value of color bar axis max: float maximum value of color bar axis ticks: list of float position of ticks for color bar axis """ return np.nanmin(data), np.nanmax(data), None def fix_waveform_plot(self, axs, indices): """Customize waveform plots. This function is called once after new data have been plotted into the waveform plots. Reimplement this function to customize these plots. In particular to set axis limits and labels, plot title, etc. You may even open a new figure (with non-blocking `show()`). The following member variables might be usefull: - `self.wave_data`: the full list of waveform data. - `self.wave_nested`: True if the elements of `self.wave_data` are lists of 2D arrays. Otherwise the elements are 2D arrays. The first column of a 2D array contains the x-values, further columns y-values. - `self.wave_has_xticks`: List of booleans for each axis. True if the axis has its own xticks. - `self.wave_xlabels`: List of xlabels (only for the axis where the corresponding entry in `self.wave_has_xticks` is True). - `self.wave_ylabels`: for each axis its ylabel For example, you can set the linewidth of all plotted waveforms via: ``` for ax in axs: for l in ax.lines: l.set_linewidth(3.0) ``` or enable markers to be plotted: ``` for ax, yl in zip(axs, self.wave_ylabels): if 'Power' in yl: for l in ax.lines: l.set_marker('.') l.set_markersize(15.0) l.set_markeredgewidth(0.5) l.set_markeredgecolor('k') l.set_markerfacecolor(l.get_color()) ``` Usefull is to reduce the maximum number of y-ticks: ``` axs[0].yaxis.get_major_locator().set_params(nbins=7) ``` or ``` import matplotlib.ticker as ticker axs[0].yaxis.set_major_locator(ticker.MaxNLocator(nbins=4)) ``` Parameters ---------- axs: list of matplotlib axes Axis of the waveform plots to be worked on. indices: list of int Indices of the waveforms that have been selected and plotted. """ pass def list_selection(self, indices): """List information about the current selection of data points. This function is called when 'l' is pressed. Reimplement this function, for example, to print some meaningfull information about the current selection of data points on console. You may do, however, whatever you want in this function. Parameters ---------- indices: list of int Indices of the data points that have been selected. """ print('') print('selected rows in data table:') for i in indices: print(self.valid_rows[i]) def analyze_selection(self, index): """Provide further information about a single selected data point. This function is called when a single data item was double clicked. Reimplement this function to provide some further details on this data point. This can be an additional figure window. In this case show it non-blocking: `` Parameters ---------- index: int The index of the selected data point. """ pass def _set_magnified_pos(self, width, height): """Set position of magnified plot.""" if self.magnified_on: xoffs = self.xborder/width yoffs = self.yborder/height if self.scatter_indices[-1][1] <[1]: idx = self.scatter_indices[:-1].index(self.scatter_indices[-1]) pos = self.scatter_ax[idx].get_position().get_points() else: pos = self.hist_ax[self.scatter_indices[-1][0]].get_position().get_points() pos[0] = np.mean(pos, 0) - 0.5*self.magnified_size if pos[0][0] < xoffs: pos[0][0] = xoffs if pos[0][1] < yoffs: pos[0][1] = yoffs pos[1] = pos[0] + self.magnified_size if pos[1][0] > 1.0-self.spacing/width: pos[1][0] = 1.0-self.spacing/width if pos[1][1] > 1.0-self.spacing/height: pos[1][1] = 1.0-self.spacing/height pos[0] = pos[1] - self.magnified_size self.scatter_ax[-1].set_position([pos[0][0], pos[0][1], self.magnified_size[0], self.magnified_size[1]]) self.scatter_ax[-1].set_visible(True) else: self.scatter_ax[-1].set_position([0.5, 0.9, 0.05, 0.05]) self.scatter_ax[-1].set_visible(False) def _make_selection(self, ax, key, x0, x1, y0, y1): """Select points from a scatter or histogram plot.""" if not key in ['shift', 'control']: self.mark_data = [] if ax in self.scatter_ax: axi = self.scatter_ax.index(ax) # from scatter plots: c, r = self.scatter_indices[axi] if r <[1]: # from scatter: for ind, (x, y) in enumerate(zip([:, c],[:, r])): if x >= x0 and x <= x1 and y >= y0 and y <= y1: if ind in self.mark_data: if key == 'control': self.mark_data.remove(ind) elif key != 'control': self.mark_data.append(ind) else: # from histogram: for ind, x in enumerate([:, c]): if x >= x0 and x <= x1: if ind in self.mark_data: if key == 'control': self.mark_data.remove(ind) elif key != 'control': self.mark_data.append(ind) elif ax in self.hist_ax: r = self.hist_indices[self.hist_ax.index(ax)] # from histogram: for ind, x in enumerate([:, r]): if x >= x0 and x <= x1: if ind in self.mark_data: if key == 'control': self.mark_data.remove(ind) elif key != 'control': self.mark_data.append(ind) def _update_selection(self): """Highlight selected points in the scatter plots and plot corresponding waveforms.""" # update scatter plots: for artist, (c, r) in zip(self.scatter_artists, self.scatter_indices): if artist is not None: if len(self.mark_data) == 0: artist.set_offsets(np.zeros((0, 2))) else: artist.set_offsets(list(zip([self.mark_data, c],[self.mark_data, r]))) artist.set_facecolors(self.data_colors[self.mark_data]) # waveform plots: if len(self.wave_ax) > 0: axdi = 0 axti = 1 for xi, ax in enumerate(self.wave_ax): ax.clear() if len(self.mark_data) > 0: for idx in self.mark_data: if self.wave_nested: data = self.wave_data[idx][axdi] else: data = self.wave_data[idx] if data is not None: ax.plot(data[:, 0], data[:, axti], c=self.data_colors[idx], picker=self.pick_radius) axti += 1 if self.wave_has_xticks[xi]: ax.set_xlabel(self.wave_xlabels[axdi]) axti = 1 axdi += 1 #else: # ax.xaxis.set_major_formatter(plt.NullFormatter()) for ax, ylabel in zip(self.wave_ax, self.wave_ylabels): ax.set_ylabel(ylabel) if not isinstance(self.wave_title, bool) and self.wave_title: self.wave_ax[0].set_title(self.wave_title) self.fix_waveform_plot(self.wave_ax, self.mark_data) self.fig.canvas.draw() def _set_limits(self, ax, x0, x1, y0, y1): if ax in self.hist_ax: ax.set_xlim(x0, x1) for hax in self.hist_ax: hax.set_ylim(y0, y1) cc = self.hist_indices[self.hist_ax.index(ax)] for sax, (c, r) in zip(self.scatter_ax, self.scatter_indices): if c == cc: sax.set_xlim(x0, x1) if r == cc: sax.set_ylim(x0, x1) if ax in self.scatter_ax: idx = self.scatter_ax.index(ax) cc, rr = self.scatter_indices[idx] self.hist_ax[cc].set_xlim(x0, x1) self.hist_ax[rr].set_xlim(y0, y1) for sax, (c, r) in zip(self.scatter_ax, self.scatter_indices): if c == cc: sax.set_xlim(x0, x1) if c == rr: sax.set_xlim(y0, y1) if r == cc: sax.set_ylim(x0, x1) if r == rr: sax.set_ylim(y0, y1) def _on_key(self, event): """Handle key events.""" #print('pressed', event.key) if event.key in ['left', 'right', 'up', 'down']: if self.magnified_on: mc, mr = self.scatter_indices[-1] if event.key == 'left': if mc > 0: self.scatter_indices[-1][0] -= 1 elif mr > 1: if mr >=[1]: self.scatter_indices[-1][1] = self.maxcols - 1 else: self.scatter_indices[-1][1] -= 1 self.scatter_indices[-1][0] = self.scatter_indices[-1][1] - 1 else: self.scatter_indices[-1][0] =[1] - 1 self.scatter_indices[-1][1] =[1] elif event.key == 'right': if mc < mr - 1 and mc < self.maxcols - 1: self.scatter_indices[-1][0] += 1 elif mr < self.maxcols: self.scatter_indices[-1][0] = 0 self.scatter_indices[-1][1] += 1 if mr >= self.maxcols: self.scatter_indices[-1][1] =[1] else: self.scatter_indices[-1][0] = 0 self.scatter_indices[-1][1] = 1 elif event.key == 'up': if mr > mc + 1: if mr >=[1]: self.scatter_indices[-1][1] = self.maxcols - 1 else: self.scatter_indices[-1][1] -= 1 elif mc > 0: self.scatter_indices[-1][0] -= 1 self.scatter_indices[-1][1] =[1] else: self.scatter_indices[-1][0] =[1] - 1 self.scatter_indices[-1][1] =[1] elif event.key == 'down': if mr < self.maxcols: self.scatter_indices[-1][1] += 1 if mr >= self.maxcols: self.scatter_indices[-1][1] =[1] elif mc < self.maxcols - 1: self.scatter_indices[-1][0] += 1 self.scatter_indices[-1][1] = mc + 2 if self.scatter_indices[-1][1] >= self.maxcols: self.scatter_indices[-1][1] =[1] else: self.scatter_indices[-1][0] = 0 self.scatter_indices[-1][1] = 1 for k in reversed(range(len(self.zoom_stack))): if self.zoom_stack[k][0] == self.scatter_ax[-1]: del self.zoom_stack[k] self.scatter_ax[-1].clear() self.scatter_ax[-1].set_visible(True) self.magnified_on = True self._set_magnified_pos(self.fig.get_window_extent().width, self.fig.get_window_extent().height) if self.scatter_indices[-1][1] <[1]: self._plot_scatter(self.scatter_ax[-1], True) else: self._plot_hist(self.scatter_ax[-1], True) self.fig.canvas.draw() else: if event.key == 'escape': if len(self.scatter_ax) >=[1]: self.scatter_ax[-1].set_position([0.5, 0.9, 0.05, 0.05]) self.magnified_on = False self.scatter_ax[-1].set_visible(False) self.fig.canvas.draw() elif event.key in 'oz': self.select_zooms = not self.select_zooms elif event.key == 'backspace': if len(self.zoom_stack) > 0: self._set_limits(*self.zoom_stack.pop()) self.fig.canvas.draw() elif event.key in '+=': self.pick_radius *= 1.5 elif event.key in '-': if self.pick_radius > 5.0: self.pick_radius /= 1.5 elif event.key in '0': self.pick_radius = 4.0 elif event.key in ['pageup', 'pagedown', '<', '>']: if event.key in ['pageup', '<'] and self.maxcols > 2: self.maxcols -= 1 elif event.key in ['pagedown', '>'] and self.maxcols < self.raw_data.shape[1]: self.maxcols += 1 for ax in self.hist_ax: self._plot_hist(ax, False) self._update_layout() elif event.key == 'w': if len(self.wave_data) > 0: if self.maxcols > 0: self.all_maxcols[self.show_mode] = self.maxcols self.maxcols = 0 else: self.maxcols = self.all_maxcols[self.show_mode] self._set_layout(self.fig.get_window_extent().width, self.fig.get_window_extent().height) self.fig.canvas.draw() elif event.key == 'ctrl+a': self.mark_data = list(range(len( self._update_selection() elif event.key in 'cC': if event.key in 'c': self.color_index -= 1 if self.color_index < 0: self.color_set_index -= 1 if self.color_set_index < -1: self.color_set_index = len(self.all_data)-1 if self.color_set_index >= 0: if self.all_data[self.color_set_index] is None: self.compute_pca(self.color_set_index>1, True) self.color_index = self.all_data[self.color_set_index].shape[1]-1 else: self.color_index = 0 if self.extra_colors is None else 1 self._set_color_column() else: self.color_index += 1 if (self.color_set_index >= 0 and \ self.color_index >= self.all_data[self.color_set_index].shape[1]) or \ (self.color_set_index < 0 and \ self.color_index >= (1 if self.extra_colors is None else 2)): self.color_index = 0 self.color_set_index += 1 if self.color_set_index >= len(self.all_data): self.color_set_index = -1 elif self.all_data[self.color_set_index] is None: self.compute_pca(self.color_set_index>1, True) self._set_color_column() for ax in self.scatter_ax: ax.collections[0].set_facecolors(self.data_colors) for a in self.scatter_artists: if a is not None: a.set_facecolors(self.data_colors[self.mark_data]) for ax in self.wave_ax: for l, c in zip(ax.lines, self.data_colors[self.mark_data]): l.set_color(c) l.set_markerfacecolor(c) self._plot_scatter(self.scatter_ax[0], False, self.cbax) self.fix_scatter_plot(self.cbax, self.color_values, self.color_label, 'c') self.fig.canvas.draw() elif event.key in 'nN': if event.key in 'N': self.hist_nbins = (self.hist_nbins*3)//2 elif self.hist_nbins >= 15: self.hist_nbins = (self.hist_nbins*2)//3 else: self.hist_nbins = 10 for ax in self.hist_ax: self._plot_hist(ax, False) self._set_hist_ylim() if self.scatter_indices[-1][1] >=[1]: self._plot_hist(self.scatter_ax[-1], True, True) elif not self.scatter: self._plot_scatter(self.scatter_ax[-1], True) if not self.scatter: for ax in self.scatter_ax[:-1]: self._plot_scatter(ax, False) self.fig.canvas.draw() elif event.key in 'H': self.scatter = not self.scatter for ax in self.scatter_ax[:-1]: self._plot_scatter(ax, False) if self.scatter_indices[-1][1] <[1]: self._plot_scatter(self.scatter_ax[-1], True) self.fig.canvas.draw() elif event.key in 'pP': if len(self.scatter_ax) >=[1]: self.scatter_ax[-1].set_position([0.5, 0.9, 0.05, 0.05]) self.scatter_indices[-1] = [0, 1] self.magnified_on = False self.scatter_ax[-1].set_visible(False) self.all_maxcols[self.show_mode] = self.maxcols if event.key == 'P': self.show_mode += 1 if self.show_mode >= len(self.all_data): self.show_mode = 0 else: self.show_mode -= 1 if self.show_mode < 0: self.show_mode = len(self.all_data)-1 if self.show_mode == 1: print('principal components') elif self.show_mode == 2: print('scaled principal components') else: print('data') if self.all_data[self.show_mode] is None: self.compute_pca(self.show_mode>1, True) = self.all_data[self.show_mode] self.labels = self.all_labels[self.show_mode] self.maxcols = self.all_maxcols[self.show_mode] self.zoom_stack = [] self.fig.canvas.manager.set_window_title(self.title + ': ' + self.all_titles[self.show_mode]) for ax in self.hist_ax: self._plot_hist(ax, False) self._set_hist_ylim() for ax in self.scatter_ax: self._plot_scatter(ax, False) self._update_layout() elif event.key in 'l': if len(self.mark_data) > 0: self.list_selection(self.mark_data) elif event.key in 'h': self.help_ax.set_visible(not self.help_ax.get_visible()) self.fig.canvas.draw() def _on_select(self, eclick, erelease): """Handle selection events.""" if eclick.dblclick: if len(self.mark_data) > 0: self.analyze_selection(self.mark_data[-1]) return x0 = min(eclick.xdata, erelease.xdata) x1 = max(eclick.xdata, erelease.xdata) y0 = min(eclick.ydata, erelease.ydata) y1 = max(eclick.ydata, erelease.ydata) ax = erelease.inaxes if ax is None: ax = eclick.inaxes xmin, xmax = ax.get_xlim() ymin, ymax = ax.get_ylim() dx = 0.02*(xmax-xmin) dy = 0.02*(ymax-ymin) if x1 - x0 < dx and y1 - y0 < dy: bbox = ax.get_window_extent().transformed(self.fig.dpi_scale_trans.inverted()) width, height = bbox.width, bbox.height width *= self.fig.dpi height *= self.fig.dpi dx = self.pick_radius*(xmax-xmin)/width dy = self.pick_radius*(ymax-ymin)/height x0 = erelease.xdata - dx x1 = erelease.xdata + dx y0 = erelease.ydata - dy y1 = erelease.ydata + dy elif self.select_zooms: self.zoom_stack.append((ax, xmin, xmax, ymin, ymax)) self._set_limits(ax, x0, x1, y0, y1) self._make_selection(ax, erelease.key, x0, x1, y0, y1) self._update_selection() def _on_pick(self, event): """Handle pick events.""" for ax in self.wave_ax: for k, l in enumerate(ax.lines): if l is event.artist: self.mark_data = [self.mark_data[k]] for ax in self.scatter_ax: if ax.collections[0] is event.artist: self.mark_data = event.ind self._update_selection() if event.mouseevent.dblclick: if len(self.mark_data) > 0: self.analyze_selection(self.mark_data[-1]) def _set_layout(self, width, height): """Update positions and visibility of all plots.""" xoffs = self.xborder/width yoffs = self.yborder/height xs = self.spacing/width ys = self.spacing/height if self.maxcols > 0: dx = (1.0-xoffs)/self.maxcols dy = (1.0-yoffs)/self.maxcols xw = dx - xs yw = dy - ys # histograms: for c, ax in enumerate(self.hist_ax): if c < self.maxcols: ax.set_position([xoffs+c*dx, yoffs, xw, yw]) ax.set_visible(True) else: ax.set_visible(False) ax.set_position([0.99, 0.01, 0.01, 0.01]) # scatter plots: for ax, (c, r) in zip(self.scatter_ax[:-1], self.scatter_indices[:-1]): if r < self.maxcols: ax.set_position([xoffs+c*dx, yoffs+(self.maxcols-r)*dy, xw, yw]) ax.set_visible(True) else: ax.set_visible(False) ax.set_position([0.99, 0.01, 0.01, 0.01]) # color bar: if self.maxcols > 0: self.cbax.set_position([xoffs+dx, yoffs+(self.maxcols-1)*dy, 0.3*xoffs, yw]) self.cbax.set_visible(True) else: self.cbax.set_visible(False) self.cbax.set_position([0.99, 0.01, 0.01, 0.01]) # magnified plot: if self.maxcols > 0: self._set_magnified_pos(width, height) if self.magnified_backdrop is not None: bbox = self.scatter_ax[-1].get_tightbbox(self.fig.canvas.get_renderer()) if bbox is not None: self.magnified_backdrop.set_bounds(bbox.x0 - self.mborder, bbox.y0 - self.mborder, bbox.width + 2*self.mborder, bbox.height + 2*self.mborder) else: self.scatter_ax[-1].set_position([0.5, 0.9, 0.05, 0.05]) self.scatter_ax[-1].set_visible(False) # waveform plots: if len(self.wave_ax) > 0: if self.maxcols > 0: x0 = xoffs+((self.maxcols+1)//2)*dx y0 = ((self.maxcols+1)//2)*dy if self.maxcols%2 == 0: x0 += xoffs y0 += yoffs - ys else: y0 += ys else: x0 = xoffs y0 = 0.0 yp = 1.0 dy = 1.0-y0 dy -= np.sum(self.wave_has_xticks)*yoffs yp -= ys dy -= ys if self.wave_title: yp -= 2*ys dy -= 2*ys dy /= len(self.wave_ax) for ax, has_xticks in zip(self.wave_ax, self.wave_has_xticks): yp -= dy ax.set_position([x0, yp, 1.0-x0-xs, dy]) if has_xticks: yp -= yoffs else: yp -= ys def _update_layout(self): """Update content and position of magnified plot.""" if self.scatter_indices[-1][1] <[1]: if self.scatter_indices[-1][1] >= self.maxcols: self.scatter_indices[-1][1] = self.maxcols-1 if self.scatter_indices[-1][0] >= self.scatter_indices[-1][1]: self.scatter_indices[-1][0] = self.scatter_indices[-1][1]-1 self._plot_scatter(self.scatter_ax[-1], True) else: if self.scatter_indices[-1][0] >= self.maxcols: self.scatter_indices[-1][0] = self.maxcols-1 self._plot_hist(self.scatter_ax[-1], True) self._set_hist_ylim() self._set_layout(self.fig.get_window_extent().width, self.fig.get_window_extent().height) self.fig.canvas.draw() def _on_resize(self, event): """Adapt layout of plots to new figure size.""" self._set_layout(event.width, event.height)
Simple matplotlib-based GUI for viewing and exploring multivariate data.
Shown are scatter plots of all pairs of variables or PCA axis. Points in the scatter plots are colored according to the values of one of the variables. Data points can be selected and optionally corresponding waveforms are shown.
First you initialize the explorer with the data. Then you optionally specify how to colorize the data and provide waveform data associated with the data. Finally you show the figure:
expl = MultivariateExplorer(data) expl.set_colors(2) expl.set_wave_data(waveforms, 'Time [s]', 'Sine')
The `compute_pca() function computes a principal component analysis (PCA) on the input data, and
writes the principal components to a file.Customize the appearance and information provided by subclassing MultivariateExplorer and reimplementing the functions - fix_scatter_plot() - fix_waveform_plot() - list_selection() - analyze_selection() See the documentation of these functions for details.
Initialize explorer with scatter-plot data.
:TableData, 2D array,
of1D arrays
- The data to be explored. Each column is a variable. For the 2D array the columns are the second dimension, for a list of 1D arrays, the list goes over columns, i.e. each 1D array is one column.
- If data is not a TableData, then this provides labels for the data columns.
- Title for the window.
Class variables
var mouse_actions
List of tuples with mouse actions and a description of their action.
var key_actions
List of tuples with key shortcuts and a description of their action.
def set_wave_data(self, data, xlabels='', ylabels=[], title=False)
Expand source code
def set_wave_data(self, data, xlabels='', ylabels=[], title=False): """Add waveform data to explorer. Parameters ---------- data: list of (list of) 2D arrays Waveform data associated with each row of the data. Elements of the outer list correspond to the rows of the data. The inner 2D arrays contain a common x-axes (first column) and one or more corresponding y-values (second and optional higher columns). Each column for y-values is plotted in its own axes on top of each other, from top to bottom. The optional inner list of 2D arrays contains several 2D arrays as ascribed above each with its own common x-axes. xlabel: str or list of str The xlabels for the waveform plots. If only a string is given, then there will be a common xaxis for all the plots, and only the lowest one gets a labeled xaxis. If a list of strings is given, each waveform plot gets its own labeled x-axis. ylabels: list of str The ylabels for each of the waveform plots. title: bool or str If True or a string, povide space on top of the waveform plots for a title. If string, set this as the title for the waveform plots. """ self.wave_data = [] if data is not None and len(data) > 0: self.wave_data = data self.wave_has_xticks = [] self.wave_nested = isinstance(data[0], (list, tuple)) if self.wave_nested: for data in self.wave_data[0]: for k in range(data.shape[1]-2): self.wave_has_xticks.append(False) self.wave_has_xticks.append(True) else: for k in range(self.wave_data[0].shape[1]-2): self.wave_has_xticks.append(False) self.wave_has_xticks.append(True) if isinstance(xlabels, (list, tuple)): self.wave_xlabels = xlabels else: self.wave_xlabels = [xlabels] self.wave_ylabels = ylabels self.wave_title = title self.wave_ax = []
Add waveform data to explorer.
of(list of) 2D arrays
- Waveform data associated with each row of the data. Elements of the outer list correspond to the rows of the data. The inner 2D arrays contain a common x-axes (first column) and one or more corresponding y-values (second and optional higher columns). Each column for y-values is plotted in its own axes on top of each other, from top to bottom. The optional inner list of 2D arrays contains several 2D arrays as ascribed above each with its own common x-axes.
- The xlabels for the waveform plots. If only a string is given, then there will be a common xaxis for all the plots, and only the lowest one gets a labeled xaxis. If a list of strings is given, each waveform plot gets its own labeled x-axis.
- The ylabels for each of the waveform plots.
- If True or a string, povide space on top of the waveform plots for a title. If string, set this as the title for the waveform plots.
def set_colors(self, colors=0, color_label=None, color_map=None)
Expand source code
def set_colors(self, colors=0, color_label=None, color_map=None): """Set data column used to color scatter plots. Parameters ---------- colors: int or 1D array Index to colum in data to be used for coloring scatter plots. -2 for coloring row index of data. Or data array used to color scalar plots. color_label: str If colors is an array, this is a label describing the data. It is used to label the color bar. color_map: str or None Name of a matplotlib color map. If None 'jet' is used. """ if isinstance(colors, (np.integer, int)): if colors < 0: self.color_set_index = -1 self.color_index = 0 else: self.color_set_index = 0 self.color_index = colors else: if not isinstance(colors[0], (int, float, np.integer, np.floating)): # categorial data: self.extra_categories, self.extra_colors = categorize(colors) else: self.extra_colors = colors self.extra_colors = self.extra_colors[self.valid_samples] self.extra_color_label = color_label self.color_set_index = -1 self.color_index = 1 self.color_map = plt.get_cmap(color_map if color_map else 'jet')
Set data column used to color scatter plots.
or1D array
- Index to colum in data to be used for coloring scatter plots.
- -2 for coloring row index of data.
- Or data array used to color scalar plots.
- If colors is an array, this is a label describing the data.
- It is used to label the color bar.
- Name of a matplotlib color map. If None 'jet' is used.
def show(self, ioff=True)
Expand source code
def show(self, ioff=True): """Show interactive scatter plots for exploration. """ if ioff: plt.ioff() else: plt.ion() plt.rcParams['toolbar'] = 'None' plt.rcParams['keymap.quit'] = 'ctrl+w, alt+q, ctrl+q, q' plt.rcParams['font.size'] = 12 self.fig = plt.figure(facecolor='white', figsize=(10, 8)) self.fig.canvas.manager.set_window_title(self.title + ': ' + self.all_titles[self.show_mode]) self.fig.canvas.mpl_connect('key_press_event', self._on_key) self.fig.canvas.mpl_connect('resize_event', self._on_resize) self.fig.canvas.mpl_connect('pick_event', self._on_pick) if self.color_map is None: self.color_map = plt.get_cmap('jet') self._set_color_column() self._init_hist_plots() self._init_scatter_plots() self.wave_ax = [] if self.wave_data is not None and len(self.wave_data) > 0: axx = None xi = 0 for k, has_xticks in enumerate(self.wave_has_xticks): ax = self.fig.add_subplot(1, len(self.wave_has_xticks), 1+k, sharex=axx) self.wave_ax.append(ax) if has_xticks: if xi >= len(self.wave_xlabels): self.wave_xlabels.append('') ax.set_xlabel(self.wave_xlabels[xi]) xi += 1 axx = None else: #ax.xaxis.set_major_formatter(plt.NullFormatter()) if axx is None: axx = ax for ax, ylabel in zip(self.wave_ax, self.wave_ylabels): ax.set_ylabel(ylabel) if not isinstance(self.wave_title, bool) and self.wave_title: self.wave_ax[0].set_title(self.wave_title) self.fix_waveform_plot(self.wave_ax, self.mark_data) self._plot_magnified_scatter() self._plot_help()
Show interactive scatter plots for exploration.
def compute_pca(self, scale=False, write=False)
Expand source code
def compute_pca(self, scale=False, write=False): """Compute PCA based on the data. Parameters ---------- scale: boolean If True standardize data before computing PCA, i.e. remove mean of each variabel and divide by its standard deviation. write: boolean If True write PCA components to standard out. """ # pca: pca = decomposition.PCA() if scale: scaler = preprocessing.StandardScaler() pca_label = 'sPC' else: pca_label = 'PC' for k in range(len(pca.components_)): if np.abs(np.min(pca.components_[k])) > np.max(pca.components_[k]): pca.components_[k] *= -1.0 pca_data = pca.transform(self.raw_data) pca_labels = [f'{pca_label}{k+1} ' + (f'({100*v:.1f}%)' if v > 0.01 else (f'{100*v:.2f}%')) for k, v in enumerate(pca.explained_variance_ratio_)] if np.min(pca.explained_variance_ratio_) >= 0.01: pca_maxcols = pca_data.shape[1] else: pca_maxcols = np.argmax(pca.explained_variance_ratio_ < 0.01) if pca_maxcols < 2: pca_maxcols = 2 if pca_maxcols > 6: pca_maxcols = 6 # table with PCA feature weights: pca_table = self.pca_tables[1] if scale else self.pca_tables[0] pca_table.clear_data() pca_table.set_section(pca_label, 0, pca_table.nsecs) for k, comp in enumerate(pca.components_): pca_table.append_data(k+1, 0) pca_table.append_data(100.0*pca.explained_variance_ratio_[k]) pca_table.append_data(comp) if write: pca_table.write(table_format='out', unit_style='none') # submit data: if scale: self.all_data[2] = pca_data self.all_labels[2] = pca_labels self.all_maxcols[2] = pca_maxcols else: self.all_data[1] = pca_data self.all_labels[1] = pca_labels self.all_maxcols[1] = pca_maxcols
Compute PCA based on the data.
- If True standardize data before computing PCA, i.e. remove mean of each variabel and divide by its standard deviation.
- If True write PCA components to standard out.
def save_pca(self, file_name, scale, **kwargs)
Expand source code
def save_pca(self, file_name, scale, **kwargs): """Write PCA data to file. Parameters ---------- file_name: str Name of ouput file. scale: boolean If True write PCA components of standardized PCA. kwargs: dict Additional parameter for TableData.write() """ if scale: pca_file = file_name + '-pcacor' pca_table = self.pca_tables[1] else: pca_file = file_name + '-pcacov' pca_table = self.pca_tables[0] if 'unit_style' in kwargs: del kwargs['unit_style'] if 'table_format' in kwargs: pca_table.write(pca_file, unit_style='none', **kwargs) else: pca_file += '.dat' pca_table.write(pca_file, unit_style='none')
Write PCA data to file.
- Name of ouput file.
- If True write PCA components of standardized PCA.
- Additional parameter for TableData.write()
def fix_scatter_plot(self, ax, data, label, axis)
Expand source code
def fix_scatter_plot(self, ax, data, label, axis): """Customize an axes of a scatter plot. This function is called after a scatter plot has been plotted. Once for the x axes, once for the y axis and once for the color bar. Reimplement this function to set appropriate limits and ticks. Return values are only used for the color bar (`axis='c'`). Otherwise they are ignored. For example, ticks for phase variables can be nicely labeled using the unicode character for pi: ``` if 'phase' in label: if axis == 'y': ax.set_ylim(0.0, 2.0*np.pi) ax.set_yticks(np.arange(0.0, 2.5*np.pi, 0.5*np.pi)) ax.set_yticklabels(['0', u'\u03c0/2', u'\u03c0', u'3\u03c0/2', u'2\u03c0']) ``` Parameters ---------- ax: matplotlib axes Axes of the scatter plot or color bar to be worked on. data: 1D array Data array of the axes. label: str Label coresponding to the data array. axis: str 'x', 'y': set properties of x or y axes of ax. 'c': set properies of color bar axes (note that ax can be None!) and return vmin, vmax, and ticks. Returns ------- min: float minimum value of color bar axis max: float maximum value of color bar axis ticks: list of float position of ticks for color bar axis """ return np.nanmin(data), np.nanmax(data), None
Customize an axes of a scatter plot.
This function is called after a scatter plot has been plotted. Once for the x axes, once for the y axis and once for the color bar. Reimplement this function to set appropriate limits and ticks.
Return values are only used for the color bar (
). Otherwise they are ignored.For example, ticks for phase variables can be nicely labeled using the unicode character for pi:
if 'phase' in label: if axis == 'y': ax.set_ylim(0.0, 2.0*np.pi) ax.set_yticks(np.arange(0.0, 2.5*np.pi, 0.5*np.pi)) ax.set_yticklabels(['0', u'π/2', u'π', u'3π/2', u'2π'])
:matplotlib axes
- Axes of the scatter plot or color bar to be worked on.
:1D array
- Data array of the axes.
- Label coresponding to the data array.
- 'x', 'y': set properties of x or y axes of ax. 'c': set properies of color bar axes (note that ax can be None!) and return vmin, vmax, and ticks.
- minimum value of color bar axis
- maximum value of color bar axis
- position of ticks for color bar axis
def fix_waveform_plot(self, axs, indices)
Expand source code
def fix_waveform_plot(self, axs, indices): """Customize waveform plots. This function is called once after new data have been plotted into the waveform plots. Reimplement this function to customize these plots. In particular to set axis limits and labels, plot title, etc. You may even open a new figure (with non-blocking `show()`). The following member variables might be usefull: - `self.wave_data`: the full list of waveform data. - `self.wave_nested`: True if the elements of `self.wave_data` are lists of 2D arrays. Otherwise the elements are 2D arrays. The first column of a 2D array contains the x-values, further columns y-values. - `self.wave_has_xticks`: List of booleans for each axis. True if the axis has its own xticks. - `self.wave_xlabels`: List of xlabels (only for the axis where the corresponding entry in `self.wave_has_xticks` is True). - `self.wave_ylabels`: for each axis its ylabel For example, you can set the linewidth of all plotted waveforms via: ``` for ax in axs: for l in ax.lines: l.set_linewidth(3.0) ``` or enable markers to be plotted: ``` for ax, yl in zip(axs, self.wave_ylabels): if 'Power' in yl: for l in ax.lines: l.set_marker('.') l.set_markersize(15.0) l.set_markeredgewidth(0.5) l.set_markeredgecolor('k') l.set_markerfacecolor(l.get_color()) ``` Usefull is to reduce the maximum number of y-ticks: ``` axs[0].yaxis.get_major_locator().set_params(nbins=7) ``` or ``` import matplotlib.ticker as ticker axs[0].yaxis.set_major_locator(ticker.MaxNLocator(nbins=4)) ``` Parameters ---------- axs: list of matplotlib axes Axis of the waveform plots to be worked on. indices: list of int Indices of the waveforms that have been selected and plotted. """ pass
Customize waveform plots.
This function is called once after new data have been plotted into the waveform plots. Reimplement this function to customize these plots. In particular to set axis limits and labels, plot title, etc. You may even open a new figure (with non-blocking
).The following member variables might be usefull: -
: the full list of waveform data. -self.wave_nested
: True if the elements ofself.wave_data
are lists of 2D arrays. Otherwise the elements are 2D arrays. The first column of a 2D array contains the x-values, further columns y-values. -self.wave_has_xticks
: List of booleans for each axis. True if the axis has its own xticks. -self.wave_xlabels
: List of xlabels (only for the axis where the corresponding entry inself.wave_has_xticks
is True). -self.wave_ylabels
: for each axis its ylabelFor example, you can set the linewidth of all plotted waveforms via:
for ax in axs: for l in ax.lines: l.set_linewidth(3.0)
or enable markers to be plotted:
for ax, yl in zip(axs, self.wave_ylabels): if 'Power' in yl: for l in ax.lines: l.set_marker('.') l.set_markersize(15.0) l.set_markeredgewidth(0.5) l.set_markeredgecolor('k') l.set_markerfacecolor(l.get_color())
Usefull is to reduce the maximum number of y-ticks:
import matplotlib.ticker as ticker axs[0].yaxis.set_major_locator(ticker.MaxNLocator(nbins=4))
ofmatplotlib axes
- Axis of the waveform plots to be worked on.
- Indices of the waveforms that have been selected and plotted.
def list_selection(self, indices)
Expand source code
def list_selection(self, indices): """List information about the current selection of data points. This function is called when 'l' is pressed. Reimplement this function, for example, to print some meaningfull information about the current selection of data points on console. You may do, however, whatever you want in this function. Parameters ---------- indices: list of int Indices of the data points that have been selected. """ print('') print('selected rows in data table:') for i in indices: print(self.valid_rows[i])
List information about the current selection of data points.
This function is called when 'l' is pressed. Reimplement this function, for example, to print some meaningfull information about the current selection of data points on console. You may do, however, whatever you want in this function.
- Indices of the data points that have been selected.
def analyze_selection(self, index)
Expand source code
def analyze_selection(self, index): """Provide further information about a single selected data point. This function is called when a single data item was double clicked. Reimplement this function to provide some further details on this data point. This can be an additional figure window. In this case show it non-blocking: `` Parameters ---------- index: int The index of the selected data point. """ pass
Provide further information about a single selected data point.
This function is called when a single data item was double clicked. Reimplement this function to provide some further details on this data point. This can be an additional figure window. In this case show it non-blocking:
- The index of the selected data point.
class PrintHelp (option_strings,
Expand source code
class PrintHelp(argparse.Action): def __call__(self, parser, namespace, values, option_string): parser.print_help() print('') print('mouse:') for ma in MultivariateExplorer.mouse_actions: print('%-23s %s' % ma) print('%-23s %s' % ('double left click', 'run thunderfish on selected EOD waveform')) print('') print('key shortcuts:') for ka in MultivariateExplorer.key_actions: print('%-23s %s' % ka) parser.exit()
Information about how to convert command line strings to Python objects.
Action objects are used by an ArgumentParser to represent the information needed to parse a single argument from one or more strings from the command line. The keyword arguments to the Action constructor are also all attributes of Action instances.
Keyword Arguments:
- option_strings -- A list of command-line option strings which should be associated with this action. - dest -- The name of the attribute to hold the created object(s) - nargs -- The number of command-line arguments that should be consumed. By default, one argument will be consumed and a single value will be produced. Other values include: - N (an integer) consumes N arguments (and produces a list) - '?' consumes zero or one arguments - '*' consumes zero or more arguments (and produces a list) - '+' consumes one or more arguments (and produces a list) Note that the difference between the default and nargs=1 is that with the default, a single value will be produced, while with nargs=1, a list containing a single value will be produced. - const -- The value to be produced if the option is specified and the option uses an action that takes no values. - default -- The value to be produced if the option is not specified. - type -- A callable that accepts a single string argument, and returns the converted value. The standard Python types str, int, float, and complex are useful examples of such callables. If None, str is used. - choices -- A container of values that should be allowed. If not None, after a command-line argument has been converted to the appropriate type, an exception will be raised if it is not a member of this collection. - required -- True if the action must always be specified at the command line. This is only meaningful for optional command-line arguments. - help -- The help string describing the argument. - metavar -- The name to be used for the option's argument with the help string. If None, the 'dest' value will be used as the name.
- argparse.Action
- argparse._AttributeHolder