Coverage for src/thunderlab/multivariateexplorer.py: 88%
981 statements
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-18 22:10 +0000
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-18 22:10 +0000
1"""Simple GUI for viewing and exploring multivariate data.
3- `class MultiVariateExplorer`: simple matplotlib-based GUI for viewing and exploring multivariate data.
4- `categorize()`: convert categorial string data into integer categories.
5- `select_features()`: assemble list of column indices.
6- `select_coloring()`: select column from data table for colorizing the data.
7- `list_available_features()`: print available features on console.
8"""
10import sys
11import numpy as np
12from scipy.stats import pearsonr
13from sklearn import decomposition
14from sklearn import preprocessing
15import matplotlib.pyplot as plt
16import matplotlib.patches as patches
17import matplotlib.widgets as widgets
18import argparse
19from .version import __version__, __year__
20from .tabledata import TableData
23class MultivariateExplorer(object):
24 """Simple matplotlib-based GUI for viewing and exploring multivariate data.
26 Shown are scatter plots of all pairs of variables or PCA axis.
27 Points in the scatter plots are colored according to the values of one of the variables.
28 Data points can be selected and optionally corresponding waveforms are shown.
30 First you initialize the explorer with the data. Then you optionally
31 specify how to colorize the data and provide waveform data
32 associated with the data. Finally you show the figure:
33 ```
34 expl = MultivariateExplorer(data)
35 expl.set_colors(2)
36 expl.set_wave_data(waveforms, 'Time [s]', 'Sine')
37 expl.show()
38 ```
40 The `compute_pca() function computes a principal component analysis (PCA)
41 on the input data, and `save_pca()` writes the principal components to a file.
43 Customize the appearance and information provided by subclassing
44 MultivariateExplorer and reimplementing the functions
45 - fix_scatter_plot()
46 - fix_waveform_plot()
47 - list_selection()
48 - analyze_selection()
49 See the documentation of these functions for details.
50 """
52 mouse_actions = [
53 ('left click', 'select sample'),
54 ('left and drag', 'rectangular selection of samples and/or zoom'),
55 ('shift + left click/drag', 'add samples to selection'),
56 ('ctrl + left click/drag', 'remove samples from selection')
57 ]
58 """List of tuples with mouse actions and a description of their action."""
60 key_actions = [
61 ('c, C', 'cycle color map trough data columns'),
62 ('p,P', 'toggle between features, PCs, and scaled PCs'),
63 ('<, pageup', 'decrease number of displayed featured/PCs'),
64 ('>, pagedown', 'increase number of displayed features/PCs'),
65 ('o, z', 'toggle zoom mode on or off'),
66 ('backspace', 'zoom back'),
67 ('n, N', 'decrease, increase number of bins of histograms'),
68 ('H', 'toggle between scatter plot and 2D histogram'),
69 ('left, right, up, down', 'show and move magnified scatter plot'),
70 ('escape', 'close magnified scatter plot'),
71 ('ctrl + a', 'select all'),
72 ('+, -', 'increase, decrease pick radius'),
73 ('0', 'reset pick radius'),
74 ('l', 'list selection on console'),
75 ('w', 'toggle maximized waveform plot'),
76 ('h', 'toggle help window'),
77 ]
78 """List of tuples with key shortcuts and a description of their action."""
80 def __init__(self, data, labels=None, title=None):
81 """Initialize explorer with scatter-plot data.
83 Parameters
84 ----------
85 data: TableData, 2D array, or list of 1D arrays
86 The data to be explored. Each column is a variable.
87 For the 2D array the columns are the second dimension,
88 for a list of 1D arrays, the list goes over columns,
89 i.e. each 1D array is one column.
90 labels: list of str
91 If data is not a TableData, then this provides labels
92 for the data columns.
93 title: str
94 Title for the window.
95 """
96 # data. categories and labels:
97 self.raw_data = None # original data table as 2D numpy array (samples x features)
98 self.raw_labels = None # for each feature a label, optional with unit
99 self.categories = [] # for each feature None or list of categories
100 if isinstance(data, TableData):
101 for c, col in enumerate(data):
102 if not isinstance(data[col][0], (int, float,
103 np.integer, np.floating)):
104 # categorial data:
105 cats, data[:,c] = categorize(data[col])
106 self.categories.append(cats)
107 else:
108 self.categories.append(None)
109 self.raw_data = data.array()
110 if labels is None:
111 self.raw_labels = []
112 for c in range(data.columns()):
113 if len(data.unit(c)) > 0 and not data.unit(c) in ['-', '1']:
114 self.raw_labels.append(f'{data.label(c)} [{data.unit(c)}]')
115 else:
116 self.raw_labels.append(data.label(c))
117 else:
118 self.raw_labels = labels
119 else:
120 if isinstance(data, np.ndarray):
121 self.raw_data = data
122 self.categories = [None] * data.shape[1]
123 else:
124 for c, col in enumerate(data):
125 if not isinstance(col[0], (int, float,
126 np.integer, np.floating)):
127 # categorial data:
128 cats, data[c] = categorize(col)
129 self.categories.append(cats)
130 else:
131 self.categories.append(None)
132 self.raw_data = np.asarray(data).T
133 self.raw_labels = labels
134 # remove columns containing only invalid numbers:
135 cols = np.all(~np.isfinite(self.raw_data), 0)
136 if np.sum(cols) > 0:
137 print('removed columns containing no numbers:',
138 [l for l, c in zip(self.raw_labels, cols) if c])
139 self.raw_data = self.raw_data[:, ~cols]
140 self.raw_labels = [l for l, c in zip(self.raw_labels, cols) if not c]
141 # remove rows containing invalid numbers:
142 self.valid_samples = ~np.any(~np.isfinite(self.raw_data), 1)
143 self.raw_data = self.raw_data[self.valid_samples, :]
144 if np.sum(~self.valid_samples) > 0:
145 print(f'removed {np.sum(~self.valid_samples)} rows containing invalid numbers:')
146 for k in range(len(self.valid_samples)):
147 if not self.valid_samples[k]:
148 print(k)
149 self.valid_rows = [k for k in range(len(self.valid_samples))
150 if self.valid_samples[k]]
151 # title for the window:
152 self.title = title if title is not None else 'MultivariateExplorer'
153 # data, pca-data, scaled-pca data (no pca data yet):
154 self.all_data = [self.raw_data, None, None]
155 self.all_labels = [self.raw_labels, None, None]
156 self.all_maxcols = [self.raw_data.shape[1], None, None]
157 self.all_titles = ['data', 'PCA', 'scaled PCA'] # added to window title
158 # pca:
159 self.pca_tables = [None, None] # raw and scaled pca coefficients
160 self._pca_header(self.raw_data, self.raw_labels) # prepare header of the pca tables
161 # start showing raw data:
162 self.show_mode = 0 # show data, pca or scaled pca
163 self.data = self.all_data[self.show_mode] # the data shown
164 self.labels = self.all_labels[self.show_mode] # the feature labels shown
165 self.maxcols = self.all_maxcols[self.show_mode] # maximum number of features currently shown
166 if self.maxcols > 6:
167 self.maxcols = 6
168 # waveform data:
169 self.wave_data = []
170 self.wave_nested = False
171 self.wave_has_xticks = []
172 self.wave_xlabels = []
173 self.wave_ylabels = []
174 self.wave_title = False
175 # colors:
176 self.color_map = plt.get_cmap('jet')
177 self.extra_colors = None # additional data column to be used for coloring
178 self.extra_color_label = None # label for extra_colors
179 self.extra_categories = None # category name for extra_colors if needed
180 self.color_set_index = 0 # -1: rows and extra_colors, 0: data, 1: pca, 2: scaled pca
181 self.color_index = 0 # column used for coloring with color_set_index
182 self.color_values = None # data column currently used for coloring as specified by color_set_index and color_index
183 self.color_label = None # label of data currently used for coloring
184 self.data_colors = None # actual colors for color_values
185 self.color_vmin = None
186 self.color_vmax = None
187 self.color_ticks = None
188 self.cbax = None # axes of color bar
189 # figure variables:
190 self.plt_params = {}
191 for k in ['toolbar', 'keymap.quit', 'keymap.back', 'keymap.forward',
192 'keymap.zoom', 'keymap.pan', 'keymap.xscale', 'keymap.yscale']:
193 self.plt_params[k] = plt.rcParams[k]
194 if k != 'toolbar':
195 plt.rcParams[k] = ''
196 self.xborder = 100.0 # pixel for ylabels
197 self.yborder = 50.0 # pixel for xlabels
198 self.spacing = 10.0 # pixel between plots
199 self.mborder = 20.0 # pixel around magnified plot
200 self.pick_radius = 4.0
201 # histogram plots:
202 self.hist_ax = [] # list of histogram axes
203 self.hist_indices = [] # feature index of the histogram axes
204 self.hist_selector = [] # for each histogram axes a selector
205 self.hist_nbins = 30 # number of bins for computing histograms
206 # scatter plots:
207 self.scatter_ax = [] # list of axes with scatter plots (1D)
208 self.scatter_indices = [] # for each axes a tuple of the column and row index
209 self.scatter_artists = [] # artists of selected scatter points
210 self.scatter_selector = [] # selector for each axes
211 self.scatter = True # scatter (True) or density (False)
212 self.mark_data = [] # list of indices of selected data
213 self.significance_level = 0.05 # r is bold if p is below
214 self.select_zooms = False
215 self.zoom_stack = []
216 # magnified scatter plot:
217 self.magnified_on = False
218 self.magnified_backdrop = None
219 self.magnified_size = np.array([0.6, 0.6])
220 # waveform plots:
221 self.wave_ax = []
222 # help window:
223 self.help_ax = None
226 def set_wave_data(self, data, xlabels='', ylabels=[], title=False):
227 """Add waveform data to explorer.
229 Parameters
230 ----------
231 data: list of (list of) 2D arrays
232 Waveform data associated with each row of the data.
233 Elements of the outer list correspond to the rows of the data.
234 The inner 2D arrays contain a common x-axes (first column)
235 and one or more corresponding y-values (second and optional higher columns).
236 Each column for y-values is plotted in its own axes on top of each other,
237 from top to bottom.
238 The optional inner list of 2D arrays contains several 2D arrays as ascribed above
239 each with its own common x-axes.
240 xlabel: str or list of str
241 The xlabels for the waveform plots. If only a string is given, then
242 there will be a common xaxis for all the plots, and only the lowest
243 one gets a labeled xaxis. If a list of strings is given, each waveform
244 plot gets its own labeled x-axis.
245 ylabels: list of str
246 The ylabels for each of the waveform plots.
247 title: bool or str
248 If True or a string, povide space on top of the waveform plots for a title.
249 If string, set this as the title for the waveform plots.
250 """
251 self.wave_data = []
252 if data is not None and len(data) > 0:
253 self.wave_data = data
254 self.wave_has_xticks = []
255 self.wave_nested = isinstance(data[0], (list, tuple))
256 if self.wave_nested:
257 for data in self.wave_data[0]:
258 for k in range(data.shape[1]-2):
259 self.wave_has_xticks.append(False)
260 self.wave_has_xticks.append(True)
261 else:
262 for k in range(self.wave_data[0].shape[1]-2):
263 self.wave_has_xticks.append(False)
264 self.wave_has_xticks.append(True)
265 if isinstance(xlabels, (list, tuple)):
266 self.wave_xlabels = xlabels
267 else:
268 self.wave_xlabels = [xlabels]
269 self.wave_ylabels = ylabels
270 self.wave_title = title
271 self.wave_ax = []
274 def set_colors(self, colors=0, color_label=None, color_map=None):
275 """Set data column used to color scatter plots.
277 Parameters
278 ----------
279 colors: int or 1D array
280 Index to colum in data to be used for coloring scatter plots.
281 -2 for coloring row index of data.
282 Or data array used to color scalar plots.
283 color_label: str
284 If colors is an array, this is a label describing the data.
285 It is used to label the color bar.
286 color_map: str or None
287 Name of a matplotlib color map.
288 If None 'jet' is used.
289 """
290 if isinstance(colors, (np.integer, int)):
291 if colors < 0:
292 self.color_set_index = -1
293 self.color_index = 0
294 else:
295 self.color_set_index = 0
296 self.color_index = colors
297 else:
298 if not isinstance(colors[0], (int, float,
299 np.integer, np.floating)):
300 # categorial data:
301 self.extra_categories, self.extra_colors = categorize(colors)
302 else:
303 self.extra_colors = colors
304 self.extra_colors = self.extra_colors[self.valid_samples]
305 self.extra_color_label = color_label
306 self.color_set_index = -1
307 self.color_index = 1
308 self.color_map = plt.get_cmap(color_map if color_map else 'jet')
311 def show(self, ioff=True):
312 """Show interactive scatter plots for exploration.
313 """
314 if ioff:
315 plt.ioff()
316 else:
317 plt.ion()
318 plt.rcParams['toolbar'] = 'None'
319 plt.rcParams['keymap.quit'] = 'ctrl+w, alt+q, ctrl+q, q'
320 plt.rcParams['font.size'] = 12
321 self.fig = plt.figure(facecolor='white', figsize=(10, 8))
322 self.fig.canvas.manager.set_window_title(self.title + ': ' + self.all_titles[self.show_mode])
323 self.fig.canvas.mpl_connect('key_press_event', self._on_key)
324 self.fig.canvas.mpl_connect('resize_event', self._on_resize)
325 self.fig.canvas.mpl_connect('pick_event', self._on_pick)
326 if self.color_map is None:
327 self.color_map = plt.get_cmap('jet')
328 self._set_color_column()
329 self._init_hist_plots()
330 self._init_scatter_plots()
331 self.wave_ax = []
332 if self.wave_data is not None and len(self.wave_data) > 0:
333 axx = None
334 xi = 0
335 for k, has_xticks in enumerate(self.wave_has_xticks):
336 ax = self.fig.add_subplot(1, len(self.wave_has_xticks),
337 1+k, sharex=axx)
338 self.wave_ax.append(ax)
339 if has_xticks:
340 if xi >= len(self.wave_xlabels):
341 self.wave_xlabels.append('')
342 ax.set_xlabel(self.wave_xlabels[xi])
343 xi += 1
344 axx = None
345 else:
346 #ax.xaxis.set_major_formatter(plt.NullFormatter())
347 if axx is None:
348 axx = ax
349 for ax, ylabel in zip(self.wave_ax, self.wave_ylabels):
350 ax.set_ylabel(ylabel)
351 if not isinstance(self.wave_title, bool) and self.wave_title:
352 self.wave_ax[0].set_title(self.wave_title)
353 self.fix_waveform_plot(self.wave_ax, self.mark_data)
354 self._plot_magnified_scatter()
355 self._plot_help()
356 plt.show()
359 def _pca_header(self, data, labels):
360 """Set up header for the table of principal components.
362 Parameters
363 ----------
364 data: ndarray of float
365 The data (samples x features) without invalid (infinite or
366 NaN) numbers.
367 labels: list of str
368 Labels of the features.
369 """
370 lbs = []
371 for l, d in zip(labels, data):
372 if '[' in l:
373 lbs.append(l.split('[')[0].strip())
374 elif '/' in l:
375 lbs.append(l.split('/')[0].strip())
376 else:
377 lbs.append(l)
378 header = TableData(header=lbs)
379 header.set_formats('%.3f')
380 header.insert(0, ['PC'] + ['-']*header.nsecs, '', '%d')
381 header.insert(1, 'variance', '%', '%.3f')
382 for k in range(len(self.pca_tables)):
383 self.pca_tables[k] = TableData(header)
386 def compute_pca(self, scale=False, write=False):
387 """Compute PCA based on the data.
389 Parameters
390 ----------
391 scale: boolean
392 If True standardize data before computing PCA, i.e. remove mean
393 of each variabel and divide by its standard deviation.
394 write: boolean
395 If True write PCA components to standard out.
396 """
397 # pca:
398 pca = decomposition.PCA()
399 if scale:
400 scaler = preprocessing.StandardScaler()
401 scaler.fit(self.raw_data)
402 pca.fit(scaler.transform(self.raw_data))
403 pca_label = 'sPC'
404 else:
405 pca.fit(self.raw_data)
406 pca_label = 'PC'
407 for k in range(len(pca.components_)):
408 if np.abs(np.min(pca.components_[k])) > np.max(pca.components_[k]):
409 pca.components_[k] *= -1.0
410 pca_data = pca.transform(self.raw_data)
411 pca_labels = [f'{pca_label}{k+1} ' + (f'({100*v:.1f}%)' if v > 0.01 else (f'{100*v:.2f}%'))
412 for k, v in enumerate(pca.explained_variance_ratio_)]
413 if np.min(pca.explained_variance_ratio_) >= 0.01:
414 pca_maxcols = pca_data.shape[1]
415 else:
416 pca_maxcols = np.argmax(pca.explained_variance_ratio_ < 0.01)
417 if pca_maxcols < 2:
418 pca_maxcols = 2
419 if pca_maxcols > 6:
420 pca_maxcols = 6
421 # table with PCA feature weights:
422 pca_table = self.pca_tables[1] if scale else self.pca_tables[0]
423 pca_table.clear_data()
424 pca_table.set_section(0, pca_label, pca_table.nsecs)
425 for k, comp in enumerate(pca.components_):
426 pca_table.add(k+1, 0)
427 pca_table.add(100.0*pca.explained_variance_ratio_[k])
428 pca_table.add(comp)
429 if write:
430 pca_table.write(table_format='out', unit_style='none')
431 # submit data:
432 if scale:
433 self.all_data[2] = pca_data
434 self.all_labels[2] = pca_labels
435 self.all_maxcols[2] = pca_maxcols
436 else:
437 self.all_data[1] = pca_data
438 self.all_labels[1] = pca_labels
439 self.all_maxcols[1] = pca_maxcols
442 def save_pca(self, file_name, scale, **kwargs):
443 """Write PCA data to file.
445 Parameters
446 ----------
447 file_name: str
448 Name of ouput file.
449 scale: boolean
450 If True write PCA components of standardized PCA.
451 kwargs: dict
452 Additional parameter for TableData.write()
453 """
454 if scale:
455 pca_file = file_name + '-pcacor'
456 pca_table = self.pca_tables[1]
457 else:
458 pca_file = file_name + '-pcacov'
459 pca_table = self.pca_tables[0]
460 if 'unit_style' in kwargs:
461 del kwargs['unit_style']
462 if 'table_format' in kwargs:
463 pca_table.write(pca_file, unit_style='none', **kwargs)
464 else:
465 pca_file += '.dat'
466 pca_table.write(pca_file, unit_style='none')
469 def _set_color_column(self):
470 """Initialize variables used for colorization of scatter points."""
471 if self.color_set_index == -1:
472 if self.color_index == 0:
473 self.color_values = np.arange(self.data.shape[0], dtype=float)
474 self.color_label = 'sample'
475 elif self.color_index == 1:
476 self.color_values = self.extra_colors
477 self.color_label = self.extra_color_label
478 else:
479 self.color_values = self.all_data[self.color_set_index][:,self.color_index]
480 self.color_label = self.all_labels[self.color_set_index][self.color_index]
481 self.color_vmin, self.color_vmax, self.color_ticks = \
482 self.fix_scatter_plot(self.cbax, self.color_values,
483 self.color_label, 'c')
484 if self.color_ticks is None:
485 if self.color_set_index == 0 and \
486 self.categories[self.color_index] is not None:
487 self.color_ticks = np.arange(len(self.categories[self.color_index]))
488 elif self.color_set_index == -1 and \
489 self.color_index == 1 and \
490 self.extra_categories is not None:
491 self.color_ticks = np.arange(len(self.extra_categories))
492 self.data_colors = self.color_map((self.color_values - self.color_vmin)/(self.color_vmax - self.color_vmin))
495 def _add_backdrop(self, ax):
496 bbox = ax.get_tightbbox(self.fig.canvas.get_renderer())
497 if bbox is not None:
498 self.magnified_backdrop = \
499 patches.Rectangle((bbox.x0 - self.mborder,
500 bbox.y0 - self.mborder),
501 bbox.width + 2*self.mborder,
502 bbox.height + 2*self.mborder,
503 transform=None, clip_on=False,
504 facecolor='#f7f7f7', edgecolor='none',
505 zorder=-5)
506 ax.add_patch(self.magnified_backdrop)
509 def _create_selector(self, ax):
510 try:
511 selector = \
512 widgets.RectangleSelector(ax, self._on_select,
513 useblit=True, button=1,
514 minspanx=0, minspany=0,
515 spancoords='pixels',
516 props=dict(facecolor='gray',
517 edgecolor='gray',
518 alpha=0.2,
519 fill=True),
520 state_modifier_keys=dict(move='',
521 clear='',
522 square='',
523 center='ctrl'))
524 except TypeError:
525 # old matplotlib:
526 selector = widgets.RectangleSelector(ax, self._on_select,
527 useblit=True, button=1)
528 return selector
531 def _plot_hist(self, ax, magnifiedax):
532 """Plot and label a histogram."""
533 try:
534 idx = self.hist_ax.index(ax)
535 c = self.hist_indices[idx]
536 in_hist = True
537 except ValueError:
538 idx = self.scatter_ax.index(ax)
539 c = self.scatter_indices[idx][0]
540 in_hist = False
541 ax.clear()
542 #ax.relim()
543 #ax.autoscale(True)
544 x = self.data[:,c]
545 ax.hist(x, self.hist_nbins)
546 #ax.autoscale(False)
547 ax.set_xlabel(self.labels[c])
548 ax.xaxis.set_major_locator(plt.AutoLocator())
549 ax.xaxis.set_major_formatter(plt.ScalarFormatter())
550 if self.show_mode == 0:
551 if self.categories[c] is not None:
552 ax.set_xticks(np.arange(len(self.categories[c])))
553 ax.set_xticklabels(self.categories[c])
554 self.fix_scatter_plot(ax, self.data[:,c], self.labels[c], 'x')
555 if magnifiedax:
556 ax.text(0.05, 0.9, f'n={len(self.data)}',
557 transform=ax.transAxes, zorder=100)
558 ax.set_ylabel('count')
559 cax = self.hist_ax[self.scatter_indices[-1][0]]
560 ax.set_xlim(cax.get_xlim())
561 else:
562 if c == 0:
563 ax.text(0.05, 0.9, f'n={len(self.data)}',
564 transform=ax.transAxes, zorder=100)
565 ax.set_ylabel('count')
566 else:
567 ax.yaxis.set_major_formatter(plt.NullFormatter())
568 selector = self._create_selector(ax)
569 if in_hist:
570 self.hist_selector[idx] = selector
571 else:
572 self.scatter_selector[idx] = selector
573 self.scatter_artists[idx] = None
574 ax.relim(True)
575 if magnifiedax:
576 self._add_backdrop(ax)
579 def _set_hist_ylim(self):
580 ymax = np.max([ax.get_ylim() for ax in self.hist_ax[:self.maxcols]], 0)[1]
581 for ax in self.hist_ax:
582 ax.set_ylim(0, ymax)
585 def _init_hist_plots(self):
586 """Initial plots of the histograms."""
587 n = self.data.shape[1]
588 self.hist_ax = []
589 for r in range(n):
590 ax = self.fig.add_subplot(n, n, (n-1)*n+r+1)
591 self.hist_ax.append(ax)
592 self.hist_indices.append(r)
593 self.hist_selector.append(None)
594 self._plot_hist(ax, False)
595 self._set_hist_ylim()
598 def _plot_scatter(self, ax, magnifiedax, cax=None):
599 """Plot a scatter plot."""
600 idx = self.scatter_ax.index(ax)
601 c, r = self.scatter_indices[idx]
602 if self.scatter: # scatter plot
603 ax.clear()
604 a = ax.scatter(self.data[:,c], self.data[:,r], s=50,
605 edgecolors='white', linewidths=0.5,
606 picker=self.pick_radius, zorder=10)
607 a.set_facecolor(self.data_colors)
608 pr, pp = pearsonr(self.data[:,c], self.data[:,r])
609 fw = 'bold' if pp < self.significance_level/self.data.shape[1] else 'normal'
610 if pr < 0:
611 ax.text(0.95, 0.9, f'r={pr:.2f}, p={pp:.3f}', fontweight=fw,
612 transform=ax.transAxes, ha='right', zorder=100)
613 else:
614 ax.text(0.05, 0.9, f'r={pr:.2f}, p={pp:.3f}', fontweight=fw,
615 transform=ax.transAxes, zorder=100)
616 # color bar:
617 if cax is not None:
618 a = ax.scatter(self.data[:, c], self.data[:, r],
619 c=self.color_values, cmap=self.color_map)
620 self.fig.colorbar(a, cax=cax, ticks=self.color_ticks)
621 a.remove()
622 cax.set_ylabel(self.color_label)
623 self.color_vmin, self.color_vmax, self.color_ticks = \
624 self.fix_scatter_plot(self.cbax, self.color_values,
625 self.color_label, 'c')
626 if self.color_ticks is None:
627 if self.color_set_index == 0 and \
628 self.categories[self.color_index] is not None:
629 cax.set_yticklabels(self.categories[self.color_index])
630 elif self.color_set_index == -1 and \
631 self.color_index == 1 and \
632 self.extra_categories is not None:
633 cax.set_yticklabels(self.extra_categories)
634 else: # histogram
635 if self.show_mode == 0:
636 self.fix_scatter_plot(ax, self.data[:,c], self.labels[c], 'x')
637 self.fix_scatter_plot(ax, self.data[:,r], self.labels[r], 'y')
638 axrange = [ax.get_xlim(), ax.get_ylim()]
639 ax.clear()
640 ax.hist2d(self.data[:,c], self.data[:,r], self.hist_nbins,
641 range=axrange, cmap=plt.get_cmap('Greys'))
642 # selected data:
643 a = ax.scatter(self.data[self.mark_data, c],
644 self.data[self.mark_data, r], s=100,
645 edgecolors='black', linewidths=0.5,
646 picker=self.pick_radius, zorder=11)
647 a.set_facecolor(self.data_colors[self.mark_data])
648 self.scatter_artists[idx] = a
649 ax.xaxis.set_major_locator(plt.AutoLocator())
650 ax.yaxis.set_major_locator(plt.AutoLocator())
651 ax.xaxis.set_major_formatter(plt.ScalarFormatter())
652 ax.yaxis.set_major_formatter(plt.ScalarFormatter())
653 if self.show_mode == 0:
654 if self.categories[c] is not None:
655 ax.set_xticks(np.arange(len(self.categories[c])))
656 ax.set_xticklabels(self.categories[c])
657 if self.categories[r] is not None:
658 ax.set_yticks(np.arange(len(self.categories[r])))
659 ax.set_yticklabels(self.categories[r])
660 if magnifiedax:
661 ax.set_xlabel(self.labels[c])
662 ax.set_ylabel(self.labels[r])
663 cax = self.scatter_ax[self.scatter_indices[:-1].index(self.scatter_indices[-1])]
664 ax.set_xlim(cax.get_xlim())
665 ax.set_ylim(cax.get_ylim())
666 else:
667 if c == 0:
668 ax.set_ylabel(self.labels[r])
669 if self.show_mode == 0:
670 self.fix_scatter_plot(ax, self.data[:, c], self.labels[c], 'x')
671 self.fix_scatter_plot(ax, self.data[:, r], self.labels[r], 'y')
672 if not magnifiedax:
673 ax.xaxis.set_major_formatter(plt.NullFormatter())
674 if c > 0:
675 ax.yaxis.set_major_formatter(plt.NullFormatter())
676 ax.set_xlim(*self.hist_ax[c].get_xlim())
677 ax.set_ylim(*self.hist_ax[r].get_xlim())
678 if magnifiedax:
679 self._add_backdrop(ax)
680 selector = self._create_selector(ax)
681 self.scatter_selector[idx] = selector
682 ax.relim(True)
685 def _init_scatter_plots(self):
686 """Initial plots of scatter plots."""
687 self.cbax = self.fig.add_axes([0.5, 0.5, 0.1, 0.5])
688 cbax = self.cbax
689 n = self.data.shape[1]
690 for r in range(1, n):
691 for c in range(r):
692 ax = self.fig.add_subplot(n, n, (r-1)*n+c+1)
693 self.scatter_ax.append(ax)
694 self.scatter_indices.append([c, r])
695 self.scatter_artists.append(None)
696 self.scatter_selector.append(None)
697 self._plot_scatter(ax, False, cbax)
698 cbax = None
701 def _plot_magnified_scatter(self):
702 """Initial plot of the magnified scatter plot."""
703 ax = self.fig.add_axes([0.5, 0.9, 0.05, 0.05])
704 ax.set_visible(False)
705 self.magnified_on = False
706 c = 0
707 r = 1
708 a = ax.scatter(self.data[:, c], self.data[:, r],
709 s=50, edgecolors='none')
710 a.set_facecolor(self.data_colors)
711 a = ax.scatter(self.data[self.mark_data, c],
712 self.data[self.mark_data, r], s=80)
713 a.set_facecolor(self.data_colors[self.mark_data])
714 ax.set_xlabel(self.labels[c])
715 ax.set_ylabel(self.labels[r])
716 self.fix_scatter_plot(ax, self.data[:, c], self.labels[c], 'x')
717 self.fix_scatter_plot(ax, self.data[:, r], self.labels[r], 'y')
718 self.scatter_ax.append(ax)
719 self.scatter_indices.append([c, r])
720 self.scatter_artists.append(a)
721 self.scatter_selector.append(None)
724 def _plot_help(self):
725 ax = self.fig.add_subplot(1, 1, 1)
726 ax.set_position([0.02, 0.02, 0.96, 0.96])
727 ax.xaxis.set_major_locator(plt.NullLocator())
728 ax.yaxis.set_major_locator(plt.NullLocator())
729 n = len(self.mouse_actions) + len(self.key_actions) + 4
730 dy = 1/n
731 y = 1 - 1.5*dy
732 ax.text(0.05, y, 'Key shortcuts', transform=ax.transAxes,
733 fontweight='bold')
734 y -= dy
735 for a, d in self.key_actions:
736 ax.text(0.05, y, a, transform=ax.transAxes)
737 ax.text(0.3, y, d, transform=ax.transAxes)
738 y -= dy
739 y -= dy
740 ax.text(0.05, y, 'Mouse actions', transform=ax.transAxes,
741 fontweight='bold')
742 y -= dy
743 for a, d in self.mouse_actions:
744 ax.text(0.05, y, a, transform=ax.transAxes)
745 ax.text(0.3, y, d, transform=ax.transAxes)
746 y -= dy
747 ax.set_visible(False)
748 self.help_ax = ax
751 def fix_scatter_plot(self, ax, data, label, axis):
752 """Customize an axes of a scatter plot.
754 This function is called after a scatter plot has been plotted.
755 Once for the x axes, once for the y axis and once for the color bar.
756 Reimplement this function to set appropriate limits and ticks.
758 Return values are only used for the color bar (`axis='c'`).
759 Otherwise they are ignored.
761 For example, ticks for phase variables can be nicely labeled
762 using the unicode character for pi:
763 ```
764 if 'phase' in label:
765 if axis == 'y':
766 ax.set_ylim(0.0, 2.0*np.pi)
767 ax.set_yticks(np.arange(0.0, 2.5*np.pi, 0.5*np.pi))
768 ax.set_yticklabels(['0', u'\u03c0/2', u'\u03c0', u'3\u03c0/2', u'2\u03c0'])
769 ```
771 Parameters
772 ----------
773 ax: matplotlib axes
774 Axes of the scatter plot or color bar to be worked on.
775 data: 1D array
776 Data array of the axes.
777 label: str
778 Label coresponding to the data array.
779 axis: str
780 'x', 'y': set properties of x or y axes of ax.
781 'c': set properies of color bar axes (note that ax can be None!)
782 and return vmin, vmax, and ticks.
784 Returns
785 -------
786 min: float
787 minimum value of color bar axis
788 max: float
789 maximum value of color bar axis
790 ticks: list of float
791 position of ticks for color bar axis
792 """
793 return np.nanmin(data), np.nanmax(data), None
796 def fix_waveform_plot(self, axs, indices):
797 """Customize waveform plots.
799 This function is called once after new data have been plotted
800 into the waveform plots. Reimplement this function to customize
801 these plots. In particular to set axis limits and labels, plot
802 title, etc.
803 You may even open a new figure (with non-blocking `show()`).
805 The following member variables might be usefull:
806 - `self.wave_data`: the full list of waveform data.
807 - `self.wave_nested`: True if the elements of `self.wave_data` are lists of 2D arrays. Otherwise the elements are 2D arrays. The first column of a 2D array contains the x-values, further columns y-values.
808 - `self.wave_has_xticks`: List of booleans for each axis. True if the axis has its own xticks.
809 - `self.wave_xlabels`: List of xlabels (only for the axis where the corresponding entry in `self.wave_has_xticks` is True).
810 - `self.wave_ylabels`: for each axis its ylabel
812 For example, you can set the linewidth of all plotted waveforms via:
813 ```
814 for ax in axs:
815 for l in ax.lines:
816 l.set_linewidth(3.0)
817 ```
818 or enable markers to be plotted:
819 ```
820 for ax, yl in zip(axs, self.wave_ylabels):
821 if 'Power' in yl:
822 for l in ax.lines:
823 l.set_marker('.')
824 l.set_markersize(15.0)
825 l.set_markeredgewidth(0.5)
826 l.set_markeredgecolor('k')
827 l.set_markerfacecolor(l.get_color())
828 ```
829 Usefull is to reduce the maximum number of y-ticks:
830 ```
831 axs[0].yaxis.get_major_locator().set_params(nbins=7)
832 ```
833 or
834 ```
835 import matplotlib.ticker as ticker
836 axs[0].yaxis.set_major_locator(ticker.MaxNLocator(nbins=4))
837 ```
839 Parameters
840 ----------
841 axs: list of matplotlib axes
842 Axis of the waveform plots to be worked on.
843 indices: list of int
844 Indices of the waveforms that have been selected and plotted.
845 """
846 pass
849 def list_selection(self, indices):
850 """List information about the current selection of data points.
852 This function is called when 'l' is pressed. Reimplement this
853 function, for example, to print some meaningfull information
854 about the current selection of data points on console. You may
855 do, however, whatever you want in this function.
857 Parameters
858 ----------
859 indices: list of int
860 Indices of the data points that have been selected.
861 """
862 print('')
863 print('selected rows in data table:')
864 for i in indices:
865 print(self.valid_rows[i])
868 def analyze_selection(self, index):
869 """Provide further information about a single selected data point.
871 This function is called when a single data item was double
872 clicked. Reimplement this function to provide some further
873 details on this data point. This can be an additional figure
874 window. In this case show it non-blocking:
875 `plt.show(block=False)`
877 Parameters
878 ----------
879 index: int
880 The index of the selected data point.
881 """
882 pass
885 def _set_magnified_pos(self, width, height):
886 """Set position of magnified plot."""
887 if self.magnified_on:
888 xoffs = self.xborder/width
889 yoffs = self.yborder/height
890 if self.scatter_indices[-1][1] < self.data.shape[1]:
891 idx = self.scatter_indices[:-1].index(self.scatter_indices[-1])
892 pos = self.scatter_ax[idx].get_position().get_points()
893 else:
894 pos = self.hist_ax[self.scatter_indices[-1][0]].get_position().get_points()
895 pos[0] = np.mean(pos, 0) - 0.5*self.magnified_size
896 if pos[0][0] < xoffs: pos[0][0] = xoffs
897 if pos[0][1] < yoffs: pos[0][1] = yoffs
898 pos[1] = pos[0] + self.magnified_size
899 if pos[1][0] > 1.0-self.spacing/width: pos[1][0] = 1.0-self.spacing/width
900 if pos[1][1] > 1.0-self.spacing/height: pos[1][1] = 1.0-self.spacing/height
901 pos[0] = pos[1] - self.magnified_size
902 self.scatter_ax[-1].set_position([pos[0][0], pos[0][1],
903 self.magnified_size[0], self.magnified_size[1]])
904 self.scatter_ax[-1].set_visible(True)
905 else:
906 self.scatter_ax[-1].set_position([0.5, 0.9, 0.05, 0.05])
907 self.scatter_ax[-1].set_visible(False)
910 def _make_selection(self, ax, key, x0, x1, y0, y1):
911 """Select points from a scatter or histogram plot."""
912 if not key in ['shift', 'control']:
913 self.mark_data = []
914 if ax in self.scatter_ax:
915 axi = self.scatter_ax.index(ax)
916 # from scatter plots:
917 c, r = self.scatter_indices[axi]
918 if r < self.data.shape[1]:
919 # from scatter:
920 for ind, (x, y) in enumerate(zip(self.data[:, c], self.data[:, r])):
921 if x >= x0 and x <= x1 and y >= y0 and y <= y1:
922 if ind in self.mark_data:
923 if key == 'control':
924 self.mark_data.remove(ind)
925 elif key != 'control':
926 self.mark_data.append(ind)
927 else:
928 # from histogram:
929 for ind, x in enumerate(self.data[:, c]):
930 if x >= x0 and x <= x1:
931 if ind in self.mark_data:
932 if key == 'control':
933 self.mark_data.remove(ind)
934 elif key != 'control':
935 self.mark_data.append(ind)
936 elif ax in self.hist_ax:
937 r = self.hist_indices[self.hist_ax.index(ax)]
938 # from histogram:
939 for ind, x in enumerate(self.data[:, r]):
940 if x >= x0 and x <= x1:
941 if ind in self.mark_data:
942 if key == 'control':
943 self.mark_data.remove(ind)
944 elif key != 'control':
945 self.mark_data.append(ind)
948 def _update_selection(self):
949 """Highlight selected points in the scatter plots and plot corresponding waveforms."""
950 # update scatter plots:
951 for artist, (c, r) in zip(self.scatter_artists, self.scatter_indices):
952 if artist is not None:
953 if len(self.mark_data) == 0:
954 artist.set_offsets(np.zeros((0, 2)))
955 else:
956 artist.set_offsets(list(zip(self.data[self.mark_data, c],
957 self.data[self.mark_data, r])))
958 artist.set_facecolors(self.data_colors[self.mark_data])
959 # waveform plots:
960 if len(self.wave_ax) > 0:
961 axdi = 0
962 axti = 1
963 for xi, ax in enumerate(self.wave_ax):
964 ax.clear()
965 if len(self.mark_data) > 0:
966 for idx in self.mark_data:
967 if self.wave_nested:
968 data = self.wave_data[idx][axdi]
969 else:
970 data = self.wave_data[idx]
971 if data is not None:
972 ax.plot(data[:, 0], data[:, axti],
973 c=self.data_colors[idx],
974 picker=self.pick_radius)
975 axti += 1
976 if self.wave_has_xticks[xi]:
977 ax.set_xlabel(self.wave_xlabels[axdi])
978 axti = 1
979 axdi += 1
980 #else:
981 # ax.xaxis.set_major_formatter(plt.NullFormatter())
982 for ax, ylabel in zip(self.wave_ax, self.wave_ylabels):
983 ax.set_ylabel(ylabel)
984 if not isinstance(self.wave_title, bool) and self.wave_title:
985 self.wave_ax[0].set_title(self.wave_title)
986 self.fix_waveform_plot(self.wave_ax, self.mark_data)
987 self.fig.canvas.draw()
990 def _set_limits(self, ax, x0, x1, y0, y1):
991 if ax in self.hist_ax:
992 ax.set_xlim(x0, x1)
993 for hax in self.hist_ax:
994 hax.set_ylim(y0, y1)
995 cc = self.hist_indices[self.hist_ax.index(ax)]
996 for sax, (c, r) in zip(self.scatter_ax, self.scatter_indices):
997 if c == cc:
998 sax.set_xlim(x0, x1)
999 if r == cc:
1000 sax.set_ylim(x0, x1)
1001 if ax in self.scatter_ax:
1002 idx = self.scatter_ax.index(ax)
1003 cc, rr = self.scatter_indices[idx]
1004 self.hist_ax[cc].set_xlim(x0, x1)
1005 self.hist_ax[rr].set_xlim(y0, y1)
1006 for sax, (c, r) in zip(self.scatter_ax, self.scatter_indices):
1007 if c == cc:
1008 sax.set_xlim(x0, x1)
1009 if c == rr:
1010 sax.set_xlim(y0, y1)
1011 if r == cc:
1012 sax.set_ylim(x0, x1)
1013 if r == rr:
1014 sax.set_ylim(y0, y1)
1017 def _on_key(self, event):
1018 """Handle key events."""
1019 #print('pressed', event.key)
1020 if event.key in ['left', 'right', 'up', 'down']:
1021 if self.magnified_on:
1022 mc, mr = self.scatter_indices[-1]
1023 if event.key == 'left':
1024 if mc > 0:
1025 self.scatter_indices[-1][0] -= 1
1026 elif mr > 1:
1027 if mr >= self.data.shape[1]:
1028 self.scatter_indices[-1][1] = self.maxcols - 1
1029 else:
1030 self.scatter_indices[-1][1] -= 1
1031 self.scatter_indices[-1][0] = self.scatter_indices[-1][1] - 1
1032 else:
1033 self.scatter_indices[-1][0] = self.data.shape[1] - 1
1034 self.scatter_indices[-1][1] = self.data.shape[1]
1035 elif event.key == 'right':
1036 if mc < mr - 1 and mc < self.maxcols - 1:
1037 self.scatter_indices[-1][0] += 1
1038 elif mr < self.maxcols:
1039 self.scatter_indices[-1][0] = 0
1040 self.scatter_indices[-1][1] += 1
1041 if mr >= self.maxcols:
1042 self.scatter_indices[-1][1] = self.data.shape[1]
1043 else:
1044 self.scatter_indices[-1][0] = 0
1045 self.scatter_indices[-1][1] = 1
1046 elif event.key == 'up':
1047 if mr > mc + 1:
1048 if mr >= self.data.shape[1]:
1049 self.scatter_indices[-1][1] = self.maxcols - 1
1050 else:
1051 self.scatter_indices[-1][1] -= 1
1052 elif mc > 0:
1053 self.scatter_indices[-1][0] -= 1
1054 self.scatter_indices[-1][1] = self.data.shape[1]
1055 else:
1056 self.scatter_indices[-1][0] = self.data.shape[1] - 1
1057 self.scatter_indices[-1][1] = self.data.shape[1]
1058 elif event.key == 'down':
1059 if mr < self.maxcols:
1060 self.scatter_indices[-1][1] += 1
1061 if mr >= self.maxcols:
1062 self.scatter_indices[-1][1] = self.data.shape[1]
1063 elif mc < self.maxcols - 1:
1064 self.scatter_indices[-1][0] += 1
1065 self.scatter_indices[-1][1] = mc + 2
1066 if self.scatter_indices[-1][1] >= self.maxcols:
1067 self.scatter_indices[-1][1] = self.data.shape[1]
1068 else:
1069 self.scatter_indices[-1][0] = 0
1070 self.scatter_indices[-1][1] = 1
1071 for k in reversed(range(len(self.zoom_stack))):
1072 if self.zoom_stack[k][0] == self.scatter_ax[-1]:
1073 del self.zoom_stack[k]
1074 self.scatter_ax[-1].clear()
1075 self.scatter_ax[-1].set_visible(True)
1076 self.magnified_on = True
1077 self._set_magnified_pos(self.fig.get_window_extent().width,
1078 self.fig.get_window_extent().height)
1079 if self.scatter_indices[-1][1] < self.data.shape[1]:
1080 self._plot_scatter(self.scatter_ax[-1], True)
1081 else:
1082 self._plot_hist(self.scatter_ax[-1], True)
1083 self.fig.canvas.draw()
1084 else:
1085 if event.key == 'escape':
1086 if len(self.scatter_ax) >= self.data.shape[1]:
1087 self.scatter_ax[-1].set_position([0.5, 0.9, 0.05, 0.05])
1088 self.magnified_on = False
1089 self.scatter_ax[-1].set_visible(False)
1090 self.fig.canvas.draw()
1091 elif event.key in 'oz':
1092 self.select_zooms = not self.select_zooms
1093 elif event.key == 'backspace':
1094 if len(self.zoom_stack) > 0:
1095 self._set_limits(*self.zoom_stack.pop())
1096 self.fig.canvas.draw()
1097 elif event.key in '+=':
1098 self.pick_radius *= 1.5
1099 elif event.key in '-':
1100 if self.pick_radius > 5.0:
1101 self.pick_radius /= 1.5
1102 elif event.key in '0':
1103 self.pick_radius = 4.0
1104 elif event.key in ['pageup', 'pagedown', '<', '>']:
1105 if event.key in ['pageup', '<'] and self.maxcols > 2:
1106 self.maxcols -= 1
1107 elif event.key in ['pagedown', '>'] and self.maxcols < self.raw_data.shape[1]:
1108 self.maxcols += 1
1109 for ax in self.hist_ax:
1110 self._plot_hist(ax, False)
1111 self._update_layout()
1112 elif event.key == 'w':
1113 if len(self.wave_data) > 0:
1114 if self.maxcols > 0:
1115 self.all_maxcols[self.show_mode] = self.maxcols
1116 self.maxcols = 0
1117 else:
1118 self.maxcols = self.all_maxcols[self.show_mode]
1119 self._set_layout(self.fig.get_window_extent().width,
1120 self.fig.get_window_extent().height)
1121 self.fig.canvas.draw()
1122 elif event.key == 'ctrl+a':
1123 self.mark_data = list(range(len(self.data)))
1124 self._update_selection()
1125 elif event.key in 'cC':
1126 if event.key in 'c':
1127 self.color_index -= 1
1128 if self.color_index < 0:
1129 self.color_set_index -= 1
1130 if self.color_set_index < -1:
1131 self.color_set_index = len(self.all_data)-1
1132 if self.color_set_index >= 0:
1133 if self.all_data[self.color_set_index] is None:
1134 self.compute_pca(self.color_set_index>1, True)
1135 self.color_index = self.all_data[self.color_set_index].shape[1]-1
1136 else:
1137 self.color_index = 0 if self.extra_colors is None else 1
1138 self._set_color_column()
1139 else:
1140 self.color_index += 1
1141 if (self.color_set_index >= 0 and \
1142 self.color_index >= self.all_data[self.color_set_index].shape[1]) or \
1143 (self.color_set_index < 0 and \
1144 self.color_index >= (1 if self.extra_colors is None else 2)):
1145 self.color_index = 0
1146 self.color_set_index += 1
1147 if self.color_set_index >= len(self.all_data):
1148 self.color_set_index = -1
1149 elif self.all_data[self.color_set_index] is None:
1150 self.compute_pca(self.color_set_index>1, True)
1151 self._set_color_column()
1152 for ax in self.scatter_ax:
1153 ax.collections[0].set_facecolors(self.data_colors)
1154 for a in self.scatter_artists:
1155 if a is not None:
1156 a.set_facecolors(self.data_colors[self.mark_data])
1157 for ax in self.wave_ax:
1158 for l, c in zip(ax.lines, self.data_colors[self.mark_data]):
1159 l.set_color(c)
1160 l.set_markerfacecolor(c)
1161 self._plot_scatter(self.scatter_ax[0], False, self.cbax)
1162 self.fix_scatter_plot(self.cbax, self.color_values,
1163 self.color_label, 'c')
1164 self.fig.canvas.draw()
1165 elif event.key in 'nN':
1166 if event.key in 'N':
1167 self.hist_nbins = (self.hist_nbins*3)//2
1168 elif self.hist_nbins >= 15:
1169 self.hist_nbins = (self.hist_nbins*2)//3
1170 else:
1171 self.hist_nbins = 10
1172 for ax in self.hist_ax:
1173 self._plot_hist(ax, False)
1174 self._set_hist_ylim()
1175 if self.scatter_indices[-1][1] >= self.data.shape[1]:
1176 self._plot_hist(self.scatter_ax[-1], True, True)
1177 elif not self.scatter:
1178 self._plot_scatter(self.scatter_ax[-1], True)
1179 if not self.scatter:
1180 for ax in self.scatter_ax[:-1]:
1181 self._plot_scatter(ax, False)
1182 self.fig.canvas.draw()
1183 elif event.key in 'H':
1184 self.scatter = not self.scatter
1185 for ax in self.scatter_ax[:-1]:
1186 self._plot_scatter(ax, False)
1187 if self.scatter_indices[-1][1] < self.data.shape[1]:
1188 self._plot_scatter(self.scatter_ax[-1], True)
1189 self.fig.canvas.draw()
1190 elif event.key in 'pP':
1191 if len(self.scatter_ax) >= self.data.shape[1]:
1192 self.scatter_ax[-1].set_position([0.5, 0.9, 0.05, 0.05])
1193 self.scatter_indices[-1] = [0, 1]
1194 self.magnified_on = False
1195 self.scatter_ax[-1].set_visible(False)
1196 self.all_maxcols[self.show_mode] = self.maxcols
1197 if event.key == 'P':
1198 self.show_mode += 1
1199 if self.show_mode >= len(self.all_data):
1200 self.show_mode = 0
1201 else:
1202 self.show_mode -= 1
1203 if self.show_mode < 0:
1204 self.show_mode = len(self.all_data)-1
1205 if self.show_mode == 1:
1206 print('principal components')
1207 elif self.show_mode == 2:
1208 print('scaled principal components')
1209 else:
1210 print('data')
1211 if self.all_data[self.show_mode] is None:
1212 self.compute_pca(self.show_mode>1, True)
1213 self.data = self.all_data[self.show_mode]
1214 self.labels = self.all_labels[self.show_mode]
1215 self.maxcols = self.all_maxcols[self.show_mode]
1216 self.zoom_stack = []
1217 self.fig.canvas.manager.set_window_title(self.title + ': ' + self.all_titles[self.show_mode])
1218 for ax in self.hist_ax:
1219 self._plot_hist(ax, False)
1220 self._set_hist_ylim()
1221 for ax in self.scatter_ax:
1222 self._plot_scatter(ax, False)
1223 self._update_layout()
1224 elif event.key in 'l':
1225 if len(self.mark_data) > 0:
1226 self.list_selection(self.mark_data)
1227 elif event.key in 'h':
1228 self.help_ax.set_visible(not self.help_ax.get_visible())
1229 self.fig.canvas.draw()
1232 def _on_select(self, eclick, erelease):
1233 """Handle selection events."""
1234 if eclick.dblclick:
1235 if len(self.mark_data) > 0:
1236 self.analyze_selection(self.mark_data[-1])
1237 return
1238 x0 = min(eclick.xdata, erelease.xdata)
1239 x1 = max(eclick.xdata, erelease.xdata)
1240 y0 = min(eclick.ydata, erelease.ydata)
1241 y1 = max(eclick.ydata, erelease.ydata)
1242 ax = erelease.inaxes
1243 if ax is None:
1244 ax = eclick.inaxes
1245 xmin, xmax = ax.get_xlim()
1246 ymin, ymax = ax.get_ylim()
1247 dx = 0.02*(xmax-xmin)
1248 dy = 0.02*(ymax-ymin)
1249 if x1 - x0 < dx and y1 - y0 < dy:
1250 bbox = ax.get_window_extent().transformed(self.fig.dpi_scale_trans.inverted())
1251 width, height = bbox.width, bbox.height
1252 width *= self.fig.dpi
1253 height *= self.fig.dpi
1254 dx = self.pick_radius*(xmax-xmin)/width
1255 dy = self.pick_radius*(ymax-ymin)/height
1256 x0 = erelease.xdata - dx
1257 x1 = erelease.xdata + dx
1258 y0 = erelease.ydata - dy
1259 y1 = erelease.ydata + dy
1260 elif self.select_zooms:
1261 self.zoom_stack.append((ax, xmin, xmax, ymin, ymax))
1262 self._set_limits(ax, x0, x1, y0, y1)
1263 self._make_selection(ax, erelease.key, x0, x1, y0, y1)
1264 self._update_selection()
1267 def _on_pick(self, event):
1268 """Handle pick events."""
1269 for ax in self.wave_ax:
1270 for k, l in enumerate(ax.lines):
1271 if l is event.artist:
1272 self.mark_data = [self.mark_data[k]]
1273 for ax in self.scatter_ax:
1274 if ax.collections[0] is event.artist:
1275 self.mark_data = event.ind
1276 self._update_selection()
1277 if event.mouseevent.dblclick:
1278 if len(self.mark_data) > 0:
1279 self.analyze_selection(self.mark_data[-1])
1282 def _set_layout(self, width, height):
1283 """Update positions and visibility of all plots."""
1284 xoffs = self.xborder/width
1285 yoffs = self.yborder/height
1286 xs = self.spacing/width
1287 ys = self.spacing/height
1288 if self.maxcols > 0:
1289 dx = (1.0-xoffs)/self.maxcols
1290 dy = (1.0-yoffs)/self.maxcols
1291 xw = dx - xs
1292 yw = dy - ys
1293 # histograms:
1294 for c, ax in enumerate(self.hist_ax):
1295 if c < self.maxcols:
1296 ax.set_position([xoffs+c*dx, yoffs, xw, yw])
1297 ax.set_visible(True)
1298 else:
1299 ax.set_visible(False)
1300 ax.set_position([0.99, 0.01, 0.01, 0.01])
1301 # scatter plots:
1302 for ax, (c, r) in zip(self.scatter_ax[:-1], self.scatter_indices[:-1]):
1303 if r < self.maxcols:
1304 ax.set_position([xoffs+c*dx, yoffs+(self.maxcols-r)*dy, xw, yw])
1305 ax.set_visible(True)
1306 else:
1307 ax.set_visible(False)
1308 ax.set_position([0.99, 0.01, 0.01, 0.01])
1309 # color bar:
1310 if self.maxcols > 0:
1311 self.cbax.set_position([xoffs+dx, yoffs+(self.maxcols-1)*dy, 0.3*xoffs, yw])
1312 self.cbax.set_visible(True)
1313 else:
1314 self.cbax.set_visible(False)
1315 self.cbax.set_position([0.99, 0.01, 0.01, 0.01])
1316 # magnified plot:
1317 if self.maxcols > 0:
1318 self._set_magnified_pos(width, height)
1319 if self.magnified_backdrop is not None:
1320 bbox = self.scatter_ax[-1].get_tightbbox(self.fig.canvas.get_renderer())
1321 if bbox is not None:
1322 self.magnified_backdrop.set_bounds(bbox.x0 - self.mborder,
1323 bbox.y0 - self.mborder,
1324 bbox.width + 2*self.mborder,
1325 bbox.height + 2*self.mborder)
1326 else:
1327 self.scatter_ax[-1].set_position([0.5, 0.9, 0.05, 0.05])
1328 self.scatter_ax[-1].set_visible(False)
1329 # waveform plots:
1330 if len(self.wave_ax) > 0:
1331 if self.maxcols > 0:
1332 x0 = xoffs+((self.maxcols+1)//2)*dx
1333 y0 = ((self.maxcols+1)//2)*dy
1334 if self.maxcols%2 == 0:
1335 x0 += xoffs
1336 y0 += yoffs - ys
1337 else:
1338 y0 += ys
1339 else:
1340 x0 = xoffs
1341 y0 = 0.0
1342 yp = 1.0
1343 dy = 1.0-y0
1344 dy -= np.sum(self.wave_has_xticks)*yoffs
1345 yp -= ys
1346 dy -= ys
1347 if self.wave_title:
1348 yp -= 2*ys
1349 dy -= 2*ys
1350 dy /= len(self.wave_ax)
1351 for ax, has_xticks in zip(self.wave_ax, self.wave_has_xticks):
1352 yp -= dy
1353 ax.set_position([x0, yp, 1.0-x0-xs, dy])
1354 if has_xticks:
1355 yp -= yoffs
1356 else:
1357 yp -= ys
1360 def _update_layout(self):
1361 """Update content and position of magnified plot."""
1362 if self.scatter_indices[-1][1] < self.data.shape[1]:
1363 if self.scatter_indices[-1][1] >= self.maxcols:
1364 self.scatter_indices[-1][1] = self.maxcols-1
1365 if self.scatter_indices[-1][0] >= self.scatter_indices[-1][1]:
1366 self.scatter_indices[-1][0] = self.scatter_indices[-1][1]-1
1367 self._plot_scatter(self.scatter_ax[-1], True)
1368 else:
1369 if self.scatter_indices[-1][0] >= self.maxcols:
1370 self.scatter_indices[-1][0] = self.maxcols-1
1371 self._plot_hist(self.scatter_ax[-1], True)
1372 self._set_hist_ylim()
1373 self._set_layout(self.fig.get_window_extent().width,
1374 self.fig.get_window_extent().height)
1375 self.fig.canvas.draw()
1378 def _on_resize(self, event):
1379 """Adapt layout of plots to new figure size."""
1380 self._set_layout(event.width, event.height)
1383def categorize(data):
1384 """Convert categorial string data into integer categories.
1386 Parameters
1387 ----------
1388 data: list or ndarray of str
1389 Data with textual categories.
1391 Returns
1392 -------
1393 categories: list of str
1394 A sorted unique list of the strings in `data`,
1395 that is the names of the categories.
1396 cdata: ndarray of int
1397 A copy of the input `data` where each string value is replaced
1398 by an integer number that is an index into the returned `categories`.
1399 """
1400 cats = sorted(set(data))
1401 cdata = np.array([cats.index(x) for x in data], dtype=int)
1402 return cats, cdata
1405def select_features(data, columns):
1406 """Assemble list of column indices.
1408 Parameters
1409 ----------
1410 data: TableData
1411 The table from which to select features.
1412 columns: list of str
1413 Feature names (column headers) to be selected from the data.
1414 If a column is listed twice (even times) it is not added.
1416 Returns
1417 -------
1418 data_cols: list of int
1419 List of indices into data columns for selecting features.
1420 """
1421 if len(columns) == 0:
1422 data_cols = list(np.arange(len(data)))
1423 else:
1424 data_cols = []
1425 for c in columns:
1426 idx = data.index(c)
1427 if idx is None:
1428 print(f'"{c}" is not a valid data column')
1429 elif idx in data_cols:
1430 data_cols.remove(idx)
1431 else:
1432 data_cols.append(idx)
1433 return data_cols
1436def select_coloring(data, data_cols, color_col):
1437 """Select column from data table for colorizing the data.
1439 Pass the output of this function on to MultivariateExplorer.set_colors().
1441 Parameters
1442 ----------
1443 data: TableData
1444 Table with all EOD properties from which columns are selected.
1445 data_cols: list of int
1446 List of columns selected to be explored.
1447 color_col: str or int
1448 Column to be selected for coloring the data.
1449 If 'row' then use the row index of the data in the table for coloring.
1451 Returns
1452 -------
1453 colors: int or list of values or None
1454 Either index of `data_cols` or additional data from the data table
1455 to be used for coloring.
1456 color_label: str or None
1457 Label for labeling the color bar.
1458 color_idx: int or None
1459 Index of color column in `data`.
1460 error: None or str
1461 In case an invalid column is selected, an error string.
1462 """
1463 color_idx = data.index(color_col)
1464 colors = None
1465 color_label = None
1466 if color_idx is None and color_col != 'row':
1467 return None, None, None, f'"{color_col}" is not a valid column for color code'
1468 if color_idx is None:
1469 colors = -2
1470 elif color_idx in data_cols:
1471 colors = data_cols.index(color_idx)
1472 else:
1473 if len(data.unit(color_idx)) > 0 and not data.unit(color_idx) in ['-', '1']:
1474 color_label = f'{data.label(color_idx)} [{data.unit(color_idx)}]'
1475 else:
1476 color_label = data.label(color_idx)
1477 colors = data[:, color_idx]
1478 return colors, color_label, color_idx, None
1481def list_available_features(data, data_cols=[], color_col=None):
1482 """Print available features on console.
1484 Parameters
1485 ----------
1486 data: TableData
1487 The full data table.
1488 data_cols: list of int
1489 List of indices of columns (features) in the table
1490 that are passed on to the MultivariateExplorer.
1491 color_col: int or None
1492 Index of columns (feature) in the table
1493 that is initially used for color coding the data.
1494 """
1495 print('available features:')
1496 for k, c in enumerate(data.keys()):
1497 s = [' '] * 3
1498 if k in data_cols:
1499 s[1] = '*'
1500 if color_col is not None and k == color_col:
1501 s[0] = 'C'
1502 print(''.join(s) + c)
1503 if len(data_cols) > 0:
1504 print('*: feature selected for exploration')
1505 if color_col is not None:
1506 print('C: feature selected for color coding the data')
1509class PrintHelp(argparse.Action):
1510 def __call__(self, parser, namespace, values, option_string):
1511 parser.print_help()
1512 print('')
1513 print('mouse:')
1514 for ma in MultivariateExplorer.mouse_actions:
1515 print('%-23s %s' % ma)
1516 print('%-23s %s' % ('double left click', 'run thunderfish on selected EOD waveform'))
1517 print('')
1518 print('key shortcuts:')
1519 for ka in MultivariateExplorer.key_actions:
1520 print('%-23s %s' % ka)
1521 parser.exit()
1524def demo():
1525 """Run the multivariate explorer with a random test data set.
1526 """
1527 # generate data:
1528 n = 100
1529 data = []
1530 data.append(np.random.randn(n) + 2.0)
1531 data.append(1.0+0.1*data[0] + 1.5*np.random.randn(n))
1532 data.append(10*(-3.0*data[0] + 2.0*data[1] + 1.8*np.random.randn(n)))
1533 idx = np.random.randint(0, 3, n)
1534 names = ['aaa', 'bbb', 'ccc']
1535 data.append([names[i] for i in idx])
1536 # generate waveforms:
1537 waveforms = []
1538 time = np.arange(0.0, 10.0, 0.01)
1539 for r in range(len(data[0])):
1540 x = data[0][r]*np.sin(2.0*np.pi*data[1][r]*time + data[2][r])
1541 y = data[0][r]*np.exp(-0.5*((time-data[1][r])/(0.2*data[2][r]))**2.0)
1542 waveforms.append(np.column_stack((time, x, y)))
1543 #waveforms.append([np.column_stack((time, x)), np.column_stack((time, y))])
1544 # initialize explorer:
1545 expl = MultivariateExplorer(data,
1546 list(map(chr, np.arange(len(data))+ord('A'))),
1547 'Explorer')
1548 expl.set_wave_data(waveforms, 'Time', ['Sine', 'Gauss'])
1549 # explore data:
1550 expl.set_colors()
1551 expl.show()
1554def main(*cargs):
1555 # parse command line:
1556 parser = argparse.ArgumentParser(add_help=False,
1557 description='View and explore multivariate data.',
1558 epilog = f'version {__version__} by Benda-Lab (2019-{__year__})')
1559 parser.add_argument('-h', '--help', nargs=0, action=PrintHelp,
1560 help='show this help message and exit')
1561 parser.add_argument('--version', action='version', version=__version__)
1562 parser.add_argument('-l', dest='list_features', action='store_true',
1563 help='list all available data columns (features) and exit')
1564 parser.add_argument('-d', dest='data_cols', action='append',
1565 default=[], metavar='COLUMN',
1566 help='data columns (features) to be explored')
1567 parser.add_argument('-c', dest='color_col', default=None,
1568 type=str, metavar='COLUMN',
1569 help='data column to be used for color code or "row"')
1570 parser.add_argument('-m', dest='color_map', default='jet',
1571 type=str, metavar='CMAP',
1572 help='name of color map to be used')
1573 parser.add_argument('file', nargs='?', default='', type=str,
1574 help='a file containing a table of data (csv file or similar)')
1575 if len(cargs) == 0:
1576 cargs = None
1577 args = parser.parse_args(cargs)
1578 if args.file:
1579 # load data:
1580 data = TableData(args.file)
1581 data_cols = select_features(data, args.data_cols)
1582 # select column used for coloring the data:
1583 colors, color_label, color_col, error = \
1584 select_coloring(data, data_cols, args.color_col)
1585 if error:
1586 parser.error(error)
1587 # list features:
1588 if args.list_features:
1589 list_available_features(data, data_cols, color_col)
1590 parser.exit()
1591 # explore data:
1592 expl = MultivariateExplorer(data[:, data_cols])
1593 expl.set_colors(colors, color_label, args.color_map)
1594 expl.show()
1595 else:
1596 demo()
1599if __name__ == '__main__':
1600 main(*sys.argv[1:])